Skip to content

vllm.model_executor.layers.quantization.kernels.mixed_precision.conch

_CONCH_SUPPORTED_GROUP_SIZES module-attribute

_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]

_CONCH_SUPPORTED_WEIGHT_TYPES module-attribute

_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
    uint4,
    uint8,
    uint4b8,
    uint8b128,
]

ConchLinearKernel

Bases: MPLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py
class ConchLinearKernel(MPLinearKernel):
    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
        if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
            error_msg = (
                f"Weight type ({c.weight_type}) not supported by "
                "ConchLinearKernel, supported types are: "
                f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
            )
            return False, error_msg

        if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
            error_msg = (
                f"Group size ({c.group_size}) not supported by "
                "ConchLinearKernel, supported group sizes are: "
                f"{_CONCH_SUPPORTED_GROUP_SIZES}"
            )
            return False, error_msg

        if find_spec("conch") is None:
            error_msg = (
                "conch-triton-kernels is not installed, please "
                "install it via `pip install conch-triton-kernels` "
                "and try again!"
            )
            return False, error_msg

        return True, None

    # note assumes that
    #  `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
    #  `weight_scale` is: {input_dim = 0, output_dim = 1}
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        def transform_w_q(x):
            assert isinstance(x, BasevLLMParameter)
            permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
            x.data = x.data.contiguous()
            return x

        def transform_w_s(x):
            assert isinstance(x, BasevLLMParameter)
            permute_param_layout_(x, input_dim=0, output_dim=1)
            x.data = x.data.contiguous()
            return x

        self._transform_param(layer, self.w_q_name, transform_w_q)
        self._transform_param(layer, self.w_s_name, transform_w_s)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        from conch.ops.quantization.gemm import mixed_precision_gemm

        w_q, w_s, w_zp, _ = self._get_weight_params(layer)

        output = mixed_precision_gemm(
            x=x,
            w_q_packed=w_q.data,
            w_s=w_s.data,
            w_zp=w_zp.data if w_zp is not None else None,
            weight_size_bits=self.config.weight_type.size_bits,
            weight_bias=self.config.weight_type.bias,
            group_size=self.config.group_size,
        )

        if bias is not None:
            output.add_(bias)  # In-place add

        return output

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    from conch.ops.quantization.gemm import mixed_precision_gemm

    w_q, w_s, w_zp, _ = self._get_weight_params(layer)

    output = mixed_precision_gemm(
        x=x,
        w_q_packed=w_q.data,
        w_s=w_s.data,
        w_zp=w_zp.data if w_zp is not None else None,
        weight_size_bits=self.config.weight_type.size_bits,
        weight_bias=self.config.weight_type.bias,
        group_size=self.config.group_size,
    )

    if bias is not None:
        output.add_(bias)  # In-place add

    return output

can_implement classmethod

can_implement(
    c: MPLinearLayerConfig,
) -> tuple[bool, Optional[str]]
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
    if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
        error_msg = (
            f"Weight type ({c.weight_type}) not supported by "
            "ConchLinearKernel, supported types are: "
            f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
        )
        return False, error_msg

    if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
        error_msg = (
            f"Group size ({c.group_size}) not supported by "
            "ConchLinearKernel, supported group sizes are: "
            f"{_CONCH_SUPPORTED_GROUP_SIZES}"
        )
        return False, error_msg

    if find_spec("conch") is None:
        error_msg = (
            "conch-triton-kernels is not installed, please "
            "install it via `pip install conch-triton-kernels` "
            "and try again!"
        )
        return False, error_msg

    return True, None

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    def transform_w_q(x):
        assert isinstance(x, BasevLLMParameter)
        permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
        x.data = x.data.contiguous()
        return x

    def transform_w_s(x):
        assert isinstance(x, BasevLLMParameter)
        permute_param_layout_(x, input_dim=0, output_dim=1)
        x.data = x.data.contiguous()
        return x

    self._transform_param(layer, self.w_q_name, transform_w_q)
    self._transform_param(layer, self.w_s_name, transform_w_s)