Skip to content

vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize

DEEPEP_QUANT_BLOCK_SHAPE module-attribute

DEEPEP_QUANT_BLOCK_SHAPE = [
    DEEPEP_QUANT_BLOCK_SIZE,
    DEEPEP_QUANT_BLOCK_SIZE,
]

DEEPEP_QUANT_BLOCK_SIZE module-attribute

DEEPEP_QUANT_BLOCK_SIZE = 128

DeepEPLLPrepareAndFinalize

Bases: FusedMoEPrepareAndFinalize

Prepare/Finalize using DeepEP low-latency kernels.

Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
    """
    Prepare/Finalize using DeepEP low-latency kernels.
    """

    # DeepEP low-latency kernels are compiled only for certain
    # specific hidden sizes.
    SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168]

    def __init__(
        self,
        buffer: deep_ep.Buffer,
        max_tokens_per_rank: int,
        num_dispatchers: int,
        use_fp8_dispatch: bool = False,
    ):
        super().__init__()

        self.buffer = buffer
        self.max_tokens_per_rank = max_tokens_per_rank
        self.use_fp8_dispatch = use_fp8_dispatch
        # The dispatch function returns a handle that the combine function
        # requires. We store the handle here so it is available to the
        # combine function.
        self.handles: list[Optional[tuple]] = [None, None]
        self.num_dispatchers_ = num_dispatchers

    def num_dispatchers(self) -> int:
        return self.num_dispatchers_

    @property
    def activation_format(self) -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.BatchedExperts

    def max_num_tokens_per_rank(self) -> Optional[int]:
        return self.max_tokens_per_rank

    def topk_indices_dtype(self) -> Optional[torch.dtype]:
        return torch.int64

    def _do_quant(
        self,
        x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
        a1_dtype: torch.dtype,
        quant_config: FusedMoEQuantConfig,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        if self.use_fp8_dispatch:
            block_k = (
                quant_config.block_shape[1]
                if quant_config.block_shape is not None
                else None
            )
            if block_k == DEEPEP_QUANT_BLOCK_SIZE:
                # DeepEP kernels did the quantization for us.
                x, x_scales = x
                return x, x_scales

            # Dequant to get back the tokens in the datatype we dispatched in.
            x_fp8, x_scales = x
            x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)

        assert isinstance(x, torch.Tensor)

        num_experts, max_tokens, hidden_dim = x.size()

        # TODO (varun): Optimization - Use a batched version of quant
        x = x.view((-1, hidden_dim))
        x, x_scales = moe_kernel_quantize_input(
            x,
            quant_config.a1_scale,
            quant_config.quant_dtype,
            quant_config.per_act_token_quant,
            quant_config.block_shape,
        )
        x = x.view((num_experts, -1, hidden_dim))

        if quant_config.quant_dtype is not None:
            assert x_scales is not None
            x_scales = normalize_batched_scales_shape(x_scales, num_experts)

        return x, x_scales

    def supports_async(self) -> bool:
        return True

    def prepare_async(
        self,
        a1: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_experts: int,
        expert_map: Optional[torch.Tensor],
        apply_router_weight_on_input: bool,
        quant_config: FusedMoEQuantConfig,
    ) -> tuple[Callable, mk.ReceiverType]:
        hidden_size = a1.size(1)
        assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
            f"Hidden Size {hidden_size} not in supported list of hidden sizes"
            f"{self.SUPPORTED_HIDDEN_SIZES}"
        )

        a2a_idx = dbo_current_ubatch_id()

        if self.use_fp8_dispatch:
            assert hidden_size % 128 == 0, (
                "DeepEP kernels quantize the inputs in blocks of shape 128"
            )

        has_per_token_scales = (
            quant_config.a1_scale.numel() != 1
            if quant_config.a1_scale is not None
            else (
                quant_config.a2_scale.numel() != 1
                if quant_config.a2_scale is not None
                else False
            )
        )
        assert not has_per_token_scales, (
            "low_latency kernels doesn't support dispatching per-token scales"
        )

        if apply_router_weight_on_input:
            topk = topk_ids.size(1)
            # TODO: this only works for topK=1, will need to update for topK>1
            assert topk == 1, (
                "apply_router_weight_on_input is only implemented for topk=1"
            )
            a1 = a1 * topk_weights.to(a1.dtype)

        # Dispatch
        expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
            a1,
            topk_ids,
            self.max_tokens_per_rank,
            num_experts,
            use_fp8=self.use_fp8_dispatch,
            async_finish=False,
            return_recv_hook=True,
        )
        self.handles[a2a_idx] = handle

        return (
            hook,
            lambda: self._receiver(
                expert_x,
                expert_num_tokens,
                quant_config.a1_scale,
                a1.dtype,
                quant_config,
            ),
        )

    def _receiver(
        self,
        expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
        expert_num_tokens: torch.Tensor,
        a1_scale: Optional[torch.Tensor],
        a1_dtype: torch.dtype,
        quant_config: FusedMoEQuantConfig,
    ) -> mk.PrepareResultType:
        expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)

        expert_tokens_meta = mk.ExpertTokensMetadata(
            expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
        )

        return expert_x, expert_x_scale, expert_tokens_meta, None, None

    def prepare(
        self,
        a1: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_experts: int,
        expert_map: Optional[torch.Tensor],
        apply_router_weight_on_input: bool,
        quant_config: FusedMoEQuantConfig,
    ) -> mk.PrepareResultType:
        hook, receiver = self.prepare_async(
            a1,
            topk_weights,
            topk_ids,
            num_experts,
            expert_map,
            apply_router_weight_on_input,
            quant_config,
        )
        hook()
        return receiver()

    def _finalize(
        self,
        output: torch.Tensor,
        fused_expert_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
        weight_and_reduce_impl: mk.TopKWeightAndReduce,
        do_async: bool,
    ) -> tuple[Callable, Callable]:
        assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
            "Weight application and reduction happens in the combine kernel."
        )

        a2a_idx = dbo_current_ubatch_id()
        do_recv_hook = dbo_enabled() or do_async
        handle = self.handles[a2a_idx]
        assert handle is not None

        combine_topk_weights = topk_weights
        if apply_router_weight_on_input:
            # weights have already been applied.
            combine_topk_weights = torch.ones_like(topk_weights)

        # TODO (varun) : Enable zero copy mode
        dbo_maybe_run_recv_hook()
        _, _, recv_hook = self.buffer.low_latency_combine(
            fused_expert_output,
            topk_ids,
            combine_topk_weights,
            handle,
            async_finish=False,
            zero_copy=False,
            return_recv_hook=do_recv_hook,
            out=output,
        )

        return recv_hook, lambda: None

    def finalize_async(
        self,
        output: torch.Tensor,
        fused_expert_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
        weight_and_reduce_impl: mk.TopKWeightAndReduce,
    ) -> tuple[Callable, Callable]:
        return self._finalize(
            output,
            fused_expert_output,
            topk_weights,
            topk_ids,
            apply_router_weight_on_input,
            weight_and_reduce_impl,
            do_async=True,
        )

    def finalize(
        self,
        output: torch.Tensor,
        fused_expert_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
        weight_and_reduce_impl: mk.TopKWeightAndReduce,
    ) -> None:
        self._finalize(
            output,
            fused_expert_output,
            topk_weights,
            topk_ids,
            apply_router_weight_on_input,
            weight_and_reduce_impl,
            do_async=False,
        )

