class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
@staticmethod
def _pad_to_alignment(
x: torch.Tensor, dim: int, alignment: int, value: float = 0.0
) -> torch.Tensor:
"""Pad tensor ``x`` along ``dim`` to the next multiple of
``alignment``."""
remainder = x.shape[dim] % alignment
if remainder == 0:
return x
pad_size = alignment - remainder
pad_spec = [0] * (2 * x.dim())
pad_spec[-(2 * dim + 1)] = pad_size
return torch.nn.functional.pad(x, pad_spec, value=value)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Per-tensor/Per-channel padding to use Cutlass instead of Triton.
K, N = B.shape
pad_k = (16 - K % 16) % 16
pad_n = (16 - N % 16) % 16
if pad_k > 0 or pad_n > 0:
# B is column-major [K, N]. Transpose to row-major [N, K],
# pad both dims in one call, then transpose back so the
# result keeps column-major layout with stride (1, K_padded).
B = torch.nn.functional.pad(B.t().contiguous(), (0, pad_k, 0, pad_n)).t()
if pad_k > 0:
A = self._pad_to_alignment(A, dim=1, alignment=16)
if pad_n > 0:
if bias is not None:
bias = self._pad_to_alignment(bias, dim=0, alignment=16)
# Bs is per-tensor (numel==1) or per-channel (numel==N)
# in this kernel class — never 2D block-wise.
if Bs.numel() > 1:
Bs = self._pad_to_alignment(
Bs.view(-1), dim=0, alignment=16, value=1.0
)
if Bs.dim() == 1 and B.shape[1] > 1:
Bs = Bs.view(-1, 1)
output = ops.cutlass_scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
if pad_n > 0:
output = output[..., :N].contiguous()
return output.view(*output_shape)