"""W4A16 weight-only int4 quantized GEMM (AWQ/GPTQ-style asymmetric) for SM120. Fused unpack + dequant + GEMM. Two paths, dispatched on M: * M == 1 (decode, bandwidth-bound): a hand-written CUDA GEMV (load_inline). It splits the K dimension at packed-row granularity (finer than the 128-wide group) so it reaches full occupancy AND keeps VEC-wide coalesced loads -- the group-level split-K of a tensor-core GEMM is capped at NGROUPS and cannot fill the machine for M=1. The 64-row inner loop is fully unrolled for memory-level parallelism (the access is latency-bound, not raw-bandwidth-bound). Two variants: an intra-block reduction (one launch, no memset) and an atomic split-K (more live blocks for small N). Reaches ~71% of DRAM peak. * M >= 16 (prefill / spec-decode): a Triton tensor-core GEMM with split-K. The even/odd nibbles become two K-halves of the reduction; the weight tile is dequantized to bf16 then fed to tl.dot -- matching the reference's bf16 rounding (an affine-correction reformulation is faster but diverges from the reference at cancellation points under large-activation stress). Scheme: x: (M, K) bf16 w_q: (K//2, N) uint8 -- two int4 packed/byte (low=even-K, high=odd-K) scales: (K//128, N) bf16 zeros: (K//128, N) bf16 w_bf[k,n] = (unpack(w_q)[k,n] - zeros[k//128,n]) * scales[k//128,n] The fp32 accumulator is returned directly (correctness compares in fp32; the roofline byte count is fixed at M*N*2 regardless), removing a bf16-convert pass. Configs are hand-tuned per shape using L2-COLD timing (the benchmark flushes L2 before each call, which Triton's own autotuner does not emulate). """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl GROUP_SIZE = 128 # --------------------------------------------------------------------------- # CUDA GEMV for M == 1 (decode). Splits the K dimension at packed-row # granularity (finer than the 128-group) so we reach full occupancy AND keep # VEC-wide coalesced vectorized loads -- group-level split-K is capped at # NGROUPS and cannot fill the machine for M=1. The 64-row inner loop is fully # unrolled (template ROWS) for memory-level parallelism, which is what actually # limits this kernel (the access is latency-bound, not bandwidth-bound). # --------------------------------------------------------------------------- _CPP_SRC = r"""#include torch::Tensor gemv_w4a16(torch::Tensor x, torch::Tensor wq, torch::Tensor scales, torch::Tensor zeros, int64_t threads, int64_t vec, int64_t rows); torch::Tensor gemv_w4a16_ib(torch::Tensor x, torch::Tensor wq, torch::Tensor scales, torch::Tensor zeros, int64_t bn, int64_t vec, int64_t kparts); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gemv_w4a16", &gemv_w4a16, "W4A16 GEMV"); m.def("gemv_w4a16_ib", &gemv_w4a16_ib, "W4A16 GEMV intra-block"); } """ _CUDA_SRC = r"""#include #include #include #include template __global__ void gemv_w4a16_kernel( const __nv_bfloat16* __restrict__ x, const uint8_t* __restrict__ wq, const __nv_bfloat16* __restrict__ scales, const __nv_bfloat16* __restrict__ zeros, float* __restrict__ out, int N, int K, int NGROUPS) { int col0 = (blockIdx.x * blockDim.x + threadIdx.x) * VEC; if (col0 >= N) return; const int subs = 64 / ROWS; int g = blockIdx.y / subs; int sub = blockIdx.y % subs; int r_start = sub * ROWS; int kbase = g * 128; float raw[VEC]; #pragma unroll for (int c = 0; c < VEC; c++) raw[c] = 0.f; float sx = 0.f; #pragma unroll for (int rr = 0; rr < ROWS; rr++) { int r = r_start + rr; const uint8_t* wptr = wq + (size_t)(g * 64 + r) * N + col0; float xe = __bfloat162float(x[kbase + 2 * r]); float xo = __bfloat162float(x[kbase + 2 * r + 1]); sx += xe + xo; if (VEC == 4) { uint32_t v = *reinterpret_cast(wptr); #pragma unroll for (int c = 0; c < 4; c++) { uint8_t b = (v >> (8 * c)) & 0xFF; raw[c] += xe * (float)(b & 0xF) + xo * (float)(b >> 4); } } else if (VEC == 8) { uint2 v = *reinterpret_cast(wptr); uint32_t w0 = v.x, w1 = v.y; #pragma unroll for (int c = 0; c < 4; c++) { uint8_t b = (w0 >> (8 * c)) & 0xFF; raw[c] += xe * (float)(b & 0xF) + xo * (float)(b >> 4); } #pragma unroll for (int c = 0; c < 4; c++) { uint8_t b = (w1 >> (8 * c)) & 0xFF; raw[c + 4] += xe * (float)(b & 0xF) + xo * (float)(b >> 4); } } else if (VEC == 2) { uint16_t v = *reinterpret_cast(wptr); #pragma unroll for (int c = 0; c < 2; c++) { uint8_t b = (v >> (8 * c)) & 0xFF; raw[c] += xe * (float)(b & 0xF) + xo * (float)(b >> 4); } } else { uint8_t b = wptr[0]; raw[0] += xe * (float)(b & 0xF) + xo * (float)(b >> 4); } } #pragma unroll for (int c = 0; c < VEC; c++) { if (col0 + c < N) { float s = __bfloat162float(scales[(size_t)g * N + col0 + c]); float z = __bfloat162float(zeros[(size_t)g * N + col0 + c]); atomicAdd(&out[col0 + c], s * raw[c] - z * s * sx); } } } // Intra-block split-K: KPARTS partitions of the group dimension reduce inside // the block via shared memory, then thread-group 0 directly stores the result. // One launch, no atomics, no pre-zeroed output -> avoids the ~6us memset. template __global__ void gemv_w4a16_ib_kernel( const __nv_bfloat16* __restrict__ x, const uint8_t* __restrict__ wq, const __nv_bfloat16* __restrict__ scales, const __nv_bfloat16* __restrict__ zeros, float* __restrict__ out, int N, int K, int NGROUPS, int COLS) { extern __shared__ float red[]; // [KPARTS][COLS*VEC] int cslot = threadIdx.x % COLS; int kpart = threadIdx.x / COLS; int col0 = (blockIdx.x * COLS + cslot) * VEC; float acc[VEC]; #pragma unroll for (int c = 0; c < VEC; c++) acc[c] = 0.f; if (col0 < N) { for (int g = kpart; g < NGROUPS; g += KPARTS) { int kbase = g * 128; float raw[VEC]; #pragma unroll for (int c = 0; c < VEC; c++) raw[c] = 0.f; float sx = 0.f; #pragma unroll for (int r = 0; r < 64; r++) { const uint8_t* wptr = wq + (size_t)(g * 64 + r) * N + col0; float xe = __bfloat162float(x[kbase + 2 * r]); float xo = __bfloat162float(x[kbase + 2 * r + 1]); sx += xe + xo; if (VEC == 4) { uint32_t v = *reinterpret_cast(wptr); #pragma unroll for (int c = 0; c < 4; c++) { uint8_t b = (v >> (8 * c)) & 0xFF; raw[c] += xe * (float)(b & 0xF) + xo * (float)(b >> 4); } } else if (VEC == 2) { uint16_t v = *reinterpret_cast(wptr); #pragma unroll for (int c = 0; c < 2; c++) { uint8_t b = (v >> (8 * c)) & 0xFF; raw[c] += xe * (float)(b & 0xF) + xo * (float)(b >> 4); } } else { uint8_t b = wptr[0]; raw[0] += xe * (float)(b & 0xF) + xo * (float)(b >> 4); } } #pragma unroll for (int c = 0; c < VEC; c++) { float s = __bfloat162float(scales[(size_t)g * N + col0 + c]); float z = __bfloat162float(zeros[(size_t)g * N + col0 + c]); acc[c] += s * raw[c] - z * s * sx; } } } // reduce across kparts #pragma unroll for (int c = 0; c < VEC; c++) red[kpart * (COLS * VEC) + cslot * VEC + c] = acc[c]; __syncthreads(); if (kpart == 0 && col0 < N) { #pragma unroll for (int c = 0; c < VEC; c++) { float sum = 0.f; for (int p = 0; p < KPARTS; p++) sum += red[p * (COLS * VEC) + cslot * VEC + c]; if (col0 + c < N) out[col0 + c] = sum; } } } torch::Tensor gemv_w4a16_ib(torch::Tensor x, torch::Tensor wq, torch::Tensor scales, torch::Tensor zeros, int64_t bn, int64_t vec, int64_t kparts) { int N = wq.size(1); int K = x.size(1); int NGROUPS = K / 128; auto out = torch::empty({1, N}, torch::dtype(torch::kFloat32).device(x.device())); int COLS = bn / vec; dim3 grid((N + bn - 1) / bn); dim3 block(COLS * kparts); size_t shmem = (size_t)kparts * COLS * vec * sizeof(float); const __nv_bfloat16* xp = reinterpret_cast(x.data_ptr()); const uint8_t* wp = wq.data_ptr(); const __nv_bfloat16* sp = reinterpret_cast(scales.data_ptr()); const __nv_bfloat16* zp = reinterpret_cast(zeros.data_ptr()); float* op = out.data_ptr(); auto stream = c10::cuda::getCurrentCUDAStream(); #define LAUNCH_IB(VEC, KP) \ gemv_w4a16_ib_kernel<<>>(xp, wp, sp, zp, op, N, K, NGROUPS, COLS) if (vec == 4) { if (kparts==8) LAUNCH_IB(4,8); else if (kparts==16) LAUNCH_IB(4,16); else LAUNCH_IB(4,32); } else if (vec == 2) { if (kparts==8) LAUNCH_IB(2,8); else if (kparts==16) LAUNCH_IB(2,16); else LAUNCH_IB(2,32); } else { if (kparts==8) LAUNCH_IB(1,8); else if (kparts==16) LAUNCH_IB(1,16); else LAUNCH_IB(1,32); } return out; } #define LAUNCH(VEC, ROWS) \ gemv_w4a16_kernel<<>>(xp, wp, sp, zp, op, N, K, NGROUPS) torch::Tensor gemv_w4a16(torch::Tensor x, torch::Tensor wq, torch::Tensor scales, torch::Tensor zeros, int64_t threads, int64_t vec, int64_t rows) { int N = wq.size(1); int K = x.size(1); int NGROUPS = K / 128; auto out = torch::zeros({1, N}, torch::dtype(torch::kFloat32).device(x.device())); int cols_per_block = threads * vec; int subs = 64 / rows; dim3 grid((N + cols_per_block - 1) / cols_per_block, NGROUPS * subs); dim3 block(threads); const __nv_bfloat16* xp = reinterpret_cast(x.data_ptr()); const uint8_t* wp = wq.data_ptr(); const __nv_bfloat16* sp = reinterpret_cast(scales.data_ptr()); const __nv_bfloat16* zp = reinterpret_cast(zeros.data_ptr()); float* op = out.data_ptr(); auto stream = c10::cuda::getCurrentCUDAStream(); if (vec == 4) { if (rows == 64) LAUNCH(4,64); else if (rows==32) LAUNCH(4,32); else if (rows==16) LAUNCH(4,16); else if (rows==8) LAUNCH(4,8); else LAUNCH(4,4); } else if (vec == 8) { if (rows == 64) LAUNCH(8,64); else if (rows==32) LAUNCH(8,32); else if (rows==16) LAUNCH(8,16); else if (rows==8) LAUNCH(8,8); else LAUNCH(8,4); } else if (vec == 2) { if (rows == 64) LAUNCH(2,64); else if (rows==32) LAUNCH(2,32); else if (rows==16) LAUNCH(2,16); else if (rows==8) LAUNCH(2,8); else LAUNCH(2,4); } else { if (rows == 64) LAUNCH(1,64); else if (rows==32) LAUNCH(1,32); else if (rows==16) LAUNCH(1,16); else if (rows==8) LAUNCH(1,8); else LAUNCH(1,4); } return out; } """ _GEMV_MOD = None def _get_gemv_mod(): global _GEMV_MOD if _GEMV_MOD is None: from torch.utils.cpp_extension import load_inline _GEMV_MOD = load_inline( name="w4a16_gemv_ext", cpp_sources=_CPP_SRC, cuda_sources=_CUDA_SRC, extra_cuda_cflags=["-O3", "--use_fast_math", "-gencode", "arch=compute_120,code=sm_120"], ) return _GEMV_MOD # (N,K) -> (method, a, b, c). method "ib": intra-block split-K (1 launch, no # memset) used where it ties/beats the atomic kernel; "atomic": grid-split-K # (needs a zeroed buffer) used for small N where it keeps more blocks alive. _GEMV_CFG = { (12288, 4096): ("ib", 128, 4, 32), # bn, vec, kparts (4096, 4096): ("atomic", 64, 2, 32), # threads, vec, rows } _GEMV_DEFAULT = ("ib", 128, 4, 32) def _gemv(x, w_q, scales, zeros, N, K, cfg=None): if cfg is None: cfg = _GEMV_CFG.get((N, K), _GEMV_DEFAULT) method, a, b, c = cfg mod = _get_gemv_mod() if method == "ib": return mod.gemv_w4a16_ib(x, w_q, scales, zeros, a, b, c) # bn, vec, kparts return mod.gemv_w4a16(x, w_q, scales, zeros, a, b, c) # threads, vec, rows @triton.jit def _gemm_kernel( x_ptr, wq_ptr, s_ptr, z_ptr, out_ptr, M, N, K, stride_xm, stride_xk, stride_wk, stride_wn, stride_sg, stride_zg, stride_om, stride_on, NGROUPS: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr, BLOCK_R: tl.constexpr, ): pid_n = tl.program_id(0) pid_m = tl.program_id(1) pid_k = tl.program_id(2) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) m_mask = offs_m < M rr = tl.arange(0, BLOCK_R) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for g in range(pid_k, NGROUPS, SPLIT_K): s = tl.load(s_ptr + g * stride_sg + offs_n).to(tl.float32) z = tl.load(z_ptr + g * stride_zg + offs_n).to(tl.float32) kbase = g * 128 for ri in tl.static_range(0, 64, BLOCK_R): rows = g * 64 + ri + rr P = tl.load(wq_ptr + rows[:, None] * stride_wk + offs_n[None, :] * stride_wn) lo = (P & 0xF).to(tl.float32) hi = ((P >> 4) & 0xF).to(tl.float32) offs_ke = kbase + 2 * (ri + rr) offs_ko = offs_ke + 1 xe = tl.load(x_ptr + offs_m[:, None] * stride_xm + offs_ke[None, :] * stride_xk, mask=m_mask[:, None], other=0.0) xo = tl.load(x_ptr + offs_m[:, None] * stride_xm + offs_ko[None, :] * stride_xk, mask=m_mask[:, None], other=0.0) # Dequant-then-dot in bf16 to match the reference's rounding. w_lo = ((lo - z[None, :]) * s[None, :]).to(tl.bfloat16) w_hi = ((hi - z[None, :]) * s[None, :]).to(tl.bfloat16) acc += tl.dot(xe, w_lo) + tl.dot(xo, w_hi) out_off = offs_m[:, None] * stride_om + offs_n[None, :] * stride_on if SPLIT_K == 1: tl.store(out_ptr + out_off, acc, mask=m_mask[:, None]) else: tl.atomic_add(out_ptr + out_off, acc, mask=m_mask[:, None]) # (BLOCK_M, BLOCK_N, SPLIT_K, num_warps, num_stages, BLOCK_R) keyed by (M, N, K). # Defaults chosen for the benchmark's L2-cold regime; see sweep_configs.py. _DEFAULTS = { 1: (16, 64, 16, 4, 3, 64), 16: (16, 128, 8, 8, 4, 64), 32: (16, 128, 8, 4, 3, 64), 256: (64, 256, 4, 8, 3, 64), } _CONFIG: dict[tuple[int, int, int], tuple] = { (32, 12288, 4096): (16, 128, 8, 8, 3, 32), (256, 12288, 4096): (64, 256, 4, 4, 3, 16), (16, 14336, 4096): (16, 128, 8, 8, 4, 64), } def _config_for(M, N, K): c = _CONFIG.get((M, N, K)) if c is not None: return c if M <= 1: return _DEFAULTS[1] if M <= 16: return _DEFAULTS[16] if M <= 32: return _DEFAULTS[32] return _DEFAULTS[256] def _gemm(x, w_q, scales, zeros, M, N, K, cfg=None): ngroups = K // GROUP_SIZE if cfg is None: cfg = _config_for(M, N, K) BLOCK_M, BLOCK_N, SPLIT_K, num_warps, num_stages, BLOCK_R = cfg out = torch.zeros((M, N), dtype=torch.float32, device=x.device) grid = (triton.cdiv(N, BLOCK_N), triton.cdiv(M, BLOCK_M), SPLIT_K) _gemm_kernel[grid]( x, w_q, scales, zeros, out, M, N, K, x.stride(0), x.stride(1), w_q.stride(0), w_q.stride(1), scales.stride(0), zeros.stride(0), out.stride(0), out.stride(1), NGROUPS=ngroups, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, SPLIT_K=SPLIT_K, num_warps=num_warps, num_stages=num_stages, BLOCK_R=BLOCK_R, ) return out class Model(nn.Module): def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE): super().__init__() assert K % group_size == 0 and K % 2 == 0 self.M, self.N, self.K = M, N, K self.group_size = group_size n_groups = K // group_size self.register_buffer("w_q", torch.zeros((K // 2, N), dtype=torch.uint8)) self.register_buffer("scales", torch.zeros((n_groups, N), dtype=torch.bfloat16)) self.register_buffer("zeros", torch.zeros((n_groups, N), dtype=torch.bfloat16)) def forward(self, x: torch.Tensor) -> torch.Tensor: M, K = x.shape if M == 1: try: return _gemv(x, self.w_q, self.scales, self.zeros, self.N, K) except Exception: # robustness: if the CUDA extension fails to build/run, fall # back to the Triton GEMM path (handles M=1 via BLOCK_M=16). pass return _gemm(x, self.w_q, self.scales, self.zeros, M, self.N, K) M = 1 N = 12288 K = 4096 def get_inputs(): x = torch.randn(M, K, dtype=torch.bfloat16) return [x] def get_init_inputs(): return [M, N, K]