import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "gemm" SUPPORTED_PRECISIONS = ["fp8_e4m3"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] E4M3_MAX = 448.0 @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bn, stride_bk, stride_cm, stride_cn, stride_scale, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N offs_k = tl.arange(0, BLOCK_K) a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) accumulator = tl.dot(a, b, accumulator) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk # Load scales scale_ptrs = scale_ptr + offs_bn * stride_scale scale = tl.load(scale_ptrs, mask=offs_bn < N, other=1.0) accumulator = accumulator * scale[None, :] offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator.to(tl.bfloat16), mask=c_mask) class Model(nn.Module): def __init__(self, M: int, N: int, K: int): super().__init__() self.M, self.N, self.K = M, N, K # Setup temporary buffers matching reference w = torch.empty(N, K, dtype=torch.bfloat16) nn.init.normal_(w, std=0.02) s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12) w_fp8 = (w.float() / s).to(torch.float8_e4m3fn) self.register_buffer("weight", w_fp8) self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: # Run our Triton GEMM M, K = x.shape N = self.weight.shape[0] c = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) BLOCK_M = 64 BLOCK_N = 64 BLOCK_K = 64 grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) matmul_kernel[grid]( x, self.weight, c, self.weight_scale, M, N, K, x.stride(0), x.stride(1), self.weight.stride(0), self.weight.stride(1), c.stride(0), c.stride(1), self.weight_scale.stride(0), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, ) return c M = 4096 N = 4096 K = 4096 def get_inputs(): x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn) return [x] def get_init_inputs(): return [M, N, K]