SUPPORTED_HIDDEN_SIZES class-attribute instance-attribute

SUPPORTED_HIDDEN_SIZES = [
    2048,
    2560,
    4096,
    5120,
    6144,
    7168,
]

activation_format property

activation_format: FusedMoEActivationFormat

buffer instance-attribute

buffer = buffer

handles instance-attribute

handles: list[Optional[tuple]] = [None, None]

max_tokens_per_rank instance-attribute

max_tokens_per_rank = max_tokens_per_rank

num_dispatchers_ instance-attribute

num_dispatchers_ = num_dispatchers

use_fp8_dispatch instance-attribute

use_fp8_dispatch = use_fp8_dispatch

__init__

__init__(
    buffer: Buffer,
    max_tokens_per_rank: int,
    num_dispatchers: int,
    use_fp8_dispatch: bool = False,
)
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def __init__(
    self,
    buffer: deep_ep.Buffer,
    max_tokens_per_rank: int,
    num_dispatchers: int,
    use_fp8_dispatch: bool = False,
):
    super().__init__()

    self.buffer = buffer
    self.max_tokens_per_rank = max_tokens_per_rank
    self.use_fp8_dispatch = use_fp8_dispatch
    # The dispatch function returns a handle that the combine function
    # requires. We store the handle here so it is available to the
    # combine function.
    self.handles: list[Optional[tuple]] = [None, None]
    self.num_dispatchers_ = num_dispatchers

