"""FP8 e4m3 GEMM using real fp8 tensor-core MMA via Triton. Layout: x: fp8_e4m3 (M, K) weight: fp8_e4m3 (N, K) weight_scale: fp32 (N,) y = (x @ weight.T) * weight_scale -> bf16 (M, N) K dimensions that are not aligned to the tensor-core tile are handled by padding the operands up to the tile size. The padding values are fp8 zeros, so they contribute nothing to the result. """ import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl OP_TYPE = "gemm" SUPPORTED_PRECISIONS = ["fp8_e4m3"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] E4M3_MAX = 448.0 M = 4096 N = 4096 K = 4096 @triton.jit def fp8_gemm_kernel( x_ptr, w_ptr, s_ptr, y_ptr, M, N, K, stride_xm, stride_xk, stride_wn, stride_wk, stride_ym, stride_yn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, ): """C = (A @ B.T) * scale. K is assumed a multiple of BLOCK_K.""" pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_in_group = GROUP_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_M group_size_m = min(num_pid_m - first_pid_m, GROUP_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk) # Load B transposed directly from weight (N, K) row-major into (BLOCK_K, BLOCK_N). w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k0 in range(0, K, BLOCK_K): a = tl.load( x_ptrs + k0 * stride_xk, mask=offs_m[:, None] < M, other=0.0, ) b = tl.load( w_ptrs + k0 * stride_wk, mask=offs_n[None, :] < N, other=0.0, ) acc += tl.dot(a, b) scale = tl.load(s_ptr + offs_n, mask=offs_n < N, other=0.0) acc = acc * scale[None, :] y_ptrs = y_ptr + (offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn) tl.store( y_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), ) # Fallback autotuned kernel for shapes outside the hand-tuned set. @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8}, num_warps=8, num_stages=3), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_warps=8, num_stages=3), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_warps=8, num_stages=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8}, num_warps=8, num_stages=3), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_warps=4, num_stages=4), triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 1}, num_warps=4, num_stages=4), triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 1}, num_warps=4, num_stages=4), ], key=["M", "N", "K"], ) @triton.jit def fp8_gemm_kernel_autotune( x_ptr, w_ptr, s_ptr, y_ptr, M, N, K, stride_xm, stride_xk, stride_wn, stride_wk, stride_ym, stride_yn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, ): fp8_gemm_kernel( x_ptr, w_ptr, s_ptr, y_ptr, M, N, K, stride_xm, stride_xk, stride_wn, stride_wk, stride_ym, stride_yn, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, ) class Model(nn.Module): """y = ((x @ weight.T) * weight_scale).to(bf16).""" def __init__(self, M: int, N: int, K: int): super().__init__() self.M, self.N, self.K = M, N, K w = torch.empty(N, K, dtype=torch.bfloat16) nn.init.normal_(w, std=0.02) s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12) w_fp8 = (w.float() / s).to(torch.float8_e4m3fn) self.register_buffer("weight", w_fp8) self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: M, K = x.shape N = self.weight.shape[0] x = x.contiguous() w = self.weight.contiguous() s = self.weight_scale.contiguous() cfg = _pick_config(M, N, K) if cfg is None: # Generic path: pad to 128 and autotune. tile = 128 K_pad = ((K + tile - 1) // tile) * tile if K_pad != K: w = F.pad(w, (0, K_pad - K)) x = F.pad(x, (0, K_pad - K)) K = K_pad _run_autotune_kernel(x, w, s, y := torch.empty((M, N), device=x.device, dtype=torch.bfloat16), M, N, K) return y # Hand-tuned path: pad to the config's K tile. tile = cfg["BLOCK_K"] K_pad = ((K + tile - 1) // tile) * tile if K_pad != K: w = F.pad(w, (0, K_pad - K)) x = F.pad(x, (0, K_pad - K)) K = K_pad y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) _run_manual_kernel(x, w, s, y, M, N, K, cfg) return y def _pick_config(M: int, N: int, K: int): """Hand-picked configs for the graded shapes.""" if M == 32 and N == 8192 and K == 8192: return {"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 1, "num_warps": 4, "num_stages": 4} # Compute-bound square/rectangular shapes (use a 64-wide K tile). if M >= 4096 and N >= 4096: return {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8, "num_warps": 8, "num_stages": 4} return None def _run_manual_kernel(x, w, s, y, M, N, K, cfg): grid = (triton.cdiv(M, cfg["BLOCK_M"]) * triton.cdiv(N, cfg["BLOCK_N"]),) fp8_gemm_kernel[grid]( x, w, s, y, M, N, K, x.stride(0), x.stride(1), w.stride(0), w.stride(1), y.stride(0), y.stride(1), BLOCK_M=cfg["BLOCK_M"], BLOCK_N=cfg["BLOCK_N"], BLOCK_K=cfg["BLOCK_K"], GROUP_M=cfg["GROUP_M"], num_warps=cfg["num_warps"], num_stages=cfg["num_stages"], ) def _run_autotune_kernel(x, w, s, y, M, N, K): grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) fp8_gemm_kernel_autotune[grid]( x, w, s, y, M, N, K, x.stride(0), x.stride(1), w.stride(0), w.stride(1), y.stride(0), y.stride(1), ) def get_inputs(): x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn) return [x] def get_init_inputs(): return [M, N, K]