import torch import torch.nn as nn import triton import triton.language as tl GROUP_SIZE = 128 @triton.jit def w4a16_gemm_kernel( x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr, M, N, stride_xm, stride_xk, stride_wk, stride_wn, stride_sm, stride_sn, stride_zm, stride_zn, stride_om, stride_on, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE: tl.constexpr, K: tl.constexpr, ): pid = tl.program_id(0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # 2D Grid with Program ID grouping/swizzling for L2 cache reuse num_pid_in_group = 8 group_id = pid // (num_pid_in_group * num_pid_n) first_pid_m = group_id * num_pid_in_group group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2) offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2) a_mask = offs_am[:, None] < M b_mask = offs_bn[None, :] < N b_col_mask = offs_bn < N # Hoist invariant pointer offsets outside the loop a_base_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half)[None, :] * stride_xk) a_base_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half + 1)[None, :] * stride_xk) b_base_ptrs = w_ptr + (offs_b_k[:, None] * stride_wk + offs_bn[None, :] * stride_wn) scale_base_ptrs = scales_ptr + offs_bn * stride_sn zero_base_ptrs = zeros_ptr + offs_bn * stride_zn accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k_curr in range(0, K, BLOCK_SIZE_K): # Even and Odd input pointers a_even = tl.load(a_base_even_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0) a_odd = tl.load(a_base_odd_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0) # Load packed weights b_packed = tl.load(b_base_ptrs + (k_curr // 2) * stride_wk, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0) # Unpack even/odd b_even_uint8 = b_packed & 0xF b_odd_uint8 = (b_packed >> 4) & 0xF # Load scale and zero for current group k_group = k_curr // GROUP_SIZE scale = tl.load(scale_base_ptrs + k_group * stride_sm, mask=b_col_mask, other=0.0) zero = tl.load(zero_base_ptrs + k_group * stride_zm, mask=b_col_mask, other=0.0) # Dequantize to bfloat16 w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :] w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :] # Dot products accumulator += tl.dot(a_even, w_even) accumulator += tl.dot(a_odd, w_odd) # Write to output c = accumulator.to(tl.bfloat16) offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on) out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N) tl.store(out_ptrs, c, mask=out_mask) class Model(nn.Module): def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE): super().__init__() assert K % group_size == 0, "K must be divisible by group_size" assert K % 2 == 0, "K must be even (int4 packing)" self.M, self.N, self.K = M, N, K self.group_size = group_size n_groups = K // group_size # Register buffers so load_state_dict works self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8)) self.register_buffer("scales", torch.empty((n_groups, N), dtype=torch.bfloat16)) self.register_buffer("zeros", torch.empty((n_groups, N), dtype=torch.bfloat16)) def forward(self, x: torch.Tensor, config_override=None) -> torch.Tensor: M, K = x.shape N = self.w_q.shape[1] out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device) if config_override is not None: BLOCK_SIZE_M = config_override['BLOCK_SIZE_M'] BLOCK_SIZE_N = config_override['BLOCK_SIZE_N'] BLOCK_SIZE_K = config_override['BLOCK_SIZE_K'] num_warps = config_override['num_warps'] num_stages = config_override['num_stages'] else: # Optimal hand-tuned dispatcher configs if M == 1 and N == 12288: BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 64 BLOCK_SIZE_K = 128 num_warps = 4 num_stages = 5 elif M == 32 and N == 12288: BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 128 BLOCK_SIZE_K = 128 num_warps = 8 num_stages = 2 elif M == 256 and N == 12288: BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 256 BLOCK_SIZE_K = 64 num_warps = 4 num_stages = 3 elif M == 1 and N == 4096: BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 64 BLOCK_SIZE_K = 128 num_warps = 8 num_stages = 3 elif M == 16 and N == 14336: BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 128 BLOCK_SIZE_K = 128 num_warps = 8 num_stages = 2 else: # Default fallback BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 128 BLOCK_SIZE_K = 128 num_warps = 4 num_stages = 2 grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) w4a16_gemm_kernel[grid]( x, self.w_q, self.scales, self.zeros, out, M, N, x.stride(0), x.stride(1), self.w_q.stride(0), self.w_q.stride(1), self.scales.stride(0), self.scales.stride(1), self.zeros.stride(0), self.zeros.stride(1), out.stride(0), out.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE=self.group_size, K=K, num_warps=num_warps, num_stages=num_stages, ) return out def get_inputs(): from reference import get_inputs as ref_get_inputs return ref_get_inputs() def get_init_inputs(): from reference import get_init_inputs as ref_get_init_inputs return ref_get_init_inputs()