Skip to content

vllm.compilation.fusion

FP4_DTYPE module-attribute

FP4_DTYPE = uint8

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

FUSED_OPS module-attribute

FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
    FusedRMSQuantKey(kFp8StaticTensorSym, False): default,
    FusedRMSQuantKey(kFp8StaticTensorSym, True): default,
    FusedRMSQuantKey(kFp8DynamicTokenSym, False): default,
    FusedRMSQuantKey(kFp8DynamicTokenSym, True): default,
}

QUANT_OPS module-attribute

QUANT_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: default,
    kFp8DynamicTensorSym: default,
    kFp8DynamicTokenSym: default,
}

RMS_ADD_OP module-attribute

RMS_ADD_OP = default

RMS_OP module-attribute

RMS_OP = default

logger module-attribute

logger = init_logger(__name__)

FusedAddRMSNormDynamicQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            result: torch.Tensor,
            input: torch.Tensor,
            residual: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                RMS_ADD_OP,
                input=input,
                residual=residual,
                weight=weight,
                epsilon=self.epsilon,
            )
            at1 = auto_functionalized(
                self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
            )

            # result, residual, scale
            return at1[1], at[2], at1[2]

        def replacement(
            result: torch.Tensor,
            input: torch.Tensor,
            residual: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )

            # result, residual, scale
            return at[1], at[3], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1),  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape = PER_TOKEN,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape = GroupShape.PER_TOKEN,
    symmetric=True,
):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(
        result: torch.Tensor,
        input: torch.Tensor,
        residual: torch.Tensor,
        weight: torch.Tensor,
        scale: torch.Tensor,
    ):
        at = auto_functionalized(
            RMS_ADD_OP,
            input=input,
            residual=residual,
            weight=weight,
            epsilon=self.epsilon,
        )
        at1 = auto_functionalized(
            self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
        )

        # result, residual, scale
        return at1[1], at[2], at1[2]

    def replacement(
        result: torch.Tensor,
        input: torch.Tensor,
        residual: torch.Tensor,
        weight: torch.Tensor,
        scale: torch.Tensor,
    ):
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
            scale_ub=None,
            residual=residual,
        )

        # result, residual, scale
        return at[1], at[3], at[2]

    inputs = [
        torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
        empty_bf16(5, 4),  # input
        empty_bf16(5, 4),  # residual
        empty_bf16(1, 5),  # weight
        empty_fp32(1, 1),  # scale
    ]

    pm.register_replacement(
        pattern,
        replacement,
        inputs,
        pm.fwd_only,
        pm_pass,
    )

FusedAddRMSNormStaticQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            result: torch.Tensor,
            input: torch.Tensor,
            residual: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                RMS_ADD_OP,
                input=input,
                residual=residual,
                weight=weight,
                epsilon=self.epsilon,
            )
            at1 = auto_functionalized(
                self.QUANT_OP, result=result, input=at[1], scale=scale
            )

            # result, residual
            return at1[1], at[2]

        def replacement(
            result: torch.Tensor,
            input: torch.Tensor,
            residual: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )

            # result, residual
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # input
            empty_bf16(5, 4),  # residual
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1),  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float, quant_dtype: dtype, symmetric=True
)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(
            dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
        ),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(
        result: torch.Tensor,
        input: torch.Tensor,
        residual: torch.Tensor,
        weight: torch.Tensor,
        scale: torch.Tensor,
    ):
        at = auto_functionalized(
            RMS_ADD_OP,
            input=input,
            residual=residual,
            weight=weight,
            epsilon=self.epsilon,
        )
        at1 = auto_functionalized(
            self.QUANT_OP, result=result, input=at[1], scale=scale
        )

        # result, residual
        return at1[1], at[2]

    def replacement(
        result: torch.Tensor,
        input: torch.Tensor,
        residual: torch.Tensor,
        weight: torch.Tensor,
        scale: torch.Tensor,
    ):
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            residual=residual,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
        )

        # result, residual
        return at[1], at[2]

    inputs = [
        torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
        empty_bf16(5, 4),  # input
        empty_bf16(5, 4),  # residual
        empty_bf16(1, 5),  # weight
        empty_fp32(1, 1),  # scale
    ]

    pm.register_replacement(
        pattern,
        replacement,
        inputs,
        pm.fwd_only,
        pm_pass,
    )

FusedRMSQuantKey

Bases: NamedTuple

Named tuple for identifying the type of RMSNorm + quant fusion. quant: type of quantization fused_add: does the op also perform the residual add

