Skip to content

vllm.model_executor.layers.quantization.compressed_tensors.transform.module

HadamardTransform

Bases: Module

Class which handles weight loading, postprocessing, and application of transforms. Meant to be used with CompressedTensorsLinearTransformMethod and attention transforms method (not implemented yet)

Source code in vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
class HadamardTransform(torch.nn.Module):
    """
    Class which handles weight loading, postprocessing, and application of
    transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
    and attention transforms method (not implemented yet)
    """

    transforms: dict[int, TransformTuple]  # info parsed from transforms config
    weight: SharedWeightParameter  # container for shared tensors

    scales: dict[int, float]  # hadamard scale, usually sqrt(matrix.size(0))

    def __init__(
        self,
        transforms: dict[int, TransformTuple],
        layer: torch.nn.Module,
        weight_loader: Callable,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
    ):
        super().__init__()
        self.transforms = transforms
        self.scales = {}

        if get_tensor_model_parallel_world_size() > 1:
            raise NotImplementedError(
                "Online transforms with tensor parallelism is not supported"
            )

        # Similar to row/col parallel params, but tensors are separate
        # to allow for loading with shared memory
        self.weight = SharedWeightParameter(weight_loader=weight_loader)

        # create shared partition data for each partition of the original weight
        input_size = input_size_per_partition
        for part_index, (_scheme_name, scheme, args) in self.transforms.items():
            output_size = output_partition_sizes[part_index]
            weight_size = self._get_weight_size(
                layer, scheme, args, input_size, output_size
            )

            data_key = self._get_data_key(scheme, weight_size)
            self.weight.add_partition(
                part_index,
                data_key,
                size=(weight_size, weight_size),
                dtype=scheme.precision,
            )

        # validate that shared tensors and schemes are correct
        self._validate_input_transforms()

    def process_weights_after_loading(self):
        for part_id in self.weight.partitions:
            data = self.weight.partitions[part_id].data

            # required by torch.compile
            self.weight.process_weights_after_loading()

            # precompute scale as a runtime multiply, not division
            # do not fold into weight in order to utilize FWHT
            self.scales[part_id] = 1 / math.sqrt(data.size(0))

            # FUTURE: avoid runtime transpose by processing weights
            # prior to apply

    def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
        if part_id not in self.weight.partitions:
            return value

        # use hadacore if possible
        if self.transforms[part_id].scheme.type == "hadamard":
            if self.transforms[part_id].scheme.head_dim is not None:
                weight_size = self.transforms[part_id].scheme.head_dim
                value = value.unflatten(-1, (-1, weight_size))
                value = ops.hadacore_transform(value)
                value = value.flatten(-2, -1)

                return value

            # sylvester transforms are symmetric, inv => transpose => original
            return ops.hadacore_transform(value)

        # fall back to dense
        else:
            weight = self.weight.partitions[part_id]
            weight = (
                weight if self.transforms[part_id].args.inverse else weight.T
            )  # linear := x(W.T)
            scale = self.scales[part_id]

            if self.transforms[part_id].scheme.head_dim is not None:
                value = value.unflatten(-1, (-1, weight.size(0)))
                value = (
                    dispatch_unquantized_gemm()(
                        self, value.to(weight.dtype), weight, None
                    ).to(value.dtype)
                    * scale
                )
                value = value.flatten(-2, -1)

                return value

            return (
                dispatch_unquantized_gemm()(
                    self, value.to(weight.dtype), weight, None
                ).to(value.dtype)
                * scale
            )

    def _get_data_key(self, scheme: TransformScheme, weight_size: int) -> Hashable:
        return (id(scheme), weight_size)

    def _get_weight_size(
        self,
        layer: torch.nn.Module,
        scheme: TransformScheme,
        args: TransformArgs,
        input_size: int,
        output_size: int,
    ) -> int:
        if scheme.head_dim is not None:
            return scheme.head_dim

        if isinstance(layer, LinearBase):
            if args.location == TransformLocation.INPUT:
                return input_size

            elif args.location == TransformLocation.OUTPUT:
                return output_size

        elif isinstance(layer, VocabParallelEmbedding):
            if args.location == TransformLocation.INPUT:
                return output_size

            elif args.location == TransformLocation.OUTPUT:
                return input_size

        raise ValueError()

    def _validate_input_transforms(self):
        assert len(self.transforms) > 0
        location = list(self.transforms.values())[0].args.location

        if location == TransformLocation.INPUT:
            first_data = self.weight.partitions[0].data
            for partition in self.weight.partitions.values():
                if partition.data.data_ptr() != first_data.data_ptr():
                    raise ValueError("")

scales instance-attribute

scales: dict[int, float] = {}

