import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
cuda_src = r"""
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <cublasLt.h>
#include <cuda_runtime.h>
#include <unordered_map>
#include <mutex>
#define CUBLAS_CHECK(call) \
do { \
cublasStatus_t err = call; \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw std::runtime_error("cuBLAS error: " + std::to_string(err)); \
} \
} while (0)
static cublasLtHandle_t get_handle() {
static thread_local cublasLtHandle_t handle = nullptr;
if (!handle) CUBLAS_CHECK(cublasLtCreate(&handle));
return handle;
}
struct GemmCacheKey {
int64_t M, N, K;
bool operator==(const GemmCacheKey& o) const { return M==o.M && N==o.N && K==o.K; }
};
struct GemmCacheKeyHash {
size_t operator()(const GemmCacheKey& k) const { return k.M*1000003 ^ k.N*1033 ^ k.K; }
};
struct GemmCacheEntry {
cublasLtMatrixLayout_t lA, lB, lD;
cublasLtMatmulDesc_t desc;
cublasLtMatmulAlgo_t algo;
bool has_algo;
int64_t ws_size;
};
static std::unordered_map<GemmCacheKey, GemmCacheEntry, GemmCacheKeyHash> cache;
static std::mutex cache_mtx;
static GemmCacheEntry& get_or_create(int64_t M, int64_t N, int64_t K, int device_id) {
GemmCacheKey key{M, N, K};
std::lock_guard<std::mutex> lock(cache_mtx);
auto it = cache.find(key);
if (it != cache.end()) return it->second;
auto handle = get_handle();
GemmCacheEntry e;
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&e.lA, CUDA_R_16BF, (uint64_t)K, (uint64_t)N, (int64_t)K));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&e.lB, CUDA_R_16BF, (uint64_t)K, (uint64_t)M, (int64_t)K));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&e.lD, CUDA_R_16BF, (uint64_t)N, (uint64_t)M, (int64_t)N));
CUBLAS_CHECK(cublasLtMatmulDescCreate(&e.desc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
cublasOperation_t tA = CUBLAS_OP_T, tB = CUBLAS_OP_N;
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(e.desc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(e.desc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB)));
e.ws_size = 64 << 20;
cublasLtMatmulPreference_t pref;
CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&pref));
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &e.ws_size, sizeof(e.ws_size)));
cublasLtMatmulHeuristicResult_t heur[4];
int found = 0;
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, e.desc, e.lA, e.lB, e.lD, e.lD, pref, 4, heur, &found));
if (found > 0) {
e.algo = heur[0].algo;
e.has_algo = true;
} else {
e.has_algo = false;
}
CUBLAS_CHECK(cublasLtMatmulPreferenceDestroy(pref));
auto result = cache.emplace(key, std::move(e));
return result.first->second;
}
torch::Tensor bf16_gemm(torch::Tensor A, torch::Tensor B, int64_t M, int64_t N, int64_t K) {
TORCH_CHECK(A.is_cuda() && B.is_cuda(), "inputs must be CUDA");
TORCH_CHECK(A.dtype() == torch::kBFloat16, "A must be bf16");
TORCH_CHECK(B.dtype() == torch::kBFloat16, "B must be bf16");
auto handle = get_handle();
cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();
auto out = torch::empty({M, N}, A.options().dtype(torch::kBFloat16));
float alpha = 1.0f, beta = 0.0f;
auto& e = get_or_create(M, N, K, A.get_device());
static thread_local torch::Tensor ws;
if (!ws.defined() || ws.numel() < (int64_t)e.ws_size || ws.device() != A.device()) {
ws = torch::empty({(int64_t)e.ws_size}, torch::TensorOptions().dtype(torch::kInt8).device(A.device()));
}
if (e.has_algo) {
CUBLAS_CHECK(cublasLtMatmul(handle, e.desc, &alpha,
B.data_ptr(), e.lA, A.data_ptr(), e.lB, &beta,
nullptr, e.lD, out.data_ptr(), e.lD,
&e.algo, ws.data_ptr(), e.ws_size, stream));
} else {
CUBLAS_CHECK(cublasLtMatmul(handle, e.desc, &alpha,
B.data_ptr(), e.lA, A.data_ptr(), e.lB, &beta,
nullptr, e.lD, out.data_ptr(), e.lD,
nullptr, ws.data_ptr(), e.ws_size, stream));
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bf16_gemm", &bf16_gemm, "BF16 GEMM via cuBLASLt");
}
"""
_ext = load_inline(
name="bf16_gemm_v3",
cpp_sources="",
cuda_sources=cuda_src,
extra_ldflags=["-lcublasLt", "-lcublas"],
verbose=True,
)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int):
super().__init__()
self.M, self.N, self.K = M, N, K
self.weight = nn.Parameter(torch.empty(N, K, dtype=torch.bfloat16))
nn.init.normal_(self.weight, std=0.02)
self.K_pad = ((K + 31) // 32) * 32
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_bf16 = x.to(torch.bfloat16)
w_bf16 = self.weight.to(torch.bfloat16)
K_pad = self.K_pad
if K_pad != self.K:
x_bf16 = torch.nn.functional.pad(x_bf16, (0, K_pad - self.K))
w_bf16 = torch.nn.functional.pad(w_bf16, (0, K_pad - self.K))
return _ext.bf16_gemm(x_bf16.contiguous(), w_bf16.contiguous(), self.M, self.N, K_pad)
M = 4096
N = 4096
K = 4096
def get_inputs():
x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn)
return [x]
def get_init_inputs():
return [M, N, K]
ninja: no work to do.
FAIL: shape 1 {'M': 4096, 'N': 4096, 'K': 4127} seed 42: tolerance exceeded: max_abs_diff=0.0625 (atol=0.01, rtol=0.01)
{
"problem": "01_fp8_gemm",
"harness": "opencode",
"model": "zai/glm-5.1",
"reasoning_effort": "",
"has_solution": true,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 137470, "output_tokens": 15056, "cache_read_tokens": 3295296, "cache_creation_tokens": 0, "reasoning_tokens": 100870, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive FP8 e4m3 GEMM reference (correctness only, NOT the SOTA baseline).
2:
3: We cast inputs to bf16 and use torch.matmul. The agent's solution must match
4: this numerically within the fp8 tolerance declared in problem.yaml.
5: """
6: import torch
7: import torch.nn as nn
8:
9: OP_TYPE = "gemm"
10: SUPPORTED_PRECISIONS = ["fp8_e4m3"]
11: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
12:
13:
14: class Model(nn.Module):
15: """y = (x @ w.T).to(bf16), where x is fp8_e4m3 (M, K), w is fp8_e4m3 (N, K)."""
16:
17: def __init__(self, M: int, N: int, K: int):
18: super().__init__()
19: self.M, self.N, self.K = M, N, K
20: # Weights stored as parameters so state_dict is well-defined.
21: # We initialize in bf16 then cast; the fp8 dtype is set by get_inputs.
22: self.weight = nn.Parameter(torch.empty(N, K, dtype=torch.bfloat16))
23: nn.init.normal_(self.weight, std=0.02)
24:
25: def forward(self, x: torch.Tensor) -> torch.Tensor:
26: # Upcast to bf16 for the naive reference; the kernel equivalent would
27: # use mma.sync f8f6f4 kind directly.
28: x_bf = x.to(torch.bfloat16)
29: w_bf = self.weight.to(torch.bfloat16)
30: return x_bf @ w_bf.T # (M, N) bf16
31:
32:
33: M = 4096
34: N = 4096
35: K = 4096
36:
37:
38: def get_inputs():
39: # fp8_e4m3 inp
... (227 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for FP8 GEMM.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or None
42:
43: # --- Per-shape c
... (2907 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for FP8 GEMM.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "compute")
40: flops_formula
... (3416 more chars){
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for FP8 GEMM.
2:
3: Mix of:
4: - square aligned (the easy case)
5: - off-alignment K (common real-world failure mode for tile-quantized kernels)
6: - skinny (decode-like, memory-bound)
7: - rectangular (prefill with grouped attention)
8: """
9:
10: SHAPES = [
11: {"M": 4096, "N": 4096, "K": 4096}, # square aligned
12: {"M": 4096, "N": 4096, "K": 4127}, # K not multiple of 128 -> forces predicated tails
13: {"M": 32, "N": 8192, "K": 8192}, # skinny M (decode)
14: {"M": 4096, "N": 14336, "K": 4096}, # Llama3 up-proj shape
15: ]
(End of file - total 15 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 01_fp8_gemm
2: display_name: "FP8 e4m3 GEMM"
3: precision: fp8_e4m3
4: regime: compute
5:
6: # Dense-equivalent FLOPs per invocation. Evaluated per-shape; the benchmark
7: # substitutes M/N/K from each shapes.py entry.
8: flops_formula: "2 * M * N * K"
9:
10: # Bytes moved (approximate): fp8 inputs in, bf16 output.
11: bytes_formula: "M*K + K*N + M*N*2"
12:
13: hardware: [RTX_PRO_6000]
14: peak_tflops_key: fp8 # index into src/hardware/rtx_pro_6000.py peak_tflops_dense
15: peak_bandwidth_key: dram
16:
17: # Per-dtype tolerance override (optional). Defaults are in src/eval/correctness.py.
18: tolerance:
19: fp8_e4m3fn: 0.15
20:
21: # Forbidden ops — using any of these in solution.py fails correctness post-hoc.
22: # torch._scaled_mm is Stanford-style cheating (just call the vendor library).
23: forbidden:
24: - "torch._scaled_mm"
25: - "torch.ops.aten._scaled_mm"
26:
27: sota:
28: name: "FlashInfer FP8 GEMM"
29: url: "https://github.com/flashinfer-ai/flashinfer"
30: function: "flashinfer.gemm.fp8_gemm"
31: deps:
32: - "flashinfer>=0.6.8"
33: # Documented H100 throughput for this shape (informational, not graded):
34: reference_throughput_tflops_h100: 550
35:
36: num_correct_trials: 3
37: num_perf_trials: 30
(End of file - total 37 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for FP8 GEMM: flashinfer.gemm.fp8_gemm.
2:
3: If flashinfer is not installed or the SM120 path isn't supported, this falls
4: back to torch._scaled_mm which is the cuBLAS FP8 path. The benchmark treats
5: whichever succeeds as the SOTA reference line.
6:
7: Agents are FORBIDDEN from using torch._scaled_mm in their solution (see
8: problem.yaml.forbidden). This file is only for the benchmark's reference line.
9: """
10: from __future__ import annotations
11:
12: import torch
13:
14:
15: def _try_flashinfer(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor | None:
16: try:
17: import flashinfer # noqa: F401
18: # Note: flashinfer's FP8 GEMM API surface may differ; adapt if needed.
19: # Placeholder call — replace with the actual flashinfer entry point
20: # once validated on SM120.
21: return None
22: except ImportError:
23: return None
24:
25:
26: def _scaled_mm(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
27: # torch._scaled_mm wants per-tensor scales. Use unit scales for the reference.
28: scale_a = torch.tensor(1.0, device=x.device)
29: scale_b = torch.tensor(1.0, device=x.device)
30: out = torch._scaled_mm(
31: x,
32: w.T,
33: scale_a=scale_a,
34: scale_b=scale_b,
35: out_dtype=torch.bfloat16,
36: )
37: return out
... (576 more chars){
"todos": [
{
"content": "Check GPU env (nvidia-smi, nvcc, CUDA version, SM arch)",
"status": "in_progress",
"priority": "high"
},
{
"content": "Write initial solution.py with cuBLASLt FP8 GEMM via load_inline",
"status": "pending",
"priority": "high"
},
{
"content": "Test correctness with check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Benchmark with benchmark.py",
"status": "pending",
"priority": "high"
},
{
"content": "Iterate on performance if needed",
... (63 more chars)content: 'Check GPU env (nvidia-smi, nvcc, CUDA version, SM arch)'
status: 'in_progress'
priority: 'high'
content: 'Write initial solution.py with cuBLASLt FP8 GEMM via load_inline'
status: 'pending'
priority: 'high'
content: 'Test correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Benchmark with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate on performance if needed'
status: 'pending'
priority: 'medium'{
"command": "nvidia-smi 2>/dev/null | head -5",
"description": "Check GPU utilization and memory"
}Tue Apr 28 14:41:39 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "nvidia-smi --query-gpu=name,compute_cap,sm --format=csv,noheader 2>/dev/null; echo \"---\"; nvcc --version 2>/dev/null | tail -1; echo \"---\"; python3 -c \"import torch; print('torch:', torch.__version__); print('cuda:', torch.version.cuda); print('arch_list:', torch.cuda.get_arch_list())\" 2>/dev/null",
"description": "Check GPU, nvcc, torch versions"
}Field "sm" is not a valid field to query.
---
Build cuda_13.2.r13.2/compiler.37434383_0
---
torch: 2.11.0+cu130
cuda: 13.0
arch_list: ['sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120']
{
"command": "nvidia-smi --query-gpu=name,compute_cap --format=csv,noheader 2>/dev/null",
"description": "Check GPU compute capability"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0
{
"command": "python3 -c \"import torch; print(torch.cuda.get_device_properties(0))\"",
"description": "Check GPU properties via PyTorch"
}_CudaDeviceProperties(name='NVIDIA RTX PRO 6000 Blackwell Workstation Edition', major=12, minor=0, total_memory=97227MB, multi_processor_count=188, uuid=d594fb7a-abcb-bad5-1fc4-675b478a47df, pci_bus_id=1, pci_device_id=0, pci_domain_id=0, L2_cache_size=128MB)
{
"command": "python3 -c \"import triton; print('triton:', triton.__version__)\" 2>/dev/null; echo \"---\"; ls /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py 2>/dev/null || echo \"no solution.py yet\"",
"description": "Check Triton version and existing solution"
}triton: 3.6.0
---
no solution.py yet
{
"command": "python3 -c \"import torch; print(hasattr(torch, '_scaled_mm')); print(torch._scaled_mm.__doc__[:200] if hasattr(torch, '_scaled_mm') else 'N/A')\"",
"description": "Check torch._scaled_mm availability"
}Traceback (most recent call last):
File "<string>", line 1, in <module>
TypeError: 'NoneType' object is not subscriptable
True
{
"command": "<238 chars \u2014 see below>",
"description": "Check hardware peak specs"
}python3 -c "
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard')
from src.hardware import get
hw = get('RTX_PRO_6000')
print('Peak TFLOPS dense:', hw.peak_tflops_dense)
print('Peak bandwidth:', hw.peak_bandwidth_gb_s)
"Peak TFLOPS dense: {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}
Peak bandwidth: 1800.0
{
"command": "<589 chars \u2014 see below>",
"description": "Benchmark reference bf16 GEMM"
}python3 -c "
import torch
# Quick test: reference bf16 GEMM performance
M, N, K = 4096, 4096, 4096
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device='cuda', dtype=torch.bfloat16) * 0.02
# warmup
for _ in range(5):
_ = x.to(torch.bfloat16) @ w.T
torch.cuda.synchronize()
import time
t0 = time.time()
for _ in range(50):
_ = x.to(torch.bfloat16) @ w.T
torch.cuda.synchronize()
t1 = time.time()
ms = (t1 - t0) / 50 * 1000
flops = 2 * M * N * K
tflops = flops / ms / 1e9
print(f'Reference bf16 GEMM: {ms:.3f} ms, {tflops:.1f} TFLOPS')
"Reference bf16 GEMM: 0.355 ms, 387.6 TFLOPS
{
"command": "<793 chars \u2014 see below>",
"description": "Benchmark torch._scaled_mm FP8 GEMM"
}python3 -c "
import torch
M, N, K = 4096, 4096, 4096
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device='cuda', dtype=torch.bfloat16) * 0.02
w_fp8 = w.to(torch.float8_e4m3fn)
scale_a = torch.tensor(1.0, device='cuda')
scale_b = torch.tensor(1.0, device='cuda')
# warmup
for _ in range(5):
_ = torch._scaled_mm(x, w_fp8.T, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)
torch.cuda.synchronize()
import time
t0 = time.time()
for _ in range(50):
_ = torch._scaled_mm(x, w_fp8.T, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)
torch.cuda.synchronize()
t1 = time.time()
ms = (t1 - t0) / 50 * 1000
flops = 2 * M * N * K
tflops = flops / ms / 1e9
print(f'torch._scaled_mm FP8 GEMM: {ms:.3f} ms, {tflops:.1f} TFLOPS')
"torch._scaled_mm FP8 GEMM: 0.177 ms, 775.8 TFLOPS
{
"todos": [
{
"content": "Check GPU env (nvidia-smi, nvcc, CUDA version, SM arch)",
"priority": "high",
"status": "completed"
},
{
"content": "Write initial solution.py with cuBLASLt FP8 GEMM via load_inline",
"priority": "high",
"status": "in_progress"
},
{
"content": "Test correctness with check.py",
"priority": "high",
"status": "pending"
},
{
"content": "Benchmark with benchmark.py",
"priority": "high",
"status": "pending"
},
{
"content": "Iterate on performance if needed",
... (65 more chars)content: 'Check GPU env (nvidia-smi, nvcc, CUDA version, SM arch)'
status: 'completed'
priority: 'high'
content: 'Write initial solution.py with cuBLASLt FP8 GEMM via load_inline'
status: 'in_progress'
priority: 'high'
content: 'Test correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Benchmark with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate on performance if needed'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -0,0 +1,128 @@
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+cuda_src = r"""
+#include <torch/extension.h>
+#include <cublasLt.h>
+#include <cuda_runtime.h>
+
+#define CUBLAS_CHECK(call) \
+ do { \
+ cublasStatus_t err = call; \
+ if (err != CUBLAS_STATUS_SUCCESS) { \
+ throw std::runtime_error("cuBLAS error: " + std::to_string(err)); \
+ } \
+ } while (0)
+
+static cublasLtHandle_t get_handle() {
+ static thread_local cublasLtHandle_t handle = nullptr;
+ if (!handle) CUBLAS_CHECK(cublasLtCreate(&handle));
+ return handle;
+}
+
+torch::Tensor fp8_gemm(torch::Tensor A, torch::Tensor B, int64_t M, int64_t N, int64_t K) {
+ TORCH_CHECK(A.is_cuda() && B.is_cuda(), "inputs must be CUDA");
+ TORCH_CHECK(A.dtype() == torch::kFloat8_e4m3fn, "A must be fp8_e4m3fn");
+ TORCH_CHECK(B.dtype() == torch::kFloat8_e4m3fn, "B must be fp8_e4m3fn");
+
+ auto handle = get_handle();
+ auto stream = at::cuda::getCurrentCUDAStream();
+ CUBLAS_CHECK(cublasLtSetStream(handle, stream));
+
+ auto out = torch::empty({M, N}, A.options().dtype(torch::kBFloat16));
+
+ float alpha = 1.0f, beta = 0.0f;
+ auto sa = torch::tensor({1.0f}, torch::kFloat32).to(A.device());
+ auto sb = torch::tensor({1.0f}, torch::kFloat32).to(A.device());
+ float* dsa = sa.data_ptr<float>();
+ float* dsb = sb.data_ptr<float>();
+
+ // Row-major PyTorch tensors viewed as column-major for cuBLAS:
+ // A_row(M,K) = A_col(K,M), B_row(N,K) = B_col(K,N), D_row(M,N) = D_col(N,M)
+ // Compute D = A @ B.T => D_col = B_col^T @ A_col
+ // cuBLAS args: A_cublas=B_col (K,N) transA=T, B_cublas=A_col (K,M) transB=N
+ cublasLtMatrixLayout_t lA, lB, lD;
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lA, CUDA_R_8F_E4M3, (uint64_t)K, (uint64_t)N, (int64_t)K));
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lB, CUDA_R_8F_E4M3, (uint64_t)K, (uint64_t)M, (int64_t)K));
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lD, CUDA_R_16BF, (uint64_t)N, (uint64_t)M, (int64_t)N));
+
+ cublasLtMatmulDesc_t desc;
+ CUBLAS_CHECK(cublasLtMatmulDescCreate(&desc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
+
+ cublasOperation_t tA = CUBLAS_OP_T, tB = CUBLAS_OP_N;
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA)));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB)));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &dsa, sizeof(dsa)));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &dsb, sizeof(dsb)));
+
+ size_t ws_size = 32 << 20;
+ auto ws = torch::empty({(int64_t)ws_size}, torch::TensorOptions().dtype(torch::kInt8).device(A.device()));
+
+ cublasLtMatmulPreference_t pref;
+ CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&pref));
+ CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_size, sizeof(ws_size)));
+
+ int found = 0;
+ cublasLtMatmulHeuristicResult_t heur;
+ CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, lA, lB, lD, lD, pref, 1, &heur, &found));
+
+ if (found) {
+ CUBLAS_CHECK(cublasLtMatmul(handle, desc, &alpha,
+ B.data_ptr(), lA, A.data_ptr(), lB, &beta,
+ nullptr, lD, out.data_ptr(), lD,
+ &heur.algo, ws.data_ptr(), ws_size, stream));
+ } else {
+ CUBLAS_CHECK(cublasLtMatmul(handle, desc, &alpha,
+ B.data_ptr(), lA, A.data_ptr(), lB, &beta,
+ nullptr, lD, out.data_ptr(), lD,
+ nullptr, ws.data_ptr(), ws_size, stream));
+ }
+
+ cublasLtMatrixLayoutDestroy(lA);
+ cublasLtMatrixLayoutDestroy(lB);
+ cublasLtMatrixLayoutDestroy(lD);
+ cublasLtMatmulDescDestroy(desc);
+ cublasLtMatmulPreferenceDestroy(pref);
+
+ return out;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fp8_gemm", &fp8_gemm, "FP8 GEMM via cuBLASLt");
+}
+"""
+
+_ext = load_inline(
+ name="fp8_gemm_ext",
+ cuda_sources=cuda_src,
+ extra_ldflags=["-lcublasLt", "-lcublas"],
+ verbose=True,
+)
+
+
+class Model(nn.Module):
+ def __init__(self, M: int, N: int, K: int):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ self.weight = nn.Parameter(torch.empty(N, K, dtype=torch.bfloat16))
+ nn.init.normal_(self.weight, std=0.02)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ w = self.weight.to(torch.float8_e4m3fn).contiguous()
+ x = x.contiguous()
+ return _ext.fp8_gemm(x, w, self.M, self.N, self.K)
+
+
+M = 4096
+N = 4096
+K = 4096
+
+
+def get_inputs():
+ x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn)
+ return [x]
+
+
+def get_init_inputs():
+ return [M, N, K]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"content": "<4984 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "<475 chars \u2014 see below>",
"description": "Quick smoke test of solution",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import solution
M, N, K = 4096, 4096, 4096
model = solution.Model(M, N, K).cuda().eval()
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
with torch.no_grad():
out = model(x)
print('Output shape:', out.shape, 'dtype:', out.dtype)
print('Output stats: min={:.4f}, max={:.4f}, mean={:.4f}'.format(out.min().item(), out.max().item(), out.mean().item()))
"Traceback (most recent call last):
File "<string>", line 3, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py", line 96, in <module>
_ext = load_inline(
^^^^^^^^^^^^
TypeError: load_inline() missing 1 required positional argument: 'cpp_sources'
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -95,6 +95,7 @@
_ext = load_inline(
name="fp8_gemm_ext",
+ cpp_sources="",
cuda_sources=cuda_src,
extra_ldflags=["-lcublasLt", "-lcublas"],
verbose=True,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"oldString": "<137 chars \u2014 see below>",
"newString": "<157 chars \u2014 see below>"
}_ext = load_inline(
name="fp8_gemm_ext",
cuda_sources=cuda_src,
extra_ldflags=["-lcublasLt", "-lcublas"],
verbose=True,
)_ext = load_inline(
name="fp8_gemm_ext",
cpp_sources="",
cuda_sources=cuda_src,
extra_ldflags=["-lcublasLt", "-lcublas"],
verbose=True,
)Edit applied successfully.{
"command": "<480 chars \u2014 see below>",
"description": "Smoke test solution",
"timeout": 180000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import solution
M, N, K = 4096, 4096, 4096
model = solution.Model(M, N, K).cuda().eval()
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
with torch.no_grad():
out = model(x)
print('Output shape:', out.shape, 'dtype:', out.dtype)
print('Output stats: min={:.4f}, max={:.4f}, mean={:.4f}'.format(out.min().item(), out.max().item(), out.mean().item()))
" 2>&1[1/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=fp8_gemm_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/fp8_gemm_ext/cuda.cu -o cuda.cuda.o
FAILED: [code=2] cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=fp8_gemm_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=c
... (3034 more chars){
"command": "rm -rf /home/infatoshi/.cache/torch_extensions/py311_cu130/fp8_gemm_ext",
"description": "Clean build cache"
}{
"command": "<148 chars \u2014 see below>",
"description": "Check CUDA stream API"
}uv run python -c "
import torch
# Check what's available for getting CUDA stream
print(dir(torch.cuda))
" 2>/dev/null | tr ',' '\n' | grep -i stream'ExternalStream'
'Stream'
'StreamContext'
'_set_stream_by_id'
'current_stream'
'default_stream'
'get_stream_from_external'
'is_current_stream_capturing'
'set_stream'
'stream'
'streams'
{
"command": "<324 chars \u2014 see below>",
"description": "Find CUDA stream header location"
}uv run python -c "
# Check C++ API for getting CUDA stream
import subprocess, sys
result = subprocess.run(['grep', '-r', 'getCurrentCUDAStream', '/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/', '--include=*.h', '-l'], capture_output=True, text=True)
print(result.stdout)
"/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h
{
"command": "grep -r \"getCurrentCUDAStream\" /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/c10/cuda/ --include='*.h' -l 2>/dev/null; echo \"---\"; grep -r \"getCurrentCUDAStream\" /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ --include='*.h' -l 2>/dev/null",
"description": "Find getCurrentCUDAStream header"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/c10/cuda/impl/CUDAGuardImpl.h
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/c10/cuda/CUDAGraphsC10Utils.h
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/c10/cuda/CUDAEvent.h
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/c10/cuda/CUDAStream.h
---
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h
{
"command": "grep -n \"getCurrentCUDAStream\" /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/c10/cuda/CUDAStream.h | head -5",
"description": "Check CUDAStream API"
}227:C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1);
257: return c10::cuda::getCurrentCUDAStream(device_index);
278: return c10::cuda::getCurrentCUDAStream(device_index);
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py</path>
<type>file</type>
<content>
1: import torch
2: import torch.nn as nn
3: from torch.utils.cpp_extension import load_inline
4:
5: cuda_src = r"""
6: #include <torch/extension.h>
7: #include <cublasLt.h>
8: #include <cuda_runtime.h>
9:
10: #define CUBLAS_CHECK(call) \
11: do { \
12: cublasStatus_t err = call; \
13: if (err != CUBLAS_STATUS_SUCCESS) { \
14: throw std::runtime_error("cuBLAS error: " + std::to_string(err)); \
15: } \
16: } while (0)
17:
18: static cublasLtHandle_t get_handle() {
19: static thread_local cublasLtHandle_t handle = nullptr;
20: if (!handle) CUBLAS_CHECK(cublasLtCreate(&handle));
21: return handle;
22: }
23:
24: torch::Tensor fp8_gemm(torch::Tensor A, torch::Tensor B, int64_t M, int64_t N, int64_t K) {
25: TORCH_CHECK(A.is_cuda() && B.is_cuda(), "inputs must be CUDA");
26: TORCH_CHECK(A.dtype() == torch::kFloat8_e4m3fn, "A must be fp8_e4m3fn");
27: TORCH_CHECK(B.dtype() == torch::kFloat8_e4m3fn, "B must be fp8_e4m3fn");
28:
29: auto handle = get_handle();
30: auto stream = at::cuda::getCurrentCUDAStream();
31: CUBLAS_CHEC
... (4196 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -4,6 +4,7 @@
cuda_src = r"""
#include <torch/extension.h>
+#include <c10/cuda/CUDAStream.h>
#include <cublasLt.h>
#include <cuda_runtime.h>
@@ -27,8 +28,7 @@
TORCH_CHECK(B.dtype() == torch::kFloat8_e4m3fn, "B must be fp8_e4m3fn");
auto handle = get_handle();
- auto stream = at::cuda::getCurrentCUDAStream();
- CUBLAS_CHECK(cublasLtSetStream(handle, stream));
+ cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();
auto out = torch::empty({M, N}, A.options().dtype(torch::kBFloat16));
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"oldString": "<1224 chars \u2014 see below>",
"newString": "<1222 chars \u2014 see below>"
}cuda_src = r"""
#include <torch/extension.h>
#include <cublasLt.h>
#include <cuda_runtime.h>
#define CUBLAS_CHECK(call) \
do { \
cublasStatus_t err = call; \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw std::runtime_error("cuBLAS error: " + std::to_string(err)); \
} \
} while (0)
static cublasLtHandle_t get_handle() {
static thread_local cublasLtHandle_t handle = nullptr;
if (!handle) CUBLAS_CHECK(cublasLtCreate(&handle));
return handle;
}
torch::Tensor fp8_gemm(torch::Tensor A, torch::Tensor B, int64_t M, int64_t N, int64_t K) {
TORCH_CHECK(A.is_cuda() && B.is_cuda(), "inputs must be CUDA");
TORCH_CHECK(A.dtype() == torch::kFloat8_e4m3fn, "A must be fp8_e4m3fn");
TORCH_CHECK(B.dtype() == torch::kFloat8_e4m3fn, "B must be fp8_e4m3fn");
auto handle = get_handle();
auto stream = at::cuda::getCurrentCUDAStream();
CUBLAS_CHECK(cublasLtSetStream(handle, stream));cuda_src = r"""
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <cublasLt.h>
#include <cuda_runtime.h>
#define CUBLAS_CHECK(call) \
do { \
cublasStatus_t err = call; \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw std::runtime_error("cuBLAS error: " + std::to_string(err)); \
} \
} while (0)
static cublasLtHandle_t get_handle() {
static thread_local cublasLtHandle_t handle = nullptr;
if (!handle) CUBLAS_CHECK(cublasLtCreate(&handle));
return handle;
}
torch::Tensor fp8_gemm(torch::Tensor A, torch::Tensor B, int64_t M, int64_t N, int64_t K) {
TORCH_CHECK(A.is_cuda() && B.is_cuda(), "inputs must be CUDA");
TORCH_CHECK(A.dtype() == torch::kFloat8_e4m3fn, "A must be fp8_e4m3fn");
TORCH_CHECK(B.dtype() == torch::kFloat8_e4m3fn, "B must be fp8_e4m3fn");
auto handle = get_handle();
cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();Edit applied successfully.{
"command": "rm -rf ~/.cache/torch_extensions/py311_cu130/fp8_gemm_ext 2>/dev/null; echo \"cleaned\"",
"description": "Clean stale extension build cache"
}--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -94,7 +94,7 @@
"""
_ext = load_inline(
- name="fp8_gemm_ext",
+ name="fp8_gemm_v2",
cpp_sources="",
cuda_sources=cuda_src,
extra_ldflags=["-lcublasLt", "-lcublas"],
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"oldString": "<157 chars \u2014 see below>",
"newString": "<156 chars \u2014 see below>"
}_ext = load_inline(
name="fp8_gemm_ext",
cpp_sources="",
cuda_sources=cuda_src,
extra_ldflags=["-lcublasLt", "-lcublas"],
verbose=True,
)_ext = load_inline(
name="fp8_gemm_v2",
cpp_sources="",
cuda_sources=cuda_src,
extra_ldflags=["-lcublasLt", "-lcublas"],
verbose=True,
)Edit applied successfully.{
"command": "<480 chars \u2014 see below>",
"description": "Smoke test with fixed CUDA stream API",
"timeout": 180000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import solution
M, N, K = 4096, 4096, 4096
model = solution.Model(M, N, K).cuda().eval()
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
with torch.no_grad():
out = model(x)
print('Output shape:', out.shape, 'dtype:', out.dtype)
print('Output stats: min={:.4f}, max={:.4f}, mean={:.4f}'.format(out.min().item(), out.max().item(), out.mean().item()))
" 2>&1[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=fp8_gemm_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -fPIC -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/fp8_gemm_v2/main.cpp -o main.o
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=fp8_gemm_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/fp8_gemm_v2/cuda.cu -o cuda.cuda.o
[3/3] c++ main.o cuda.cuda.o -shared -lcublasLt -lcublas -L/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.1
... (260 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000
}ninja: no work to do.
FAIL: shape 0 {'M': 4096, 'N': 4096, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=0.5625 (atol=0.01, rtol=0.01)
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && find . -path './src/eval/correctness.py' -exec cat {} \\;",
"description": "Read correctness check implementation"
}"""Per-dtype correctness tolerance.
Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
We extend to fp8 (0.1, permissive) and int (exact).
"""
from __future__ import annotations
import torch
DEFAULT_TOLERANCE = {
torch.float32: {"atol": 1e-4, "rtol": 1e-4},
torch.float16: {"atol": 1e-2, "rtol": 1e-2},
torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
torch.int8: {"atol": 0, "rtol": 0},
torch.int32: {"atol": 0, "rtol": 0},
torch.int64: {"atol": 0, "rtol": 0},
}
def _coerce_float(x):
"""Accept int/float/str scientific-notation as float; pass through dicts."""
if isinstance(x, (int, float)):
return float(x)
if isinstance(x, str):
try:
return float(x)
except ValueError:
return x
return x
def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
"""Lookup atol/rtol for a given dtype, with optional per-problem override.
The override dict is keyed by str(dtype). Values may be:
- a single number (or a string parseable as float, e.g. "5e-2") -> used
as both atol and rtol
- a dict {"atol": ..., "rtol": ...} -> used directly, with values coerced
"""
# Accept several spellings of the same dtype key (PyYAML quirks):
# "torch.bfloat16", "bfloat16", torch.bfloat16
if override is not None:
type_
... (2072 more chars){
"command": "<1334 chars \u2014 see below>",
"description": "Check _scaled_mm accuracy with/without scaling",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import reference
torch.manual_seed(42)
M, N, K = 4096, 4096, 4096
ref = reference.Model(M, N, K).cuda().eval()
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
# Reference
with torch.no_grad():
ref_out = ref(x)
# _scaled_mm with unit scales
w_fp8 = ref.weight.to(torch.float8_e4m3fn)
scale_a = torch.tensor(1.0, device='cuda')
scale_b = torch.tensor(1.0, device='cuda')
with torch.no_grad():
scaled_mm_out = torch._scaled_mm(x, w_fp8.T, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)
diff = (ref_out.float() - scaled_mm_out.float()).abs()
print(f'_scaled_mm (unit scale): max_abs_diff={diff.max().item():.6f}, mean_abs_diff={diff.mean().item():.6f}')
# _scaled_mm with proper scaling
scale_w = ref.weight.abs().max() / 448.0
w_fp8_scaled = (ref.weight / scale_w).to(torch.float8_e4m3fn)
scale_b2 = torch.tensor(scale_w.item(), device='cuda')
with torch.no_grad():
scaled_mm_out2 = torch._scaled_mm(x, w_fp8_scaled.T, scale_a=scale_a, scale_b=scale_b2, out_dtype=torch.bfloat16)
diff2 = (ref_out.float() - scaled_mm_out2.float()).abs()
print(f'_scaled_mm (scaled): max_abs_diff={diff2.max().item():.6f}, mean_abs_diff={diff2.mean().item():.6f}')
print(f'scale_w={scale_w.item():.8f}')
"_scaled_mm (unit scale): max_abs_diff=0.546875, mean_abs_diff=0.078210
_scaled_mm (scaled): max_abs_diff=0.429688, mean_abs_diff=0.062447
scale_w=0.00014782
{
"command": "<1246 chars \u2014 see below>",
"description": "Check tolerance issue and _scaled_mm accuracy",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import reference
# Check what tolerance the check would use
import yaml
from pathlib import Path
meta = yaml.safe_load(Path('problem.yaml').read_text())
tol_override = meta.get('tolerance')
print('tol_override:', tol_override)
print('ref_out dtype:', torch.bfloat16)
print('str(dtype):', str(torch.bfloat16))
# Check if _scaled_mm would even pass
torch.manual_seed(42)
M, N, K = 4096, 4096, 4096
ref = reference.Model(M, N, K).cuda().eval()
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
with torch.no_grad():
ref_out = ref(x)
w_fp8 = ref.weight.to(torch.float8_e4m3fn)
scale_a = torch.tensor(1.0, device='cuda')
scale_b = torch.tensor(1.0, device='cuda')
with torch.no_grad():
sol_out = torch._scaled_mm(x, w_fp8.T, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)
diff = (ref_out.float() - sol_out.float()).abs()
print(f'_scaled_mm vs ref: max_abs={diff.max().item():.6f} mean_abs={diff.mean().item():.6f}')
print(f'Passes 1e-2? {torch.allclose(ref_out.float(), sol_out.float(), atol=1e-2, rtol=1e-2)}')
print(f'Passes 0.15? {torch.allclose(ref_out.float(), sol_out.float(), atol=0.15, rtol=0.15)}')
"tol_override: {'fp8_e4m3fn': 0.15}
ref_out dtype: torch.bfloat16
str(dtype): torch.bfloat16
_scaled_mm vs ref: max_abs=0.546875 mean_abs=0.078210
Passes 1e-2? False
Passes 0.15? False
{
"command": "<717 chars \u2014 see below>",
"description": "Check dtype string representations"
}uv run python -c "
import torch
print('str(float8_e4m3fn):', str(torch.float8_e4m3fn))
print('short:', str(torch.float8_e4m3fn).split('.')[-1])
print()
# What if we use the override directly
override = {'fp8_e4m3fn': 0.15}
type_str = str(torch.bfloat16)
short_str = type_str.split('.')[-1]
print('bf16 type_str:', type_str, ' short:', short_str)
print('override.get(type_str):', override.get(type_str))
print('override.get(short_str):', override.get(short_str))
type_str2 = str(torch.float8_e4m3fn)
short_str2 = type_str2.split('.')[-1]
print('fp8 type_str:', type_str2, ' short:', short_str2)
print('override.get(type_str2):', override.get(type_str2))
print('override.get(short_str2):', override.get(short_str2))
"str(float8_e4m3fn): torch.float8_e4m3fn
short: float8_e4m3fn
bf16 type_str: torch.bfloat16 short: bfloat16
override.get(type_str): None
override.get(short_str): None
fp8 type_str: torch.float8_e4m3fn short: float8_e4m3fn
override.get(type_str2): None
override.get(short_str2): None
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -1,105 +1,66 @@
import torch
import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
+import triton
+import triton.language as tl
-cuda_src = r"""
-#include <torch/extension.h>
-#include <c10/cuda/CUDAStream.h>
-#include <cublasLt.h>
-#include <cuda_runtime.h>
-#define CUBLAS_CHECK(call) \
- do { \
- cublasStatus_t err = call; \
- if (err != CUBLAS_STATUS_SUCCESS) { \
- throw std::runtime_error("cuBLAS error: " + std::to_string(err)); \
- } \
- } while (0)
+@triton.jit
+def gemm_kernel(
+ x_ptr, w_ptr, y_ptr,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wn, stride_wk,
+ stride_ym, stride_yn,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
-static cublasLtHandle_t get_handle() {
- static thread_local cublasLtHandle_t handle = nullptr;
- if (!handle) CUBLAS_CHECK(cublasLtCreate(&handle));
- return handle;
-}
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-torch::Tensor fp8_gemm(torch::Tensor A, torch::Tensor B, int64_t M, int64_t N, int64_t K) {
- TORCH_CHECK(A.is_cuda() && B.is_cuda(), "inputs must be CUDA");
- TORCH_CHECK(A.dtype() == torch::kFloat8_e4m3fn, "A must be fp8_e4m3fn");
- TORCH_CHECK(B.dtype() == torch::kFloat8_e4m3fn, "B must be fp8_e4m3fn");
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_start in range(0, K, BLOCK_K):
+ offs_k = k_start + tl.arange(0, BLOCK_K)
+ x = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
+ mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
+ other=0.0,
+ )
+ w = tl.load(
+ w_ptr + offs_n[None, :] * stride_wn + offs_k[:, None] * stride_wk,
+ mask=(offs_k[:, None] < K) & (offs_n[None, :] < N),
+ other=0.0,
+ )
+ acc += tl.dot(x, w)
- auto handle = get_handle();
- cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();
+ y = acc.to(tl.bfloat16)
+ tl.store(
+ y_ptr + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn,
+ y,
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
+ )
- auto out = torch::empty({M, N}, A.options().dtype(torch::kBFloat16));
- float alpha = 1.0f, beta = 0.0f;
- auto sa = torch::tensor({1.0f}, torch::kFloat32).to(A.device());
- auto sb = torch::tensor({1.0f}, torch::kFloat32).to(A.device());
- float* dsa = sa.data_ptr<float>();
- float* dsb = sb.data_ptr<float>();
-
- // Row-major PyTorch tensors viewed as column-major for cuBLAS:
- // A_row(M,K) = A_col(K,M), B_row(N,K) = B_col(K,N), D_row(M,N) = D_col(N,M)
- // Compute D = A @ B.T => D_col = B_col^T @ A_col
- // cuBLAS args: A_cublas=B_col (K,N) transA=T, B_cublas=A_col (K,M) transB=N
- cublasLtMatrixLayout_t lA, lB, lD;
- CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lA, CUDA_R_8F_E4M3, (uint64_t)K, (uint64_t)N, (int64_t)K));
- CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lB, CUDA_R_8F_E4M3, (uint64_t)K, (uint64_t)M, (int64_t)K));
- CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lD, CUDA_R_16BF, (uint64_t)N, (uint64_t)M, (int64_t)N));
-
- cublasLtMatmulDesc_t desc;
- CUBLAS_CHECK(cublasLtMatmulDescCreate(&desc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
-
- cublasOperation_t tA = CUBLAS_OP_T, tB = CUBLAS_OP_N;
- CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA)));
- CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB)));
- CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &dsa, sizeof(dsa)));
- CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &dsb, sizeof(dsb)));
-
- size_t ws_size = 32 << 20;
- auto ws = torch::empty({(int64_t)ws_size}, torch::TensorOptions().dtype(torch::kInt8).device(A.device()));
-
- cublasLtMatmulPreference_t pref;
- CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&pref));
- CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_size, sizeof(ws_size)));
-
- int found = 0;
- cublasLtMatmulHeuristicResult_t heur;
- CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, lA, lB, lD, lD, pref, 1, &heur, &found));
-
- if (found) {
- CUBLAS_CHECK(cublasLtMatmul(handle, desc, &alpha,
- B.data_ptr(), lA, A.data_ptr(), lB, &beta,
- nullptr, lD, out.data_ptr(), lD,
- &heur.algo, ws.data_ptr(), ws_size, stream));
- } else {
- CUBLAS_CHECK(cublasLtMatmul(handle, desc, &alpha,
- B.data_ptr(), lA, A.data_ptr(), lB, &beta,
- nullptr, lD, out.data_ptr(), lD,
- nullptr, ws.data_ptr(), ws_size, stream));
- }
-
- cublasLtMatrixLayoutDestroy(lA);
- cublasLtMatrixLayoutDestroy(lB);
- cublasLtMatrixLayoutDestroy(lD);
- cublasLtMatmulDescDestroy(desc);
- cublasLtMatmulPreferenceDestroy(pref);
-
- return out;
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("fp8_gemm", &fp8_gemm, "FP8 GEMM via cuBLASLt");
-}
-"""
-
-_ext = load_inline(
- name="fp8_gemm_v2",
- cpp_sources="",
- cuda_sources=cuda_src,
- extra_ldflags=["-lcublasLt", "-lcublas"],
- verbose=True,
-)
+def _gemm(x_bf16, w_bf16):
+
... (974 more chars){
"content": "<2493 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Run correctness check with Triton bf16 kernel",
"timeout": 180000
}FAIL: shape 1 {'M': 4096, 'N': 4096, 'K': 4127} seed 42: tolerance exceeded: max_abs_diff=0.0625 (atol=0.01, rtol=0.01)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -1,66 +1,99 @@
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
+from torch.utils.cpp_extension import load_inline
+cuda_src = r"""
+#include <torch/extension.h>
+#include <c10/cuda/CUDAStream.h>
+#include <cublasLt.h>
+#include <cuda_runtime.h>
-@triton.jit
-def gemm_kernel(
- x_ptr, w_ptr, y_ptr,
- M, N, K,
- stride_xm, stride_xk,
- stride_wn, stride_wk,
- stride_ym, stride_yn,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
-):
- pid = tl.program_id(axis=0)
- num_pid_m = tl.cdiv(M, BLOCK_M)
- num_pid_n = tl.cdiv(N, BLOCK_N)
- pid_m = pid % num_pid_m
- pid_n = pid // num_pid_m
+#define CUBLAS_CHECK(call) \
+ do { \
+ cublasStatus_t err = call; \
+ if (err != CUBLAS_STATUS_SUCCESS) { \
+ throw std::runtime_error("cuBLAS error: " + std::to_string(err)); \
+ } \
+ } while (0)
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+static cublasLtHandle_t get_handle() {
+ static thread_local cublasLtHandle_t handle = nullptr;
+ if (!handle) CUBLAS_CHECK(cublasLtCreate(&handle));
+ return handle;
+}
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_start in range(0, K, BLOCK_K):
- offs_k = k_start + tl.arange(0, BLOCK_K)
- x = tl.load(
- x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
- mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
- other=0.0,
- )
- w = tl.load(
- w_ptr + offs_n[None, :] * stride_wn + offs_k[:, None] * stride_wk,
- mask=(offs_k[:, None] < K) & (offs_n[None, :] < N),
- other=0.0,
- )
- acc += tl.dot(x, w)
+torch::Tensor bf16_gemm(torch::Tensor A, torch::Tensor B, int64_t M, int64_t N, int64_t K) {
+ TORCH_CHECK(A.is_cuda() && B.is_cuda(), "inputs must be CUDA");
+ TORCH_CHECK(A.dtype() == torch::kBFloat16, "A must be bf16");
+ TORCH_CHECK(B.dtype() == torch::kBFloat16, "B must be bf16");
- y = acc.to(tl.bfloat16)
- tl.store(
- y_ptr + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn,
- y,
- mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
- )
+ auto handle = get_handle();
+ cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();
+ auto out = torch::empty({M, N}, A.options().dtype(torch::kBFloat16));
-def _gemm(x_bf16, w_bf16):
- M, K = x_bf16.shape
- N, K2 = w_bf16.shape
- assert K == K2
- y = torch.empty((M, N), dtype=torch.bfloat16, device=x_bf16.device)
- BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
- grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
- gemm_kernel[grid](
- x_bf16, w_bf16, y,
- M, N, K,
- x_bf16.stride(0), x_bf16.stride(1),
- w_bf16.stride(0), w_bf16.stride(1),
- y.stride(0), y.stride(1),
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
- )
- return y
+ float alpha = 1.0f, beta = 0.0f;
+
+ // Row-major -> col-major trick:
+ // A_row(M,K) = A_col(K,M), B_row(N,K) = B_col(K,N), D_row(M,N) = D_col(N,M)
+ // D = A @ B.T => D_col = B_col^T @ A_col
+ // cuBLAS: A_cublas=B_col(K,N) transA=T, B_cublas=A_col(K,M) transB=N
+ cublasLtMatrixLayout_t lA, lB, lD;
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lA, CUDA_R_16BF, (uint64_t)K, (uint64_t)N, (int64_t)K));
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lB, CUDA_R_16BF, (uint64_t)K, (uint64_t)M, (int64_t)K));
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lD, CUDA_R_16BF, (uint64_t)N, (uint64_t)M, (int64_t)N));
+
+ cublasLtMatmulDesc_t desc;
+ CUBLAS_CHECK(cublasLtMatmulDescCreate(&desc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
+
+ cublasOperation_t tA = CUBLAS_OP_T, tB = CUBLAS_OP_N;
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA)));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB)));
+
+ size_t ws_size = 32 << 20;
+ auto ws = torch::empty({(int64_t)ws_size}, torch::TensorOptions().dtype(torch::kInt8).device(A.device()));
+
+ cublasLtMatmulPreference_t pref;
+ CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&pref));
+ CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_size, sizeof(ws_size)));
+
+ int found = 0;
+ cublasLtMatmulHeuristicResult_t heur;
+ CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, lA, lB, lD, lD, pref, 1, &heur, &found));
+
+ if (found) {
+ CUBLAS_CHECK(cublasLtMatmul(handle, desc, &alpha,
+ B.data_ptr(), lA, A.data_ptr(), lB, &beta,
+ nullptr, lD, out.data_ptr(), lD,
+ &heur.algo, ws.data_ptr(), ws_size, stream));
+ } else {
+ CUBLAS_CHECK(cublasLtMatmul(handle, desc, &alpha,
+ B.data_ptr(), lA, A.data_ptr(), lB, &beta,
+ nullptr, lD, out.data_ptr(), lD,
+ nullptr, ws.data_ptr(), ws_size, stream));
+ }
+
+ cublasLtMatrixLayoutDestroy(lA);
+ cublasLtMatrixLayoutDestroy(lB);
+ cublasLtMatrixLayoutDestroy(lD);
+ cublasLtMatmulDescDestroy(desc);
+ cublasLtMatmulPreferenceDestroy(pref);
+
+ return out;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("bf16_gemm", &bf16_gemm, "BF16 GEMM via cuBLASLt");
+}
+"""
+
+_ext = load_inline(
+ name="bf16_gemm_ext",
+ cpp_sources="",
+ cuda_sources=cuda_src,
+ extra_ldflags=["-lcublasLt", "-lcublas"],
+ verbose=T
... (350 more chars){
"content": "<4522 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Run correctness check with cuBLASLt bf16 GEMM",
"timeout": 300000
}[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=bf16_gemm_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -fPIC -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/bf16_gemm_ext/main.cpp -o main.o
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=bf16_gemm_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/bf16_gemm_ext/cuda.cu -o cuda.cuda.o
[3/3] c++ main.o cuda.cuda.o -shared -lcublasLt -lcublas -L/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/p
... (160 more chars){
"todos": [
{
"content": "Check GPU env (nvidia-smi, nvcc, CUDA version, SM arch)",
"priority": "high",
"status": "completed"
},
{
"content": "Write initial solution.py with cuBLASLt FP8 GEMM via load_inline",
"priority": "high",
"status": "completed"
},
{
"content": "Test correctness with check.py",
"priority": "high",
"status": "completed"
},
{
"content": "Benchmark with benchmark.py",
"priority": "high",
"status": "in_progress"
},
{
"content": "Optimize for FP8 performance wh
... (95 more chars)content: 'Check GPU env (nvidia-smi, nvcc, CUDA version, SM arch)'
status: 'completed'
priority: 'high'
content: 'Write initial solution.py with cuBLASLt FP8 GEMM via load_inline'
status: 'completed'
priority: 'high'
content: 'Test correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Benchmark with benchmark.py'
status: 'in_progress'
priority: 'high'
content: 'Optimize for FP8 performance while maintaining correctness'
status: 'pending'
priority: 'medium'{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}ninja: no work to do.
[sota unavailable] RuntimeError: Invalid scaling configuration.
- For TensorWise scaling, a and b should be float8, scales should be float and singletons.
- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (4096, 1) and scale_b should be (1, 4096), and both should be contiguous.
- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (4096, 32) and scale_b should be (32, 4096), and both should be outer-dim-major.
- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (32, 32) and scale_b should be (32, 32), and both should be near-inner-dim-major (with 16-byte aligned strides).
- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have 524288 elements and scale_b should have 524288 elements, and both should be contiguous.
- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have 2097152 elements and scale_b should have 2097152 elements, and both should be contiguous.
Got a.dtype()=Float8_e4m3fn, scale_a.dtype()=Float, scale_a.size()=[], scale_a.stride()=[], b.dtype()=BFloat16, scale_b.dtype()=Float, scale_b.size()=[] and scale_b.stride()=[]
shape=0 variant=eager tflops=347.517 gbps=169.686 ms=0.395
shape=0 variant=compiled tflops=323.794 gbps=158.103 ms=0.424
shape=0 variant=solution tflops=347.954 gbps=169.899 ms=0
... (4632 more chars){
"command": "<1398 chars \u2014 see below>",
"description": "Benchmark padded K vs original for K=4127",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import time
M, N, K = 4096, 4096, 4127
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device='cuda', dtype=torch.bfloat16) * 0.02
# Eager bf16 matmul
x_bf = x.to(torch.bfloat16)
for _ in range(5): _ = x_bf @ w.T
torch.cuda.synchronize()
t0 = time.time()
for _ in range(100): _ = x_bf @ w.T
torch.cuda.synchronize()
t1 = time.time()
print(f'Eager K=4127: {(t1-t0)/100*1000:.3f} ms, {2*M*N*K/(t1-t0)/100/1e9:.1f} TFLOPS')
# Padded K
K_pad = ((K + 127) // 128) * 128
x_pad = torch.nn.functional.pad(x_bf, (0, K_pad - K))
w_pad = torch.nn.functional.pad(w, (0, K_pad - K))
for _ in range(5): _ = x_pad @ w_pad.T
torch.cuda.synchronize()
t0 = time.time()
for _ in range(100): _ = x_pad @ w_pad.T
torch.cuda.synchronize()
t1 = time.time()
print(f'Padded K={K_pad}: {(t1-t0)/100*1000:.3f} ms, {2*M*N*K/(t1-t0)/100/1e9:.1f} TFLOPS (using original K FLOPs)')
# For comparison: K=4096
x2 = (torch.rand(M, 4096, device='cuda') * 8 - 4).to(torch.bfloat16)
w2 = torch.randn(N, 4096, device='cuda', dtype=torch.bfloat16) * 0.02
for _ in range(5): _ = x2 @ w2.T
torch.cuda.synchronize()
t0 = time.time()
for _ in range(100): _ = x2 @ w2.T
torch.cuda.synchronize()
t1 = time.time()
print(f'K=4096: {(t1-t0)/100*1000:.3f} ms, {2*M*N*4096/(t1-t0)/100/1e9:.1f} TFLOPS')
"Eager K=4127: 0.725 ms, 19.1 TFLOPS
Padded K=4224: 0.345 ms, 40.2 TFLOPS (using original K FLOPs)
K=4096: 0.359 ms, 38.3 TFLOPS
{
"command": "<1046 chars \u2014 see below>",
"description": "Benchmark padding including overhead",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import time
M, N, K = 4096, 4096, 4127
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device='cuda', dtype=torch.bfloat16) * 0.02
def fn_nopad(x, w):
return x.to(torch.bfloat16) @ w.T
def fn_pad(x, w):
x_bf = x.to(torch.bfloat16)
K = x_bf.shape[1]
K_pad = ((K + 127) // 128) * 128
if K_pad != K:
x_bf = torch.nn.functional.pad(x_bf, (0, K_pad - K))
w_pad = torch.nn.functional.pad(w, (0, K_pad - K))
else:
w_pad = w
return x_bf @ w_pad.T
for _ in range(10): fn_nopad(x, w)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(200): fn_nopad(x, w)
torch.cuda.synchronize()
t1 = time.time()
print(f'No pad total: {(t1-t0)/200*1000:.3f} ms')
for _ in range(10): fn_pad(x, w)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(200): fn_pad(x, w)
torch.cuda.synchronize()
t1 = time.time()
print(f'With pad total: {(t1-t0)/200*1000:.3f} ms')
"No pad total: 0.788 ms
With pad total: 0.469 ms
{
"command": "<793 chars \u2014 see below>",
"description": "Test different K alignment values",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import time
M, N, K = 4096, 4096, 4127
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device='cuda', dtype=torch.bfloat16) * 0.02
for align in [16, 32, 64, 128, 256]:
K_pad = ((K + align - 1) // align) * align
x_bf = x.to(torch.bfloat16)
x_p = torch.nn.functional.pad(x_bf, (0, K_pad - K))
w_p = torch.nn.functional.pad(w, (0, K_pad - K))
for _ in range(5): _ = x_p @ w_p.T
torch.cuda.synchronize()
t0 = time.time()
for _ in range(200): _ = x_p @ w_p.T
torch.cuda.synchronize()
t1 = time.time()
print(f'align={align:3d} K_pad={K_pad:5d}: {(t1-t0)/200*1000:.3f} ms, {2*M*N*K/(t1-t0)/200/1e9:.1f} TFLOPS')
"align= 16 K_pad= 4128: 0.337 ms, 10.3 TFLOPS
align= 32 K_pad= 4128: 0.346 ms, 10.0 TFLOPS
align= 64 K_pad= 4160: 0.364 ms, 9.5 TFLOPS
align=128 K_pad= 4224: 0.372 ms, 9.3 TFLOPS
align=256 K_pad= 4352: 0.383 ms, 9.0 TFLOPS
{
"command": "<1587 chars \u2014 see below>",
"description": "Test FP8 GEMM accuracy with different scaling",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
# Test FP8 GEMM accuracy with different scaling strategies
torch.manual_seed(42)
for shape_name, M, N, K in [('sq', 4096, 4096, 4096), ('k4127', 4096, 4096, 4127), ('skinny', 32, 8192, 8192), ('llama', 4096, 14336, 4096)]:
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device='cuda', dtype=torch.bfloat16) * 0.02
ref = x.to(torch.bfloat16) @ w.T
# Per-tensor scaling
scale_w = w.abs().max() / 448.0
w_fp8 = (w / scale_w).to(torch.float8_e4m3fn)
sa = torch.tensor(1.0, device='cuda')
sb = torch.tensor(scale_w.item(), device='cuda')
sol = torch._scaled_mm(x, w_fp8.T, scale_a=sa, scale_b=sb, out_dtype=torch.bfloat16)
diff = (ref - sol).abs()
print(f'{shape_name:8s} per-tensor: max={diff.max().item():.4f} mean={diff.mean().item():.4f}')
# Per-row scaling (scale each row of w independently)
scale_w_row = w.abs().amax(dim=1, keepdim=True) / 448.0 # (N, 1)
w_fp8_row = (w / scale_w_row).to(torch.float8_e4m3fn)
# Compute y = x @ w.T = x @ (scale_w_row * w_fp8_row).T
# = x @ w_fp8_row.T * scale_w_row.T (scale applied per output column)
y_raw = torch._scaled_mm(x, w_fp8_row.T, scale_a=sa, scale_b=torch.tensor(1.0, device='cuda'), out_dtype=torch.bfloat16)
sol_row = y_raw * scale_w_row.T # (M, N) * (1, N)
diff_row = (ref - sol_row).abs()
print(f'{shape_name:8s} per-row: max={diff_row.max().item():.4f} mean={diff_row.mean().item():.4f}')
"Traceback (most recent call last):
File "<string>", line 17, in <module>
RuntimeError: Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (4096x4127).
sq per-tensor: max=0.4375 mean=0.0635
sq per-row: max=0.4375 mean=0.0630
{
"command": "<1948 chars \u2014 see below>",
"description": "Benchmark padded vs non-padded full forward pass",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import time
# Quick test: does padding K=4127 help in a full forward pass?
M, N, K = 4096, 4096, 4127
class PadModel(torch.nn.Module):
def __init__(self, M, N, K):
super().__init__()
self.M, self.N, self.K = M, N, K
self.weight = torch.nn.Parameter(torch.empty(N, K, dtype=torch.bfloat16))
torch.nn.init.normal_(self.weight, std=0.02)
# Pre-compute padded weight
K_pad = ((K + 31) // 32) * 32
self.K_pad = K_pad
if K_pad != K:
self.register_buffer('weight_pad', torch.nn.functional.pad(self.weight.data, (0, K_pad - K)))
def forward(self, x):
x_bf16 = x.to(torch.bfloat16)
if self.K_pad != self.K:
x_bf16 = torch.nn.functional.pad(x_bf16, (0, self.K_pad - self.K))
return x_bf16 @ self.weight_pad.T
return x_bf16 @ self.weight.T
class NoPadModel(torch.nn.Module):
def __init__(self, M, N, K):
super().__init__()
self.M, self.N, self.K = M, N, K
self.weight = torch.nn.Parameter(torch.empty(N, K, dtype=torch.bfloat16))
torch.nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
return x.to(torch.bfloat16) @ self.weight.T
pad_model = PadModel(M, N, K).cuda().eval()
nopad_model = NoPadModel(M, N, K).cuda().eval()
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
# Warmup
for _ in range(10):
with torch.no_grad(): pad_model(x)
with torch.no_grad(): nopad_model(x)
torch.cuda.synchronize()
# Time
for name, model in [('padded', pad_model), ('no-pad', nopad_model)]:
t0 = time.time()
for _ in range(200):
with torch.no_grad(): model(x)
torch.cuda.synchronize()
t1 = time.time()
ms = (t1-t0)/200*1000
tflops = 2*M*N*K / ms / 1e9
print(f'{name:8s}: {ms:.3f} ms, {tflops:.1f} TFLOPS')
"padded : 0.404 ms, 342.5 TFLOPS
no-pad : 0.831 ms, 166.7 TFLOPS
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py
@@ -7,6 +7,8 @@
#include <c10/cuda/CUDAStream.h>
#include <cublasLt.h>
#include <cuda_runtime.h>
+#include <unordered_map>
+#include <mutex>
#define CUBLAS_CHECK(call) \
do { \
@@ -22,6 +24,63 @@
return handle;
}
+struct GemmCacheKey {
+ int64_t M, N, K;
+ bool operator==(const GemmCacheKey& o) const { return M==o.M && N==o.N && K==o.K; }
+};
+struct GemmCacheKeyHash {
+ size_t operator()(const GemmCacheKey& k) const { return k.M*1000003 ^ k.N*1033 ^ k.K; }
+};
+
+struct GemmCacheEntry {
+ cublasLtMatrixLayout_t lA, lB, lD;
+ cublasLtMatmulDesc_t desc;
+ cublasLtMatmulAlgo_t algo;
+ bool has_algo;
+ int64_t ws_size;
+};
+
+static std::unordered_map<GemmCacheKey, GemmCacheEntry, GemmCacheKeyHash> cache;
+static std::mutex cache_mtx;
+
+static GemmCacheEntry& get_or_create(int64_t M, int64_t N, int64_t K, int device_id) {
+ GemmCacheKey key{M, N, K};
+ std::lock_guard<std::mutex> lock(cache_mtx);
+ auto it = cache.find(key);
+ if (it != cache.end()) return it->second;
+
+ auto handle = get_handle();
+ GemmCacheEntry e;
+
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&e.lA, CUDA_R_16BF, (uint64_t)K, (uint64_t)N, (int64_t)K));
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&e.lB, CUDA_R_16BF, (uint64_t)K, (uint64_t)M, (int64_t)K));
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&e.lD, CUDA_R_16BF, (uint64_t)N, (uint64_t)M, (int64_t)N));
+
+ CUBLAS_CHECK(cublasLtMatmulDescCreate(&e.desc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
+ cublasOperation_t tA = CUBLAS_OP_T, tB = CUBLAS_OP_N;
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(e.desc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA)));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(e.desc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB)));
+
+ e.ws_size = 64 << 20;
+ cublasLtMatmulPreference_t pref;
+ CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&pref));
+ CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &e.ws_size, sizeof(e.ws_size)));
+
+ cublasLtMatmulHeuristicResult_t heur[4];
+ int found = 0;
+ CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, e.desc, e.lA, e.lB, e.lD, e.lD, pref, 4, heur, &found));
+ if (found > 0) {
+ e.algo = heur[0].algo;
+ e.has_algo = true;
+ } else {
+ e.has_algo = false;
+ }
+ CUBLAS_CHECK(cublasLtMatmulPreferenceDestroy(pref));
+
+ auto result = cache.emplace(key, std::move(e));
+ return result.first->second;
+}
+
torch::Tensor bf16_gemm(torch::Tensor A, torch::Tensor B, int64_t M, int64_t N, int64_t K) {
TORCH_CHECK(A.is_cuda() && B.is_cuda(), "inputs must be CUDA");
TORCH_CHECK(A.dtype() == torch::kBFloat16, "A must be bf16");
@@ -34,50 +93,24 @@
float alpha = 1.0f, beta = 0.0f;
- // Row-major -> col-major trick:
- // A_row(M,K) = A_col(K,M), B_row(N,K) = B_col(K,N), D_row(M,N) = D_col(N,M)
- // D = A @ B.T => D_col = B_col^T @ A_col
- // cuBLAS: A_cublas=B_col(K,N) transA=T, B_cublas=A_col(K,M) transB=N
- cublasLtMatrixLayout_t lA, lB, lD;
- CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lA, CUDA_R_16BF, (uint64_t)K, (uint64_t)N, (int64_t)K));
- CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lB, CUDA_R_16BF, (uint64_t)K, (uint64_t)M, (int64_t)K));
- CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&lD, CUDA_R_16BF, (uint64_t)N, (uint64_t)M, (int64_t)N));
+ auto& e = get_or_create(M, N, K, A.get_device());
- cublasLtMatmulDesc_t desc;
- CUBLAS_CHECK(cublasLtMatmulDescCreate(&desc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
-
- cublasOperation_t tA = CUBLAS_OP_T, tB = CUBLAS_OP_N;
- CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA)));
- CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB)));
-
- size_t ws_size = 32 << 20;
- auto ws = torch::empty({(int64_t)ws_size}, torch::TensorOptions().dtype(torch::kInt8).device(A.device()));
-
- cublasLtMatmulPreference_t pref;
- CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&pref));
- CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_size, sizeof(ws_size)));
-
- int found = 0;
- cublasLtMatmulHeuristicResult_t heur;
- CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, lA, lB, lD, lD, pref, 1, &heur, &found));
-
- if (found) {
- CUBLAS_CHECK(cublasLtMatmul(handle, desc, &alpha,
- B.data_ptr(), lA, A.data_ptr(), lB, &beta,
- nullptr, lD, out.data_ptr(), lD,
- &heur.algo, ws.data_ptr(), ws_size, stream));
- } else {
- CUBLAS_CHECK(cublasLtMatmul(handle, desc, &alpha,
- B.data_ptr(), lA, A.data_ptr(), lB, &beta,
- nullptr, lD, out.data_ptr(), lD,
- nullptr, ws.data_ptr(), ws_size, stream));
+ static thread_local torch::Tensor ws;
+ if (!ws.defined() || ws.numel() < (int64_t)e.ws_size || ws.device() != A.device()) {
+ ws = torch::empty({(int64_t)e.ws_size}, torch::TensorOptions().dtype(torch::kInt8).device(A.device()));
}
- cublasLtMatrixLayoutDestroy(lA);
- cublasLtMatrixLayoutDestroy(lB);
- cublasLtMatrixLayoutDestroy(lD);
- cublasLtMatmulDescDestroy(desc);
- cublasLtMatmulPreferenceDestroy(pref);
+ if (e.has_algo) {
+ CUBLAS_CHECK(cublasLtMatmul(handle, e.desc, &alpha,
+ B.data_ptr(), e.lA, A.data_ptr(), e.lB, &beta,
+ nullptr, e.lD, out.data_ptr(), e.lD,
+ &e.algo, ws.data_ptr(), e.ws_size, stream));
+ } else {
+ CUBLAS_CHECK(cublasLtMatmul(handle, e.desc, &alpha,
+ B.data_ptr(), e.lA, A.data_ptr(), e.lB, &beta,
+ nullptr, e.lD, out.data_ptr(), e.lD,
... (1078 more chars){
"content": "<5654 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python check.py 2>&1",
"description": "Run correctness check with optimized solution",
"timeout": 300000
}[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=bf16_gemm_v3 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -fPIC -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/bf16_gemm_v3/main.cpp -o main.o
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=bf16_gemm_v3 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/bf16_gemm_v3/cuda.cu -o cuda.cuda.o
[3/3] c++ main.o cuda.cuda.o -shared -lcublasLt -lcublas -L/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/pytho
... (270 more chars){
"command": "<1284 chars \u2014 see below>",
"description": "Debug numerical differences from K-padding",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
torch.manual_seed(42)
M, N, K = 4096, 4096, 4127
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device='cuda', dtype=torch.bfloat16) * 0.02
x_bf = x.to(torch.bfloat16)
w_bf = w
# Reference (no padding)
y_ref = x_bf @ w_bf.T
# With padding
K_pad = ((K + 31) // 32) * 32 # 4128
x_pad = torch.nn.functional.pad(x_bf, (0, K_pad - K))
w_pad = torch.nn.functional.pad(w_bf, (0, K_pad - K))
y_pad = x_pad @ w_pad.T
diff = (y_ref - y_pad).abs()
print(f'torch.matmul padded vs unpadded: max={diff.max().item():.6f} mean={diff.mean().item():.6f}')
print(f'Passes 1e-2? {torch.allclose(y_ref.float(), y_pad.float(), atol=1e-2, rtol=1e-2)}')
# Test our cuBLASLt with padded K
import solution
sol_out = solution._ext.bf16_gemm(x_pad.contiguous(), w_pad.contiguous(), M, N, K_pad)
diff2 = (y_ref - sol_out).abs()
print(f'cuBLASLt padded: max={diff2.max().item():.6f} mean={diff2.mean().item():.6f}')
# Test cuBLASLt with original K (no padding)
sol_out2 = solution._ext.bf16_gemm(x_bf.contiguous(), w_bf.contiguous(), M, N, K)
diff3 = (y_ref - sol_out2).abs()
print(f'cuBLASLt no pad: max={diff3.max().item():.6f} mean={diff3.mean().item():.6f}')
"torch.matmul padded vs unpadded: max=0.062500 mean=0.002365
Passes 1e-2? False
ninja: no work to do.
cuBLASLt padded: max=0.062500 mean=0.002365
cuBLASLt no pad: max=0.000000 mean=0.000000
{
"command": "<1430 chars \u2014 see below>",
"description": "Test split-K approach for K=4127",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm && uv run python -c "
import torch
import time
torch.manual_seed(42)
M, N, K = 4096, 4096, 4127
x = (torch.rand(M, K, device='cuda') * 8 - 4).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device='cuda', dtype=torch.bfloat16) * 0.02
x_bf = x.to(torch.bfloat16)
w_bf = w
# Reference: single GEMM with K=4127
y_ref = x_bf @ w_bf.T
# Split-K approach: K_aligned=4096, K_tail=31
K_aligned = (K // 32) * 32 # 4096
K_tail = K - K_aligned # 31
y1 = x_bf[:, :K_aligned] @ w_bf[:, :K_aligned].T
y2 = x_bf[:, K_aligned:] @ w_bf[:, K_aligned:].T
y_split = y1 + y2
diff = (y_ref - y_split).abs()
print(f'Split-K vs reference: max={diff.max().item():.6f} mean={diff.mean().item():.6f}')
print(f'Passes 1e-2? {torch.allclose(y_ref.float(), y_split.float(), atol=1e-2, rtol=1e-2)}')
print(f'Passes 0.15? {torch.allclose(y_ref.float(), y_split.float(), atol=0.15, rtol=0.15)}')
# Benchmark split approach
for _ in range(10):
y1 = x_bf[:, :K_aligned] @ w_bf[:, :K_aligned].T
y2 = x_bf[:, K_aligned:] @ w_bf[:, K_aligned:].T
_ = y1 + y2
torch.cuda.synchronize()
t0 = time.time()
for _ in range(200):
y1 = x_bf[:, :K_aligned] @ w_bf[:, :K_aligned].T
y2 = x_bf[:, K_aligned:] @ w_bf[:, K_aligned:].T
_ = y1 + y2
torch.cuda.synchronize()
t1 = time.time()
ms = (t1-t0)/200*1000
tflops = 2*M*N*K / ms / 1e9
print(f'Split-K: {ms:.3f} ms, {tflops:.1f} TFLOPS')
"Split-K vs reference: max=0.125000 mean=0.004883
Passes 1e-2? False
Passes 0.15? True
Split-K: 0.810 ms, 170.9 TFLOPS