class IndexerQMxFp4Kernel:
"""Eight-thread subwarps process one ``(token, head)`` row."""
def __init__(
self,
head_dim: int = 128,
rope_dim: int = 64,
num_heads: int = 64,
cos_sin_dtype: type[cutlass.Numeric] = Float32,
coarsen: int = 4,
):
self.head_dim = head_dim
self.rope_dim = rope_dim
self.nope_dim = head_dim - rope_dim
self.num_heads = num_heads
self.cos_sin_dtype = cos_sin_dtype
# process multiple heads at the same time to armotize RoPE load costs
assert num_heads % coarsen == 0
self.coarsen = coarsen
# later we will use 32B load = 16 BF16 elems
# thus, head_dim=128 requires 8 threads to handle.
# let's call subwarp = 8 threads.
self.subwarp_size = head_dim // 16
self.tb_size = 128
self.threads_per_token = (self.num_heads // self.coarsen) * self.subwarp_size
@cute.jit
def __call__(
self,
positions: cute.Tensor,
q: cute.Tensor,
cos_sin_cache: cute.Tensor,
weights: cute.Tensor,
q_fp4: cute.Tensor,
q_scale: cute.Tensor,
weights_out: cute.Tensor,
scale: Float32,
stream: CUstream,
):
total_threads = q.shape[0] * self.threads_per_token
grid = (cute.ceil_div(total_threads, self.tb_size), 1, 1)
self.kernel(
positions,
q,
cos_sin_cache,
weights,
q_fp4,
q_scale,
weights_out,
scale,
).launch(grid=grid, block=(self.tb_size, 1, 1), stream=stream)
@cute.kernel
def kernel(
self,
positions: cute.Tensor,
q: cute.Tensor,
cos_sin_cache: cute.Tensor,
weights: cute.Tensor,
q_fp4: cute.Tensor,
q_scale: cute.Tensor,
weights_out: cute.Tensor,
scale: Float32,
):
block_id, _, _ = cute.arch.block_idx()
tid, _, _ = cute.arch.thread_idx()
num_token_heads = q.shape[0] * self.num_heads
global_tid = block_id * self.tb_size + tid
global_subwarp_id = global_tid // self.subwarp_size
sublane = tid % self.subwarp_size
token_id = global_subwarp_id // (self.num_heads // self.coarsen)
head_tile_id = global_subwarp_id % (self.num_heads // self.coarsen)
head_start = head_tile_id * self.coarsen
# NOTE: token_id may exceed bounds, hence we need to add load/store guards
# we can't do early exit because CuteDSL doesn't support it. and we also need
# all threads in a warp to be active since we utilize warp shuffle later.
# must_in_bounds is constexpr, True when 1 threadblock fit within 1 token
# position. the compiler will remove bounds check when that happens.
must_in_bounds = cutlass.const_expr(self.tb_size % self.threads_per_token == 0)
in_bounds = must_in_bounds or (token_id < q.shape[0])
cp_op = cute.nvgpu.CopyUniversalOp()
_layout = cute.make_layout((self.coarsen, 8), stride=(8, 1))
q_bf16x2 = cute.make_rmem_tensor(_layout, Uint32)
if in_bounds:
# we can't do cute.copy() on the whole 2D tile directly because
# cute.copy() wants the 1st mode to be covered by the copy atom,
# and other modes as for loop. there is no fast way to
# "transpose" the tensor view.
q_tile = cute.local_tile(
q[token_id, None, None],
tiler=(self.coarsen, 16),
coord=(head_tile_id, sublane),
)
cp_u32x8 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=256)
for i in cutlass.range_constexpr(self.coarsen):
src = cute.recast_tensor(q_tile[i, None], Uint32)
cute.copy(cp_u32x8, src, q_bf16x2[i, None])
# RoPE applies only to the trailing rope_dim values. We keep the rounded
# BF16 result in q_bits so the later amax and quantization see BF16.
# cos_sin_cache layout: [max_pos, rope_dim]
if in_bounds and sublane * 16 >= self.nope_dim:
cos_vals = cute.make_rmem_tensor((8,), Float32)
sin_vals = cute.make_rmem_tensor((8,), Float32)
pos = positions[token_id]
# select 8 elems from cos and sin
cos_id = sublane - self.nope_dim // 16
sin_id = cos_id + self.rope_dim // 16
cos_src = cute.local_tile(
cos_sin_cache[pos, None], tiler=(8,), coord=(cos_id,)
)
sin_src = cute.local_tile(
cos_sin_cache[pos, None], tiler=(8,), coord=(sin_id,)
)
cp_f32x8 = cute.make_copy_atom(cp_op, Float32, num_bits_per_copy=256)
cp_u32x4 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=128)
if const_expr(self.cos_sin_dtype is Float32):
cute.copy(cp_f32x8, cos_src, cos_vals)
cute.copy(cp_f32x8, sin_src, sin_vals)
else:
cos_bf16x2 = cute.make_rmem_tensor((4,), Uint32)
sin_bf16x2 = cute.make_rmem_tensor((4,), Uint32)
cute.copy(cp_u32x4, cute.recast_tensor(cos_src, Uint32), cos_bf16x2)
cute.copy(cp_u32x4, cute.recast_tensor(sin_src, Uint32), sin_bf16x2)
for i in cutlass.range_constexpr(4):
cos0, cos1 = _bf16x2_to_fp32(cos_bf16x2[i])
sin0, sin1 = _bf16x2_to_fp32(sin_bf16x2[i])
cos_vals[i * 2] = cos0
cos_vals[i * 2 + 1] = cos1
sin_vals[i * 2] = sin0
sin_vals[i * 2 + 1] = sin1
for i in cutlass.range_constexpr(self.coarsen):
for j in cutlass.range_constexpr(8):
q0, q1 = _bf16x2_to_fp32(q_bf16x2[i, j])
rot0 = q0 * cos_vals[j] - q1 * sin_vals[j]
rot1 = q0 * sin_vals[j] + q1 * cos_vals[j]
# convert back to BF16 to match numerics
q_bf16x2[i, j] = _fp32x2_to_bf16x2(rot0, rot1)
# layout: [coarsen, 8]
q_fp4_tile = cute.local_tile(
q_fp4[token_id, None, None],
tiler=(self.coarsen, 8),
coord=(head_tile_id, sublane),
)
for i in cutlass.range_constexpr(self.coarsen):
# compute amax in packed bf16x2 to save instructions
# Each thread holds 16 elems. Two adjacent threads form one 32-elem
# MXFP4 block, so a width-2 shuffle gives the block amax.
amax_bf16x2 = _bf16x2_abs(q_bf16x2[i, 0])
for j in cutlass.range_constexpr(1, 8):
amax_bf16x2 = _bf16x2_max(amax_bf16x2, _bf16x2_abs(q_bf16x2[i, j]))
amax_bf16x2 = cute_utils.warp_reduce(
amax_bf16x2,
_bf16x2_max,
width=MXFP4_BLOCK_SIZE // 16,
)
amax_pair = _bf16x2_to_fp32(amax_bf16x2)
amax = cute_utils.fmax(amax_pair[0], amax_pair[1])
if in_bounds:
# compute block scale with bit manipulation
# UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask
# increments the exponent whenever fp4_scale is not exactly a power of 2
eps = cutlass.const_expr(float.fromhex("0x6p-126"))
fp4_scale = cute_utils.fmax(amax, eps) * Float32(1.0 / 6.0)
bits = _recast_val(fp4_scale, Uint32)
ue8m0 = cute_utils.shr_u32(
bits + Uint32(0x7FFFFF), Uint32(23)
) & Uint32(0xFF)
# Only one of the two threads in an MXFP4 block writes the shared scale.
if tid % 2 == 0:
mx_block = sublane // 2
q_scale[token_id, head_start + i, mx_block] = Uint8(ue8m0)
# If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent
# -A + 127 = 254 - ue8m0.
inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23)
inv_fp4_scale = _recast_val(inv_scale_bits, Float32)
vals = cute.make_rmem_tensor(16, Float32)
for j in cutlass.range_constexpr(8):
q0, q1 = _bf16x2_to_fp32(q_bf16x2[i, j])
vals[j * 2] = q0 * inv_fp4_scale
vals[j * 2 + 1] = q1 * inv_fp4_scale
# pack to FP4
packed = cute.make_rmem_tensor((2,), Uint32)
packed[0] = _fp32x8_to_fp4x8(vals, 0)
packed[1] = _fp32x8_to_fp4x8(vals, 8)
dst = q_fp4_tile[i, None]
cp_u32x2 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=64)
cute.copy(cp_u32x2, packed, cute.recast_tensor(dst, Uint32))
# Weight scaling is independent of the Q subwarp work. The first
# num_tokens * num_heads logical threads cover one weight each.
if global_tid < num_token_heads:
weight_token_id = global_tid // self.num_heads
weight_head_id = global_tid % self.num_heads
weights_out[weight_token_id, weight_head_id] = (
weights[weight_token_id, weight_head_id].to(Float32) * scale
)
@cache
@staticmethod
def compile(
head_dim: int = 128,
rope_dim: int = 64,
num_heads: int = 64,
cos_sin_dtype: type[cutlass.Numeric] = Float32,
coarsen: int = 4,
):
num_tokens = cute.sym_int()
max_pos = cute.sym_int()
q = make_fake_tensor(
BFloat16, (num_tokens, num_heads, head_dim), divisibility=16
)
positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1)
cos_sin_cache = make_fake_tensor(
cos_sin_dtype,
(max_pos, rope_dim),
divisibility=8,
)
weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8)
q_fp4 = make_fake_tensor(
Uint8,
(num_tokens, num_heads, head_dim // 2),
divisibility=16,
)
q_scale = make_fake_tensor(
Uint8,
(num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE),
divisibility=4,
)
weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4)
kernel = IndexerQMxFp4Kernel(
head_dim, rope_dim, num_heads, cos_sin_dtype, coarsen
)
stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
return cute.compile(
kernel,
positions,
q,
cos_sin_cache,
weights,
q_fp4,
q_scale,
weights_out,
Float32(0.0),
stream,
options="--enable-tvm-ffi",
)