transforms instance-attribute

transforms: dict[int, TransformTuple] = transforms

weight instance-attribute

weight: SharedWeightParameter = SharedWeightParameter(
    weight_loader=weight_loader
)

__init__

__init__(
    transforms: dict[int, TransformTuple],
    layer: Module,
    weight_loader: Callable,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
def __init__(
    self,
    transforms: dict[int, TransformTuple],
    layer: torch.nn.Module,
    weight_loader: Callable,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
):
    super().__init__()
    self.transforms = transforms
    self.scales = {}

    if get_tensor_model_parallel_world_size() > 1:
        raise NotImplementedError(
            "Online transforms with tensor parallelism is not supported"
        )

    # Similar to row/col parallel params, but tensors are separate
    # to allow for loading with shared memory
    self.weight = SharedWeightParameter(weight_loader=weight_loader)

    # create shared partition data for each partition of the original weight
    input_size = input_size_per_partition
    for part_index, (_scheme_name, scheme, args) in self.transforms.items():
        output_size = output_partition_sizes[part_index]
        weight_size = self._get_weight_size(
            layer, scheme, args, input_size, output_size
        )

        data_key = self._get_data_key(scheme, weight_size)
        self.weight.add_partition(
            part_index,
            data_key,
            size=(weight_size, weight_size),
            dtype=scheme.precision,
        )

    # validate that shared tensors and schemes are correct
    self._validate_input_transforms()

_get_data_key

_get_data_key(
    scheme: TransformScheme, weight_size: int
) -> Hashable
Source code in vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
def _get_data_key(self, scheme: TransformScheme, weight_size: int) -> Hashable:
    return (id(scheme), weight_size)

_get_weight_size

_get_weight_size(
    layer: Module,
    scheme: TransformScheme,
    args: TransformArgs,
    input_size: int,
    output_size: int,
) -> int
Source code in vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
def _get_weight_size(
    self,
    layer: torch.nn.Module,
    scheme: TransformScheme,
    args: TransformArgs,
    input_size: int,
    output_size: int,
) -> int:
    if scheme.head_dim is not None:
        return scheme.head_dim

    if isinstance(layer, LinearBase):
        if args.location == TransformLocation.INPUT:
            return input_size

        elif args.location == TransformLocation.OUTPUT:
            return output_size

    elif isinstance(layer, VocabParallelEmbedding):
        if args.location == TransformLocation.INPUT:
            return output_size

        elif args.location == TransformLocation.OUTPUT:
            return input_size

    raise ValueError()

_validate_input_transforms

_validate_input_transforms()
Source code in vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
def _validate_input_transforms(self):
    assert len(self.transforms) > 0
    location = list(self.transforms.values())[0].args.location

    if location == TransformLocation.INPUT:
        first_data = self.weight.partitions[0].data
        for partition in self.weight.partitions.values():
            if partition.data.data_ptr() != first_data.data_ptr():
                raise ValueError("")

forward

forward(value: Tensor, part_id: int = 0) -> Tensor
Source code in vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
    if part_id not in self.weight.partitions:
        return value

    # use hadacore if possible
    if self.transforms[part_id].scheme.type == "hadamard":
        if self.transforms[part_id].scheme.head_dim is not None:
            weight_size = self.transforms[part_id].scheme.head_dim
            value = value.unflatten(-1, (-1, weight_size))
            value = ops.hadacore_transform(value)
            value = value.flatten(-2, -1)

            return value

        # sylvester transforms are symmetric, inv => transpose => original
        return ops.hadacore_transform(value)

    # fall back to dense
    else:
        weight = self.weight.partitions[part_id]
        weight = (
            weight if self.transforms[part_id].args.inverse else weight.T
        )  # linear := x(W.T)
        scale = self.scales[part_id]

        if self.transforms[part_id].scheme.head_dim is not None:
            value = value.unflatten(-1, (-1, weight.size(0)))
            value = (
                dispatch_unquantized_gemm()(
                    self, value.to(weight.dtype), weight, None
                ).to(value.dtype)
                * scale
            )
            value = value.flatten(-2, -1)

            return value

        return (
            dispatch_unquantized_gemm()(
                self, value.to(weight.dtype), weight, None
            ).to(value.dtype)
            * scale
        )

process_weights_after_loading

process_weights_after_loading()
Source code in vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
def process_weights_after_loading(self):
    for part_id in self.weight.partitions:
        data = self.weight.partitions[part_id].data

        # required by torch.compile
        self.weight.process_weights_after_loading()

        # precompute scale as a runtime multiply, not division
        # do not fold into weight in order to utilize FWHT
        self.scales[part_id] = 1 / math.sqrt(data.size(0))