Skip to content

vllm.model_executor.kernels.linear.scaled_mm.cutlass

CutlassFP8ScaledMMLinearKernel

Bases: FP8ScaledMMLinearKernel

Source code in vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_cuda():
            return False, "requires CUDA."
        return True, None

    @classmethod
    def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    @staticmethod
    def _pad_to_alignment(
        x: torch.Tensor, dim: int, alignment: int, value: float = 0.0
    ) -> torch.Tensor:
        """Pad tensor ``x`` along ``dim`` to the next multiple of
        ``alignment``."""
        remainder = x.shape[dim] % alignment
        if remainder == 0:
            return x
        pad_size = alignment - remainder
        pad_spec = [0] * (2 * x.dim())
        pad_spec[-(2 * dim + 1)] = pad_size
        return torch.nn.functional.pad(x, pad_spec, value=value)

    def apply_scaled_mm(
        self,
        *,
        A: torch.Tensor,
        B: torch.Tensor,
        out_dtype: torch.dtype,
        As: torch.Tensor,
        Bs: torch.Tensor,
        bias: torch.Tensor | None,
        output_shape: list,
    ) -> torch.Tensor:
        # Per-tensor/Per-channel padding to use Cutlass instead of Triton.
        K, N = B.shape
        pad_k = (16 - K % 16) % 16
        pad_n = (16 - N % 16) % 16

        if pad_k > 0 or pad_n > 0:
            # B is column-major [K, N].  Transpose to row-major [N, K],
            # pad both dims in one call, then transpose back so the
            # result keeps column-major layout with stride (1, K_padded).
            B = torch.nn.functional.pad(B.t().contiguous(), (0, pad_k, 0, pad_n)).t()

            if pad_k > 0:
                A = self._pad_to_alignment(A, dim=1, alignment=16)
            if pad_n > 0:
                if bias is not None:
                    bias = self._pad_to_alignment(bias, dim=0, alignment=16)
                # Bs is per-tensor (numel==1) or per-channel (numel==N)
                # in this kernel class — never 2D block-wise.
                if Bs.numel() > 1:
                    Bs = self._pad_to_alignment(
                        Bs.view(-1), dim=0, alignment=16, value=1.0
                    )
                    if Bs.dim() == 1 and B.shape[1] > 1:
                        Bs = Bs.view(-1, 1)

        output = ops.cutlass_scaled_mm(
            A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
        )

        if pad_n > 0:
            output = output[..., :N].contiguous()

        return output.view(*output_shape)

_pad_to_alignment staticmethod

_pad_to_alignment(
    x: Tensor, dim: int, alignment: int, value: float = 0.0
) -> Tensor

Pad tensor x along dim to the next multiple of alignment.

Source code in vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
@staticmethod
def _pad_to_alignment(
    x: torch.Tensor, dim: int, alignment: int, value: float = 0.0
) -> torch.Tensor:
    """Pad tensor ``x`` along ``dim`` to the next multiple of
    ``alignment``."""
    remainder = x.shape[dim] % alignment
    if remainder == 0:
        return x
    pad_size = alignment - remainder
    pad_spec = [0] * (2 * x.dim())
    pad_spec[-(2 * dim + 1)] = pad_size
    return torch.nn.functional.pad(x, pad_spec, value=value)