"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale.""" import torch import torch.nn as nn import triton import triton.language as tl E4M3_MAX = 448.0 @triton.jit def _fp8_gemm_kernel( A, B, C, Scales, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, NUM_STAGES: tl.constexpr, ): pid = tl.program_id(0) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_in_group = GROUP_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_M group_size_m = min(num_pid_m - first_pid_m, GROUP_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES): a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0) b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0) acc = tl.dot(a, b, acc) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0) acc = acc * scales[None, :] c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) @triton.jit def _fp8_gemm_skinny_kernel( A, B, C, Scales, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, NUM_STAGES: tl.constexpr, ): pid_n = tl.program_id(0) offs_m = tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES): a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0) b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0) acc = tl.dot(a, b, acc) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0) acc = acc * scales[None, :] c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]: if M <= 64: return 32, 128, 256, 4, 3, 1 if N >= 12000: return 128, 256, 128, 8, 3, 4 return 128, 256, 128, 8, 3, 4 def _fp8_gemm( x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, M: int, ) -> torch.Tensor: K = x.shape[1] N, K_w = weight.shape assert K == K_w y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N) if M <= 64: grid = (triton.cdiv(N, block_n),) _fp8_gemm_skinny_kernel[grid]( x, weight, y, weight_scale, M, N, K, x.stride(0), x.stride(1), weight.stride(1), weight.stride(0), y.stride(0), y.stride(1), BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, NUM_STAGES=num_stages, num_warps=num_warps, ) else: grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),) _fp8_gemm_kernel[grid]( x, weight, y, weight_scale, M, N, K, x.stride(0), x.stride(1), weight.stride(1), weight.stride(0), y.stride(0), y.stride(1), BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, GROUP_M=group_m, NUM_STAGES=num_stages, num_warps=num_warps, ) return y class Model(nn.Module): def __init__(self, M: int, N: int, K: int): super().__init__() self.M, self.N, self.K = M, N, K w = torch.empty(N, K, dtype=torch.bfloat16) nn.init.normal_(w, std=0.02) s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12) w_fp8 = (w.float() / s).to(torch.float8_e4m3fn) self.register_buffer("weight", w_fp8) self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32)) self._weight_padded: torch.Tensor | None = None self._weight_version: int = -1 def forward(self, x: torch.Tensor) -> torch.Tensor: K = x.shape[1] pad_k = (128 - (K % 128)) % 128 if pad_k: if self._weight_padded is None or self._weight_version != self.weight._version: self._weight_padded = torch.nn.functional.pad(self.weight, (0, pad_k)) self._weight_version = self.weight._version x = torch.nn.functional.pad(x, (0, pad_k)) weight = self._weight_padded else: weight = self.weight return _fp8_gemm(x, weight, self.weight_scale, self.M) M = 4096 N = 4096 K = 4096 def get_inputs(): x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn) return [x] def get_init_inputs(): return [M, N, K]