Source code in vllm/compilation/fusion.py
class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """

    quant: QuantKey
    fused_add: bool

    def __str__(self):
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )

fused_add instance-attribute

fused_add: bool

quant instance-attribute

quant: QuantKey

__str__

__str__()
Source code in vllm/compilation/fusion.py
def __str__(self):
    return (
        f"FusedQuantKey({self.quant}, with"
        f"{'' if self.fused_add else 'out'} residual)"
    )

RMSNormDynamicQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            result: torch.Tensor,
            result_rms: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at1 = auto_functionalized(
                RMS_OP,
                result=result_rms,
                input=input,
                weight=weight,
                epsilon=self.epsilon,
            )
            at2 = auto_functionalized(
                self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
            )

            # result, scale
            return at2[1], at2[2]

        def replacement(
            result: torch.Tensor,
            result_rms: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )

            # result, scale
            return at[1], at[2]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1),  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape = PER_TOKEN,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape = GroupShape.PER_TOKEN,
    symmetric=True,
):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(
        result: torch.Tensor,
        result_rms: torch.Tensor,
        input: torch.Tensor,
        weight: torch.Tensor,
        scale: torch.Tensor,
    ):
        at1 = auto_functionalized(
            RMS_OP,
            result=result_rms,
            input=input,
            weight=weight,
            epsilon=self.epsilon,
        )
        at2 = auto_functionalized(
            self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
        )

        # result, scale
        return at2[1], at2[2]

    def replacement(
        result: torch.Tensor,
        result_rms: torch.Tensor,
        input: torch.Tensor,
        weight: torch.Tensor,
        scale: torch.Tensor,
    ):
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
            scale_ub=None,
            residual=None,
        )

        # result, scale
        return at[1], at[2]

    inputs = [
        torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
        empty_bf16(5, 4),  # result_rms
        empty_bf16(5, 4),  # input
        empty_bf16(1, 5),  # weight
        empty_fp32(1, 1),  # scale
    ]

    pm.register_replacement(
        pattern,
        replacement,
        inputs,
        pm.fwd_only,
        pm_pass,
    )

RMSNormQuantFusionPass

Bases: VllmPatternMatcherPass

This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. It also supports fused_add_rms_norm.

Source code in vllm/compilation/fusion.py
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
    """
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
    """

    @enable_fake_mode
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="rmsnorm_quant_fusion_pass"
        )

        for epsilon in [1e-5, 1e-6]:
            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

            # Fuse fused_add_rms_norm + static fp8 quant
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
                self.patterns
            )

            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
                self.patterns
            )

        self.dump_patterns(config, self.patterns)

    @VllmInductorPass.time_and_log
    def __call__(self, graph: fx.Graph):
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
        return self.hash_source(
            self,
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
        )

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="rmsnorm_quant_fusion_pass"
)

__call__

__call__(graph: Graph)
Source code in vllm/compilation/fusion.py
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
    self.matched_count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", self.matched_count)

__init__

__init__(config: VllmConfig)
Source code in vllm/compilation/fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig):
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="rmsnorm_quant_fusion_pass"
    )

    for epsilon in [1e-5, 1e-6]:
        # Fuse rms_norm + static fp8 quant
        RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

        # Fuse fused_add_rms_norm + static fp8 quant
        FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
            self.patterns
        )

        # Fuse rms_norm + dynamic per-token fp8 quant
        RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

        # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
        FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
            self.patterns
        )

    self.dump_patterns(config, self.patterns)

uuid

uuid() -> Any
Source code in vllm/compilation/fusion.py
def uuid(self) -> Any:
    return self.hash_source(
        self,
        RMSNormQuantPattern,
        RMSNormStaticQuantPattern,
        RMSNormDynamicQuantPattern,
        FusedAddRMSNormStaticQuantPattern,
        FusedAddRMSNormDynamicQuantPattern,
    )

RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormQuantPattern:
    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype

        assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
        self.QUANT_OP = QUANT_OPS[key.quant]

        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
        self.FUSED_OP = FUSED_OPS[key]

FUSED_OP instance-attribute

FUSED_OP = FUSED_OPS[key]

QUANT_OP instance-attribute

QUANT_OP = QUANT_OPS[quant]

epsilon instance-attribute

epsilon = epsilon

quant_dtype instance-attribute

quant_dtype = dtype

__init__

__init__(epsilon: float, key: FusedRMSQuantKey)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
    self.epsilon = epsilon
    self.quant_dtype = key.quant.dtype

    assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
    self.QUANT_OP = QUANT_OPS[key.quant]

    assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
    self.FUSED_OP = FUSED_OPS[key]

RMSNormStaticQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
        def pattern(
            result: torch.Tensor,
            result_rms: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at1 = auto_functionalized(
                RMS_OP,
                result=result_rms,
                input=input,
                weight=weight,
                epsilon=self.epsilon,
            )
            at2 = auto_functionalized(
                self.QUANT_OP, result=result, input=at1[1], scale=scale
            )

            # result
            return at2[1]

        def replacement(
            result: torch.Tensor,
            result_rms: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )

            # result
            return at[1]

        inputs = [
            torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
            empty_bf16(5, 4),  # result_rms
            empty_bf16(5, 4),  # input
            empty_bf16(1, 5),  # weight
            empty_fp32(1, 1),  # scale
        ]

        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)

__init__

__init__(
    epsilon: float, quant_dtype: dtype, symmetric=True
)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
    fused_key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(
            dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
        ),
    )
    super().__init__(epsilon, fused_key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    # Cannot use methods, as the self argument affects tracing
    def pattern(
        result: torch.Tensor,
        result_rms: torch.Tensor,
        input: torch.Tensor,
        weight: torch.Tensor,
        scale: torch.Tensor,
    ):
        at1 = auto_functionalized(
            RMS_OP,
            result=result_rms,
            input=input,
            weight=weight,
            epsilon=self.epsilon,
        )
        at2 = auto_functionalized(
            self.QUANT_OP, result=result, input=at1[1], scale=scale
        )

        # result
        return at2[1]

    def replacement(
        result: torch.Tensor,
        result_rms: torch.Tensor,
        input: torch.Tensor,
        weight: torch.Tensor,
        scale: torch.Tensor,
    ):
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
        )

        # result
        return at[1]

    inputs = [
        torch.empty(5, 4, device="cuda", dtype=self.quant_dtype),  # result
        empty_bf16(5, 4),  # result_rms
        empty_bf16(5, 4),  # input
        empty_bf16(1, 5),  # weight
        empty_fp32(1, 1),  # scale
    ]

    pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)

empty_bf16

empty_bf16(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

empty_fp32

empty_fp32(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")

empty_i32

empty_i32(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")