"""Custom FP8→bf16 GEMM kernel using bf16 tensor-core MMA on SM120 Blackwell.
Strategy: upcast fp8 inputs to bf16, keep weights as bf16, then use
mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 to hit tensor cores.
This matches the reference's numerical path exactly.
"""
from __future__ import annotations
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
OP_TYPE = "gemm"
SUPPORTED_PRECISIONS = ["fp8_e4m3"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
# ---------------------------------------------------------------------------
# CUDA kernel source
# ---------------------------------------------------------------------------
CUDA_SRC = r"""
#include <cuda_fp8.h>
#include <cuda_bf16.h>
// Tile sizes
// BM and BN are tuned for occupancy: 128×128 output per block
// BK=16 matches bf16 MMA K-dimension
#define BM 128
#define BN 128
#define BK 16
// Warp layout: 8 warps in a 2×4 grid (2 in M, 4 in N)
// Each warp covers 64 columns in N and 32 rows in M? Wait, let me recalculate.
// With 8 warps: 2 in M, 4 in N → each warp covers BM/2=64 rows, BN/4=32 cols
// But an MMA gives 16×8 → 64/16=4 MMAs in M, 32/8=4 MMAs in N → 16 MMAs per warp per K-step
// 16×4=64 f32 accumulators per thread. That's tight on registers.
// Actually let me use 4 warps per block for lower register pressure.
// BM=128, BN=64 with 4 warps (2×2): each warp covers 64×32, with 4×4=16 MMAs.
// Hmm, still 16 MMAs.
// Let me keep it simple: BM=64, BN=64, 4 warps (2×2), each warp: 32×32 → 2×4=8 MMAs.
// Warp tile 32×32, 8 MMAs per K-step → 32 accumulators per thread. Good.
#define WM_TILE 32
#define WN_TILE 32
#define WM_MMA 2 // 32/16
#define WN_MMA 4 // 32/8
__global__ void fp8_gemm_bf16_kernel(
const __nv_fp8_e4m3* __restrict__ A_in, // M x K, row-major, fp8
const __nv_bfloat16* __restrict__ B_in, // N x K, row-major, bf16
__nv_bfloat16* __restrict__ C, // M x N, row-major, bf16
int M, int N, int K)
{
// Block indices: x=N-dim, y=M-dim
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int warp_id = tid / 32;
int lane_id = tid % 32;
int warp_m = warp_id / 2; // 0..1
int warp_n = warp_id % 2; // 0..1
int group_id = lane_id >> 2; // 0..7
int thread_id = lane_id & 3; // 0..3
int c_row = by * BM + warp_m * WM_TILE;
int c_col = bx * BN + warp_n * WN_TILE;
// Shared memory tiles
__shared__ __nv_bfloat16 As[BM][BK]; // A tile after fp8→bf16 upcast
__shared__ __nv_bfloat16 Bs[BN][BK]; // B tile (already bf16)
// Accumulators: 8 MMA sets × 4 f32 = 32 f32
float accum[WM_MMA][WN_MMA][4] = {0.0f};
for (int k_block = 0; k_block < K; k_block += BK) {
// Cooperative load A (fp8 → bf16 conversion on the fly)
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
int g_row = by * BM + row;
int g_col = k_block + col;
if (g_row < M && g_col < K) {
float val = float(A_in[g_row * K + g_col]);
As[row][col] = __float2bfloat16(val);
} else {
As[row][col] = __float2bfloat16(0.0f);
}
}
// Cooperative load B (already bf16)
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
int g_n = bx * BN + n_idx;
int g_k = k_block + k_idx;
if (g_n < N && g_k < K) {
Bs[n_idx][k_idx] = B_in[g_n * K + g_k];
} else {
Bs[n_idx][k_idx] = __float2bfloat16(0.0f);
}
}
__syncthreads();
// MMAs for this warp
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
int a_row_base = warp_m * WM_TILE + m_idx * 16;
// A fragment: 2 regs × 2 bf16 each = 4 bf16 values
// a0,a1: row=group_id, col=thread_id*2+{0,1}
// a2,a3: row=group_id+8, col=thread_id*2+{0,1}
uint32_t a_reg0 = *(reinterpret_cast<const uint32_t*>(
&As[a_row_base + group_id][thread_id * 2]));
uint32_t a_reg1 = *(reinterpret_cast<const uint32_t*>(
&As[a_row_base + group_id + 8][thread_id * 2]));
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int b_n_base = warp_n * WN_TILE + n_idx * 8;
// B fragment: 1 reg × 2 bf16 = 2 bf16 values
// b0,b1: row=thread_id*2+{0,1}, col=group_id
// B is K×N col-major, element (k,n) → w[n][k]
// So b = w[group_id][thread_id*2+{0,1}]
uint32_t b_reg = *(reinterpret_cast<const uint32_t*>(
&Bs[b_n_base + group_id][thread_id * 2]));
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5}, "
"{%6}, "
"{%7, %8, %9, %10};"
: "=f"(accum[m_idx][n_idx][0]),
"=f"(accum[m_idx][n_idx][1]),
"=f"(accum[m_idx][n_idx][2]),
"=f"(accum[m_idx][n_idx][3])
: "r"(a_reg0), "r"(a_reg1),
"r"(b_reg),
"f"(accum[m_idx][n_idx][0]),
"f"(accum[m_idx][n_idx][1]),
"f"(accum[m_idx][n_idx][2]),
"f"(accum[m_idx][n_idx][3])
);
}
}
__syncthreads();
}
// Store bf16 output
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int out_row_base = c_row + m_idx * 16;
int out_col_base = c_col + n_idx * 8;
// Accumulator layout (same as FP8 m16n8k32):
// c0: (group_id, thread_id*2)
// c1: (group_id, thread_id*2+1)
// c2: (group_id+8, thread_id*2)
// c3: (group_id+8, thread_id*2+1)
int r0 = out_row_base + group_id;
int r1 = out_row_base + group_id + 8;
int c0 = out_col_base + thread_id * 2;
int c1 = out_col_base + thread_id * 2 + 1;
if (r0 < M) {
if (c0 < N) C[r0 * N + c0] = __float2bfloat16(accum[m_idx][n_idx][0]);
if (c1 < N) C[r0 * N + c1] = __float2bfloat16(accum[m_idx][n_idx][1]);
}
if (r1 < M) {
if (c0 < N) C[r1 * N + c0] = __float2bfloat16(accum[m_idx][n_idx][2]);
if (c1 < N) C[r1 * N + c1] = __float2bfloat16(accum[m_idx][n_idx][3]);
}
}
}
}
torch::Tensor fp8_gemm_launch(torch::Tensor x, torch::Tensor w_bf16) {
int M = x.size(0);
int K_x = x.size(1);
int N = w_bf16.size(0);
int K_w = w_bf16.size(1);
TORCH_CHECK(K_x == K_w, "K dimension mismatch");
auto output = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(x.device()));
dim3 block(128); // 4 warps
dim3 grid(
(N + BN - 1) / BN,
(M + BM - 1) / BM
);
cudaStream_t stream = 0;
fp8_gemm_bf16_kernel<<<grid, block, 0, stream>>>(
reinterpret_cast<const __nv_fp8_e4m3*>(x.const_data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(w_bf16.const_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
M, N, K_x
);
return output;
}
"""
# ---------------------------------------------------------------------------
# JIT compilation
# ---------------------------------------------------------------------------
_module = None
_fp8_gemm_fn = None
def _ensure_module():
global _module, _fp8_gemm_fn
if _module is not None:
return
_module = load_inline(
name="fp8_gemm_bf16_kernel_v2",
cpp_sources="torch::Tensor fp8_gemm_launch(torch::Tensor x, torch::Tensor w_bf16);",
cuda_sources=[CUDA_SRC],
functions=["fp8_gemm_launch"],
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=[
"-O3",
"--use_fast_math",
"-arch=sm_120a",
"-std=c++17",
],
verbose=False,
)
_fp8_gemm_fn = _module.fp8_gemm_launch
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class Model(nn.Module):
"""y = (x @ w.T).to(bf16) — fp8 activation, bf16 weight, bf16 output."""
def __init__(self, M: int, N: int, K: int):
super().__init__()
self.M, self.N, self.K = M, N, K
self.weight = nn.Parameter(torch.empty(N, K, dtype=torch.bfloat16))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_ensure_module()
# x is fp8, weight is bf16 — kernel handles upcast
return _fp8_gemm_fn(x.contiguous(), self.weight.contiguous())
M = 4096
N = 4096
K = 4096
def get_inputs():
x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn)
return [x]
def get_init_inputs():
return [M, N, K]
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 255.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py", line 254, in forward
_ensure_module()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py", line 223, in _ensure_module
_module = load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2268, in _jit_compile
_write_ninja_file_and_build_library(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2429, in _write_ninja_file_and_build_library
_run_ninja_build(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2791, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error building extension 'fp8_gemm_bf16_kernel_v2': [1/2] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=fp8_gemm_bf16_kernel_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -arch=sm_120a -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/fp8_gemm_bf16_kernel_v2/cuda.cu -o cuda.cuda.o
FAILED: [code=255] cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=fp8_gemm_bf16_kernel_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -arch=sm_120a -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/fp8_gemm_bf16_kernel_v2/cuda.cu -o cuda.cuda.o
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 706; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 706; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 710; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 710; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 714; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 714; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 718; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 718; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 724; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 724; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 728; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 728; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 732; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 732; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 736; error : Argument vector size mismatch for instruction 'mma'
ptxas /tmp/tmpxft_00102670_00000000-7_cuda.compute_120.ptx, line 736; error : Argument vector size mismatch for instruction 'mma'
ptxas fatal : Ptx assembly aborted due to errors
ninja: build stopped: subcommand failed.
{
"problem": "01_fp8_gemm",
"harness": "opencode",
"model": "deepseek/deepseek-v4-pro",
"reasoning_effort": "",
"has_solution": true,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 75814, "output_tokens": 28539, "cache_read_tokens": 3605120, "cache_creation_tokens": 0, "reasoning_tokens": 52812, "total_cost_usd": null}
}
{
"command": "ls -la",
"description": "List current directory"
}total 40
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 27 05:14 .
drwxrwxr-x 10 infatoshi infatoshi 4096 Apr 24 00:30 ..
-rw-rw-r-- 1 infatoshi infatoshi 4228 Apr 26 03:11 benchmark.py
-rw-rw-r-- 1 infatoshi infatoshi 3805 Apr 24 00:37 check.py
-rw-rw-r-- 1 infatoshi infatoshi 1128 Apr 24 00:36 problem.yaml
-rw-rw-r-- 1 infatoshi infatoshi 1850 Apr 27 03:21 PROMPT.txt
-rw-rw-r-- 1 infatoshi infatoshi 1401 Apr 24 00:40 reference.py
-rw-rw-r-- 1 infatoshi infatoshi 564 Apr 24 00:36 shapes.py
-rw-rw-r-- 1 infatoshi infatoshi 1725 Apr 24 00:37 sota.py
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive FP8 e4m3 GEMM reference (correctness only, NOT the SOTA baseline).
2:
3: We cast inputs to bf16 and use torch.matmul. The agent's solution must match
4: this numerically within the fp8 tolerance declared in problem.yaml.
5: """
6: import torch
7: import torch.nn as nn
8:
9: OP_TYPE = "gemm"
10: SUPPORTED_PRECISIONS = ["fp8_e4m3"]
11: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
12:
13:
14: class Model(nn.Module):
15: """y = (x @ w.T).to(bf16), where x is fp8_e4m3 (M, K), w is fp8_e4m3 (N, K)."""
16:
17: def __init__(self, M: int, N: int, K: int):
18: super().__init__()
19: self.M, self.N, self.K = M, N, K
20: # Weights stored as parameters so state_dict is well-defined.
21: # We initialize in bf16 then cast; the fp8 dtype is set by get_inputs.
22: self.weight = nn.Parameter(torch.empty(N, K, dtype=torch.bfloat16))
23: nn.init.normal_(self.weight, std=0.02)
24:
25: def forward(self, x: torch.Tensor) -> torch.Tensor:
26: # Upcast to bf16 for the naive reference; the kernel equivalent would
27: # use mma.sync f8f6f4 kind directly.
28: x_bf = x.to(torch.bfloat16)
29: w_bf = self.weight.to(torch.bfloat16)
30: return x_bf @ w_bf.T # (M, N) bf16
31:
32:
33: M = 4096
34: N = 4096
35: K = 4096
36:
37:
38: def get_inputs():
39: # fp8_e4m3 inp
... (227 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for FP8 GEMM.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "compute")
40: flops_formula
... (3416 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for FP8 GEMM.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or None
42:
43: # --- Per-shape c
... (2907 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for FP8 GEMM.
2:
3: Mix of:
4: - square aligned (the easy case)
5: - off-alignment K (common real-world failure mode for tile-quantized kernels)
6: - skinny (decode-like, memory-bound)
7: - rectangular (prefill with grouped attention)
8: """
9:
10: SHAPES = [
11: {"M": 4096, "N": 4096, "K": 4096}, # square aligned
12: {"M": 4096, "N": 4096, "K": 4127}, # K not multiple of 128 -> forces predicated tails
13: {"M": 32, "N": 8192, "K": 8192}, # skinny M (decode)
14: {"M": 4096, "N": 14336, "K": 4096}, # Llama3 up-proj shape
15: ]
(End of file - total 15 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 01_fp8_gemm
2: display_name: "FP8 e4m3 GEMM"
3: precision: fp8_e4m3
4: regime: compute
5:
6: # Dense-equivalent FLOPs per invocation. Evaluated per-shape; the benchmark
7: # substitutes M/N/K from each shapes.py entry.
8: flops_formula: "2 * M * N * K"
9:
10: # Bytes moved (approximate): fp8 inputs in, bf16 output.
11: bytes_formula: "M*K + K*N + M*N*2"
12:
13: hardware: [RTX_PRO_6000]
14: peak_tflops_key: fp8 # index into src/hardware/rtx_pro_6000.py peak_tflops_dense
15: peak_bandwidth_key: dram
16:
17: # Per-dtype tolerance override (optional). Defaults are in src/eval/correctness.py.
18: tolerance:
19: fp8_e4m3fn: 0.15
20:
21: # Forbidden ops — using any of these in solution.py fails correctness post-hoc.
22: # torch._scaled_mm is Stanford-style cheating (just call the vendor library).
23: forbidden:
24: - "torch._scaled_mm"
25: - "torch.ops.aten._scaled_mm"
26:
27: sota:
28: name: "FlashInfer FP8 GEMM"
29: url: "https://github.com/flashinfer-ai/flashinfer"
30: function: "flashinfer.gemm.fp8_gemm"
31: deps:
32: - "flashinfer>=0.6.8"
33: # Documented H100 throughput for this shape (informational, not graded):
34: reference_throughput_tflops_h100: 550
35:
36: num_correct_trials: 3
37: num_perf_trials: 30
(End of file - total 37 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for FP8 GEMM: flashinfer.gemm.fp8_gemm.
2:
3: If flashinfer is not installed or the SM120 path isn't supported, this falls
4: back to torch._scaled_mm which is the cuBLAS FP8 path. The benchmark treats
5: whichever succeeds as the SOTA reference line.
6:
7: Agents are FORBIDDEN from using torch._scaled_mm in their solution (see
8: problem.yaml.forbidden). This file is only for the benchmark's reference line.
9: """
10: from __future__ import annotations
11:
12: import torch
13:
14:
15: def _try_flashinfer(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor | None:
16: try:
17: import flashinfer # noqa: F401
18: # Note: flashinfer's FP8 GEMM API surface may differ; adapt if needed.
19: # Placeholder call — replace with the actual flashinfer entry point
20: # once validated on SM120.
21: return None
22: except ImportError:
23: return None
24:
25:
26: def _scaled_mm(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
27: # torch._scaled_mm wants per-tensor scales. Use unit scales for the reference.
28: scale_a = torch.tensor(1.0, device=x.device)
29: scale_b = torch.tensor(1.0, device=x.device)
30: out = torch._scaled_mm(
31: x,
32: w.T,
33: scale_a=scale_a,
34: scale_b=scale_b,
35: out_dtype=torch.bfloat16,
36: )
37: return out
... (576 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to optimize an FP8 GEMM for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The input is an fp8_e4m3 activation tensor (M, K) and an fp8_e4m3 weight (N, K), and you return y = x @ w.T as bf16 (M, N). Correctness tolerance on the bf16 output is 0.15 abs/rel — fp8 multiply has real noise so the bar is loose, but it has to match the reference within that. The shapes you have to handle are M=4096 N=4096 K=4096 (square aligned, the easy case), M=4096 N=4096 K=4127 (K not a multiple of 128, which forces predicated tails on tile-quantized kernels), M=32 N=8192 K=8192 (skinny M, decode-style), and M=4096 N=14336 K=4096 (Llama-3 up-proj).
4:
5: This needs to be a real custom kernel. Don't reach for torch._scaled_mm or torch.ops.aten._scaled_mm — they're off-limits and using them fails correctness. Try CUDA C++ via torch.utils.cpp_extension.load_inline, inline PTX, or CUTLASS — whatever fits. Anything you're uncertain about, look up PTX docs, clone CUTLASS or other reference repos, read library headers, and investiga
... (509 more chars){
"command": "cat /home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py 2>/dev/null || echo \"Not found\"",
"description": "Check RTX PRO 6000 hardware specs"
}"""RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
will see 60-85% of peak on well-tuned code.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class HardwareTarget:
name: str
sm: str
vram_gb: int
peak_bandwidth_gb_s: float # DRAM
peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
RTX_PRO_6000 = HardwareTarget(
name="RTX PRO 6000 Blackwell Workstation",
sm="sm_120a",
vram_gb=96,
peak_bandwidth_gb_s=1800.0,
peak_tflops_dense={
"fp4": 800.0,
"nvfp4": 800.0,
"mxfp4": 800.0,
"fp6": 800.0,
"fp8": 400.0,
"bf16": 200.0,
"fp16": 200.0,
"tf32": 100.0,
"fp32": 12.0, # non-tensor-core SIMT fp32
"int8": 400.0,
"int4": 800.0,
},
)
{
"command": "nvidia-smi --query-gpu=name,driver_version,compute_cap --format=csv,noheader",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 595.58.03, 12.0
{
"pattern": "**/src/eval/correctness.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard"
}/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py{
"description": "Check CUDA version and nvcc",
"command": "nvcc --version 2>&1 | head -5; echo \"---\"; ls /usr/local/cuda*/include/cuda_fp8.h 2>/dev/null; echo \"---\"; python -c \"import torch; print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}')\""
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
---
/usr/local/cuda-12.8/include/cuda_fp8.h
/usr/local/cuda-13.2/include/cuda_fp8.h
/usr/local/cuda-13/include/cuda_fp8.h
/usr/local/cuda/include/cuda_fp8.h
---
PyTorch 2.11.0+cu130, CUDA 13.0
{
"description": "Check for MMA-related headers and FP8 intrinsics",
"command": "ls /usr/local/cuda-12.8/include/cuda_fp8.h 2>/dev/null && head -100 /usr/local/cuda-12.8/include/cuda_fp8.h"
}/usr/local/cuda-12.8/include/cuda_fp8.h
/*
* Copyright 2022-2024 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
*
... (2842 more chars){
"description": "Look for MMA FP8 related intrinsics",
"command": "grep -r \"m16n8k32\\|f8e4m3\\|tcgen05\" /usr/local/cuda-13.2/include/ 2>/dev/null | head -30"
}/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_alloc.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_commit.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_cp.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_fence.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_ld.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_mma.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_mma_ws.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_shift.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_st.h>
/usr/local/cuda-13.2/include/cccl/cuda/ptx:#include <cuda/__ptx/instructions/tcgen05_wait.h>
/usr/local/cuda-13.2/include/cccl/cuda/__ptx/instructions/tcgen05_mma_ws.h:#include <cuda/__ptx/instructions/generated/tcgen05_mma_ws.h>
/usr/local/cuda-13.2/include/cccl/cuda/__ptx/instructions/tcgen05_mma.h:#include <cuda/__ptx/instructions/generated/tcgen05_mma.h>
/usr/local/cuda-13.2/include/cccl/cuda/__ptx/instructions/tcgen05_alloc.h:#include <cuda/__ptx/instructions/generated/tcgen05_alloc.h>
/usr/local/cuda-13.2/include/cccl/cuda/__ptx/instructions/tcgen05_shift.h:#include <cuda/__ptx/instructions/generated/tcgen05_shift.h>
/usr/local/cuda-13.2/include/ccc
... (2919 more chars){
"description": "Check PTX ISA version and FP8 MMA support",
"command": "ls /usr/local/cuda-13.2/nvvm/libdevice/ 2>/dev/null | head -10; nvcc --ptxas-options --help 2>&1 | head -5"
}libdevice.10.bc
nvcc fatal : No input files specified; use option --help for more information
{
"description": "Look at tcgen05 MMA instructions for FP8",
"command": "grep -i \"f8\\|m16n8k32\\|mma.*cta_group\" /usr/local/cuda-13.2/include/cccl/cuda/__ptx/instructions/generated/tcgen05_mma.h 2>/dev/null | head -60"
}// tcgen05.mma.cta_group.kind [d_tmem], a_desc, b_desc, idesc, disable_output_lane, enable_input_d, scale_input_d; //
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%4, %5, %6, %7}, PRED_enable_input_d, %9;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%4, %5, %6, %7}, PRED_enable_input_d, %9;\n\t"
// tcgen05.mma.cta_group.kind [d_tmem], a_desc, b_desc, idesc, disable_output_lane, enable_input_d, scale_input_d; //
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%4, %5, %6, %7, %8, %9, %10, %11}, PRED_enable_input_d, "
"tcgen05.mma.cta_group::2.kind::tf32 [%0], %1, %2, %3, {%4, %5, %6, %7, %8, %9, %10, %11}, PRED_enable_input_d, "
// tcgen05.mma.cta_group.kind [d_tmem], a_desc, b_desc, idesc, disable_output_lane, enable_input_d; // PTX ISA 86,
// .kind = { .kind::f16, .kind::tf32, .kind::f8f6f4, .kind::i8 }
static_assert(__kind == kind_f16 || __kind == kind_tf32 || __kind == kind_f8f6f4 || __kind == kind_i8, "");
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%4, %5, %6, %7}, PRED_enable_input_d;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%4, %5, %6, %7}, PRED_enable_input_d;\n\t"
else if constexpr (__kind == kind_f8f6f4)
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%4, %5, %6, %7}, PRED_enable_input_d;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%4, %5, %6, %7}, PRED_enable_input_d;\n\t"
// tcgen05.mma.cta_group.kind [d_tmem], a_desc, b_
... (4324 more chars){
"description": "Check if there are older mma PTX instruction intrinsics for FP8",
"command": "find /usr/local/cuda-13.2/include -name \"*.h\" -exec grep -l \"f8e4m3\\|mma.*m16n8k32\" {} \\; 2>/dev/null | head -20"
}(no output){
"description": "Search for mma instruction PTX intrinsics",
"command": "find /usr/local/cuda-13.2/include -path \"*/instructions/generated/mma.h\" 2>/dev/null"
}(no output){
"description": "Check for mma intrinsics with FP8",
"command": "find /usr/local/cuda-13.2/include -name \"mma.h\" -path \"*/generated/*\" 2>/dev/null"
}(no output){
"description": "Find all mma-related headers",
"command": "find /usr/local/cuda-13.2/include -name \"*mma*\" -o -name \"*mma*\" 2>/dev/null | head -20; find /usr/local/cuda-13.2/include/cccl -name \"*.h\" | xargs grep -l \"mma.sync\" 2>/dev/null | head -10"
}(no output){
"url": "https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma",
"format": "text",
"timeout": 30
}1. Introduction — PTX ISA 9.2 documentation
1. Introduction
1.1. Scalable Data-Parallel Computing using GPUs
1.2. Goals of PTX
1.3. PTX ISA Version 9.2
1.4. Document Structure
2. Programming Model
2.1. A Highly Multithreaded Coprocessor
2.2. Thread Hierarchy
2.2.1. Cooperative Thread Arrays
2.2.2. Cluster of Cooperative Thread Arrays
2.2.3. Grid of Clusters
2.3. Memory Hierarchy
3. PTX Machine Model
3.1. A Set of SIMT Multiprocessors
3.2. Independent Thread Scheduling
3.3. On-chip Shared Memory
4. Syntax
4.1. Source Format
4.2. Comments
4.3. Statements
4.3.1. Directive Statements
4.3.2. Instruction Statements
4.4. Identifiers
4.5. Constants
4.5.1. Integer Constants
4.5.2. Floating-Point Constants
4.5.3. Predicate Constants
4.5.4. Constant Expressions
4.5.5. Integer Constant Expression Evaluation
4.5.6. Summary of Constant Expression Evaluation Rules
5. State Spaces, Types, and Variables
5.1. State Spaces
5.1.1. Register State Space
5.1.2. Special Register State Space
5.1.3. Constant State Space
5.1.3.1. Banked Constant State Space (deprecated)
5.1.4. Global State Space
5.1.5. Local State Space
5.1.6. Parameter State Space
5.1.6.1. Kernel Function Parameters
5.1.6.2. Kernel Function Parameter Attributes
5.1.6.3. Kernel Parameter Attribute: .ptr
5.1.6.4. Device Function Parameters
5.1.
... (49932 more chars){
"description": "Find mma m16n8k32 fragment layout",
"prompt": "Search the file /home/infatoshi/.local/share/opencode/tool-output/tool_dceaa8f6e001vznYo1n4HSy7Ne for the section about \"Matrix Fragments for mma.m16n8k32\" \u2014 this is section 9.7.14.5.10 in the PTX ISA docs. Read the entire section about the fragment layout for m16n8k32, including the table showing which thread gets which elements. Return the full content of that section including the table.",
"subagent_type": "general"
}task_id: ses_231551b53ffe7k6sHQiBLESw1g (for resuming to continue this task if needed)
<task_result>
Here is the complete content of **Section 9.7.14.5.10 — Matrix Fragments for mma.m16n8k32**, from line 27424 through 27713:
---
```
9.7.14.5.10. Matrix Fragments for mma.m16n8k32
A warp executing mma.m16n8k32 will compute an MMA operation of shape .m16n8k32.
Elements of the matrix are distributed across the threads in a warp so each thread of the warp holds
a fragment of the matrix.
Multiplicand A:
.s4 or .u4 :
.atype Fragment Elements (low to high)
.s4 / .u4 A vector expression containing two .b32 registers, with each a0, a1, …, a14, a15
register containing eight .u4 / .s4 elements from the
matrix A.
The layout of the fragments held by different threads is shown in Figure 87.
Figure 87 MMA .m16n8k32 fragment layout for matrix A with .u4 / .s4 type.
The row and column of a matrix fragment can be computed as:
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = groupID for ai where i < 8
groupID + 8 for ai where i >= 8
col = (threadID_in_group * 8) + (i & 0x7) for ai where i = {0,..,15}
.s8 or .u8 or .e4m3 or .e5m2 or .e3m2 or .e2m3 or .e2m1:
.atype Fragment
... (5073 more chars){
"command": "nvidia-smi --query-gpu=name,compute_cap,gpu_bus_id,memory.total --format=csv,noheader && python3 -c \"import torch; props = torch.cuda.get_device_properties(0); print(f'SMs: {props.multi_processor_count}, Max threads/SM: {props.max_threads_per_multi_processor}, Shared mem/SM: {props.shared_mem_per_multiprocessor/1024:.0f} KB, Max shared mem/block: {props.shared_mem_per_block_optin/1024:.0f} KB, Registers/SM: {props.regs_per_multiprocessor}')\"",
"description": "Get SM count and GPU properties"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 00000000:01:00.0, 97887 MiB
Traceback (most recent call last):
File "<string>", line 1, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'shared_mem_per_multiprocessor'. Did you mean: 'shared_memory_per_multiprocessor'?
{
"command": "<578 chars \u2014 see below>",
"description": "Get detailed GPU properties"
}python3 -c "
import torch
props = torch.cuda.get_device_properties(0)
print(f'SMs: {props.multi_processor_count}')
print(f'Max threads per SM: {props.max_threads_per_multi_processor}')
print(f'Shared mem per SM: {props.shared_memory_per_multiprocessor/1024:.0f} KB')
print(f'Shared mem per block opt-in: {props.shared_memory_per_block_optin/1024:.0f} KB')
print(f'Registers per SM: {props.regs_per_multiprocessor}')
print(f'Max threads per block: {props.max_threads_per_block}')
print(f'Warp size: {props.warp_size}')
print(f'L2 cache: {props.l2_cache_size/1024/1024:.1f} MB')
"Traceback (most recent call last):
File "<string>", line 11, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'l2_cache_size'. Did you mean: 'L2_cache_size'?
SMs: 188
Max threads per SM: 1536
Shared mem per SM: 100 KB
Shared mem per block opt-in: 99 KB
Registers per SM: 65536
Max threads per block: 1024
Warp size: 32
{
"todos": [
{
"content": "Write CUDA FP8 GEMM kernel in solution.py using mma.sync PTX",
"status": "in_progress",
"priority": "high"
},
{
"content": "Verify correctness with python check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize kernel performance",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py to measure TFLOPS",
"status": "pending",
"priority": "medium"
}
]
}content: 'Write CUDA FP8 GEMM kernel in solution.py using mma.sync PTX'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with python check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py to measure TFLOPS'
status: 'pending'
priority: 'medium'{
"command": "ls -la solution.py 2>/dev/null; cat solution.py 2>/dev/null || echo \"No solution.py yet\"",
"description": "Check if solution.py exists"
}No solution.py yet
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -0,0 +1,271 @@
+"""Custom FP8 e4m3 GEMM kernel using inline PTX mma.sync on SM120 Blackwell.
+
+Uses mma.sync.aligned.m16n8k32.row.col.f32.f8e4m3.f8e4m3.f32 to hit FP8 tensor cores
+directly. Tiling: BM=64, BN=64, BK=32 with 128-thread blocks (4 warps).
+"""
+from __future__ import annotations
+
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+OP_TYPE = "gemm"
+SUPPORTED_PRECISIONS = ["fp8_e4m3"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+# ---------------------------------------------------------------------------
+# CUDA source
+# ---------------------------------------------------------------------------
+
+CUDA_SRC = r"""
+#include <cuda_fp8.h>
+#include <cuda_bf16.h>
+#include <cuda_runtime.h>
+
+#define BM 64
+#define BN 64
+#define BK 32
+#define WM 32
+#define WN 32
+#define WM_MMA 2 // 32/16 = 2 MMAs in M
+#define WN_MMA 4 // 32/8 = 4 MMAs in N
+
+__global__ void fp8_gemm_kernel(
+ const __nv_fp8_e4m3* __restrict__ A,
+ const __nv_fp8_e4m3* __restrict__ B,
+ __nv_bfloat16* __restrict__ C,
+ int M, int N, int K)
+{
+ // Block indices
+ int bx = blockIdx.x; // N-dimension tile index
+ int by = blockIdx.y; // M-dimension tile index
+
+ // Warp layout: 2 warps in M (warps 0-1, 2-3), 2 warps in N (warps 0,2 and 1,3)
+ int tid = threadIdx.x;
+ int warp_id = tid / 32;
+ int lane_id = tid % 32;
+ int warp_m = warp_id / 2; // 0 or 1
+ int warp_n = warp_id % 2; // 0 or 1
+
+ int group_id = lane_id >> 2; // 0..7
+ int thread_id = lane_id & 3; // 0..3
+
+ // Global output position for this thread block
+ int c_row = by * BM + warp_m * WM;
+ int c_col = bx * BN + warp_n * WN;
+
+ // Shared memory
+ __shared__ __nv_fp8_e4m3 As[BM][BK];
+ __shared__ __nv_fp8_e4m3 Bs[BN][BK]; // stored as w: [n_tile][k_tile]
+
+ // Accumulators: 2 MMAs in M x 4 MMAs in N = 8 MMA sets x 4 f32 = 32 f32
+ float accum[WM_MMA][WN_MMA][4] = {0.0f};
+
+ // K loop
+ for (int k_block = 0; k_block < K; k_block += BK) {
+ // Cooperative load A into shared memory
+ // Each thread loads 4 elements
+ for (int i = tid; i < BM * BK; i += blockDim.x) {
+ int row = i / BK;
+ int col = i % BK;
+ int global_row = by * BM + row;
+ int global_col = k_block + col;
+ if (global_row < M && global_col < K) {
+ As[row][col] = A[global_row * K + global_col];
+ } else {
+ As[row][col] = __nv_fp8_e4m3(0);
+ }
+ }
+
+ // Cooperative load B into shared memory
+ // B is stored as w[N][K], treated as KxN col-major for MMA
+ for (int i = tid; i < BN * BK; i += blockDim.x) {
+ int n_idx = i / BK;
+ int k_idx = i % BK;
+ int global_n = bx * BN + n_idx;
+ int global_k = k_block + k_idx;
+ if (global_n < N && global_k < K) {
+ Bs[n_idx][k_idx] = B[global_n * K + global_k];
+ } else {
+ Bs[n_idx][k_idx] = __nv_fp8_e4m3(0);
+ }
+ }
+
+ __syncthreads();
+
+ // Process MMAs for this warp
+ #pragma unroll
+ for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
+ int a_row_base = warp_m * WM + m_idx * 16;
+
+ // A fragment: 4 registers, each holding 4 fp8 values
+ // a0..a3: row=group_id, col=thread_id*4 + 0..3
+ // a4..a7: row=group_id+8, col=thread_id*4 + 0..3
+ // a8..a11: row=group_id+8, col=thread_id*4 + 16..19
+ // a12..a15:row=group_id, col=thread_id*4 + 16..19
+ uint32_t a_reg[4];
+ a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4]));
+ a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4]));
+ a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));
+ a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16]));
+
+ #pragma unroll
+ for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
+ int b_n_base = warp_n * WN + n_idx * 8;
+
+ // B fragment: 2 registers, each holding 4 fp8 values
+ // b0..b3: row=thread_id*4+0..3, col=group_id
+ // b4..b7: row=thread_id*4+16..19, col=group_id
+ // In w (NxK row-major), this is w[group_id][thread_id*4 + offset]
+ uint32_t b_reg[2];
+ b_reg[0] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4]));
+ b_reg[1] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4 + 16]));
+
+ // MMA instruction
+ asm volatile(
+ "mma.sync.aligned.m16n8k32.row.col.f32.f8e4m3.f8e4m3.f32 "
+ "{%0, %1, %2, %3}, "
+ "{%4, %5, %6, %7}, "
+ "{%8, %9}, "
+ "{%10, %11, %12, %13};"
+ : "=f"(accum[m_idx][n_idx][0]),
+ "=f"(accum[m_idx][n_idx][1]),
+ "=f"(accum[m_idx][n_idx][2]),
+ "=f"(accum[m_idx][n_idx][3])
+ : "r"(a_reg[0]), "r"(a_reg[1]), "r"(a_reg[2]), "r"(a_reg[3]),
+ "r"(b_reg[0]), "r"(b_reg[1]),
+ "f"(accum[m_idx][n_idx][0]),
+ "f"(accum[m_idx][n_idx][1]),
+ "f"(accum[m_idx][n_idx][2]),
+ "f"(accum[m_idx][n_idx][3])
+ );
+ }
+ }
+
+ __syncthreads();
+ }
+
+ // Store output: convert f32 accumulators to bf16
+ #pragma unroll
+
... (3785 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"content": "<9405 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run check.py for correctness",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 255.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py", line 254, in forward
_ensure_module()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
... (6895 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -125,7 +125,7 @@
// MMA instruction
asm volatile(
- "mma.sync.aligned.m16n8k32.row.col.f32.f8e4m3.f8e4m3.f32 "
+ "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"oldString": "<830 chars \u2014 see below>",
"newString": "<826 chars \u2014 see below>"
} asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.f8e4m3.f8e4m3.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
: "=f"(accum[m_idx][n_idx][0]),
"=f"(accum[m_idx][n_idx][1]),
"=f"(accum[m_idx][n_idx][2]),
"=f"(accum[m_idx][n_idx][3])
: "r"(a_reg[0]), "r"(a_reg[1]), "r"(a_reg[2]), "r"(a_reg[3]),
"r"(b_reg[0]), "r"(b_reg[1]),
"f"(accum[m_idx][n_idx][0]),
"f"(accum[m_idx][n_idx][1]),
"f"(accum[m_idx][n_idx][2]),
"f"(accum[m_idx][n_idx][3])
); asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
: "=f"(accum[m_idx][n_idx][0]),
"=f"(accum[m_idx][n_idx][1]),
"=f"(accum[m_idx][n_idx][2]),
"=f"(accum[m_idx][n_idx][3])
: "r"(a_reg[0]), "r"(a_reg[1]), "r"(a_reg[2]), "r"(a_reg[3]),
"r"(b_reg[0]), "r"(b_reg[1]),
"f"(accum[m_idx][n_idx][0]),
"f"(accum[m_idx][n_idx][1]),
"f"(accum[m_idx][n_idx][2]),
"f"(accum[m_idx][n_idx][3])
);Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -216,7 +216,8 @@
if _module is not None:
return
- extra_cflags = [
+ extra_cflags = ["-O3", "-std=c++17"]
+ extra_cuda_cflags = [
"-O3",
"--use_fast_math",
"-arch=sm_120a",
@@ -229,9 +230,7 @@
cuda_sources=[CUDA_SRC],
functions=["fp8_gemm_launch"],
extra_cflags=extra_cflags,
- extra_cuda_cflags=extra_cflags + [
- "--ptxas-options=-v",
- ],
+ extra_cuda_cflags=extra_cuda_cflags,
verbose=False,
)
_fp8_gemm_fn = _module.fp8_gemm_launch
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"oldString": "<487 chars \u2014 see below>",
"newString": "<490 chars \u2014 see below>"
} extra_cflags = [
"-O3",
"--use_fast_math",
"-arch=sm_120a",
"-std=c++17",
]
_module = load_inline(
name="fp8_gemm_kernel",
cpp_sources="torch::Tensor fp8_gemm_launch(torch::Tensor x, torch::Tensor w);",
cuda_sources=[CUDA_SRC],
functions=["fp8_gemm_launch"],
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cflags + [
"--ptxas-options=-v",
],
verbose=False,
) extra_cflags = ["-O3", "-std=c++17"]
extra_cuda_cflags = [
"-O3",
"--use_fast_math",
"-arch=sm_120a",
"-std=c++17",
]
_module = load_inline(
name="fp8_gemm_kernel",
cpp_sources="torch::Tensor fp8_gemm_launch(torch::Tensor x, torch::Tensor w);",
cuda_sources=[CUDA_SRC],
functions=["fp8_gemm_launch"],
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
verbose=False,
)Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Run check.py after fixes",
"timeout": 120000
}FAIL: shape 0 {'M': 4096, 'N': 4096, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=15.9062 (atol=0.01, rtol=0.01)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -101,15 +101,15 @@
int a_row_base = warp_m * WM + m_idx * 16;
// A fragment: 4 registers, each holding 4 fp8 values
- // a0..a3: row=group_id, col=thread_id*4 + 0..3
- // a4..a7: row=group_id+8, col=thread_id*4 + 0..3
- // a8..a11: row=group_id+8, col=thread_id*4 + 16..19
- // a12..a15:row=group_id, col=thread_id*4 + 16..19
+ // a0..a3: row=group_id, col=thread_id*4 + 0..3
+ // a4..a7: row=group_id+8, col=thread_id*4 + 0..3
+ // a8..a11: row=group_id, col=thread_id*4 + 16..19
+ // a12..a15: row=group_id+8, col=thread_id*4 + 16..19
uint32_t a_reg[4];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4]));
- a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));
- a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16]));
+ a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16]));
+ a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"oldString": "<790 chars \u2014 see below>",
"newString": "<794 chars \u2014 see below>"
} // A fragment: 4 registers, each holding 4 fp8 values
// a0..a3: row=group_id, col=thread_id*4 + 0..3
// a4..a7: row=group_id+8, col=thread_id*4 + 0..3
// a8..a11: row=group_id+8, col=thread_id*4 + 16..19
// a12..a15:row=group_id, col=thread_id*4 + 16..19
uint32_t a_reg[4];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4]));
a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));
a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16])); // A fragment: 4 registers, each holding 4 fp8 values
// a0..a3: row=group_id, col=thread_id*4 + 0..3
// a4..a7: row=group_id+8, col=thread_id*4 + 0..3
// a8..a11: row=group_id, col=thread_id*4 + 16..19
// a12..a15: row=group_id+8, col=thread_id*4 + 16..19
uint32_t a_reg[4];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4]));
a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16]));
a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Run check.py after fixing A fragment",
"timeout": 120000
}FAIL: shape 0 {'M': 4096, 'N': 4096, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=0.53125 (atol=0.01, rtol=0.01)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"command": "<562 chars \u2014 see below>",
"description": "Check if fp8 tensor ops work on this GPU"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
# Quick test: does torch._scaled_mm work on this GPU?
x = (torch.rand(64, 32, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = (torch.rand(64, 32, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
try:
y = torch._scaled_mm(x, w.T, scale_a=torch.tensor(1.0, device='cuda'), scale_b=torch.tensor(1.0, device='cuda'), out_dtype=torch.bfloat16)
print('scaled_mm works, shape:', y.shape, 'dtype:', y.dtype)
except Exception as e:
print(f'scaled_mm failed: {e}')
"scaled_mm works, shape: torch.Size([64, 64]) dtype: torch.bfloat16
{
"command": "<2903 chars \u2014 see below>",
"description": "Test MMA instruction compilation",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = '''
#include <cuda_fp8.h>
__global__ void test_kernel(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, float* C) {
int lane_id = threadIdx.x % 32;
int group_id = lane_id >> 2;
int thread_id = lane_id & 3;
__shared__ __nv_fp8_e4m3 As[16][32];
__shared__ __nv_fp8_e4m3 Bs[8][32];
// Load dummy data
for (int i = threadIdx.x; i < 16*32; i += 32) {
int row = i / 32;
int col = i % 32;
As[row][col] = __nv_fp8_e4m3(0);
}
for (int i = threadIdx.x; i < 8*32; i += 32) {
int row = i / 32;
int col = i % 32;
Bs[row][col] = __nv_fp8_e4m3(0);
}
__syncthreads();
uint32_t a_reg[4], b_reg[2];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[group_id + 8][thread_id * 4]));
a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[group_id][thread_id * 4 + 16]));
a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[group_id + 8][thread_id * 4 + 16]));
b_reg[0] = *(reinterpret_cast<const uint32_t*>(&Bs[group_id][thread_id * 4]));
b_reg[1] = *(reinterpret_cast<const uint32_t*>(&Bs[group_id][thread_id * 4 + 16]));
float acc[4] = {0,0,0,0};
asm volatile(
\"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \"
\"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\"
: \"=f\"(acc[0]), \"=f\"(acc[1]), \"=f\"(acc[2]), \"=f\"(acc[3])
: \"r\"(a_reg[0]), \"r\"(a_reg[1]), \"r\"(a_reg[2]), \"r\"(a_reg[3]),
\"r\"(b_reg[0]), \"r\"(b_reg[1]),
\"f\"(acc[0]), \"f\"(acc[1]), \"f\"(acc[2]), \"f\"(acc[3])
);
C[threadIdx.x * 4 + 0] = acc[0];
C[threadIdx.x * 4 + 1] = acc[1];
C[threadIdx.x * 4 + 2] = acc[2];
C[threadIdx.x * 4 + 3] = acc[3];
}
torch::Tensor launch() {
auto out = torch::empty({128}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
auto a = torch::empty({16*32}, torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA));
auto b = torch::empty({8*32}, torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA));
test_kernel<<<1, 32>>>(
reinterpret_cast<const __nv_fp8_e4m3*>(a.const_data_ptr()),
reinterpret_cast<const __nv_fp8_e4m3*>(b.const_data_ptr()),
out.mutable_data_ptr<float>()
);
return out;
}
'''
try:
mod = load_inline(
name='test_mma',
cpp_sources='torch::Tensor launch();',
cuda_sources=[src],
functions=['launch'],
extra_cuda_cflags=['-O3', '-arch=sm_120a', '-std=c++17', '--keep'],
verbose=True,
)
out = mod.launch()
print('MMA test passed, output:', out[:4])
except Exception as e:
print(f'MMA test failed: {e}')
" 2>&1 | tail -20[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=test_mma -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -fPIC -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/test_mma/main.cpp -o main.o
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=test_mma -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 -arch=sm_120a -std=c++17 --keep -c /home/infatoshi/.cache/torch_extensions/py312_cu130/test_mma/cuda.cu -o cuda.cuda.o
[3/3] c++ main.o cuda.cuda.o -shared -L/home/infatoshi/.local/lib/python3.12/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda-13/lib64 -lcudart -o test_mma.so
MMA test passed, output: tensor([0., 0., 0., 0.], device='cuda:0')
{
"command": "<3377 chars \u2014 see below>",
"description": "Test single MMA computation against reference",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = '''
#include <cuda_fp8.h>
#include <cuda_bf16.h>
__global__ void test_fp8_mma(
const __nv_fp8_e4m3* A, // 16x32, row-major
const __nv_fp8_e4m3* B, // 8x32, row-major (treated as 32x8 col-major)
float* C_out // 16x8 output in f32 for checking
) {
int lane_id = threadIdx.x % 32;
int group_id = lane_id >> 2;
int thread_id = lane_id & 3;
__shared__ __nv_fp8_e4m3 As[16][32];
__shared__ __nv_fp8_e4m3 Bs[8][32];
// Each thread loads its portion of A
for (int i = threadIdx.x; i < 16*32; i += 32) {
As[i/32][i%32] = A[i];
}
for (int i = threadIdx.x; i < 8*32; i += 32) {
Bs[i/32][i%32] = B[i];
}
__syncthreads();
uint32_t a_reg[4], b_reg[2];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[group_id + 8][thread_id * 4]));
a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[group_id][thread_id * 4 + 16]));
a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[group_id + 8][thread_id * 4 + 16]));
b_reg[0] = *(reinterpret_cast<const uint32_t*>(&Bs[group_id][thread_id * 4]));
b_reg[1] = *(reinterpret_cast<const uint32_t*>(&Bs[group_id][thread_id * 4 + 16]));
float acc[4] = {0,0,0,0};
asm volatile(
\"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \"
\"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\"
: \"=f\"(acc[0]), \"=f\"(acc[1]), \"=f\"(acc[2]), \"=f\"(acc[3])
: \"r\"(a_reg[0]), \"r\"(a_reg[1]), \"r\"(a_reg[2]), \"r\"(a_reg[3]),
\"r\"(b_reg[0]), \"r\"(b_reg[1]),
\"f\"(acc[0]), \"f\"(acc[1]), \"f\"(acc[2]), \"f\"(acc[3])
);
// Store to output in the MMA fragment layout
int r0 = group_id;
int r1 = group_id + 8;
int c0 = thread_id * 2;
int c1 = thread_id * 2 + 1;
C_out[(r0) * 8 + c0] = acc[0];
C_out[(r0) * 8 + c1] = acc[1];
C_out[(r1) * 8 + c0] = acc[2];
C_out[(r1) * 8 + c1] = acc[3];
}
torch::Tensor launch(torch::Tensor a, torch::Tensor b) {
auto out = torch::zeros({16, 8}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
test_fp8_mma<<<1, 32>>>(
reinterpret_cast<const __nv_fp8_e4m3*>(a.const_data_ptr()),
reinterpret_cast<const __nv_fp8_e4m3*>(b.const_data_ptr()),
out.mutable_data_ptr<float>()
);
return out;
}
'''
mod = load_inline(
name='test_fp8_mma_full',
cpp_sources='torch::Tensor launch(torch::Tensor a, torch::Tensor b);',
cuda_sources=[src],
functions=['launch'],
extra_cuda_cflags=['-O3', '-arch=sm_120a', '-std=c++17'],
verbose=False,
)
# Create test matrices
torch.manual_seed(42)
a_fp8 = (torch.rand(16, 32) * 2 - 1).to(torch.float8_e4m3fn)
b_fp8 = (torch.rand(8, 32) * 2 - 1).to(torch.float8_e4m3fn)
mma_out = mod.launch(a_fp8, b_fp8)
# Reference: x @ w.T where x=a (M=16,K=32), w=b (N=8,K=32)
ref_out = a_fp8.to(torch.float32) @ b_fp8.to(torch.float32).T
print('MMA output:')
print(mma_out)
print()
print('Reference:')
print(ref_out)
print()
print('Diff:')
diff = (mma_out - ref_out).abs()
print(diff)
print()
print(f'Max diff: {diff.max().item()}')
print(f'Mean diff: {diff.mean().item()}')
" 2>&1Traceback (most recent call last):
File "<string>", line 96, in <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
MMA output:
tensor([[-2.0397, 4.3123, 0.1939, 0.7910, 2.1230, 0.8066, 2.4675, 1.2237],
[ 0.0288, -0.0763, -1.3528, 2.2675, 1.9339, 0.0629, 1.4968, -1.4283],
[-0.7389, -0.4175, 2.6858, 0.3521, -0.0449, -0.0850, -0.6050, 0.7443],
[-1.0031, 1.5126, 1.3962, 3.0055, -0.4978, -2.3778, -1.3828, -1.7230],
[-4.2261, 0.2886, -5.0140, -0.4471, 1.7966, -2.8521, 5.2173, -0.4017],
[ 1.2911, -0.0166, 1.3037, 2.3794, -2.5284, 1.9436, 0.7437, -0.0864],
[-0.8485, -0.0861, -1.5728, 3.4196, -0.5233, -0.5842, 2.6978, 1.5140],
[-2.0581, 0.5074, 2.1452, 2.1805, -0.0065, 0.0079, -0.3975, 0.8380],
[ 4.7699, 3.1807, -0.4404, -0.3566, -0.8828, -3.3740, 2.9771, -1.2673],
[ 1.5197, -2.0945, -0.6326, -1.4501, 0.4269, 0.7590, 1.8896, -2.8092],
[-1.4343, 1.7009, -3.1567, -1.0857, 1.0572, -5.8605, 5.5454, 0.4465],
[-2.1106, 0.2716, 0.7831, 1.9332, 0.3499, 2.5971, 1.1561, -1.4697],
[-2.0002, 2.6079, 1.3279, -1.4916, -2.9670, 0.2734, -2.3752, -0.8746],
[-0.5668, -2.9500, 3.0677, -1.0364, 0.6293, 2.9496, -2.8300, -3.6202],
[ 2.8798, -3.4182, 0.2415, -0.2503, -0.2463, 1.2915, -3.1165, -1.6692],
[ 1.3029, 1.4653, 2.0859, -0.5103, -0.2332, 1.7084, -1.6769, -0
... (1365 more chars){
"command": "<6041 chars \u2014 see below>",
"description": "Test full kernel with small matrix",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = r'''
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#define BM 64
#define BN 64
#define BK 32
#define WM 32
#define WN 32
#define WM_MMA 2
#define WN_MMA 4
__global__ void fp8_gemm_kernel(
const __nv_fp8_e4m3* __restrict__ A,
const __nv_fp8_e4m3* __restrict__ B,
float* __restrict__ C,
int M, int N, int K)
{
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int warp_id = tid / 32;
int lane_id = tid % 32;
int warp_m = warp_id / 2;
int warp_n = warp_id % 2;
int group_id = lane_id >> 2;
int thread_id = lane_id & 3;
int c_row = by * BM + warp_m * WM;
int c_col = bx * BN + warp_n * WN;
__shared__ __nv_fp8_e4m3 As[BM][BK];
__shared__ __nv_fp8_e4m3 Bs[BN][BK];
float accum[WM_MMA][WN_MMA][4] = {0.0f};
for (int k_block = 0; k_block < K; k_block += BK) {
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
int global_row = by * BM + row;
int global_col = k_block + col;
if (global_row < M && global_col < K)
As[row][col] = A[global_row * K + global_col];
else
As[row][col] = __nv_fp8_e4m3(0);
}
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
int global_n = bx * BN + n_idx;
int global_k = k_block + k_idx;
if (global_n < N && global_k < K)
Bs[n_idx][k_idx] = B[global_n * K + global_k];
else
Bs[n_idx][k_idx] = __nv_fp8_e4m3(0);
}
__syncthreads();
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
int a_row_base = warp_m * WM + m_idx * 16;
uint32_t a_reg[4];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4]));
a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16]));
a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int b_n_base = warp_n * WN + n_idx * 8;
uint32_t b_reg[2];
b_reg[0] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4]));
b_reg[1] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4 + 16]));
asm volatile(
\"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \"
\"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\"
: \"=f\"(accum[m_idx][n_idx][0]), \"=f\"(accum[m_idx][n_idx][1]),
\"=f\"(accum[m_idx][n_idx][2]), \"=f\"(accum[m_idx][n_idx][3])
: \"r\"(a_reg[0]), \"r\"(a_reg[1]), \"r\"(a_reg[2]), \"r\"(a_reg[3]),
\"r\"(b_reg[0]), \"r\"(b_reg[1]),
\"f\"(accum[m_idx][n_idx][0]), \"f\"(accum[m_idx][n_idx][1]),
\"f\"(accum[m_idx][n_idx][2]), \"f\"(accum[m_idx][n_idx][3])
);
}
}
__syncthreads();
}
// Store (f32 for debugging, not bf16)
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int out_row_base = c_row + m_idx * 16;
int out_col_base = c_col + n_idx * 8;
int r0 = out_row_base + group_id;
int r1 = out_row_base + group_id + 8;
int c0 = out_col_base + thread_i
... (2041 more chars)64x64x32 - Max diff: 0.00000000
PASS
{
"command": "<6001 chars \u2014 see below>",
"description": "Test kernel with various shapes",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = r'''
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#define BM 64
#define BN 64
#define BK 32
#define WM 32
#define WN 32
#define WM_MMA 2
#define WN_MMA 4
__global__ void fp8_gemm_kernel(
const __nv_fp8_e4m3* __restrict__ A,
const __nv_fp8_e4m3* __restrict__ B,
float* __restrict__ C,
int M, int N, int K)
{
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int warp_id = tid / 32;
int lane_id = tid % 32;
int warp_m = warp_id / 2;
int warp_n = warp_id % 2;
int group_id = lane_id >> 2;
int thread_id = lane_id & 3;
int c_row = by * BM + warp_m * WM;
int c_col = bx * BN + warp_n * WN;
__shared__ __nv_fp8_e4m3 As[BM][BK];
__shared__ __nv_fp8_e4m3 Bs[BN][BK];
float accum[WM_MMA][WN_MMA][4] = {0.0f};
for (int k_block = 0; k_block < K; k_block += BK) {
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
int global_row = by * BM + row;
int global_col = k_block + col;
if (global_row < M && global_col < K)
As[row][col] = A[global_row * K + global_col];
else
As[row][col] = __nv_fp8_e4m3(0);
}
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
int global_n = bx * BN + n_idx;
int global_k = k_block + k_idx;
if (global_n < N && global_k < K)
Bs[n_idx][k_idx] = B[global_n * K + global_k];
else
Bs[n_idx][k_idx] = __nv_fp8_e4m3(0);
}
__syncthreads();
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
int a_row_base = warp_m * WM + m_idx * 16;
uint32_t a_reg[4];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4]));
a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16]));
a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int b_n_base = warp_n * WN + n_idx * 8;
uint32_t b_reg[2];
b_reg[0] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4]));
b_reg[1] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4 + 16]));
asm volatile(
\"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \"
\"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\"
: \"=f\"(accum[m_idx][n_idx][0]), \"=f\"(accum[m_idx][n_idx][1]),
\"=f\"(accum[m_idx][n_idx][2]), \"=f\"(accum[m_idx][n_idx][3])
: \"r\"(a_reg[0]), \"r\"(a_reg[1]), \"r\"(a_reg[2]), \"r\"(a_reg[3]),
\"r\"(b_reg[0]), \"r\"(b_reg[1]),
\"f\"(accum[m_idx][n_idx][0]), \"f\"(accum[m_idx][n_idx][1]),
\"f\"(accum[m_idx][n_idx][2]), \"f\"(accum[m_idx][n_idx][3])
);
}
}
__syncthreads();
}
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int out_row_base = c_row + m_idx * 16;
int out_col_base = c_col + n_idx * 8;
int r0 = out_row_base + group_id;
int r1 = out_row_base + group_id + 8;
int c0 = out_col_base + thread_id * 2;
int c1 = out_col_base +
... (2001 more chars)Traceback (most recent call last):
File "<string>", line 152, in <module>
NameError: name 'PASS' is not defined
{
"command": "<6027 chars \u2014 see below>",
"description": "Test kernel with various shapes using f32 output",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = r'''
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#define BM 64
#define BN 64
#define BK 32
#define WM 32
#define WN 32
#define WM_MMA 2
#define WN_MMA 4
__global__ void fp8_gemm_kernel(
const __nv_fp8_e4m3* __restrict__ A,
const __nv_fp8_e4m3* __restrict__ B,
float* __restrict__ C,
int M, int N, int K)
{
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int warp_id = tid / 32;
int lane_id = tid % 32;
int warp_m = warp_id / 2;
int warp_n = warp_id % 2;
int group_id = lane_id >> 2;
int thread_id = lane_id & 3;
int c_row = by * BM + warp_m * WM;
int c_col = bx * BN + warp_n * WN;
__shared__ __nv_fp8_e4m3 As[BM][BK];
__shared__ __nv_fp8_e4m3 Bs[BN][BK];
float accum[WM_MMA][WN_MMA][4] = {0.0f};
for (int k_block = 0; k_block < K; k_block += BK) {
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
int global_row = by * BM + row;
int global_col = k_block + col;
if (global_row < M && global_col < K)
As[row][col] = A[global_row * K + global_col];
else
As[row][col] = __nv_fp8_e4m3(0);
}
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
int global_n = bx * BN + n_idx;
int global_k = k_block + k_idx;
if (global_n < N && global_k < K)
Bs[n_idx][k_idx] = B[global_n * K + global_k];
else
Bs[n_idx][k_idx] = __nv_fp8_e4m3(0);
}
__syncthreads();
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
int a_row_base = warp_m * WM + m_idx * 16;
uint32_t a_reg[4];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4]));
a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16]));
a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int b_n_base = warp_n * WN + n_idx * 8;
uint32_t b_reg[2];
b_reg[0] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4]));
b_reg[1] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4 + 16]));
asm volatile(
\"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \"
\"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\"
: \"=f\"(accum[m_idx][n_idx][0]), \"=f\"(accum[m_idx][n_idx][1]),
\"=f\"(accum[m_idx][n_idx][2]), \"=f\"(accum[m_idx][n_idx][3])
: \"r\"(a_reg[0]), \"r\"(a_reg[1]), \"r\"(a_reg[2]), \"r\"(a_reg[3]),
\"r\"(b_reg[0]), \"r\"(b_reg[1]),
\"f\"(accum[m_idx][n_idx][0]), \"f\"(accum[m_idx][n_idx][1]),
\"f\"(accum[m_idx][n_idx][2]), \"f\"(accum[m_idx][n_idx][3])
);
}
}
__syncthreads();
}
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int out_row_base = c_row + m_idx * 16;
int out_col_base = c_col + n_idx * 8;
int r0 = out_row_base + group_id;
int r1 = out_row_base + group_id + 8;
int c0 = out_col_base + thread_id * 2;
int c1 = out_col_base +
... (2027 more chars)64x64x128 - Max diff: 0.00000000 PASS
128x128x128 - Max diff: 0.00000000 PASS
64x64x4127 - Max diff: 0.00000763 PASS
32x64x64 - Max diff: 0.00000000 PASS
4096x4096x4096 - Max diff: 0.00003052 PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Re-run check.py with A-fragment fix",
"timeout": 120000
}FAIL: shape 0 {'M': 4096, 'N': 4096, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=0.53125 (atol=0.01, rtol=0.01)
{
"command": "<1221 chars \u2014 see below>",
"description": "Compare solution output vs reference directly",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
import sys
sys.path.insert(0, '.')
import reference
import solution as sol
device = torch.device('cuda:0')
# Test shape 0
reference.M = 4096
reference.N = 4096
reference.K = 4096
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = sol.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
inputs = [t.to(device) for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out.float() - sol_out.float()).abs()
print(f'Max diff: {diff.max().item():.6f}')
print(f'Mean diff: {diff.mean().item():.6f}')
print(f'Num diff > 0.01: {(diff > 0.01).sum().item()}')
# Check specific elements
bad = diff > 0.5
if bad.any():
idx = torch.where(bad)
for i in range(min(20, idx[0].shape[0])):
r, c = idx[0][i].item(), idx[1][i].item()
print(f' [{r},{c}] sol={sol_out[r,c].item():.6f} ref={ref_out[r,c].item():.6f} diff={diff[r,c].item():.6f}')
else:
print('No diffs > 0.5')
" 2>&1Max diff: 0.539062
Mean diff: 0.078263
Num diff > 0.01: 15309541
[179,1740] sol=2.031250 ref=1.492188 diff=0.539062
[447,2507] sol=-1.171875 ref=-0.667969 diff=0.503906
[2746,3595] sol=0.402344 ref=0.910156 diff=0.507812
[3991,3003] sol=2.171875 ref=1.656250 diff=0.515625
{
"command": "<6041 chars \u2014 see below>",
"description": "Test identical kernel with bf16 output",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = r'''
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#define BM 64
#define BN 64
#define BK 32
#define WM 32
#define WN 32
#define WM_MMA 2
#define WN_MMA 4
__global__ void fp8_gemm_kernel(
const __nv_fp8_e4m3* __restrict__ A,
const __nv_fp8_e4m3* __restrict__ B,
__nv_bfloat16* __restrict__ C,
int M, int N, int K)
{
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int warp_id = tid / 32;
int lane_id = tid % 32;
int warp_m = warp_id / 2;
int warp_n = warp_id % 2;
int group_id = lane_id >> 2;
int thread_id = lane_id & 3;
int c_row = by * BM + warp_m * WM;
int c_col = bx * BN + warp_n * WN;
__shared__ __nv_fp8_e4m3 As[BM][BK];
__shared__ __nv_fp8_e4m3 Bs[BN][BK];
float accum[WM_MMA][WN_MMA][4] = {0.0f};
for (int k_block = 0; k_block < K; k_block += BK) {
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
int global_row = by * BM + row;
int global_col = k_block + col;
if (global_row < M && global_col < K)
As[row][col] = A[global_row * K + global_col];
else
As[row][col] = __nv_fp8_e4m3(0);
}
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
int global_n = bx * BN + n_idx;
int global_k = k_block + k_idx;
if (global_n < N && global_k < K)
Bs[n_idx][k_idx] = B[global_n * K + global_k];
else
Bs[n_idx][k_idx] = __nv_fp8_e4m3(0);
}
__syncthreads();
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
int a_row_base = warp_m * WM + m_idx * 16;
uint32_t a_reg[4];
a_reg[0] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4]));
a_reg[1] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4]));
a_reg[2] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id][thread_id * 4 + 16]));
a_reg[3] = *(reinterpret_cast<const uint32_t*>(&As[a_row_base + group_id + 8][thread_id * 4 + 16]));
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int b_n_base = warp_n * WN + n_idx * 8;
uint32_t b_reg[2];
b_reg[0] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4]));
b_reg[1] = *(reinterpret_cast<const uint32_t*>(&Bs[b_n_base + group_id][thread_id * 4 + 16]));
asm volatile(
\"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \"
\"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\"
: \"=f\"(accum[m_idx][n_idx][0]), \"=f\"(accum[m_idx][n_idx][1]),
\"=f\"(accum[m_idx][n_idx][2]), \"=f\"(accum[m_idx][n_idx][3])
: \"r\"(a_reg[0]), \"r\"(a_reg[1]), \"r\"(a_reg[2]), \"r\"(a_reg[3]),
\"r\"(b_reg[0]), \"r\"(b_reg[1]),
\"f\"(accum[m_idx][n_idx][0]), \"f\"(accum[m_idx][n_idx][1]),
\"f\"(accum[m_idx][n_idx][2]), \"f\"(accum[m_idx][n_idx][3])
);
}
}
__syncthreads();
}
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int out_row_base = c_row + m_idx * 16;
int out_col_base = c_col + n_idx * 8;
int r0 = out_row_base + group_id;
int r1 = out_row_base + group_id + 8;
int c0 = out_col_base + thread_id * 2;
int c1 = out_col
... (2041 more chars)Max diff: 0.531250
Num diff > 0.01: 15309800
{
"command": "<6154 chars \u2014 see below>",
"description": "Debug: test kernel with f32 output vs reference",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = r'''
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#define BM 64
#define BN 64
#define BK 32
#define WM 32
#define WN 32
#define WM_MMA 2
#define WN_MMA 4
__global__ void fp8_gemm_kernel_f32(
const __nv_fp8_e4m3* __restrict__ A,
const __nv_fp8_e4m3* __restrict__ B,
float* __restrict__ C,
int M, int N, int K)
{
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int warp_id = tid / 32;
int lane_id = tid % 32;
int warp_m = warp_id / 2;
int warp_n = warp_id % 2;
int group_id = lane_id >> 2;
int thread_id = lane_id & 3;
int c_row = by * BM + warp_m * WM;
int c_col = bx * BN + warp_n * WN;
__shared__ __nv_fp8_e4m3 As[BM][BK];
__shared__ __nv_fp8_e4m3 Bs[BN][BK];
float accum[WM_MMA][WN_MMA][4] = {0.0f};
for (int k_block = 0; k_block < K; k_block += BK) {
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
int global_row = by * BM + row;
int global_col = k_block + col;
if (global_row < M && global_col < K)
As[row][col] = A[global_row * K + global_col];
else
As[row][col] = __nv_fp8_e4m3(0);
}
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
int global_n = bx * BN + n_idx;
int global_k = k_block + k_idx;
if (global_n < N && global_k < K)
Bs[n_idx][k_idx] = B[global_n * K + global_k];
else
Bs[n_idx][k_idx] = __nv_fp8_e4m3(0);
}
__syncthreads();
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
int a_row_base = warp_m * WM + m_idx * 16;
uint32_t a_reg[4];
a_reg[0] = *(const uint32_t*)(&As[a_row_base + group_id][thread_id * 4]);
a_reg[1] = *(const uint32_t*)(&As[a_row_base + group_id + 8][thread_id * 4]);
a_reg[2] = *(const uint32_t*)(&As[a_row_base + group_id][thread_id * 4 + 16]);
a_reg[3] = *(const uint32_t*)(&As[a_row_base + group_id + 8][thread_id * 4 + 16]);
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int b_n_base = warp_n * WN + n_idx * 8;
uint32_t b_reg[2];
b_reg[0] = *(const uint32_t*)(&Bs[b_n_base + group_id][thread_id * 4]);
b_reg[1] = *(const uint32_t*)(&Bs[b_n_base + group_id][thread_id * 4 + 16]);
asm volatile(
\"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \"
\"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\"
: \"=f\"(accum[m_idx][n_idx][0]), \"=f\"(accum[m_idx][n_idx][1]),
\"=f\"(accum[m_idx][n_idx][2]), \"=f\"(accum[m_idx][n_idx][3])
: \"r\"(a_reg[0]), \"r\"(a_reg[1]), \"r\"(a_reg[2]), \"r\"(a_reg[3]),
\"r\"(b_reg[0]), \"r\"(b_reg[1]),
\"f\"(accum[m_idx][n_idx][0]), \"f\"(accum[m_idx][n_idx][1]),
\"f\"(accum[m_idx][n_idx][2]), \"f\"(accum[m_idx][n_idx][3])
);
}
}
__syncthreads();
}
#pragma unroll
for (int m_idx = 0; m_idx < WM_MMA; ++m_idx) {
#pragma unroll
for (int n_idx = 0; n_idx < WN_MMA; ++n_idx) {
int out_row_base = c_row + m_idx * 16;
int out_col_base = c_col + n_idx * 8;
int r0 = out_row_base + group_id;
int r1 = out_row_base + group_id + 8;
int c0 = out_col_base + thread_id * 2;
int c1 = out_col_base + thread_id * 2 + 1;
if (r0 < M) {
if (c0 < N) C[r0 * N + c0] = accum[m_idx][n
... (2154 more chars)Kernel f32 vs ref - Max diff: 0.527828
Num diff > 0.01: 15415938
Num diff > 0.1: 5171740
Kernel f32 vs float ref - Max diff: 0.00000000
Num diff > 1e-4: 0
{
"todos": [
{
"priority": "high",
"content": "Write CUDA FP8 GEMM kernel in solution.py using mma.sync PTX",
"status": "completed"
},
{
"priority": "high",
"content": "Rewrite kernel to use bf16 MMA (match reference precision)",
"status": "in_progress"
},
{
"priority": "high",
"content": "Verify correctness with python check.py",
"status": "pending"
},
{
"priority": "high",
"content": "Profile and optimize kernel performance",
"status": "pending"
},
{
"priority": "medium",
... (87 more chars)content: 'Write CUDA FP8 GEMM kernel in solution.py using mma.sync PTX'
status: 'completed'
priority: 'high'
content: 'Rewrite kernel to use bf16 MMA (match reference precision)'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with python check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py to measure TFLOPS'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -1,11 +1,10 @@
-"""Custom FP8 e4m3 GEMM kernel using inline PTX mma.sync on SM120 Blackwell.
-
-Uses mma.sync.aligned.m16n8k32.row.col.f32.f8e4m3.f8e4m3.f32 to hit FP8 tensor cores
-directly. Tiling: BM=64, BN=64, BK=32 with 128-thread blocks (4 warps).
+"""Custom FP8→bf16 GEMM kernel using bf16 tensor-core MMA on SM120 Blackwell.
+
+Strategy: upcast fp8 inputs to bf16, keep weights as bf16, then use
+mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 to hit tensor cores.
+This matches the reference's numerical path exactly.
"""
from __future__ import annotations
-
-from pathlib import Path
import torch
import torch.nn as nn
@@ -16,126 +15,129 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
# ---------------------------------------------------------------------------
-# CUDA source
+# CUDA kernel source
# ---------------------------------------------------------------------------
CUDA_SRC = r"""
#include <cuda_fp8.h>
#include <cuda_bf16.h>
-#include <cuda_runtime.h>
-
-#define BM 64
-#define BN 64
-#define BK 32
-#define WM 32
-#define WN 32
-#define WM_MMA 2 // 32/16 = 2 MMAs in M
-#define WN_MMA 4 // 32/8 = 4 MMAs in N
-
-__global__ void fp8_gemm_kernel(
- const __nv_fp8_e4m3* __restrict__ A,
- const __nv_fp8_e4m3* __restrict__ B,
- __nv_bfloat16* __restrict__ C,
+
+// Tile sizes
+// BM and BN are tuned for occupancy: 128×128 output per block
+// BK=16 matches bf16 MMA K-dimension
+#define BM 128
+#define BN 128
+#define BK 16
+
+// Warp layout: 8 warps in a 2×4 grid (2 in M, 4 in N)
+// Each warp covers 64 columns in N and 32 rows in M? Wait, let me recalculate.
+// With 8 warps: 2 in M, 4 in N → each warp covers BM/2=64 rows, BN/4=32 cols
+// But an MMA gives 16×8 → 64/16=4 MMAs in M, 32/8=4 MMAs in N → 16 MMAs per warp per K-step
+// 16×4=64 f32 accumulators per thread. That's tight on registers.
+
+// Actually let me use 4 warps per block for lower register pressure.
+// BM=128, BN=64 with 4 warps (2×2): each warp covers 64×32, with 4×4=16 MMAs.
+// Hmm, still 16 MMAs.
+
+// Let me keep it simple: BM=64, BN=64, 4 warps (2×2), each warp: 32×32 → 2×4=8 MMAs.
+// Warp tile 32×32, 8 MMAs per K-step → 32 accumulators per thread. Good.
+#define WM_TILE 32
+#define WN_TILE 32
+#define WM_MMA 2 // 32/16
+#define WN_MMA 4 // 32/8
+
+__global__ void fp8_gemm_bf16_kernel(
+ const __nv_fp8_e4m3* __restrict__ A_in, // M x K, row-major, fp8
+ const __nv_bfloat16* __restrict__ B_in, // N x K, row-major, bf16
+ __nv_bfloat16* __restrict__ C, // M x N, row-major, bf16
int M, int N, int K)
{
- // Block indices
- int bx = blockIdx.x; // N-dimension tile index
- int by = blockIdx.y; // M-dimension tile index
-
- // Warp layout: 2 warps in M (warps 0-1, 2-3), 2 warps in N (warps 0,2 and 1,3)
+ // Block indices: x=N-dim, y=M-dim
+ int bx = blockIdx.x;
+ int by = blockIdx.y;
int tid = threadIdx.x;
int warp_id = tid / 32;
int lane_id = tid % 32;
- int warp_m = warp_id / 2; // 0 or 1
- int warp_n = warp_id % 2; // 0 or 1
-
- int group_id = lane_id >> 2; // 0..7
- int thread_id = lane_id & 3; // 0..3
-
- // Global output position for this thread block
- int c_row = by * BM + warp_m * WM;
- int c_col = bx * BN + warp_n * WN;
-
- // Shared memory
- __shared__ __nv_fp8_e4m3 As[BM][BK];
- __shared__ __nv_fp8_e4m3 Bs[BN][BK]; // stored as w: [n_tile][k_tile]
-
- // Accumulators: 2 MMAs in M x 4 MMAs in N = 8 MMA sets x 4 f32 = 32 f32
+ int warp_m = warp_id / 2; // 0..1
+ int warp_n = warp_id % 2; // 0..1
+ int group_id = lane_id >> 2; // 0..7
+ int thread_id = lane_id & 3; // 0..3
+
+ int c_row = by * BM + warp_m * WM_TILE;
+ int c_col = bx * BN + warp_n * WN_TILE;
+
+ // Shared memory tiles
+ __shared__ __nv_bfloat16 As[BM][BK]; // A tile after fp8→bf16 upcast
+ __shared__ __nv_bfloat16 Bs[BN][BK]; // B tile (already bf16)
+
+ // Accumulators: 8 MMA sets × 4 f32 = 32 f32
float accum[WM_MMA][WN_MMA][4] = {0.0f};
- // K loop
for (int k_block = 0; k_block < K; k_block += BK) {
- // Cooperative load A into shared memory
- // Each thread loads 4 elements
+ // Cooperative load A (fp8 → bf16 conversion on the fly)
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
- int global_row = by * BM + row;
- int global_col = k_block + col;
- if (global_row < M && global_col < K) {
- As[row][col] = A[global_row * K + global_col];
+ int g_row = by * BM + row;
+ int g_col = k_block + col;
+ if (g_row < M && g_col < K) {
+ As[row][col] = __nv_bfloat16(float(A_in[g_row * K + g_col]));
} else {
- As[row][col] = __nv_fp8_e4m3(0);
- }
- }
-
- // Cooperative load B into shared memory
- // B is stored as w[N][K], treated as KxN col-major for MMA
+ As[row][col] = __nv_bfloat16(0.0f);
+ }
+ }
+
+ // Cooperative load B (already bf16)
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
- int global_n = bx * BN + n_idx;
- int global_k = k_block + k_idx;
- if (global_n < N && global_k < K) {
- Bs[n_idx][k_idx] = B[global_n * K + global_k];
+ int g_n = bx * BN + n_idx;
+ int g_k = k_block + k_idx;
+ if (g_n < N && g_k < K) {
+ Bs[n_idx][k_idx] = B_in[g_n * K + g_k];
} else {
- Bs[n_idx][k_idx] = __nv_fp8_e4m3(0);
+ Bs[n_idx][k_idx] = __nv_bfloat16(0.0f);
... (7591 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"content": "<9309 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Run check.py with bf16 MMA kernel",
"timeout": 180000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 2.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py", line 253, in forward
_ensure_module()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py", li
... (7626 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -81,9 +81,10 @@
int g_row = by * BM + row;
int g_col = k_block + col;
if (g_row < M && g_col < K) {
- As[row][col] = __nv_bfloat16(float(A_in[g_row * K + g_col]));
+ float val = float(A_in[g_row * K + g_col]);
+ As[row][col] = __float2bfloat16(val);
} else {
- As[row][col] = __nv_bfloat16(0.0f);
+ As[row][col] = __float2bfloat16(0.0f);
}
}
@@ -96,7 +97,7 @@
if (g_n < N && g_k < K) {
Bs[n_idx][k_idx] = B_in[g_n * K + g_k];
} else {
- Bs[n_idx][k_idx] = __nv_bfloat16(0.0f);
+ Bs[n_idx][k_idx] = __float2bfloat16(0.0f);
}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"oldString": "<918 chars \u2014 see below>",
"newString": "<960 chars \u2014 see below>"
} // Cooperative load A (fp8 → bf16 conversion on the fly)
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
int g_row = by * BM + row;
int g_col = k_block + col;
if (g_row < M && g_col < K) {
As[row][col] = __nv_bfloat16(float(A_in[g_row * K + g_col]));
} else {
As[row][col] = __nv_bfloat16(0.0f);
}
}
// Cooperative load B (already bf16)
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
int g_n = bx * BN + n_idx;
int g_k = k_block + k_idx;
if (g_n < N && g_k < K) {
Bs[n_idx][k_idx] = B_in[g_n * K + g_k];
} else {
Bs[n_idx][k_idx] = __nv_bfloat16(0.0f);
}
} // Cooperative load A (fp8 → bf16 conversion on the fly)
for (int i = tid; i < BM * BK; i += blockDim.x) {
int row = i / BK;
int col = i % BK;
int g_row = by * BM + row;
int g_col = k_block + col;
if (g_row < M && g_col < K) {
float val = float(A_in[g_row * K + g_col]);
As[row][col] = __float2bfloat16(val);
} else {
As[row][col] = __float2bfloat16(0.0f);
}
}
// Cooperative load B (already bf16)
for (int i = tid; i < BN * BK; i += blockDim.x) {
int n_idx = i / BK;
int k_idx = i % BK;
int g_n = bx * BN + n_idx;
int g_k = k_block + k_idx;
if (g_n < N && g_k < K) {
Bs[n_idx][k_idx] = B_in[g_n * K + g_k];
} else {
Bs[n_idx][k_idx] = __float2bfloat16(0.0f);
}
}Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Run check.py after bf16 fix",
"timeout": 180000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 255.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py", line 254, in forward
_ensure_module()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
... (4735 more chars){
"command": "<2411 chars \u2014 see below>",
"description": "Test bf16 m16n8k16 MMA compilation",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = r'''
#include <cuda_bf16.h>
__global__ void test_mma_bf16(const __nv_bfloat16* A, const __nv_bfloat16* B, float* C) {
int tid = threadIdx.x;
int lane = tid % 32;
int quad = lane >> 2;
int lane_in_quad = lane & 3;
__shared__ __nv_bfloat16 As[16][16];
__shared__ __nv_bfloat16 Bs[8][16];
// Load dummy data
for (int i = tid; i < 16*16; i += 32) {
As[i/16][i%16] = __float2bfloat16(0.0f);
}
for (int i = tid; i < 8*16; i += 32) {
Bs[i/16][i%16] = __float2bfloat16(0.0f);
}
__syncthreads();
uint32_t a0 = *(const uint32_t*)(&As[quad][lane_in_quad * 2]);
uint32_t a1 = *(const uint32_t*)(&As[quad + 8][lane_in_quad * 2]);
uint32_t b0 = *(const uint32_t*)(&Bs[quad][lane_in_quad * 2]);
float acc[4] = {0,0,0,0};
asm volatile(
\"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"
\"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\"
: \"=f\"(acc[0]), \"=f\"(acc[1]), \"=f\"(acc[2]), \"=f\"(acc[3])
: \"r\"(a0), \"r\"(a1), \"r\"(b0),
\"f\"(acc[0]), \"f\"(acc[1]), \"f\"(acc[2]), \"f\"(acc[3])
);
int r0 = quad;
int r1 = quad + 8;
int c0 = lane_in_quad * 2;
int c1 = lane_in_quad * 2 + 1;
if (r0 < 16 && c0 < 8) C[r0*8 + c0] = acc[0];
if (r0 < 16 && c1 < 8) C[r0*8 + c1] = acc[1];
if (r1 < 16 && c0 < 8) C[r1*8 + c0] = acc[2];
if (r1 < 16 && c1 < 8) C[r1*8 + c1] = acc[3];
}
torch::Tensor launch() {
auto out = torch::zeros({128}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
auto a = torch::zeros({16*16}, torch::dtype(torch::kBFloat16).device(torch::kCUDA));
auto b = torch::zeros({8*16}, torch::dtype(torch::kBFloat16).device(torch::kCUDA));
test_mma_bf16<<<1, 32>>>(
reinterpret_cast<const __nv_bfloat16*>(a.const_data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(b.const_data_ptr()),
out.mutable_data_ptr<float>()
);
return out;
}
'''
mod = load_inline(
name='test_mma_bf16',
cpp_sources='torch::Tensor launch();',
cuda_sources=[src],
functions=['launch'],
extra_cuda_cflags=['-O3', '-arch=sm_120a', '-std=c++17', '--keep'],
verbose=True,
)
out = mod.launch()
print('bf16 MMA test passed')
" 2>&1 | tail -30[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=test_mma_bf16 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -fPIC -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/test_mma_bf16/main.cpp -o main.o
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=test_mma_bf16 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 -arch=sm_120a -std=c++17 --keep -c /home/infatoshi/.cache/torch_extensions/py312_cu130/test_mma_bf16/cuda.cu -o cuda.cuda.o
FAILED: cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=test_mma_bf16 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12
... (1829 more chars){
"command": "find /home/infatoshi/.cache/torch_extensions -name \"cuda.compute_120.ptx\" -newer /tmp 2>/dev/null | head -3",
"description": "Find generated PTX file"
}(no output){
"command": "find /home/infatoshi/.cache/torch_extensions -name \"*.ptx\" -mmin -5 2>/dev/null | head -5",
"description": "Find recent PTX files"
}/home/infatoshi/.cache/torch_extensions/py312_cu130/test_mma_bf16/cuda.compute_120.ptx
{
"command": "grep -n \"mma\" /home/infatoshi/.cache/torch_extensions/py312_cu130/test_mma_bf16/cuda.compute_120.ptx 2>/dev/null",
"description": "Check MMA lines in PTX"
}13: // .globl _Z13test_mma_bf16PK13__nv_bfloat16S1_Pf
14:// _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2As has been demoted
15:// _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2Bs has been demoted
63:.visible .entry _Z13test_mma_bf16PK13__nv_bfloat16S1_Pf(
64: .param .u64 .ptr .align 1 _Z13test_mma_bf16PK13__nv_bfloat16S1_Pf_param_0,
65: .param .u64 .ptr .align 1 _Z13test_mma_bf16PK13__nv_bfloat16S1_Pf_param_1,
66: .param .u64 .ptr .align 1 _Z13test_mma_bf16PK13__nv_bfloat16S1_Pf_param_2
74: .shared .align 2 .b8 _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2As[512];
76: .shared .align 2 .b8 _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2Bs[256];
77: ld.param.b64 %rd1, [_Z13test_mma_bf16PK13__nv_bfloat16S1_Pf_param_2];
87: mov.b32 %r6, _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2As;
106: mov.b32 %r9, _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2Bs;
121: mov.b32 %r20, _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2As;
128: mov.b32 %r25, _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2Bs;
134: mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%r10, %r11, %r12, %r13}, {%r14, %r15}, {%r16}, {%r17, %r17, %r17, %r17};
{
"command": "head -15 /home/infatoshi/.cache/torch_extensions/py312_cu130/test_mma_bf16/cuda.compute_120.ptx",
"description": "Check PTX header version"
}//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: UNKNOWN
// Cuda compilation tools, release 13.2, V13.2.51
// Based on NVVM 22.0.0
//
.version 9.2
.target sm_120
.address_size 64
// .globl _Z13test_mma_bf16PK13__nv_bfloat16S1_Pf
// _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2As has been demoted
// _ZZ13test_mma_bf16PK13__nv_bfloat16S1_PfE2Bs has been demoted
{
"command": "<990 chars \u2014 see below>",
"description": "Test if f16 mma works on SM120",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && python3 -c "
from torch.utils.cpp_extension import load_inline
src = r'''
__global__ void test() {
float acc[4] = {0,0,0,0};
unsigned a0=0, a1=0, b0=0;
asm volatile(
\"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"
\"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\"
: \"=f\"(acc[0]), \"=f\"(acc[1]), \"=f\"(acc[2]), \"=f\"(acc[3])
: \"r\"(a0), \"r\"(a1), \"r\"(b0),
\"f\"(acc[0]), \"f\"(acc[1]), \"f\"(acc[2]), \"f\"(acc[3])
);
}
void launch() {
test<<<1,32>>>();
}
'''
try:
mod = load_inline(
name='test_f16_mma',
cpp_sources='void launch();',
cuda_sources=[src],
functions=['launch'],
extra_cuda_cflags=['-O3', '-arch=sm_120a', '-std=c++17'],
verbose=False,
)
mod.launch()
print('f16 MMA works on SM120')
except Exception as e:
print(f'f16 MMA FAILED: {type(e).__name__}')
" 2>&1f16 MMA FAILED: RuntimeError