"""W4A16 weight-only quantized GEMM (AWQ/GPTQ-style asymmetric int4). Fused unpack + dequant + GEMM in Triton. Same Model/get_inputs/get_init_inputs interface as reference.py. Scheme: x: (M, K) bf16 w_q: (K//2, N) uint8 low nibble = even-K row, high nibble = odd-K row scales: (K//128, N) bf16 zeros: (K//128, N) bf16 out: (M, N) bf16 w_bf[k,n] = (unpack(w_q)[k,n] - zeros[k//128,n]) * scales[k//128,n] out = x @ w_bf Two kernel paths: - M == 1: split-K GEMV over a *transposed* weight (N, K//2) so each output's weight vector is contiguous in memory -> higher sustained bandwidth. When the N-tiling alone fills the GPU, no split-K and a single direct-store kernel. Otherwise split-K with fp32 partials + fused reduce/cast (fp32 end-to-end). - M > 1: tl.dot GEMM with fused dequant (original (K//2, N) layout). """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl GROUP_SIZE = 128 _NUM_SMS = 188 # RTX PRO 6000 Blackwell # --------------------------------------------------------------------------- # CUDA decode GEMV (M==1): one warp per output, fully-parallel K-reduction # (32 lanes each do K/32 of the dot product, then warp shuffle-reduce), # vectorized 16-byte loads with next-iteration prefetch. Single kernel, # no split-K reduction overhead. # --------------------------------------------------------------------------- _CUDA_GEMM_SRC = r''' #include #include #include // One warp per output. Each lane reads 16 bytes (uint4) per K-iteration, // covering 32 K-elements; the warp covers 512 bytes = 1024 K = 8 groups/iter. // d1 = 1-deep prefetch (high-occupancy large-N), d2 = 2-deep (low-occupancy small-N). #define GEMV_BODY_PREFETCH1 \ int niters = KH >> 9; \ uint4 wn = *reinterpret_cast(wrow + 16 * lane); \ _Pragma("unroll 4") \ for (int it = 0; it < niters; it++) { \ uint4 wv = wn; \ int b = it * 512; \ if (it + 1 < niters) wn = *reinterpret_cast(wrow + b + 512 + 16 * lane); \ int g = b / 64 + (lane >> 2); \ float s = __bfloat162float(sc[(size_t)g * N + n]); \ float zf = __bfloat162float(zz[(size_t)g * N + n]); \ int k0 = 2 * b + 32 * lane; \ __nv_bfloat16 xb[32]; \ *reinterpret_cast(xb) = *reinterpret_cast(x + k0); \ *reinterpret_cast(xb + 8) = *reinterpret_cast(x + k0 + 8); \ *reinterpret_cast(xb + 16) = *reinterpret_cast(x + k0 + 16); \ *reinterpret_cast(xb + 24) = *reinterpret_cast(x + k0 + 24); \ const uint32_t* wp = reinterpret_cast(&wv); \ _Pragma("unroll") \ for (int q = 0; q < 4; q++) { uint32_t p = wp[q]; int base = 8 * q; \ _Pragma("unroll") \ for (int j = 0; j < 4; j++) { unsigned int bv = (p >> (8 * j)) & 0xFFu; \ float xe = __bfloat162float(xb[base + 2 * j]); float xo = __bfloat162float(xb[base + 2 * j + 1]); \ float wl = ((float)(bv & 0xFu) - zf) * s; float wh = ((float)((bv >> 4) & 0xFu) - zf) * s; \ acc += xe * wl + xo * wh; } } } #define GEMV_BODY_PREFETCH2 \ int niters = KH >> 9; \ uint4 w0 = *reinterpret_cast(wrow + 0 + 16 * lane); \ uint4 w1 = (niters > 1) ? *reinterpret_cast(wrow + 512 + 16 * lane) : w0; \ _Pragma("unroll 4") \ for (int it = 0; it < niters; it++) { \ uint4 wv = w0; w0 = w1; \ if (it + 2 < niters) w1 = *reinterpret_cast(wrow + (it + 2) * 512 + 16 * lane); \ int b = it * 512; \ int g = b / 64 + (lane >> 2); \ float s = __bfloat162float(sc[(size_t)g * N + n]); \ float zf = __bfloat162float(zz[(size_t)g * N + n]); \ int k0 = 2 * b + 32 * lane; \ __nv_bfloat16 xb[32]; \ *reinterpret_cast(xb) = *reinterpret_cast(x + k0); \ *reinterpret_cast(xb + 8) = *reinterpret_cast(x + k0 + 8); \ *reinterpret_cast(xb + 16) = *reinterpret_cast(x + k0 + 16); \ *reinterpret_cast(xb + 24) = *reinterpret_cast(x + k0 + 24); \ const uint32_t* wp = reinterpret_cast(&wv); \ _Pragma("unroll") \ for (int q = 0; q < 4; q++) { uint32_t p = wp[q]; int base = 8 * q; \ _Pragma("unroll") \ for (int j = 0; j < 4; j++) { unsigned int bv = (p >> (8 * j)) & 0xFFu; \ float xe = __bfloat162float(xb[base + 2 * j]); float xo = __bfloat162float(xb[base + 2 * j + 1]); \ float wl = ((float)(bv & 0xFu) - zf) * s; float wh = ((float)((bv >> 4) & 0xFu) - zf) * s; \ acc += xe * wl + xo * wh; } } } __global__ void w4a16_gemv_kernel_d1( const __nv_bfloat16* __restrict__ x, const uint8_t* __restrict__ wqt, const __nv_bfloat16* __restrict__ sc, const __nv_bfloat16* __restrict__ zz, __nv_bfloat16* __restrict__ y, int N, int K, int KH) { int wpb = blockDim.x >> 5; int warp = threadIdx.x >> 5; int lane = threadIdx.x & 31; int n = blockIdx.x * wpb + warp; if (n >= N) return; const uint8_t* wrow = wqt + (size_t)n * KH; float acc = 0.f; GEMV_BODY_PREFETCH1 #pragma unroll for (int off = 16; off > 0; off >>= 1) acc += __shfl_xor_sync(0xffffffff, acc, off); if (lane == 0) y[n] = __float2bfloat16(acc); } __global__ void w4a16_gemv_kernel_d2( const __nv_bfloat16* __restrict__ x, const uint8_t* __restrict__ wqt, const __nv_bfloat16* __restrict__ sc, const __nv_bfloat16* __restrict__ zz, __nv_bfloat16* __restrict__ y, int N, int K, int KH) { int wpb = blockDim.x >> 5; int warp = threadIdx.x >> 5; int lane = threadIdx.x & 31; int n = blockIdx.x * wpb + warp; if (n >= N) return; const uint8_t* wrow = wqt + (size_t)n * KH; float acc = 0.f; GEMV_BODY_PREFETCH2 #pragma unroll for (int off = 16; off > 0; off >>= 1) acc += __shfl_xor_sync(0xffffffff, acc, off); if (lane == 0) y[n] = __float2bfloat16(acc); } void launch_w4a16_gemv(torch::Tensor x, torch::Tensor wqt, torch::Tensor sc, torch::Tensor zz, torch::Tensor out, int wpb, int depth) { int N = out.size(1); int K = x.size(1); int KH = K / 2; int block = wpb * 32; int grid_n = (N + wpb - 1) / wpb; if (depth <= 1) w4a16_gemv_kernel_d1<<>>( reinterpret_cast(x.data_ptr()), reinterpret_cast(wqt.data_ptr()), reinterpret_cast(sc.data_ptr()), reinterpret_cast(zz.data_ptr()), reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), N, K, KH); else w4a16_gemv_kernel_d2<<>>( reinterpret_cast(x.data_ptr()), reinterpret_cast(wqt.data_ptr()), reinterpret_cast(sc.data_ptr()), reinterpret_cast(zz.data_ptr()), reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), N, K, KH); } ''' _CUDA_CPP = ("void launch_w4a16_gemv(torch::Tensor x, torch::Tensor wqt, " "torch::Tensor sc, torch::Tensor zz, torch::Tensor out, int wpb, int depth);") def _try_compile_cuda(): try: from torch.utils.cpp_extension import load_inline return load_inline( name="w4a16_gemv_sm120", cpp_sources=_CUDA_CPP, cuda_sources=_CUDA_GEMM_SRC, functions=["launch_w4a16_gemv"], verbose=False, ) except Exception: return None _CUDA_MOD = _try_compile_cuda() # --------------------------------------------------------------------------- # Decode path: M == 1, split-K GEMV on transposed weight (N, K//2). # --------------------------------------------------------------------------- @triton.jit def _w4a16_gemv_kernel( x_ptr, wqt_ptr, s_ptr, z_ptr, out_ptr, N, KH, # KH = K//2 stride_sn, stride_sg, stride_out, SKB: tl.constexpr, # split size in bytes (multiple of 64) BK: tl.constexpr, BN: tl.constexpr, GPT: tl.constexpr, # BK = 64*GPT STORE_DIRECT: tl.constexpr, ): pid_n = tl.program_id(0) pid_k = tl.program_id(1) offs_n = pid_n * BN + tl.arange(0, BN) mn = offs_n < N kb_start = pid_k * SKB acc = tl.zeros((BN,), dtype=tl.float32) niters = SKB // BK for it in range(0, niters): kb = kb_start + it * BK gbase = kb // 64 off_kb = kb + tl.arange(0, BK) w = tl.load(wqt_ptr + offs_n[:, None] * KH + off_kb[None, :], mask=mn[:, None], other=0) # (BN, BK) uint8 wlo = (w & 0xF).to(tl.bfloat16) whi = ((w >> 4) & 0xF).to(tl.bfloat16) # apply per-group scale/zero: (BN, GPT, 64) <- broadcast (BN, GPT) wlo = tl.reshape(wlo, (BN, GPT, 64)) whi = tl.reshape(whi, (BN, GPT, 64)) gg = gbase + tl.arange(0, GPT) s = tl.load(s_ptr + gg[:, None] * stride_sg + offs_n[None, :] * stride_sn, mask=mn[None, :], other=0.0).to(tl.bfloat16) # (GPT, BN) z = tl.load(z_ptr + gg[:, None] * stride_sg + offs_n[None, :] * stride_sn, mask=mn[None, :], other=0.0).to(tl.bfloat16) s = tl.permute(s, (1, 0)) # (BN, GPT) z = tl.permute(z, (1, 0)) wlo = (wlo - z[:, :, None]) * s[:, :, None] whi = (whi - z[:, :, None]) * s[:, :, None] wlo = tl.reshape(wlo, (BN, BK)) whi = tl.reshape(whi, (BN, BK)) xk = (2 * kb) + tl.arange(0, 2 * BK) xf = tl.load(x_ptr + xk) # (2*BK,) bf16 xr = tl.reshape(xf, (BK, 2)) xe, xo = tl.split(xr) # (BK,) each acc += tl.sum(xe[None, :].to(tl.float32) * wlo.to(tl.float32), axis=1) acc += tl.sum(xo[None, :].to(tl.float32) * whi.to(tl.float32), axis=1) if STORE_DIRECT: tl.store(out_ptr + offs_n, acc.to(tl.bfloat16), mask=mn) else: tl.store(out_ptr + pid_k * stride_out + offs_n, acc, mask=mn) @triton.jit def _reduce_cast_kernel(p_ptr, y_ptr, N, NSPLIT, stride_pk, BN: tl.constexpr): pid = tl.program_id(0) offs = pid * BN + tl.arange(0, BN) mask = offs < N acc = tl.zeros((BN,), dtype=tl.float32) for i in range(0, NSPLIT): acc += tl.load(p_ptr + i * stride_pk + offs, mask=mask, other=0.0) tl.store(y_ptr + offs, acc.to(tl.bfloat16), mask=mask) def _gemv_plan(N, K): KH = K // 2 BN = 32 n_groups = KH // 64 nnt = triton.cdiv(N, BN) if N >= 8192: GPT, BK, nw, ns = 2, 128, 4, 3 else: GPT, BK, nw, ns = 1, 64, 4, 4 # split-K to ~6x SMs total blocks (caps occupancy-driven stalls). nspl = max(1, (6 * _NUM_SMS) // nnt) nspl = min(nspl, n_groups) while nspl > 1 and (KH // nspl) % BK != 0: nspl -= 1 SKB = KH // nspl return dict(BN=BN, GPT=GPT, BK=BK, nw=nw, ns=ns, nspl=nspl, SKB=SKB, nnt=nnt) def _gemv(x, wqt, scales, zeros, N, K, plan, partial): KH = K // 2 y = torch.empty((1, N), dtype=torch.bfloat16, device=x.device) _w4a16_gemv_kernel[(plan["nnt"], plan["nspl"])]( x, wqt, scales, zeros, partial, N, KH, scales.stride(1), scales.stride(0), partial.stride(0), SKB=plan["SKB"], BK=plan["BK"], BN=plan["BN"], GPT=plan["GPT"], STORE_DIRECT=False, num_warps=plan["nw"], num_stages=plan["ns"], ) _reduce_cast_kernel[(triton.cdiv(N, plan["BN"]),)]( partial, y, N, plan["nspl"], partial.stride(0), BN=plan["BN"], num_warps=4, num_stages=1, ) return y # --------------------------------------------------------------------------- # Prefill path: M > 1, tl.dot GEMM with fused dequant (original layout). # --------------------------------------------------------------------------- @triton.jit def _w4a16_gemm_kernel( x_ptr, wq_ptr, s_ptr, z_ptr, y_ptr, M, N, K, stride_xm, stride_xk, stride_wpk, stride_wn, stride_sg, stride_sn, stride_ym, stride_yn, BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr, GROUP: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BM + tl.arange(0, BM) offs_n = pid_n * BN + tl.arange(0, BN) offs_k = tl.arange(0, BK) mask_m = offs_m < M mask_n = offs_n < N acc = tl.zeros((BM, BN), dtype=tl.float32) n_kc = K // BK # K-chunks of size BK (BK divides GROUP, so a quant # group spans GROUP/BK chunks; scale indexed by k0//GROUP) BK_HALF: tl.constexpr = BK // 2 for kc in range(0, n_kc): k0 = kc * BK g = k0 // GROUP x_ptrs = x_ptr + offs_m[:, None] * stride_xm + (k0 + offs_k)[None, :] * stride_xk x = tl.load(x_ptrs, mask=mask_m[:, None], other=0.0) x_r = tl.reshape(x, (BM, BK_HALF, 2)) x_lo, x_hi = tl.split(x_r) pk_offs = (k0 // 2) + tl.arange(0, BK_HALF) w_ptrs = wq_ptr + pk_offs[:, None] * stride_wpk + offs_n[None, :] * stride_wn w_packed = tl.load(w_ptrs, mask=mask_n[None, :], other=0) w_lo = (w_packed & 0xF).to(tl.bfloat16) w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16) s = tl.load(s_ptr + g * stride_sg + offs_n * stride_sn, mask=mask_n, other=0.0).to(tl.bfloat16) z = tl.load(z_ptr + g * stride_sg + offs_n * stride_sn, mask=mask_n, other=0.0).to(tl.bfloat16) w_lo = (w_lo - z[None, :]) * s[None, :] w_hi = (w_hi - z[None, :]) * s[None, :] acc = tl.dot(x_lo, w_lo, acc=acc, allow_tf32=False) acc = tl.dot(x_hi, w_hi, acc=acc, allow_tf32=False) y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn tl.store(y_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :]) def _gemm(x, wq, scales, zeros, M, N, K, y=None): if y is None: y = torch.empty((M, N), dtype=torch.bfloat16, device=x.device) if M >= 128: BM, BN, nw, ns, BK = 64, 128, 8, 3, 128 elif M >= 32: BM, BN, nw, ns, BK = 32, 128, 8, 5, 128 else: BM, BN, nw, ns, BK = 16, 128, 8, 5, 128 grid = (triton.cdiv(M, BM), triton.cdiv(N, BN)) _w4a16_gemm_kernel[grid]( x, wq, scales, zeros, y, M, N, K, x.stride(0), x.stride(1), wq.stride(0), wq.stride(1), scales.stride(0), scales.stride(1), y.stride(0), y.stride(1), BM=BM, BN=BN, BK=BK, GROUP=GROUP_SIZE, num_warps=nw, num_stages=ns, ) return y class Model(nn.Module): def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE): super().__init__() self.M, self.N, self.K = M, N, K self.group_size = group_size self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8)) self.register_buffer("scales", torch.empty((K // group_size, N), dtype=torch.bfloat16)) self.register_buffer("zeros", torch.empty((K // group_size, N), dtype=torch.bfloat16)) self._wqt = None self._partial = None self._plan = None self._yout = None self._gemm_out = None self._wpb = 4 def _ensure_setup(self): if self._wqt is None: self._wqt = self.w_q.t().contiguous() if self._plan is None: self._plan = _gemv_plan(self.N, self.K) self._partial = torch.empty( (self._plan["nspl"], self.N), dtype=torch.float32, device=self.w_q.device) self._wpb = 4 self._depth = 1 if self.N >= 8192 else 2 self._yout = torch.empty((1, self.N), dtype=torch.bfloat16, device=self.w_q.device) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dtype != torch.bfloat16: x = x.to(torch.bfloat16) M = x.shape[0] if M == 1: self._ensure_setup() if _CUDA_MOD is not None: _CUDA_MOD.launch_w4a16_gemv(x, self._wqt, self.scales, self.zeros, self._yout, self._wpb, self._depth) return self._yout return _gemv(x, self._wqt, self.scales, self.zeros, self.N, self.K, self._plan, self._partial) if self._gemm_out is None: self._gemm_out = torch.empty((self.M, self.N), dtype=torch.bfloat16, device=x.device) return _gemm(x, self.w_q, self.scales, self.zeros, M, self.N, self.K, self._gemm_out) M = 1 N = 12288 K = 4096 def get_inputs(): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") return [x] def get_init_inputs(): return [M, N, K]