from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl GROUP_SIZE = 128 @triton.jit def _w4a16_gemm_kernel( x_ptr, wq_ptr, scales_ptr, zeros_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32) if BLOCK_K == 256: offs_k128 = tl.arange(0, 128) for k0 in range(0, K, 256): for part in tl.static_range(0, 2): k = k0 + part * 128 + offs_k128 a = tl.load( x_ptr + offs_m[:, None] * K + k[None, :], mask=offs_m[:, None] < M, other=0.0, ) packed = tl.load( wq_ptr + (k[:, None] // 2) * N + offs_n[None, :], mask=offs_n[None, :] < N, other=0, ) q_lo = packed & 0x0F q_hi = (packed >> 4) & 0x0F q = tl.where((k[:, None] & 1) == 0, q_lo, q_hi).to(tl.float32) group = k0 // 128 + part s = tl.load(scales_ptr + group * N + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) z = tl.load(zeros_ptr + group * N + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) b = ((q - z[None, :]) * s[None, :]).to(tl.bfloat16) acc += tl.dot(a, b, out_dtype=tl.float32) else: offs_k = tl.arange(0, BLOCK_K) for k0 in range(0, K, BLOCK_K): k = k0 + offs_k a = tl.load( x_ptr + offs_m[:, None] * K + k[None, :], mask=offs_m[:, None] < M, other=0.0, ) packed = tl.load( wq_ptr + (k[:, None] // 2) * N + offs_n[None, :], mask=offs_n[None, :] < N, other=0, ) q_lo = packed & 0x0F q_hi = (packed >> 4) & 0x0F q = tl.where((k[:, None] & 1) == 0, q_lo, q_hi).to(tl.float32) group = k0 // 128 s = tl.load(scales_ptr + group * N + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) z = tl.load(zeros_ptr + group * N + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) b = ((q - z[None, :]) * s[None, :]).to(tl.bfloat16) acc += tl.dot(a, b, out_dtype=tl.float32) tl.store( out_ptr + offs_m[:, None] * N + offs_n[None, :], acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), ) class Model(nn.Module): def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE): super().__init__() assert group_size == GROUP_SIZE assert K % GROUP_SIZE == 0 assert K % 2 == 0 self.M, self.N, self.K = M, N, K self.group_size = group_size self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8)) self.register_buffer("scales", torch.empty((K // group_size, N), dtype=torch.bfloat16)) self.register_buffer("zeros", torch.empty((K // group_size, N), dtype=torch.bfloat16)) def forward(self, x: torch.Tensor) -> torch.Tensor: M, N, K = self.M, self.N, self.K out = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) if M == 1: if N <= 4096: bm, bn, bk, warps, stages = 1, 32, 256, 2, 4 else: bm, bn, bk, warps, stages = 1, 64, 256, 8, 2 elif M <= 16: bm, bn, bk, warps, stages = 16, 128, 128, 8, 4 elif M <= 32: bm, bn, bk, warps, stages = 32, 64, 128, 4, 3 else: bm, bn, bk, warps, stages = 128, 64, 32, 4, 3 grid = (triton.cdiv(M, bm), triton.cdiv(N, bn)) _w4a16_gemm_kernel[grid]( x, self.w_q, self.scales, self.zeros, out, M, N, K, BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk, num_warps=warps, num_stages=stages, ) return out M = 1 N = 12288 K = 4096 def get_inputs(): x = torch.randn(M, K, dtype=torch.bfloat16) return [x] def get_init_inputs(): return [M, N, K]