Bases: CutlassScaledMMLinearKernel
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not "
+ "currently supported on CPU.",
)
if not c.input_symmetric:
return (
False,
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
|
apply_weights
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
|
can_implement classmethod
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not "
+ "currently supported on CPU.",
)
if not c.input_symmetric:
return (
False,
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return True, None
|
get_min_capability classmethod
get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| @classmethod
def get_min_capability(cls) -> int:
return 75
|
process_weights_after_loading
process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
|