import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
# ---------------------------------------------------------------------------
# Hard-coded tuning for the five benchmark shapes. The benchmark creates
# a fresh Model(*init_args) per shape, so we can safely specialise.
# Fallback heuristics handle unseen sizes.
# ---------------------------------------------------------------------------
_KERNEL_CFG = {
(32, 4096): {"fused": "oneshot", "block_size": 4096},
(16, 32768): {"fused": "online", "block_size": 4096},
(8, 131072): {"fused": False, "nb": 16, "block_size": 1024},
(4, 262144): {"fused": False, "nb": 32, "block_size": 1024},
}
@triton.jit
def softmax_oneshot_kernel(
input_ptr,
output_ptr,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""One-shot softmax for rows that fit in a single tile (n_cols <= BLOCK_SIZE).
Loads the row once, computes max & sum, and writes the output.
"""
row_idx = tl.program_id(0)
if row_idx >= n_rows:
return
row_start = input_ptr + row_idx * n_cols
out_start = output_ptr + row_idx * n_cols
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
row_max = tl.max(x, axis=0)
row_sum = tl.sum(tl.exp(x - row_max), axis=0)
out_val = tl.exp(x - row_max) / row_sum
tl.store(out_start + cols, out_val, mask=mask)
@triton.jit
def softmax_fused_kernel(
input_ptr,
output_ptr,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""Two-pass online softmax for medium-length rows."""
row_idx = tl.program_id(0)
if row_idx >= n_rows:
return
row_start = input_ptr + row_idx * n_cols
out_start = output_ptr + row_idx * n_cols
# Online softmax: single-pass max+sum.
row_max = -float('inf')
row_sum = 0.0
offset = 0
while offset < n_cols:
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
block_max = tl.max(x, axis=0)
block_sum = tl.sum(tl.exp(x - block_max), axis=0)
new_max = tl.maximum(row_max, block_max)
scale_row = tl.exp(row_max - new_max)
scale_block = tl.exp(block_max - new_max)
row_sum = row_sum * scale_row + block_sum * scale_block
row_max = new_max
offset += BLOCK_SIZE
# Second pass: write normalized output.
offset = 0
while offset < n_cols:
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
out_val = tl.exp(x - row_max) / row_sum
tl.store(out_start + cols, out_val, mask=mask)
offset += BLOCK_SIZE
@triton.jit
def softmax_max_kernel(
input_ptr,
mid_max_ptr,
mid_sum_ptr,
n_rows,
n_cols,
num_blocks_per_row,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_blocks_per_row
block_idx = pid % num_blocks_per_row
if row_idx >= n_rows:
return
row_start = input_ptr + row_idx * n_cols
mid_idx = row_idx * num_blocks_per_row + block_idx
local_max = -float('inf')
local_sum = 0.0
offset = block_idx * BLOCK_SIZE
stride = num_blocks_per_row * BLOCK_SIZE
while offset < n_cols:
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
block_max = tl.max(x, axis=0)
block_sum = tl.sum(tl.exp(x - block_max), axis=0)
new_max = tl.maximum(local_max, block_max)
scale_local = tl.exp(local_max - new_max)
scale_block = tl.exp(block_max - new_max)
local_sum = local_sum * scale_local + block_sum * scale_block
local_max = new_max
offset += stride
tl.store(mid_max_ptr + mid_idx, local_max)
tl.store(mid_sum_ptr + mid_idx, local_sum)
@triton.jit
def softmax_reduce_kernel(
mid_max_ptr,
mid_sum_ptr,
row_max_ptr,
row_sum_ptr,
n_rows,
num_blocks_per_row,
):
row_idx = tl.program_id(0)
if row_idx >= n_rows:
return
base = row_idx * num_blocks_per_row
global_max = -float('inf')
global_sum = 0.0
for i in range(num_blocks_per_row):
m = tl.load(mid_max_ptr + base + i)
s = tl.load(mid_sum_ptr + base + i)
new_max = tl.maximum(global_max, m)
scale_global = tl.exp(global_max - new_max)
scale_m = tl.exp(m - new_max)
global_sum = global_sum * scale_global + s * scale_m
global_max = new_max
tl.store(row_max_ptr + row_idx, global_max)
tl.store(row_sum_ptr + row_idx, global_sum)
@triton.jit
def softmax_write_kernel(
input_ptr,
output_ptr,
row_max_ptr,
row_sum_ptr,
n_rows,
n_cols,
num_blocks_per_row,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_blocks_per_row
block_idx = pid % num_blocks_per_row
if row_idx >= n_rows:
return
row_start = input_ptr + row_idx * n_cols
out_start = output_ptr + row_idx * n_cols
row_max = tl.load(row_max_ptr + row_idx)
row_sum = tl.load(row_sum_ptr + row_idx)
offset = block_idx * BLOCK_SIZE
stride = num_blocks_per_row * BLOCK_SIZE
while offset < n_cols:
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
out_val = tl.exp(x - row_max) / row_sum
tl.store(out_start + cols, out_val, mask=mask)
offset += stride
class Model(nn.Module):
def __init__(self, batch: int, vocab: int):
super().__init__()
self.batch = batch
self.vocab = vocab
cfg = _KERNEL_CFG.get((batch, vocab))
if cfg is None:
if vocab <= 4096:
cfg = {"fused": "oneshot", "block_size": 4096}
elif vocab <= 32768:
cfg = {"fused": "online", "block_size": 4096}
else:
nb = max(1, min(32, 128 // batch))
cfg = {"fused": False, "nb": nb, "block_size": 1024}
self._cfg = cfg
# Pre-allocate scratch buffers for the multi-block path so that
# repeated forward() calls avoid cudaMalloc overhead.
if cfg.get("fused") is False:
nb = cfg["nb"]
self._mid_max = torch.empty(
batch * nb, dtype=torch.float32, device="cuda"
)
self._mid_sum = torch.empty(
batch * nb, dtype=torch.float32, device="cuda"
)
self._row_max = torch.empty(
batch, dtype=torch.float32, device="cuda"
)
self._row_sum = torch.empty(
batch, dtype=torch.float32, device="cuda"
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 2
n_rows, n_cols = x.shape
out = torch.empty_like(x)
cfg = self._cfg
fused = cfg.get("fused")
if fused == "oneshot":
grid = (n_rows,)
softmax_oneshot_kernel[grid](
x, out,
n_rows, n_cols,
BLOCK_SIZE=cfg["block_size"],
)
elif fused == "online":
grid = (n_rows,)
softmax_fused_kernel[grid](
x, out,
n_rows, n_cols,
BLOCK_SIZE=cfg["block_size"],
)
else:
nb = cfg["nb"]
bs = cfg["block_size"]
mid_max = self._mid_max[: n_rows * nb]
mid_sum = self._mid_sum[: n_rows * nb]
row_max = self._row_max[:n_rows]
row_sum = self._row_sum[:n_rows]
grid1 = (n_rows * nb,)
softmax_max_kernel[grid1](
x, mid_max, mid_sum,
n_rows, n_cols, nb,
BLOCK_SIZE=bs,
)
grid2 = (n_rows,)
softmax_reduce_kernel[grid2](
mid_max, mid_sum, row_max, row_sum,
n_rows, nb,
)
grid3 = (n_rows * nb,)
softmax_write_kernel[grid3](
x, out, row_max, row_sum,
n_rows, n_cols, nb,
BLOCK_SIZE=bs,
)
return out
BATCH = 8
VOCAB = 32768
def get_inputs():
x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
return [x]
def get_init_inputs():
return [BATCH, VOCAB]
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
# ---------------------------------------------------------------------------
# Hard-coded tuning for the five benchmark shapes. The benchmark creates
# a fresh Model(*init_args) per shape, so we can safely specialise.
# Fallback heuristics handle unseen sizes.
# ---------------------------------------------------------------------------
_KERNEL_CFG = {
(32, 4096): {"fused": "oneshot", "block_size": 4096},
(16, 32768): {"fused": "online", "block_size": 4096},
(8, 131072): {"fused": False, "nb": 16, "block_size": 1024},
(4, 262144): {"fused": False, "nb": 32, "block_size": 1024},
}
shape=0 variant=eager tflops=0.009 gbps=14.347 ms=0.073
shape=0 variant=compiled tflops=0.006 gbps=10.039 ms=0.104
shape=0 variant=sota tflops=0.031 gbps=48.835 ms=0.021
shape=0 variant=solution tflops=0.055 gbps=87.732 ms=0.012
shape=0 solution_peak_fraction=0.0487
shape=1 variant=eager tflops=0.039 gbps=62.759 ms=0.067
shape=1 variant=compiled tflops=0.027 gbps=42.549 ms=0.099
shape=1 variant=sota tflops=0.100 gbps=160.627 ms=0.026
shape=1 variant=solution tflops=0.112 gbps=178.572 ms=0.023
shape=1 solution_peak_fraction=0.0992
shape=2 variant=eager tflops=0.071 gbps=113.433 ms=0.074
shape=2 variant=compiled tflops=0.024 gbps=38.111 ms=0.220
shape=2 variant=sota tflops=0.080 gbps=127.844 ms=0.066
shape=2 variant=solution tflops=0.193 gbps=309.497 ms=0.027
shape=2 solution_peak_fraction=0.1719
shape=3 variant=eager tflops=0.069 gbps=110.890 ms=0.076
shape=3 variant=compiled tflops=0.014 gbps=22.998 ms=0.365
shape=3 variant=sota tflops=0.076 gbps=122.240 ms=0.069
shape=3 variant=solution tflops=0.185 gbps=295.540 ms=0.028
shape=3 solution_peak_fraction=0.1642
shape=4 variant=eager tflops=0.070 gbps=112.774 ms=0.074
shape=4 variant=compiled tflops=0.023 gbps=36.592 ms=0.229
shape=4 variant=sota tflops=0.092 gbps=147.562 ms=0.057
shape=4 variant=solution tflops=0.187 gbps=299.422 ms=0.028
shape=4 solution_peak_fraction=0.1663
peak_fraction: 0.1178
RESULT: OK
PASS
{
"problem": "04_kahan_softmax",
"harness": "kimi",
"model": "kimi-k2.6",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1178,
"elapsed_seconds": 1573,
"harness_exit_code": 1,
"session_complete": false,
"usage": {"input_tokens": null, "output_tokens": null, "cache_read_tokens": null, "cache_creation_tokens": null, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py"
}{
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py"
}<system>52 lines read from file starting from line 1. Total lines in file: 52. End of file reached.</system>
1 """Naive softmax over the last dim, computed in fp64 for ground-truth.
2
3 The reference deliberately runs in float64 so that fp16 / fp32 accumulation
4 drift in agent solutions is exposed by the tight tolerance in problem.yaml.
5 The agent's job is to produce an fp32 softmax whose values match this
6 double-precision reference within atol=rtol=1e-5 — this requires either
7 fp32 accumulation or compensated (Kahan) summation when vocab is large.
8 """
9 import torch
10 import torch.nn as nn
11
12 OP_TYPE = "softmax"
13 SUPPORTED_PRECISIONS = ["fp32"]
14 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
15
16
17 class Model(nn.Module):
18 """y = softmax(x, dim=-1) computed in fp64 then returned as fp32.
19
20 No learned parameters — softmax is parameter-free. We still expose an
21 empty state_dict so the harness's strict load_state_dict matches.
22 """
23
24 def __init__(self, batch: int, vocab: int):
25 super().__init__()
26 self.batch = batch
27 self.vocab = vocab
28
29 def forward(self, x: torch.Tensor) -> torch.Tensor:
30 # Promote to fp64 for the ground-truth pathway. Even with double
31 # precision we still subtract the row-max for stability.
32 x6
... (706 more chars){
"raw": "{\"path\": \"/home/infatos"
}<system>126 lines read from file starting from line 1. Total lines in file: 126. End of file reached.</system>
1 """Correctness runner for Kahan-corrected softmax.
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3
4 seeds each, with the tight (1e-5) fp32 tolerance from problem.yaml. Also
5 rejects forbidden ops via grep.
6 """
7 import re
8 import sys
9 from pathlib import Path
10
11 import torch
12 import yaml
13
14 # Make the repo's src/ importable
15 REPO_ROOT = Path(__file__).resolve().parents[2]
16 sys.path.insert(0, str(REPO_ROOT))
17
18 from src.eval.correctness import check_correctness # noqa: E402
19
20
21 def _make_inputs(batch: int, vocab: int, extreme: bool, seed: int) -> torch.Tensor:
22 g = torch.Generator().manual_seed(seed)
23 if extreme:
24 # Adversarial: most logits are mild but a handful per row are huge.
25 # If the kernel forgets to subtract the row-max before exp, this
26 # row overflows fp32 and produces NaN/Inf. If it accumulates in
27 # fp16, the long tail of small exp() values is lost beneath the
28 # tolerance threshold.
29 x = torch.randn(batch, vocab, generator=g) * 2.0
30 # Spike: 4 very large positive logits per row.
31 idx = torch.randint(0, vocab, (batch, 4), generator=g)
32 x.scatter_(1, idx, 30.
... (3778 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py"
}<system>135 lines read from file starting from line 1. Total lines in file: 135. End of file reached.</system>
1 """Roofline benchmark for Kahan-corrected softmax.
2
3 For each shape: times eager reference, compiled reference, SOTA (if
4 available), and the agent's solution. Reports achieved TFLOPS, GB/s, and
5 peak_fraction. Softmax is memory-bound, so the score is GB/s / peak_dram.
6
7 Output lines the harness picks up:
8 shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9 peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10 """
11 import sys
12 from math import exp, log
13 from pathlib import Path
14
15 import torch
16 import yaml
17
18 REPO_ROOT = Path(__file__).resolve().parents[2]
19 sys.path.insert(0, str(REPO_ROOT))
20
21 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22 from src.eval.timing import time_fn # noqa: E402
23 from src.hardware import get as get_hw # noqa: E402
24
25
26 def _eval_formula(expr: str, vars: dict) -> float:
27 return float(eval(expr, {"__builtins__": {}}, vars))
28
29
30 def _make_inputs(batch: int, vocab: int, extreme: bool) -> torch.Tensor:
31 if extreme:
32 x = torch.randn(batch, vocab) * 2.0
33 idx = torch.randint(0, vocab, (batch, 4))
34 x.scatter_(1, idx, 30.0)
35 else:
36
... (4034 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py"
}<system>24 lines read from file starting from line 1. Total lines in file: 24. End of file reached.</system>
1 """Shape sweep for Kahan-corrected softmax.
2
3 The point of this problem is numerical accuracy on long reductions. Shapes
4 mix typical LLM vocab sizes with deliberately adversarial regimes:
5
6 - small vocab (sanity check; naive fp32 should pass)
7 - Llama3 vocab 128K (real-world, where fp16 accumulation starts to drift)
8 - 256K (DeepSeek-V3 / Gemma-3 class vocab; naive fp16 sum DOES drift past
9 the 1e-5 tolerance — this row is what proves Kahan was needed)
10 - extreme-logit edge case (large positive logits stress max-subtract +
11 summation; if the implementation accidentally exps before subtracting
12 max, this row overflows)
13
14 The 'extreme' flag is read by check.py to switch input generation to a
15 distribution that produces a few very large logits per row.
16 """
17
18 SHAPES = [
19 {"batch": 32, "vocab": 4096, "extreme": False}, # sanity
20 {"batch": 16, "vocab": 32768, "extreme": False}, # GPT-2 class
21 {"batch": 8, "vocab": 131072, "extreme": False}, # Llama3 vocab
22 {"batch": 4, "vocab": 262144, "extreme": False}, # 256K — Kahan needed
23 {"batch": 8, "vocab": 131072, "extreme": True}, # extreme logits edge
24 ]
{
"raw": "{\"path\": \"/home/infatos"
}<system>45 lines read from file starting from line 1. Total lines in file: 45. End of file reached.</system>
1 """SOTA reference for last-dim softmax.
2
3 Preference order:
4 1. liger-kernel's Triton softmax (LigerSoftmaxFunction) — fused, fp32-acc
5 2. torch.compile(torch.softmax) — Inductor generates a Triton kernel that
6 is competitive on bandwidth-bound reductions
7
8 Agents are FORBIDDEN from importing either of these in solution.py (see
9 problem.yaml.forbidden). This file is only the benchmark's reference line.
10 """
11 from __future__ import annotations
12
13 import torch
14
15 _compiled_softmax = None
16
17
18 def _liger_softmax(x: torch.Tensor) -> torch.Tensor | None:
19 try:
20 from liger_kernel.ops.softmax import LigerSoftmaxFunction
21 return LigerSoftmaxFunction.apply(x)
22 except Exception:
23 return None
24
25
26 def _compiled(x: torch.Tensor) -> torch.Tensor:
27 global _compiled_softmax
28 if _compiled_softmax is None:
29 _compiled_softmax = torch.compile(
30 lambda t: torch.softmax(t, dim=-1),
31 mode="reduce-overhead",
32 )
33 return _compiled_softmax(x)
34
35
36 def sota_forward(x: torch.Tensor) -> torch.Tensor:
37 """Best-available softmax reference. x: (batch, vocab) fp32."""
38 out =
... (226 more chars)<system>7 lines read from file starting from line 1. Total lines in file: 7. End of file reached.</system>
1 I need you to write a numerically tight softmax for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2
3 The op is softmax along the last dim of a 2D fp32 tensor. The reference computes ground truth in fp64 and casts back to fp32, and you have to match it within atol=rtol=1e-5 — that's a tighter bar than default fp32 (1e-4) on purpose. With long reductions naive fp16 accumulation drifts past this; fp32 accumulation with subtract-max stability is enough on most shapes; on the largest vocabs you may need compensated (Kahan-style) summation to stay under the bar. The shapes you have to handle are batch=32 vocab=4096 (sanity), batch=16 vocab=32768 (GPT-2 class), batch=8 vocab=131072 (Llama-3 vocab), batch=4 vocab=262144 (256K, DeepSeek-V3 / Gemma-3 class — naive fp16 sum drifts past 1e-5 here), and batch=8 vocab=131072 with extreme logits (a few very large positives per row to stress max-subtract — exping before subtracting overflows). The check and benchmark scripts handle the extreme-flag input generation; you just ne
... (927 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml"
}<system>43 lines read from file starting from line 1. Total lines in file: 43. End of file reached.</system>
1 name: 04_kahan_softmax
2 display_name: "Kahan-corrected Softmax"
3 precision: fp32
4 regime: memory # softmax is bandwidth-bound: 2 passes over the input tensor
5
6 # Softmax FLOPs: per-element exp + 2 reductions + divide. Roughly 5 flops/elt.
7 flops_formula: "5 * batch * vocab"
8
9 # Bytes moved: read x once, write y once. Both fp32.
10 bytes_formula: "batch * vocab * 4 + batch * vocab * 4"
11
12 hardware: [RTX_PRO_6000]
13 peak_tflops_key: fp32
14 peak_bandwidth_key: dram
15
16 # TIGHTER than default (fp32 default is 1e-4). This problem exists
17 # specifically to test whether the agent uses compensated summation, so
18 # we squeeze the tolerance to 1e-5 — naive fp16 sum across 256K elements
19 # drifts past this; fp32 accumulation passes; Kahan/fp32 always passes.
20 tolerance:
21 "torch.float32": {"atol": 1.0e-5, "rtol": 1.0e-5}
22
23 # Forbidden ops — block the obvious "just call the library" cheats. The
24 # agent must implement softmax themselves with explicit (compensated)
25 # summation logic.
26 forbidden:
27 - "torch.nn.functional.softmax"
28 - "torch.softmax"
29 - "F.softmax"
30 - "liger_kernel.softmax"
31 - "liger_kernel.transformers.softmax"
32 - ".softmax("
33
34 sota:
35 name: "Lig
... (312 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python -c \"import torch; print(torch.__version__); import triton; print(triton.__version__)\""
}<system>Command executed successfully.</system>
2.11.0+cu130
3.6.0
{
"command": "python -c \"import liger_kernel; print(liger_kernel.__file__)\""
}<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'liger_kernel'
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -0,0 +1,89 @@
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "softmax"
+SUPPORTED_PRECISIONS = ["fp32"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+@triton.jit
+def softmax_kernel(
+ input_ptr,
+ output_ptr,
+ n_rows,
+ n_cols,
+ BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+
+ # --- Pass 1: compute row max ---
+ row_max = -float("inf")
+ for offset in range(0, n_cols, BLOCK_SIZE):
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float("inf"))
+ block_max = tl.max(x, axis=0)
+ row_max = tl.maximum(row_max, block_max)
+
+ # --- Pass 2: compute row sum with Kahan compensation ---
+ row_sum = 0.0
+ row_c = 0.0
+ for offset in range(0, n_cols, BLOCK_SIZE):
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float("inf"))
+ exp_vals = tl.exp(x - row_max)
+ block_sum = tl.sum(exp_vals, axis=0)
+ # Kahan add block_sum into row_sum
+ y = block_sum - row_c
+ t = row_sum + y
+ row_c = (t - row_sum) - y
+ row_sum = t
+
+ # --- Pass 3: write output ---
+ for offset in range(0, n_cols, BLOCK_SIZE):
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float("inf"))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+
+
+class Model(nn.Module):
+ def __init__(self, batch: int, vocab: int):
+ super().__init__()
+ self.batch = batch
+ self.vocab = vocab
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ assert x.dim() == 2
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ BLOCK_SIZE = 1024
+ grid = (n_rows,)
+ softmax_kernel[grid](
+ x, out,
+ n_rows, n_cols,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ return out
+
+
+BATCH = 8
+VOCAB = 32768
+
+
+def get_inputs():
+ x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
+ return [x]
+
+
+def get_init_inputs():
+ return [BATCH, VOCAB]
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<2356 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 2356 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python check.py"
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/lowering.py:7836: UserWarning:
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.
warnings.warn(
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/lowering.py:7836: UserWarning:
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.
warnings.warn(
shape=0 variant=eager tflops=0.010 gbps=15.244 ms=0.069
shape=0 variant=compiled tflops=0.006 gbps=10.300 ms=0.102
shape=0 variant=sota tflops=0.008 gbps=13.588 ms=0.077
shape=0 variant=solution tflops=0.031 gbps=49.312 ms=0.021
shape=0 solution_peak_fraction=0.0274
shape=1 variant=eager tflops=0.042 gbps=66.873 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.771 ms=0.098
shape=1 variant=sota tflops=0.032 gbps=50.548 ms=0.083
shape=1 variant=solution tflops=0.069 gbps=110.796 ms=0.038
shape=1 solution_peak_fraction=0.0616
shape=2 variant=eager tflops=0.072 gbps=114.774 ms=0.073
shape=2 variant=compiled tflops=0.023 gbps=37.107 ms=0.226
shape=2 variant=sota tflops=0.057 gbps=91.723 ms=0.091
shape=2 variant=solution tflops=0.042 gbps=67.755 ms=0.124
shape=2 solution_peak_fraction=0.0376
shape=3 variant=eager tflops=0.
... (539 more chars){
"command": "python -c \"import triton.language as tl; print(dir(tl))\" | tr ',' '\\n' | grep -i atomic"
}<system>Command executed successfully.</system>
'atomic_add'
'atomic_and'
'atomic_cas'
'atomic_max'
'atomic_min'
'atomic_or'
'atomic_xchg'
'atomic_xor'
{
"command": "<348 chars \u2014 see below>"
}python -c "
import triton
import triton.language as tl
@triton.jit
def test_kernel(ptr, n, BLOCK: tl.constexpr):
pid = tl.program_id(0)
off = pid * BLOCK
i = 0
while i < 3:
off = off + 1
i = i + 1
tl.store(ptr, off)
import torch
x = torch.zeros(1, device='cuda')
test_kernel[(1,)](x, 10, BLOCK=4)
print('ok')
"<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1260, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1089, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 5, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_triton.py
@@ -0,0 +1,17 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def test_kernel(ptr, n, BLOCK: tl.constexpr):
+ pid = tl.program_id(0)
+ off = pid * BLOCK
+ i = 0
+ while i < 3:
+ off = off + 1
+ i = i + 1
+ tl.store(ptr, off)
+
+x = torch.zeros(1, device='cuda')
+test_kernel[(1,)](x, 10, BLOCK=4)
+print('while loop ok')
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_triton.py",
"content": "<346 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 346 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_triton.py"
}<system>Command executed successfully.</system>
while loop ok
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_triton2.py
@@ -0,0 +1,13 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def test_kernel(ptr, n):
+ for i in range(n):
+ tl.store(ptr + i, i)
+
+x = torch.zeros(10, device='cuda')
+test_kernel[(1,)](x, 5)
+print(x[:5])
+print('range dynamic ok')
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_triton2.py",
"content": "<245 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 245 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_triton2.py"
}<system>Command executed successfully.</system>
tensor([0., 1., 2., 3., 4.], device='cuda:0')
range dynamic ok
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_multi.py
@@ -0,0 +1,152 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def softmax_max_kernel(
+ input_ptr,
+ mid_max_ptr,
+ mid_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr,
+ mid_sum_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr,
+ output_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+def softmax_multi(x, num_blocks_per_row=32, BLOCK_SIZE=1024):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ mid_max = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
+ mid_sum = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
+ row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+
+ grid1 = (n_rows * num_blocks_per_row,)
+ softmax_max_kernel[grid1](
+ x, mid_max, mid_sum,
+ n_rows, n_cols, num_blocks_per_row,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](
+ mid_max, mid_sum, row_max, row_sum,
+ n_rows, num_blocks_per_row,
+ )
+ grid3 = (n_rows * num_blocks_per_row,)
+ softmax_write_kernel[grid3](
+ x, out, row_max, row_sum,
+ n_rows, n_cols, num_blocks_per_row,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ return out
+
+
+def reference(x):
+ x64 = x.to(torch.float64)
+ m = x64.amax(dim=-1, keepdim=True)
+ e = torch.exp(x64 - m)
+ s = e.sum(dim=-1, keepdim=True)
+ return (e / s).to(torch.float32)
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ for batch, vocab in [(32, 4096), (16, 32768), (8, 131072), (4, 262144)]:
+ x = torch.randn(batch, vocab, dtype=torch.float32, device='cuda') * 4.0
+ y = softmax_multi(x, num_blocks_per_row=32)
+ y_ref = reference(x)
+ diff = (y - y_ref).abs().max().item()
+ print(f"batch={batch} vocab={vocab} max_diff={diff:.2e}")
+ assert diff < 1e-5
+ print("all passed")
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_multi.py",
"content": "<4515 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 4515 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_multi.py"
}<system>Command executed successfully.</system>
batch=32 vocab=4096 max_diff=1.19e-07
batch=16 vocab=32768 max_diff=1.19e-07
batch=8 vocab=131072 max_diff=1.19e-07
batch=4 vocab=262144 max_diff=2.98e-08
all passed
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -9,51 +9,105 @@
@triton.jit
-def softmax_kernel(
+def softmax_max_kernel(
input_ptr,
- output_ptr,
+ mid_max_ptr,
+ mid_sum_ptr,
n_rows,
n_cols,
+ num_blocks_per_row,
BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr,
+ mid_sum_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ num_blocks_per_row,
):
row_idx = tl.program_id(0)
if row_idx >= n_rows:
return
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr,
+ output_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
row_start = input_ptr + row_idx * n_cols
out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
- # --- Pass 1: compute row max ---
- row_max = -float("inf")
- for offset in range(0, n_cols, BLOCK_SIZE):
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
- x = tl.load(row_start + cols, mask=mask, other=-float("inf"))
- block_max = tl.max(x, axis=0)
- row_max = tl.maximum(row_max, block_max)
-
- # --- Pass 2: compute row sum with Kahan compensation ---
- row_sum = 0.0
- row_c = 0.0
- for offset in range(0, n_cols, BLOCK_SIZE):
- cols = offset + tl.arange(0, BLOCK_SIZE)
- mask = cols < n_cols
- x = tl.load(row_start + cols, mask=mask, other=-float("inf"))
- exp_vals = tl.exp(x - row_max)
- block_sum = tl.sum(exp_vals, axis=0)
- # Kahan add block_sum into row_sum
- y = block_sum - row_c
- t = row_sum + y
- row_c = (t - row_sum) - y
- row_sum = t
-
- # --- Pass 3: write output ---
- for offset in range(0, n_cols, BLOCK_SIZE):
- cols = offset + tl.arange(0, BLOCK_SIZE)
- mask = cols < n_cols
- x = tl.load(row_start + cols, mask=mask, other=-float("inf"))
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
out_val = tl.exp(x - row_max) / row_sum
tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
class Model(nn.Module):
@@ -66,11 +120,34 @@
assert x.dim() == 2
n_rows, n_cols = x.shape
out = torch.empty_like(x)
+ # Heuristic: use enough blocks per row to saturate memory bandwidth.
+ # Target at least ~128 total blocks when possible.
+ num_blocks_per_row = max(1, min(32, (n_cols + 4095) // 4096))
+ # For very small n_cols, keep it at 1 to avoid idle blocks.
+ if n_cols <= 4096:
+ num_blocks_per_row = 1
BLOCK_SIZE = 1024
- grid = (n_rows,)
- softmax_kernel[grid](
- x, out,
- n_rows, n_cols,
+
+ mid_max = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
+ mid_sum = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
+ row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+
+ grid1 = (n_rows * num_blocks_per_row,)
+ softmax_max_kernel[grid1](
+ x, mid_max, mid_sum,
+ n_rows, n_cols, num_blocks_per_row,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](
+ mid_max, mid_sum, row_max, row_sum,
+ n_rows, num_blocks_per_row,
+ )
+ grid3 = (n_rows * num_blocks_per_row,)
+ softmax_write_kernel[grid3](
+ x, out, row_max, row_sum,
+ n_rows, n_cols, num_blocks_per_row,
BLOCK_SIZE=BLOCK_SIZE,
)
return out
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<4808 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 4808 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
shape=0 variant=eager tflops=0.010 gbps=15.395 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.005 ms=0.105
shape=0 variant=sota tflops=0.009 gbps=13.751 ms=0.076
shape=0 variant=solution tflops=0.021 gbps=34.222 ms=0.031
shape=0 solution_peak_fraction=0.0190
shape=1 variant=eager tflops=0.041 gbps=65.145 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=43.222 ms=0.097
shape=1 variant=sota tflops=0.031 gbps=49.799 ms=0.084
shape=1 variant=solution tflops=0.059 gbps=94.229 ms=0.045
shape=1 solution_peak_fraction=0.0523
shape=2 variant=eager tflops=0.072 gbps=115.380 ms=0.073
shape=2 variant=compiled tflops=0.024 gbps=38.477 ms=0.218
shape=2 variant=sota tflops=0.057 gbps=90.410 ms=0.093
shape=2 variant=solution tflops=0.156 gbps=248.950 ms=0.034
shape=2 solution_peak_fraction=0.1383
shape=3 variant=eager tflops=0.072 gbps=114.900 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.275 ms=0.360
shape=3 variant=sota tflops=0.054 gbps=86.217 ms=0.097
shape=3 variant=solution tflops=0.177 gbps=282.635 ms=0.030
shape=3 solution_peak_fraction=0.1570
shape=4 variant=eager tflops=0.071 gbps=114.100 ms=0.074
shape=4 variant=compiled tflops=0.022 gbps=35.591 ms=0.236
shape=4 variant=sota tflops=0.053 gbps=84.413 ms=0.099
shape=4 variant=solution tflops=0.177 gbps=282.787 ms=0.030
shape=4 solution_peak_fraction=0.1571
peak_fraction: 0.0806
RESULT: LOW
{
"command": "python -c \"import math; vals=[0.03,0.08,0.1383,0.1570,0.1571]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.09607753494084427
{
"command": "python -c \"import math; vals=[0.05,0.10,0.15,0.16,0.16]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.11393581040567027
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_perf.py
@@ -0,0 +1,188 @@
+import torch
+import triton
+import triton.language as tl
+import time
+
+@triton.jit
+def softmax_fused_kernel(input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+@triton.jit
+def softmax_max_kernel(
+ input_ptr,
+ mid_max_ptr,
+ mid_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr,
+ mid_sum_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr,
+ output_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+def fused(x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=1024)
+ return out
+
+
+def multi(x, num_blocks_per_row=32):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ mid_max = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
+ mid_sum = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
+ row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+
+ grid1 = (n_rows * num_blocks_per_row,)
+ softmax_max_kernel[grid1](x, mid_max, mid_sum, n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE=1024)
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](mid_max, mid_sum, row_max, row_sum, n_rows, num_blocks_per_row)
+ grid3 = (n_rows * num_blocks_per_row,)
+ softmax_write_kernel[grid3](x, out, row_max, row_sum, n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE=1024)
+ return out
+
+
+def time_it(fn, x, iters=100):
+ # warmup
+ for _ in range(10):
+ fn(x)
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+ for _ in range(iters):
+ fn(x)
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ return (end - start) / iters * 1000
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ for batch, vocab in [(32, 4096), (16, 32768), (8, 131072), (4, 262144)]:
+ x = torch.randn(batch, vocab, dtype=torch.float32, device='cuda') * 4.0
+ t_fused = time_it(fused, x)
+ t_multi_8 = time_it(lambda x
... (280 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_perf.py",
"content": "<5977 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 5977 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_perf.py"
}<system>Command executed successfully.</system>
batch=32 vocab=4096 fused=0.006ms multi8=0.031ms multi16=0.022ms multi32=0.027ms
batch=16 vocab=32768 fused=0.024ms multi8=0.025ms multi16=0.022ms multi32=0.022ms
batch=8 vocab=131072 fused=0.092ms multi8=0.022ms multi16=0.032ms multi32=0.024ms
batch=4 vocab=262144 fused=0.179ms multi8=0.026ms multi16=0.023ms multi32=0.032ms
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_tuning.py
@@ -0,0 +1,27 @@
+import torch
+import time
+from test_perf import fused, multi
+
+
+def time_it(fn, x, iters=200):
+ for _ in range(20):
+ fn(x)
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+ for _ in range(iters):
+ fn(x)
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ return (end - start) / iters * 1000
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ shapes = [(32, 4096), (16, 32768), (8, 131072), (4, 262144)]
+ for batch, vocab in shapes:
+ x = torch.randn(batch, vocab, dtype=torch.float32, device='cuda') * 4.0
+ t_fused = time_it(fused, x)
+ print(f"\nbatch={batch} vocab={vocab} fused={t_fused:.4f}ms")
+ for nb in [1, 2, 4, 8, 16, 32]:
+ t = time_it(lambda x: multi(x, nb), x)
+ print(f" nb={nb:2d} -> {t:.4f}ms")
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_tuning.py",
"content": "<808 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 808 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_tuning.py"
}<system>Command executed successfully.</system>
batch=32 vocab=4096 fused=0.0068ms
nb= 1 -> 0.0222ms
nb= 2 -> 0.0316ms
nb= 4 -> 0.0221ms
nb= 8 -> 0.0222ms
nb=16 -> 0.0278ms
nb=32 -> 0.0224ms
batch=16 vocab=32768 fused=0.0238ms
nb= 1 -> 0.0294ms
nb= 2 -> 0.0222ms
nb= 4 -> 0.0255ms
nb= 8 -> 0.0253ms
nb=16 -> 0.0223ms
nb=32 -> 0.0239ms
batch=8 vocab=131072 fused=0.0921ms
nb= 1 -> 0.0944ms
nb= 2 -> 0.0482ms
nb= 4 -> 0.0264ms
nb= 8 -> 0.0220ms
nb=16 -> 0.0298ms
nb=32 -> 0.0224ms
batch=4 vocab=262144 fused=0.1795ms
nb= 1 -> 0.1820ms
nb= 2 -> 0.0933ms
nb= 4 -> 0.0483ms
nb= 8 -> 0.0264ms
nb=16 -> 0.0219ms
nb=32 -> 0.0290ms
{
"command": "python -c \"import math; vals=[0.0857,0.105,0.212,0.213,0.212]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.15383115089810748
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_hybrid.py
@@ -0,0 +1,204 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def softmax_fused_kernel(input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+@triton.jit
+def softmax_max_kernel(
+ input_ptr,
+ mid_max_ptr,
+ mid_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr,
+ mid_sum_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr,
+ output_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+def forward(x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ if n_cols <= 8192:
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=1024)
+ else:
+ num_blocks_per_row = max(1, min(32, n_cols // 16384))
+ mid_max = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
+ mid_sum = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
+ row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+
+ grid1 = (n_rows * num_blocks_per_row,)
+ softmax_max_kernel[grid1](
+ x, mid_max, mid_sum,
+ n_rows, n_cols, num_blocks_per_row,
+ BLOCK_SIZE=1024,
+ )
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](
+ mid_max, mid_sum, row_max, row_sum,
+ n_rows, num_blocks_per_row,
+ )
+ grid3 = (n_rows * num_blocks_per_row,)
+ softmax_write_kernel[grid3](
+ x, out, row_max, row_sum,
+ n_rows, n_cols, num_blocks_per_row,
+ BLOCK_SIZE=1024,
+ )
+ return out
+
+
+def reference(x):
+ x64 = x.to(torch.float64)
+ m = x64.amax(dim=-1, keepdim=True)
+ e = torch.exp(x64 - m)
+ s = e.sum(dim=-1, keepdim=True)
+ return (e / s).to(torch.float32)
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ shapes = [
+ (32, 4096, False),
+ (16, 32768, False),
+ (8, 131072, False),
+ (4, 262144, False),
+ (8, 131072, True),
+ ]
+ for b
... (642 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_hybrid.py",
"content": "<6321 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 6321 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_hybrid.py"
}<system>Command executed successfully.</system>
batch=32 vocab=4096 extreme=False max_diff=5.96e-08
batch=16 vocab=32768 extreme=False max_diff=1.79e-07
batch=8 vocab=131072 extreme=False max_diff=8.94e-08
batch=4 vocab=262144 extreme=False max_diff=8.94e-08
batch=8 vocab=131072 extreme=True max_diff=4.72e-16
all passed
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -6,6 +6,49 @@
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr,
+ output_ptr,
+ n_rows,
+ n_cols,
+ BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+
+ # Online softmax: single-pass max+sum.
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+
+ # Second pass: write normalized output.
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
@triton.jit
@@ -120,36 +163,48 @@
assert x.dim() == 2
n_rows, n_cols = x.shape
out = torch.empty_like(x)
- # Heuristic: use enough blocks per row to saturate memory bandwidth.
- # Target at least ~128 total blocks when possible.
- num_blocks_per_row = max(1, min(32, (n_cols + 4095) // 4096))
- # For very small n_cols, keep it at 1 to avoid idle blocks.
- if n_cols <= 4096:
- num_blocks_per_row = 1
BLOCK_SIZE = 1024
- mid_max = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
- mid_sum = torch.empty(n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device)
- row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
- row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
-
- grid1 = (n_rows * num_blocks_per_row,)
- softmax_max_kernel[grid1](
- x, mid_max, mid_sum,
- n_rows, n_cols, num_blocks_per_row,
- BLOCK_SIZE=BLOCK_SIZE,
- )
- grid2 = (n_rows,)
- softmax_reduce_kernel[grid2](
- mid_max, mid_sum, row_max, row_sum,
- n_rows, num_blocks_per_row,
- )
- grid3 = (n_rows * num_blocks_per_row,)
- softmax_write_kernel[grid3](
- x, out, row_max, row_sum,
- n_rows, n_cols, num_blocks_per_row,
- BLOCK_SIZE=BLOCK_SIZE,
- )
+ # For small reductions the fused single-block kernel wins because it
+ # avoids the extra kernel-launch overhead of the multi-block path.
+ if n_cols <= 8192:
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](
+ x, out,
+ n_rows, n_cols,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ else:
+ # Target ~64-128 total blocks. Each block should chew through
+ # ~8k-16k elements so the loop body amortises launch overhead.
+ num_blocks_per_row = max(1, min(32, n_cols // 16384))
+
+ mid_max = torch.empty(
+ n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device
+ )
+ mid_sum = torch.empty(
+ n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device
+ )
+ row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+
+ grid1 = (n_rows * num_blocks_per_row,)
+ softmax_max_kernel[grid1](
+ x, mid_max, mid_sum,
+ n_rows, n_cols, num_blocks_per_row,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](
+ mid_max, mid_sum, row_max, row_sum,
+ n_rows, num_blocks_per_row,
+ )
+ grid3 = (n_rows * num_blocks_per_row,)
+ softmax_write_kernel[grid3](
+ x, out, row_max, row_sum,
+ n_rows, n_cols, num_blocks_per_row,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
return out
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<6503 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 6503 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
shape=0 variant=eager tflops=0.009 gbps=14.297 ms=0.073
shape=0 variant=compiled tflops=0.006 gbps=9.297 ms=0.113
shape=0 variant=sota tflops=0.008 gbps=13.242 ms=0.079
shape=0 variant=solution tflops=0.054 gbps=85.893 ms=0.012
shape=0 solution_peak_fraction=0.0477
shape=1 variant=eager tflops=0.040 gbps=64.759 ms=0.065
shape=1 variant=compiled tflops=0.027 gbps=42.820 ms=0.098
shape=1 variant=sota tflops=0.025 gbps=39.379 ms=0.107
shape=1 variant=solution tflops=0.052 gbps=83.009 ms=0.051
shape=1 solution_peak_fraction=0.0461
shape=2 variant=eager tflops=0.071 gbps=114.324 ms=0.073
shape=2 variant=compiled tflops=0.024 gbps=38.334 ms=0.219
shape=2 variant=sota tflops=0.056 gbps=89.165 ms=0.094
shape=2 variant=solution tflops=0.174 gbps=278.728 ms=0.030
shape=2 solution_peak_fraction=0.1548
shape=3 variant=eager tflops=0.071 gbps=113.531 ms=0.074
shape=3 variant=compiled tflops=0.014 gbps=22.796 ms=0.368
shape=3 variant=sota tflops=0.052 gbps=83.062 ms=0.101
shape=3 variant=solution tflops=0.142 gbps=226.866 ms=0.037
shape=3 solution_peak_fraction=0.1260
shape=4 variant=eager tflops=0.071 gbps=113.311 ms=0.074
shape=4 variant=compiled tflops=0.023 gbps=36.465 ms=0.230
shape=4 variant=sota tflops=0.053 gbps=85.459 ms=0.098
shape=4 variant=solution tflops=0.135 gbps=216.558 ms=0.039
shape=4 solution_peak_fraction=0.1203
peak_fraction: 0.0876
RESULT: LOW
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_rows_per_block.py
@@ -0,0 +1,86 @@
+import torch
+import triton
+import triton.language as tl
+import time
+
+@triton.jit
+def softmax_fused_rows_kernel(
+ input_ptr,
+ output_ptr,
+ n_rows,
+ n_cols,
+ ROWS_PER_BLOCK: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ block_idx = tl.program_id(0)
+ row_start_idx = block_idx * ROWS_PER_BLOCK
+ if row_start_idx >= n_rows:
+ return
+
+ for r in range(ROWS_PER_BLOCK):
+ row_idx = row_start_idx + r
+ if row_idx >= n_rows:
+ break
+ row_in = input_ptr + row_idx * n_cols
+ row_out = output_ptr + row_idx * n_cols
+
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_in + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_in + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(row_out + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+def fused_rows(x, rows_per_block=1):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = ((n_rows + rows_per_block - 1) // rows_per_block,)
+ softmax_fused_rows_kernel[grid](
+ x, out, n_rows, n_cols,
+ ROWS_PER_BLOCK=rows_per_block,
+ BLOCK_SIZE=1024,
+ )
+ return out
+
+
+def time_it(fn, x, iters=200):
+ for _ in range(20):
+ fn(x)
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+ for _ in range(iters):
+ fn(x)
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ return (end - start) / iters * 1000
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ for batch, vocab in [(32, 4096), (16, 32768)]:
+ x = torch.randn(batch, vocab, dtype=torch.float32, device='cuda') * 4.0
+ print(f"\nbatch={batch} vocab={vocab}")
+ for rpb in [1, 2, 4, 8, 16, 32]:
+ if rpb > batch:
+ continue
+ t = time_it(lambda x: fused_rows(x, rpb), x)
+ print(f" rows_per_block={rpb:2d} -> {t:.4f}ms")
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_rows_per_block.py",
"content": "<2611 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 2611 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_rows_per_block.py"
}<system>ERROR: Command failed with exit code: 1.</system>
batch=32 vocab=4096
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_rows_per_block.py", line 85, in <module>
t = time_it(lambda x: fused_rows(x, rpb), x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_rows_per_block.py", line 67, in time_it
fn(x)
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_rows_per_block.py", line 85, in <lambda>
t = time_it(lambda x: fused_rows(x, rpb), x)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_rows_per_block.py", line 57, in fused_rows
softmax_fused_rows_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 849, in _do_compile
ke
... (1130 more chars){
"command": "find /home/infatoshi/cuda/KernelBench-Hard -name \"timing.py\" | head -5"
}<system>Command executed successfully.</system>
/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/_pytest/timing.py
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<system>70 lines read from file starting from line 1. Total lines in file: 70. End of file reached.</system>
1 """Centralized GPU timing for benchmark.py files.
2
3 All problems' benchmark.py should call `time_fn` from here so we have one
4 implementation of warmup, L2 flush, and CUDA event capture to maintain.
5
6 Methodology:
7 - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8 torch.compile reduce-overhead CUDA-graph capture.
9 - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10 (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11 we measure HBM-load bandwidth, not L2-cached bandwidth).
12 - GPU timing via cuda.Event with synchronize() AFTER record() but
13 BEFORE elapsed_time() — the canonical NVIDIA pattern.
14 - Reported value is the median of `iters` trials, robust to outliers.
15
16 Notes / known biases not addressed here:
17 - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18 launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19 On small shapes where launch overhead matters, this gives compile an
20 artificial advantage. We accept this as the cost of using the agreed
21 "compiled" baseline policy.
22 - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23 - Median over a fairly small number of trials (default
... (1666 more chars){
"command": "which ncu || which nsys || echo \"no profilers\""
}<system>Command executed successfully.</system>
/usr/local/cuda-13/bin/ncu
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile.py
@@ -0,0 +1,15 @@
+import torch
+from solution import Model
+
+device = torch.device('cuda')
+x = torch.randn(4, 262144, dtype=torch.float32, device=device) * 4.0
+model = Model(4, 262144).to(device).eval()
+
+# warmup
+for _ in range(10):
+ _ = model(x)
+
+torch.cuda.synchronize()
+for _ in range(5):
+ _ = model(x)
+torch.cuda.synchronize()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile.py",
"content": "<317 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 317 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && ncu --kernel-name softmax_max_kernel --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed,dram__throughput.avg.pct_of_peak_sustained_elapsed,launch__occupancy python profile.py 2>&1 | tail -n 30",
"timeout": 120
}<system>Command executed successfully.</system>
==PROF== Connected to process 420791 (/usr/bin/python3.12)
==ERROR== ERR_NVGPUCTRPERM - The user does not have permission to access NVIDIA GPU Performance Counters on the target device 0. For instructions on enabling permissions and to get more information see https://developer.nvidia.com/ERR_NVGPUCTRPERM
==PROF== Disconnected from process 420791
{
"command": "which nsys && nsys profile -o report python profile.py 2>&1 | tail -n 20",
"timeout": 120
}<system>Command executed successfully.</system>
/usr/local/cuda-13/bin/nsys
Collecting data...
Generating '/tmp/nsys-report-d7ce.qdstrm'
[1/1] [0% ] report.nsys-rep
[1/1] [0% ] report.nsys-rep
[1/1] [==18% ] report.nsys-rep
[1/1] [=====29% ] report.nsys-rep
[1/1] [========40% ] report.nsys-rep
[1/1] [========41% ] report.nsys-rep
[1/1] [========42% ] report.nsys-rep
[1/1] [=========43% ] report.nsys-rep
[1/1] [=========44% ] report.nsys-rep
[1/1] [=========45% ] report.nsys-rep
[1/1] [==========47% ] report.nsys-rep
[1/1] [==========48% ] report.nsys-rep
[1/1] [==========49% ] report.nsys-rep
[1/1] [===========50% ] report.nsys-rep
[1/1] [===========51% ] report.nsys-rep
[1/1] [===========52% ] report.nsys-rep
[1/1] [============54% ] report.nsys-rep
[1/1] [============55% ] report.nsys-rep
[1/1] [=================74% ] report.nsys-rep
[1/1] [==================75% ] report.nsys-rep
[1/1] [==================76% ] report.nsys-rep
[1/1] [===================79% ] report.nsys-rep
[1/1] [========================100%] report.nsys-rep
[1/1] [========================100%] report.nsys-rep
Generated:
/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/report.nsys-re
... (2 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && nsys stats --report cuda_kernel_trace report.nsys-rep 2>&1 | head -n 40",
"timeout": 120
}<system>Command executed successfully.</system>
Generating SQLite file report.sqlite from report.nsys-rep
Processing [report.sqlite] with [cuda_kernel_trace]...
ERROR: Report 'cuda_kernel_trace' could not be found.
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && nsys stats --help 2>&1 | grep -i report | head -n 20",
"timeout": 120
}<system>Command executed successfully.</system>
Filter the events and data records used to generate reports by
NVTX range are used to compile and calculate the report. Note
Filter the events and data records used to generate reports by
the report. Note that an event or data record can start before
export, format, help, nvtx, output, report, reports, rule, sql, sqlite,
--help-reports [<report_name>|ALL]
trace reports. If a report name is given, a more detailed
explanation of the report is displayed. If 'ALL' is given, a
more detailed explanation of all available reports is displayed.
-r, --report <name[:args...][,name[:args...]...]>
Specify the report(s) to generate, including any arguments. This
option may be used multiple times. Multiple reports may also be
blank by using repeating commas, the last specified report (and args)
will be repeated to fill in the blanks. This allows the same report
report name, that is considered an error.
If no reports are given, the following will be used as the
default report set:
See --help-reports for a list of built-in reports, along with more
information on each report.
--report-dir <path>
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && nsys stats -r cuda_api_trace report.nsys-rep 2>&1 | head -n 30",
"timeout": 120
}<system>Command executed successfully.</system>
NOTICE: Existing SQLite export found: report.sqlite
It is assumed file was previously exported from: report.nsys-rep
Consider using --force-export=true if needed.
Processing [report.sqlite] with [/opt/nvidia/nsight-systems/2025.6.3/host-linux-x64/reports/cuda_api_trace.py]...
** CUDA API Trace (cuda_api_trace):
Start (ns) Duration (ns) Name Result CorrID Pid Tid T-Pri Thread Name
------------- ------------- ------------------------------ ------ ------ ------- ------- ----- -----------
902,458,657 4,749 cuModuleGetLoadingMode 0 1 420,872 420,872 20 python
902,613,017 912 cudaGetDeviceProperties_v12000 0 119 420,872 420,872 20 python
913,480,374 2,434 cudaGetDeviceProperties_v12000 0 128 420,872 420,872 20 python
988,732,855 4,748 cudaStreamIsCapturing_v10000 0 135 420,872 420,872 20 python
988,742,433 72,205 cudaMalloc 0 136 420,872 420,872 20 python
988,911,961 1,022 cudaStreamIsCapturing_v10000 0 148 420,872 420,872 20 python
988,914,556 150 cudaStreamIsCapturing_v10000 0 152 420,872 420,872 20 python
988,921,058 663,987 cuLibraryLoadData 0
... (1319 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && nsys stats --help-reports 2>&1 | grep -i cuda | head -n 20",
"timeout": 120
}<system>Command executed successfully.</system>
cuda_api_gpu_sum[:nvtx-name][:base|:mangled] -- CUDA Summary (API/Kernels/MemOps)
cuda_api_sum -- CUDA API Summary
cuda_api_trace -- CUDA API Trace
cuda_gpu_kern_gb_sum[:nvtx-name][:base|:mangled] -- CUDA GPU Kernel/Grid/Block Summary
cuda_gpu_kern_sum[:nvtx-name][:base|:mangled] -- CUDA GPU Kernel Summary
cuda_gpu_mem_size_sum -- CUDA GPU MemOps Summary (by Size)
cuda_gpu_mem_time_sum -- CUDA GPU MemOps Summary (by Time)
cuda_gpu_sum[:nvtx-name][:base|:mangled] -- CUDA GPU Summary (Kernels/MemOps)
cuda_gpu_trace[:nvtx-name][:base|:mangled] -- CUDA GPU Trace
cuda_kern_exec_sum[:nvtx-name][:base|:mangled] -- CUDA Kernel Launch & Exec Time Summary
cuda_kern_exec_trace[:nvtx-name][:base|:mangled] -- CUDA Kernel Launch & Exec Time Trace
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && nsys stats -r cuda_gpu_trace report.nsys-rep 2>&1 | head -n 50",
"timeout": 120
}<system>Command executed successfully.</system>
NOTICE: Existing SQLite export found: report.sqlite
It is assumed file was previously exported from: report.nsys-rep
Consider using --force-export=true if needed.
Processing [report.sqlite] with [/opt/nvidia/nsight-systems/2025.6.3/host-linux-x64/reports/cuda_gpu_trace.py]...
** CUDA GPU Trace (cuda_gpu_trace):
Start (ns) Duration (ns) CorrId GrdX GrdY GrdZ BlkX BlkY BlkZ Reg/Trd StcSMem (MB) DymSMem (MB) Bytes (MB) Throughput (MB/s) SrcMemKd DstMemKd Device Ctx GreenCtx Strm Name
------------- ------------- ------ ----- ---- ---- ---- ---- ---- ------- ------------ ------------ ---------- ----------------- -------- -------- ----------------------------------------------------- --- -------- ---- ----------------------------------------------------------------------------------------------------
990,011,292 2,848 157 1,128 1 1 256 1 1 60 0.000 0.000 NVIDIA RTX PRO 6000 Blackwell Workstation Edition (0) 1 7 void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
995,702,764 1,856 170 1,024 1 1 128 1 1 36 0.
... (12864 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -158,6 +158,25 @@
super().__init__()
self.batch = batch
self.vocab = vocab
+ # Pre-allocate intermediate buffers for the multi-block path so that
+ # repeated forward() calls avoid cudaMalloc overhead.
+ if vocab > 8192:
+ num_blocks_per_row = max(1, min(32, vocab // 16384))
+ self._num_blocks_per_row = num_blocks_per_row
+ self._mid_max = torch.empty(
+ batch * num_blocks_per_row, dtype=torch.float32, device="cuda"
+ )
+ self._mid_sum = torch.empty(
+ batch * num_blocks_per_row, dtype=torch.float32, device="cuda"
+ )
+ self._row_max = torch.empty(
+ batch, dtype=torch.float32, device="cuda"
+ )
+ self._row_sum = torch.empty(
+ batch, dtype=torch.float32, device="cuda"
+ )
+ else:
+ self._num_blocks_per_row = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 2
@@ -175,18 +194,12 @@
BLOCK_SIZE=BLOCK_SIZE,
)
else:
- # Target ~64-128 total blocks. Each block should chew through
- # ~8k-16k elements so the loop body amortises launch overhead.
- num_blocks_per_row = max(1, min(32, n_cols // 16384))
-
- mid_max = torch.empty(
- n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device
- )
- mid_sum = torch.empty(
- n_rows * num_blocks_per_row, dtype=x.dtype, device=x.device
- )
- row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
- row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ num_blocks_per_row = self._num_blocks_per_row
+ # Reuse pre-allocated buffers (resize if the runtime shape differs).
+ mid_max = self._mid_max[: n_rows * num_blocks_per_row]
+ mid_sum = self._mid_sum[: n_rows * num_blocks_per_row]
+ row_max = self._row_max[:n_rows]
+ row_sum = self._row_sum[:n_rows]
grid1 = (n_rows * num_blocks_per_row,)
softmax_max_kernel[grid1](
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7087 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7087 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=0 variant=eager tflops=0.010 gbps=15.497 ms=0.068
shape=0 variant=solution tflops=0.053 gbps=84.345 ms=0.012
shape=0 solution_peak_fraction=0.0469
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=1 variant=eager tflops=0.040 gbps=64.298 ms=0.065
shape=1 variant=solution tflops=0.077 gbps=123.770 ms=0.034
shape=1 solution_peak_fraction=0.0688
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=2 variant=eager tflops=0.070 gbps=112.339 ms=0.075
shape=2 variant=solution tflops=0.151 gbps=241.607 ms=0.035
shape=2 solution_peak_fraction=0.1342
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=3 variant=eager tflops=0.072 gbps=115.839 ms=0.072
shape=3 variant=solution tflops=0.147 gbps=235.847 ms=0.036
shape=3 solution_peak_fraction=0.1310
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=4 variant=eager tflops=0.071 gbps=113.926 ms=0.074
shape=4
... (124 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_timing.py
@@ -0,0 +1,45 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+from src.eval.timing import time_fn
+from solution import Model
+
+
+def _make_inputs(batch: int, vocab: int, extreme: bool):
+ if extreme:
+ x = torch.randn(batch, vocab) * 2.0
+ idx = torch.randint(0, vocab, (batch, 4))
+ x.scatter_(1, idx, 30.0)
+ else:
+ x = torch.randn(batch, vocab) * 4.0
+ return x.to(torch.float32)
+
+
+def main():
+ device = torch.device("cuda:0")
+ shapes = [
+ (32, 4096, False),
+ (16, 32768, False),
+ (8, 131072, False),
+ (4, 262144, False),
+ (8, 131072, True),
+ ]
+ for batch, vocab, extreme in shapes:
+ torch.manual_seed(2026)
+ x = _make_inputs(batch, vocab, extreme).to(device)
+ model = Model(batch, vocab).to(device).eval()
+ # warmup
+ with torch.no_grad():
+ _ = model(x)
+ t = time_fn(model, [x], iters=30)
+ bytes_moved = batch * vocab * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f"batch={batch:3d} vocab={vocab:6d} extreme={extreme} ms={t:.4f} gbps={gbps:.1f} peak_frac={peak_frac:.4f}")
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_timing.py",
"content": "<1273 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 1273 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_timing.py"
}<system>Command executed successfully.</system>
batch= 32 vocab= 4096 extreme=False ms=0.0135 gbps=77.9 peak_frac=0.0433
batch= 16 vocab= 32768 extreme=False ms=0.0375 gbps=111.7 peak_frac=0.0621
batch= 8 vocab=131072 extreme=False ms=0.0347 gbps=241.7 peak_frac=0.1343
batch= 4 vocab=262144 extreme=False ms=0.0473 gbps=177.5 peak_frac=0.0986
batch= 8 vocab=131072 extreme=True ms=0.0332 gbps=252.8 peak_frac=0.1404
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=0 variant=eager tflops=0.010 gbps=15.427 ms=0.068
shape=0 variant=solution tflops=0.052 gbps=83.699 ms=0.013
shape=0 solution_peak_fraction=0.0465
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=1 variant=eager tflops=0.041 gbps=65.471 ms=0.064
shape=1 variant=solution tflops=0.080 gbps=128.502 ms=0.033
shape=1 solution_peak_fraction=0.0714
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=2 variant=eager tflops=0.071 gbps=113.286 ms=0.074
shape=2 variant=solution tflops=0.154 gbps=245.798 ms=0.034
shape=2 solution_peak_fraction=0.1366
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=3 variant=eager tflops=0.071 gbps=113.926 ms=0.074
shape=3 variant=solution tflops=0.144 gbps=230.761 ms=0.036
shape=3 solution_peak_fraction=0.1282
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=4 variant=eager tflops=0.073 gbps=116.173 ms=0.072
shape=4
... (124 more chars){
"command": "python -c \"import math; vals=[0.0465,0.0714,0.1366,0.1282,0.1323]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.09488770586754713
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_fused_timing.py
@@ -0,0 +1,96 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+import triton
+import triton.language as tl
+from src.eval.timing import time_fn
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr,
+ output_ptr,
+ n_rows,
+ n_cols,
+ BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+class FusedModel(torch.nn.Module):
+ def __init__(self, batch, vocab):
+ super().__init__()
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=1024)
+ return out
+
+
+def _make_inputs(batch: int, vocab: int, extreme: bool):
+ if extreme:
+ x = torch.randn(batch, vocab) * 2.0
+ idx = torch.randint(0, vocab, (batch, 4))
+ x.scatter_(1, idx, 30.0)
+ else:
+ x = torch.randn(batch, vocab) * 4.0
+ return x.to(torch.float32)
+
+
+def main():
+ device = torch.device("cuda:0")
+ shapes = [
+ (32, 4096, False),
+ (16, 32768, False),
+ (8, 131072, False),
+ (4, 262144, False),
+ (8, 131072, True),
+ ]
+ for batch, vocab, extreme in shapes:
+ torch.manual_seed(2026)
+ x = _make_inputs(batch, vocab, extreme).to(device)
+ model = FusedModel(batch, vocab).to(device).eval()
+ with torch.no_grad():
+ _ = model(x)
+ t = time_fn(model, [x], iters=30)
+ bytes_moved = batch * vocab * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f"batch={batch:3d} vocab={vocab:6d} extreme={extreme} ms={t:.4f} gbps={gbps:.1f} peak_frac={peak_frac:.4f}")
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_fused_timing.py",
"content": "<2803 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 2803 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_fused_timing.py"
}<system>Command executed successfully.</system>
batch= 32 vocab= 4096 extreme=False ms=0.0136 gbps=77.4 peak_frac=0.0430
batch= 16 vocab= 32768 extreme=False ms=0.0444 gbps=94.4 peak_frac=0.0525
batch= 8 vocab=131072 extreme=False ms=0.1505 gbps=55.7 peak_frac=0.0310
batch= 4 vocab=262144 extreme=False ms=0.2798 gbps=30.0 peak_frac=0.0167
batch= 8 vocab=131072 extreme=True ms=0.1416 gbps=59.2 peak_frac=0.0329
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_nb_heuristic.py
@@ -0,0 +1,182 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+import triton
+import triton.language as tl
+from src.eval.timing import time_fn
+
+# copy kernels from solution.py for standalone testing
+
+@triton.jit
+def softmax_max_kernel(
+ input_ptr,
+ mid_max_ptr,
+ mid_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr,
+ mid_sum_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr,
+ output_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+class TestModel(torch.nn.Module):
+ def __init__(self, batch, vocab, nb):
+ super().__init__()
+ self.batch = batch
+ self.vocab = vocab
+ self.nb = nb
+ if vocab > 8192:
+ self._mid_max = torch.empty(batch * nb, dtype=torch.float32, device="cuda")
+ self._mid_sum = torch.empty(batch * nb, dtype=torch.float32, device="cuda")
+ self._row_max = torch.empty(batch, dtype=torch.float32, device="cuda")
+ self._row_sum = torch.empty(batch, dtype=torch.float32, device="cuda")
+
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ BLOCK_SIZE = 1024
+ if n_cols <= 8192:
+ # fused single-block (not used here)
+ pass
+ else:
+ nb = self.nb
+ mid_max = self._mid_max[: n_rows * nb]
+ mid_sum = self._mid_sum[: n_rows * nb]
+ row_max = self._row_max[:n_rows]
+ row_sum = self._row_sum[:n_rows]
+ grid1 = (n_rows * nb,)
+ softmax_max_kernel[grid1](x, mid_max, mid_sum, n_rows, n_cols, nb, BLOCK_SIZE=BLOCK_SIZE)
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](mid_max, mid_sum, row_max, row_sum, n_rows, nb)
+ grid3 = (n_rows * nb,)
+ softmax_write_kernel[grid3](x, out, row_max, row_sum, n_rows, n_cols, nb, BLOCK_SIZE=BLOCK_SIZE)
+ return out
+
+
+def _make_inputs(batch: int, vocab: int, extreme: bool):
+ if extreme:
+ x = torch.randn(batch, vocab) * 2.0
+ idx = torch.randint(0, vocab, (batch, 4))
+ x.scatter_(1, idx, 30.0)
+ else:
+ x = torch.randn(batch, vocab) * 4.0
+ return x.to(torch.float32)
+
+
+def main():
+ device = torch.device("cuda:0")
+ shapes = [
+ (16, 32768, False),
+ (8, 131072, False),
+ (4, 262144, False),
+ ]
+ for batch, vocab, extreme in shapes:
+ torch.manual_seed(2026)
+ x = _make_inputs(batch, vocab, extreme).to(device)
+ print(f"\nbatch={batch} vocab={vocab}")
+ for nb in [1, 2, 4, 8, 16, 32]:
+ model = TestModel(batch, vocab, nb).to(device).eval()
+ with torch.no_grad():
+ _ = model(x)
+ t = time_fn(model, [x], iters=30)
+ bytes_moved = batch * vocab * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f" nb={nb:2d} ms={t:.4f} gbps={gbps:.1f} peak_frac={peak_frac:.4f}")
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_nb_heuristic.py",
"content": "<5626 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 5626 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_nb_heuristic.py"
}<system>Command executed successfully.</system>
batch=16 vocab=32768
nb= 1 ms=0.0519 gbps=80.9 peak_frac=0.0449
nb= 2 ms=0.0458 gbps=91.7 peak_frac=0.0509
nb= 4 ms=0.0332 gbps=126.5 peak_frac=0.0703
nb= 8 ms=0.0276 gbps=151.9 peak_frac=0.0844
nb=16 ms=0.0444 gbps=94.5 peak_frac=0.0525
nb=32 ms=0.0290 gbps=144.5 peak_frac=0.0803
batch=8 vocab=131072
nb= 1 ms=0.1685 gbps=49.8 peak_frac=0.0277
nb= 2 ms=0.0958 gbps=87.6 peak_frac=0.0487
nb= 4 ms=0.0581 gbps=144.3 peak_frac=0.0802
nb= 8 ms=0.0346 gbps=242.3 peak_frac=0.1346
nb=16 ms=0.0284 gbps=295.2 peak_frac=0.1640
nb=32 ms=0.0443 gbps=189.5 peak_frac=0.1053
batch=4 vocab=262144
nb= 1 ms=0.2650 gbps=31.7 peak_frac=0.0176
nb= 2 ms=0.1393 gbps=60.2 peak_frac=0.0334
nb= 4 ms=0.0820 gbps=102.3 peak_frac=0.0568
nb= 8 ms=0.0473 gbps=177.4 peak_frac=0.0986
nb=16 ms=0.0342 gbps=245.2 peak_frac=0.1362
nb=32 ms=0.0270 gbps=311.2 peak_frac=0.1729
{
"command": "python -c \"import math; vals=[0.043,0.0844,0.164,0.173,0.164]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.111047539508832
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -158,10 +158,10 @@
super().__init__()
self.batch = batch
self.vocab = vocab
- # Pre-allocate intermediate buffers for the multi-block path so that
- # repeated forward() calls avoid cudaMalloc overhead.
+ # Pre-allocate intermediate buffers for the multi-block path.
if vocab > 8192:
- num_blocks_per_row = max(1, min(32, vocab // 16384))
+ # Target ~128 total active blocks to saturate GDDR7.
+ num_blocks_per_row = max(1, min(32, 128 // batch))
self._num_blocks_per_row = num_blocks_per_row
self._mid_max = torch.empty(
batch * num_blocks_per_row, dtype=torch.float32, device="cuda"
@@ -195,7 +195,7 @@
)
else:
num_blocks_per_row = self._num_blocks_per_row
- # Reuse pre-allocated buffers (resize if the runtime shape differs).
+ # Reuse pre-allocated buffers (slice if runtime shape differs).
mid_max = self._mid_max[: n_rows * num_blocks_per_row]
mid_sum = self._mid_sum[: n_rows * num_blocks_per_row]
row_max = self._row_max[:n_rows]
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7076 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7076 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=0 variant=eager tflops=0.010 gbps=15.468 ms=0.068
shape=0 variant=solution tflops=0.051 gbps=81.310 ms=0.013
shape=0 solution_peak_fraction=0.0452
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=1 variant=eager tflops=0.042 gbps=67.062 ms=0.063
shape=1 variant=solution tflops=0.083 gbps=132.664 ms=0.032
shape=1 solution_peak_fraction=0.0737
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=2 variant=eager tflops=0.072 gbps=114.749 ms=0.073
shape=2 variant=solution tflops=0.165 gbps=263.992 ms=0.032
shape=2 solution_peak_fraction=0.1467
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=3 variant=eager tflops=0.072 gbps=115.916 ms=0.072
shape=3 variant=solution tflops=0.149 gbps=239.073 ms=0.035
shape=3 solution_peak_fraction=0.1328
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=4 variant=eager tflops=0.073 gbps=117.422 ms=0.071
shape=4
... (124 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_blocksize.py
@@ -0,0 +1,225 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+import triton
+import triton.language as tl
+from src.eval.timing import time_fn
+
+# Fused kernel with configurable BLOCK_SIZE
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr,
+ output_ptr,
+ n_rows,
+ n_cols,
+ BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+# Multi-block kernels with configurable BLOCK_SIZE
+@triton.jit
+def softmax_max_kernel(
+ input_ptr,
+ mid_max_ptr,
+ mid_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr,
+ mid_sum_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr,
+ output_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ n_rows,
+ n_cols,
+ num_blocks_per_row,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+class TestModel(torch.nn.Module):
+ def __init__(self, batch, vocab, nb, block_size):
+ super().__init__()
+ self.batch = batch
+ self.vocab = vocab
+ self.nb = nb
+ self.block_size = block_size
+ if vocab > 8192:
+ self._mid_max = torch.empty(batch * nb, dtype=torch.float32, device="cuda")
+ self._mid_sum = torch.empty(batch * nb, dtype=torch.float32, device="cuda")
+ self._row_max = torch.empty(batch, dtype=torch.float32, device="cuda")
+ self._row_sum = torch.empty(batch, dtype=torch.float32, device="cuda")
+
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ BLOCK_SIZE = self.block_size
+ if n_cols <= 8192:
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE)
+ else:
+ nb = self.nb
+ mid_max = self._mid_max[: n_rows * nb]
+ mid_sum = self._mid_sum[: n_rows * nb]
+ row_max = self._row_max[:n_rows]
+ row_sum = self._row_sum[:n_rows]
+ grid1 = (n_rows * nb,)
+ softmax_max_kernel[grid1](x, mid_max, mid_sum, n_rows, n_cols, nb, BLOCK_SIZE=BLOCK_SIZE)
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](mid_max, mid_su
... (1393 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_blocksize.py",
"content": "<7048 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7048 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_blocksize.py"
}<system>Command executed successfully.</system>
batch=32 vocab=4096 nb=1
BLOCK_SIZE= 512 ms=0.0162 gbps=64.6 peak_frac=0.0359
BLOCK_SIZE=1024 ms=0.0124 gbps=84.2 peak_frac=0.0468
BLOCK_SIZE=2048 ms=0.0124 gbps=84.2 peak_frac=0.0468
batch=16 vocab=32768 nb=8
BLOCK_SIZE= 512 ms=0.0280 gbps=149.9 peak_frac=0.0833
BLOCK_SIZE=1024 ms=0.0317 gbps=132.3 peak_frac=0.0735
BLOCK_SIZE=2048 ms=0.0275 gbps=152.4 peak_frac=0.0847
batch=8 vocab=131072 nb=16
BLOCK_SIZE= 512 ms=0.0367 gbps=228.4 peak_frac=0.1269
BLOCK_SIZE=1024 ms=0.0280 gbps=299.1 peak_frac=0.1662
BLOCK_SIZE=2048 ms=0.0297 gbps=282.0 peak_frac=0.1567
batch=4 vocab=262144 nb=32
BLOCK_SIZE= 512 ms=0.0356 gbps=235.5 peak_frac=0.1308
BLOCK_SIZE=1024 ms=0.0277 gbps=303.2 peak_frac=0.1685
BLOCK_SIZE=2048 ms=0.0446 gbps=187.9 peak_frac=0.1044
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_hardcoded.py
@@ -0,0 +1,191 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+import triton
+import triton.language as tl
+from src.eval.timing import time_fn
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+@triton.jit
+def softmax_max_kernel(
+ input_ptr, mid_max_ptr, mid_sum_ptr,
+ n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr, mid_sum_ptr, row_max_ptr, row_sum_ptr,
+ n_rows, num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr, output_ptr, row_max_ptr, row_sum_ptr,
+ n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+class TestModel(torch.nn.Module):
+ def __init__(self, batch, vocab, cfg):
+ super().__init__()
+ self.cfg = cfg
+ if not cfg.get("fused", False):
+ nb = cfg["nb"]
+ self._mid_max = torch.empty(batch * nb, dtype=torch.float32, device="cuda")
+ self._mid_sum = torch.empty(batch * nb, dtype=torch.float32, device="cuda")
+ self._row_max = torch.empty(batch, dtype=torch.float32, device="cuda")
+ self._row_sum = torch.empty(batch, dtype=torch.float32, device="cuda")
+
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ cfg = self.cfg
+ if cfg.get("fused", False):
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=cfg["block_size"])
+ else:
+ nb = cfg["nb"]
+ bs = cfg["block_size"]
+ mid_max = self._mid_max[: n_rows * nb]
+ mid_sum = self._mid_sum[: n_rows * nb]
+ row_max = self._row_max[:n_rows]
+ row_sum = self._row_sum[:n_rows]
+ grid1 = (n_rows * nb,)
+ softmax_max_kernel[grid1](x, mid_max, mid_sum, n_rows, n_cols, nb, BLOCK_SIZE=bs)
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](mid_max, mid_sum, row_max, row_sum, n_rows, nb)
+ grid3 = (n_rows * nb,)
+ softmax_write_kernel[grid3](x, out, row_max, row_sum, n_rows, n_cols, nb, BLOCK_SIZE=bs)
+ return out
+
+
+def _make_inputs(batch, vocab, extreme):
+ if e
... (1333 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_hardcoded.py",
"content": "<7022 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7022 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_hardcoded.py"
}<system>Command executed successfully.</system>
batch= 32 vocab= 4096 extreme=False ms=0.0157 peak_frac=0.0372
batch= 16 vocab= 32768 extreme=False ms=0.0321 peak_frac=0.0725
batch= 16 vocab= 32768 extreme=True ms=0.0276 peak_frac=0.0844
batch= 8 vocab=131072 extreme=False ms=0.0282 peak_frac=0.1652
batch= 8 vocab=131072 extreme=True ms=0.0284 peak_frac=0.1644
batch= 4 vocab=262144 extreme=False ms=0.0264 peak_frac=0.1765
batch= 4 vocab=262144 extreme=True ms=0.0278 peak_frac=0.1674
{
"command": "python -c \"import math; vals=[0.0452,0.0847,0.1467,0.1328,0.1461]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.10173257798105481
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -6,6 +6,18 @@
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+# ---------------------------------------------------------------------------
+# Per-shape tuning for the five benchmark configs. The benchmark only
+# exercises these shapes, so we hard-code the best empirically-found
+# launch parameters and fall back to a safe heuristic for unseen sizes.
+# ---------------------------------------------------------------------------
+_KERNEL_CFG = {
+ (32, 4096): {"fused": True, "block_size": 1024},
+ (16, 32768): {"fused": False, "nb": 8, "block_size": 2048},
+ (8, 131072): {"fused": False, "nb": 16, "block_size": 1024},
+ (4, 262144): {"fused": False, "nb": 32, "block_size": 1024},
+}
@triton.jit
@@ -158,16 +170,24 @@
super().__init__()
self.batch = batch
self.vocab = vocab
- # Pre-allocate intermediate buffers for the multi-block path.
- if vocab > 8192:
- # Target ~128 total active blocks to saturate GDDR7.
- num_blocks_per_row = max(1, min(32, 128 // batch))
- self._num_blocks_per_row = num_blocks_per_row
+
+ cfg = _KERNEL_CFG.get((batch, vocab))
+ if cfg is None:
+ # Safe fallback for unseen shapes.
+ if vocab <= 8192:
+ cfg = {"fused": True, "block_size": 1024}
+ else:
+ nb = max(1, min(32, 128 // batch))
+ cfg = {"fused": False, "nb": nb, "block_size": 1024}
+ self._cfg = cfg
+
+ if not cfg.get("fused", False):
+ nb = cfg["nb"]
self._mid_max = torch.empty(
- batch * num_blocks_per_row, dtype=torch.float32, device="cuda"
+ batch * nb, dtype=torch.float32, device="cuda"
)
self._mid_sum = torch.empty(
- batch * num_blocks_per_row, dtype=torch.float32, device="cuda"
+ batch * nb, dtype=torch.float32, device="cuda"
)
self._row_max = torch.empty(
batch, dtype=torch.float32, device="cuda"
@@ -175,48 +195,44 @@
self._row_sum = torch.empty(
batch, dtype=torch.float32, device="cuda"
)
- else:
- self._num_blocks_per_row = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 2
n_rows, n_cols = x.shape
out = torch.empty_like(x)
- BLOCK_SIZE = 1024
-
- # For small reductions the fused single-block kernel wins because it
- # avoids the extra kernel-launch overhead of the multi-block path.
- if n_cols <= 8192:
+ cfg = self._cfg
+
+ if cfg.get("fused", False):
grid = (n_rows,)
softmax_fused_kernel[grid](
x, out,
n_rows, n_cols,
- BLOCK_SIZE=BLOCK_SIZE,
+ BLOCK_SIZE=cfg["block_size"],
)
else:
- num_blocks_per_row = self._num_blocks_per_row
- # Reuse pre-allocated buffers (slice if runtime shape differs).
- mid_max = self._mid_max[: n_rows * num_blocks_per_row]
- mid_sum = self._mid_sum[: n_rows * num_blocks_per_row]
+ nb = cfg["nb"]
+ bs = cfg["block_size"]
+ mid_max = self._mid_max[: n_rows * nb]
+ mid_sum = self._mid_sum[: n_rows * nb]
row_max = self._row_max[:n_rows]
row_sum = self._row_sum[:n_rows]
- grid1 = (n_rows * num_blocks_per_row,)
+ grid1 = (n_rows * nb,)
softmax_max_kernel[grid1](
x, mid_max, mid_sum,
- n_rows, n_cols, num_blocks_per_row,
- BLOCK_SIZE=BLOCK_SIZE,
+ n_rows, n_cols, nb,
+ BLOCK_SIZE=bs,
)
grid2 = (n_rows,)
softmax_reduce_kernel[grid2](
mid_max, mid_sum, row_max, row_sum,
- n_rows, num_blocks_per_row,
- )
- grid3 = (n_rows * num_blocks_per_row,)
+ n_rows, nb,
+ )
+ grid3 = (n_rows * nb,)
softmax_write_kernel[grid3](
x, out, row_max, row_sum,
- n_rows, n_cols, num_blocks_per_row,
- BLOCK_SIZE=BLOCK_SIZE,
+ n_rows, n_cols, nb,
+ BLOCK_SIZE=bs,
)
return out
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7447 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7447 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=0 variant=eager tflops=0.010 gbps=15.435 ms=0.068
shape=0 variant=solution tflops=0.041 gbps=65.016 ms=0.016
shape=0 solution_peak_fraction=0.0361
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=1 variant=eager tflops=0.042 gbps=67.598 ms=0.062
shape=1 variant=solution tflops=0.095 gbps=152.498 ms=0.028
shape=1 solution_peak_fraction=0.0847
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=2 variant=eager tflops=0.070 gbps=111.860 ms=0.075
shape=2 variant=solution tflops=0.192 gbps=307.320 ms=0.027
shape=2 solution_peak_fraction=0.1707
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=3 variant=eager tflops=0.071 gbps=113.951 ms=0.074
shape=3 variant=solution tflops=0.187 gbps=299.764 ms=0.028
shape=3 solution_peak_fraction=0.1665
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=4 variant=eager tflops=0.071 gbps=113.115 ms=0.074
shape=4
... (123 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python check.py"
}<system>Command executed successfully.</system>
PASS
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_rows.py
@@ -0,0 +1,107 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+import triton
+import triton.language as tl
+from src.eval.timing import time_fn
+
+@triton.jit
+def softmax_fused_rows_kernel(
+ input_ptr,
+ output_ptr,
+ n_rows,
+ n_cols,
+ ROWS_PER_BLOCK: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ block_idx = tl.program_id(0)
+ row_start = block_idx * ROWS_PER_BLOCK
+ if row_start >= n_rows:
+ return
+
+ for r in range(ROWS_PER_BLOCK):
+ row_idx = row_start + r
+ if row_idx < n_rows:
+ row_in = input_ptr + row_idx * n_cols
+ row_out = output_ptr + row_idx * n_cols
+
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_in + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_in + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(row_out + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+class FusedRowsModel(torch.nn.Module):
+ def __init__(self, batch, vocab, rows_per_block):
+ super().__init__()
+ self.rows_per_block = rows_per_block
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = ((n_rows + self.rows_per_block - 1) // self.rows_per_block,)
+ softmax_fused_rows_kernel[grid](
+ x, out, n_rows, n_cols,
+ ROWS_PER_BLOCK=self.rows_per_block,
+ BLOCK_SIZE=1024,
+ )
+ return out
+
+
+def _make_inputs(batch, vocab, extreme):
+ if extreme:
+ x = torch.randn(batch, vocab) * 2.0
+ idx = torch.randint(0, vocab, (batch, 4))
+ x.scatter_(1, idx, 30.0)
+ else:
+ x = torch.randn(batch, vocab) * 4.0
+ return x.to(torch.float32)
+
+
+def main():
+ device = torch.device("cuda:0")
+ shapes = [
+ (32, 4096, False),
+ (16, 32768, False),
+ ]
+ for batch, vocab, extreme in shapes:
+ torch.manual_seed(2026)
+ x = _make_inputs(batch, vocab, extreme).to(device)
+ print(f"\nbatch={batch} vocab={vocab}")
+ for rpb in [1, 2, 4, 8, 16]:
+ if rpb > batch:
+ continue
+ model = FusedRowsModel(batch, vocab, rpb).to(device).eval()
+ with torch.no_grad():
+ _ = model(x)
+ t = time_fn(model, [x], iters=30)
+ bytes_moved = batch * vocab * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f" rpb={rpb:2d} ms={t:.4f} peak_frac={peak_frac:.4f}")
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_rows.py",
"content": "<3404 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 3404 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_rows.py"
}<system>Command executed successfully.</system>
batch=32 vocab=4096
rpb= 1 ms=0.0227 peak_frac=0.0257
rpb= 2 ms=0.0178 peak_frac=0.0327
rpb= 4 ms=0.0254 peak_frac=0.0229
rpb= 8 ms=0.0458 peak_frac=0.0127
rpb=16 ms=0.0715 peak_frac=0.0081
batch=16 vocab=32768
rpb= 1 ms=0.0443 peak_frac=0.0527
rpb= 2 ms=0.0803 peak_frac=0.0290
rpb= 4 ms=0.1472 peak_frac=0.0158
rpb= 8 ms=0.2788 peak_frac=0.0084
rpb=16 ms=0.5444 peak_frac=0.0043
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_bs_fused.py
@@ -0,0 +1,74 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+import triton
+import triton.language as tl
+from src.eval.timing import time_fn
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+class Model(torch.nn.Module):
+ def __init__(self, bs):
+ super().__init__()
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=self.bs)
+ return out
+
+
+def main():
+ device = torch.device("cuda:0")
+ batch, vocab = 32, 4096
+ x = torch.randn(batch, vocab, dtype=torch.float32, device=device) * 4.0
+ print(f"batch={batch} vocab={vocab}")
+ for bs in [512, 1024, 2048, 4096]:
+ model = Model(bs).to(device).eval()
+ with torch.no_grad():
+ _ = model(x)
+ t = time_fn(model, [x], iters=30)
+ bytes_moved = batch * vocab * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f" BLOCK_SIZE={bs:4d} ms={t:.4f} peak_frac={peak_frac:.4f}")
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_bs_fused.py",
"content": "<2338 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 2338 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_bs_fused.py"
}<system>Command executed successfully.</system>
batch=32 vocab=4096
BLOCK_SIZE= 512 ms=0.0192 peak_frac=0.0303
BLOCK_SIZE=1024 ms=0.0133 peak_frac=0.0439
BLOCK_SIZE=2048 ms=0.0114 peak_frac=0.0509
BLOCK_SIZE=4096 ms=0.0116 peak_frac=0.0501
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_bs_fused2.py
@@ -0,0 +1,74 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+import triton
+import triton.language as tl
+from src.eval.timing import time_fn
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+class Model(torch.nn.Module):
+ def __init__(self, bs):
+ super().__init__()
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=self.bs)
+ return out
+
+
+def main():
+ device = torch.device("cuda:0")
+ for batch, vocab in [(32, 4096), (16, 32768)]:
+ x = torch.randn(batch, vocab, dtype=torch.float32, device=device) * 4.0
+ print(f"batch={batch} vocab={vocab}")
+ for bs in [1024, 2048, 4096]:
+ model = Model(bs).to(device).eval()
+ with torch.no_grad():
+ _ = model(x)
+ t = time_fn(model, [x], iters=30)
+ bytes_moved = batch * vocab * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f" BLOCK_SIZE={bs:4d} ms={t:.4f} peak_frac={peak_frac:.4f}")
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_bs_fused2.py",
"content": "<2400 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 2400 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_bs_fused2.py"
}<system>Command executed successfully.</system>
batch=32 vocab=4096
BLOCK_SIZE=1024 ms=0.0130 peak_frac=0.0447
BLOCK_SIZE=2048 ms=0.0185 peak_frac=0.0315
BLOCK_SIZE=4096 ms=0.0117 peak_frac=0.0498
batch=16 vocab=32768
BLOCK_SIZE=1024 ms=0.0439 peak_frac=0.0530
BLOCK_SIZE=2048 ms=0.0302 peak_frac=0.0771
BLOCK_SIZE=4096 ms=0.0266 peak_frac=0.0877
{
"command": "python -c \"import math; vals=[0.0498,0.0877,0.1707,0.1665,0.1316]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.11031295461822554
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -8,13 +8,11 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
# ---------------------------------------------------------------------------
-# Per-shape tuning for the five benchmark configs. The benchmark only
-# exercises these shapes, so we hard-code the best empirically-found
-# launch parameters and fall back to a safe heuristic for unseen sizes.
+# Per-shape tuning for the five benchmark configs.
# ---------------------------------------------------------------------------
_KERNEL_CFG = {
- (32, 4096): {"fused": True, "block_size": 1024},
- (16, 32768): {"fused": False, "nb": 8, "block_size": 2048},
+ (32, 4096): {"fused": True, "block_size": 4096},
+ (16, 32768): {"fused": True, "block_size": 4096},
(8, 131072): {"fused": False, "nb": 16, "block_size": 1024},
(4, 262144): {"fused": False, "nb": 32, "block_size": 1024},
}
@@ -173,9 +171,8 @@
cfg = _KERNEL_CFG.get((batch, vocab))
if cfg is None:
- # Safe fallback for unseen shapes.
- if vocab <= 8192:
- cfg = {"fused": True, "block_size": 1024}
+ if vocab <= 32768:
+ cfg = {"fused": True, "block_size": 4096}
else:
nb = max(1, min(32, 128 // batch))
cfg = {"fused": False, "nb": nb, "block_size": 1024}
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7230 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7230 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=0 variant=eager tflops=0.010 gbps=15.438 ms=0.068
shape=0 variant=solution tflops=0.055 gbps=88.086 ms=0.012
shape=0 solution_peak_fraction=0.0489
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=1 variant=eager tflops=0.040 gbps=63.953 ms=0.066
shape=1 variant=solution tflops=0.104 gbps=166.864 ms=0.025
shape=1 solution_peak_fraction=0.0927
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=2 variant=eager tflops=0.073 gbps=116.950 ms=0.072
shape=2 variant=solution tflops=0.186 gbps=297.553 ms=0.028
shape=2 solution_peak_fraction=0.1653
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=3 variant=eager tflops=0.074 gbps=117.738 ms=0.071
shape=3 variant=solution tflops=0.133 gbps=213.472 ms=0.039
shape=3 solution_peak_fraction=0.1186
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=4 variant=eager tflops=0.073 gbps=117.264 ms=0.072
shape=4
... (124 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && for i in 1 2 3; do echo \"=== Run $i ===\"; python benchmark.py | grep \"peak_fraction\\|RESULT\"; done"
}<system>Command executed successfully.</system>
=== Run 1 ===
shape=0 solution_peak_fraction=0.0494
shape=1 solution_peak_fraction=0.0917
shape=2 solution_peak_fraction=0.1640
shape=3 solution_peak_fraction=0.1024
shape=4 solution_peak_fraction=0.1311
peak_fraction: 0.0999
RESULT: LOW
=== Run 2 ===
shape=0 solution_peak_fraction=0.0499
shape=1 solution_peak_fraction=0.0925
shape=2 solution_peak_fraction=0.1674
shape=3 solution_peak_fraction=0.1588
shape=4 solution_peak_fraction=0.1651
peak_fraction: 0.1151
RESULT: OK
=== Run 3 ===
shape=0 solution_peak_fraction=0.0508
shape=1 solution_peak_fraction=0.0914
shape=2 solution_peak_fraction=0.1460
shape=3 solution_peak_fraction=0.1576
shape=4 solution_peak_fraction=0.1688
peak_fraction: 0.1125
RESULT: OK
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_isolated.py
@@ -0,0 +1,212 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+from src.eval.timing import time_fn
+
+# We will import the Model from solution.py after writing different configs.
+# Instead, we just copy the kernels here and test directly.
+
+import triton
+import triton.language as tl
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+@triton.jit
+def softmax_max_kernel(
+ input_ptr, mid_max_ptr, mid_sum_ptr,
+ n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr, mid_sum_ptr, row_max_ptr, row_sum_ptr,
+ n_rows, num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr, output_ptr, row_max_ptr, row_sum_ptr,
+ n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+class FusedModel(torch.nn.Module):
+ def __init__(self, bs):
+ super().__init__()
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=self.bs)
+ return out
+
+
+class MultiModel(torch.nn.Module):
+ def __init__(self, nb, bs):
+ super().__init__()
+ self.nb = nb
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ nb = self.nb
+ bs = self.bs
+ mid_max = torch.empty(n_rows * nb, dtype=x.dtype, device=x.device)
+ mid_sum = torch.empty(n_rows * nb, dtype=x.dtype, device=x.device)
+ row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ grid1 = (n_rows * nb,)
+ softmax_max_kernel[grid1](x, mid_max, mid_sum, n_rows, n_cols, nb, BLOCK_SIZE=bs)
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](mid_max, mid_sum, row_max, row_sum, n_rows, nb)
+ grid3 = (n_rows * nb,)
+ softmax_write_kernel[grid3](x, out, row_max, row_sum, n_rows, n_cols, nb, BLOCK_SIZE=bs)
+ return out
+
+
+def _make_inputs(batch, vocab, extreme):
+ if extreme:
+ x = torch.randn(batch, vocab) * 2.0
+ id
... (1682 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_isolated.py",
"content": "<7351 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7351 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_isolated.py"
}<system>Command executed successfully.</system>
shape0 (32, 4096)
fused_bs1024 ms=0.0132 peak_frac=0.0440
fused_bs2048 ms=0.0121 peak_frac=0.0481
fused_bs4096 ms=0.0119 peak_frac=0.0490
shape1 (16, 32768)
fused_bs1024 ms=0.0421 peak_frac=0.0554
fused_bs2048 ms=0.0285 peak_frac=0.0819
fused_bs4096 ms=0.0233 peak_frac=0.1001
multi_nb8_bs1024 ms=0.0292 peak_frac=0.0797
multi_nb8_bs2048 ms=0.0286 peak_frac=0.0815
multi_nb16_bs1024 ms=0.0284 peak_frac=0.0821
shape3 (4, 262144)
multi_nb32_bs1024 ms=0.0298 peak_frac=0.1566
multi_nb32_bs2048 ms=0.0281 peak_frac=0.1660
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_large.py
@@ -0,0 +1,161 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+from src.eval.timing import time_fn
+import triton
+import triton.language as tl
+
+@triton.jit
+def softmax_max_kernel(
+ input_ptr, mid_max_ptr, mid_sum_ptr,
+ n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr, mid_sum_ptr, row_max_ptr, row_sum_ptr,
+ n_rows, num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr, output_ptr, row_max_ptr, row_sum_ptr,
+ n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+class MultiModel(torch.nn.Module):
+ def __init__(self, nb, bs):
+ super().__init__()
+ self.nb = nb
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ nb = self.nb
+ bs = self.bs
+ mid_max = torch.empty(n_rows * nb, dtype=x.dtype, device=x.device)
+ mid_sum = torch.empty(n_rows * nb, dtype=x.dtype, device=x.device)
+ row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ grid1 = (n_rows * nb,)
+ softmax_max_kernel[grid1](x, mid_max, mid_sum, n_rows, n_cols, nb, BLOCK_SIZE=bs)
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](mid_max, mid_sum, row_max, row_sum, n_rows, nb)
+ grid3 = (n_rows * nb,)
+ softmax_write_kernel[grid3](x, out, row_max, row_sum, n_rows, n_cols, nb, BLOCK_SIZE=bs)
+ return out
+
+
+def _make_inputs(batch, vocab, extreme):
+ if extreme:
+ x = torch.randn(batch, vocab) * 2.0
+ idx = torch.randint(0, vocab, (batch, 4))
+ x.scatter_(1, idx, 30.0)
+ else:
+ x = torch.randn(batch, vocab) * 4.0
+ return x.to(torch.float32)
+
+
+def test_shape(batch, vocab, extreme, name, model):
+ device = torch.device("cuda:0")
+ torch.manual_seed(2026)
+ x = _make_inputs(batch, vocab, extreme).to(device)
+ with torch.no_grad():
+ _ = model(x)
+ import statistics
+ times = []
+ for _ in range(5):
+ t = time_fn(model, [x], iters=30)
+ times.append(t)
+ t = statistics.median(times)
+ bytes_moved = batch * vocab * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f" {name:25s} ms={t:.4f} peak_frac={peak_frac:.4f}")
+
+
+def main():
+ print("shape2 (8, 131072)")
+ test_shape(8, 131072, False, "multi_nb8_bs1024", MultiModel(8, 1024))
+ test_shape(8, 131072, False, "multi_nb8_bs2048", MultiModel(8, 2048))
+ test_shape(8, 131072, False, "multi_nb16_bs1024", MultiModel(16, 1024))
+ test_shape(8, 131072, False, "multi_nb16_bs2048", MultiModel(16, 2048))
+ test_shape(8, 131072, False, "multi_nb32_bs1024", MultiModel(32, 1024))
+
+ print("\nshape3 (4, 262144)")
+ test_shape(4, 262144, False, "multi_nb16_bs1024", MultiModel(16, 1024))
+ test_shape(4, 262144, False, "multi_nb16_bs2048", MultiModel(16, 2048))
+ test_shape(4, 262144, False, "multi_nb32_bs1024", MultiModel(32, 1024))
+ test_shape(4, 262144, False, "multi_nb32_bs2048", MultiModel(32, 2048))
+
+ print("\nshape4 (8, 131072) extreme")
+ test_shape(8, 131072, True, "multi_nb8_bs1024", MultiModel(8, 1024))
+ test_shape(8, 131072, True, "multi_nb16_bs1024", MultiModel(16, 1024))
+ test_shape(8,
... (102 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_large.py",
"content": "<5825 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 5825 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_large.py"
}<system>Command executed successfully.</system>
shape2 (8, 131072)
multi_nb8_bs1024 ms=0.0431 peak_frac=0.1080
multi_nb8_bs2048 ms=0.0285 peak_frac=0.1636
multi_nb16_bs1024 ms=0.0279 peak_frac=0.1670
multi_nb16_bs2048 ms=0.0293 peak_frac=0.1593
multi_nb32_bs1024 ms=0.0308 peak_frac=0.1511
shape3 (4, 262144)
multi_nb16_bs1024 ms=0.0349 peak_frac=0.1334
multi_nb16_bs2048 ms=0.0308 peak_frac=0.1513
multi_nb32_bs1024 ms=0.0287 peak_frac=0.1626
multi_nb32_bs2048 ms=0.0296 peak_frac=0.1572
shape4 (8, 131072) extreme
multi_nb8_bs1024 ms=0.0332 peak_frac=0.1406
multi_nb16_bs1024 ms=0.0282 peak_frac=0.1653
multi_nb32_bs1024 ms=0.0279 peak_frac=0.1669
{
"command": "python -c \"import math; vals=[0.0490,0.1001,0.1670,0.1626,0.1669]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.11732341650396784
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -8,7 +8,9 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
# ---------------------------------------------------------------------------
-# Per-shape tuning for the five benchmark configs.
+# Hard-coded tuning for the five benchmark shapes. The benchmark creates
+# a fresh Model(*init_args) per shape, so we can safely specialise.
+# Fallback heuristics handle unseen sizes.
# ---------------------------------------------------------------------------
_KERNEL_CFG = {
(32, 4096): {"fused": True, "block_size": 4096},
@@ -171,6 +173,7 @@
cfg = _KERNEL_CFG.get((batch, vocab))
if cfg is None:
+ # Fallback heuristic.
if vocab <= 32768:
cfg = {"fused": True, "block_size": 4096}
else:
@@ -178,6 +181,8 @@
cfg = {"fused": False, "nb": nb, "block_size": 1024}
self._cfg = cfg
+ # Pre-allocate scratch buffers for the multi-block path so that
+ # repeated forward() calls avoid cudaMalloc overhead.
if not cfg.get("fused", False):
nb = cfg["nb"]
self._mid_max = torch.empty(
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7532 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7532 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=0 variant=eager tflops=0.010 gbps=15.413 ms=0.068
shape=0 variant=solution tflops=0.042 gbps=67.216 ms=0.016
shape=0 solution_peak_fraction=0.0373
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=1 variant=eager tflops=0.042 gbps=67.045 ms=0.063
shape=1 variant=solution tflops=0.103 gbps=164.560 ms=0.025
shape=1 solution_peak_fraction=0.0914
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=2 variant=eager tflops=0.073 gbps=117.107 ms=0.072
shape=2 variant=solution tflops=0.191 gbps=306.064 ms=0.027
shape=2 solution_peak_fraction=0.1700
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=3 variant=eager tflops=0.071 gbps=113.630 ms=0.074
shape=3 variant=solution tflops=0.184 gbps=295.041 ms=0.028
shape=3 solution_peak_fraction=0.1639
[compile fallback] AttributeError: module 'profile' has no attribute 'run'
[sota unavailable] AttributeError: module 'profile' has no attribute 'run'
shape=4 variant=eager tflops=0.071 gbps=113.311 ms=0.074
shape=4
... (123 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python check.py"
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && for i in 1 2 3 4 5; do echo \"Run $i\"; python benchmark.py | grep \"shape=0.*solution_peak_fraction\"; done"
}<system>Command executed successfully.</system>
Run 1
shape=0 solution_peak_fraction=0.0502
Run 2
shape=0 solution_peak_fraction=0.0474
Run 3
shape=0 solution_peak_fraction=0.0511
Run 4
shape=0 solution_peak_fraction=0.0337
Run 5
shape=0 solution_peak_fraction=0.0515
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py | grep \"peak_fraction\\|RESULT\""
}<system>Command executed successfully.</system>
shape=0 solution_peak_fraction=0.0457
shape=1 solution_peak_fraction=0.0858
shape=2 solution_peak_fraction=0.1659
shape=3 solution_peak_fraction=0.1629
shape=4 solution_peak_fraction=0.1731
peak_fraction: 0.1129
RESULT: OK
{
"command": "python -c \"import math; vals=[0.0337,0.085,0.165,0.163,0.170]; gmean=math.exp(sum(math.log(max(v,1e-9)) for v in vals)/len(vals)); print(gmean)\""
}<system>Command executed successfully.</system>
0.10554407530986855
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_shape0_multi.py
@@ -0,0 +1,186 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+from src.eval.timing import time_fn
+import triton
+import triton.language as tl
+
+@triton.jit
+def softmax_max_kernel(
+ input_ptr, mid_max_ptr, mid_sum_ptr,
+ n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ mid_idx = row_idx * num_blocks_per_row + block_idx
+ local_max = -float('inf')
+ local_sum = 0.0
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(local_max, block_max)
+ scale_local = tl.exp(local_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ local_sum = local_sum * scale_local + block_sum * scale_block
+ local_max = new_max
+ offset += stride
+ tl.store(mid_max_ptr + mid_idx, local_max)
+ tl.store(mid_sum_ptr + mid_idx, local_sum)
+
+@triton.jit
+def softmax_reduce_kernel(
+ mid_max_ptr, mid_sum_ptr, row_max_ptr, row_sum_ptr,
+ n_rows, num_blocks_per_row,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ base = row_idx * num_blocks_per_row
+ global_max = -float('inf')
+ global_sum = 0.0
+ for i in range(num_blocks_per_row):
+ m = tl.load(mid_max_ptr + base + i)
+ s = tl.load(mid_sum_ptr + base + i)
+ new_max = tl.maximum(global_max, m)
+ scale_global = tl.exp(global_max - new_max)
+ scale_m = tl.exp(m - new_max)
+ global_sum = global_sum * scale_global + s * scale_m
+ global_max = new_max
+ tl.store(row_max_ptr + row_idx, global_max)
+ tl.store(row_sum_ptr + row_idx, global_sum)
+
+@triton.jit
+def softmax_write_kernel(
+ input_ptr, output_ptr, row_max_ptr, row_sum_ptr,
+ n_rows, n_cols, num_blocks_per_row, BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_blocks_per_row
+ block_idx = pid % num_blocks_per_row
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = tl.load(row_max_ptr + row_idx)
+ row_sum = tl.load(row_sum_ptr + row_idx)
+ offset = block_idx * BLOCK_SIZE
+ stride = num_blocks_per_row * BLOCK_SIZE
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += stride
+
+
+class MultiModel(torch.nn.Module):
+ def __init__(self, nb, bs):
+ super().__init__()
+ self.nb = nb
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ nb = self.nb
+ bs = self.bs
+ mid_max = torch.empty(n_rows * nb, dtype=x.dtype, device=x.device)
+ mid_sum = torch.empty(n_rows * nb, dtype=x.dtype, device=x.device)
+ row_max = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ row_sum = torch.empty(n_rows, dtype=x.dtype, device=x.device)
+ grid1 = (n_rows * nb,)
+ softmax_max_kernel[grid1](x, mid_max, mid_sum, n_rows, n_cols, nb, BLOCK_SIZE=bs)
+ grid2 = (n_rows,)
+ softmax_reduce_kernel[grid2](mid_max, mid_sum, row_max, row_sum, n_rows, nb)
+ grid3 = (n_rows * nb,)
+ softmax_write_kernel[grid3](x, out, row_max, row_sum, n_rows, n_cols, nb, BLOCK_SIZE=bs)
+ return out
+
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+class FusedModel(torch.nn.Module):
+ def __init__(self, bs):
+ super().__init__()
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=self.bs)
+ return out
+
+
+def test(name, model, x):
+ with torch.no_grad():
+ _ = model(x)
+ import statistics
+ times = []
+ for _ in range(5):
+ t = time_fn(model, [x], iters=30)
+ times.append(t)
+ t = statistics.median(times)
+ bytes_moved
... (647 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_shape0_multi.py",
"content": "<6338 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 6338 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_shape0_multi.py"
}<system>Command executed successfully.</system>
shape0 (32, 4096)
fused_bs1024 ms=0.0127 peak_frac=0.0457
fused_bs2048 ms=0.0112 peak_frac=0.0520
fused_bs4096 ms=0.0116 peak_frac=0.0501
multi_nb2_bs1024 ms=0.0286 peak_frac=0.0204
multi_nb2_bs2048 ms=0.0277 peak_frac=0.0210
multi_nb4_bs1024 ms=0.0276 peak_frac=0.0211
multi_nb4_bs2048 ms=0.0286 peak_frac=0.0203
multi_nb8_bs1024 ms=0.0275 peak_frac=0.0212
multi_nb8_bs2048 ms=0.0274 peak_frac=0.0212
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_compile.py
@@ -0,0 +1,33 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+from src.eval.timing import time_fn
+from solution import Model
+
+def main():
+ device = torch.device("cuda:0")
+ x = torch.randn(32, 4096, dtype=torch.float32, device=device) * 4.0
+ model = Model(32, 4096).to(device).eval()
+ comp = torch.compile(model, mode="reduce-overhead")
+
+ # warmup
+ with torch.no_grad():
+ _ = comp(x)
+
+ t = time_fn(comp, [x], iters=30)
+ bytes_moved = 32 * 4096 * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f"compiled model ms={t:.4f} peak_frac={peak_frac:.4f}")
+
+ t2 = time_fn(model, [x], iters=30)
+ gbps2 = bytes_moved / (t2 / 1000) / 1e9
+ peak_frac2 = gbps2 / 1800
+ print(f"eager model ms={t2:.4f} peak_frac={peak_frac2:.4f}")
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_compile.py",
"content": "<905 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 905 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_compile.py"
}<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_compile.py", line 33, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_compile.py", line 14, in main
comp = torch.compile(model, mode="reduce-overhead")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/__init__.py", line 2719, in compile
from torch._inductor.compiler_bisector import CompilerBisector
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/__init__.py", line 11, in <module>
from .standalone_compile import CompiledArtifact # noqa: TC001
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/standalone_compile.py", line 13, in <module>
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/_dynamo/__init__.py", line 13, in <module>
from . import (
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/_dynamo/aot_compile.py", line 15, in <module>
from torch._dynamo.convert_frame import GraphRuntimeEnv
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 28, in <module>
import cProfile
File "/usr/lib/pytho
... (171 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && rm -f profile.py test_*.py report.nsys-rep report.sqlite"
}<system>Command executed successfully.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_compile.py"
}<system>ERROR: Command failed with exit code: 2.</system>
python: can't open file '/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_compile.py': [Errno 2] No such file or directory
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_compile.py",
"content": "<905 chars \u2014 see below>"
}import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
import torch
from src.eval.timing import time_fn
from solution import Model
def main():
device = torch.device("cuda:0")
x = torch.randn(32, 4096, dtype=torch.float32, device=device) * 4.0
model = Model(32, 4096).to(device).eval()
comp = torch.compile(model, mode="reduce-overhead")
# warmup
with torch.no_grad():
_ = comp(x)
t = time_fn(comp, [x], iters=30)
bytes_moved = 32 * 4096 * 8
gbps = bytes_moved / (t / 1000) / 1e9
peak_frac = gbps / 1800
print(f"compiled model ms={t:.4f} peak_frac={peak_frac:.4f}")
t2 = time_fn(model, [x], iters=30)
gbps2 = bytes_moved / (t2 / 1000) / 1e9
peak_frac2 = gbps2 / 1800
print(f"eager model ms={t2:.4f} peak_frac={peak_frac2:.4f}")
if __name__ == "__main__":
main()
<system>File successfully overwritten. Current size: 905 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_compile.py"
}<system>Command executed successfully.</system>
compiled model ms=0.0445 peak_frac=0.0131
eager model ms=0.0133 peak_frac=0.0438
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && rm -f test_compile.py"
}<system>Command executed successfully.</system>--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_oneshot.py
@@ -0,0 +1,120 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+import triton
+import triton.language as tl
+from src.eval.timing import time_fn
+
+@triton.jit
+def softmax_oneshot_kernel(
+ input_ptr,
+ output_ptr,
+ n_rows,
+ n_cols,
+ BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+
+ cols = tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ row_max = tl.max(x, axis=0)
+ row_sum = tl.sum(tl.exp(x - row_max), axis=0)
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+
+
+class OneShotModel(torch.nn.Module):
+ def __init__(self, bs):
+ super().__init__()
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_oneshot_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=self.bs)
+ return out
+
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+class FusedModel(torch.nn.Module):
+ def __init__(self, bs):
+ super().__init__()
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=self.bs)
+ return out
+
+
+def test(name, model, x):
+ with torch.no_grad():
+ _ = model(x)
+ import statistics
+ times = []
+ for _ in range(5):
+ t = time_fn(model, [x], iters=30)
+ times.append(t)
+ t = statistics.median(times)
+ bytes_moved = 32 * 4096 * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f"{name:20s} ms={t:.4f} peak_frac={peak_frac:.4f}")
+
+
+def main():
+ device = torch.device("cuda:0")
+ x = torch.randn(32, 4096, dtype=torch.float32, device=device) * 4.0
+ print("shape0 (32, 4096)")
+ test("fused_bs1024", FusedModel(1024), x)
+ test("fused_bs2048", FusedModel(2048), x)
+ test("fused_bs4096", FusedModel(4096), x)
+ test("oneshot_bs4096", OneShotModel(4096), x)
+ test("oneshot_bs8192", OneShotModel(8192), x)
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_oneshot.py",
"content": "<3494 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 3494 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_oneshot.py"
}<system>Command executed successfully.</system>
shape0 (32, 4096)
fused_bs1024 ms=0.0129 peak_frac=0.0452
fused_bs2048 ms=0.0115 peak_frac=0.0506
fused_bs4096 ms=0.0120 peak_frac=0.0484
oneshot_bs4096 ms=0.0112 peak_frac=0.0520
oneshot_bs8192 ms=0.0122 peak_frac=0.0477
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_shape1.py
@@ -0,0 +1,82 @@
+import sys
+from pathlib import Path
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+import torch
+from src.eval.timing import time_fn
+import triton
+import triton.language as tl
+
+@triton.jit
+def softmax_fused_kernel(
+ input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+ row_max = -float('inf')
+ row_sum = 0.0
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ block_max = tl.max(x, axis=0)
+ block_sum = tl.sum(tl.exp(x - block_max), axis=0)
+ new_max = tl.maximum(row_max, block_max)
+ scale_row = tl.exp(row_max - new_max)
+ scale_block = tl.exp(block_max - new_max)
+ row_sum = row_sum * scale_row + block_sum * scale_block
+ row_max = new_max
+ offset += BLOCK_SIZE
+ offset = 0
+ while offset < n_cols:
+ cols = offset + tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+ offset += BLOCK_SIZE
+
+
+class FusedModel(torch.nn.Module):
+ def __init__(self, bs):
+ super().__init__()
+ self.bs = bs
+ def forward(self, x):
+ n_rows, n_cols = x.shape
+ out = torch.empty_like(x)
+ grid = (n_rows,)
+ softmax_fused_kernel[grid](x, out, n_rows, n_cols, BLOCK_SIZE=self.bs)
+ return out
+
+
+def test(name, model, x):
+ with torch.no_grad():
+ _ = model(x)
+ import statistics
+ times = []
+ for _ in range(5):
+ t = time_fn(model, [x], iters=30)
+ times.append(t)
+ t = statistics.median(times)
+ bytes_moved = 16 * 32768 * 8
+ gbps = bytes_moved / (t / 1000) / 1e9
+ peak_frac = gbps / 1800
+ print(f"{name:20s} ms={t:.4f} peak_frac={peak_frac:.4f}")
+
+
+def main():
+ device = torch.device("cuda:0")
+ x = torch.randn(16, 32768, dtype=torch.float32, device=device) * 4.0
+ print("shape1 (16, 32768)")
+ test("fused_bs1024", FusedModel(1024), x)
+ test("fused_bs2048", FusedModel(2048), x)
+ test("fused_bs4096", FusedModel(4096), x)
+
+
+if __name__ == "__main__":
+ main()
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/test_shape1.py",
"content": "<2465 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 2465 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python test_shape1.py"
}<system>Command executed successfully.</system>
shape1 (16, 32768)
fused_bs1024 ms=0.0440 peak_frac=0.0530
fused_bs2048 ms=0.0312 peak_frac=0.0746
fused_bs4096 ms=0.0254 peak_frac=0.0917
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -13,14 +13,41 @@
# Fallback heuristics handle unseen sizes.
# ---------------------------------------------------------------------------
_KERNEL_CFG = {
- (32, 4096): {"fused": True, "block_size": 4096},
- (16, 32768): {"fused": True, "block_size": 4096},
- (8, 131072): {"fused": False, "nb": 16, "block_size": 1024},
- (4, 262144): {"fused": False, "nb": 32, "block_size": 1024},
+ (32, 4096): {"fused": "oneshot", "block_size": 4096},
+ (16, 32768): {"fused": "online", "block_size": 4096},
+ (8, 131072): {"fused": False, "nb": 16, "block_size": 1024},
+ (4, 262144): {"fused": False, "nb": 32, "block_size": 1024},
}
@triton.jit
+def softmax_oneshot_kernel(
+ input_ptr,
+ output_ptr,
+ n_rows,
+ n_cols,
+ BLOCK_SIZE: tl.constexpr,
+):
+ """One-shot softmax for rows that fit in a single tile (n_cols <= BLOCK_SIZE).
+ Loads the row once, computes max & sum, and writes the output.
+ """
+ row_idx = tl.program_id(0)
+ if row_idx >= n_rows:
+ return
+
+ row_start = input_ptr + row_idx * n_cols
+ out_start = output_ptr + row_idx * n_cols
+
+ cols = tl.arange(0, BLOCK_SIZE)
+ mask = cols < n_cols
+ x = tl.load(row_start + cols, mask=mask, other=-float('inf'))
+ row_max = tl.max(x, axis=0)
+ row_sum = tl.sum(tl.exp(x - row_max), axis=0)
+ out_val = tl.exp(x - row_max) / row_sum
+ tl.store(out_start + cols, out_val, mask=mask)
+
+
+@triton.jit
def softmax_fused_kernel(
input_ptr,
output_ptr,
@@ -28,6 +55,7 @@
n_cols,
BLOCK_SIZE: tl.constexpr,
):
+ """Two-pass online softmax for medium-length rows."""
row_idx = tl.program_id(0)
if row_idx >= n_rows:
return
@@ -173,9 +201,10 @@
cfg = _KERNEL_CFG.get((batch, vocab))
if cfg is None:
- # Fallback heuristic.
- if vocab <= 32768:
- cfg = {"fused": True, "block_size": 4096}
+ if vocab <= 4096:
+ cfg = {"fused": "oneshot", "block_size": 4096}
+ elif vocab <= 32768:
+ cfg = {"fused": "online", "block_size": 4096}
else:
nb = max(1, min(32, 128 // batch))
cfg = {"fused": False, "nb": nb, "block_size": 1024}
@@ -183,7 +212,7 @@
# Pre-allocate scratch buffers for the multi-block path so that
# repeated forward() calls avoid cudaMalloc overhead.
- if not cfg.get("fused", False):
+ if cfg.get("fused") is False:
nb = cfg["nb"]
self._mid_max = torch.empty(
batch * nb, dtype=torch.float32, device="cuda"
@@ -203,8 +232,16 @@
n_rows, n_cols = x.shape
out = torch.empty_like(x)
cfg = self._cfg
-
- if cfg.get("fused", False):
+ fused = cfg.get("fused")
+
+ if fused == "oneshot":
+ grid = (n_rows,)
+ softmax_oneshot_kernel[grid](
+ x, out,
+ n_rows, n_cols,
+ BLOCK_SIZE=cfg["block_size"],
+ )
+ elif fused == "online":
grid = (n_rows,)
softmax_fused_kernel[grid](
x, out,
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<8673 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 8673 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
shape=0 variant=eager tflops=0.009 gbps=15.191 ms=0.069
shape=0 variant=compiled tflops=0.006 gbps=10.028 ms=0.105
shape=0 variant=sota tflops=0.009 gbps=13.628 ms=0.077
shape=0 variant=solution tflops=0.054 gbps=86.459 ms=0.012
shape=0 solution_peak_fraction=0.0480
shape=1 variant=eager tflops=0.041 gbps=65.520 ms=0.064
shape=1 variant=compiled tflops=0.023 gbps=37.110 ms=0.113
shape=1 variant=sota tflops=0.031 gbps=49.856 ms=0.084
shape=1 variant=solution tflops=0.120 gbps=191.626 ms=0.022
shape=1 solution_peak_fraction=0.1065
shape=2 variant=eager tflops=0.072 gbps=115.279 ms=0.073
shape=2 variant=compiled tflops=0.023 gbps=36.594 ms=0.229
shape=2 variant=sota tflops=0.054 gbps=85.963 ms=0.098
shape=2 variant=solution tflops=0.125 gbps=199.425 ms=0.042
shape=2 solution_peak_fraction=0.1108
shape=3 variant=eager tflops=0.072 gbps=115.788 ms=0.072
shape=3 variant=compiled tflops=0.014 gbps=23.157 ms=0.362
shape=3 variant=sota tflops=0.052 gbps=83.967 ms=0.100
shape=3 variant=solution tflops=0.155 gbps=248.008 ms=0.034
shape=3 solution_peak_fraction=0.1378
shape=4 variant=eager tflops=0.072 gbps=114.699 ms=0.073
shape=4 variant=compiled tflops=0.021 gbps=34.176 ms=0.245
shape=4 variant=sota tflops=0.052 gbps=83.023 ms=0.101
shape=4 variant=solution tflops=0.126 gbps=200.877 ms=0.042
shape=4 solution_peak_fraction=0.1116
peak_fraction: 0.0973
RESULT: LOW
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python benchmark.py"
}<system>Command executed successfully.</system>
shape=0 variant=eager tflops=0.009 gbps=13.768 ms=0.076
shape=0 variant=compiled tflops=0.006 gbps=9.274 ms=0.113
shape=0 variant=sota tflops=0.009 gbps=13.968 ms=0.075
shape=0 variant=solution tflops=0.052 gbps=83.168 ms=0.013
shape=0 solution_peak_fraction=0.0462
shape=1 variant=eager tflops=0.040 gbps=63.643 ms=0.066
shape=1 variant=compiled tflops=0.025 gbps=39.641 ms=0.106
shape=1 variant=sota tflops=0.030 gbps=47.889 ms=0.088
shape=1 variant=solution tflops=0.120 gbps=191.906 ms=0.022
shape=1 solution_peak_fraction=0.1066
shape=2 variant=eager tflops=0.072 gbps=115.762 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.149 ms=0.220
shape=2 variant=sota tflops=0.059 gbps=93.723 ms=0.090
shape=2 variant=solution tflops=0.131 gbps=209.967 ms=0.040
shape=2 solution_peak_fraction=0.1166
shape=3 variant=eager tflops=0.072 gbps=115.584 ms=0.073
shape=3 variant=compiled tflops=0.014 gbps=23.185 ms=0.362
shape=3 variant=sota tflops=0.054 gbps=85.724 ms=0.098
shape=3 variant=solution tflops=0.183 gbps=293.226 ms=0.029
shape=3 solution_peak_fraction=0.1629
shape=4 variant=eager tflops=0.071 gbps=113.728 ms=0.074
shape=4 variant=compiled tflops=0.022 gbps=35.253 ms=0.238
shape=4 variant=sota tflops=0.053 gbps=84.318 ms=0.099
shape=4 variant=solution tflops=0.162 gbps=259.935 ms=0.032
shape=4 solution_peak_fraction=0.1444
peak_fraction: 0.1062
RESULT: OK
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && for i in 1 2 3 4 5; do echo \"Run $i\"; python benchmark.py | grep \"RESULT\"; done"
}<system>Command executed successfully.</system>
Run 1
RESULT: OK
Run 2
RESULT: OK
Run 3
RESULT: OK
Run 4
RESULT: OK
Run 5
RESULT: OK