"""W4A16 weight-only quantized GEMM — fused unpack + dequant + matmul via Triton. AWQ/GPTQ-style asymmetric int4 scheme with explicit zero-points and per-group (group_size=128) bf16 scales. The Triton kernel loads packed uint8 weights, unpacks nibbles, applies the per-group affine dequant, and accumulates with bf16 activations in a single fused pass — no materialised bf16 weight matrix. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "gemm_w4a16" SUPPORTED_PRECISIONS = ["int4_bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] GROUP_SIZE = 128 # --------------------------------------------------------------------------- # Triton kernel (fused unpack + dequant + gemm) # --------------------------------------------------------------------------- @triton.jit def _gemm_w4a16_kernel( x_ptr, w_q_ptr, scales_ptr, zeros_ptr, out_ptr, M, N, K, stride_xm, stride_xk, stride_wq_kh, stride_wq_n, stride_s_g, stride_s_n, stride_z_g, stride_z_n, stride_om, stride_on, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ): """Fused W4A16 GEMM — unpack + dequant + matmul in one pass. BLOCK_K must divide GROUP_SIZE so every K tile is inside one quant group. """ pid = tl.program_id(0) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_in_group = GROUP_M * num_pid_n group_id_m = pid // num_pid_in_group first_pid_m = group_id_m * GROUP_M group_size_m = min(num_pid_m - first_pid_m, GROUP_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) mask_m = offs_m < M mask_n = offs_n < N x_base = x_ptr + offs_m[:, None] * stride_xm out_base = out_ptr + offs_m[:, None] * stride_om acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) k_even_base = tl.arange(0, BLOCK_K // 2) * 2 k_odd_base = k_even_base + 1 kh_base = tl.arange(0, BLOCK_K // 2) for k_start in range(0, K, BLOCK_K): # --- packed weights for this K tile -------------------------------- k_half_offs = (k_start // 2) + kh_base wq_ptrs = w_q_ptr + k_half_offs[:, None] * stride_wq_kh + offs_n[None, :] * stride_wq_n mask_wq = (k_half_offs[:, None] < K // 2) & mask_n[None, :] w_packed = tl.load(wq_ptrs, mask=mask_wq, other=0) w_lo = w_packed & 0x0F w_hi = (w_packed >> 4) & 0x0F # --- per-group scales & zeros -------------------------------------- gid = k_start // GROUP_SIZE s_g = tl.load(scales_ptr + gid * stride_s_g + offs_n * stride_s_n, mask=mask_n, other=0.0) z_g = tl.load(zeros_ptr + gid * stride_z_g + offs_n * stride_z_n, mask=mask_n, other=0.0) # --- dequant: w_bf16 = (w_int4 - zero) * scale -------------------- # Use f32 for subtraction (uint8→f32 is cheap) then bf16 for multiplies # to save registers. z_g_f32 = z_g.to(tl.float32) w_lo_bf = (w_lo.to(tl.float32) - z_g_f32[None, :]).to(tl.bfloat16) * s_g[None, :] w_hi_bf = (w_hi.to(tl.float32) - z_g_f32[None, :]).to(tl.bfloat16) * s_g[None, :] # --- activations (even / odd K rows) ------------------------------- k_even = k_start + k_even_base k_odd = k_start + k_odd_base x_even_ptrs = x_base + k_even[None, :] * stride_xk x_odd_ptrs = x_base + k_odd[None, :] * stride_xk x_even = tl.load(x_even_ptrs, mask=mask_m[:, None] & (k_even[None, :] < K), other=0.0) x_odd = tl.load(x_odd_ptrs, mask=mask_m[:, None] & (k_odd[None, :] < K), other=0.0) # --- accumulate ---------------------------------------------------- acc += tl.dot(x_even, w_lo_bf) acc += tl.dot(x_odd, w_hi_bf) # --- store ------------------------------------------------------------ out_ptrs = out_base + offs_n[None, :] * stride_on tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :]) # --------------------------------------------------------------------------- # Heuristic config selection — tuned empirically on RTX PRO 6000 Blackwell. # --------------------------------------------------------------------------- def _pick_config(M: int, N: int, K: int): """Return (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) for this shape.""" BK = 128 # = GROUP_SIZE — one group per K-tile is simplest if M == 1: # Decode: memory-bound on weight read return (4, 64, BK, 8) elif M <= 4: return (4, 64, BK, 8) elif M <= 8: return (8, 64, BK, 8) elif M <= 16: # spec-decode-ish (M=16) or small prefill if N >= 12288: return (8, 64, BK, 8) return (8, 128, BK, 8) elif M <= 32: return (8, 128, BK, 8) elif M <= 64: return (16, 128, BK, 8) else: return (16, 128, BK, 8) # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- class Model(nn.Module): """W4A16 GEMM: y = x @ dequant(w_q, scales, zeros).""" def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE): super().__init__() assert K % group_size == 0, "K must be divisible by group_size" assert K % 2 == 0, "K must be even (int4 packing)" self.M, self.N, self.K = M, N, K self.group_size = group_size n_groups = K // group_size # Deterministic synthetic quant — identical to reference.Model torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K)) w_full = torch.randn(K, N, dtype=torch.float32) * 0.02 w_g = w_full.view(n_groups, group_size, N) w_min = w_g.min(dim=1, keepdim=True).values w_max = w_g.max(dim=1, keepdim=True).values scales = (w_max - w_min).clamp_min(1e-8) / 15.0 zeros = (-w_min / scales).round().clamp(0, 15) w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8) w_q = w_q.view(K, N) scales_2d = scales.squeeze(1).to(torch.bfloat16) zeros_2d = zeros.squeeze(1).to(torch.bfloat16) # Pack int4: low nibble = even-K, high nibble = odd-K lo = w_q[0::2].to(torch.uint8) & 0xF hi = w_q[1::2].to(torch.uint8) & 0xF w_packed = (lo | (hi << 4)).contiguous() self.register_buffer("w_q", w_packed) self.register_buffer("scales", scales_2d) self.register_buffer("zeros", zeros_2d) def forward(self, x: torch.Tensor) -> torch.Tensor: M_out, K_in = x.shape assert K_in == self.K, f"K mismatch: {K_in} vs {self.K}" device = x.device out = torch.empty(M_out, self.N, dtype=torch.bfloat16, device=device) x = x.contiguous() w_q = self.w_q.contiguous() scales = self.scales.contiguous() zeros = self.zeros.contiguous() BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M = _pick_config(M_out, self.N, self.K) grid = (triton.cdiv(M_out, BLOCK_M) * triton.cdiv(self.N, BLOCK_N),) _gemm_w4a16_kernel[grid]( x, w_q, scales, zeros, out, M_out, self.N, self.K, x.stride(0), x.stride(1), w_q.stride(0), w_q.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), out.stride(0), out.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE=self.group_size, GROUP_M=GROUP_M, num_warps=4, ) return out # --------------------------------------------------------------------------- # Module-level helpers (shimmed by check.py / benchmark.py before each call) # --------------------------------------------------------------------------- M = 1 N = 12288 K = 4096 def get_inputs(): x = torch.randn(M, K, dtype=torch.bfloat16) return [x] def get_init_inputs(): return [M, N, K]