"""Fused W4A16 GEMM: unpack int4 weights + dequant + matmul in one Triton kernel.""" from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl GROUP_SIZE = 128 # --------------------------------------------------------------------------- # Decode path: M == 1 split-K GEMV — parallelize over (N-tile, K-group). # --------------------------------------------------------------------------- @triton.autotune( configs=[ triton.Config({"BLOCK_N": 256}, num_warps=4), triton.Config({"BLOCK_N": 512}, num_warps=8), triton.Config({"BLOCK_N": 128}, num_warps=4), triton.Config({"BLOCK_N": 1024}, num_warps=8), ], key=["N", "K"], ) @triton.jit def _w4a16_gemv_splitk_kernel( x_ptr, wq_ptr, scales_ptr, zeros_ptr, partial_ptr, N, K, stride_wqn, stride_sg, stride_sn, stride_pg, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr, PACKED_PER_GROUP: tl.constexpr, ): pid_n = tl.program_id(0) pid_g = tl.program_id(1) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) n_mask = offs_n < N s = tl.load( scales_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0 ).to(tl.float32) z = tl.load( zeros_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0 ).to(tl.float32) acc = tl.zeros((BLOCK_N,), dtype=tl.float32) base_k = pid_g * GROUP_SIZE base_packed = pid_g * PACKED_PER_GROUP for pi in tl.static_range(PACKED_PER_GROUP): packed_row = base_packed + pi k0 = base_k + 2 * pi w_packed = tl.load( wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0 ).to(tl.int32) w_lo = (w_packed & 0xF).to(tl.float32) w_hi = ((w_packed >> 4) & 0xF).to(tl.float32) x0 = tl.load(x_ptr + k0).to(tl.float32) x1 = tl.load(x_ptr + k0 + 1).to(tl.float32) acc += x0 * (w_lo - z) * s + x1 * (w_hi - z) * s tl.store(partial_ptr + pid_g * stride_pg + offs_n, acc, mask=n_mask) @triton.jit def _reduce_partial_kernel( partial_ptr, out_ptr, N, n_groups, stride_pg, BLOCK_N: tl.constexpr, ): pid = tl.program_id(0) offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N) n_mask = offs_n < N acc = tl.zeros((BLOCK_N,), dtype=tl.float32) for g in tl.range(0, n_groups): acc += tl.load(partial_ptr + g * stride_pg + offs_n, mask=n_mask, other=0.0) tl.store(out_ptr + offs_n, acc.to(tl.bfloat16), mask=n_mask) # --------------------------------------------------------------------------- # General GEMM path: M > 1 — one tl.dot per quant group (128 K). # --------------------------------------------------------------------------- @triton.autotune( configs=[ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=8, num_stages=2), ], key=["M", "N", "K"], ) @triton.jit def _w4a16_gemm_kernel( x_ptr, wq_ptr, scales_ptr, zeros_ptr, out_ptr, M, N, K, stride_xm, stride_xk, stride_wqn, stride_sg, stride_sn, stride_om, stride_on, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr, PACKED_PER_GROUP: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) m_mask = offs_m < M n_mask = offs_n < N acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) n_groups = K // GROUP_SIZE for g in tl.range(0, n_groups): s = tl.load( scales_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0 ).to(tl.bfloat16) z = tl.load( zeros_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0 ).to(tl.bfloat16) base_k = g * GROUP_SIZE base_packed = g * PACKED_PER_GROUP packed_rows = base_packed + tl.arange(0, PACKED_PER_GROUP) wq_ptrs = wq_ptr + packed_rows[:, None] * stride_wqn + offs_n[None, :] wq_mask = n_mask[None, :] wq_packed = tl.load(wq_ptrs, mask=wq_mask, other=0).to(tl.int32) w_lo = (wq_packed & 0xF).to(tl.bfloat16) w_hi = ((wq_packed >> 4) & 0xF).to(tl.bfloat16) w_deq_lo = (w_lo - z[None, :]) * s[None, :] w_deq_hi = (w_hi - z[None, :]) * s[None, :] offs_k_even = base_k + 2 * tl.arange(0, PACKED_PER_GROUP) offs_k_odd = offs_k_even + 1 x_even_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k_even[None, :] * stride_xk x_odd_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k_odd[None, :] * stride_xk k_even_mask = offs_k_even[None, :] < K k_odd_mask = offs_k_odd[None, :] < K x_even = tl.load(x_even_ptrs, mask=m_mask[:, None] & k_even_mask, other=0.0).to( tl.bfloat16 ) x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & k_odd_mask, other=0.0).to( tl.bfloat16 ) acc += tl.dot(x_even, w_deq_lo, out_dtype=tl.float32) acc += tl.dot(x_odd, w_deq_hi, out_dtype=tl.float32) out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on out_mask = m_mask[:, None] & n_mask[None, :] tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask) def w4a16_gemm( x: torch.Tensor, w_q: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = GROUP_SIZE, ) -> torch.Tensor: """x: (M,K) bf16; w_q: (K//2,N) uint8; scales/zeros: (K//group,N) bf16 -> (M,N) bf16.""" assert x.is_contiguous() assert w_q.is_contiguous() assert scales.is_contiguous() assert zeros.is_contiguous() M, K = x.shape Kh, N = w_q.shape assert Kh * 2 == K assert scales.shape == zeros.shape assert scales.shape[0] == K // group_size assert scales.shape[1] == N if M == 1: n_groups = K // group_size packed_per_group = group_size // 2 partial = torch.empty((n_groups, N), dtype=torch.float32, device=x.device) grid = lambda meta: ( triton.cdiv(N, meta["BLOCK_N"]), n_groups, ) _w4a16_gemv_splitk_kernel[grid]( x, w_q, scales, zeros, partial, N, K, w_q.stride(0), scales.stride(0), scales.stride(1), partial.stride(0), GROUP_SIZE=group_size, PACKED_PER_GROUP=packed_per_group, ) out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device) reduce_grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) _reduce_partial_kernel[reduce_grid]( partial, out, N, n_groups, partial.stride(0), BLOCK_N=256, ) return out else: out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device) packed_per_group = group_size // 2 grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]), ) _w4a16_gemm_kernel[grid]( x, w_q, scales, zeros, out, M, N, K, x.stride(0), x.stride(1), w_q.stride(0), scales.stride(0), scales.stride(1), out.stride(0), out.stride(1), GROUP_SIZE=group_size, PACKED_PER_GROUP=packed_per_group, ) return out class Model(nn.Module): """W4A16 GEMM: y = x @ dequant(w_q, scales, zeros).""" def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE): super().__init__() assert K % group_size == 0 assert K % 2 == 0 self.M, self.N, self.K = M, N, K self.group_size = group_size n_groups = K // group_size torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K)) w_full = torch.randn(K, N, dtype=torch.float32) * 0.02 w_g = w_full.view(n_groups, group_size, N) w_min = w_g.min(dim=1, keepdim=True).values w_max = w_g.max(dim=1, keepdim=True).values scales = (w_max - w_min).clamp_min(1e-8) / 15.0 zeros = (-w_min / scales).round().clamp(0, 15) w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8).view(K, N) lo = w_q[0::2] & 0xF hi = w_q[1::2] & 0xF w_packed = (lo | (hi << 4)).contiguous() self.register_buffer("w_q", w_packed) self.register_buffer("scales", scales.squeeze(1).to(torch.bfloat16)) self.register_buffer("zeros", zeros.squeeze(1).to(torch.bfloat16)) def forward(self, x: torch.Tensor) -> torch.Tensor: return w4a16_gemm(x, self.w_q, self.scales, self.zeros, self.group_size) M = 1 N = 12288 K = 4096 def get_inputs(): x = torch.randn(M, K, dtype=torch.bfloat16) return [x] def get_init_inputs(): return [M, N, K]