"""Numerically tight softmax with Kahan compensated summation.
Single-block path for smaller vocabs (V ≤ 32768) where one-kernel-launch
simplicity wins. Multi-block map-reduce for large vocabs where parallelism
across blocks is needed to saturate GPU bandwidth.
Map: each block computes local (max, Kahan-sum-of-exp) for its chunk.
Reduce: GPU-side Kahan combination of per-block results (num_warps=1).
Norm: each block normalizes its chunk using global (max, sum).
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# Single-block kernel
# ---------------------------------------------------------------------------
@triton.jit
def _single_kernel(x_ptr, y_ptr, V: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
"""Online softmax + Kahan sum, then normalize. One block per row."""
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
m = float('-inf')
s = 0.0
c = 0.0
for start in range(0, V, BLOCK_VOCAB):
mask = (start + offs) < V
addr = pid * V + start + offs
x = tl.load(x_ptr + addr, mask=mask, other=float('-inf'))
m_local = tl.max(tl.where(mask, x, float('-inf')))
m_new = tl.maximum(m, m_local)
scale = tl.exp(m - m_new)
s = s * scale
c = c * scale
exp_vals = tl.exp(x - m_new)
chunk_sum = tl.sum(tl.where(mask, exp_vals, 0.0))
t = s
y_k = chunk_sum - c
s = t + y_k
c = (s - t) - y_k
m = m_new
inv_sum = 1.0 / s
for start in range(0, V, BLOCK_VOCAB):
mask = (start + offs) < V
addr = pid * V + start + offs
x = tl.load(x_ptr + addr, mask=mask, other=0.0)
y = tl.exp(x - m) * inv_sum
tl.store(y_ptr + addr, y, mask=mask)
# ---------------------------------------------------------------------------
# Multi-block kernels
# ---------------------------------------------------------------------------
@triton.jit
def _map_kernel(x_ptr, max_buf, sum_buf,
elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
base = pid * elems_per_block
m = float('-inf')
s = 0.0
c = 0.0
for start in range(0, elems_per_block, BLOCK_VOCAB):
mask = (start + offs) < elems_per_block
addr = base + start + offs
x = tl.load(x_ptr + addr, mask=mask, other=float('-inf'))
m_local = tl.max(tl.where(mask, x, float('-inf')))
m_new = tl.maximum(m, m_local)
scale = tl.exp(m - m_new)
s = s * scale
c = c * scale
exp_vals = tl.exp(x - m_new)
chunk_sum = tl.sum(tl.where(mask, exp_vals, 0.0))
t = s
y_k = chunk_sum - c
s = t + y_k
c = (s - t) - y_k
m = m_new
tl.store(max_buf + pid, m)
tl.store(sum_buf + pid, s)
@triton.jit
def _reduce_kernel(max_in, sum_in, max_out, sum_out,
blocks_per_row: tl.constexpr):
rid = tl.program_id(0)
base = rid * blocks_per_row
m = float('-inf')
s = 0.0
c = 0.0
for k in range(blocks_per_row):
mk = tl.load(max_in + (base + k))
sk = tl.load(sum_in + (base + k))
m_new = tl.maximum(m, mk)
scale = tl.exp(m - m_new)
s = s * scale
c = c * scale
if mk < m_new:
sk = sk * tl.exp(mk - m_new)
t = s
y_k = sk - c
s = t + y_k
c = (s - t) - y_k
m = m_new
tl.store(max_out + rid, m)
tl.store(sum_out + rid, s)
@triton.jit
def _norm_kernel(x_ptr, y_ptr, max_per_row, sum_per_row, row_ids,
elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
row = tl.load(row_ids + pid)
m_global = tl.load(max_per_row + row)
s_global = tl.load(sum_per_row + row)
inv_sum = 1.0 / s_global
base = pid * elems_per_block
for start in range(0, elems_per_block, BLOCK_VOCAB):
mask = (start + offs) < elems_per_block
addr = base + start + offs
x = tl.load(x_ptr + addr, mask=mask, other=0.0)
y = tl.exp(x - m_global) * inv_sum
tl.store(y_ptr + addr, y, mask=mask)
# ---------------------------------------------------------------------------
# Tuning
# ---------------------------------------------------------------------------
def _single_config(vocab: int) -> tuple:
"""Return (BLOCK_VOCAB, num_warps) for single-block path."""
if vocab <= 4096:
return (1024, 4)
elif vocab <= 16384:
return (2048, 8)
else:
return (4096, 8)
def _multi_config(vocab: int) -> tuple:
"""Return (BLOCK_VOCAB, num_warps) for multi-block path."""
if vocab <= 32768:
return (2048, 8)
elif vocab <= 131072:
return (4096, 8)
else:
return (8192, 8)
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class Model(nn.Module):
def __init__(self, batch: int, vocab: int):
super().__init__()
self.batch = batch
self.vocab = vocab
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, V = x.shape
device = x.device
# Single-block path for small vocabs or high-batch cases
if V <= 32768:
y = torch.empty_like(x)
BLOCK_VOCAB, num_warps = _single_config(V)
_single_kernel[(B,)](
x, y,
V=V,
BLOCK_VOCAB=BLOCK_VOCAB,
num_warps=num_warps,
)
return y
# ---- Multi-block path for large vocabs (> 32768) ----
BLOCK_VOCAB, num_warps = _multi_config(V)
# Choose blocks_per_row as a power-of-2 dividing V
target = max(1, min(64 // max(B, 1), V // 4096))
blocks_per_row = triton.next_power_of_2(target)
while V % blocks_per_row != 0 and blocks_per_row > 1:
blocks_per_row //= 2
elems_per_block = V // blocks_per_row
total_blocks = B * blocks_per_row
row_ids = torch.arange(B, device=device, dtype=torch.int32).repeat_interleave(blocks_per_row)
# 1) Map
max_buf = torch.empty(total_blocks, dtype=torch.float32, device=device)
sum_buf = torch.empty(total_blocks, dtype=torch.float32, device=device)
_map_kernel[(total_blocks,)](
x, max_buf, sum_buf,
elems_per_block=elems_per_block,
BLOCK_VOCAB=BLOCK_VOCAB,
num_warps=num_warps,
)
# 2) Reduce
global_max = torch.empty(B, dtype=torch.float32, device=device)
global_sum = torch.empty(B, dtype=torch.float32, device=device)
_reduce_kernel[(B,)](
max_buf, sum_buf, global_max, global_sum,
blocks_per_row=blocks_per_row,
num_warps=1,
)
# 3) Norm
y = torch.empty_like(x)
_norm_kernel[(total_blocks,)](
x, y, global_max, global_sum, row_ids,
elems_per_block=elems_per_block,
BLOCK_VOCAB=BLOCK_VOCAB,
num_warps=num_warps,
)
return y
# Default shape
BATCH = 8
VOCAB = 32768
def get_inputs():
x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
return [x]
def get_init_inputs():
return [BATCH, VOCAB]
"""Numerically tight softmax with Kahan compensated summation.
Single-block path for smaller vocabs (V ≤ 32768) where one-kernel-launch
simplicity wins. Multi-block map-reduce for large vocabs where parallelism
across blocks is needed to saturate GPU bandwidth.
Map: each block computes local (max, Kahan-sum-of-exp) for its chunk.
Reduce: GPU-side Kahan combination of per-block results (num_warps=1).
Norm: each block normalizes its chunk using global (max, sum).
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
shape=0 variant=eager tflops=0.010 gbps=15.537 ms=0.067
shape=0 variant=compiled tflops=0.006 gbps=10.264 ms=0.102
shape=0 variant=sota tflops=0.034 gbps=54.207 ms=0.019
shape=0 variant=solution tflops=0.057 gbps=90.770 ms=0.012
shape=0 solution_peak_fraction=0.0504
shape=1 variant=eager tflops=0.041 gbps=66.215 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.611 ms=0.098
shape=1 variant=sota tflops=0.130 gbps=208.713 ms=0.020
shape=1 variant=solution tflops=0.140 gbps=224.055 ms=0.019
shape=1 solution_peak_fraction=0.1245
shape=2 variant=eager tflops=0.072 gbps=115.635 ms=0.073
shape=2 variant=compiled tflops=0.024 gbps=38.395 ms=0.218
shape=2 variant=sota tflops=0.109 gbps=173.778 ms=0.048
shape=2 variant=solution tflops=0.138 gbps=220.474 ms=0.038
shape=2 solution_peak_fraction=0.1225
shape=3 variant=eager tflops=0.074 gbps=117.632 ms=0.071
shape=3 variant=compiled tflops=0.015 gbps=23.251 ms=0.361
shape=3 variant=sota tflops=0.099 gbps=157.776 ms=0.053
shape=3 variant=solution tflops=0.131 gbps=209.046 ms=0.040
shape=3 solution_peak_fraction=0.1161
shape=4 variant=eager tflops=0.072 gbps=115.152 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.792 ms=0.228
shape=4 variant=sota tflops=0.100 gbps=159.746 ms=0.053
shape=4 variant=solution tflops=0.133 gbps=213.385 ms=0.039
shape=4 solution_peak_fraction=0.1185
peak_fraction: 0.1011
RESULT: OK
PASS
{
"problem": "04_kahan_softmax",
"harness": "opencode",
"model": "deepseek/deepseek-v4-pro",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1011,
"elapsed_seconds": 1847,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 42869, "output_tokens": 17712, "cache_read_tokens": 2034304, "cache_creation_tokens": 0, "reasoning_tokens": 45516, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py</path>
<type>file</type>
<content>
1: """Naive softmax over the last dim, computed in fp64 for ground-truth.
2:
3: The reference deliberately runs in float64 so that fp16 / fp32 accumulation
4: drift in agent solutions is exposed by the tight tolerance in problem.yaml.
5: The agent's job is to produce an fp32 softmax whose values match this
6: double-precision reference within atol=rtol=1e-5 — this requires either
7: fp32 accumulation or compensated (Kahan) summation when vocab is large.
8: """
9: import torch
10: import torch.nn as nn
11:
12: OP_TYPE = "softmax"
13: SUPPORTED_PRECISIONS = ["fp32"]
14: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
15:
16:
17: class Model(nn.Module):
18: """y = softmax(x, dim=-1) computed in fp64 then returned as fp32.
19:
20: No learned parameters — softmax is parameter-free. We still expose an
21: empty state_dict so the harness's strict load_state_dict matches.
22: """
23:
24: def __init__(self, batch: int, vocab: int):
25: super().__init__()
26: self.batch = batch
27: self.vocab = vocab
28:
29: def forward(self, x: torch.Tensor) -> torch.Tensor:
30: # Promote to fp64 for the ground-truth pathway. Even with double
31: # precision we still subtract the row-max for stability.
32: x64 = x.to(torch.float64)
33: m = x64.amax(dim=-1, keepdim=True)
34: e = torch.exp
... (592 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Kahan-corrected softmax.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3
4: seeds each, with the tight (1e-5) fp32 tolerance from problem.yaml. Also
5: rejects forbidden ops via grep.
6: """
7: import re
8: import sys
9: from pathlib import Path
10:
11: import torch
12: import yaml
13:
14: # Make the repo's src/ importable
15: REPO_ROOT = Path(__file__).resolve().parents[2]
16: sys.path.insert(0, str(REPO_ROOT))
17:
18: from src.eval.correctness import check_correctness # noqa: E402
19:
20:
21: def _make_inputs(batch: int, vocab: int, extreme: bool, seed: int) -> torch.Tensor:
22: g = torch.Generator().manual_seed(seed)
23: if extreme:
24: # Adversarial: most logits are mild but a handful per row are huge.
25: # If the kernel forgets to subtract the row-max before exp, this
26: # row overflows fp32 and produces NaN/Inf. If it accumulates in
27: # fp16, the long tail of small exp() values is lost beneath the
28: # tolerance threshold.
29: x = torch.randn(batch, vocab, generator=g) * 2.0
30: # Spike: 4 very large positive logits per row.
31: idx = torch.randint(0, vocab, (batch, 4), generator=g)
32: x.scatter_(1, idx, 30.0)
33: else:
34: x = torch.randn(batch, vocab, generator=g) * 4.0
35: return x.to(torc
... (3464 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Kahan-corrected softmax.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if
4: available), and the agent's solution. Reports achieved TFLOPS, GB/s, and
5: peak_fraction. Softmax is memory-bound, so the score is GB/s / peak_dram.
6:
7: Output lines the harness picks up:
8: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10: """
11: import sys
12: from math import exp, log
13: from pathlib import Path
14:
15: import torch
16: import yaml
17:
18: REPO_ROOT = Path(__file__).resolve().parents[2]
19: sys.path.insert(0, str(REPO_ROOT))
20:
21: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22: from src.eval.timing import time_fn # noqa: E402
23: from src.hardware import get as get_hw # noqa: E402
24:
25:
26: def _eval_formula(expr: str, vars: dict) -> float:
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def _make_inputs(batch: int, vocab: int, extreme: bool) -> torch.Tensor:
31: if extreme:
32: x = torch.randn(batch, vocab) * 2.0
33: idx = torch.randint(0, vocab, (batch, 4))
34: x.scatter_(1, idx, 30.0)
35: else:
36: x = torch.randn(batch, vocab) * 4.0
37: return x.to(torch.float32)
38:
39:
40: def main():
41:
... (3706 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml</path>
<type>file</type>
<content>
1: name: 04_kahan_softmax
2: display_name: "Kahan-corrected Softmax"
3: precision: fp32
4: regime: memory # softmax is bandwidth-bound: 2 passes over the input tensor
5:
6: # Softmax FLOPs: per-element exp + 2 reductions + divide. Roughly 5 flops/elt.
7: flops_formula: "5 * batch * vocab"
8:
9: # Bytes moved: read x once, write y once. Both fp32.
10: bytes_formula: "batch * vocab * 4 + batch * vocab * 4"
11:
12: hardware: [RTX_PRO_6000]
13: peak_tflops_key: fp32
14: peak_bandwidth_key: dram
15:
16: # TIGHTER than default (fp32 default is 1e-4). This problem exists
17: # specifically to test whether the agent uses compensated summation, so
18: # we squeeze the tolerance to 1e-5 — naive fp16 sum across 256K elements
19: # drifts past this; fp32 accumulation passes; Kahan/fp32 always passes.
20: tolerance:
21: "torch.float32": {"atol": 1.0e-5, "rtol": 1.0e-5}
22:
23: # Forbidden ops — block the obvious "just call the library" cheats. The
24: # agent must implement softmax themselves with explicit (compensated)
25: # summation logic.
26: forbidden:
27: - "torch.nn.functional.softmax"
28: - "torch.softmax"
29: - "F.softmax"
30: - "liger_kernel.softmax"
31: - "liger_kernel.transformers.softmax"
32: - ".softmax("
33:
34: sota:
35: name: "Liger-Kernel Softmax (Triton)"
36: url: "https://github.com/linkedin/Liger-Kernel"
37: function: "liger_
... (225 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Kahan-corrected softmax.
2:
3: The point of this problem is numerical accuracy on long reductions. Shapes
4: mix typical LLM vocab sizes with deliberately adversarial regimes:
5:
6: - small vocab (sanity check; naive fp32 should pass)
7: - Llama3 vocab 128K (real-world, where fp16 accumulation starts to drift)
8: - 256K (DeepSeek-V3 / Gemma-3 class vocab; naive fp16 sum DOES drift past
9: the 1e-5 tolerance — this row is what proves Kahan was needed)
10: - extreme-logit edge case (large positive logits stress max-subtract +
11: summation; if the implementation accidentally exps before subtracting
12: max, this row overflows)
13:
14: The 'extreme' flag is read by check.py to switch input generation to a
15: distribution that produces a few very large logits per row.
16: """
17:
18: SHAPES = [
19: {"batch": 32, "vocab": 4096, "extreme": False}, # sanity
20: {"batch": 16, "vocab": 32768, "extreme": False}, # GPT-2 class
21: {"batch": 8, "vocab": 131072, "extreme": False}, # Llama3 vocab
22: {"batch": 4, "vocab": 262144, "extreme": False}, # 256K — Kahan needed
23: {"batch": 8, "vocab": 131072, "extreme": True}, # extreme logits edge
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for last-dim softmax.
2:
3: Preference order:
4: 1. liger-kernel's Triton softmax (LigerSoftmaxFunction) — fused, fp32-acc
5: 2. torch.compile(torch.softmax) — Inductor generates a Triton kernel that
6: is competitive on bandwidth-bound reductions
7:
8: Agents are FORBIDDEN from importing either of these in solution.py (see
9: problem.yaml.forbidden). This file is only the benchmark's reference line.
10: """
11: from __future__ import annotations
12:
13: import torch
14:
15: _compiled_softmax = None
16:
17:
18: def _liger_softmax(x: torch.Tensor) -> torch.Tensor | None:
19: try:
20: from liger_kernel.ops.softmax import LigerSoftmaxFunction
21: return LigerSoftmaxFunction.apply(x)
22: except Exception:
23: return None
24:
25:
26: def _compiled(x: torch.Tensor) -> torch.Tensor:
27: global _compiled_softmax
28: if _compiled_softmax is None:
29: _compiled_softmax = torch.compile(
30: lambda t: torch.softmax(t, dim=-1),
31: mode="reduce-overhead",
32: )
33: return _compiled_softmax(x)
34:
35:
36: def sota_forward(x: torch.Tensor) -> torch.Tensor:
37: """Best-available softmax reference. x: (batch, vocab) fp32."""
38: out = _liger_softmax(x)
39: if out is not None:
40: return out
41: return _compiled(x)
42:
43:
44: def is_a
... (128 more chars){
"command": "nvidia-smi --query-gpu=name,memory.total,compute_cap --format=csv,noheader",
"description": "Check GPU info before running workloads"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 12.0
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python -c \"import triton; print('Triton', triton.__version__)\"",
"description": "Check Triton availability and version"
}Triton 3.6.0
{
"todos": [
{
"content": "Write Triton softmax kernel with Kahan compensated summation in solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Verify correctness with python check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Benchmark with python benchmark.py",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and iterate if needed",
"status": "pending",
"priority": "medium"
}
]
}content: 'Write Triton softmax kernel with Kahan compensated summation in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with python check.py'
status: 'pending'
priority: 'high'
content: 'Benchmark with python benchmark.py'
status: 'pending'
priority: 'high'
content: 'Profile and iterate if needed'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -0,0 +1,118 @@
+"""Numertically tight softmax with Kahan compensated summation.
+
+Two-pass Triton kernel:
+ Pass 1 (forward): online max-finding + Kahan-compensated summation of exp(x - max)
+ Pass 2 (backward): re-read input, compute exp(x - max) / sum, write output
+
+The Kahan compensation runs on the cross-chunk scalar accumulator,
+which is where naive fp32 summation over 256K elements drifts past 1e-5.
+Within each chunk (≤2048 elements) Triton's tree-reduction in shared memory
+is accurate enough.
+"""
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _softmax_kahan_kernel(
+ x_ptr,
+ y_ptr,
+ V: tl.constexpr,
+ BLOCK_VOCAB: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs = tl.arange(0, BLOCK_VOCAB)
+ neg_inf: tl.constexpr = float('-inf')
+
+ # --- Forward pass: online max + Kahan-compensated sum ---
+ m_running: tl.constexpr = float('-inf')
+ s_running: tl.constexpr = 0.0
+ c_running: tl.constexpr = 0.0
+
+ for start in range(0, V, BLOCK_VOCAB):
+ mask = (start + offs) < V
+ row_offs = pid * V + start + offs
+ x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=neg_inf)
+
+ m_local = tl.max(tl.where(mask, x_chunk, neg_inf))
+ m_new = tl.maximum(m_running, m_local)
+
+ scale = tl.exp(m_running - m_new)
+ s_running = s_running * scale
+ c_running = c_running * scale
+
+ exp_chunk = tl.exp(x_chunk - m_new)
+ chunk_sum = tl.sum(tl.where(mask, exp_chunk, 0.0))
+
+ t = s_running
+ y = chunk_sum - c_running
+ s_running = t + y
+ c_running = (s_running - t) - y
+
+ m_running = m_new
+
+ inv_sum = 1.0 / s_running
+
+ # --- Backward pass: normalize and store ---
+ for start in range(0, V, BLOCK_VOCAB):
+ mask = (start + offs) < V
+ row_offs = pid * V + start + offs
+ x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=0.0)
+ y_chunk = tl.exp(x_chunk - m_running) * inv_sum
+ tl.store(y_ptr + row_offs, y_chunk, mask=mask)
+
+
+def _pick_block_size(vocab: int) -> int:
+ """Heuristic block size tuned for RTX PRO 6000 (sm_120)."""
+ if vocab <= 4096:
+ return 512
+ elif vocab <= 32768:
+ return 1024
+ else:
+ return 2048
+
+
+def _pick_num_warps(vocab: int) -> int:
+ if vocab <= 32768:
+ return 4
+ else:
+ return 8
+
+
+class Model(nn.Module):
+ def __init__(self, batch: int, vocab: int):
+ super().__init__()
+ self.batch = batch
+ self.vocab = vocab
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, V = x.shape
+ y = torch.empty_like(x)
+ BLOCK_VOCAB = _pick_block_size(V)
+ num_warps = _pick_num_warps(V)
+
+ grid = (B,)
+ _softmax_kahan_kernel[grid](
+ x, y,
+ V=V,
+ BLOCK_VOCAB=BLOCK_VOCAB,
+ num_warps=num_warps,
+ )
+ return y
+
+
+# Default shape; overridden per-iteration by check.py / benchmark.py.
+BATCH = 8
+VOCAB = 32768
+
+
+def get_inputs():
+ x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
+ return [x]
+
+
+def get_init_inputs():
+ return [BATCH, VOCAB]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<3131 chars \u2014 see diff>"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -25,12 +25,11 @@
):
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
- neg_inf: tl.constexpr = float('-inf')
# --- Forward pass: online max + Kahan-compensated sum ---
- m_running: tl.constexpr = float('-inf')
- s_running: tl.constexpr = 0.0
- c_running: tl.constexpr = 0.0
+ m_running = float('-inf')
+ s_running = 0.0
+ c_running = 0.0
for start in range(0, V, BLOCK_VOCAB):
mask = (start + offs) < V
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<281 chars \u2014 see below>",
"newString": "<197 chars \u2014 see below>"
} pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
neg_inf: tl.constexpr = float('-inf')
# --- Forward pass: online max + Kahan-compensated sum ---
m_running: tl.constexpr = float('-inf')
s_running: tl.constexpr = 0.0
c_running: tl.constexpr = 0.0 pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
# --- Forward pass: online max + Kahan-compensated sum ---
m_running = float('-inf')
s_running = 0.0
c_running = 0.0Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py</path>
<type>file</type>
<content>
1: """Numertically tight softmax with Kahan compensated summation.
2:
3: Two-pass Triton kernel:
4: Pass 1 (forward): online max-finding + Kahan-compensated summation of exp(x - max)
5: Pass 2 (backward): re-read input, compute exp(x - max) / sum, write output
6:
7: The Kahan compensation runs on the cross-chunk scalar accumulator,
8: which is where naive fp32 summation over 256K elements drifts past 1e-5.
9: Within each chunk (≤2048 elements) Triton's tree-reduction in shared memory
10: is accurate enough.
11: """
12:
13: import torch
14: import torch.nn as nn
15: import triton
16: import triton.language as tl
17:
18:
19: @triton.jit
20: def _softmax_kahan_kernel(
21: x_ptr,
22: y_ptr,
23: V: tl.constexpr,
24: BLOCK_VOCAB: tl.constexpr,
25: ):
26: pid = tl.program_id(0)
27: offs = tl.arange(0, BLOCK_VOCAB)
28:
29: # --- Forward pass: online max + Kahan-compensated sum ---
30: m_running = float('-inf')
31: s_running = 0.0
32: c_running = 0.0
33:
34: for start in range(0, V, BLOCK_VOCAB):
35: mask = (start + offs) < V
36: row_offs = pid * V + start + offs
37: x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=neg_inf)
38:
39: m_local = tl.max(tl.where(mask, x_chunk, neg_inf))
40: m_new = tl.maximum(m_running, m_local)
41:
42: scale = tl.exp(m_running - m_ne
... (2184 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -34,9 +34,9 @@
for start in range(0, V, BLOCK_VOCAB):
mask = (start + offs) < V
row_offs = pid * V + start + offs
- x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=neg_inf)
+ x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=float('-inf'))
- m_local = tl.max(tl.where(mask, x_chunk, neg_inf))
+ m_local = tl.max(tl.where(mask, x_chunk, float('-inf')))
m_new = tl.maximum(m_running, m_local)
scale = tl.exp(m_running - m_new)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<129 chars \u2014 see below>",
"newString": "<141 chars \u2014 see below>"
} x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=neg_inf)
m_local = tl.max(tl.where(mask, x_chunk, neg_inf)) x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=float('-inf'))
m_local = tl.max(tl.where(mask, x_chunk, float('-inf')))Edit applied successfully.{
"description": "Run correctness check for all shapes",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python problems/04_kahan_softmax/check.py 2>&1",
"timeout": 300000
}PASS
{
"description": "Run benchmark for all shapes",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python problems/04_kahan_softmax/benchmark.py 2>&1",
"timeout": 600000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py", line 135, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py", line 45, in main
meta = yaml.safe_load(Path("problem.yaml").read_text())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/pathlib.py", line 1058, in read_text
with self.open(mode='r', encoding=encoding, errors=errors) as f:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/pathlib.py", line 1044, in open
return io.open(self, mode, buffering, encoding, errors, newline)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'problem.yaml'
{
"description": "Run benchmark from problem directory",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"timeout": 600000
}shape=0 variant=eager tflops=0.010 gbps=15.600 ms=0.067
shape=0 variant=compiled tflops=0.006 gbps=10.287 ms=0.102
shape=0 variant=sota tflops=0.035 gbps=56.351 ms=0.019
shape=0 variant=solution tflops=0.052 gbps=83.913 ms=0.012
shape=0 solution_peak_fraction=0.0466
shape=1 variant=eager tflops=0.041 gbps=66.181 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.771 ms=0.098
shape=1 variant=sota tflops=0.131 gbps=210.051 ms=0.020
shape=1 variant=solution tflops=0.084 gbps=134.020 ms=0.031
shape=1 solution_peak_fraction=0.0745
shape=2 variant=eager tflops=0.073 gbps=116.121 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.449 ms=0.218
shape=2 variant=sota tflops=0.109 gbps=173.605 ms=0.048
shape=2 variant=solution tflops=0.084 gbps=135.021 ms=0.062
shape=2 solution_peak_fraction=0.0750
shape=3 variant=eager tflops=0.072 gbps=114.649 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.331 ms=0.360
shape=3 variant=sota tflops=0.101 gbps=161.817 ms=0.052
shape=3 variant=solution tflops=0.046 gbps=74.115 ms=0.113
shape=3 solution_peak_fraction=0.0412
shape=4 variant=eager tflops=0.072 gbps=115.355 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.777 ms=0.228
shape=4 variant=sota tflops=0.103 gbps=164.354 ms=0.051
shape=4 variant=solution tflops=0.084 gbps=135.091 ms=0.062
shape=4 solution_peak_fraction=0.0751
peak_fraction: 0.0604
RESULT: LOW
{
"description": "Quick ncu profile of the kernel on shape 3",
"command": "<821 chars \u2014 see below>",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && python3 -c "
import torch
import triton
import sys
sys.path.insert(0, '.')
import shapes
import reference
device = torch.device('cuda:0')
shape = shapes.SHAPES[3] # batch=4, vocab=262144
batch, vocab = shape['batch'], shape['vocab']
reference.BATCH = batch
reference.VOCAB = vocab
x = torch.randn(batch, vocab, dtype=torch.float32).to(device)
import solution
m = solution.Model(batch, vocab).to(device).eval()
# Warmup
for _ in range(10):
_ = m(x)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
iters = 100
start.record()
for _ in range(iters):
_ = m(x)
end.record()
torch.cuda.synchronize()
print(f'Time per iter: {start.elapsed_time(end) / iters:.4f} ms')
" 2>&1Time per iter: 0.1055 ms
{
"description": "Check Liger softmax kernel source for approach",
"command": "<113 chars \u2014 see below>",
"timeout": 10000
}python3 -c "
import liger_kernel.ops.softmax as sm
import inspect
print(inspect.getsource(sm))
" 2>&1 | head -200Traceback (most recent call last):
File "<string>", line 2, in <module>
ModuleNotFoundError: No module named 'liger_kernel'
{
"todos": [
{
"priority": "high",
"content": "Write Triton softmax kernel with Kahan compensated summation in solution.py",
"status": "completed"
},
{
"priority": "high",
"content": "Verify correctness with python check.py",
"status": "completed"
},
{
"priority": "high",
"content": "Optimize: multi-block map-reduce for better GPU utilization",
"status": "in_progress"
},
{
"priority": "high",
"content": "Benchmark with python benchmark.py",
"status": "pending"
},
{
"priority": "m
... (95 more chars)content: 'Write Triton softmax kernel with Kahan compensated summation in solution.py'
status: 'completed'
priority: 'high'
content: 'Verify correctness with python check.py'
status: 'completed'
priority: 'high'
content: 'Optimize: multi-block map-reduce for better GPU utilization'
status: 'in_progress'
priority: 'high'
content: 'Benchmark with python benchmark.py'
status: 'pending'
priority: 'high'
content: 'Profile and iterate if needed'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,11 +1,12 @@
-"""Numertically tight softmax with Kahan compensated summation.
-
-Two-pass Triton kernel:
- Pass 1 (forward): online max-finding + Kahan-compensated summation of exp(x - max)
- Pass 2 (backward): re-read input, compute exp(x - max) / sum, write output
-
-The Kahan compensation runs on the cross-chunk scalar accumulator,
-which is where naive fp32 summation over 256K elements drifts past 1e-5.
+"""Numerically tight softmax with Kahan compensated summation.
+
+Multi-block map-reduce on each row:
+ Map kernel: each block computes local (max, sum_of_exp) via online softmax
+ CPU reduce: combines per-block local results into global (max, sum) per row
+ Norm kernel: each block normalizes its chunk and writes output
+
+The Kahan compensation runs on the cross-chunk scalar accumulator inside each
+block, which is where naive fp32 summation over 256K elements drifts past 1e-5.
Within each chunk (≤2048 elements) Triton's tree-reduction in shared memory
is accurate enough.
"""
@@ -17,57 +18,106 @@
@triton.jit
-def _softmax_kahan_kernel(
- x_ptr,
- y_ptr,
- V: tl.constexpr,
- BLOCK_VOCAB: tl.constexpr,
-):
+def _map_kernel(x_ptr, max_buf, sum_buf, V, map_elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
+ """Online softmax: find local (max, sum) for each block's chunk of the row."""
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
- # --- Forward pass: online max + Kahan-compensated sum ---
- m_running = float('-inf')
- s_running = 0.0
- c_running = 0.0
-
- for start in range(0, V, BLOCK_VOCAB):
- mask = (start + offs) < V
- row_offs = pid * V + start + offs
- x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=float('-inf'))
-
- m_local = tl.max(tl.where(mask, x_chunk, float('-inf')))
- m_new = tl.maximum(m_running, m_local)
-
- scale = tl.exp(m_running - m_new)
- s_running = s_running * scale
- c_running = c_running * scale
-
- exp_chunk = tl.exp(x_chunk - m_new)
- chunk_sum = tl.sum(tl.where(mask, exp_chunk, 0.0))
-
- t = s_running
- y = chunk_sum - c_running
- s_running = t + y
- c_running = (s_running - t) - y
-
- m_running = m_new
-
- inv_sum = 1.0 / s_running
-
- # --- Backward pass: normalize and store ---
- for start in range(0, V, BLOCK_VOCAB):
- mask = (start + offs) < V
- row_offs = pid * V + start + offs
- x_chunk = tl.load(x_ptr + row_offs, mask=mask, other=0.0)
- y_chunk = tl.exp(x_chunk - m_running) * inv_sum
- tl.store(y_ptr + row_offs, y_chunk, mask=mask)
+ m = float('-inf')
+ s = 0.0
+ c = 0.0
+
+ base = pid * map_elems_per_block
+ for start in range(0, map_elems_per_block, BLOCK_VOCAB):
+ mask = (start + offs) < map_elems_per_block
+ addr = base + start + offs
+ x = tl.load(x_ptr + addr, mask=mask, other=float('-inf'))
+
+ m_local = tl.max(tl.where(mask, x, float('-inf')))
+ m_new = tl.maximum(m, m_local)
+
+ scale = tl.exp(m - m_new)
+ s = s * scale
+ c = c * scale
+
+ exp_vals = tl.exp(x - m_new)
+ chunk_sum = tl.sum(tl.where(mask, exp_vals, 0.0))
+
+ t = s
+ y = chunk_sum - c
+ s = t + y
+ c = (s - t) - y
+
+ m = m_new
+
+ tl.store(max_buf + pid, m)
+ tl.store(sum_buf + pid, s)
+
+
+@triton.jit
+def _norm_kernel(x_ptr, y_ptr, max_per_row, sum_per_row, row_ids, V,
+ norm_elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
+ """Normalize: exp(x - global_max) / global_sum, store."""
+ pid = tl.program_id(0)
+ offs = tl.arange(0, BLOCK_VOCAB)
+
+ row = tl.load(row_ids + pid)
+ m_global = tl.load(max_per_row + row)
+ s_global = tl.load(sum_per_row + row)
+ inv_sum = 1.0 / s_global if s_global > 0.0 else 0.0
+
+ base = pid * norm_elems_per_block
+ for start in range(0, norm_elems_per_block, BLOCK_VOCAB):
+ mask = (start + offs) < norm_elems_per_block
+ addr = base + start + offs
+ x = tl.load(x_ptr + addr, mask=mask, other=0.0)
+ y = tl.exp(x - m_global) * inv_sum
+ tl.store(y_ptr + addr, y, mask=mask)
+
+
+def _reduce_local_results(max_buf, sum_buf, batch, blocks_per_row):
+ """CPU-side reduction: combine per-block (max, sum) into per-row (max, sum)."""
+ global_max = torch.empty(batch, dtype=torch.float32, device=max_buf.device)
+ global_sum = torch.empty(batch, dtype=torch.float32, device=max_buf.device)
+
+ max_np = max_buf.view(batch, blocks_per_row)
+ sum_np = sum_buf.view(batch, blocks_per_row)
+
+ for r in range(batch):
+ m = float('-inf')
+ s = 0.0
+ c = 0.0
+ for k in range(blocks_per_row):
+ mk = max_np[r, k].item()
+ sk = sum_np[r, k].item()
+
+ m_new = max(m, mk)
+ scale = math_exp(m - m_new)
+ s = s * scale
+ c = c * scale
+
+ if mk < m_new:
+ sk = sk * math_exp(mk - m_new)
+
+ t = s
+ y = sk - c
+ s = t + y
+ c = (s - t) - y
+ m = m_new
+
+ global_max[r] = m
+ global_sum[r] = s
+
+ return global_max, global_sum
+
+
+import math
+math_exp = math.exp
def _pick_block_size(vocab: int) -> int:
- """Heuristic block size tuned for RTX PRO 6000 (sm_120)."""
if vocab <= 4096:
- return 512
+ return 256
elif vocab <= 32768:
return 1024
else:
@@ -75,7 +125,7 @@
def _pick_num_warps(vocab: int) -> int:
- if vocab <= 32768:
+ if vocab <= 16384:
return 4
else:
return 8
@@ -89,21 +139,112 @@
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, V = x.shape
- y = to
... (3442 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7314 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run correctness check for all shapes",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python problems/04_kahan_softmax/check.py 2>&1",
"timeout": 300000
}PASS
{
"description": "Run benchmark from problem directory",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"timeout": 600000
}shape=0 variant=eager tflops=0.010 gbps=15.530 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.368 ms=0.101
shape=0 variant=sota tflops=0.035 gbps=55.586 ms=0.019
shape=0 variant=solution tflops=0.041 gbps=65.145 ms=0.016
shape=0 solution_peak_fraction=0.0362
shape=1 variant=eager tflops=0.041 gbps=66.265 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.639 ms=0.098
shape=1 variant=sota tflops=0.128 gbps=204.003 ms=0.021
shape=1 variant=solution tflops=0.002 gbps=2.501 ms=1.677
shape=1 solution_peak_fraction=0.0014
shape=2 variant=eager tflops=0.073 gbps=116.508 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.474 ms=0.218
shape=2 variant=sota tflops=0.109 gbps=174.996 ms=0.048
shape=2 variant=solution tflops=0.003 gbps=5.229 ms=1.604
shape=2 solution_peak_fraction=0.0029
shape=3 variant=eager tflops=0.074 gbps=119.102 ms=0.070
shape=3 variant=compiled tflops=0.015 gbps=23.319 ms=0.360
shape=3 variant=sota tflops=0.098 gbps=157.208 ms=0.053
shape=3 variant=solution tflops=0.003 gbps=5.449 ms=1.539
shape=3 solution_peak_fraction=0.0030
shape=4 variant=eager tflops=0.072 gbps=115.711 ms=0.072
shape=4 variant=compiled tflops=0.023 gbps=36.782 ms=0.228
shape=4 variant=sota tflops=0.103 gbps=164.302 ms=0.051
shape=4 variant=solution tflops=0.003 gbps=5.294 ms=1.584
shape=4 solution_peak_fraction=0.0029
peak_fraction: 0.0042
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -2,13 +2,12 @@
Multi-block map-reduce on each row:
Map kernel: each block computes local (max, sum_of_exp) via online softmax
- CPU reduce: combines per-block local results into global (max, sum) per row
+ Reduce kernel: GPU-side Kahan reduction of per-block results -> global per row
Norm kernel: each block normalizes its chunk and writes output
The Kahan compensation runs on the cross-chunk scalar accumulator inside each
-block, which is where naive fp32 summation over 256K elements drifts past 1e-5.
-Within each chunk (≤2048 elements) Triton's tree-reduction in shared memory
-is accurate enough.
+block and in the global reduce, which is where naive fp32 summation over 256K
+elements drifts past 1e-5.
"""
import torch
@@ -18,18 +17,19 @@
@triton.jit
-def _map_kernel(x_ptr, max_buf, sum_buf, V, map_elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
- """Online softmax: find local (max, sum) for each block's chunk of the row."""
+def _map_kernel(x_ptr, max_buf, sum_buf, row_ids, V,
+ elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
+ """Online softmax: find local (max, Kahan-sum-of-exp) for each block's chunk."""
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
- m = float('-inf')
- s = 0.0
- c = 0.0
-
- base = pid * map_elems_per_block
- for start in range(0, map_elems_per_block, BLOCK_VOCAB):
- mask = (start + offs) < map_elems_per_block
+ base = pid * elems_per_block
+ m = float('-inf')
+ s = 0.0
+ c = 0.0
+
+ for start in range(0, elems_per_block, BLOCK_VOCAB):
+ mask = (start + offs) < elems_per_block
addr = base + start + offs
x = tl.load(x_ptr + addr, mask=mask, other=float('-inf'))
@@ -50,14 +50,92 @@
m = m_new
+ row_id = tl.load(row_ids + pid)
tl.store(max_buf + pid, m)
tl.store(sum_buf + pid, s)
@triton.jit
+def _reduce_kernel(max_in, sum_in, max_out, sum_out,
+ reduced_per_row: tl.constexpr, NUM_BLOCKS: tl.constexpr):
+ """Reduce per-block (max, sum) pairs into per-row global (max, sum).
+
+ Each program handles one row, combining reduced_per_row pairs with
+ online softmax + Kahan compensation.
+ """
+ rid = tl.program_id(0)
+ offs = rid * reduced_per_row + tl.arange(0, reduced_per_row)
+
+ m = float('-inf')
+ s = 0.0
+ c = 0.0
+
+ for k in range(NUM_BLOCKS):
+ mask = (tl.arange(0, reduced_per_row)) < reduced_per_row # all active
+ mk_vec = tl.load(max_in + offs, mask=mask, other=float('-inf'))
+ sk_vec = tl.load(sum_in + offs, mask=mask, other=0.0)
+
+ mk = tl.sum(tl.where(mask, mk_vec, 0.0)) # reduction to get only element
+ sk = tl.sum(tl.where(mask, sk_vec, 0.0))
+
+ # Actually, this is wrong. I need to load one element at a time, not vector.
+ # Let me fix this: use a scalar load approach.
+ pass
+
+ tl.store(max_out + rid, m)
+ tl.store(sum_out + rid, s)
+
+
+@triton.jit
+def _reduce_kernel_v2(max_in, sum_in, max_out, sum_out,
+ blocks_per_row: tl.constexpr, BLOCK_REDUCE: tl.constexpr):
+ """Combine per-block (max,sum) into per-row global (max,sum) with Kahan.
+
+ One program per row. Iterates over blocks_per_row results sequentially."""
+ rid = tl.program_id(0)
+ offs = tl.arange(0, BLOCK_REDUCE)
+
+ m = float('-inf')
+ s = 0.0
+ c = 0.0
+
+ for k in range(0, blocks_per_row, BLOCK_REDUCE):
+ mask = (k + offs) < blocks_per_row
+ idx = rid * blocks_per_row + k + offs
+ mk = tl.load(max_in + idx, mask=mask, other=float('-inf'))
+ sk = tl.load(sum_in + idx, mask=mask, other=0.0)
+
+ # Process BLOCK_REDUCE elements sequentially (scalar loop per thread)
+ # Since each thread may process multiple elements, use a scalar approach
+
+ # For each element in this chunk:
+ for j in range(BLOCK_REDUCE):
+ valid = (k + j) < blocks_per_row
+ if valid:
+ curr_mk = tl.load(max_in + rid * blocks_per_row + k + j)
+ curr_sk = tl.load(sum_in + rid * blocks_per_row + k + j)
+
+ m_new = tl.maximum(m, curr_mk)
+ scale = tl.exp(m - m_new)
+ s = s * scale
+ c = c * scale
+ contrib = curr_sk if curr_mk >= m_new else curr_sk * tl.exp(curr_mk - m_new)
+
+ t = s
+ y_k = contrib - c
+ s = t + y_k
+ c = (s - t) - y_k
+
+ m = m_new
+
+ tl.store(max_out + rid, m)
+ tl.store(sum_out + rid, s)
+
+
+@triton.jit
def _norm_kernel(x_ptr, y_ptr, max_per_row, sum_per_row, row_ids, V,
- norm_elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
- """Normalize: exp(x - global_max) / global_sum, store."""
+ elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
+ """exp(x - global_max) / global_sum, store."""
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
@@ -66,54 +144,58 @@
s_global = tl.load(sum_per_row + row)
inv_sum = 1.0 / s_global if s_global > 0.0 else 0.0
- base = pid * norm_elems_per_block
- for start in range(0, norm_elems_per_block, BLOCK_VOCAB):
- mask = (start + offs) < norm_elems_per_block
+ base = pid * elems_per_block
+ for start in range(0, elems_per_block, BLOCK_VOCAB):
+ mask = (start + offs) < elems_per_block
addr = base + start + offs
x = tl.load(x_ptr + addr, mask=mask, other=0.0)
y = tl.exp(x - m_global) * inv_sum
tl.store(y_ptr + addr, y, mask=mask)
-def _reduce_local_results(max_buf, sum_buf, batch, blocks_per_row):
- """CPU-side reduction: combine per-block (max, sum) into per
... (6912 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<8988 chars \u2014 see diff>"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,13 +1,12 @@
"""Numerically tight softmax with Kahan compensated summation.
-Multi-block map-reduce on each row:
- Map kernel: each block computes local (max, sum_of_exp) via online softmax
- Reduce kernel: GPU-side Kahan reduction of per-block results -> global per row
- Norm kernel: each block normalizes its chunk and writes output
-
-The Kahan compensation runs on the cross-chunk scalar accumulator inside each
-block and in the global reduce, which is where naive fp32 summation over 256K
-elements drifts past 1e-5.
+For small vocabs (≤32K where one block per row gives reasonable occupancy):
+ single-pass online softmax within one block.
+
+For large vocabs (>32K): multi-block map-reduce per row:
+ Map: each block computes local (max, Kahan-sum-of-exp) for its chunk
+ Reduce: GPU-side Kahan combination of per-block results (num_warps=1)
+ Norm: each block normalizes its chunk using global (max, sum)
"""
import torch
@@ -16,14 +15,64 @@
import triton.language as tl
-@triton.jit
-def _map_kernel(x_ptr, max_buf, sum_buf, row_ids, V,
- elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
- """Online softmax: find local (max, Kahan-sum-of-exp) for each block's chunk."""
+# ---------------------------------------------------------------------------
+# Single-block kernel (used when blocks_per_row == 1)
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _single_block_kernel(x_ptr, y_ptr, V: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
+ """Online softmax + Kahan sum, then normalize. One block per row."""
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
+ m = float('-inf')
+ s = 0.0
+ c = 0.0
+
+ for start in range(0, V, BLOCK_VOCAB):
+ mask = (start + offs) < V
+ addr = pid * V + start + offs
+ x = tl.load(x_ptr + addr, mask=mask, other=float('-inf'))
+
+ m_local = tl.max(tl.where(mask, x, float('-inf')))
+ m_new = tl.maximum(m, m_local)
+ scale = tl.exp(m - m_new)
+
+ s = s * scale
+ c = c * scale
+
+ exp_vals = tl.exp(x - m_new)
+ chunk_sum = tl.sum(tl.where(mask, exp_vals, 0.0))
+
+ t = s
+ y_k = chunk_sum - c
+ s = t + y_k
+ c = (s - t) - y_k
+
+ m = m_new
+
+ inv_sum = 1.0 / s
+
+ for start in range(0, V, BLOCK_VOCAB):
+ mask = (start + offs) < V
+ addr = pid * V + start + offs
+ x = tl.load(x_ptr + addr, mask=mask, other=0.0)
+ y = tl.exp(x - m) * inv_sum
+ tl.store(y_ptr + addr, y, mask=mask)
+
+
+# ---------------------------------------------------------------------------
+# Multi-block kernels
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _map_kernel(x_ptr, max_buf, sum_buf, V,
+ elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
+ """One block processes one chunk: online max + Kahan-sum of exp."""
+ pid = tl.program_id(0)
+ offs = tl.arange(0, BLOCK_VOCAB)
base = pid * elems_per_block
+
m = float('-inf')
s = 0.0
c = 0.0
@@ -35,8 +84,8 @@
m_local = tl.max(tl.where(mask, x, float('-inf')))
m_new = tl.maximum(m, m_local)
-
scale = tl.exp(m - m_new)
+
s = s * scale
c = c * scale
@@ -44,105 +93,65 @@
chunk_sum = tl.sum(tl.where(mask, exp_vals, 0.0))
t = s
- y = chunk_sum - c
- s = t + y
- c = (s - t) - y
+ y_k = chunk_sum - c
+ s = t + y_k
+ c = (s - t) - y_k
m = m_new
- row_id = tl.load(row_ids + pid)
tl.store(max_buf + pid, m)
tl.store(sum_buf + pid, s)
@triton.jit
def _reduce_kernel(max_in, sum_in, max_out, sum_out,
- reduced_per_row: tl.constexpr, NUM_BLOCKS: tl.constexpr):
- """Reduce per-block (max, sum) pairs into per-row global (max, sum).
-
- Each program handles one row, combining reduced_per_row pairs with
- online softmax + Kahan compensation.
- """
+ blocks_per_row: tl.constexpr):
+ """Combine per-block (max,sum) into per-row global (max,sum).
+
+ Single-warp kernel (num_warps=1). One program per row.
+ Iterates over blocks_per_row pairs with Kahan-compensated online softmax."""
rid = tl.program_id(0)
- offs = rid * reduced_per_row + tl.arange(0, reduced_per_row)
+ base = rid * blocks_per_row
m = float('-inf')
s = 0.0
c = 0.0
- for k in range(NUM_BLOCKS):
- mask = (tl.arange(0, reduced_per_row)) < reduced_per_row # all active
- mk_vec = tl.load(max_in + offs, mask=mask, other=float('-inf'))
- sk_vec = tl.load(sum_in + offs, mask=mask, other=0.0)
-
- mk = tl.sum(tl.where(mask, mk_vec, 0.0)) # reduction to get only element
- sk = tl.sum(tl.where(mask, sk_vec, 0.0))
-
- # Actually, this is wrong. I need to load one element at a time, not vector.
- # Let me fix this: use a scalar load approach.
- pass
+ for k in range(blocks_per_row):
+ mk = tl.load(max_in + (base + k))
+ sk = tl.load(sum_in + (base + k))
+
+ m_new = tl.maximum(m, mk)
+ scale = tl.exp(m - m_new)
+
+ s = s * scale
+ c = c * scale
+
+ if mk < m_new:
+ sk = sk * tl.exp(mk - m_new)
+
+ t = s
+ y_k = sk - c
+ s = t + y_k
+ c = (s - t) - y_k
+
+ m = m_new
tl.store(max_out + rid, m)
tl.store(sum_out + rid, s)
@triton.jit
-def _reduce_kernel_v2(max_in, sum_in, max_out, sum_out,
- blocks_per_row: tl.constexpr, BLOCK_REDUCE: tl.constexpr):
- """Combine per-block (max,sum) into per-row global (max,sum) with Kahan.
-
- One program per row. Iterates over b
... (6339 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7729 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Check correctness then benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python problems/04_kahan_softmax/check.py 2>&1 && cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=0.010 gbps=15.556 ms=0.067
shape=0 variant=compiled tflops=0.006 gbps=10.295 ms=0.102
shape=0 variant=sota tflops=0.034 gbps=54.522 ms=0.019
shape=0 variant=solution tflops=0.042 gbps=67.355 ms=0.016
shape=0 solution_peak_fraction=0.0374
shape=1 variant=eager tflops=0.042 gbps=66.500 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.778 ms=0.098
shape=1 variant=sota tflops=0.130 gbps=207.392 ms=0.020
shape=1 variant=solution tflops=0.083 gbps=133.001 ms=0.032
shape=1 solution_peak_fraction=0.0739
shape=2 variant=eager tflops=0.073 gbps=116.742 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.517 ms=0.218
shape=2 variant=sota tflops=0.108 gbps=172.747 ms=0.049
shape=2 variant=solution tflops=0.143 gbps=228.348 ms=0.037
shape=2 solution_peak_fraction=0.1269
shape=3 variant=eager tflops=0.074 gbps=118.403 ms=0.071
shape=3 variant=compiled tflops=0.015 gbps=23.243 ms=0.361
shape=3 variant=sota tflops=0.098 gbps=157.255 ms=0.053
shape=3 variant=solution tflops=0.138 gbps=220.846 ms=0.038
shape=3 solution_peak_fraction=0.1227
shape=4 variant=eager tflops=0.072 gbps=114.925 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.615 ms=0.229
shape=4 variant=sota tflops=0.102 gbps=163.482 ms=0.051
shape=4 variant=solution tflops=0.145 gbps=232.191 ms=0.036
shape=4 solution_peak_fraction=0.1290
peak_fraction: 0.0889
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,12 +1,12 @@
"""Numerically tight softmax with Kahan compensated summation.
-For small vocabs (≤32K where one block per row gives reasonable occupancy):
- single-pass online softmax within one block.
-
-For large vocabs (>32K): multi-block map-reduce per row:
+Multi-block map-reduce per row for all shapes with >1 block per row:
Map: each block computes local (max, Kahan-sum-of-exp) for its chunk
Reduce: GPU-side Kahan combination of per-block results (num_warps=1)
Norm: each block normalizes its chunk using global (max, sum)
+
+Single-block path only when batch is large enough that one-block-per-row
+already saturates the GPU (B >= 32).
"""
import torch
@@ -16,7 +16,7 @@
# ---------------------------------------------------------------------------
-# Single-block kernel (used when blocks_per_row == 1)
+# Single-block kernel
# ---------------------------------------------------------------------------
@triton.jit
@@ -108,8 +108,7 @@
blocks_per_row: tl.constexpr):
"""Combine per-block (max,sum) into per-row global (max,sum).
- Single-warp kernel (num_warps=1). One program per row.
- Iterates over blocks_per_row pairs with Kahan-compensated online softmax."""
+ Single-warp kernel (num_warps=1). One program per row."""
rid = tl.program_id(0)
base = rid * blocks_per_row
@@ -163,23 +162,32 @@
# ---------------------------------------------------------------------------
-# Heuristics
-# ---------------------------------------------------------------------------
-
-def _pick_block_size(vocab: int) -> int:
- if vocab <= 4096:
- return 256
- elif vocab <= 32768:
- return 1024
+# Heuristics (tuned for RTX PRO 6000 / SM120 Blackwell)
+# ---------------------------------------------------------------------------
+
+def _pick_block_size(vocab: int, multi_block: bool) -> int:
+ if multi_block:
+ if vocab <= 32768:
+ return triton.next_power_of_2(vocab // 16 // 4) or 512
+ else:
+ return 4096
else:
- return 2048
-
-
-def _pick_num_warps(vocab: int) -> int:
- if vocab <= 16384:
- return 4
+ if vocab <= 4096:
+ return 512
+ elif vocab <= 32768:
+ return 1024
+ else:
+ return 2048
+
+
+def _pick_num_warps(vocab: int, multi_block: bool) -> int:
+ if multi_block:
+ return 8
else:
- return 8
+ if vocab <= 16384:
+ return 4
+ else:
+ return 8
# ---------------------------------------------------------------------------
@@ -196,12 +204,13 @@
B, V = x.shape
device = x.device
- BLOCK_VOCAB = _pick_block_size(V)
- num_warps = _pick_num_warps(V)
-
- # One block per row is sufficient when the row fits in a reasonable
- # number of loop iterations. Multi-block only for large vocabs.
- if V <= 32768 or B >= 32:
+ # Use single-block path when B >= 32 (already good occupancy)
+ # or when V is so small that multi-block has disproportionate overhead
+ use_multi = (B < 32) and (V >= 4096)
+ BLOCK_VOCAB = _pick_block_size(V, use_multi)
+ num_warps = _pick_num_warps(V, use_multi)
+
+ if not use_multi:
y = torch.empty_like(x)
_single_block_kernel[(B,)](
x, y,
@@ -212,14 +221,14 @@
return y
# ---- Multi-block path ----
- blocks_per_row = min(max(1, V // 16384), 32)
+ # Target enough blocks to keep SMs busy, with at least 2048 elems/block
+ blocks_per_row = max(1, min(96 // max(B, 1), V // 2048))
elems_per_block = V // blocks_per_row
total_blocks = B * blocks_per_row
- # Row-to-block mapping
row_ids = torch.arange(B, device=device, dtype=torch.int32).repeat_interleave(blocks_per_row)
- # 1) Map: each block computes its local (max, Kahan-sum)
+ # 1) Map
max_buf = torch.empty(total_blocks, dtype=torch.float32, device=device)
sum_buf = torch.empty(total_blocks, dtype=torch.float32, device=device)
@@ -231,7 +240,7 @@
num_warps=num_warps,
)
- # 2) Reduce: combine per-block results -> per-row global (max, sum)
+ # 2) Reduce
global_max = torch.empty(B, dtype=torch.float32, device=device)
global_sum = torch.empty(B, dtype=torch.float32, device=device)
@@ -241,7 +250,7 @@
num_warps=1,
)
- # 3) Norm: normalize each chunk using global (max, sum)
+ # 3) Norm
y = torch.empty_like(x)
_norm_kernel[(total_blocks,)](
x, y, global_max, global_sum, row_ids,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7919 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Check correctness then benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python problems/04_kahan_softmax/check.py 2>&1 && cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"timeout": 600000
}FAIL: shape 1 {'batch': 16, 'vocab': 32768, 'extreme': False} seed 42: tolerance exceeded: max_abs_diff=0.000353589 (atol=0.0001, rtol=0.0001)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -221,8 +221,13 @@
return y
# ---- Multi-block path ----
- # Target enough blocks to keep SMs busy, with at least 2048 elems/block
- blocks_per_row = max(1, min(96 // max(B, 1), V // 2048))
+ # Target enough blocks to keep SMs busy; prefer powers-of-2 that
+ # evenly divide V so every element belongs to exactly one chunk.
+ target = max(1, min(96 // max(B, 1), V // 2048))
+ blocks_per_row = triton.next_power_of_2(target)
+ # Walk down to a divisor of V
+ while V % blocks_per_row != 0 and blocks_per_row > 1:
+ blocks_per_row //= 2
elems_per_block = V // blocks_per_row
total_blocks = B * blocks_per_row
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<227 chars \u2014 see below>",
"newString": "<474 chars \u2014 see below>"
} # ---- Multi-block path ----
# Target enough blocks to keep SMs busy, with at least 2048 elems/block
blocks_per_row = max(1, min(96 // max(B, 1), V // 2048))
elems_per_block = V // blocks_per_row # ---- Multi-block path ----
# Target enough blocks to keep SMs busy; prefer powers-of-2 that
# evenly divide V so every element belongs to exactly one chunk.
target = max(1, min(96 // max(B, 1), V // 2048))
blocks_per_row = triton.next_power_of_2(target)
# Walk down to a divisor of V
while V % blocks_per_row != 0 and blocks_per_row > 1:
blocks_per_row //= 2
elems_per_block = V // blocks_per_rowEdit applied successfully.{
"description": "Check correctness then benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python problems/04_kahan_softmax/check.py 2>&1 && cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=0.010 gbps=15.570 ms=0.067
shape=0 variant=compiled tflops=0.006 gbps=10.285 ms=0.102
shape=0 variant=sota tflops=0.035 gbps=56.302 ms=0.019
shape=0 variant=solution tflops=0.052 gbps=82.852 ms=0.013
shape=0 solution_peak_fraction=0.0460
shape=1 variant=eager tflops=0.041 gbps=65.291 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=42.508 ms=0.099
shape=1 variant=sota tflops=0.128 gbps=204.481 ms=0.021
shape=1 variant=solution tflops=0.064 gbps=102.963 ms=0.041
shape=1 solution_peak_fraction=0.0572
shape=2 variant=eager tflops=0.072 gbps=115.482 ms=0.073
shape=2 variant=compiled tflops=0.024 gbps=38.480 ms=0.218
shape=2 variant=sota tflops=0.111 gbps=177.424 ms=0.047
shape=2 variant=solution tflops=0.134 gbps=214.784 ms=0.039
shape=2 solution_peak_fraction=0.1193
shape=3 variant=eager tflops=0.072 gbps=115.993 ms=0.072
shape=3 variant=compiled tflops=0.015 gbps=23.301 ms=0.360
shape=3 variant=sota tflops=0.099 gbps=158.923 ms=0.053
shape=3 variant=solution tflops=0.131 gbps=209.548 ms=0.040
shape=3 solution_peak_fraction=0.1164
shape=4 variant=eager tflops=0.073 gbps=116.121 ms=0.072
shape=4 variant=compiled tflops=0.023 gbps=36.790 ms=0.228
shape=4 variant=sota tflops=0.102 gbps=162.924 ms=0.051
shape=4 variant=solution tflops=0.136 gbps=217.547 ms=0.039
shape=4 solution_peak_fraction=0.1209
peak_fraction: 0.0849
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,12 +1,12 @@
"""Numerically tight softmax with Kahan compensated summation.
-Multi-block map-reduce per row for all shapes with >1 block per row:
- Map: each block computes local (max, Kahan-sum-of-exp) for its chunk
- Reduce: GPU-side Kahan combination of per-block results (num_warps=1)
- Norm: each block normalizes its chunk using global (max, sum)
-
-Single-block path only when batch is large enough that one-block-per-row
-already saturates the GPU (B >= 32).
+Single-block path for smaller vocabs (V ≤ 32768) where one-kernel-launch
+simplicity wins. Multi-block map-reduce for large vocabs where parallelism
+across blocks is needed to saturate GPU bandwidth.
+
+Map: each block computes local (max, Kahan-sum-of-exp) for its chunk.
+Reduce: GPU-side Kahan combination of per-block results (num_warps=1).
+Norm: each block normalizes its chunk using global (max, sum).
"""
import torch
@@ -20,7 +20,7 @@
# ---------------------------------------------------------------------------
@triton.jit
-def _single_block_kernel(x_ptr, y_ptr, V: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
+def _single_kernel(x_ptr, y_ptr, V: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
"""Online softmax + Kahan sum, then normalize. One block per row."""
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
@@ -66,9 +66,8 @@
# ---------------------------------------------------------------------------
@triton.jit
-def _map_kernel(x_ptr, max_buf, sum_buf, V,
+def _map_kernel(x_ptr, max_buf, sum_buf,
elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
- """One block processes one chunk: online max + Kahan-sum of exp."""
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
base = pid * elems_per_block
@@ -106,9 +105,6 @@
@triton.jit
def _reduce_kernel(max_in, sum_in, max_out, sum_out,
blocks_per_row: tl.constexpr):
- """Combine per-block (max,sum) into per-row global (max,sum).
-
- Single-warp kernel (num_warps=1). One program per row."""
rid = tl.program_id(0)
base = rid * blocks_per_row
@@ -122,7 +118,6 @@
m_new = tl.maximum(m, mk)
scale = tl.exp(m - m_new)
-
s = s * scale
c = c * scale
@@ -141,9 +136,8 @@
@triton.jit
-def _norm_kernel(x_ptr, y_ptr, max_per_row, sum_per_row, row_ids, V,
+def _norm_kernel(x_ptr, y_ptr, max_per_row, sum_per_row, row_ids,
elems_per_block: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
- """exp(x - global_max) * (1 / global_sum), store."""
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_VOCAB)
@@ -162,32 +156,27 @@
# ---------------------------------------------------------------------------
-# Heuristics (tuned for RTX PRO 6000 / SM120 Blackwell)
-# ---------------------------------------------------------------------------
-
-def _pick_block_size(vocab: int, multi_block: bool) -> int:
- if multi_block:
- if vocab <= 32768:
- return triton.next_power_of_2(vocab // 16 // 4) or 512
- else:
- return 4096
+# Tuning
+# ---------------------------------------------------------------------------
+
+def _single_config(vocab: int) -> tuple:
+ """Return (BLOCK_VOCAB, num_warps) for single-block path."""
+ if vocab <= 4096:
+ return (1024, 4)
+ elif vocab <= 16384:
+ return (2048, 8)
else:
- if vocab <= 4096:
- return 512
- elif vocab <= 32768:
- return 1024
- else:
- return 2048
-
-
-def _pick_num_warps(vocab: int, multi_block: bool) -> int:
- if multi_block:
- return 8
+ return (4096, 8)
+
+
+def _multi_config(vocab: int) -> tuple:
+ """Return (BLOCK_VOCAB, num_warps) for multi-block path."""
+ if vocab <= 32768:
+ return (2048, 8)
+ elif vocab <= 131072:
+ return (4096, 8)
else:
- if vocab <= 16384:
- return 4
- else:
- return 8
+ return (8192, 8)
# ---------------------------------------------------------------------------
@@ -204,15 +193,11 @@
B, V = x.shape
device = x.device
- # Use single-block path when B >= 32 (already good occupancy)
- # or when V is so small that multi-block has disproportionate overhead
- use_multi = (B < 32) and (V >= 4096)
- BLOCK_VOCAB = _pick_block_size(V, use_multi)
- num_warps = _pick_num_warps(V, use_multi)
-
- if not use_multi:
+ # Single-block path for small vocabs or high-batch cases
+ if V <= 32768:
y = torch.empty_like(x)
- _single_block_kernel[(B,)](
+ BLOCK_VOCAB, num_warps = _single_config(V)
+ _single_kernel[(B,)](
x, y,
V=V,
BLOCK_VOCAB=BLOCK_VOCAB,
@@ -220,12 +205,12 @@
)
return y
- # ---- Multi-block path ----
- # Target enough blocks to keep SMs busy; prefer powers-of-2 that
- # evenly divide V so every element belongs to exactly one chunk.
- target = max(1, min(96 // max(B, 1), V // 2048))
+ # ---- Multi-block path for large vocabs (> 32768) ----
+ BLOCK_VOCAB, num_warps = _multi_config(V)
+
+ # Choose blocks_per_row as a power-of-2 dividing V
+ target = max(1, min(64 // max(B, 1), V // 4096))
blocks_per_row = triton.next_power_of_2(target)
- # Walk down to a divisor of V
while V % blocks_per_row != 0 and blocks_per_row > 1:
blocks_per_row //= 2
elems_per_block = V // blocks_per_row
@@ -239,7 +224,6 @@
_map_kernel[(total_blocks,)](
x, max_buf, sum_buf,
- V=V,
elems_per_block=elems_per_block,
BLOCK_VOCAB=BLOCK_VOCAB,
... (316 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<7521 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Check correctness then benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python problems/04_kahan_softmax/check.py 2>&1 && cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=0.010 gbps=15.600 ms=0.067
shape=0 variant=compiled tflops=0.007 gbps=10.407 ms=0.101
shape=0 variant=sota tflops=0.035 gbps=56.399 ms=0.019
shape=0 variant=solution tflops=0.056 gbps=90.270 ms=0.012
shape=0 solution_peak_fraction=0.0501
shape=1 variant=eager tflops=0.041 gbps=66.098 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.674 ms=0.098
shape=1 variant=sota tflops=0.131 gbps=210.220 ms=0.020
shape=1 variant=solution tflops=0.139 gbps=221.780 ms=0.019
shape=1 solution_peak_fraction=0.1232
shape=2 variant=eager tflops=0.072 gbps=115.737 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.359 ms=0.219
shape=2 variant=sota tflops=0.109 gbps=174.821 ms=0.048
shape=2 variant=solution tflops=0.138 gbps=221.499 ms=0.038
shape=2 solution_peak_fraction=0.1231
shape=3 variant=eager tflops=0.074 gbps=117.844 ms=0.071
shape=3 variant=compiled tflops=0.015 gbps=23.300 ms=0.360
shape=3 variant=sota tflops=0.099 gbps=157.966 ms=0.053
shape=3 variant=solution tflops=0.132 gbps=211.406 ms=0.040
shape=3 solution_peak_fraction=0.1174
shape=4 variant=eager tflops=0.072 gbps=115.788 ms=0.072
shape=4 variant=compiled tflops=0.023 gbps=36.782 ms=0.228
shape=4 variant=sota tflops=0.101 gbps=161.022 ms=0.052
shape=4 variant=solution tflops=0.134 gbps=214.872 ms=0.039
shape=4 solution_peak_fraction=0.1194
peak_fraction: 0.1013
RESULT: OK
{
"description": "Final correctness check and verify framework.txt",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python problems/04_kahan_softmax/check.py 2>&1 && cat problems/04_kahan_softmax/framework.txt",
"timeout": 300000
}PASS
cat: problems/04_kahan_softmax/framework.txt: No such file or directory
{
"description": "Run check from problem directory for tight tolerance",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python check.py 2>&1",
"timeout": 300000
}PASS
{
"description": "Check framework.txt was created",
"command": "ls -la /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/framework.txt && cat /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/framework.txt"
}-rw-rw-r-- 1 infatoshi infatoshi 7 Apr 27 10:25 /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/framework.txt
triton
{
"description": "Final benchmark from correct directory",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"timeout": 600000
}shape=0 variant=eager tflops=0.010 gbps=15.482 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.269 ms=0.102
shape=0 variant=sota tflops=0.033 gbps=53.152 ms=0.020
shape=0 variant=solution tflops=0.057 gbps=90.519 ms=0.012
shape=0 solution_peak_fraction=0.0503
shape=1 variant=eager tflops=0.041 gbps=65.145 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=42.411 ms=0.099
shape=1 variant=sota tflops=0.126 gbps=201.960 ms=0.021
shape=1 variant=solution tflops=0.139 gbps=222.344 ms=0.019
shape=1 solution_peak_fraction=0.1235
shape=2 variant=eager tflops=0.072 gbps=115.177 ms=0.073
shape=2 variant=compiled tflops=0.024 gbps=38.471 ms=0.218
shape=2 variant=sota tflops=0.107 gbps=171.898 ms=0.049
shape=2 variant=solution tflops=0.138 gbps=220.567 ms=0.038
shape=2 solution_peak_fraction=0.1225
shape=3 variant=eager tflops=0.072 gbps=115.026 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.362 ms=0.359
shape=3 variant=sota tflops=0.099 gbps=158.587 ms=0.053
shape=3 variant=solution tflops=0.132 gbps=211.066 ms=0.040
shape=3 solution_peak_fraction=0.1173
shape=4 variant=eager tflops=0.073 gbps=116.147 ms=0.072
shape=4 variant=compiled tflops=0.023 gbps=36.612 ms=0.229
shape=4 variant=sota tflops=0.099 gbps=158.875 ms=0.053
shape=4 variant=solution tflops=0.138 gbps=220.939 ms=0.038
shape=4 solution_peak_fraction=0.1227
peak_fraction: 0.1018
RESULT: OK
{
"description": "Run ruff lint check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/04_kahan_softmax/solution.py --fix 2>&1"
}Found 1 error (1 fixed, 0 remaining).
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py</path>
<type>file</type>
<content>
1: """Numerically tight softmax with Kahan compensated summation.
2:
3: Single-block path for smaller vocabs (V ≤ 32768) where one-kernel-launch
4: simplicity wins. Multi-block map-reduce for large vocabs where parallelism
5: across blocks is needed to saturate GPU bandwidth.
6:
7: Map: each block computes local (max, Kahan-sum-of-exp) for its chunk.
8: Reduce: GPU-side Kahan combination of per-block results (num_warps=1).
9: Norm: each block normalizes its chunk using global (max, sum).
10: """
11:
12: import torch
13: import torch.nn as nn
14: import triton
15: import triton.language as tl
16:
17: # ---------------------------------------------------------------------------
18: # Single-block kernel
19: # ---------------------------------------------------------------------------
20:
21: @triton.jit
22: def _single_kernel(x_ptr, y_ptr, V: tl.constexpr, BLOCK_VOCAB: tl.constexpr):
23: """Online softmax + Kahan sum, then normalize. One block per row."""
24: pid = tl.program_id(0)
25: offs = tl.arange(0, BLOCK_VOCAB)
26:
27: m = float('-inf')
28: s = 0.0
29: c = 0.0
30:
31: for start in range(0, V, BLOCK_VOCAB):
32: mask = (start + offs) < V
33: addr = pid * V + start + offs
34: x = tl.load(x_ptr + addr, mask=mask, other=float('-inf'))
35:
36: m_local = tl.max(tl.where(mask, x, float('-in
... (7392 more chars){
"todos": [
{
"priority": "high",
"content": "Write Triton softmax kernel with Kahan compensated summation in solution.py",
"status": "completed"
},
{
"priority": "high",
"content": "Verify correctness with python check.py",
"status": "completed"
},
{
"priority": "high",
"content": "Optimize: multi-block map-reduce for better GPU utilization",
"status": "completed"
},
{
"priority": "high",
"content": "Benchmark with python benchmark.py",
"status": "completed"
},
{
"priority": "m
... (97 more chars)content: 'Write Triton softmax kernel with Kahan compensated summation in solution.py'
status: 'completed'
priority: 'high'
content: 'Verify correctness with python check.py'
status: 'completed'
priority: 'high'
content: 'Optimize: multi-block map-reduce for better GPU utilization'
status: 'completed'
priority: 'high'
content: 'Benchmark with python benchmark.py'
status: 'completed'
priority: 'high'
content: 'Profile and iterate if needed'
status: 'completed'
priority: 'medium'