Skip to content

vllm.model_executor.layers.quantization.kernels.scaled_mm.xla

XLAScaledMMLinearKernel

Bases: ScaledMMLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
    @classmethod
    def get_min_capability(cls) -> int:
        raise NotImplementedError(
            "TPU platform does have a concept of compute capability, "
            "this method should not be called."
        )

    @classmethod
    def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
        if not current_platform.is_tpu():
            return False, "ScaledMMXLA requires running on TPU."

        if c.is_static_input_scheme:
            return False, "ScaledMMXLA requires dynamic activation scales."

        if not c.input_symmetric:
            return False, "ScaledMMXLA requires symmetric activation scales."

        if not c.is_channelwise:
            return False, "ScaledMMXLA requires channelwise weight scales"

        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # WEIGHT
        # [out, in] (different than cutlass_scaled_mm)
        weight = getattr(layer, self.w_q_name)
        replace_parameter(
            layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
        )

        # WEIGHT SCALE
        # XLA kernels support only per-tensor and per-channel.
        # If we have a fused module (QKV, MLP) with per tensor scales (thus N
        # scales being passed to the kernel), convert to the per-channel case.
        is_fused_module = len(layer.logical_widths) > 1
        weight_scale = getattr(layer, self.w_s_name)
        if is_fused_module and not self.config.is_channelwise:
            weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)

        # [out_channel,] (different than cutlass_scaled_mm)
        weight_scale = weight_scale.squeeze(-1)
        replace_parameter(
            layer,
            self.w_s_name,
            torch.nn.Parameter(weight_scale.data, requires_grad=False),
        )

        # Only support symmetric dynamic activation quantization.
        setattr(layer, self.i_s_name, None)
        setattr(layer, self.i_zp_name, None)
        setattr(layer, self.azp_adj_name, None)

        # Filter warning for cond usage in apply_weights. It is okay
        # to specialize the graph since bias is not dynamic.
        warnings.filterwarnings(
            "ignore",
            message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.",  # noqa: E501
        )

    def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
        return x

    def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
        return x + bias

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        w_q, w_s, _, _, _ = self._get_weight_params(layer)

        # Required to register custom ops.
        import torch_xla.experimental.custom_kernel  # noqa: F401

        out = torch.ops.xla.quantized_matmul_int8(
            x,
            w_q,
            w_s,
            quantize_activation=True,
        )

        # Explicitly capture control flow to make dynamo happy.
        # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
        return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])

add_bias

add_bias(x: Tensor, bias: Optional[Tensor])
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
    return x + bias

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    w_q, w_s, _, _, _ = self._get_weight_params(layer)

    # Required to register custom ops.
    import torch_xla.experimental.custom_kernel  # noqa: F401

    out = torch.ops.xla.quantized_matmul_int8(
        x,
        w_q,
        w_s,
        quantize_activation=True,
    )

    # Explicitly capture control flow to make dynamo happy.
    # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
    return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])

can_implement classmethod

can_implement(
    c: ScaledMMLinearLayerConfig,
) -> tuple[bool, Optional[str]]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
    if not current_platform.is_tpu():
        return False, "ScaledMMXLA requires running on TPU."

    if c.is_static_input_scheme:
        return False, "ScaledMMXLA requires dynamic activation scales."

    if not c.input_symmetric:
        return False, "ScaledMMXLA requires symmetric activation scales."

    if not c.is_channelwise:
        return False, "ScaledMMXLA requires channelwise weight scales"

    return True, None

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
@classmethod
def get_min_capability(cls) -> int:
    raise NotImplementedError(
        "TPU platform does have a concept of compute capability, "
        "this method should not be called."
    )

no_add_bias

no_add_bias(x: Tensor, bias: Optional[Tensor])
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
    return x

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # WEIGHT
    # [out, in] (different than cutlass_scaled_mm)
    weight = getattr(layer, self.w_q_name)
    replace_parameter(
        layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
    )

    # WEIGHT SCALE
    # XLA kernels support only per-tensor and per-channel.
    # If we have a fused module (QKV, MLP) with per tensor scales (thus N
    # scales being passed to the kernel), convert to the per-channel case.
    is_fused_module = len(layer.logical_widths) > 1
    weight_scale = getattr(layer, self.w_s_name)
    if is_fused_module and not self.config.is_channelwise:
        weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)

    # [out_channel,] (different than cutlass_scaled_mm)
    weight_scale = weight_scale.squeeze(-1)
    replace_parameter(
        layer,
        self.w_s_name,
        torch.nn.Parameter(weight_scale.data, requires_grad=False),
    )

    # Only support symmetric dynamic activation quantization.
    setattr(layer, self.i_s_name, None)
    setattr(layer, self.i_zp_name, None)
    setattr(layer, self.azp_adj_name, None)

    # Filter warning for cond usage in apply_weights. It is okay
    # to specialize the graph since bias is not dynamic.
    warnings.filterwarnings(
        "ignore",
        message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.",  # noqa: E501
    )