Skip to content

vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass

CutlassW4A8LinearKernel

Bases: MPLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
class CutlassW4A8LinearKernel(MPLinearKernel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # dynamic per-tok fp8 activation quantization
        self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)

    @classmethod
    def get_min_capability(cls) -> int:
        return 90

    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
        if not current_platform.is_cuda():
            return False, "CUTLASS only supported on CUDA"

        if not current_platform.is_device_capability(90):
            return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"

        if c.act_type != torch.float8_e4m3fn:
            return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"

        if c.has_g_idx:
            return False, "Act reordering not supported by CUTLASS W4A8"

        if c.zero_points:
            return False, "Zero points not supported by CUTLASS W4A8"

        if c.weight_type != scalar_types.int4:
            return (
                False,
                f"Quant type ({c.weight_type}) not supported by "
                "CUTLASS W4A8, only supported int4",
            )

        # TODO(czhu): support -1 (column-wise)
        if c.group_size != 128:
            return False, "Only group_size 128 is supported"

        in_features, out_features = c.partition_weight_shape
        if in_features % 128 or out_features % 128:
            return (
                False,
                f"K and N must be divisible by 128, got {c.partition_weight_shape}",
            )

        if c.out_type != torch.bfloat16:
            return (
                False,
                f"Only bfloat16 output type currently supportedgot {c.out_type=}",
            )

        return True, None

    # note assumes that
    #  `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
    #  `weight_scale`  is: {input_dim = 0, output_dim = 1}
    def process_weights_after_loading(self, layer: torch.nn.Module):
        # TODO(czhu): optimize speed/mem usage
        def transform_w_q(x):
            assert isinstance(x, BasevLLMParameter)
            permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
            x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
            return x

        def transform_w_s(x):
            assert isinstance(x, BasevLLMParameter)
            permute_param_layout_(x, input_dim=0, output_dim=1)
            x.data = x.data.contiguous().to(torch.float8_e4m3fn)
            x.data = ops.cutlass_pack_scale_fp8(x.data)
            return x

        # Encode/reorder weights and pack scales
        self._transform_param(layer, self.w_q_name, transform_w_q)
        self._transform_param(layer, self.w_s_name, transform_w_s)
        self._transform_param(layer, "weight_chan_scale", lambda x: x)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        c = self.config
        w_q, w_s, _, _ = self._get_weight_params(layer)
        w_ch_s = layer.weight_chan_scale

        x_2d = x.reshape(-1, x.shape[-1])
        out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)

        x_2d, act_scales = self.quant_fp8(x_2d)
        output = ops.cutlass_w4a8_mm(
            a=x_2d,
            b_q=w_q,
            b_group_scales=w_s,
            b_group_size=c.group_size,
            a_token_scales=act_scales,
            b_channel_scales=w_ch_s,
        )

        if bias is not None:
            output.add_(bias)  # In-place add

        return output.reshape(out_shape)

quant_fp8 instance-attribute

quant_fp8 = QuantFP8(static=False, group_shape=PER_TOKEN)

__init__

__init__(*args, **kwargs)
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    # dynamic per-tok fp8 activation quantization
    self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    c = self.config
    w_q, w_s, _, _ = self._get_weight_params(layer)
    w_ch_s = layer.weight_chan_scale

    x_2d = x.reshape(-1, x.shape[-1])
    out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)

    x_2d, act_scales = self.quant_fp8(x_2d)
    output = ops.cutlass_w4a8_mm(
        a=x_2d,
        b_q=w_q,
        b_group_scales=w_s,
        b_group_size=c.group_size,
        a_token_scales=act_scales,
        b_channel_scales=w_ch_s,
    )

    if bias is not None:
        output.add_(bias)  # In-place add

    return output.reshape(out_shape)

can_implement classmethod

can_implement(
    c: MPLinearLayerConfig,
) -> tuple[bool, Optional[str]]
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
    if not current_platform.is_cuda():
        return False, "CUTLASS only supported on CUDA"

    if not current_platform.is_device_capability(90):
        return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"

    if c.act_type != torch.float8_e4m3fn:
        return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"

    if c.has_g_idx:
        return False, "Act reordering not supported by CUTLASS W4A8"

    if c.zero_points:
        return False, "Zero points not supported by CUTLASS W4A8"

    if c.weight_type != scalar_types.int4:
        return (
            False,
            f"Quant type ({c.weight_type}) not supported by "
            "CUTLASS W4A8, only supported int4",
        )

    # TODO(czhu): support -1 (column-wise)
    if c.group_size != 128:
        return False, "Only group_size 128 is supported"

    in_features, out_features = c.partition_weight_shape
    if in_features % 128 or out_features % 128:
        return (
            False,
            f"K and N must be divisible by 128, got {c.partition_weight_shape}",
        )

    if c.out_type != torch.bfloat16:
        return (
            False,
            f"Only bfloat16 output type currently supportedgot {c.out_type=}",
        )

    return True, None

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
@classmethod
def get_min_capability(cls) -> int:
    return 90

process_weights_after_loading

process_weights_after_loading(layer: Module)
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
def process_weights_after_loading(self, layer: torch.nn.Module):
    # TODO(czhu): optimize speed/mem usage
    def transform_w_q(x):
        assert isinstance(x, BasevLLMParameter)
        permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
        x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
        return x

    def transform_w_s(x):
        assert isinstance(x, BasevLLMParameter)
        permute_param_layout_(x, input_dim=0, output_dim=1)
        x.data = x.data.contiguous().to(torch.float8_e4m3fn)
        x.data = ops.cutlass_pack_scale_fp8(x.data)
        return x

    # Encode/reorder weights and pack scales
    self._transform_param(layer, self.w_q_name, transform_w_q)
    self._transform_param(layer, self.w_s_name, transform_w_s)
    self._transform_param(layer, "weight_chan_scale", lambda x: x)