vllm.attention.ops.flashmla ¶
flash_mla_sparse_prefill ¶
flash_mla_sparse_prefill(
q: Tensor,
kv: Tensor,
indices: Tensor,
sm_scale: float,
d_v: int = 512,
) -> tuple[Tensor, Tensor, Tensor]
Sparse attention prefill kernel
Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - d_v: The dimension of value vectors. Can only be 512
- (output, max_logits, lse) About the definition of output, max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
Source code in vllm/attention/ops/flashmla.py
flash_mla_with_kvcache ¶
flash_mla_with_kvcache(
q: Tensor,
k_cache: Tensor,
block_table: Tensor,
cache_seqlens: Tensor,
head_dim_v: int,
tile_scheduler_metadata: Tensor,
num_splits: Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[Tensor] = None,
descale_k: Optional[Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[Tensor] = None,
) -> tuple[Tensor, Tensor]
Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head dimension of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. - softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. - is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md - indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the indices
array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up indices
, please refer to README.md.
Returns: - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
Source code in vllm/attention/ops/flashmla.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
|
get_mla_metadata ¶
get_mla_metadata(
cache_seqlens: Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None,
) -> tuple[Tensor, Tensor]
Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. - num_heads_k: The number of k heads. - num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the indices
array passed to flash_mla_with_kvcache_sm90
will be attended to.
- tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
Source code in vllm/attention/ops/flashmla.py
is_flashmla_supported ¶
Return: is_supported_flag, unsupported_reason (optional).