_do_quant

_do_quant(
    x: Union[Tensor, tuple[Tensor, Tensor]],
    a1_dtype: dtype,
    quant_config: FusedMoEQuantConfig,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def _do_quant(
    self,
    x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
    a1_dtype: torch.dtype,
    quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    if self.use_fp8_dispatch:
        block_k = (
            quant_config.block_shape[1]
            if quant_config.block_shape is not None
            else None
        )
        if block_k == DEEPEP_QUANT_BLOCK_SIZE:
            # DeepEP kernels did the quantization for us.
            x, x_scales = x
            return x, x_scales

        # Dequant to get back the tokens in the datatype we dispatched in.
        x_fp8, x_scales = x
        x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)

    assert isinstance(x, torch.Tensor)

    num_experts, max_tokens, hidden_dim = x.size()

    # TODO (varun): Optimization - Use a batched version of quant
    x = x.view((-1, hidden_dim))
    x, x_scales = moe_kernel_quantize_input(
        x,
        quant_config.a1_scale,
        quant_config.quant_dtype,
        quant_config.per_act_token_quant,
        quant_config.block_shape,
    )
    x = x.view((num_experts, -1, hidden_dim))

    if quant_config.quant_dtype is not None:
        assert x_scales is not None
        x_scales = normalize_batched_scales_shape(x_scales, num_experts)

    return x, x_scales

_finalize

_finalize(
    output: Tensor,
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: TopKWeightAndReduce,
    do_async: bool,
) -> tuple[Callable, Callable]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def _finalize(
    self,
    output: torch.Tensor,
    fused_expert_output: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: mk.TopKWeightAndReduce,
    do_async: bool,
) -> tuple[Callable, Callable]:
    assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
        "Weight application and reduction happens in the combine kernel."
    )

    a2a_idx = dbo_current_ubatch_id()
    do_recv_hook = dbo_enabled() or do_async
    handle = self.handles[a2a_idx]
    assert handle is not None

    combine_topk_weights = topk_weights
    if apply_router_weight_on_input:
        # weights have already been applied.
        combine_topk_weights = torch.ones_like(topk_weights)

    # TODO (varun) : Enable zero copy mode
    dbo_maybe_run_recv_hook()
    _, _, recv_hook = self.buffer.low_latency_combine(
        fused_expert_output,
        topk_ids,
        combine_topk_weights,
        handle,
        async_finish=False,
        zero_copy=False,
        return_recv_hook=do_recv_hook,
        out=output,
    )

    return recv_hook, lambda: None

_receiver

_receiver(
    expert_x: Union[Tensor, tuple[Tensor, Tensor]],
    expert_num_tokens: Tensor,
    a1_scale: Optional[Tensor],
    a1_dtype: dtype,
    quant_config: FusedMoEQuantConfig,
) -> PrepareResultType
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def _receiver(
    self,
    expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
    expert_num_tokens: torch.Tensor,
    a1_scale: Optional[torch.Tensor],
    a1_dtype: torch.dtype,
    quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
    expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)

    expert_tokens_meta = mk.ExpertTokensMetadata(
        expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
    )

    return expert_x, expert_x_scale, expert_tokens_meta, None, None

finalize

