Skip to content

vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe

TritonOrDeepGemmExperts

Bases: FusedMoEPermuteExpertsUnpermute

Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
    def __init__(
        self,
        quant_config: FusedMoEQuantConfig,
        allow_deep_gemm: bool = False,
    ):
        super().__init__(quant_config)

        self.triton_expert = TritonExperts(quant_config)

        self.allow_deep_gemm = (
            allow_deep_gemm
            and self.quant_config.use_fp8_w8a8
            and self.block_shape == deep_gemm_block_shape()
        )

        self.deep_gemm_expert = (
            DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None
        )

    @property
    def activation_formats(
        self,
    ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
        assert (
            self.deep_gemm_expert is None
            or self.triton_expert.activation_formats
            == self.deep_gemm_expert.activation_formats
        )
        return self.triton_expert.activation_formats

    def supports_chunking(self) -> bool:
        dge = self.deep_gemm_expert
        te = self.triton_expert
        return (dge is None or dge.supports_chunking()) and (
            te is None or te.supports_chunking()
        )

    def supports_expert_map(self) -> bool:
        dge = self.deep_gemm_expert
        te = self.triton_expert
        return (dge is None or dge.supports_expert_map()) and (
            te is None or te.supports_expert_map()
        )

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        dge = self.deep_gemm_expert
        te = self.triton_expert
        dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
        te_war = te.finalize_weight_and_reduce_impl() if te else None
        is_dge_war = dge_war is not None
        is_te_war = te_war is not None

        if is_dge_war and is_te_war:
            assert dge_war == te_war, (
                "Both implementations should agree on WeightAndReduce impls. "
                f"Got dge_war: {dge_war}, and te_war: {te_war}"
            )

        if dge_war is not None:
            return dge_war

        assert te_war is not None
        return te_war

    def workspace_shapes(
        self,
        a: torch.Tensor,
        aq: torch.Tensor,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
        # Note: the deep gemm workspaces are strictly larger than the triton
        # workspaces so we can be pessimistic here and allocate for DeepGemm
        # even if we fall back to triton later, e.g. if expert maps are set.
        if self.allow_deep_gemm and (
            is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)
        ):
            assert self.deep_gemm_expert is not None
            return self.deep_gemm_expert.workspace_shapes(
                a,
                aq,
                M,
                N,
                K,
                topk,
                global_num_experts,
                local_num_experts,
                expert_tokens_meta,
            )
        else:
            return self.triton_expert.workspace_shapes(
                a,
                aq,
                M,
                N,
                K,
                topk,
                global_num_experts,
                local_num_experts,
                expert_tokens_meta,
            )

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        expert_map: Optional[torch.Tensor],
        a1q_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
        apply_router_weight_on_input: bool,
    ):
        use_deep_gemm = self.allow_deep_gemm and (
            _valid_deep_gemm(hidden_states, w1, w2) or is_deep_gemm_e8m0_used()
        )

        experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
        assert experts is not None

        experts.apply(
            output,
            hidden_states,
            w1,
            w2,
            topk_weights,
            topk_ids,
            activation,
            global_num_experts,
            expert_map,
            a1q_scale,
            a2_scale,
            workspace13,
            workspace2,
            expert_tokens_meta,
            apply_router_weight_on_input,
        )

activation_formats property

allow_deep_gemm instance-attribute

allow_deep_gemm = (
    allow_deep_gemm
    and use_fp8_w8a8
    and block_shape == deep_gemm_block_shape()
)

deep_gemm_expert instance-attribute

deep_gemm_expert = (
    DeepGemmExperts(quant_config)
    if allow_deep_gemm
    else None
)

triton_expert instance-attribute

triton_expert = TritonExperts(quant_config)

__init__

__init__(
    quant_config: FusedMoEQuantConfig,
    allow_deep_gemm: bool = False,
)
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def __init__(
    self,
    quant_config: FusedMoEQuantConfig,
    allow_deep_gemm: bool = False,
):
    super().__init__(quant_config)

    self.triton_expert = TritonExperts(quant_config)

    self.allow_deep_gemm = (
        allow_deep_gemm
        and self.quant_config.use_fp8_w8a8
        and self.block_shape == deep_gemm_block_shape()
    )

    self.deep_gemm_expert = (
        DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None
    )

apply

apply(
    output: Tensor,
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[Tensor],
    a1q_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    workspace13: Tensor,
    workspace2: Tensor,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
)
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def apply(
    self,
    output: torch.Tensor,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[torch.Tensor],
    a1q_scale: Optional[torch.Tensor],
    a2_scale: Optional[torch.Tensor],
    workspace13: torch.Tensor,
    workspace2: torch.Tensor,
    expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
):
    use_deep_gemm = self.allow_deep_gemm and (
        _valid_deep_gemm(hidden_states, w1, w2) or is_deep_gemm_e8m0_used()
    )

    experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
    assert experts is not None

    experts.apply(
        output,
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        activation,
        global_num_experts,
        expert_map,
        a1q_scale,
        a2_scale,
        workspace13,
        workspace2,
        expert_tokens_meta,
        apply_router_weight_on_input,
    )

finalize_weight_and_reduce_impl

finalize_weight_and_reduce_impl() -> TopKWeightAndReduce
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
    dge = self.deep_gemm_expert
    te = self.triton_expert
    dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
    te_war = te.finalize_weight_and_reduce_impl() if te else None
    is_dge_war = dge_war is not None
    is_te_war = te_war is not None

    if is_dge_war and is_te_war:
        assert dge_war == te_war, (
            "Both implementations should agree on WeightAndReduce impls. "
            f"Got dge_war: {dge_war}, and te_war: {te_war}"
        )

    if dge_war is not None:
        return dge_war

    assert te_war is not None
    return te_war

supports_chunking

supports_chunking() -> bool
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def supports_chunking(self) -> bool:
    dge = self.deep_gemm_expert
    te = self.triton_expert
    return (dge is None or dge.supports_chunking()) and (
        te is None or te.supports_chunking()
    )

supports_expert_map

supports_expert_map() -> bool
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def supports_expert_map(self) -> bool:
    dge = self.deep_gemm_expert
    te = self.triton_expert
    return (dge is None or dge.supports_expert_map()) and (
        te is None or te.supports_expert_map()
    )

workspace_shapes

workspace_shapes(
    a: Tensor,
    aq: Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[
    tuple[int, ...], tuple[int, ...], tuple[int, ...], dtype
]
Source code in vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
def workspace_shapes(
    self,
    a: torch.Tensor,
    aq: torch.Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
    # Note: the deep gemm workspaces are strictly larger than the triton
    # workspaces so we can be pessimistic here and allocate for DeepGemm
    # even if we fall back to triton later, e.g. if expert maps are set.
    if self.allow_deep_gemm and (
        is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)
    ):
        assert self.deep_gemm_expert is not None
        return self.deep_gemm_expert.workspace_shapes(
            a,
            aq,
            M,
            N,
            K,
            topk,
            global_num_experts,
            local_num_experts,
            expert_tokens_meta,
        )
    else:
        return self.triton_expert.workspace_shapes(
            a,
            aq,
            M,
            N,
            K,
            topk,
            global_num_experts,
            local_num_experts,
            expert_tokens_meta,
        )