"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
import torch
import torch.nn as nn
import triton
import triton.language as tl
E4M3_MAX = 448.0
@triton.jit
def _fp8_gemm_kernel(
A,
B,
C,
Scales,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES):
a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
acc = acc * scales[None, :]
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
@triton.jit
def _fp8_gemm_skinny_kernel(
A,
B,
C,
Scales,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
pid_n = tl.program_id(0)
offs_m = tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES):
a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
acc = acc * scales[None, :]
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
if M <= 64:
return 32, 128, 256, 4, 3, 1
if N >= 12000:
return 128, 256, 128, 8, 3, 4
return 128, 256, 128, 8, 3, 4
def _fp8_gemm(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
M: int,
) -> torch.Tensor:
K = x.shape[1]
N, K_w = weight.shape
assert K == K_w
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N)
if M <= 64:
grid = (triton.cdiv(N, block_n),)
_fp8_gemm_skinny_kernel[grid](
x,
weight,
y,
weight_scale,
M,
N,
K,
x.stride(0),
x.stride(1),
weight.stride(1),
weight.stride(0),
y.stride(0),
y.stride(1),
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
NUM_STAGES=num_stages,
num_warps=num_warps,
)
else:
grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
_fp8_gemm_kernel[grid](
x,
weight,
y,
weight_scale,
M,
N,
K,
x.stride(0),
x.stride(1),
weight.stride(1),
weight.stride(0),
y.stride(0),
y.stride(1),
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
GROUP_M=group_m,
NUM_STAGES=num_stages,
num_warps=num_warps,
)
return y
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int):
super().__init__()
self.M, self.N, self.K = M, N, K
w = torch.empty(N, K, dtype=torch.bfloat16)
nn.init.normal_(w, std=0.02)
s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
self.register_buffer("weight", w_fp8)
self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
self._weight_padded: torch.Tensor | None = None
self._weight_version: int = -1
def forward(self, x: torch.Tensor) -> torch.Tensor:
K = x.shape[1]
pad_k = (128 - (K % 128)) % 128
if pad_k:
if self._weight_padded is None or self._weight_version != self.weight._version:
self._weight_padded = torch.nn.functional.pad(self.weight, (0, pad_k))
self._weight_version = self.weight._version
x = torch.nn.functional.pad(x, (0, pad_k))
weight = self._weight_padded
else:
weight = self.weight
return _fp8_gemm(x, weight, self.weight_scale, self.M)
M = 4096
N = 4096
K = 4096
def get_inputs():
x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn)
return [x]
def get_init_inputs():
return [M, N, K]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_114532_cursor_composer-2.5-fast_01_fp8_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T18:03:56.203594+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T18:03:56.616695+00:00 elapsed_s=0.413 ms=0.219136
shape=0 variant=solution tflops=627.186 gbps=306.243 ms=0.219
shape=0 solution_peak_fraction=0.6272
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T18:03:57.165588+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T18:03:57.326882+00:00 elapsed_s=0.161 ms=0.254816
shape=1 variant=solution tflops=543.448 gbps=264.359 ms=0.255
shape=1 solution_peak_fraction=0.5434
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T18:03:59.311997+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T18:03:59.496462+00:00 elapsed_s=0.184 ms=0.048112
shape=2 variant=solution tflops=89.270 gbps=1411.193 ms=0.048
shape=2 solution_peak_fraction=0.0893
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T18:04:01.294873+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T18:04:01.524438+00:00 elapsed_s=0.230 ms=0.725008
shape=3 variant=solution tflops=663.491 gbps=266.118 ms=0.725
shape=3 solution_peak_fraction=0.6635
peak_fraction: 0.3769
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_114532_cursor_composer-2.5-fast_01_fp8_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_114532_cursor_composer-2.5-fast_01_fp8_gemm/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_114532_cursor_composer-2.5-fast_01_fp8_gemm/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_114532_cursor_composer-2.5-fast_01_fp8_gemm/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_114532_cursor_composer-2.5-fast_01_fp8_gemm/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_114532_cursor_composer-2.5-fast_01_fp8_gemm/repo
Installed 56 packages in 193ms
PASS
{
"run_id": "20260615_114532_cursor_composer-2.5-fast_01_fp8_gemm",
"run_group": "",
"problem": "01_fp8_gemm",
"harness": "cursor",
"model": "composer-2.5-fast",
"reasoning_effort": "",
"started_at": "2026-06-15T11:45:32-06:00",
"harness_finished_at": "2026-06-15T11:51:15-06:00",
"finished_at": "2026-06-15T12:04:01-06:00",
"start_epoch": 1781545532,
"harness_end_epoch": 1781545875,
"end_epoch": 1781546641,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.3769,
"template_mutated": false,
"elapsed_seconds": 343,
"total_elapsed_seconds": 1109,
"check_elapsed_seconds": 759,
"benchmark_elapsed_seconds": 7,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 90.71720116618076,
"usage": {"input_tokens": 61080, "output_tokens": 31116, "cache_read_tokens": 1122545, "cache_creation_tokens": 0, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"targetDirectory": "/workspace/problems/01_fp8_gemm",
"globPattern": "**/*"
}{
"path": "/workspace/problems/01_fp8_gemm/reference.py"
}null{
"path": "/workspace/problems/01_fp8_gemm/solution.py"
}"""FP8 e4m3 GEMM reference (correctness only, NOT the SOTA baseline).
Genuine fp8 x fp8: BOTH operands are fp8_e4m3. The weight is stored as fp8
(normalized into the e4m3 range) together with a per-output-channel scale, the
standard scaled-fp8 inference layout. The reference upcasts the fp8 operands to
bf16 and matmuls, then applies the per-channel scale — this DEFINES the fp8
target. A real fp8 x fp8 MMA kernel matches it (and can exceed the bf16 roofline
ceiling of ~0.5); a bf16-upcast kernel also matches but stays capped at ~0.5.
"""
import torch
import torch.nn as nn
OP_TYPE = "gemm"
SUPPORTED_PRECISIONS = ["fp8_e4m3"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
E4M3_MAX = 448.0
class Model(nn.Module):
"""y = ((x @ w.T) * weight_scale).to(bf16).
x: fp8_e4m3 (M, K). w: fp8_e4m3 (N, K) normalized to the e4m3 range.
weight_scale: (N,) per-output-channel dequant scale.
"""
def __init__(self, M: int, N: int, K: int):
super().__init__()
self.M, self.N, self.K = M, N, K
w = torch.empty(N, K, dtype=torch.bfloat16)
nn.init.normal_(w, std=0.02)
s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12) # (N,1)
w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
self.register_buffer("weight", w_fp8) # (N, K) fp8
self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32)) # (N,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
... (450 more chars){
"path": "/workspace/problems/01_fp8_gemm/check.py"
}errorMessage: 'File not found'{
"path": "/workspace/problems/01_fp8_gemm/benchmark.py"
}"""Correctness runner for FP8 GEMM.
Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
"""
import re
import sys
from pathlib import Path
import torch
import yaml
# Make the repo's src/ importable
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.correctness import check_correctness # noqa: E402
from src.eval.numeric_stress import ( # noqa: E402
numeric_stress_cases,
numeric_stress_context,
tolerance_for_case,
)
def main():
try:
import reference
import shapes
import solution
except Exception as e:
print(f"FAIL: import error: {e}")
sys.exit(1)
problem_yaml = Path("problem.yaml")
meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
# --- Forbidden-op check ------------------------------------------------
sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
for forbidden in meta.get("forbidden", []):
pat = re.escape(forbidden)
if re.search(pat, sol_src):
print(f"FAIL: forbidden op used: {forbidden}")
sys.exit(1)
device = torch.device("cuda:0")
tol_override = meta.get("tolerance") or None
# --- Per-shape correctness --------------------------------------------
all_shapes = shapes.SHAPES
for shape_idx, shape in enumerate(all_shapes):
... (2703 more chars)"""Roofline benchmark for FP8 GEMM.
For each shape: times the agent's solution first. Optional eager, compiled,
and SOTA diagnostics are enabled with KBH_BENCHMARK_BASELINES=1. Reports achieved TFLOPS, GB/s, and peak_fraction.
Output lines the harness picks up:
shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
"""
import sys
from math import exp, log
from pathlib import Path
import torch
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
from src.hardware import get as get_hw # noqa: E402
def _eval_formula(expr: str, vars: dict) -> float:
# Very small eval: only names from `vars` are valid.
return float(eval(expr, {"__builtins__": {}}, vars))
def main():
import reference
import shapes
import solution
meta = yaml.safe_load(Path("problem.yaml").read_text())
hw = get_hw(meta["hardware"][0])
peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
peak_gbps = hw.peak_bandwidth_gb_s
regime = meta.get("regime", "compute")
flops_formula = meta["flops_formula"]
bytes_formula = meta["bytes_formula"]
num_perf_trials = int(meta.get("num_perf_trials", 30))
device = torch.device("cuda:0")
include_baselines = benchmark
... (3647 more chars){
"path": "/workspace/problems/01_fp8_gemm/shapes.py"
}{
"path": "/workspace/problems/01_fp8_gemm/problem.yaml"
}{
"path": "/workspace/problems/01_fp8_gemm/sota.py"
}"""Canonical shape sweep for FP8 GEMM.
Mix of:
- square aligned (the easy case)
- off-alignment K (common real-world failure mode for tile-quantized kernels)
- skinny (decode-like, memory-bound)
- rectangular (prefill with grouped attention)
"""
SHAPES = [
{"M": 4096, "N": 4096, "K": 4096}, # square aligned
{"M": 4096, "N": 4096, "K": 4127}, # K not multiple of 128 -> forces predicated tails
{"M": 32, "N": 8192, "K": 8192}, # skinny M (decode)
{"M": 4096, "N": 14336, "K": 4096}, # Llama3 up-proj shape
]
name: 01_fp8_gemm
display_name: "FP8 e4m3 GEMM"
precision: fp8_e4m3
regime: compute
# Dense-equivalent FLOPs per invocation. Evaluated per-shape; the benchmark
# substitutes M/N/K from each shapes.py entry.
flops_formula: "2 * M * N * K"
# Bytes moved (approximate): fp8 inputs in, bf16 output.
bytes_formula: "M*K + K*N + M*N*2"
hardware: [RTX_PRO_6000]
peak_tflops_key: fp8 # index into src/hardware/rtx_pro_6000.py peak_tflops_dense
peak_bandwidth_key: dram
# Per-dtype tolerance override (optional). Defaults are in src/eval/correctness.py.
tolerance:
bfloat16: 0.2
fp8_e4m3fn: 0.2
# Forbidden ops — using any of these in solution.py fails correctness post-hoc.
# torch._scaled_mm is Stanford-style cheating (just call the vendor library).
forbidden:
- "torch._scaled_mm"
- "torch.ops.aten._scaled_mm"
sota:
name: "FlashInfer FP8 GEMM"
url: "https://github.com/flashinfer-ai/flashinfer"
function: "flashinfer.gemm.fp8_gemm"
deps:
- "flashinfer>=0.6.8"
# Documented H100 throughput for this shape (informational, not graded):
reference_throughput_tflops_h100: 550
num_correct_trials: 3
num_perf_trials: 30
"""SOTA reference for FP8 GEMM: flashinfer.gemm.fp8_gemm.
If flashinfer is not installed or the SM120 path isn't supported, this falls
back to torch._scaled_mm which is the cuBLAS FP8 path. The benchmark treats
whichever succeeds as the SOTA reference line.
Agents are FORBIDDEN from using torch._scaled_mm in their solution (see
problem.yaml.forbidden). This file is only for the benchmark's reference line.
"""
from __future__ import annotations
import torch
def _try_flashinfer(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor | None:
try:
import flashinfer # noqa: F401
# Note: flashinfer's FP8 GEMM API surface may differ; adapt if needed.
# Placeholder call — replace with the actual flashinfer entry point
# once validated on SM120.
return None
except ImportError:
return None
def _scaled_mm(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
# torch._scaled_mm wants per-tensor scales. Use unit scales for the reference.
scale_a = torch.tensor(1.0, device=x.device)
scale_b = torch.tensor(1.0, device=x.device)
out = torch._scaled_mm(
x,
w.T,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
)
return out if not isinstance(out, tuple) else out[0]
def sota_forward(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""Best-available FP8 GEMM reference. x: (M, K) fp8, w: (N, K) fp8."""
out = _try_flashinfer(x, w)
if out is not None:
... (223 more chars){
"command": "<798 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_dbf7c147-9b20-4b66-b23e-064cec484ec",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (2269 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch
print('torch', torch.__version__)
print('cuda', torch.version.cuda)
print('device', torch.cuda.get_device_name(0))
print('capability', torch.cuda.get_device_capability(0))
print('fp8', hasattr(torch, 'float8_e4m3fn'))
# test triton
try:
import triton
print('triton', triton.__version__)
except: pass
# test scaled_mm
x = torch.randn(128,128, device='cuda', dtype=torch.bfloat16).to(torch.float8_e4m3fn)
w = torch.randn(128,128, device='cuda', dtype=torch.bfloat16).to(torch.float8_e4m3fn)
try:
out = torch._scaled_mm(x, w.T, scale_a=torch.ones(1,device='cuda'), scale_b=torch.ones(1,device='cuda'), out_dtype=torch.bfloat16)
print('scaled_mm works', out.shape)
except Exception as e:
print('scaled_mm', e)
"
null{
"command": "<1886 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 60000,
"toolCallId": "tool_6c9e5422-13d2-4c15-b013-80ae54f4aee",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (4481 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch
import triton
import triton.language as tl
# Test triton fp8 dot
code = '''
import triton
import triton.language as tl
import torch
@triton.jit
def test_dot(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
M, N, K = 128, 128, 128
a = torch.randn(M, K, device='cuda', dtype=torch.bfloat16).to(torch.float8_e4m3fn)
b = torch.randn(N, K, device='cuda', dtype=torch.bfloat16).to(torch.float8_e4m3fn)
c = torch.zeros(M, N, device='cuda', dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
test_dot[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(1), b.stride(0), c.stride(0), c.stride(1), BLOCK_M=64, BLOCK_N=64, BLOCK_K=128)
ref = (a.float() @ b.float().T)
print('max diff', (c - ref).abs().max().item())
print('triton fp8 dot works!')
'
exec(code)
"
{
"query": "FP8 e4m3 GEMM kernel implementation SM120 Blackwell",
"targetDirectories": [],
"explanation": ""
}errorMessage: 'Codebase search is unavailable for this conversation because the codebase index could not be found. Do not retry SemanticSearch. Use Grep, Glob, and Read instead.'error: 'Glob pattern "**/*" matches every file and is not allowed. Use a more specific glob or no glob.'failure:
command:
cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch
import triton
import triton.language as tl
# Test triton fp8 dot
code = '''
import triton
import triton.language as tl
import torch
@triton.jit
def test_dot(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
M, N, K = 128, 128, 128
a = torch.randn(M, K, device='cuda', dtype=torch.bfl
... (1053 more chars)--- /dev/null
+++ b//workspace/problems/01_fp8_gemm/test_triton_fp8.py
@@ -0,0 +1,56 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def test_dot(
+ A, B, C,
+ M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+ a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, K, BLOCK_K):
+ a = tl.load(
+ a_ptrs,
+ mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K),
+ other=0.0,
+ )
+ b = tl.load(
+ b_ptrs,
+ mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N),
+ other=0.0,
+ )
+ acc = tl.dot(a, b, acc)
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+ c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
+ tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+M, N, K = 128, 128, 128
+a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16).to(torch.float8_e4m3fn)
+b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).to(torch.float8_e4m3fn)
+c = torch.zeros(M, N, device="cuda", dtype=torch.float32)
+grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]), triton.cdiv(N, META["BLOCK_N"]))
+test_dot[grid](
+ a, b, c,
+ M, N, K,
+ a.stride(0), a.stride(1),
+ b.stride(1), b.stride(0),
+ c.stride(0), c.stride(1),
+ BLOCK_M=64, BLOCK_N=64, BLOCK_K=128,
+)
+ref = a.float() @ b.float().T
+print("max diff", (c - ref).abs().max().item())
+print("triton fp8 dot works!")
{
"path": "/workspace/problems/01_fp8_gemm/test_triton_fp8.py",
"streamContent": "<1871 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python test_triton_fp8.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_6a1fd41b-78a0-437b-8bfb-b186d2a54cd",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_g
... (919 more chars){
"pattern": "fp8|e4m3|float8",
"path": "/workspace",
"glob": "**/*.py",
"caseInsensitive": false,
"headLimit": 40,
"multiline": false,
"toolCallId": "tool_90d9be14-687e-45f5-ba65-5ac096bd2f5",
"offset": 0
}nullnull--- /dev/null
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -0,0 +1,139 @@
+"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+E4M3_MAX = 448.0
+
+
+@triton.jit
+def _fp8_gemm_kernel(
+ A,
+ B,
+ C,
+ Scales,
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ num_pid_in_group = GROUP_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, K, BLOCK_K):
+ k_mask = offs_k[None, :] + k < K
+ a = tl.load(
+ a_ptrs,
+ mask=(offs_m[:, None] < M) & k_mask,
+ other=0.0,
+ )
+ b = tl.load(
+ b_ptrs,
+ mask=k_mask[:, None] & (offs_n[None, :] < N),
+ other=0.0,
+ )
+ acc = tl.dot(a, b, acc)
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+
+ scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
+ acc = acc * scales[None, :]
+
+ c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
+ tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+def _fp8_gemm(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
+ """x: (M,K) fp8, weight: (N,K) fp8, weight_scale: (N,) fp32 -> (M,N) bf16."""
+ M, K = x.shape
+ N, K_w = weight.shape
+ assert K == K_w
+ y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
+
+ # Heuristic tile sizes tuned for SM120 FP8 tensor cores.
+ if M <= 64:
+ block_m, block_n, block_k, group_m, num_warps, num_stages = 16, 128, 128, 8, 4, 3
+ elif N >= 8192:
+ block_m, block_n, block_k, group_m, num_warps, num_stages = 128, 256, 128, 8, 8, 3
+ else:
+ block_m, block_n, block_k, group_m, num_warps, num_stages = 128, 128, 128, 8, 8, 3
+
+ grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
+ _fp8_gemm_kernel[grid](
+ x,
+ weight,
+ y,
+ weight_scale,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ weight.stride(1),
+ weight.stride(0),
+ y.stride(0),
+ y.stride(1),
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=block_k,
+ GROUP_M=group_m,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return y
+
+
+class Model(nn.Module):
+ """y = ((x @ w.T) * weight_scale).to(bf16)."""
+
+ def __init__(self, M: int, N: int, K: int):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ w = torch.empty(N, K, dtype=torch.bfloat16)
+ nn.init.normal_(w, std=0.02)
+ s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
+ w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
+ self.register_buffer("weight", w_fp8)
+ self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return _fp8_gemm(x, self.weight, self.weight_scale)
+
+
+M = 4096
+N = 4096
+K = 4096
+
+
+def get_inputs():
+ x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn)
+ return [x]
+
+
+def get_init_inputs():
+ return [M, N, K]
{
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<3923 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_e4d8e629-b2e4-4ad8-9248-99007c09306",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
... (899 more chars)failure:
command: 'cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 2035, in dot
assert len(a_shape) == len(b_shape) >= 2, "input and other must have equal ranks >= 2"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: input and other must have equal ranks >= 2
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/01_fp8_gemm/check.py", line 119, in <module>
main()
File "/workspace/problems/01_fp8_gemm/check.py", line 80, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/01_fp8_gemm/solution
... (5805 more chars)--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -25,18 +25,9 @@
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
- GROUP_M: tl.constexpr,
):
- pid = tl.program_id(0)
- num_pid_m = tl.cdiv(M, BLOCK_M)
- num_pid_n = tl.cdiv(N, BLOCK_N)
- num_pid_in_group = GROUP_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
- pid_m = first_pid_m + (pid % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
-
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
@@ -69,21 +60,19 @@
def _fp8_gemm(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
- """x: (M,K) fp8, weight: (N,K) fp8, weight_scale: (N,) fp32 -> (M,N) bf16."""
M, K = x.shape
N, K_w = weight.shape
assert K == K_w
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
- # Heuristic tile sizes tuned for SM120 FP8 tensor cores.
if M <= 64:
- block_m, block_n, block_k, group_m, num_warps, num_stages = 16, 128, 128, 8, 4, 3
+ block_m, block_n, block_k, num_warps, num_stages = 32, 128, 128, 4, 3
elif N >= 8192:
- block_m, block_n, block_k, group_m, num_warps, num_stages = 128, 256, 128, 8, 8, 3
+ block_m, block_n, block_k, num_warps, num_stages = 64, 256, 128, 8, 4
else:
- block_m, block_n, block_k, group_m, num_warps, num_stages = 128, 128, 128, 8, 8, 3
+ block_m, block_n, block_k, num_warps, num_stages = 64, 64, 128, 4, 3
- grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
+ grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n))
_fp8_gemm_kernel[grid](
x,
weight,
@@ -101,7 +90,6 @@
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
- GROUP_M=group_m,
num_warps=num_warps,
num_stages=num_stages,
)
@@ -109,8 +97,6 @@
class Model(nn.Module):
- """y = ((x @ w.T) * weight_scale).to(bf16)."""
-
def __init__(self, M: int, N: int, K: int):
super().__init__()
self.M, self.N, self.K = M, N, K
{
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<3317 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_dc1151be-6221-45a3-ac1d-d6b4b6d7f17",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
... (893 more chars)failure:
command: 'cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/problems/01_fp8_gemm/check.py", line 119, in <module>
main()
File "/workspace/problems/01_fp8_gemm/check.py", line 80, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/01_fp8_gemm/solution.py", line 111, in forward
return _fp8_gemm(x, self.weight, self.weight_scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/01_fp8_gemm/solution.py", line 76, in _fp8_gemm
_fp8_gemm_kernel[grid](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/trit
... (4753 more chars)--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,125 +1,12 @@
-"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-E4M3_MAX = 448.0
-
-
-@triton.jit
-def _fp8_gemm_kernel(
- A,
- B,
- C,
- Scales,
- M,
- N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
-):
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_K)
-
- a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
- b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
- k_mask = offs_k[None, :] + k < K
+ k_mask = offs_k + k < K
a = tl.load(
a_ptrs,
- mask=(offs_m[:, None] < M) & k_mask,
+ mask=(offs_m[:, None] < M) & (k_mask[None, :]),
other=0.0,
)
b = tl.load(
b_ptrs,
- mask=k_mask[:, None] & (offs_n[None, :] < N),
+ mask=(k_mask[:, None]) & (offs_n[None, :] < N),
other=0.0,
- )
- acc = tl.dot(a, b, acc)
- a_ptrs += BLOCK_K * stride_ak
- b_ptrs += BLOCK_K * stride_bk
-
- scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
- acc = acc * scales[None, :]
-
- c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
- tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
-
-
-def _fp8_gemm(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
- M, K = x.shape
- N, K_w = weight.shape
- assert K == K_w
- y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
-
- if M <= 64:
- block_m, block_n, block_k, num_warps, num_stages = 32, 128, 128, 4, 3
- elif N >= 8192:
- block_m, block_n, block_k, num_warps, num_stages = 64, 256, 128, 8, 4
- else:
- block_m, block_n, block_k, num_warps, num_stages = 64, 64, 128, 4, 3
-
- grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n))
- _fp8_gemm_kernel[grid](
- x,
- weight,
- y,
- weight_scale,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- weight.stride(1),
- weight.stride(0),
- y.stride(0),
- y.stride(1),
- BLOCK_M=block_m,
- BLOCK_N=block_n,
- BLOCK_K=block_k,
- num_warps=num_warps,
- num_stages=num_stages,
- )
- return y
-
-
-class Model(nn.Module):
- def __init__(self, M: int, N: int, K: int):
- super().__init__()
- self.M, self.N, self.K = M, N, K
- w = torch.empty(N, K, dtype=torch.bfloat16)
- nn.init.normal_(w, std=0.02)
- s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
- w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
- self.register_buffer("weight", w_fp8)
- self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return _fp8_gemm(x, self.weight, self.weight_scale)
-
-
-M = 4096
-N = 4096
-K = 4096
-
-
-def get_inputs():
- x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn)
- return [x]
-
-
-def get_init_inputs():
- return [M, N, K]
+ ){
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<334 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_c0250ace-4abf-4d78-bbbe-0df2a1d9277",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
... (902 more chars)failure:
command: 'cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/problems/01_fp8_gemm/check.py", line 119, in <module>
main()
File "/workspace/problems/01_fp8_gemm/check.py", line 80, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/01_fp8_gemm/solution.py", line 111, in forward
return _fp8_gemm(x, self.weight, self.weight_scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/01_fp8_gemm/solution.py", line 76, in _fp8_gemm
_fp8_gemm_kernel[grid](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/trit
... (3151 more chars)--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,12 +1,6 @@
- for k in range(0, K, BLOCK_K):
- k_mask = offs_k + k < K
- a = tl.load(
- a_ptrs,
- mask=(offs_m[:, None] < M) & (k_mask[None, :]),
- other=0.0,
- )
- b = tl.load(
- b_ptrs,
- mask=(k_mask[:, None]) & (offs_n[None, :] < N),
- other=0.0,
- )+ if M <= 64:
+ block_m, block_n, block_k, num_warps, num_stages = 32, 128, 128, 4, 2
+ elif N >= 8192:
+ block_m, block_n, block_k, num_warps, num_stages = 64, 128, 128, 8, 2
+ else:
+ block_m, block_n, block_k, num_warps, num_stages = 64, 64, 128, 4, 3{
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<278 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_a1c42ccd-2de5-42b4-822e-e871ce4e1d7",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
... (908 more chars)null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_5a324d0d-0d9c-4786-bedc-e1d77a5a1c5",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
... (903 more chars){
"path": "/workspace/src/hardware/rtx_pro_6000.py"
}"""RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
will see 60-85% of peak on well-tuned code.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class HardwareTarget:
name: str
sm: str
vram_gb: int
peak_bandwidth_gb_s: float # DRAM
peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
RTX_PRO_6000 = HardwareTarget(
name="RTX PRO 6000 Blackwell Workstation",
sm="sm_120a",
vram_gb=96,
peak_bandwidth_gb_s=1800.0,
peak_tflops_dense={
# Blackwell GB202 dense tensor peaks. Derived from NVIDIA's headline
# 4000 fp4-sparse AI TOPS -> fp4 dense 2000 -> fp8 1000 -> bf16 500
# (each precision step = 2x; sparse = 2x dense). Verified empirically:
# cuBLAS hits fp8 773 / bf16 412 TFLOPS on 4096^3 (~77-82% of these,
# normal cuBLAS efficiency). The prior table was ~2.5x too low, which
# produced peak_fraction > 1.0 for real fp8 kernels. fp32 is the 125
# TFLOPS SIMT figure from the spec sheet.
"fp4": 2000.0,
"nvfp4": 2000.0,
"mxfp4": 2000.0,
"fp6": 1000.0,
"fp8": 1000.0,
"bf16": 500.0,
"fp16": 500.0,
"tf32": 250.0,
"fp32": 125.0, # non-tensor-core SIMT fp32 (spec sheet)
"int8": 1000.0,
"int4": 2000.0,
},
)
null{
"path": "/workspace/src/eval/timing.py"
}"""Centralized GPU timing for benchmark.py files.
All problems' benchmark.py should call `time_fn` from here so we have one
implementation of warmup, L2 flush, and CUDA event capture to maintain.
Methodology:
- 10 warmup calls absorb Triton autotune (typical ~7 configs) and
torch.compile reduce-overhead CUDA-graph capture.
- Between each timed call, `_l2_flush()` writes 128 MB to evict L2
(Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
we measure HBM-load bandwidth, not L2-cached bandwidth).
- GPU timing via cuda.Event with synchronize() AFTER record() but
BEFORE elapsed_time() — the canonical NVIDIA pattern.
- Reported value is the median of `iters` trials, robust to outliers.
Notes / known biases not addressed here:
- torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
On small shapes where launch overhead matters, this gives compile an
artificial advantage. We accept this as the cost of using the agreed
"compiled" baseline policy.
- cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
- Median over a fairly small number of trials (default 30) is fine for
headline numbers but won't catch bimodal distributions. Use --extra
instrumentation if you ever care.
"""
from __future__ import annotations
import os
import re
import statistics
import time
from datetime import UTC, datetime
import torch
# Bl
... (2814 more chars){
"command": "<936 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_e9dad82b-2bca-435e-8f7c-da53026850b",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (2740 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, time
import reference, solution
device = 'cuda'
for shape in [{'M':4096,'N':4096,'K':4096}, {'M':4096,'N':4096,'K':4127}]:
reference.M = shape['M']; reference.N = shape['N']; reference.K = shape['K']
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(2026)
x = reference.get_inputs()[0].to(device)
# warmup
for _ in range(5):
sol_model(x)
torch.cuda.synchronize()
times = []
for _ in range(30):
t0 = time.perf_counter()
sol_model(x)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
import statistics
print(shape, 'median', statistics.median(times), 'min', min(times))
" 2>&1null{
"command": "<1689 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_3756ce47-8f03-41ea-a6db-ca8624a3393",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (4303 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, time, statistics
import reference, solution
device = 'cuda'
shape = {'M':4096,'N':4096,'K':4127}
reference.M = shape['M']; reference.N = shape['N']; reference.K = shape['K']
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(2026)
x = reference.get_inputs()[0].to(device)
# scaled_mm baseline (not for solution, just perf reference)
scale = torch.tensor(1.0, device=device)
w = ref_model.weight
for _ in range(10):
torch._scaled_mm(x, w.T, scale_a=scale, scale_b=scale, out_dtype=torch.bfloat16)
torch.cuda.synchronize()
times=[]
for _ in range(30):
t0=time.perf_counter()
torch._scaled_mm(x, w.T, scale_a=scale, scale_b=scale, out_dtype=torch.bfloat16)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print('scaled_mm ms', statistics.median(times))
for _ in range(10):
sol_model(x)
torch.cuda.synchronize()
times=[]
for _ in range(30):
t0=time.perf_counter()
sol_model(x)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print('solution ms', statistics.median(times))
# direct kernel
from solution import _fp8_gemm
for _ in range(10):
_fp8_gemm(x, ref_model.weight, ref_model.weight_scale)
torch.cuda.synchronize()
times=[]
for _ in range(30):
t0=time.perf_counter()
_fp8_gemm(x, ref_model.weight, ref_model.weight_scale)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print('kernel ms', statistics.median(times))
" 2>&1{
"pattern": "autotune|fp8.*gemm|float8",
"path": "/workspace/.venv/lib/python3.11/site-packages/triton",
"glob": "**/*.py",
"caseInsensitive": false,
"headLimit": 30,
"multiline": false,
"toolCallId": "tool_f833f9c9-0150-4556-952a-fa4a97d4acf",
"offset": 0
}nullfailure:
command:
cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, time, statistics
import reference, solution
device = 'cuda'
shape = {'M':4096,'N':4096,'K':4127}
reference.M = shape['M']; reference.N = shape['N']; reference.K = shape['K']
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(2026)
x = reference.get_inputs()[0].to(device)
# scaled_mm baseline (not for solution, just perf reference)
scale = torch.tensor(1.0, device=device)
w = ref_model.weight
for _ in range(10):
torch._scaled_mm(x, w.T, scale_a=scale, scale_b=scale, out_dtype=torch.bfloat16)
torch.cuda.synchronize()
times=[]
for _ in range(30):
t0=time.perf_counter()
torch._scaled_mm(x, w.T, scale_a=scale, scale_b=scale, out_dtype=torch.bfloat16)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print('scaled_mm ms', statistics.median(times))
for _ in range(10):
sol_model(x)
torch.cuda.synchronize()
times=[]
for _ in range(30):
t0=time.perf_counter()
sol_model(x)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print('solution ms', statistics.median(times))
# direct kernel
f
... (987 more chars){
"command": "<774 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_a2c3a3fb-b32e-4f56-9b16-5a5c2a7c11a",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (2406 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, time, statistics
from solution import _fp8_gemm
import reference
device='cuda'
M,N=4096,4096
reference.M=M; reference.N=N
for K in [4096, 4100, 4112, 4127, 4130, 4144, 4160]:
reference.K=K
ref = reference.Model(M,N,K).to(device).eval()
x = (torch.rand(M,K,device=device)*8-4).to(torch.float8_e4m3fn)
for _ in range(5):
_fp8_gemm(x, ref.weight, ref.weight_scale)
torch.cuda.synchronize()
times=[]
for _ in range(20):
t0=time.perf_counter()
_fp8_gemm(x, ref.weight, ref.weight_scale)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print(f'K={K} iters={ (K+127)//128 } ms={statistics.median(times):.3f}')
" 2>&1null{
"path": "/workspace/problems/01_fp8_gemm/solution.py"
}"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
import torch
import torch.nn as nn
import triton
import triton.language as tl
E4M3_MAX = 448.0
@triton.jit
def _fp8_gemm_kernel(
A,
B,
C,
Scales,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
k_mask = offs_k + k < K
a = tl.load(
a_ptrs,
mask=(offs_m[:, None] < M) & (k_mask[None, :]),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(k_mask[:, None]) & (offs_n[None, :] < N),
other=0.0,
)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
acc = acc * scales[None, :]
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(tl.bfloat16
... (1821 more chars)--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,6 +1,230 @@
+"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+E4M3_MAX = 448.0
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 4}, num_stages=2, num_warps=8),
+ ],
+ key=["M", "N", "K", "HAS_TAIL"],
+)
+@triton.jit
+def _fp8_gemm_kernel(
+ A,
+ B,
+ C,
+ Scales,
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ HAS_TAIL: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ num_pid_in_group = GROUP_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ full_k = (K // BLOCK_K) * BLOCK_K
+ for k in range(0, full_k, BLOCK_K):
+ a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
+ b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
+ acc = tl.dot(a, b, acc)
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+
+ if HAS_TAIL:
+ k = full_k
+ k_mask = offs_k + k < K
+ a = tl.load(
+ a_ptrs,
+ mask=(offs_m[:, None] < M) & k_mask[None, :],
+ other=0.0,
+ )
+ b = tl.load(
+ b_ptrs,
+ mask=k_mask[:, None] & (offs_n[None, :] < N),
+ other=0.0,
+ )
+ acc = tl.dot(a, b, acc)
+
+ scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
+ acc = acc * scales[None, :]
+
+ c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
+ tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=2, num_warps=4),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=2, num_warps=4),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128}, num_stages=2, num_warps=8),
+ ],
+ key=["N", "K", "HAS_TAIL"],
+)
+@triton.jit
+def _fp8_gemm_skinny_kernel(
+ A,
+ B,
+ C,
+ Scales,
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ HAS_TAIL: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_n = tl.program_id(0)
+ offs_m = tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ full_k = (K // BLOCK_K) * BLOCK_K
+ for k in range(0, full_k, BLOCK_K):
+ a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
+ b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
+ acc = tl.dot(a, b, acc)
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+
+ if HAS_TAIL:
+ k = full_k
+ k_mask = offs_k + k < K
+ a = tl.load(
+ a_ptrs,
+ mask=(offs_m[:, None] < M) & k_mask[None, :],
+ other=0.0,
+ )
+ b = tl.load(
+ b_ptrs,
+ mask=k_mask[:, None] & (offs_n[None, :] < N),
+ other=0.0,
+ )
+ acc = tl.dot(a, b, acc)
+
+ scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
+ acc = acc * scales[None, :]
+
+ c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
+ tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+def _fp8_gemm(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
+ M, K = x.shape
+ N, K_w = weight.shape
+ assert K == K_w
+ y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
+ has_tail = (K % 128) != 0
+ block_k = 128
+
if M <= 64:
- block_m, block_n, block_k, num_warps, num_stages = 32, 128, 128, 4, 2
- elif N >= 8192:
- block_m, block_n, block_k, num_warps, num_stages = 64, 128, 128, 8, 2
+ grid = (triton.cdiv(N, 128),)
+ _fp8_gemm_skinny_kernel[grid](
+ x,
+ weight,
+ y,
+ weight_scale,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ weight.stride(1),
+ weight.stride(0),
+ y.stride(0),
+ y.stride(1),
+ HAS_TAIL=has_tail,
+ )
else:
- block_m, block
... (1308 more chars){
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<6704 chars \u2014 see diff>"
}null{
"command": "<746 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_950de24b-6691-43c6-93ca-817cc09bdf4",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name":
... (2821 more chars)cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python -c "
import torch, time, statistics
from solution import _fp8_gemm
import reference
device='cuda'
for K in [4096, 4127]:
reference.M=4096; reference.N=4096; reference.K=K
ref = reference.Model(4096,4096,K).to(device).eval()
x = (torch.rand(4096,K,device=device)*8-4).to(torch.float8_e4m3fn)
for _ in range(10): _fp8_gemm(x, ref.weight, ref.weight_scale)
torch.cuda.synchronize()
times=[]
for _ in range(20):
t0=time.perf_counter()
_fp8_gemm(x, ref.weight, ref.weight_scale)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print(f'K={K} ms={statistics.median(times):.3f}')
" 2>&1failure:
command:
cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python -c "
import torch, time, statistics
from solution import _fp8_gemm
import reference
device='cuda'
for K in [4096, 4127]:
reference.M=4096; reference.N=4096; reference.K=K
ref = reference.Model(4096,4096,K).to(device).eval()
x = (torch.rand(4096,K,device=device)*8-4).to(torch.float8_e4m3fn)
for _ in range(10): _fp8_gemm(x, ref.weight, ref.weight_scale)
torch.cuda.synchronize()
times=[]
for _ in range(20):
t0=time.perf_counter()
_fp8_gemm(x, ref.weight, ref.weight_scale)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print(f'K={K} ms={statistics.median(times):.3f}')
" 2>&1
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 0 {'M': 4096, 'N': 4096, 'K': 4096} seed 42 case nominal: tolerance exceeded: max_abs_diff=12.3125 max_rel_diff=1 bad=30572/16777216 worst_idx=(4092, 3987) (atol=0.2, rtol=0.2)
stderr: ''
executionTime: 3931
interleavedOutput:
FAIL: shape 0 {'M': 4096, 'N': 4096, 'K': 4096} seed 42 case nominal: tolerance exceeded: max_abs_diff=12.3125 max_rel_diff=1 bad=30572/16777216 worst_idx=(4092, 3987) (atol=0.2, rtol=0.2)
aborted: False
localExecutionTimeMs: 3859
isBackground: False--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -7,16 +7,6 @@
E4M3_MAX = 448.0
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 4}, num_stages=2, num_warps=8),
- ],
- key=["M", "N", "K", "HAS_TAIL"],
-)
@triton.jit
def _fp8_gemm_kernel(
A,
@@ -87,14 +77,6 @@
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=2, num_warps=4),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=2, num_warps=4),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128}, num_stages=2, num_warps=8),
- ],
- key=["N", "K", "HAS_TAIL"],
-)
@triton.jit
def _fp8_gemm_skinny_kernel(
A,
@@ -154,16 +136,24 @@
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
+ if M <= 64:
+ return 32, 128, 128, 4, 2, 1
+ if N >= 12000:
+ return 128, 256, 128, 8, 2, 4
+ return 128, 128, 128, 8, 3, 8
+
+
def _fp8_gemm(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
M, K = x.shape
N, K_w = weight.shape
assert K == K_w
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
has_tail = (K % 128) != 0
- block_k = 128
+ block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N)
if M <= 64:
- grid = (triton.cdiv(N, 128),)
+ grid = (triton.cdiv(N, block_n),)
_fp8_gemm_skinny_kernel[grid](
x,
weight,
@@ -179,9 +169,14 @@
y.stride(0),
y.stride(1),
HAS_TAIL=has_tail,
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
)
else:
- grid = (triton.cdiv(M, 128) * triton.cdiv(N, 128),)
+ grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
_fp8_gemm_kernel[grid](
x,
weight,
@@ -197,6 +192,12 @@
y.stride(0),
y.stride(1),
HAS_TAIL=has_tail,
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=block_k,
+ GROUP_M=group_m,
+ num_warps=num_warps,
+ num_stages=num_stages,
)
return y
{
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<6325 chars \u2014 see diff>"
}{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_618f51b1-4983-45bc-bd05-e5b2d6cb145",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
... (908 more chars)nullnull{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_741838f1-55b6-470a-a26e-fb7dac68e87",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
... (916 more chars){
"command": "<715 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_bc3ecc64-3906-4055-8070-0567de7db24",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (2269 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, time, statistics
from solution import _fp8_gemm
import reference
device='cuda'
for K in [4096, 4127]:
reference.M=4096; reference.N=4096; reference.K=K
ref = reference.Model(4096,4096,K).to(device).eval()
x = (torch.rand(4096,K,device=device)*8-4).to(torch.float8_e4m3fn)
for _ in range(10): _fp8_gemm(x, ref.weight, ref.weight_scale)
torch.cuda.synchronize()
times=[]
for _ in range(20):
t0=time.perf_counter()
_fp8_gemm(x, ref.weight, ref.weight_scale)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
print(f'K={K} ms={statistics.median(times):.3f}')
" 2>&1nullnull--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,155 +1,16 @@
-"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-E4M3_MAX = 448.0
-
-
-@triton.jit
-def _fp8_gemm_kernel(
- A,
- B,
- C,
- Scales,
- M,
- N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- HAS_TAIL: tl.constexpr,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- GROUP_M: tl.constexpr,
-):
- pid = tl.program_id(0)
- num_pid_m = tl.cdiv(M, BLOCK_M)
- num_pid_n = tl.cdiv(N, BLOCK_N)
- num_pid_in_group = GROUP_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
- pid_m = first_pid_m + (pid % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
-
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_K)
-
- a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
- b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- full_k = (K // BLOCK_K) * BLOCK_K
- for k in range(0, full_k, BLOCK_K):
- a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
- b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
- acc = tl.dot(a, b, acc)
- a_ptrs += BLOCK_K * stride_ak
- b_ptrs += BLOCK_K * stride_bk
-
- if HAS_TAIL:
- k = full_k
- k_mask = offs_k + k < K
- a = tl.load(
- a_ptrs,
- mask=(offs_m[:, None] < M) & k_mask[None, :],
- other=0.0,
- )
- b = tl.load(
- b_ptrs,
- mask=k_mask[:, None] & (offs_n[None, :] < N),
- other=0.0,
- )
- acc = tl.dot(a, b, acc)
-
- scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
- acc = acc * scales[None, :]
-
- c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
- tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
-
-
-@triton.jit
-def _fp8_gemm_skinny_kernel(
- A,
- B,
- C,
- Scales,
- M,
- N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- HAS_TAIL: tl.constexpr,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
-):
- pid_n = tl.program_id(0)
- offs_m = tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_K)
-
- a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
- b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- full_k = (K // BLOCK_K) * BLOCK_K
- for k in range(0, full_k, BLOCK_K):
- a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
- b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
- acc = tl.dot(a, b, acc)
- a_ptrs += BLOCK_K * stride_ak
- b_ptrs += BLOCK_K * stride_bk
-
- if HAS_TAIL:
- k = full_k
- k_mask = offs_k + k < K
- a = tl.load(
- a_ptrs,
- mask=(offs_m[:, None] < M) & k_mask[None, :],
- other=0.0,
- )
- b = tl.load(
- b_ptrs,
- mask=k_mask[:, None] & (offs_n[None, :] < N),
- other=0.0,
- )
- acc = tl.dot(a, b, acc)
-
- scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
- acc = acc * scales[None, :]
-
- c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
- tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
-
-
-def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
- if M <= 64:
- return 32, 128, 128, 4, 2, 1
- if N >= 12000:
- return 128, 256, 128, 8, 2, 4
- return 128, 128, 128, 8, 3, 8
-
-
def _fp8_gemm(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
M, K = x.shape
N, K_w = weight.shape
assert K == K_w
+
+ block_k = 128
+ pad_k = (block_k - (K % block_k)) % block_k
+ if pad_k:
+ x = torch.nn.functional.pad(x, (0, pad_k))
+ weight = torch.nn.functional.pad(weight, (0, pad_k))
+ K = K + pad_k
+
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
- has_tail = (K % 128) != 0
block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N)
if M <= 64:
@@ -168,7 +29,7 @@
weight.stride(0),
y.stride(0),
y.stride(1),
- HAS_TAIL=has_tail,
+ HAS_TAIL=False,
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
@@ -191,7 +52,7 @@
weight.stride(0),
y.stride(0),
y.stride(1),
- HAS_TAIL=has_tail,
+ HAS_TAIL=False,
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
@@ -199,33 +60,4 @@
num_warps=num_warps,
num_stages=num_stages,
)
- return y
-
-
-class Model(nn.Module):
- def __init__(self, M: int, N: int, K: int):
- super().__init__()
- self.M, self.N, self.K = M, N, K
- w = torch.empty(N, K, dtype=torch.bfloat16)
- nn.init.normal_(w, std=0.02)
- s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
- w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
- self.register_buffer("weight", w_fp8)
- self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
-
- def forwar
... (299 more chars){
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<1726 chars \u2014 see diff>"
}{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_09735153-cb02-43b7-8773-b6c52f32efe",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText
... (1481 more chars)nullnull{
"command": "cd /workspace/problems/01_fp8_gemm && KBH_BENCHMARK_BASELINES=1 uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_ae84fccd-c2a8-4652-8ef7-98578df7e88",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspa
... (943 more chars){
"command": "<1596 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_93350ea1-b86d-4db6-9bfc-ef61306f223",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (4072 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
# sweep skinny kernel configs
import torch, time, statistics
import triton
from solution import _fp8_gemm_skinny_kernel
M,N,K = 32, 8192, 8192
x = torch.randn(M,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
w = torch.randn(N,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
s = torch.ones(N, device='cuda', dtype=torch.float32)
y = torch.empty(M,N,device='cuda',dtype=torch.bfloat16)
configs = [
(16,128,128,4,2),(32,128,128,4,2),(32,256,128,8,2),(16,256,128,8,2),
(32,128,256,4,3),(32,64,128,2,2),(64,128,128,4,2)
]
for bm,bn,bk,nw,ns in configs:
try:
grid=(triton.cdiv(N,bn),)
for _ in range(5):
_fp8_gemm_skinny_kernel[grid](x,w,y,s,M,N,K,x.stride(0),x.stride(1),w.stride(1),w.stride(0),y.stride(0),y.stride(1),HAS_TAIL=False,BLOCK_M=bm,BLOCK_N=bn,BLOCK_K=bk,num_warps=nw,num_stages=ns)
torch.cuda.synchronize()
times=[]
for _ in range(30):
t0=time.perf_counter()
_fp8_gemm_skinny_kernel[grid](x,w,y,s,M,N,K,x.stride(0),x.stride(1),w.stride(1),w.stride(0),y.stride(0),y.stride(1),HAS_TAIL=False,BLOCK_M=bm,BLOCK_N=bn,BLOCK_K=bk,num_warps=nw,num_stages=ns)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
flops=2*M*N*K/1e9
ms=statistics.median(times)
print(f'bm={bm} bn={bn} bk={bk} nw={nw} ns={ns} ms={ms:.3f} tflops={flops/ms*1e3:.1f} gbps={(M*K+N*K+M*N*2)/ms/1e6:.0f}')
except Exception as e:
print(f'bm={bm} bn={bn} FAIL {e}')
" 2>&1nullnull{
"command": "<1592 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_0a4e7e0e-7f0b-41c4-8669-7ac90ac5197",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (4067 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, time, statistics, triton
from solution import _fp8_gemm_kernel
M,N,K = 4096,4096,4096
x = torch.randn(M,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
w = torch.randn(N,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
s = torch.ones(N, device='cuda', dtype=torch.float32)
y = torch.empty(M,N,device='cuda',dtype=torch.bfloat16)
configs = [
(128,128,128,8,3,8),(128,256,128,8,2,4),(64,128,128,4,3,8),(128,64,128,4,3,8),
(256,128,128,8,2,4),(128,128,256,8,3,8),(64,64,128,4,3,8)
]
for bm,bn,bk,nw,ns,gm in configs:
try:
grid=(triton.cdiv(M,bm)*triton.cdiv(N,bn),)
for _ in range(5):
_fp8_gemm_kernel[grid](x,w,y,s,M,N,K,x.stride(0),x.stride(1),w.stride(1),w.stride(0),y.stride(0),y.stride(1),HAS_TAIL=False,BLOCK_M=bm,BLOCK_N=bn,BLOCK_K=bk,GROUP_M=gm,num_warps=nw,num_stages=ns)
torch.cuda.synchronize()
times=[]
for _ in range(20):
t0=time.perf_counter()
_fp8_gemm_kernel[grid](x,w,y,s,M,N,K,x.stride(0),x.stride(1),w.stride(1),w.stride(0),y.stride(0),y.stride(1),HAS_TAIL=False,BLOCK_M=bm,BLOCK_N=bn,BLOCK_K=bk,GROUP_M=gm,num_warps=nw,num_stages=ns)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
ms=statistics.median(times)
tflops=2*M*N*K/ms/1e6
print(f'bm={bm} bn={bn} bk={bk} nw={nw} ns={ns} gm={gm} ms={ms:.3f} tflops={tflops:.0f}')
except Exception as e:
print(f'bm={bm} bn={bn} FAIL {type(e).__name__}: {e}')
" 2>&1null{
"command": "<1221 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_74c303cb-caef-45f2-ac8f-36c420c9012",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (3306 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, time, statistics, triton, sys
sys.path.insert(0,'.')
from src.eval.timing import time_fn, _l2_flush
from solution import _fp8_gemm_skinny_kernel
M,N,K = 32, 8192, 8192
x = torch.randn(M,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
w = torch.randn(N,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
s = torch.ones(N, device='cuda', dtype=torch.float32)
y = torch.empty(M,N,device='cuda',dtype=torch.bfloat16)
for bm,bn,bk,nw,ns in [(32,128,256,4,3),(32,128,128,4,2),(16,128,256,4,3),(32,256,128,8,2)]:
grid=(triton.cdiv(N,bn),)
fn = lambda: _fp8_gemm_skinny_kernel[grid](x,w,y,s,M,N,K,x.stride(0),x.stride(1),w.stride(1),w.stride(0),y.stride(0),y.stride(1),HAS_TAIL=False,BLOCK_M=bm,BLOCK_N=bn,BLOCK_K=bk,num_warps=nw,num_stages=ns)
for _ in range(10): fn()
torch.cuda.synchronize()
times=[]
for _ in range(30):
_l2_flush(); torch.cuda.synchronize()
t0=time.perf_counter(); fn(); torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
ms=statistics.median(times)
print(f'bm={bm} bn={bn} bk={bk} ms={ms:.4f} gbps={(M*K+N*K+M*N*2)/ms/1e6:.0f}')
" 2>&1null{
"command": "<1331 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_4cf7cc81-d085-468f-8449-c34ac2f32a6",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (3525 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, time, statistics, triton
from solution import _fp8_gemm_kernel
M,N,K = 4096,14336,4096
x = torch.randn(M,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
w = torch.randn(N,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
s = torch.ones(N, device='cuda', dtype=torch.float32)
y = torch.empty(M,N,device='cuda',dtype=torch.bfloat16)
configs = [
(128,256,128,8,2,4),(128,128,128,8,3,8),(256,128,128,8,2,4),(128,256,128,8,2,4)
]
for bm,bn,bk,nw,ns,gm in configs:
grid=(triton.cdiv(M,bm)*triton.cdiv(N,bn),)
for _ in range(5):
_fp8_gemm_kernel[grid](x,w,y,s,M,N,K,x.stride(0),x.stride(1),w.stride(1),w.stride(0),y.stride(0),y.stride(1),HAS_TAIL=False,BLOCK_M=bm,BLOCK_N=bn,BLOCK_K=bk,GROUP_M=gm,num_warps=nw,num_stages=ns)
torch.cuda.synchronize()
times=[]
for _ in range(20):
t0=time.perf_counter()
_fp8_gemm_kernel[grid](x,w,y,s,M,N,K,x.stride(0),x.stride(1),w.stride(1),w.stride(0),y.stride(0),y.stride(1),HAS_TAIL=False,BLOCK_M=bm,BLOCK_N=bn,BLOCK_K=bk,GROUP_M=gm,num_warps=nw,num_stages=ns)
torch.cuda.synchronize()
times.append((time.perf_counter()-t0)*1000)
ms=statistics.median(times)
print(f'bm={bm} bn={bn} ms={ms:.3f} tflops={2*M*N*K/ms/1e6:.0f}')
" 2>&1--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,63 +1,6 @@
-def _fp8_gemm(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
- M, K = x.shape
- N, K_w = weight.shape
- assert K == K_w
-
- block_k = 128
- pad_k = (block_k - (K % block_k)) % block_k
- if pad_k:
- x = torch.nn.functional.pad(x, (0, pad_k))
- weight = torch.nn.functional.pad(weight, (0, pad_k))
- K = K + pad_k
-
- y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
- block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N)
-
+def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
if M <= 64:
- grid = (triton.cdiv(N, block_n),)
- _fp8_gemm_skinny_kernel[grid](
- x,
- weight,
- y,
- weight_scale,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- weight.stride(1),
- weight.stride(0),
- y.stride(0),
- y.stride(1),
- HAS_TAIL=False,
- BLOCK_M=block_m,
- BLOCK_N=block_n,
- BLOCK_K=block_k,
- num_warps=num_warps,
- num_stages=num_stages,
- )
- else:
- grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
- _fp8_gemm_kernel[grid](
- x,
- weight,
- y,
- weight_scale,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- weight.stride(1),
- weight.stride(0),
- y.stride(0),
- y.stride(1),
- HAS_TAIL=False,
- BLOCK_M=block_m,
- BLOCK_N=block_n,
- BLOCK_K=block_k,
- GROUP_M=group_m,
- num_warps=num_warps,
- num_stages=num_stages,
- )
- return y+ return 16, 128, 256, 4, 3, 1
+ if N >= 12000:
+ return 128, 256, 128, 8, 2, 4
+ return 128, 256, 128, 8, 2, 4{
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<216 chars \u2014 see diff>"
}nullnull--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,6 +1,24 @@
-def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
- if M <= 64:
- return 16, 128, 256, 4, 3, 1
- if N >= 12000:
- return 128, 256, 128, 8, 2, 4
- return 128, 256, 128, 8, 2, 4+class Model(nn.Module):
+ def __init__(self, M: int, N: int, K: int):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ w = torch.empty(N, K, dtype=torch.bfloat16)
+ nn.init.normal_(w, std=0.02)
+ s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
+ w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
+ self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
+
+ pad_k = (128 - (K % 128)) % 128
+ if pad_k:
+ w_fp8 = torch.nn.functional.pad(w_fp8, (0, pad_k))
+ self._k_padded = K + pad_k
+ else:
+ self._k_padded = K
+ self.register_buffer("weight", w_fp8)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ K = x.shape[1]
+ pad_k = self._k_padded - K
+ if pad_k:
+ x = torch.nn.functional.pad(x, (0, pad_k))
+ return _fp8_gemm(x, self.weight, self.weight_scale, self.M){
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<953 chars \u2014 see diff>"
}null--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,24 +1,61 @@
-class Model(nn.Module):
- def __init__(self, M: int, N: int, K: int):
- super().__init__()
- self.M, self.N, self.K = M, N, K
- w = torch.empty(N, K, dtype=torch.bfloat16)
- nn.init.normal_(w, std=0.02)
- s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
- w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
- self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
+def _fp8_gemm(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ M: int,
+) -> torch.Tensor:
+ K = x.shape[1]
+ N, K_w = weight.shape
+ assert K == K_w
- pad_k = (128 - (K % 128)) % 128
- if pad_k:
- w_fp8 = torch.nn.functional.pad(w_fp8, (0, pad_k))
- self._k_padded = K + pad_k
- else:
- self._k_padded = K
- self.register_buffer("weight", w_fp8)
+ y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
+ block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- K = x.shape[1]
- pad_k = self._k_padded - K
- if pad_k:
- x = torch.nn.functional.pad(x, (0, pad_k))
- return _fp8_gemm(x, self.weight, self.weight_scale, self.M)+ if M <= 64:
+ grid = (triton.cdiv(N, block_n),)
+ _fp8_gemm_skinny_kernel[grid](
+ x,
+ weight,
+ y,
+ weight_scale,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ weight.stride(1),
+ weight.stride(0),
+ y.stride(0),
+ y.stride(1),
+ HAS_TAIL=False,
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ else:
+ grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
+ _fp8_gemm_kernel[grid](
+ x,
+ weight,
+ y,
+ weight_scale,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ weight.stride(1),
+ weight.stride(0),
+ y.stride(0),
+ y.stride(1),
+ HAS_TAIL=False,
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=block_k,
+ GROUP_M=group_m,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return y{
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<1538 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_8f2ddfa0-9d97-4930-93b2-4a37e825efb",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText
... (1469 more chars)failure:
command: 'cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: state_dict mismatch at shape 1 ({'M': 4096, 'N': 4096, 'K': 4127}): Error(s) in loading state_dict for Model:
size mismatch for weight: copying a param with shape torch.Size([4096, 4127]) from checkpoint, the shape in current model is torch.Size([4096, 4224]).
stderr: ''
executionTime: 2357
interleavedOutput:
FAIL: state_dict mismatch at shape 1 ({'M': 4096, 'N': 4096, 'K': 4127}): Error(s) in loading state_dict for Model:
size mismatch for weight: copying a param with shape torch.Size([4096, 4127]) from checkpoint, the shape in current model is torch.Size([4096, 4224]).
aborted: False
localExecutionTimeMs: 2295
isBackground: False--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,61 +1,19 @@
-def _fp8_gemm(
- x: torch.Tensor,
- weight: torch.Tensor,
- weight_scale: torch.Tensor,
- M: int,
-) -> torch.Tensor:
- K = x.shape[1]
- N, K_w = weight.shape
- assert K == K_w
+class Model(nn.Module):
+ def __init__(self, M: int, N: int, K: int):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ w = torch.empty(N, K, dtype=torch.bfloat16)
+ nn.init.normal_(w, std=0.02)
+ s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
+ w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
+ self.register_buffer("weight", w_fp8)
+ self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
- y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
- block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N)
-
- if M <= 64:
- grid = (triton.cdiv(N, block_n),)
- _fp8_gemm_skinny_kernel[grid](
- x,
- weight,
- y,
- weight_scale,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- weight.stride(1),
- weight.stride(0),
- y.stride(0),
- y.stride(1),
- HAS_TAIL=False,
- BLOCK_M=block_m,
- BLOCK_N=block_n,
- BLOCK_K=block_k,
- num_warps=num_warps,
- num_stages=num_stages,
- )
- else:
- grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
- _fp8_gemm_kernel[grid](
- x,
- weight,
- y,
- weight_scale,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- weight.stride(1),
- weight.stride(0),
- y.stride(0),
- y.stride(1),
- HAS_TAIL=False,
- BLOCK_M=block_m,
- BLOCK_N=block_n,
- BLOCK_K=block_k,
- GROUP_M=group_m,
- num_warps=num_warps,
- num_stages=num_stages,
- )
- return y+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ weight = self.weight
+ K = x.shape[1]
+ pad_k = (128 - (K % 128)) % 128
+ if pad_k:
+ x = torch.nn.functional.pad(x, (0, pad_k))
+ weight = torch.nn.functional.pad(weight, (0, pad_k))
+ return _fp8_gemm(x, weight, self.weight_scale, self.M){
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<841 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_ba98d3b1-154e-42ac-9bfc-98ad436b62c",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText
... (1476 more chars)failure:
command: 'cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 2 {'M': 32, 'N': 8192, 'K': 8192} seed 42 case nominal: tolerance exceeded: max_abs_diff=120835 max_rel_diff=5.1968e+34 bad=127868/262144 worst_idx=(27, 65) (atol=0.2, rtol=0.2)
stderr: ''
executionTime: 4604
interleavedOutput:
FAIL: shape 2 {'M': 32, 'N': 8192, 'K': 8192} seed 42 case nominal: tolerance exceeded: max_abs_diff=120835 max_rel_diff=5.1968e+34 bad=127868/262144 worst_idx=(27, 65) (atol=0.2, rtol=0.2)
aborted: False
localExecutionTimeMs: 4541
isBackground: False--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,19 +1,3 @@
-class Model(nn.Module):
- def __init__(self, M: int, N: int, K: int):
- super().__init__()
- self.M, self.N, self.K = M, N, K
- w = torch.empty(N, K, dtype=torch.bfloat16)
- nn.init.normal_(w, std=0.02)
- s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
- w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
- self.register_buffer("weight", w_fp8)
- self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- weight = self.weight
- K = x.shape[1]
- pad_k = (128 - (K % 128)) % 128
- if pad_k:
- x = torch.nn.functional.pad(x, (0, pad_k))
- weight = torch.nn.functional.pad(weight, (0, pad_k))
- return _fp8_gemm(x, weight, self.weight_scale, self.M)+def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
+ if M <= 64:
+ return 32, 128, 256, 4, 3, 1{
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<125 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_739e3a56-eea0-45f5-abc5-91f301ce61f",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText
... (1478 more chars)null{
"path": "/workspace/problems/01_fp8_gemm/solution.py"
}"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
import torch
import torch.nn as nn
import triton
import triton.language as tl
E4M3_MAX = 448.0
@triton.jit
def _fp8_gemm_kernel(
A,
B,
C,
Scales,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
HAS_TAIL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
full_k = (K // BLOCK_K) * BLOCK_K
for k in range(0, full_k, BLOCK_K):
a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
... (5050 more chars)--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,3 +1,211 @@
+"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+E4M3_MAX = 448.0
+
+
+@triton.jit
+def _fp8_gemm_kernel(
+ A,
+ B,
+ C,
+ Scales,
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+ NUM_STAGES: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ num_pid_in_group = GROUP_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES):
+ a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
+ b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
+ acc = tl.dot(a, b, acc)
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+
+ scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
+ acc = acc * scales[None, :]
+
+ c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
+ tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+@triton.jit
+def _fp8_gemm_skinny_kernel(
+ A,
+ B,
+ C,
+ Scales,
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ NUM_STAGES: tl.constexpr,
+):
+ pid_n = tl.program_id(0)
+ offs_m = tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
+ b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES):
+ a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
+ b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
+ acc = tl.dot(a, b, acc)
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+
+ scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
+ acc = acc * scales[None, :]
+
+ c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
+ tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
if M <= 64:
- return 32, 128, 256, 4, 3, 1+ return 32, 128, 256, 4, 3, 1
+ if N >= 12000:
+ return 128, 256, 128, 8, 2, 4
+ return 128, 256, 128, 8, 2, 4
+
+
+def _fp8_gemm(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ M: int,
+) -> torch.Tensor:
+ K = x.shape[1]
+ N, K_w = weight.shape
+ assert K == K_w
+
+ y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
+ block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N)
+
+ if M <= 64:
+ grid = (triton.cdiv(N, block_n),)
+ _fp8_gemm_skinny_kernel[grid](
+ x,
+ weight,
+ y,
+ weight_scale,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ weight.stride(1),
+ weight.stride(0),
+ y.stride(0),
+ y.stride(1),
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=block_k,
+ NUM_STAGES=num_stages,
+ num_warps=num_warps,
+ )
+ else:
+ grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
+ _fp8_gemm_kernel[grid](
+ x,
+ weight,
+ y,
+ weight_scale,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ weight.stride(1),
+ weight.stride(0),
+ y.stride(0),
+ y.stride(1),
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=block_k,
+ GROUP_M=group_m,
+ NUM_STAGES=num_stages,
+ num_warps=num_warps,
+ )
+ return y
+
+
+class Model(nn.Module):
+ def __init__(self, M: int, N: int, K: int):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ w = torch.empty(N, K, dtype=torch.bfloat16)
+ nn.init.normal_(w, std=0.02)
+ s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
+ w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
+ self.register_buffer("weight", w_fp8)
+ self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
+ self._weight_padded: torch.Tensor | None = None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ K = x.shape[1]
+ pad_k = (128 - (K % 128)) % 128
+ if pad_k:
+ if self._weight_padded is None:
+ self._weight_padded = torch.nn.functional.pad(self.weight, (0, pad_k))
+ x = torch.nn.functional.pad(x, (0, pad_k))
+ weight = self._weight_padded
+ else:
+
... (280 more chars){
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<5914 chars \u2014 see diff>"
}{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_1951e202-1aa0-4eb6-b3c6-f28882a2717",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText
... (1466 more chars)nullfailure:
command: 'cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 1 {'M': 4096, 'N': 4096, 'K': 4127} seed 42 case small_weight: tolerance exceeded: max_abs_diff=16.5801 max_rel_diff=3.70181e+07 bad=16763441/16777216 worst_idx=(2349, 1017) (atol=0.003, rtol=0.05)
stderr: ''
executionTime: 2818
interleavedOutput:
FAIL: shape 1 {'M': 4096, 'N': 4096, 'K': 4127} seed 42 case small_weight: tolerance exceeded: max_abs_diff=16.5801 max_rel_diff=3.70181e+07 bad=16763441/16777216 worst_idx=(2349, 1017) (atol=0.003, rtol=0.05)
aborted: False
localExecutionTimeMs: 2730
isBackground: False--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,211 +1,17 @@
-"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-E4M3_MAX = 448.0
-
-
-@triton.jit
-def _fp8_gemm_kernel(
- A,
- B,
- C,
- Scales,
- M,
- N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- GROUP_M: tl.constexpr,
- NUM_STAGES: tl.constexpr,
-):
- pid = tl.program_id(0)
- num_pid_m = tl.cdiv(M, BLOCK_M)
- num_pid_n = tl.cdiv(N, BLOCK_N)
- num_pid_in_group = GROUP_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
- pid_m = first_pid_m + (pid % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
-
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_K)
-
- a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
- b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES):
- a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
- b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
- acc = tl.dot(a, b, acc)
- a_ptrs += BLOCK_K * stride_ak
- b_ptrs += BLOCK_K * stride_bk
-
- scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
- acc = acc * scales[None, :]
-
- c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
- tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
-
-
-@triton.jit
-def _fp8_gemm_skinny_kernel(
- A,
- B,
- C,
- Scales,
- M,
- N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- NUM_STAGES: tl.constexpr,
-):
- pid_n = tl.program_id(0)
- offs_m = tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_K)
-
- a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
- b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES):
- a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
- b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
- acc = tl.dot(a, b, acc)
- a_ptrs += BLOCK_K * stride_ak
- b_ptrs += BLOCK_K * stride_bk
-
- scales = tl.load(Scales + offs_n, mask=offs_n < N, other=1.0)
- acc = acc * scales[None, :]
-
- c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
- tl.store(c_ptrs, acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
-
-
-def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
- if M <= 64:
- return 32, 128, 256, 4, 3, 1
- if N >= 12000:
- return 128, 256, 128, 8, 2, 4
- return 128, 256, 128, 8, 2, 4
-
-
-def _fp8_gemm(
- x: torch.Tensor,
- weight: torch.Tensor,
- weight_scale: torch.Tensor,
- M: int,
-) -> torch.Tensor:
- K = x.shape[1]
- N, K_w = weight.shape
- assert K == K_w
-
- y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
- block_m, block_n, block_k, num_warps, num_stages, group_m = _pick_config(M, N)
-
- if M <= 64:
- grid = (triton.cdiv(N, block_n),)
- _fp8_gemm_skinny_kernel[grid](
- x,
- weight,
- y,
- weight_scale,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- weight.stride(1),
- weight.stride(0),
- y.stride(0),
- y.stride(1),
- BLOCK_M=block_m,
- BLOCK_N=block_n,
- BLOCK_K=block_k,
- NUM_STAGES=num_stages,
- num_warps=num_warps,
- )
- else:
- grid = (triton.cdiv(M, block_m) * triton.cdiv(N, block_n),)
- _fp8_gemm_kernel[grid](
- x,
- weight,
- y,
- weight_scale,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- weight.stride(1),
- weight.stride(0),
- y.stride(0),
- y.stride(1),
- BLOCK_M=block_m,
- BLOCK_N=block_n,
- BLOCK_K=block_k,
- GROUP_M=group_m,
- NUM_STAGES=num_stages,
- num_warps=num_warps,
- )
- return y
-
-
-class Model(nn.Module):
- def __init__(self, M: int, N: int, K: int):
- super().__init__()
- self.M, self.N, self.K = M, N, K
- w = torch.empty(N, K, dtype=torch.bfloat16)
- nn.init.normal_(w, std=0.02)
- s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
- w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
self.register_buffer("weight", w_fp8)
self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
self._weight_padded: torch.Tensor | None = None
+ self._weight_version: int = -1
def forward(self, x: torch.Tensor) -> torch.Tensor:
K = x.shape[1]
pad_k = (128 - (K % 128)) % 128
if pad_k:
- if self._weight_padded is None:
+ if self._weight_padded is None or self._weight_version != self.weight._version:
self._weight_padded = torch.nn.functional.pad(self.weight, (0, pad_k))
+
... (501 more chars){
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<800 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_afdbcdfe-c8b7-4c03-a76f-ccca311315b",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText
... (1459 more chars)null{
"command": "<1356 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_8e7c1272-ff50-43cc-9c8d-56247ba3c35",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText": "cd /workspace/problems/01_fp8_gemm"
},
{
"name": "uv",
... (3588 more chars)cd /workspace/problems/01_fp8_gemm && uv run python -c "
import torch, statistics, triton, sys
sys.path.insert(0,'.')
from src.eval.timing import _l2_flush
from solution import _fp8_gemm_kernel
M,N,K = 4096,4096,4096
x = torch.randn(M,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
w = torch.randn(N,K,device='cuda',dtype=torch.bfloat16).to(torch.float8_e4m3fn)
s = torch.ones(N, device='cuda', dtype=torch.float32)
y = torch.empty(M,N,device='cuda',dtype=torch.bfloat16)
configs = [
(128,256,128,8,2,4),(256,128,128,8,2,4),(128,128,128,8,3,8),(128,256,128,8,3,4)
]
for bm,bn,bk,nw,ns,gm in configs:
grid=(triton.cdiv(M,bm)*triton.cdiv(N,bn),)
fn = lambda bm=bm,bn=bn,bk=bk,nw=nw,ns=ns,gm=gm: _fp8_gemm_kernel[grid](x,w,y,s,M,N,K,x.stride(0),x.stride(1),w.stride(1),w.stride(0),y.stride(0),y.stride(1),BLOCK_M=bm,BLOCK_N=bn,BLOCK_K=bk,GROUP_M=gm,NUM_STAGES=ns,num_warps=nw)
for _ in range(10): fn()
torch.cuda.synchronize()
times=[]
for _ in range(30):
_l2_flush(); torch.cuda.synchronize()
e0=torch.cuda.Event(enable_timing=True); e1=torch.cuda.Event(enable_timing=True)
e0.record(); fn(); e1.record(); torch.cuda.synchronize()
times.append(e0.elapsed_time(e1))
ms=statistics.median(times)
print(f'bm={bm} bn={bn} ns={ns} ms={ms:.4f} tflops={2*M*N*K/ms/1e6:.0f}')
" 2>&1null--- a//workspace/problems/01_fp8_gemm/solution.py
+++ b//workspace/problems/01_fp8_gemm/solution.py
@@ -1,17 +1,6 @@
- self.register_buffer("weight", w_fp8)
- self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
- self._weight_padded: torch.Tensor | None = None
- self._weight_version: int = -1
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- K = x.shape[1]
- pad_k = (128 - (K % 128)) % 128
- if pad_k:
- if self._weight_padded is None or self._weight_version != self.weight._version:
- self._weight_padded = torch.nn.functional.pad(self.weight, (0, pad_k))
- self._weight_version = self.weight._version
- x = torch.nn.functional.pad(x, (0, pad_k))
- weight = self._weight_padded
- else:
- weight = self.weight
- return _fp8_gemm(x, weight, self.weight_scale, self.M)+def _pick_config(M: int, N: int) -> tuple[int, int, int, int, int, int]:
+ if M <= 64:
+ return 32, 128, 256, 4, 3, 1
+ if N >= 12000:
+ return 128, 256, 128, 8, 3, 4
+ return 128, 256, 128, 8, 3, 4{
"path": "/workspace/problems/01_fp8_gemm/solution.py",
"streamContent": "<216 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/01_fp8_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_275eef25-bbf2-4758-824a-18c65eb385e",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/01_fp8_gemm"
}
],
"fullText
... (1478 more chars)null{
"path": "/workspace/problems/01_fp8_gemm/test_triton_fp8.py",
"toolCallId": "tool_a5182eba-a028-47d6-8b92-07443e7c014"
}null{
"path": "/workspace/problems/01_fp8_gemm/solution.py"
}"""FP8 e4m3 GEMM via Triton FP8 tensor-core tl.dot + per-channel scale."""
import torch
import torch.nn as nn
import triton
import triton.language as tl
E4M3_MAX = 448.0
@triton.jit
def _fp8_gemm_kernel(
A,
B,
C,
Scales,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in tl.range(0, K, BLOCK_K, num_stages=NUM_STAGES):
a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLO
... (4561 more chars)