import torch import torch.nn as nn import triton import triton.language as tl E4M3_MAX = 448.0 @triton.jit def _fp8_gemm_kernel( x_ptr, w_ptr, scale_ptr, y_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, ): pid = tl.program_id(0) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) group_size = GROUP_M * num_pid_n group_id = pid // group_size first_pid_m = group_id * GROUP_M group_m = tl.minimum(num_pid_m - first_pid_m, GROUP_M) pid_in_group = pid - group_id * group_size pid_m = first_pid_m + (pid_in_group % group_m) pid_n = pid_in_group // group_m offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32) full_k = (K // BLOCK_K) * BLOCK_K for k0 in range(0, full_k, BLOCK_K): k = k0 + offs_k a = tl.load(x_ptr + offs_m[:, None] * K + k[None, :]) b = tl.load(w_ptr + offs_n[None, :] * K + k[:, None]) acc = tl.dot(a, b, acc, out_dtype=tl.float32) if full_k < K: k = full_k + offs_k a = tl.load( x_ptr + offs_m[:, None] * K + k[None, :], mask=k[None, :] < K, other=0.0, ) b = tl.load( w_ptr + offs_n[None, :] * K + k[:, None], mask=k[:, None] < K, other=0.0, ) acc = tl.dot(a, b, acc, out_dtype=tl.float32) scales = tl.load(scale_ptr + offs_n, mask=offs_n < N, other=0.0) acc = acc * scales[None, :] tl.store( y_ptr + offs_m[:, None] * N + offs_n[None, :], acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), ) def _launch_fp8_gemm(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, M: int, N: int, K: int): y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) if M <= 64: bm, bn, bk = 32, 32, 256 warps, stages, group_m = 4, 4, 1 elif K == 4224 and M == 4096 and N == 4096: bm, bn, bk = 256, 128, 128 warps, stages, group_m = 8, 3, 4 elif K % 128 != 0: bm, bn, bk = 128, 128, 128 warps, stages, group_m = 4, 3, 8 elif N >= 8192: bm, bn, bk = 128, 256, 64 warps, stages, group_m = 8, 4, 2 else: bm, bn, bk = 256, 128, 64 warps, stages, group_m = 8, 4, 4 grid = (triton.cdiv(M, bm) * triton.cdiv(N, bn),) _fp8_gemm_kernel[grid]( x, weight, weight_scale, y, M, N, K, BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk, GROUP_M=group_m, num_warps=warps, num_stages=stages, ) return y class Model(nn.Module): def __init__(self, M: int, N: int, K: int): super().__init__() self.M, self.N, self.K = M, N, K self._weight_pad = None self._weight_pad_key = None self._x_pad = None w = torch.empty(N, K, dtype=torch.bfloat16) nn.init.normal_(w, std=0.02) s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12) w_fp8 = (w.float() / s).to(torch.float8_e4m3fn) self.register_buffer("weight", w_fp8) self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: if not x.is_cuda: x_bf = x.to(torch.bfloat16) w_bf = self.weight.to(torch.bfloat16) y = (x_bf @ w_bf.T).float() y = y * self.weight_scale[None, :] return y.to(torch.bfloat16) if self.K % 128 != 0: k_pad = ((self.K + 127) // 128) * 128 w_key = (self.weight.data_ptr(), self.weight._version, self.weight.device, self.weight.shape) if self._weight_pad is None or self._weight_pad_key != w_key: weight_pad = torch.empty((self.N, k_pad), device=x.device, dtype=self.weight.dtype) weight_pad[:, : self.K].copy_(self.weight) weight_pad[:, self.K :].zero_() self._weight_pad = weight_pad self._weight_pad_key = w_key if ( self._x_pad is None or self._x_pad.device != x.device or self._x_pad.shape != (self.M, k_pad) ): self._x_pad = torch.empty((self.M, k_pad), device=x.device, dtype=x.dtype) self._x_pad[:, : self.K].copy_(x) self._x_pad[:, self.K :].zero_() return _launch_fp8_gemm(self._x_pad, self._weight_pad, self.weight_scale, self.M, self.N, k_pad) return _launch_fp8_gemm(x, self.weight, self.weight_scale, self.M, self.N, self.K) M = 4096 N = 4096 K = 4096 def get_inputs(): x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn) return [x] def get_init_inputs(): return [M, N, K]