Skip to content

vllm.attention.ops.rocm_aiter_mla

tags module-attribute

tags = ()

aiter_mla_decode_fwd

aiter_mla_decode_fwd(
    q: Tensor,
    kv_buffer: Tensor,
    o: Tensor,
    sm_scale: float,
    qo_indptr: Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[Tensor] = None,
    kv_indices: Optional[Tensor] = None,
    kv_last_page_lens: Optional[Tensor] = None,
    logit_cap: float = 0.0,
)
Source code in vllm/attention/ops/rocm_aiter_mla.py
def aiter_mla_decode_fwd(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    sm_scale: float,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    logit_cap: float = 0.0,
):
    torch.ops.vllm.rocm_aiter_mla_decode_fwd(
        q,
        kv_buffer.view(-1, 1, 1, q.shape[-1]),
        o,
        qo_indptr,
        max_seqlen_qo,
        kv_indptr,
        kv_indices,
        kv_last_page_lens,
        sm_scale=sm_scale,
        logit_cap=logit_cap,
    )

get_aiter_mla_metadata

get_aiter_mla_metadata(
    max_batch_size: int,
    block_size: int,
    max_block_per_batch: int,
    device: device,
) -> tuple[Tensor, ...]
Source code in vllm/attention/ops/rocm_aiter_mla.py
def get_aiter_mla_metadata(
    max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device
) -> tuple[torch.Tensor, ...]:
    paged_kv_indices = torch.zeros(
        max_batch_size * max_block_per_batch, dtype=torch.int32, device=device
    )
    paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device)
    paged_kv_last_page_lens = torch.full(
        (max_batch_size,), block_size, dtype=torch.int32
    )
    qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
    return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr

mla_decode_fwd_fake

mla_decode_fwd_fake(
    q: Tensor,
    kv_buffer: Tensor,
    o: Tensor,
    qo_indptr: Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[Tensor] = None,
    kv_indices: Optional[Tensor] = None,
    kv_last_page_lens: Optional[Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None
Source code in vllm/attention/ops/rocm_aiter_mla.py
def mla_decode_fwd_fake(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None:
    pass

mla_decode_fwd_impl

mla_decode_fwd_impl(
    q: Tensor,
    kv_buffer: Tensor,
    o: Tensor,
    qo_indptr: Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[Tensor] = None,
    kv_indices: Optional[Tensor] = None,
    kv_last_page_lens: Optional[Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None
Source code in vllm/attention/ops/rocm_aiter_mla.py
def mla_decode_fwd_impl(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None:
    from aiter.mla import mla_decode_fwd

    mla_decode_fwd(
        q,
        kv_buffer.view(-1, 1, 1, q.shape[-1]),
        o,
        qo_indptr,
        kv_indptr,
        kv_indices,
        kv_last_page_lens,
        max_seqlen_qo,
        sm_scale=sm_scale,
        logit_cap=logit_cap,
    )