finalize(
    output: Tensor,
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: TopKWeightAndReduce,
) -> None
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def finalize(
    self,
    output: torch.Tensor,
    fused_expert_output: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
    self._finalize(
        output,
        fused_expert_output,
        topk_weights,
        topk_ids,
        apply_router_weight_on_input,
        weight_and_reduce_impl,
        do_async=False,
    )

finalize_async

finalize_async(
    output: Tensor,
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: TopKWeightAndReduce,
) -> tuple[Callable, Callable]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def finalize_async(
    self,
    output: torch.Tensor,
    fused_expert_output: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    apply_router_weight_on_input: bool,
    weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> tuple[Callable, Callable]:
    return self._finalize(
        output,
        fused_expert_output,
        topk_weights,
        topk_ids,
        apply_router_weight_on_input,
        weight_and_reduce_impl,
        do_async=True,
    )

max_num_tokens_per_rank

max_num_tokens_per_rank() -> Optional[int]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def max_num_tokens_per_rank(self) -> Optional[int]:
    return self.max_tokens_per_rank

num_dispatchers

num_dispatchers() -> int
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def num_dispatchers(self) -> int:
    return self.num_dispatchers_

prepare

prepare(
    a1: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    num_experts: int,
    expert_map: Optional[Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> PrepareResultType
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def prepare(
    self,
    a1: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
    expert_map: Optional[torch.Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
    hook, receiver = self.prepare_async(
        a1,
        topk_weights,
        topk_ids,
        num_experts,
        expert_map,
        apply_router_weight_on_input,
        quant_config,
    )
    hook()
    return receiver()

prepare_async

prepare_async(
    a1: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    num_experts: int,
    expert_map: Optional[Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, ReceiverType]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def prepare_async(
    self,
    a1: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
    expert_map: Optional[torch.Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, mk.ReceiverType]:
    hidden_size = a1.size(1)
    assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
        f"Hidden Size {hidden_size} not in supported list of hidden sizes"
        f"{self.SUPPORTED_HIDDEN_SIZES}"
    )

    a2a_idx = dbo_current_ubatch_id()

    if self.use_fp8_dispatch:
        assert hidden_size % 128 == 0, (
            "DeepEP kernels quantize the inputs in blocks of shape 128"
        )

    has_per_token_scales = (
        quant_config.a1_scale.numel() != 1
        if quant_config.a1_scale is not None
        else (
            quant_config.a2_scale.numel() != 1
            if quant_config.a2_scale is not None
            else False
        )
    )
    assert not has_per_token_scales, (
        "low_latency kernels doesn't support dispatching per-token scales"
    )

    if apply_router_weight_on_input:
        topk = topk_ids.size(1)
        # TODO: this only works for topK=1, will need to update for topK>1
        assert topk == 1, (
            "apply_router_weight_on_input is only implemented for topk=1"
        )
        a1 = a1 * topk_weights.to(a1.dtype)

    # Dispatch
    expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
        a1,
        topk_ids,
        self.max_tokens_per_rank,
        num_experts,
        use_fp8=self.use_fp8_dispatch,
        async_finish=False,
        return_recv_hook=True,
    )
    self.handles[a2a_idx] = handle

    return (
        hook,
        lambda: self._receiver(
            expert_x,
            expert_num_tokens,
            quant_config.a1_scale,
            a1.dtype,
            quant_config,
        ),
    )

supports_async

supports_async() -> bool
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def supports_async(self) -> bool:
    return True

topk_indices_dtype

topk_indices_dtype() -> Optional[dtype]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def topk_indices_dtype(self) -> Optional[torch.dtype]:
    return torch.int64

dequant_fp8

dequant_fp8(
    expert_x_fp8: Tensor, expert_x_scales: Tensor
) -> Tensor

Return dequantized tensor in fp32

Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def dequant_fp8(
    expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
) -> torch.Tensor:
    """
    Return dequantized tensor in fp32
    """
    # TODO (varun) : Optimize leverage num_tokens_per_expert counts
    assert expert_x_fp8.is_contiguous()
    expert_x_scales = expert_x_scales.contiguous()
    num_experts = expert_x_fp8.size(0)

    expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
        num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE
    )
    expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
    return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())