"""W4A16 weight-only quantized GEMM for RTX PRO 6000 (SM120). AWQ/GPTQ-style asymmetric int4 with bf16 per-group scales/zeros. Fused unpack + GEMM via Triton. The scheme: w_bf[k, n] = (unpack(w_q)[k, n] - zeros[k // 128, n]) * scales[k // 128, n] out[m, n] = sum_k x[m, k] * w_bf[k, n] Key optimizations: - Two-dot pattern: process even/odd K rows with two separate dots. Each group_size=128 K rows become two 64-wide dots that both read the same x row span but different nibbles of the packed weight byte. This matches the bit-level rounding of the reference's cuBLAS bf16 GEMM. - Triton autotune over a focused set of configs to find the best per shape. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "gemm_w4a16" SUPPORTED_PRECISIONS = ["int4_bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] GROUP_SIZE = 128 # --------------------------------------------------------------------------- # Triton kernel # --------------------------------------------------------------------------- _CONFIGS = [ # M=1 / decode variants — small BN, BLOCK_M=1 triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3), triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3), triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3), triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 1024, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3), triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 1024, "GROUP_SZ": 128}, num_warps=4, num_stages=2), # M>=16 / prefill (BLOCK_M=16 is tensor-core minimum) triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3), triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3), triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2), ] @triton.autotune(configs=_CONFIGS, key=["M", "N", "K"]) @triton.jit def w4a16_gemm_kernel( X, WQ, S, Z, OUT, M, N, K, stride_xm, stride_xk, stride_wk, stride_wn, stride_sg, stride_sn, stride_zg, stride_zn, stride_om, stride_on, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SZ: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m_mask = offs_m < M n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ GROUP_HALF: tl.constexpr = GROUP_SZ // 2 BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k_blk in tl.range(0, K, BLOCK_K): wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF) wq = tl.load( WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn, mask=offs_n[None, :] < N, other=0, ) w_lo = (wq & 0xF).to(tl.bfloat16) w_hi = (wq >> 4).to(tl.bfloat16) g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk) s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn, mask=offs_n[None, :] < N, other=0.0) z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn, mask=offs_n[None, :] < N, other=0.0) s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N)) s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N)) z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N)) z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N)) w_lo = (w_lo - z) * s w_hi = (w_hi - z) * s x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF) x_offs_odd = x_offs_even + 1 x_even = tl.load( X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk, mask=offs_m_mask[:, None], other=0.0, ) x_odd = tl.load( X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk, mask=offs_m_mask[:, None], other=0.0, ) acc += tl.dot(x_even, w_lo) acc += tl.dot(x_odd, w_hi) out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on tl.store(out_ptrs, acc.to(tl.bfloat16), mask=offs_m_mask[:, None] & (offs_n[None, :] < N)) # --------------------------------------------------------------------------- # Module # --------------------------------------------------------------------------- class Model(nn.Module): def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE): super().__init__() assert K % group_size == 0 assert K % 2 == 0 self.M, self.N, self.K = M, N, K self.group_size = group_size n_groups = K // group_size w_q = torch.zeros(K // 2, N, dtype=torch.uint8, device="cuda") scales = torch.zeros(n_groups, N, dtype=torch.bfloat16, device="cuda") zeros = torch.zeros(n_groups, N, dtype=torch.bfloat16, device="cuda") self.register_buffer("w_q", w_q) self.register_buffer("scales", scales) self.register_buffer("zeros", zeros) def forward(self, x: torch.Tensor) -> torch.Tensor: M, N, K = self.M, self.N, self.K assert x.shape == (M, K), f"x shape mismatch: {x.shape} vs ({M},{K})" assert x.dtype == torch.bfloat16 x = x.contiguous() out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device) grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]), ) w4a16_gemm_kernel[grid]( x, self.w_q, self.scales, self.zeros, out, M, N, K, x.stride(0), x.stride(1), self.w_q.stride(0), self.w_q.stride(1), self.scales.stride(0), self.scales.stride(1), self.zeros.stride(0), self.zeros.stride(1), out.stride(0), out.stride(1), ) return out M = 1 N = 12288 K = 4096 def get_inputs(): x = torch.randn(M, K, dtype=torch.bfloat16) return [x] def get_init_inputs(): return [M, N, K]