{
"problem": "01_fp8_gemm",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-27b",
"reasoning_effort": "",
"has_solution": false,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 417,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 513532, "output_tokens": 4392, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 5414, "total_cost_usd": null}
}
{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive FP8 e4m3 GEMM reference (correctness only, NOT the SOTA baseline).
2:
3: We cast inputs to bf16 and use torch.matmul. The agent's solution must match
4: this numerically within the fp8 tolerance declared in problem.yaml.
5: """
6: import torch
7: import torch.nn as nn
8:
9: OP_TYPE = "gemm"
10: SUPPORTED_PRECISIONS = ["fp8_e4m3"]
11: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
12:
13:
14: class Model(nn.Module):
15: """y = (x @ w.T).to(bf16), where x is fp8_e4m3 (M, K), w is fp8_e4m3 (N, K)."""
16:
17: def __init__(self, M: int, N: int, K: int):
18: super().__init__()
19: self.M, self.N, self.K = M, N, K
20: # Weights stored as parameters so state_dict is well-defined.
21: # We initialize in bf16 then cast; the fp8 dtype is set by get_inputs.
22: self.weight = nn.Parameter(torch.empty(N, K, dtype=torch.bfloat16))
23: nn.init.normal_(self.weight, std=0.02)
24:
25: def forward(self, x: torch.Tensor) -> torch.Tensor:
26: # Upcast to bf16 for the naive reference; the kernel equivalent would
27: # use mma.sync f8f6f4 kind directly.
28: x_bf = x.to(torch.bfloat16)
29: w_bf = self.weight.to(torch.bfloat16)
30: return x_bf @ w_bf.T # (M, N) bf16
31:
32:
33: M = 4096
34: N = 4096
35: K = 4096
36:
37:
38: def get_inputs():
39: # fp8_e4m3 inp
... (227 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for FP8 GEMM.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or None
42:
43: # --- Per-shape c
... (2907 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for FP8 GEMM.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "compute")
40: flops_formula
... (3416 more chars){
"filePath": "/testbed/problems/01_fp8_gemm/shapes.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 01_fp8_gemm
2: display_name: "FP8 e4m3 GEMM"
3: precision: fp8_e4m3
4: regime: compute
5:
6: # Dense-equivalent FLOPs per invocation. Evaluated per-shape; the benchmark
7: # substitutes M/N/K from each shapes.py entry.
8: flops_formula: "2 * M * N * K"
9:
10: # Bytes moved (approximate): fp8 inputs in, bf16 output.
11: bytes_formula: "M*K + K*N + M*N*2"
12:
13: hardware: [RTX_PRO_6000]
14: peak_tflops_key: fp8 # index into src/hardware/rtx_pro_6000.py peak_tflops_dense
15: peak_bandwidth_key: dram
16:
17: # Per-dtype tolerance override (optional). Defaults are in src/eval/correctness.py.
18: tolerance:
19: fp8_e4m3fn: 0.15
20:
21: # Forbidden ops — using any of these in solution.py fails correctness post-hoc.
22: # torch._scaled_mm is Stanford-style cheating (just call the vendor library).
23: forbidden:
24: - "torch._scaled_mm"
25: - "torch.ops.aten._scaled_mm"
26:
27: sota:
28: name: "FlashInfer FP8 GEMM"
29: url: "https://github.com/flashinfer-ai/flashinfer"
30: function: "flashinfer.gemm.fp8_gemm"
31: deps:
32: - "flashinfer>=0.6.8"
33: # Documented H100 throughput for this shape (informational, not graded):
34: reference_throughput_tflops_h100: 550
35:
36: num_correct_trials: 3
37: num_perf_trials: 30
(End of file - total 37 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for FP8 GEMM: flashinfer.gemm.fp8_gemm.
2:
3: If flashinfer is not installed or the SM120 path isn't supported, this falls
4: back to torch._scaled_mm which is the cuBLAS FP8 path. The benchmark treats
5: whichever succeeds as the SOTA reference line.
6:
7: Agents are FORBIDDEN from using torch._scaled_mm in their solution (see
8: problem.yaml.forbidden). This file is only for the benchmark's reference line.
9: """
10: from __future__ import annotations
11:
12: import torch
13:
14:
15: def _try_flashinfer(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor | None:
16: try:
17: import flashinfer # noqa: F401
18: # Note: flashinfer's FP8 GEMM API surface may differ; adapt if needed.
19: # Placeholder call — replace with the actual flashinfer entry point
20: # once validated on SM120.
21: return None
22: except ImportError:
23: return None
24:
25:
26: def _scaled_mm(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
27: # torch._scaled_mm wants per-tensor scales. Use unit scales for the reference.
28: scale_a = torch.tensor(1.0, device=x.device)
29: scale_b = torch.tensor(1.0, device=x.device)
30: out = torch._scaled_mm(
31: x,
32: w.T,
33: scale_a=scale_a,
34: scale_b=scale_b,
35: out_dtype=torch.bfloat16,
36: )
37: return out
... (576 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for FP8 GEMM.
2:
3: Mix of:
4: - square aligned (the easy case)
5: - off-alignment K (common real-world failure mode for tile-quantized kernels)
6: - skinny (decode-like, memory-bound)
7: - rectangular (prefill with grouped attention)
8: """
9:
10: SHAPES = [
11: {"M": 4096, "N": 4096, "K": 4096}, # square aligned
12: {"M": 4096, "N": 4096, "K": 4127}, # K not multiple of 128 -> forces predicated tails
13: {"M": 32, "N": 8192, "K": 8192}, # skinny M (decode)
14: {"M": 4096, "N": 14336, "K": 4096}, # Llama3 up-proj shape
15: ]
(End of file - total 15 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware</path>
<type>directory</type>
<entries>
__init__.py
__pycache__/
m4_max.py
rtx_pro_6000.py
(4 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py</path>
<type>file</type>
<content>
1: """Roofline math: achieved TFLOPS / GB/s, peak fraction.
2:
3: FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
4: dense-equivalent algorithmic work; kernels that skip work (sparsity, early
5: exit) are still measured against the dense cost.
6: """
7: from __future__ import annotations
8:
9:
10: def compute_tflops(flops: float, time_ms: float) -> float:
11: """flops per invocation and wall time in ms -> TFLOPS."""
12: if time_ms <= 0:
13: return 0.0
14: return flops / (time_ms * 1e-3) / 1e12
15:
16:
17: def compute_gbps(bytes_moved: float, time_ms: float) -> float:
18: """bytes per invocation and wall time in ms -> GB/s."""
19: if time_ms <= 0:
20: return 0.0
21: return bytes_moved / (time_ms * 1e-3) / 1e9
22:
23:
24: def peak_fraction(
25: achieved: float,
26: peak: float,
27: ) -> float:
28: """Return achieved / peak, clamped to [0, infinity)."""
29: if peak <= 0:
30: return 0.0
31: return max(0.0, achieved / peak)
(End of file - total 31 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"command": "nvidia-smi --query-gpu=name,driver_version --format=csv,noheader 2>/dev/null && nvcc --version 2>/dev/null || echo \"nvcc not in path\"",
"description": "Check GPU and CUDA version"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 595.58.03
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
{
"command": "python -c \"import torch; print(f'pytorch {torch.__version__}'); print(f'cuda {torch.version.cuda}'); print(f'capability {torch.cuda.get_device_capability()}')\" 2>/dev/null || uv run python -c \"import torch; print(f'pytorch {torch.__version__}'); print(f'cuda {torch.version.cuda}'); print(f'capability {torch.cuda.get_device_capability()}')\"",
"description": "Check PyTorch and CUDA capability"
}pytorch 2.11.0+cu130
cuda 13.0
capability (12, 0)
{
"command": "which triton 2>/dev/null; python -c \"import triton; print(f'triton {triton.__version__}')\" 2>/dev/null || uv run python -c \"import triton; print(f'triton {triton.__version__}')\" 2>/dev/null || echo \"triton not available\"",
"description": "Check Triton availability"
}triton not found
triton 3.6.0
{
"command": "ls /usr/local/cuda*/targets/x86_64-linux/include/cutlass* 2>/dev/null | head -5; ls /usr/local/cuda*/include/cutlass 2>/dev/null | head -5; find /usr/local -maxdepth 4 -name \"cutlass\" -type d 2>/dev/null | head -5; echo \"---\"; pip list 2>/dev/null | grep -i cutlass || echo \"no cutlass pip\"",
"description": "Check CUTLASS availability"
}zsh:1: no matches found: /usr/local/cuda*/targets/x86_64-linux/include/cutlass*
zsh:1: no matches found: /usr/local/cuda*/include/cutlass
---
no cutlass pip
{
"todos": [
{
"content": "Write initial Triton FP8 GEMM kernel in solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness across all shapes",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py to measure baseline performance",
"status": "pending",
"priority": "high"
},
{
"content": "Iterate optimization (tuning, TMA, larger blocks)",
"status": "pending",
"priority": "medium"
}
]
}content: 'Write initial Triton FP8 GEMM kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py to measure baseline performance'
status: 'pending'
priority: 'high'
content: 'Iterate optimization (tuning, TMA, larger blocks)'
status: 'pending'
priority: 'medium'{
"command": "<259 chars \u2014 see below>",
"description": "Check Triton FP8 support on SM120"
}uv run python -c "
import triton
print(f'Triton version: {triton.__version__}')
print(f'Target arch: {triton.runtime.driver.active.get_current_target()}')
import triton.language as tl
print('tl.float8e4m3:', tl.float8e4m3)
print('tl.bfloat16:', tl.bfloat16)
"Traceback (most recent call last):
File "<string>", line 6, in <module>
AttributeError: module 'triton.language' has no attribute 'float8e4m3'. Did you mean: 'float8e4b8'?
Triton version: 3.6.0
Target arch: GPUTarget(backend='cuda', arch=120, warp_size=32)
{
"command": "<487 chars \u2014 see below>",
"description": "Find Triton fp8 dtype names"
}uv run python -c "
import triton.language as tl
# Check what float8 types are available
print([x for x in dir(tl) if 'float8' in x.lower() or 'fp8' in x.lower()])
# Check if we can create fp8 dtype via tl.dtype
try:
fp8 = tl.dtype('float8e4m3')
print('tl.dtype(float8e4m3):', fp8)
except Exception as e:
print('tl.dtype(float8e4m3) error:', e)
try:
fp8 = tl.float8e4m3fn
print('tl.float8e4m3fn:', fp8)
except Exception as e:
print('tl.float8e4m3fn error:', e)
"['float8e4b15', 'float8e4b8', 'float8e4nv', 'float8e5', 'float8e5b16']
tl.dtype(float8e4m3) error: float8e4m3
tl.float8e4m3fn error: module 'triton.language' has no attribute 'float8e4m3fn'
{
"tool": "",
"error": "Model tried to call unavailable tool ''. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill."
}The arguments provided to the tool are invalid: Model tried to call unavailable tool ''. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill.{
"command": "<217 chars \u2014 see below>",
"description": "Inspect Triton float8e4nv type"
}uv run python -c "
import triton.language as tl
print('float8e4nv:', tl.float8e4nv)
print('type:', type(tl.float8e4nv))
# Check attributes
print('dir:', [x for x in dir(tl.float8e4nv) if not x.startswith('_')])
" 2>&1float8e4nv: fp8e4nv
type: <class 'triton.language.core.dtype'>
dir: ['FP_TYPES', 'KIND', 'OTHER_TYPES', 'SIGNEDNESS', 'SINT_TYPES', 'STANDARD_FP_TYPES', 'UINT_TYPES', 'cache_key_part', 'codegen_name', 'exponent_bias', 'fp_mantissa_width', 'get_int_max_value', 'get_int_min_value', 'is_bf16', 'is_block', 'is_bool', 'is_const', 'is_dtype', 'is_floating', 'is_fp16', 'is_fp32', 'is_fp64', 'is_fp8', 'is_fp8e4b15', 'is_fp8e4b8', 'is_fp8e4nv', 'is_fp8e5', 'is_fp8e5b16', 'is_int', 'is_int1', 'is_int16', 'is_int32', 'is_int64', 'is_int8', 'is_int_signed', 'is_int_unsigned', 'is_ptr', 'is_standard_floating', 'is_uint16', 'is_uint32', 'is_uint64', 'is_uint8', 'is_void', 'itemsize', 'kind', 'mangle', 'name', 'primitive_bitwidth', 'scalar', 'to_ir', 'with_element_ty']
{
"command": "<331 chars \u2014 see below>",
"description": "Check Triton auto and bf16 types"
}uv run python -c "
import triton.language as tl
# Check tl.auto
try:
auto_type = tl.auto
print('tl.auto:', auto_type)
except:
print('tl.auto not available')
# Check if we can use tl.bfloat16
print('tl.bfloat16:', tl.bfloat16)
# Check dot semantics
import triton
print('triton.__version__:', triton.__version__)
" 2>&1tl.auto not available
tl.bfloat16: bf16
triton.__version__: 3.6.0
{
"command": "<657 chars \u2014 see below>",
"description": "Verify fp8e4nv is e4m3 format"
}uv run python -c "
import triton.language as tl
fp8 = tl.float8e4nv
print(f'name: {fp8.name}')
print(f'bitwidth: {fp8.primitive_bitwidth}')
print(f'mantissa: {fp8.fp_mantissa_width}')
print(f'exponent_bias: {fp8.exponent_bias}')
print(f'is_fp8: {fp8.is_fp8}')
print(f'is_fp8e4nv: {fp8.is_fp8e4nv}')
# e4m3 has 4 exponent bits, 3 mantissa bits, bias = 7
print()
print('PyTorch fp8 range:')
import torch
x = torch.tensor([1.0], dtype=torch.float8_e4m3fn)
print('torch fp8e4m3 max:', torch.finfo(torch.float8_e4m3fn).max)
print('torch fp8e4m3 min:', torch.finfo(torch.float8_e4m3fn).min)
print('torch fp8e4m3 eps:', torch.finfo(torch.float8_e4m3fn).eps)
" 2>&1name: fp8e4nv
bitwidth: 8
mantissa: 3
exponent_bias: 7
is_fp8: <bound method dtype.is_fp8 of triton.language.float8e4nv>
is_fp8e4nv: <bound method dtype.is_fp8e4nv of triton.language.float8e4nv>
PyTorch fp8 range:
torch fp8e4m3 max: 448.0
torch fp8e4m3 min: -448.0
torch fp8e4m3 eps: 0.125
{
"command": "<2150 chars \u2014 see below>",
"description": "Test Triton FP8 dot product"
}uv run python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def fp8_simple_kernel(X, Y, C, M, N, K, stride_xm, stride_xk, stride_yn, stride_yk, stride_cn, stride_cm, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
off_k = k * BLOCK_K + tl.arange(0, BLOCK_K)
x = tl.load(X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk,
mask=(off_m[:, None] < M) & (off_k[None, :] < K),
other=0.0,
element_ty=tl.float8e4nv)
y = tl.load(Y + off_n[:, None] * stride_yn + off_k[None, :] * stride_yk,
mask=(off_n[:, None] < N) & (off_k[None, :] < K),
other=0.0,
element_ty=tl.float8e4nv)
acc = tl.dot(x, y, acc=acc)
off_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
tl.store(C + off_cm[:, None] * stride_cm + off_cn[None, :] * stride_cn,
acc.to(tl.bfloat16),
mask=(off_cm[:, None] < M) & (off_cn[None, :] < N))
# Test
M, N, K = 64, 64, 64
x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn).cuda()
w = (torch.rand(N, K) * 8 - 4).to(torch.float8_e4m3fn).cuda()
out = torch.empty(M, N, dtype=torch.bfloat16, device='cuda')
grid = (triton.cdiv(M, 128), triton.cdiv(N, 128))
fp8_simple_kernel[grid](x, w, out, M, N, K,
x.stride(0), x.stride(1), w.stride(0), w.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=128, BLOCK_N=128, BLOCK_K=64)
# Reference
ref = x.to(torch.bfloat16) @ w.to(torch.bfloat16).T
print('out shape:', out.shape)
print('ref shape:', ref.shape)
max_diff = (out.float() - ref.float()).abs().max().item()
print(f'max abs diff: {max_diff:.6f}')
print('close:', torch.allclose(out.float(), ref.float(), atol=0.15, rtol=0.15))
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/test_kernel.py
@@ -0,0 +1,58 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def fp8_gemm_kernel(X, W, C, M, N, K, stride_xm, stride_xk, stride_wn, stride_wk, stride_cm, stride_cn,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ off_k = k * BLOCK_K + tl.arange(0, BLOCK_K)
+
+ x = tl.load(X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk,
+ mask=(off_m[:, None] < M) & (off_k[None, :] < K),
+ other=0.0,
+ element_ty=tl.float8e4nv)
+ w = tl.load(W + off_n[:, None] * stride_wn + off_k[None, :] * stride_wk,
+ mask=(off_n[:, None] < N) & (off_k[None, :] < K),
+ other=0.0,
+ element_ty=tl.float8e4nv)
+
+ acc = tl.dot(x, w, acc=acc)
+
+ cm_ptr = C + off_m[:, None] * stride_cm + off_n[None, :] * stride_cn
+ tl.store(cm_ptr, acc.to(tl.bfloat16),
+ mask=(off_m[:, None] < M) & (off_n[None, :] < N))
+
+
+def test_fp8_gemm():
+ M, N, K = 64, 64, 64
+ x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn).cuda()
+ w = (torch.rand(N, K) * 8 - 4).to(torch.float8_e4m3fn).cuda()
+ out = torch.empty(M, N, dtype=torch.bfloat16, device='cuda')
+
+ grid = (triton.cdiv(M, 128), triton.cdiv(N, 128))
+ fp8_gemm_kernel[grid](x, w, out, M, N, K,
+ x.stride(0), x.stride(1), w.stride(0), w.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=128, BLOCK_N=128, BLOCK_K=64)
+
+ ref = x.to(torch.bfloat16) @ w.to(torch.bfloat16).T
+ print(f'shape: out={out.shape} ref={ref.shape}')
+ max_diff = (out.float() - ref.float()).abs().max().item()
+ print(f'max abs diff: {max_diff:.6f}')
+ ok = torch.allclose(out.float(), ref.float(), atol=0.15, rtol=0.15)
+ print(f'allclose (atol=0.15, rtol=0.15): {ok}')
+ return out, ref
+
+
+if __name__ == '__main__':
+ test_fp8_gemm()
{
"content": "<2213 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/test_kernel.py"
}Wrote file successfully.{
"command": "uv run python test_kernel.py 2>&1",
"description": "Test Triton FP8 kernel",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
TypeError: load() got an unexpected keyword argument 'element_ty'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/test_kernel.py", line 58, in <module>
test_fp8_gemm()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/test_kernel.py", line 43, in test_fp8_gemm
fp8_gemm_kernel[grid](x, w, out, M, N, K,
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^
... (1181 more chars){
"command": "<332 chars \u2014 see below>",
"description": "Check Triton load signature"
}uv run python -c "
import triton.language as tl
import inspect
# Check load signature
try:
sig = inspect.signature(tl.load)
print('load params:', list(sig.parameters.keys()))
except Exception as e:
print('load sig error:', e)
# Check the function object
print('load:', tl.load)
print('load type:', type(tl.load))
" 2>&1load params: ['pointer', 'mask', 'other', 'boundary_check', 'padding_option', 'cache_modifier', 'eviction_policy', 'volatile', '_semantic']
load: <function load at 0x7601f1de60c0>
load type: <class 'function'>