"""Fused W4A16 weight-only quantized GEMM (AWQ/GPTQ-style asymmetric int4). Kernels fuse int4 unpack, per-group scale/zero dequant, and bf16 GEMM in one pass so the weight stream stays at 0.5 B/elem. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "gemm_w4a16" SUPPORTED_PRECISIONS = ["int4_bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] GROUP_SIZE = 128 @triton.jit def w4a16_gemm_kernel( x_ptr, w_q_ptr, s_ptr, z_ptr, out_ptr, M, N, K, stride_xm, stride_xk, stride_wq_k, stride_wq_n, stride_s_g, stride_s_n, stride_z_g, stride_z_n, stride_om, stride_on, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr, ): """Generic batched GEMM: y = x @ dequant(w_q, scales, zeros) in bf16. Grid: (M // BLOCK_M, N // BLOCK_N) with K loop over groups of GROUP_SIZE. Packed weights: byte at (k//2, n) holds even-k nibble in low bits and odd-k nibble in high bits. """ pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # Split the K tile into even/odd halves for the packed layout. offs_k_half = tl.arange(0, BLOCK_K // 2) offs_k_even = 2 * offs_k_half # 0, 2, 4, ... offs_k_odd = 2 * offs_k_half + 1 # 1, 3, 5, ... mask_m = offs_m < M mask_n = offs_n < N acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) n_groups = K // GROUP_SIZE for g in tl.range(0, n_groups): k0 = g * GROUP_SIZE # x even/odd k slices: (BLOCK_M, BLOCK_K//2) x_even_ptrs = ( x_ptr + (offs_m[:, None] * stride_xm) + ((k0 + offs_k_even)[None, :] * stride_xk) ) x_odd_ptrs = ( x_ptr + (offs_m[:, None] * stride_xm) + ((k0 + offs_k_odd)[None, :] * stride_xk) ) x_even = tl.load(x_even_ptrs, mask=mask_m[:, None], other=0.0) x_odd = tl.load(x_odd_ptrs, mask=mask_m[:, None], other=0.0) # packed weights: (BLOCK_K//2, BLOCK_N) wq_ptrs = ( w_q_ptr + ((k0 // 2 + offs_k_half[:, None]) * stride_wq_k) + (offs_n[None, :] * stride_wq_n) ) wq_tile = tl.load(wq_ptrs, mask=mask_n[None, :], other=0) w_lo = (wq_tile & 0xF).to(tl.bfloat16) w_hi = ((wq_tile >> 4) & 0xF).to(tl.bfloat16) # Per-group scale and zero: (BLOCK_N,) s = tl.load(s_ptr + g * stride_s_g + offs_n * stride_s_n, mask=mask_n, other=0.0) z = tl.load(z_ptr + g * stride_z_g + offs_n * stride_z_n, mask=mask_n, other=0.0) # Dequant and accumulate. w_lo = (w_lo - z[None, :]) * s[None, :] w_hi = (w_hi - z[None, :]) * s[None, :] acc += tl.dot(x_even, w_lo) acc += tl.dot(x_odd, w_hi) out_ptrs = ( out_ptr + (offs_m[:, None] * stride_om) + (offs_n[None, :] * stride_on) ) tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :]) def _grid(M, N, BLOCK_M, BLOCK_N): return (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), 1) # Shape-specific configs chosen to keep weight reads coalesced and occupancy high. _CONFIGS = { (1, 12288, 4096): {"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128, "num_stages": 4, "num_warps": 8}, (1, 4096, 4096): {"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "num_stages": 2, "num_warps": 8}, (32, 12288, 4096): {"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "num_stages": 3, "num_warps": 8}, (256, 12288, 4096): {"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128, "num_stages": 1, "num_warps": 4}, (16, 14336, 4096): {"BLOCK_M": 8, "BLOCK_N": 64, "BLOCK_K": 128, "num_stages": 2, "num_warps": 4}, } def w4a16_gemm(x: torch.Tensor, w_q: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int) -> torch.Tensor: M, K = x.shape Kh, N = w_q.shape assert Kh * 2 == K assert K % group_size == 0 assert scales.shape == (K // group_size, N) assert zeros.shape == (K // group_size, N) out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device) key = (M, N, K) cfg = _CONFIGS.get(key, {"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "num_stages": 2, "num_warps": 4}) w4a16_gemm_kernel[_grid(M, N, cfg["BLOCK_M"], cfg["BLOCK_N"])]( x, w_q, scales, zeros, out, M, N, K, x.stride(0), x.stride(1), w_q.stride(0), w_q.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), out.stride(0), out.stride(1), BLOCK_M=cfg["BLOCK_M"], BLOCK_N=cfg["BLOCK_N"], BLOCK_K=cfg["BLOCK_K"], GROUP_SIZE=group_size, num_stages=cfg["num_stages"], num_warps=cfg["num_warps"], ) return out class Model(nn.Module): """W4A16 GEMM: y = x @ dequant(w_q, scales, zeros).""" def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE): super().__init__() assert K % group_size == 0 assert K % 2 == 0 self.M, self.N, self.K = M, N, K self.group_size = group_size self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8)) self.register_buffer("scales", torch.empty((K // group_size, N), dtype=torch.bfloat16)) self.register_buffer("zeros", torch.empty((K // group_size, N), dtype=torch.bfloat16)) def forward(self, x: torch.Tensor) -> torch.Tensor: return w4a16_gemm(x, self.w_q, self.scales, self.zeros, self.group_size) # Module-level shims for get_inputs / get_init_inputs. M = 1 N = 12288 K = 4096 def get_inputs(): x = torch.randn(M, K, dtype=torch.bfloat16) return [x] def get_init_inputs(): return [M, N, K]