"""FP8 e4m3 GEMM — Triton kernel using fp8 tensor cores on Blackwell SM120. Model interface: same as reference.py — holds `weight` (fp8_e4m3, N×K) and `weight_scale` (float32, N). Forward: y = ((x @ w.T) * weight_scale).to(bf16). Uses a 2D grid launch and dynamically pads K to a multiple of BLOCK_K so every inner-loop iteration loads full, unmasked tiles (avoiding the Triton tail slowdown on fp8 masked loads). """ import torch import torch.nn as nn import triton import triton.language as tl E4M3_MAX = 448.0 # --------------------------------------------------------------------------- # Triton kernel — 2D grid (M-blocks × N-blocks) # --------------------------------------------------------------------------- @triton.jit def _fp8_gemm_kernel( a_ptr, # fp8 activation (M, K_padded) b_ptr, # fp8 weight (N, K_padded) — read transposed as (K, N) c_ptr, # bf16 output (M, N) scale_ptr, # float32 scale (N,) M, N, K_padded, # Padded K (multiple of BLOCK_K) stride_am, stride_ak, stride_bn, stride_bk, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_n[None, :] * stride_bn + offs_k[:, None] * stride_bk acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # K_padded is a multiple of BLOCK_K — every iteration loads a full tile. for k in range(0, K_padded, BLOCK_K): a = tl.load(a_ptrs) b = tl.load(b_ptrs) acc = tl.dot(a, b, acc) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk # Apply per-channel dequant scale scale = tl.load(scale_ptr + offs_n, mask=offs_n < N, other=1.0) acc = acc * scale[None, :] # Store with edge masks c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) # --------------------------------------------------------------------------- # Padding helpers # --------------------------------------------------------------------------- def _pad_tensor(t: torch.Tensor, K_padded: int) -> torch.Tensor: """Pad the last dimension of an fp8 tensor to *K_padded* with zeros.""" K = t.shape[1] if K == K_padded: return t padded = torch.zeros(t.shape[0], K_padded, dtype=t.dtype, device=t.device) padded[:, :K].copy_(t) return padded # --------------------------------------------------------------------------- # Kernel dispatch # --------------------------------------------------------------------------- def _run_kernel( x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, out: torch.Tensor, ): M, K = x.shape N = weight.shape[0] # Choose tile sizes. # - Skinny M (≤64): narrower M, wider K to amortise memory latency. # - All other shapes: balanced tile that fits 3-stage pipelining in # 128 KB of shared memory. if M <= 64: BLOCK_M, BLOCK_N, BLOCK_K = 64, 128, 256 else: BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64 # Pad K so every inner-loop iteration loads a full tile. K_padded = ((K + BLOCK_K - 1) // BLOCK_K) * BLOCK_K x_padded = _pad_tensor(x, K_padded) w_padded = _pad_tensor(weight, K_padded) grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) _fp8_gemm_kernel[grid]( x_padded, w_padded, out, weight_scale, M, N, K_padded, x_padded.stride(0), x_padded.stride(1), w_padded.stride(0), w_padded.stride(1), out.stride(0), out.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, ) # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- class Model(nn.Module): """FP8 GEMM: y = (x @ w.T) * weight_scale → bf16.""" def __init__(self, M: int, N: int, K: int): super().__init__() self.M, self.N, self.K = M, N, K w = torch.empty(N, K, dtype=torch.bfloat16) nn.init.normal_(w, std=0.02) s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12) w_fp8 = (w.float() / s).to(torch.float8_e4m3fn) self.register_buffer("weight", w_fp8) self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: M, K = x.shape N = self.weight.shape[0] out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device) _run_kernel(x, self.weight, self.weight_scale, out) return out # --------------------------------------------------------------------------- # Entry points — identical to reference # --------------------------------------------------------------------------- M = 4096 N = 4096 K = 4096 def get_inputs(): x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn) return [x] def get_init_inputs(): return [M, N, K]