"""FP8 GEMM with Triton — heuristic + version-based repad.""" import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "gemm" SUPPORTED_PRECISIONS = ["fp8_e4m3"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] E4M3_MAX = 448.0 # Pad K to multiple of 128 for efficient MMA. K_BLOCK_ALIGN = 128 @triton.jit def fp8_gemm_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bn, stride_bk, scale_ptr, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, ): pid = tl.program_id(0) grid_m = tl.cdiv(M, BLOCK_M) grid_n = tl.cdiv(N, BLOCK_N) width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // group_size offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_n[None, :] * stride_bn + offs_k[:, None] * stride_bk) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): a = tl.load(a_ptrs) b = tl.load(b_ptrs) acc = tl.dot(a, b, acc) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk scale = tl.load(scale_ptr + offs_n) acc = acc * scale[None, :] c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mask) def fp8_gemm_dispatch(x, w, weight_scale, M, N, K_pad): if M <= 32: BM, BN, BK, GM, NW, NS = 32, 64, 256, 4, 4, 2 else: BM, BN, BK, GM, NW, NS = 128, 256, 128, 4, 8, 3 c = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) grid = (triton.cdiv(M, BM) * triton.cdiv(N, BN),) fp8_gemm_kernel[grid]( x, w, c, M, N, K_pad, x.stride(0), x.stride(1), w.stride(0), w.stride(1), weight_scale, c.stride(0), c.stride(1), BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, GROUP_M=GM, num_warps=NW, num_stages=NS, ) return c class Model(nn.Module): def __init__(self, M: int, N: int, K: int): super().__init__() self.M, self.N, self.K = M, N, K w = torch.empty(N, K, dtype=torch.bfloat16) nn.init.normal_(w, std=0.02) s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12) w_fp8 = (w.float() / s).to(torch.float8_e4m3fn) self.register_buffer("weight", w_fp8) self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32)) K_pad = ((K + K_BLOCK_ALIGN - 1) // K_BLOCK_ALIGN) * K_BLOCK_ALIGN self._K_pad = K_pad if K_pad != K: w_pad = torch.zeros((N, K_pad), dtype=w_fp8.dtype) w_pad[:, :K].copy_(w_fp8) self.register_buffer("_w_padded", w_pad, persistent=False) self._w_version = -1 else: self._w_padded = None self._w_version = -1 def forward(self, x: torch.Tensor) -> torch.Tensor: K = self.K K_pad = self._K_pad M, N = self.M, self.N if K_pad != K: if self.weight._version != self._w_version: self._w_padded[:, :K].copy_(self.weight) self._w_version = self.weight._version x_pad = getattr(self, '_x_pad', None) if x_pad is None or x_pad.device != x.device: x_pad = torch.zeros((M, K_pad), device=x.device, dtype=x.dtype) self._x_pad = x_pad x_pad[:, :K].copy_(x) x_use = x_pad w_use = self._w_padded else: x_use = x w_use = self.weight return fp8_gemm_dispatch(x_use, w_use, self.weight_scale, M, N, K_pad) M = 4096 N = 4096 K = 4096 def get_inputs(): x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn) return [x] def get_init_inputs(): return [M, N, K]