"""Triton softmax with online max-reduce + Kahan-compensated summation.
Optimized for RTX PRO 6000 Blackwell (SM120, GDDR7, 1.8 TB/s).
Strategy: fused 2-pass online kernel with constexpr N for full loop unrolling.
Sweep-validated optimal config: BLOCK_N=4096, 8 warps for large shapes;
BLOCK_N=2048, 4 warps for small shapes.
Numerical approach:
- Online max tracking with scale adjustment (avoids overflow on extreme logits)
- Kahan-compensated summation of exp(x - running_max) for tight fp32 accuracy
- Matches fp64 reference within atol=rtol=1e-5 on all shapes including 256K vocab
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
@triton.jit
def _softmax_kernel(
X_ptr,
Y_ptr,
stride_xr,
stride_yr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""One-row-per-program online softmax.
Pass 1: single loop tracking running_max and Kahan-sum of exp(x-max).
Pass 2: re-read X, compute exp(x-final_max)/total_sum, write Y.
2 memory passes (optimal for softmax), full fp32 Kahan-compensated.
"""
row_id = tl.program_id(0)
x_base = X_ptr + row_id * stride_xr
y_base = Y_ptr + row_id * stride_yr
running_max = tl.full([1], float('-inf'), dtype=tl.float32)
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
chunk_max = tl.max(x, keep_dims=True)
new_max = tl.maximum(running_max, chunk_max)
scale = tl.exp(running_max - new_max)
running_sum = running_sum * scale
comp = comp * scale
e = tl.exp(x - new_max)
e_masked = tl.where(mask, e, 0.0)
e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
y_k = e_chunk - comp
new_sum = running_sum + y_k
comp = (new_sum - running_sum) - y_k
running_sum = new_sum
running_max = new_max
final_max = running_max.broadcast_to([BLOCK_N])
total_sum_vec = running_sum.broadcast_to([BLOCK_N])
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
e = tl.exp(x - final_max)
y = tl.where(mask, e / total_sum_vec, 0.0)
tl.store(y_base + idx, y, mask=mask)
def get_launch_config(vocab):
"""Return (BLOCK_N, num_warps, num_stages) validated by sweep."""
if vocab <= 2048:
return vocab, 4, 3
if vocab <= 4096:
return 2048, 4, 3
blk = min(4096, vocab)
while vocab % blk != 0 and blk > 32:
blk //= 2
return blk, 8, 2
def triton_softmax(x):
assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
batch, vocab = x.shape
y = torch.empty_like(x)
block_n, num_warps, num_stages = get_launch_config(vocab)
_softmax_kernel[(batch,)](
x, y, x.stride(0), y.stride(0), vocab,
BLOCK_N=block_n,
num_warps=num_warps,
num_stages=num_stages,
)
return y
class Model(nn.Module):
def __init__(self, batch, vocab):
super().__init__()
self.batch = batch
self.vocab = vocab
def forward(self, x):
return triton_softmax(x)
BATCH = 8
VOCAB = 32768
def get_inputs():
x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
return [x]
def get_init_inputs():
return [BATCH, VOCAB]
shape=0 variant=eager tflops=0.010 gbps=15.548 ms=0.067
shape=0 variant=compiled tflops=0.006 gbps=10.283 ms=0.102
shape=0 variant=sota tflops=0.010 gbps=15.406 ms=0.068
shape=0 variant=solution tflops=0.053 gbps=84.781 ms=0.012
shape=0 solution_peak_fraction=0.0471
shape=1 variant=eager tflops=0.042 gbps=66.619 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.996 ms=0.098
shape=1 variant=sota tflops=0.035 gbps=55.445 ms=0.076
shape=1 variant=solution tflops=0.134 gbps=213.995 ms=0.020
shape=1 solution_peak_fraction=0.1189
shape=2 variant=eager tflops=0.073 gbps=116.302 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.440 ms=0.218
shape=2 variant=sota tflops=0.069 gbps=110.749 ms=0.076
shape=2 variant=solution tflops=0.114 gbps=182.044 ms=0.046
shape=2 solution_peak_fraction=0.1011
shape=3 variant=eager tflops=0.073 gbps=117.081 ms=0.072
shape=3 variant=compiled tflops=0.015 gbps=23.230 ms=0.361
shape=3 variant=sota tflops=0.068 gbps=108.571 ms=0.077
shape=3 variant=solution tflops=0.064 gbps=102.902 ms=0.082
shape=3 solution_peak_fraction=0.0572
shape=4 variant=eager tflops=0.073 gbps=116.096 ms=0.072
shape=4 variant=compiled tflops=0.023 gbps=36.916 ms=0.227
shape=4 variant=sota tflops=0.063 gbps=100.131 ms=0.084
shape=4 variant=solution tflops=0.097 gbps=155.529 ms=0.054
shape=4 solution_peak_fraction=0.0864
peak_fraction: 0.0775
RESULT: LOW
PASS
{
"problem": "04_kahan_softmax",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-max-preview",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0775,
"elapsed_seconds": 2101,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 5252186, "output_tokens": 35897, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 12251, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"command": "nvidia-smi --query-gpu=name,memory.total,compute_cap --format=csv,noheader",
"description": "Check GPU availability and specs"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 12.0
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py</path>
<type>file</type>
<content>
1: """Naive softmax over the last dim, computed in fp64 for ground-truth.
2:
3: The reference deliberately runs in float64 so that fp16 / fp32 accumulation
4: drift in agent solutions is exposed by the tight tolerance in problem.yaml.
5: The agent's job is to produce an fp32 softmax whose values match this
6: double-precision reference within atol=rtol=1e-5 — this requires either
7: fp32 accumulation or compensated (Kahan) summation when vocab is large.
8: """
9: import torch
10: import torch.nn as nn
11:
12: OP_TYPE = "softmax"
13: SUPPORTED_PRECISIONS = ["fp32"]
14: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
15:
16:
17: class Model(nn.Module):
18: """y = softmax(x, dim=-1) computed in fp64 then returned as fp32.
19:
20: No learned parameters — softmax is parameter-free. We still expose an
21: empty state_dict so the harness's strict load_state_dict matches.
22: """
23:
24: def __init__(self, batch: int, vocab: int):
25: super().__init__()
26: self.batch = batch
27: self.vocab = vocab
28:
29: def forward(self, x: torch.Tensor) -> torch.Tensor:
30: # Promote to fp64 for the ground-truth pathway. Even with double
31: # precision we still subtract the row-max for stability.
32: x64 = x.to(torch.float64)
33: m = x64.amax(dim=-1, keepdim=True)
34: e = torch.exp
... (592 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Kahan-corrected softmax.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3
4: seeds each, with the tight (1e-5) fp32 tolerance from problem.yaml. Also
5: rejects forbidden ops via grep.
6: """
7: import re
8: import sys
9: from pathlib import Path
10:
11: import torch
12: import yaml
13:
14: # Make the repo's src/ importable
15: REPO_ROOT = Path(__file__).resolve().parents[2]
16: sys.path.insert(0, str(REPO_ROOT))
17:
18: from src.eval.correctness import check_correctness # noqa: E402
19:
20:
21: def _make_inputs(batch: int, vocab: int, extreme: bool, seed: int) -> torch.Tensor:
22: g = torch.Generator().manual_seed(seed)
23: if extreme:
24: # Adversarial: most logits are mild but a handful per row are huge.
25: # If the kernel forgets to subtract the row-max before exp, this
26: # row overflows fp32 and produces NaN/Inf. If it accumulates in
27: # fp16, the long tail of small exp() values is lost beneath the
28: # tolerance threshold.
29: x = torch.randn(batch, vocab, generator=g) * 2.0
30: # Spike: 4 very large positive logits per row.
31: idx = torch.randint(0, vocab, (batch, 4), generator=g)
32: x.scatter_(1, idx, 30.0)
33: else:
34: x = torch.randn(batch, vocab, generator=g) * 4.0
35: return x.to(torc
... (3464 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Kahan-corrected softmax.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if
4: available), and the agent's solution. Reports achieved TFLOPS, GB/s, and
5: peak_fraction. Softmax is memory-bound, so the score is GB/s / peak_dram.
6:
7: Output lines the harness picks up:
8: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10: """
11: import sys
12: from math import exp, log
13: from pathlib import Path
14:
15: import torch
16: import yaml
17:
18: REPO_ROOT = Path(__file__).resolve().parents[2]
19: sys.path.insert(0, str(REPO_ROOT))
20:
21: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22: from src.eval.timing import time_fn # noqa: E402
23: from src.hardware import get as get_hw # noqa: E402
24:
25:
26: def _eval_formula(expr: str, vars: dict) -> float:
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def _make_inputs(batch: int, vocab: int, extreme: bool) -> torch.Tensor:
31: if extreme:
32: x = torch.randn(batch, vocab) * 2.0
33: idx = torch.randint(0, vocab, (batch, 4))
34: x.scatter_(1, idx, 30.0)
35: else:
36: x = torch.randn(batch, vocab) * 4.0
37: return x.to(torch.float32)
38:
39:
40: def main():
41:
... (3706 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Kahan-corrected softmax.
2:
3: The point of this problem is numerical accuracy on long reductions. Shapes
4: mix typical LLM vocab sizes with deliberately adversarial regimes:
5:
6: - small vocab (sanity check; naive fp32 should pass)
7: - Llama3 vocab 128K (real-world, where fp16 accumulation starts to drift)
8: - 256K (DeepSeek-V3 / Gemma-3 class vocab; naive fp16 sum DOES drift past
9: the 1e-5 tolerance — this row is what proves Kahan was needed)
10: - extreme-logit edge case (large positive logits stress max-subtract +
11: summation; if the implementation accidentally exps before subtracting
12: max, this row overflows)
13:
14: The 'extreme' flag is read by check.py to switch input generation to a
15: distribution that produces a few very large logits per row.
16: """
17:
18: SHAPES = [
19: {"batch": 32, "vocab": 4096, "extreme": False}, # sanity
20: {"batch": 16, "vocab": 32768, "extreme": False}, # GPT-2 class
21: {"batch": 8, "vocab": 131072, "extreme": False}, # Llama3 vocab
22: {"batch": 4, "vocab": 262144, "extreme": False}, # 256K — Kahan needed
23: {"batch": 8, "vocab": 131072, "extreme": True}, # extreme logits edge
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml</path>
<type>file</type>
<content>
1: name: 04_kahan_softmax
2: display_name: "Kahan-corrected Softmax"
3: precision: fp32
4: regime: memory # softmax is bandwidth-bound: 2 passes over the input tensor
5:
6: # Softmax FLOPs: per-element exp + 2 reductions + divide. Roughly 5 flops/elt.
7: flops_formula: "5 * batch * vocab"
8:
9: # Bytes moved: read x once, write y once. Both fp32.
10: bytes_formula: "batch * vocab * 4 + batch * vocab * 4"
11:
12: hardware: [RTX_PRO_6000]
13: peak_tflops_key: fp32
14: peak_bandwidth_key: dram
15:
16: # TIGHTER than default (fp32 default is 1e-4). This problem exists
17: # specifically to test whether the agent uses compensated summation, so
18: # we squeeze the tolerance to 1e-5 — naive fp16 sum across 256K elements
19: # drifts past this; fp32 accumulation passes; Kahan/fp32 always passes.
20: tolerance:
21: "torch.float32": {"atol": 1.0e-5, "rtol": 1.0e-5}
22:
23: # Forbidden ops — block the obvious "just call the library" cheats. The
24: # agent must implement softmax themselves with explicit (compensated)
25: # summation logic.
26: forbidden:
27: - "torch.nn.functional.softmax"
28: - "torch.softmax"
29: - "F.softmax"
30: - "liger_kernel.softmax"
31: - "liger_kernel.transformers.softmax"
32: - ".softmax("
33:
34: sota:
35: name: "Liger-Kernel Softmax (Triton)"
36: url: "https://github.com/linkedin/Liger-Kernel"
37: function: "liger_
... (225 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to write a numerically tight softmax for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The op is softmax along the last dim of a 2D fp32 tensor. The reference computes ground truth in fp64 and casts back to fp32, and you have to match it within atol=rtol=1e-5 — that's a tighter bar than default fp32 (1e-4) on purpose. With long reductions naive fp16 accumulation drifts past this; fp32 accumulation with subtract-max stability is enough on most shapes; on the largest vocabs you may need compensated (Kahan-style) summation to stay under the bar. The shapes you have to handle are batch=32 vocab=4096 (sanity), batch=16 vocab=32768 (GPT-2 class), batch=8 vocab=131072 (Llama-3 vocab), batch=4 vocab=262144 (256K, DeepSeek-V3 / Gemma-3 class — naive fp16 sum drifts past 1e-5 here), and batch=8 vocab=131072 with extreme logits (a few very large positives per row to stress max-subtract — exping before subtracting overflows). The check and benchmark scripts handle the extreme-flag input generation; you just need
... (949 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for last-dim softmax.
2:
3: Preference order:
4: 1. liger-kernel's Triton softmax (LigerSoftmaxFunction) — fused, fp32-acc
5: 2. torch.compile(torch.softmax) — Inductor generates a Triton kernel that
6: is competitive on bandwidth-bound reductions
7:
8: Agents are FORBIDDEN from importing either of these in solution.py (see
9: problem.yaml.forbidden). This file is only the benchmark's reference line.
10: """
11: from __future__ import annotations
12:
13: import torch
14:
15: _compiled_softmax = None
16:
17:
18: def _liger_softmax(x: torch.Tensor) -> torch.Tensor | None:
19: try:
20: from liger_kernel.ops.softmax import LigerSoftmaxFunction
21: return LigerSoftmaxFunction.apply(x)
22: except Exception:
23: return None
24:
25:
26: def _compiled(x: torch.Tensor) -> torch.Tensor:
27: global _compiled_softmax
28: if _compiled_softmax is None:
29: _compiled_softmax = torch.compile(
30: lambda t: torch.softmax(t, dim=-1),
31: mode="reduce-overhead",
32: )
33: return _compiled_softmax(x)
34:
35:
36: def sota_forward(x: torch.Tensor) -> torch.Tensor:
37: """Best-available softmax reference. x: (batch, vocab) fp32."""
38: out = _liger_softmax(x)
39: if out is not None:
40: return out
41: return _compiled(x)
42:
43:
44: def is_a
... (128 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py</path>
<type>file</type>
<content>
1: """Roofline math: achieved TFLOPS / GB/s, peak fraction.
2:
3: FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
4: dense-equivalent algorithmic work; kernels that skip work (sparsity, early
5: exit) are still measured against the dense cost.
6: """
7: from __future__ import annotations
8:
9:
10: def compute_tflops(flops: float, time_ms: float) -> float:
11: """flops per invocation and wall time in ms -> TFLOPS."""
12: if time_ms <= 0:
13: return 0.0
14: return flops / (time_ms * 1e-3) / 1e12
15:
16:
17: def compute_gbps(bytes_moved: float, time_ms: float) -> float:
18: """bytes per invocation and wall time in ms -> GB/s."""
19: if time_ms <= 0:
20: return 0.0
21: return bytes_moved / (time_ms * 1e-3) / 1e9
22:
23:
24: def peak_fraction(
25: achieved: float,
26: peak: float,
27: ) -> float:
28: """Return achieved / peak, clamped to [0, infinity)."""
29: if peak <= 0:
30: return 0.0
31: return max(0.0, achieved / peak)
(End of file - total 31 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware</path>
<type>directory</type>
<entries>
__init__.py
__pycache__/
m4_max.py
rtx_pro_6000.py
(4 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__init__.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__init__.py</path>
<type>file</type>
<content>
1: """Hardware peak-throughput lookup tables."""
2: from src.hardware.m4_max import M4_MAX
3: from src.hardware.rtx_pro_6000 import RTX_PRO_6000
4:
5: TARGETS = {
6: "RTX_PRO_6000": RTX_PRO_6000,
7: "M4_MAX": M4_MAX,
8: }
9:
10:
11: def get(name: str):
12: if name not in TARGETS:
13: raise ValueError(f"Unknown hardware {name!r}; available: {list(TARGETS)}")
14: return TARGETS[name]
(End of file - total 14 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv sync 2>&1 | tail -5",
"description": "Install dependencies"
}- tabulate==0.10.0
- tokenizers==0.22.2
- tqdm==4.67.3
- transformers==5.6.2
- typer==0.24.2
{
"command": "uv run python -c \"import triton; print(triton.__version__)\"",
"description": "Check Triton version"
}3.6.0
{
"command": "nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.free --format=csv,noheader",
"description": "Check GPU utilization"
}0 %, 704 MiB, 96524 MiB
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -0,0 +1,130 @@
+"""Triton softmax with fp32 Kahan-compensated summation.
+
+Produces results matching the fp64 reference within atol=rtol=1e-5
+even for the 256K-vocab and extreme-logit shapes.
+"""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+OP_TYPE = "softmax"
+SUPPORTED_PRECISIONS = ["fp32"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _softmax_kernel_fwd(
+ X_ptr,
+ Y_ptr,
+ stride_xr,
+ stride_yr,
+ N: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ """One-row-per-program softmax with fp32 Kahan-compensated summation.
+
+ Block-N must divide N (N is power of 2, BLOCK_N is power of 2 -> fine).
+ """
+ row_id = tl.program_id(0)
+ x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
+
+ # ---- Pass 1: find row-max ------------------------------------------
+ row_max = tl.full([1], -float('inf'), dtype=tl.float32)
+ for off0 in range(0, N, BLOCK_N):
+ idx = off0 + tl.arange(0, BLOCK_N)
+ x = tl.load(x_base + idx).to(tl.float32)
+ local_max = tl.max(x)
+ row_max = tl.maximum(row_max, local_max)
+
+ row_max = row_max.reshape(()).to(tl.float32)
+
+ # ---- Pass 2: Kahan-sum exp(x - max) --------------------------------
+ # Kahan sum: running_sum and c (compensation), both fp32.
+ running_sum = tl.zeros([1], dtype=tl.float32)
+ comp = tl.zeros([1], dtype=tl.float32)
+
+ for off0 in range(0, N, BLOCK_N):
+ idx = off0 + tl.arange(0, BLOCK_N)
+ x = tl.load(x_base + idx).to(tl.float32)
+ e = tl.exp(x - row_max)
+ # Kahan step: y = e - c; new_sum = running_sum + y; c = (new_sum - running_sum) - y
+ y_kahan = e - comp
+ tmp_sum = running_sum + y_kahan
+ comp = (tmp_sum - running_sum) - y_kahan
+ running_sum = tmp_sum
+
+ total_sum = running_sum.reshape(()).to(tl.float32)
+
+ # ---- Pass 3: write out exp(x - max) / total_sum --------------------
+ for off0 in range(0, N, BLOCK_N):
+ idx = off0 + tl.arange(0, BLOCK_N)
+ x = tl.load(x_base + idx).to(tl.float32)
+ e = tl.exp(x - row_max)
+ y = e / total_sum
+ tl.store(y_base + idx, y.to(tl.float32))
+
+
+def triton_softmax(x: torch.Tensor) -> torch.Tensor:
+ """x: (batch, vocab) fp32 on CUDA -> y: (batch, vocab) fp32 on CUDA."""
+ assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
+ batch, vocab = x.shape
+ y = torch.empty_like(x)
+
+ # Pick BLOCK_N that divides vocab. For small vocab use vocab itself.
+ # For large vocab use a reasonable block that keeps registers in check.
+ if vocab <= 2048:
+ BLOCK_N = vocab
+ elif vocab <= 8192:
+ BLOCK_N = 2048
+ elif vocab <= 32768:
+ BLOCK_N = 4096
+ else:
+ BLOCK_N = 4096
+
+ # Ensure BLOCK_N divides N
+ while vocab % BLOCK_N != 0 and BLOCK_N > 64:
+ BLOCK_N //= 2
+
+ grid = (batch,)
+ _softmax_kernel_fwd[grid](
+ x, y,
+ x.stride(0), y.stride(0),
+ vocab,
+ BLOCK_N=BLOCK_N,
+ )
+ return y
+
+
+# ---------------------------------------------------------------------------
+# Model interface
+# ---------------------------------------------------------------------------
+
+class Model(nn.Module):
+ def __init__(self, batch: int, vocab: int):
+ super().__init__()
+ self.batch = batch
+ self.vocab = vocab
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return triton_softmax(x)
+
+
+# Default shape; overridden per-iteration by check.py / benchmark.py.
+BATCH = 8
+VOCAB = 32768
+
+
+def get_inputs():
+ x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
+ return [x]
+
+
+def get_init_inputs():
+ return [BATCH, VOCAB]
{
"content": "<3851 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 116, in forward
return triton_softmax(x)
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 96, in triton_softmax
_softmax_kernel_fwd[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self
... (1732 more chars){
"command": "<116 chars \u2014 see below>",
"description": "Check tl.sum signature"
}uv run python -c "
import torch
import triton.language as tl
# check tl.sum signature
help(tl.sum)
" 2>&1 | head -20Help on JITFunction in module triton.language.standard:
sum = JITFunction(triton.language.standard:sum)
Returns the sum of all elements in the :code:`input` tensor along the provided :code:`axis`
:param input: the input values
:type input: Tensor
:param axis: the dimension along which the reduction should be done. If None, reduce all dimensions
:type axis: int
:param keep_dims: if true, keep the reduced dimensions with length 1
:type keep_dims: bool
:param dtype: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`dtype` before the operation is performed. This is useful for preventing data overflows. If not specified, integer and bool dtypes are upcasted to :code:`tl.int32` and float dtypes are upcasted to at least :code:`tl.float32`.
:type dtype: tl.dtype
This function can also be called as a member function on :py:class:`tensor`,
as :code:`x.sum(...)` instead of
:code:`sum(x, ...)`.
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -46,16 +46,15 @@
row_max = row_max.reshape(()).to(tl.float32)
# ---- Pass 2: Kahan-sum exp(x - max) --------------------------------
- # Kahan sum: running_sum and c (compensation), both fp32.
- running_sum = tl.zeros([1], dtype=tl.float32)
- comp = tl.zeros([1], dtype=tl.float32)
+ running_sum = tl.full([1], 0.0, dtype=tl.float32)
+ comp = tl.full([1], 0.0, dtype=tl.float32)
for off0 in range(0, N, BLOCK_N):
idx = off0 + tl.arange(0, BLOCK_N)
x = tl.load(x_base + idx).to(tl.float32)
e = tl.exp(x - row_max)
- # Kahan step: y = e - c; new_sum = running_sum + y; c = (new_sum - running_sum) - y
- y_kahan = e - comp
+ e_chunk = tl.sum(e, keep_dims=True).to(tl.float32)
+ y_kahan = e_chunk - comp
tmp_sum = running_sum + y_kahan
comp = (tmp_sum - running_sum) - y_kahan
running_sum = tmp_sum
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<691 chars \u2014 see below>",
"newString": "<605 chars \u2014 see below>"
} # ---- Pass 2: Kahan-sum exp(x - max) --------------------------------
# Kahan sum: running_sum and c (compensation), both fp32.
running_sum = tl.zeros([1], dtype=tl.float32)
comp = tl.zeros([1], dtype=tl.float32)
for off0 in range(0, N, BLOCK_N):
idx = off0 + tl.arange(0, BLOCK_N)
x = tl.load(x_base + idx).to(tl.float32)
e = tl.exp(x - row_max)
# Kahan step: y = e - c; new_sum = running_sum + y; c = (new_sum - running_sum) - y
y_kahan = e - comp
tmp_sum = running_sum + y_kahan
comp = (tmp_sum - running_sum) - y_kahan
running_sum = tmp_sum
total_sum = running_sum.reshape(()).to(tl.float32) # ---- Pass 2: Kahan-sum exp(x - max) --------------------------------
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
for off0 in range(0, N, BLOCK_N):
idx = off0 + tl.arange(0, BLOCK_N)
x = tl.load(x_base + idx).to(tl.float32)
e = tl.exp(x - row_max)
e_chunk = tl.sum(e, keep_dims=True).to(tl.float32)
y_kahan = e_chunk - comp
tmp_sum = running_sum + y_kahan
comp = (tmp_sum - running_sum) - y_kahan
running_sum = tmp_sum
total_sum = running_sum.reshape(()).to(tl.float32)Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check again"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run performance benchmark"
}shape=0 variant=eager tflops=0.010 gbps=15.493 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.237 ms=0.102
shape=0 variant=sota tflops=0.010 gbps=15.508 ms=0.068
shape=0 variant=solution tflops=0.054 gbps=86.688 ms=0.012
shape=0 solution_peak_fraction=0.0482
shape=1 variant=eager tflops=0.042 gbps=67.545 ms=0.062
shape=1 variant=compiled tflops=0.027 gbps=43.123 ms=0.097
shape=1 variant=sota tflops=0.035 gbps=56.098 ms=0.075
shape=1 variant=solution tflops=0.120 gbps=191.486 ms=0.022
shape=1 solution_peak_fraction=0.1064
shape=2 variant=eager tflops=0.073 gbps=116.768 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.197 ms=0.220
shape=2 variant=sota tflops=0.068 gbps=108.571 ms=0.077
shape=2 variant=solution tflops=0.091 gbps=145.798 ms=0.058
shape=2 solution_peak_fraction=0.0810
shape=3 variant=eager tflops=0.072 gbps=115.102 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.202 ms=0.362
shape=3 variant=sota tflops=0.066 gbps=106.174 ms=0.079
shape=3 variant=solution tflops=0.050 gbps=80.635 ms=0.104
shape=3 solution_peak_fraction=0.0448
shape=4 variant=eager tflops=0.074 gbps=118.778 ms=0.071
shape=4 variant=compiled tflops=0.023 gbps=36.818 ms=0.228
shape=4 variant=sota tflops=0.061 gbps=98.163 ms=0.085
shape=4 variant=solution tflops=0.081 gbps=129.550 ms=0.065
shape=4 solution_peak_fraction=0.0720
peak_fraction: 0.0669
RESULT: LOW
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py</path>
<type>file</type>
<content>
1: """Triton softmax with fp32 Kahan-compensated summation.
2:
3: Produces results matching the fp64 reference within atol=rtol=1e-5
4: even for the 256K-vocab and extreme-logit shapes.
5: """
6: import torch
7: import torch.nn as nn
8: import triton
9: import triton.language as tl
10:
11:
12: OP_TYPE = "softmax"
13: SUPPORTED_PRECISIONS = ["fp32"]
14: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
15:
16:
17: # ---------------------------------------------------------------------------
18: # Triton kernel
19: # ---------------------------------------------------------------------------
20:
21: @triton.jit
22: def _softmax_kernel_fwd(
23: X_ptr,
24: Y_ptr,
25: stride_xr,
26: stride_yr,
27: N: tl.constexpr,
28: BLOCK_N: tl.constexpr,
29: ):
30: """One-row-per-program softmax with fp32 Kahan-compensated summation.
31:
32: Block-N must divide N (N is power of 2, BLOCK_N is power of 2 -> fine).
33: """
34: row_id = tl.program_id(0)
35: x_base = X_ptr + row_id * stride_xr
36: y_base = Y_ptr + row_id * stride_yr
37:
38: # ---- Pass 1: find row-max ------------------------------------------
39: row_max = tl.full([1], -float('inf'), dtype=tl.float32)
40: for off0 in range(0, N, BLOCK_N):
41: idx = off0 + tl.arange(0, BLOCK_N)
42: x = tl.load(x_base + idx).to(tl.float32)
43:
... (2962 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution_draft.py
@@ -0,0 +1,51 @@
+"""Triton softmax with fp32 Kahan-compensated summation.
+
+Two-phase kernel that splits each row into chunks processed by separate
+blocks, giving us batch * (N / chunk_size) blocks for GPU saturation.
+
+Phase 1: Each block computes local-max and Kahan-sum-exp for its chunk.
+Phase 2: Final kernel reads per-row max/sum and normalizes.
+"""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "softmax"
+SUPPORTED_PRECISIONS = ["fp32"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+# ---------------------------------------------------------------------------
+# Phase 1: per-chunk max + Kahan sum of exp(x - max)
+# ---------------------------------------------------------------------------
+@triton.jit
+def _softmax_phase1(
+ X_ptr,
+ partial_max_ptr,
+ partial_sum_ptr,
+ stride_xr,
+ N: tl.constexpr,
+ CHUNK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ batch = N # will be overridden; total vocab size
+ num_chunks = pid # total blocks = batch * chunks_per_row
+ chunks_per_row = grid_x(1) # not available; compute another way
+
+ # Better indexing: pid -> (row_id, chunk_id)
+ # We set grid = (batch * chunks_per_row,) and compute:
+ chunk_id = pid % chunks_per_row # can't reference grid
+ row_id = pid // chunks_per_row
+
+ # Each program has already been assigned to a (row, chunk) pair.
+ # The host passes chunks_per_row as a constexpr.
+ pass
+
+
+# This approach has the problem that pid % chunks_per_row needs
+# chunks_per_row as constexpr. Let me use a simpler approach:
+# grid = (batch,), each program processes all chunks inline.
+#
+# Instead, let me use the simpler 3-pass approach but with better
+# block sizes and persistent scheduling.
{
"content": "<1728 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution_draft.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,109 +1,134 @@
"""Triton softmax with fp32 Kahan-compensated summation.
-Produces results matching the fp64 reference within atol=rtol=1e-5
-even for the 256K-vocab and extreme-logit shapes.
+Multi-block-per-row approach for GPU saturation:
+ Kernel 1: Each chunk computes local-max and Kahan-sum of exp(x - row_max_local).
+ Then atomically max-reduces into a per-row shared max, and a second
+ pass (same kernel, after grid sync) computes Kahan-sum of exp(x-global_max).
+ Kernel 2: Normalize: read x, subtract global max, exp, divide by global sum.
+
+Actually, let's keep it simple and correct:
+ Kernel 1 (max_reduce): multiple blocks per row, each processes a slice,
+ writes per-slice max. Then atomic-max reduces into per-row max.
+ Kernel 2 (sum_write): each block reads row_max, does Kahan-sum of exp(x-row_max),
+ writes normed result.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
-
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-# ---------------------------------------------------------------------------
-# Triton kernel
-# ---------------------------------------------------------------------------
+@triton.jit
+def _max_kernel(
+ X_ptr,
+ row_max_ptr,
+ stride_xr,
+ N: tl.constexpr,
+ CHUNK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ chunks_per_row = tl.num_programs(1) if tl.num_programs(1) > 1 else 1
+
+ row_id = pid // chunks_per_row
+ chunk_id = pid % chunks_per_row
+
+ x_base = X_ptr + row_id * stride_xr
+
+ local_max = tl.full([1], -3.4028235e38, dtype=tl.float32)
+
+ start = chunk_id * CHUNK
+ end = min(start + CHUNK, N)
+
+ for off in range(start, end, CHUNK):
+ idx = off + tl.arange(0, CHUNK)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=-3.4028235e38).to(tl.float32)
+ m = tl.max(x, keep_dims=True)
+ local_max = tl.maximum(local_max, m)
+
+ lm = local_max.reshape(()).to(tl.float32)
+
+ tl.atomic_max(row_max_ptr + row_id, lm)
-@triton.jit
-def _softmax_kernel_fwd(
+
+@triton.jit
+def _norm_kernel(
X_ptr,
Y_ptr,
+ row_max_ptr,
stride_xr,
stride_yr,
N: tl.constexpr,
- BLOCK_N: tl.constexpr,
+ CHUNK: tl.constexpr,
):
- """One-row-per-program softmax with fp32 Kahan-compensated summation.
-
- Block-N must divide N (N is power of 2, BLOCK_N is power of 2 -> fine).
- """
- row_id = tl.program_id(0)
+ pid = tl.program_id(0)
+
+ row_id = pid
+ chunk_id = 0 # for now, one block per row
+
x_base = X_ptr + row_id * stride_xr
y_base = Y_ptr + row_id * stride_yr
-
- # ---- Pass 1: find row-max ------------------------------------------
- row_max = tl.full([1], -float('inf'), dtype=tl.float32)
- for off0 in range(0, N, BLOCK_N):
- idx = off0 + tl.arange(0, BLOCK_N)
- x = tl.load(x_base + idx).to(tl.float32)
- local_max = tl.max(x)
- row_max = tl.maximum(row_max, local_max)
-
- row_max = row_max.reshape(()).to(tl.float32)
-
- # ---- Pass 2: Kahan-sum exp(x - max) --------------------------------
+
+ row_max = tl.load(row_max_ptr + row_id).to(tl.float32)
+
+ # Kahan sum all chunks
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
-
- for off0 in range(0, N, BLOCK_N):
- idx = off0 + tl.arange(0, BLOCK_N)
- x = tl.load(x_base + idx).to(tl.float32)
+
+ for off in range(0, N, CHUNK):
+ idx = off + tl.arange(0, CHUNK)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
e = tl.exp(x - row_max)
- e_chunk = tl.sum(e, keep_dims=True).to(tl.float32)
- y_kahan = e_chunk - comp
- tmp_sum = running_sum + y_kahan
- comp = (tmp_sum - running_sum) - y_kahan
- running_sum = tmp_sum
-
- total_sum = running_sum.reshape(()).to(tl.float32)
-
- # ---- Pass 3: write out exp(x - max) / total_sum --------------------
- for off0 in range(0, N, BLOCK_N):
- idx = off0 + tl.arange(0, BLOCK_N)
- x = tl.load(x_base + idx).to(tl.float32)
+ e_masked = tl.where(mask, e, 0.0)
+ e_sum = tl.sum(e_masked, keep_dims=True).to(tl.float32)
+ y_k = e_sum - comp
+ new_sum = running_sum + y_k
+ comp = (new_sum - running_sum) - y_k
+ running_sum = new_sum
+
+ total = running_sum.reshape(()).to(tl.float32)
+
+ for off in range(0, N, CHUNK):
+ idx = off + tl.arange(0, CHUNK)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
e = tl.exp(x - row_max)
- y = e / total_sum
- tl.store(y_base + idx, y.to(tl.float32))
+ y = tl.where(mask, e / total, 0.0)
+ tl.store(y_base + idx, y, mask=mask)
def triton_softmax(x: torch.Tensor) -> torch.Tensor:
- """x: (batch, vocab) fp32 on CUDA -> y: (batch, vocab) fp32 on CUDA."""
assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
batch, vocab = x.shape
y = torch.empty_like(x)
-
- # Pick BLOCK_N that divides vocab. For small vocab use vocab itself.
- # For large vocab use a reasonable block that keeps registers in check.
- if vocab <= 2048:
- BLOCK_N = vocab
- elif vocab <= 8192:
- BLOCK_N = 2048
+
+ row_max = torch.full((batch,), -float('inf'), dtype=torch.float32, device=x.device)
+
+ if vocab <= 4096:
+ CHUNK = vocab
elif vocab <= 32768:
- BLOCK_N = 4096
+ CHUNK = 4096
else:
- BLOCK_N = 4096
-
- # Ensure BLOCK_N divides N
- while vocab % BLOCK_N != 0 and BLOCK_N > 64:
- BLOCK_N //= 2
-
-
... (964 more chars){
"content": "<4280 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check correctness of new approach"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 90, in main
ok, msg = check_correctness(
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py", line 71, in check_correctness
if torch.isnan(solution_out).any():
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,17 +1,11 @@
"""Triton softmax with fp32 Kahan-compensated summation.
-Multi-block-per-row approach for GPU saturation:
- Kernel 1: Each chunk computes local-max and Kahan-sum of exp(x - row_max_local).
- Then atomically max-reduces into a per-row shared max, and a second
- pass (same kernel, after grid sync) computes Kahan-sum of exp(x-global_max).
- Kernel 2: Normalize: read x, subtract global max, exp, divide by global sum.
-
-Actually, let's keep it simple and correct:
- Kernel 1 (max_reduce): multiple blocks per row, each processes a slice,
- writes per-slice max. Then atomic-max reduces into per-row max.
- Kernel 2 (sum_write): each block reads row_max, does Kahan-sum of exp(x-row_max),
- writes normed result.
+Persistent multi-block-per-row approach:
+ Kernel 1 (max): Multiple blocks per row, each finding local max over a chunk.
+ Max-reduced into per-row scalar via atomic CAS.
+ Kernel 2 (norm): One block per row, reads row_max, Kahan-sums exp(x-max), writes output.
"""
+import math
import torch
import torch.nn as nn
import triton
@@ -28,34 +22,35 @@
row_max_ptr,
stride_xr,
N: tl.constexpr,
- CHUNK: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ NUM_CHUNKS: tl.constexpr,
):
pid = tl.program_id(0)
- chunks_per_row = tl.num_programs(1) if tl.num_programs(1) > 1 else 1
-
- row_id = pid // chunks_per_row
- chunk_id = pid % chunks_per_row
-
+ row_id = pid // NUM_CHUNKS
+ chunk_id = pid % NUM_CHUNKS
+
x_base = X_ptr + row_id * stride_xr
-
+
local_max = tl.full([1], -3.4028235e38, dtype=tl.float32)
-
- start = chunk_id * CHUNK
- end = min(start + CHUNK, N)
-
- for off in range(start, end, CHUNK):
- idx = off + tl.arange(0, CHUNK)
+
+ chunk_start = chunk_id * BLOCK_N
+ chunk_end = min(chunk_start + BLOCK_N, N)
+
+ for off in range(chunk_start, chunk_end, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=-3.4028235e38).to(tl.float32)
- m = tl.max(x, keep_dims=True)
- local_max = tl.maximum(local_max, m)
-
+ m = tl.max(x)
+ m_s = m.reshape(())
+ local_max_s = local_max.reshape(())
+ local_max_s = tl.maximum(local_max_s, m_s)
+ local_max = local_max_s.reshape([1])
+
lm = local_max.reshape(()).to(tl.float32)
-
- tl.atomic_max(row_max_ptr + row_id, lm)
+ old = tl.atomic_max(row_max_ptr + row_id, lm.to(tl.float32))
-@triton.jit
+@triton.jit
def _norm_kernel(
X_ptr,
Y_ptr,
@@ -63,24 +58,20 @@
stride_xr,
stride_yr,
N: tl.constexpr,
- CHUNK: tl.constexpr,
+ BLOCK_N: tl.constexpr,
):
- pid = tl.program_id(0)
-
- row_id = pid
- chunk_id = 0 # for now, one block per row
-
+ row_id = tl.program_id(0)
+
x_base = X_ptr + row_id * stride_xr
y_base = Y_ptr + row_id * stride_yr
-
+
row_max = tl.load(row_max_ptr + row_id).to(tl.float32)
-
- # Kahan sum all chunks
+
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
-
- for off in range(0, N, CHUNK):
- idx = off + tl.arange(0, CHUNK)
+
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
e = tl.exp(x - row_max)
@@ -90,11 +81,11 @@
new_sum = running_sum + y_k
comp = (new_sum - running_sum) - y_k
running_sum = new_sum
-
+
total = running_sum.reshape(()).to(tl.float32)
-
- for off in range(0, N, CHUNK):
- idx = off + tl.arange(0, CHUNK)
+
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
e = tl.exp(x - row_max)
@@ -106,27 +97,35 @@
assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
batch, vocab = x.shape
y = torch.empty_like(x)
-
- row_max = torch.full((batch,), -float('inf'), dtype=torch.float32, device=x.device)
-
- if vocab <= 4096:
- CHUNK = vocab
- elif vocab <= 32768:
- CHUNK = 4096
+
+ row_max = torch.full((batch,), -3.4028235e38, dtype=torch.float32, device=x.device)
+
+ if vocab <= 1024:
+ BLOCK_N = vocab
+ elif vocab <= 4096:
+ BLOCK_N = 2048
+ elif vocab <= 16384:
+ BLOCK_N = 4096
else:
- CHUNK = 4096
-
- while vocab % CHUNK != 0 and CHUNK > 64:
- CHUNK //= 2
-
- chunks_per_row = (vocab + CHUNK - 1) // CHUNK
- grid1 = (batch * chunks_per_row,)
-
- _max_kernel[grid1](x, row_max, x.stride(0), vocab, CHUNK=CHUNK)
-
+ BLOCK_N = 4096
+
+ while vocab % BLOCK_N != 0 and BLOCK_N > 32:
+ BLOCK_N //= 2
+
+ num_chunks = (vocab + BLOCK_N - 1) // BLOCK_N
+ grid1 = (batch * num_chunks,)
+
+ _max_kernel[grid1](
+ x, row_max, x.stride(0), vocab,
+ BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks,
+ )
+
grid2 = (batch,)
- _norm_kernel[grid2](x, y, row_max, x.stride(0), y.stride(0), vocab, CHUNK=CHUNK)
-
+ _norm_kernel[grid2](
+ x, y, row_max, x.stride(0), y.stride(0), vocab,
+ BLOCK_N=BLOCK_N,
+ )
+
return y
{
"content": "<3982 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check correctness with atomic_max"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 139, in forward
return triton_softmax(x)
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 101, in triton_softmax
row_max = torch.full((batch,), -3.4028235e38, dtype=torch.float32, device=x.device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: value cannot be converted to type float without overflow
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -98,7 +98,7 @@
batch, vocab = x.shape
y = torch.empty_like(x)
- row_max = torch.full((batch,), -3.4028235e38, dtype=torch.float32, device=x.device)
+ row_max = torch.full((batch,), float('-inf'), dtype=torch.float32, device=x.device)
if vocab <= 1024:
BLOCK_N = vocab
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": " row_max = torch.full((batch,), -3.4028235e38, dtype=torch.float32, device=x.device)",
"newString": " row_max = torch.full((batch,), float('-inf'), dtype=torch.float32, device=x.device)"
}Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -31,7 +31,7 @@
x_base = X_ptr + row_id * stride_xr
- local_max = tl.full([1], -3.4028235e38, dtype=tl.float32)
+ local_max = tl.full([1], float('-inf'), dtype=tl.float32)
chunk_start = chunk_id * BLOCK_N
chunk_end = min(chunk_start + BLOCK_N, N)
@@ -39,7 +39,7 @@
for off in range(chunk_start, chunk_end, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=-3.4028235e38).to(tl.float32)
+ x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
m = tl.max(x)
m_s = m.reshape(())
local_max_s = local_max.reshape(())
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<347 chars \u2014 see below>",
"newString": "<347 chars \u2014 see below>"
} local_max = tl.full([1], -3.4028235e38, dtype=tl.float32)
chunk_start = chunk_id * BLOCK_N
chunk_end = min(chunk_start + BLOCK_N, N)
for off in range(chunk_start, chunk_end, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=-3.4028235e38).to(tl.float32) local_max = tl.full([1], float('-inf'), dtype=tl.float32)
chunk_start = chunk_id * BLOCK_N
chunk_end = min(chunk_start + BLOCK_N, N)
for off in range(chunk_start, chunk_end, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check correctness after inf fix"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 77, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1924, in reshape
return _semantic.reshape(input, shape, can_reorder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 629, in reshape
if input.type.numel != numel:
^^^^^^^^^^^^^^^^
AttributeError: 'dtype' object has no attribute 'numel'. Did you mean: 'name'?
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^
... (2622 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -40,11 +40,8 @@
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
- m = tl.max(x)
- m_s = m.reshape(())
- local_max_s = local_max.reshape(())
- local_max_s = tl.maximum(local_max_s, m_s)
- local_max = local_max_s.reshape([1])
+ m = tl.max(x, keep_dims=True)
+ local_max = tl.maximum(local_max, m)
lm = local_max.reshape(()).to(tl.float32)
old = tl.atomic_max(row_max_ptr + row_id, lm.to(tl.float32))
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<537 chars \u2014 see below>",
"newString": "<430 chars \u2014 see below>"
} local_max = tl.full([1], float('-inf'), dtype=tl.float32)
chunk_start = chunk_id * BLOCK_N
chunk_end = min(chunk_start + BLOCK_N, N)
for off in range(chunk_start, chunk_end, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
m = tl.max(x)
m_s = m.reshape(())
local_max_s = local_max.reshape(())
local_max_s = tl.maximum(local_max_s, m_s)
local_max = local_max_s.reshape([1]) local_max = tl.full([1], float('-inf'), dtype=tl.float32)
chunk_start = chunk_id * BLOCK_N
chunk_end = min(chunk_start + BLOCK_N, N)
for off in range(chunk_start, chunk_end, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
m = tl.max(x, keep_dims=True)
local_max = tl.maximum(local_max, m)Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check correctness after simplification"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Benchmark multi-block approach"
}shape=0 variant=eager tflops=0.010 gbps=15.497 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=9.483 ms=0.111
shape=0 variant=sota tflops=0.010 gbps=15.678 ms=0.067
shape=0 variant=solution tflops=0.031 gbps=49.461 ms=0.021
shape=0 solution_peak_fraction=0.0275
shape=1 variant=eager tflops=0.041 gbps=65.816 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=42.876 ms=0.098
shape=1 variant=sota tflops=0.035 gbps=55.305 ms=0.076
shape=1 variant=solution tflops=0.094 gbps=150.831 ms=0.028
shape=1 solution_peak_fraction=0.0838
shape=2 variant=eager tflops=0.072 gbps=115.967 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.370 ms=0.219
shape=2 variant=sota tflops=0.067 gbps=107.900 ms=0.078
shape=2 variant=solution tflops=0.101 gbps=161.767 ms=0.052
shape=2 solution_peak_fraction=0.0899
shape=3 variant=eager tflops=0.073 gbps=116.560 ms=0.072
shape=3 variant=compiled tflops=0.014 gbps=23.156 ms=0.362
shape=3 variant=sota tflops=0.068 gbps=109.090 ms=0.077
shape=3 variant=solution tflops=0.064 gbps=102.300 ms=0.082
shape=3 solution_peak_fraction=0.0568
shape=4 variant=eager tflops=0.072 gbps=115.304 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.846 ms=0.228
shape=4 variant=sota tflops=0.062 gbps=99.769 ms=0.084
shape=4 variant=solution tflops=0.102 gbps=162.924 ms=0.051
shape=4 solution_peak_fraction=0.0905
peak_fraction: 0.0639
RESULT: LOW
{
"command": "<175 chars \u2014 see below>",
"description": "Check persistent kernel support"
}uv run python -c "
import torch
import triton
# Check what persistent features are available
import triton.language as tl
print(dir(tl))
" 2>&1 | tr ',' '\n' | grep -i persist(no output){
"command": "<492 chars \u2014 see below>",
"timeout": 30000,
"description": "Measure actual GPU memory bandwidth"
}uv run python -c "
import torch
# Quick bandwidth test: large memcpy
n = 200_000_000 # 800 MB
a = torch.randn(n, device='cuda', dtype=torch.float32)
b = torch.empty_like(a)
for _ in range(10):
b.copy_(a)
torch.cuda.synchronize()
import time
s = time.time()
for _ in range(10):
b.copy_(a)
torch.cuda.synchronize()
elapsed = time.time() - s
bytes_moved = n * 4 * 2 * 10 # read + write, 10 iterations
gbps = bytes_moved / elapsed / 1e9
print(f'Memcpy bandwidth: {gbps:.1f} GB/s')
"Memcpy bandwidth: 1467.9 GB/s
{
"command": "<1649 chars \u2014 see below>",
"timeout": 60000,
"description": "Profile the max kernel"
}CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def _max_kernel_v2(
X_ptr, row_max_ptr, stride_xr,
N: tl.constexpr, BLOCK_N: tl.constexpr, NUM_CHUNKS: tl.constexpr,
):
pid = tl.program_id(0)
row_id = pid // NUM_CHUNKS
chunk_id = pid % NUM_CHUNKS
x_base = X_ptr + row_id * stride_xr
local_max = tl.full([1], float('-inf'), dtype=tl.float32)
chunk_start = chunk_id * BLOCK_N
chunk_end = min(chunk_start + BLOCK_N, N)
for off in range(chunk_start, chunk_end, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
m = tl.max(x, keep_dims=True)
local_max = tl.maximum(local_max, m)
lm = local_max.reshape(()).to(tl.float32)
tl.atomic_max(row_max_ptr + row_id, lm)
batch, vocab = 8, 32768
x = torch.randn(batch, vocab, device='cuda', dtype=torch.float32) * 4.0
row_max = torch.full((batch,), float('-inf'), dtype=torch.float32, device='cuda')
BLOCK_N = 4096
num_chunks = vocab // BLOCK_N
grid = (batch * num_chunks,)
# Warm-up
_max_kernel_v2[grid](x, row_max, x.stride(0), vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
torch.cuda.synchronize()
# Profile with torch profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
) as prof:
_max_kernel_v2[grid](x, row_max, x.stride(0), vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py
@@ -0,0 +1,51 @@
+"""Quick profiler scratch file."""
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def _max_kernel_v2(
+ X_ptr, row_max_ptr, stride_xr,
+ N: tl.constexpr, BLOCK_N: tl.constexpr, NUM_CHUNKS: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_id = pid // NUM_CHUNKS
+ chunk_id = pid % NUM_CHUNKS
+ x_base = X_ptr + row_id * stride_xr
+ local_max = tl.full([1], float('-inf'), dtype=tl.float32)
+ chunk_start = chunk_id * BLOCK_N
+ chunk_end = min(chunk_start + BLOCK_N, N)
+ for off in range(chunk_start, chunk_end, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
+ m = tl.max(x, keep_dims=True)
+ local_max = tl.maximum(local_max, m)
+ lm = local_max.reshape(()).to(tl.float32)
+ tl.atomic_max(row_max_ptr + row_id, lm)
+
+batch, vocab = 8, 32768
+x = torch.randn(batch, vocab, device='cuda', dtype=torch.float32) * 4.0
+row_max = torch.full((batch,), float('-inf'), dtype=torch.float32, device='cuda')
+
+BLOCK_N = 4096
+num_chunks = vocab // BLOCK_N
+grid = (batch * num_chunks,)
+
+# Warm-up
+_max_kernel_v2[grid](x, row_max, x.stride(0), vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
+torch.cuda.synchronize()
+
+print(f"Grid: {grid}, BLOCK_N: {BLOCK_N}, num_chunks: {num_chunks}")
+print(f"Total elements: {batch * vocab:,}, bytes: {batch * vocab * 4:,}")
+
+# Profile with torch profiler
+with torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CUDA],
+ record_shapes=True,
+ with_stack=True,
+) as prof:
+ _max_kernel_v2[grid](x, row_max, x.stride(0), vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
+ torch.cuda.synchronize()
+
+print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=15))
{
"content": "<1780 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py"
}Wrote file successfully.{
"command": "uv run python profile_scratch.py 2>&1",
"timeout": 60000,
"description": "Profile the max kernel"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
Grid: (64,), BLOCK_N: 4096, num_chunks: 8
Total elements: 262,144, bytes: 1,048,576
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
_max_kernel_v2 0.00% 0.000us 0.00% 0.000us 0.000us 1.632us 100.00% 1.632us 1.632us 1
Activity Buffer Request 94.57% 672.430us 94.57% 672.430us 672.430us 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernelEx 4.20% 29.876us 4.20% 29.876us 29.876us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 1.23% 8.736us 1.23% 8.736us 4.368us
... (299 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py
@@ -1,7 +1,8 @@
-"""Quick profiler scratch file."""
+"""Profile the full softmax kernels."""
import torch
import triton
import triton.language as tl
+import statistics
@triton.jit
def _max_kernel_v2(
@@ -24,28 +25,85 @@
lm = local_max.reshape(()).to(tl.float32)
tl.atomic_max(row_max_ptr + row_id, lm)
-batch, vocab = 8, 32768
-x = torch.randn(batch, vocab, device='cuda', dtype=torch.float32) * 4.0
-row_max = torch.full((batch,), float('-inf'), dtype=torch.float32, device='cuda')
+@triton.jit
+def _norm_kernel_v2(
+ X_ptr, Y_ptr, row_max_ptr, stride_xr, stride_yr,
+ N: tl.constexpr, BLOCK_N: tl.constexpr,
+):
+ row_id = tl.program_id(0)
+ x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
+ row_max = tl.load(row_max_ptr + row_id).to(tl.float32)
-BLOCK_N = 4096
-num_chunks = vocab // BLOCK_N
-grid = (batch * num_chunks,)
+ running_sum = tl.full([1], 0.0, dtype=tl.float32)
+ comp = tl.full([1], 0.0, dtype=tl.float32)
-# Warm-up
-_max_kernel_v2[grid](x, row_max, x.stride(0), vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
-torch.cuda.synchronize()
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
+ e = tl.exp(x - row_max)
+ e_masked = tl.where(mask, e, 0.0)
+ e_sum = tl.sum(e_masked, keep_dims=True).to(tl.float32)
+ y_k = e_sum - comp
+ new_sum = running_sum + y_k
+ comp = (new_sum - running_sum) - y_k
+ running_sum = new_sum
-print(f"Grid: {grid}, BLOCK_N: {BLOCK_N}, num_chunks: {num_chunks}")
-print(f"Total elements: {batch * vocab:,}, bytes: {batch * vocab * 4:,}")
+ total = running_sum.reshape(()).to(tl.float32)
-# Profile with torch profiler
-with torch.profiler.profile(
- activities=[torch.profiler.ProfilerActivity.CUDA],
- record_shapes=True,
- with_stack=True,
-) as prof:
- _max_kernel_v2[grid](x, row_max, x.stride(0), vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
+ e = tl.exp(x - row_max)
+ y = tl.where(mask, e / total, 0.0)
+ tl.store(y_base + idx, y, mask=mask)
+
+def time_kernel(kernel_fn, grid, iters=30, warmup=10, **kwargs):
+ for _ in range(warmup):
+ kernel_fn[grid](**kwargs)
torch.cuda.synchronize()
+
+ times = []
+ for _ in range(iters):
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ kernel_fn[grid](**kwargs)
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
-print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=15))
+for batch, vocab in [(32, 4096), (16, 32768), (8, 131072), (4, 262144)]:
+ x = torch.randn(batch, vocab, device='cuda', dtype=torch.float32) * 4.0
+
+ BLOCK_N = 4096
+ while vocab % BLOCK_N != 0:
+ BLOCK_N //= 2
+ num_chunks = vocab // BLOCK_N
+
+ y = torch.empty_like(x)
+ row_max = torch.full((batch,), float('-inf'), dtype=torch.float32, device='cuda')
+
+ grid1 = (batch * num_chunks,)
+ grid2 = (batch,)
+
+ ms_max = time_kernel(_max_kernel_v2, grid1, x=x, row_max_ptr=row_max,
+ stride_xr=x.stride(0), N=vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
+
+ ms_norm = time_kernel(_norm_kernel_v2, grid2, x=x, Y_ptr=y, row_max_ptr=row_max,
+ stride_xr=x.stride(0), stride_yr=y.stride(0), N=vocab, BLOCK_N=BLOCK_N)
+
+ total_ms = ms_max + ms_norm
+ input_bytes = batch * vocab * 4
+ output_bytes = batch * vocab * 4
+ total_bytes = input_bytes * 3 + output_bytes # 3 reads of X, 1 write of Y
+
+ eff_bw = (input_bytes * 2 + output_bytes) / (total_ms * 1e-3) / 1e9 # counted bytes
+ real_bw = total_bytes / (total_ms * 1e-3) / 1e9 # actual bytes moved
+
+ print(f"batch={batch:3d} vocab={vocab:7d}: max={ms_max:.3f}ms norm={ms_norm:.3f}ms "
+ f"total={total_ms:.3f}ms counted_bw={eff_bw:.1f}GB/s real_bw={real_bw:.1f}GB/s")
{
"content": "<3971 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py"
}Wrote file successfully.{
"command": "uv run python profile_scratch.py 2>&1",
"timeout": 60000,
"description": "Profile both kernels separately per shape"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py", line 94, in <module>
ms_max = time_kernel(_max_kernel_v2, grid1, x=x, row_max_ptr=row_max,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py", line 65, in time_kernel
kernel_fn[grid](**kwargs)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 710, in run
bound_args, specialization, options = binder(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: dynamic_func() missing 1 required positional argument: 'X_ptr'
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py
@@ -91,10 +91,10 @@
grid1 = (batch * num_chunks,)
grid2 = (batch,)
- ms_max = time_kernel(_max_kernel_v2, grid1, x=x, row_max_ptr=row_max,
+ ms_max = time_kernel(_max_kernel_v2, grid1, X_ptr=x, row_max_ptr=row_max,
stride_xr=x.stride(0), N=vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
-
- ms_norm = time_kernel(_norm_kernel_v2, grid2, x=x, Y_ptr=y, row_max_ptr=row_max,
+
+ ms_norm = time_kernel(_norm_kernel_v2, grid2, X_ptr=x, Y_ptr=y, row_max_ptr=row_max,
stride_xr=x.stride(0), stride_yr=y.stride(0), N=vocab, BLOCK_N=BLOCK_N)
total_ms = ms_max + ms_norm
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py",
"oldString": "<361 chars \u2014 see below>",
"newString": "<364 chars \u2014 see below>"
} ms_max = time_kernel(_max_kernel_v2, grid1, x=x, row_max_ptr=row_max,
stride_xr=x.stride(0), N=vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
ms_norm = time_kernel(_norm_kernel_v2, grid2, x=x, Y_ptr=y, row_max_ptr=row_max,
stride_xr=x.stride(0), stride_yr=y.stride(0), N=vocab, BLOCK_N=BLOCK_N) ms_max = time_kernel(_max_kernel_v2, grid1, X_ptr=x, row_max_ptr=row_max,
stride_xr=x.stride(0), N=vocab, BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks)
ms_norm = time_kernel(_norm_kernel_v2, grid2, X_ptr=x, Y_ptr=y, row_max_ptr=row_max,
stride_xr=x.stride(0), stride_yr=y.stride(0), N=vocab, BLOCK_N=BLOCK_N)Edit applied successfully.{
"command": "uv run python profile_scratch.py 2>&1",
"timeout": 120000,
"description": "Profile both kernels with correct args"
}batch= 32 vocab= 4096: max=0.008ms norm=0.009ms total=0.017ms counted_bw=91.4GB/s real_bw=121.9GB/s
batch= 16 vocab= 32768: max=0.009ms norm=0.016ms total=0.025ms counted_bw=250.9GB/s real_bw=334.6GB/s
batch= 8 vocab= 131072: max=0.008ms norm=0.040ms total=0.048ms counted_bw=262.1GB/s real_bw=349.4GB/s
batch= 4 vocab= 262144: max=0.008ms norm=0.070ms total=0.079ms counted_bw=160.2GB/s real_bw=213.6GB/s
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,15 +1,17 @@
-"""Triton softmax with fp32 Kahan-compensated summation.
+"""Triton softmax with online max-reduce + Kahan-compensated summation.
-Persistent multi-block-per-row approach:
- Kernel 1 (max): Multiple blocks per row, each finding local max over a chunk.
- Max-reduced into per-row scalar via atomic CAS.
- Kernel 2 (norm): One block per row, reads row_max, Kahan-sums exp(x-max), writes output.
+Single pass reads all elements, maintaining:
+ - running_max: max seen so far
+ - running_sum: Kahan-compensated sum of exp(x - running_max), adjusted when max updates
+ - Then a second pass writes the normalized output.
+
+Net memory traffic: 2 passes (read + write) vs 3 passes for naive approach.
"""
-import math
import torch
import torch.nn as nn
import triton
import triton.language as tl
+
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
@@ -17,77 +19,85 @@
@triton.jit
-def _max_kernel(
- X_ptr,
- row_max_ptr,
- stride_xr,
- N: tl.constexpr,
- BLOCK_N: tl.constexpr,
- NUM_CHUNKS: tl.constexpr,
-):
- pid = tl.program_id(0)
- row_id = pid // NUM_CHUNKS
- chunk_id = pid % NUM_CHUNKS
-
- x_base = X_ptr + row_id * stride_xr
-
- local_max = tl.full([1], float('-inf'), dtype=tl.float32)
-
- chunk_start = chunk_id * BLOCK_N
- chunk_end = min(chunk_start + BLOCK_N, N)
-
- for off in range(chunk_start, chunk_end, BLOCK_N):
- idx = off + tl.arange(0, BLOCK_N)
- mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
- m = tl.max(x, keep_dims=True)
- local_max = tl.maximum(local_max, m)
-
- lm = local_max.reshape(()).to(tl.float32)
- old = tl.atomic_max(row_max_ptr + row_id, lm.to(tl.float32))
-
-
-@triton.jit
-def _norm_kernel(
+def _softmax_kernel_online(
X_ptr,
Y_ptr,
- row_max_ptr,
stride_xr,
stride_yr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
):
row_id = tl.program_id(0)
-
x_base = X_ptr + row_id * stride_xr
y_base = Y_ptr + row_id * stride_yr
- row_max = tl.load(row_max_ptr + row_id).to(tl.float32)
-
+ # --- Online pass: track running_max and Kahan-sum of exp(x-max) ------
+ running_max = tl.full([1], float('-inf'), dtype=tl.float32)
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
- e = tl.exp(x - row_max)
+ x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
+
+ # Find max of this chunk
+ chunk_max = tl.max(x, keep_dims=True)
+
+ # New running max
+ new_max = tl.maximum(running_max, chunk_max)
+
+ # Scale the existing sum: exp(x - old_max) -> exp(x - new_max)
+ # factor = exp(old_max - new_max)
+ old_max_val = running_max.reshape(())
+ new_max_val = new_max.reshape(())
+ scale = tl.math.exp(old_max_val - new_max_val)
+
+ # Reshape running_sum to [1] for scalar ops
+ rs_s = running_sum.reshape(())
+ comp_s = comp.reshape(())
+
+ # Scale existing sum and compensation
+ rs_s = rs_s * scale
+ comp_s = comp_s * scale
+
+ # Add new chunk contributions: exp(x - new_max) for this chunk
+ e = tl.exp(x.reshape([-1]) - new_max_val)
e_masked = tl.where(mask, e, 0.0)
- e_sum = tl.sum(e_masked, keep_dims=True).to(tl.float32)
- y_k = e_sum - comp
- new_sum = running_sum + y_k
- comp = (new_sum - running_sum) - y_k
- running_sum = new_sum
+ e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
+ e_s = e_chunk.reshape(())
- total = running_sum.reshape(()).to(tl.float32)
+ # Kahan add
+ y_k = e_s - comp_s
+ new_sum_s = rs_s + y_k
+ comp_s = (new_sum_s - rs_s) - y_k
+ rs_s = new_sum_s
+ running_max = new_max
+ running_sum = rs_s.reshape([1])
+ comp = comp_s.reshape([1])
+
+ final_max = running_max.reshape(()).to(tl.float32)
+ total_sum = running_sum.reshape(()).to(tl.float32)
+
+ # --- Write normalized output ----------------------------------------
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
- e = tl.exp(x - row_max)
- y = tl.where(mask, e / total, 0.0)
+ e = tl.exp(x - final_max)
+ y = tl.where(mask, e / total_sum, 0.0)
tl.store(y_base + idx, y, mask=mask)
+
+
+def _pick_block_n(vocab: int) -> int:
+ if vocab <= 1024:
+ return vocab
+ if vocab <= 4096:
+ return 2048
+ if vocab <= 32768:
+ return 4096
+ return 4096
def triton_softmax(x: torch.Tensor) -> torch.Tensor:
@@ -95,34 +105,15 @@
batch, vocab = x.shape
y = torch.empty_like(x)
- row_max = torch.full((batch,), float('-inf'), dtype=torch.float32, device=x.device)
+ block_n = _pick_block_n(vocab)
+ while vocab % block_n != 0 and block_n > 32:
+ block_n //= 2
- if vocab <= 1024:
- BLOCK_N = vocab
- elif vocab <= 4096:
- BLOCK_N = 2048
- elif vocab <= 16384:
- BLOCK_N = 4096
- else:
- BLOCK_N = 4096
-
- while vocab % BLOCK_N != 0 and BLOCK_N > 32:
- BLOCK_N //= 2
-
- num_chunks = (vocab + BLOCK_N - 1) // BLOCK_N
- grid1 = (batch * num_chunks,)
-
- _max_kernel[grid1](
- x, row_max, x.stride(0), vocab,
- BLOCK_N=BLOCK_N, NUM_CHUNKS=num_chunks,
+ grid = (batch,)
+ _softmax_kernel_online[grid](
+ x, y, x.stride(0), y.stride(0), vocab,
+
... (184 more chars){
"content": "<3924 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check online softmax correctness"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 77, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1921, in reshape
shape = _shape_check_impl(_unwrap_iterable(shape))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1667, in _shape_check_impl
validate_block_shape(shape)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/_utils.py", line 54, in validate_block_shape
raise ValueError(f"Shape element {i} must be a power of 2")
ValueError: Shape element 0 must be a power of 2
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib
... (2709 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -62,7 +62,7 @@
comp_s = comp_s * scale
# Add new chunk contributions: exp(x - new_max) for this chunk
- e = tl.exp(x.reshape([-1]) - new_max_val)
+ e = tl.exp(x - new_max_val)
e_masked = tl.where(mask, e, 0.0)
e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
e_s = e_chunk.reshape(())
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<120 chars \u2014 see below>",
"newString": "<106 chars \u2014 see below>"
} # Add new chunk contributions: exp(x - new_max) for this chunk
e = tl.exp(x.reshape([-1]) - new_max_val) # Add new chunk contributions: exp(x - new_max) for this chunk
e = tl.exp(x - new_max_val)Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check online softmax again"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 77, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1924, in reshape
return _semantic.reshape(input, shape, can_reorder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 629, in reshape
if input.type.numel != numel:
^^^^^^^^^^^^^^^^
AttributeError: 'dtype' object has no attribute 'numel'. Did you mean: 'name'?
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^
... (2518 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py</path>
<type>file</type>
<content>
1: """Triton softmax with online max-reduce + Kahan-compensated summation.
2:
3: Single pass reads all elements, maintaining:
4: - running_max: max seen so far
5: - running_sum: Kahan-compensated sum of exp(x - running_max), adjusted when max updates
6: - Then a second pass writes the normalized output.
7:
8: Net memory traffic: 2 passes (read + write) vs 3 passes for naive approach.
9: """
10: import torch
11: import torch.nn as nn
12: import triton
13: import triton.language as tl
14:
15:
16: OP_TYPE = "softmax"
17: SUPPORTED_PRECISIONS = ["fp32"]
18: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
19:
20:
21: @triton.jit
22: def _softmax_kernel_online(
23: X_ptr,
24: Y_ptr,
25: stride_xr,
26: stride_yr,
27: N: tl.constexpr,
28: BLOCK_N: tl.constexpr,
29: ):
30: row_id = tl.program_id(0)
31: x_base = X_ptr + row_id * stride_xr
32: y_base = Y_ptr + row_id * stride_yr
33:
34: # --- Online pass: track running_max and Kahan-sum of exp(x-max) ------
35: running_max = tl.full([1], float('-inf'), dtype=tl.float32)
36: running_sum = tl.full([1], 0.0, dtype=tl.float32)
37: comp = tl.full([1], 0.0, dtype=tl.float32)
38:
39: for off in range(0, N, BLOCK_N):
40: idx = off + tl.arange(0, BLOCK_N)
41: mask = idx < N
42: x = tl.load(x_base + idx, mask=mask, other=float('-inf')
... (3162 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -47,38 +47,30 @@
# New running max
new_max = tl.maximum(running_max, chunk_max)
- # Scale the existing sum: exp(x - old_max) -> exp(x - new_max)
+ # Scale existing sum: exp(x-old_max) -> exp(x-new_max)
# factor = exp(old_max - new_max)
- old_max_val = running_max.reshape(())
- new_max_val = new_max.reshape(())
- scale = tl.math.exp(old_max_val - new_max_val)
-
- # Reshape running_sum to [1] for scalar ops
- rs_s = running_sum.reshape(())
- comp_s = comp.reshape(())
+ scale = tl.exp(running_max - new_max)
# Scale existing sum and compensation
- rs_s = rs_s * scale
- comp_s = comp_s * scale
+ running_sum = running_sum * scale
+ comp = comp * scale
- # Add new chunk contributions: exp(x - new_max) for this chunk
- e = tl.exp(x - new_max_val)
+ # Add new chunk contributions: exp(x - new_max)
+ e = tl.exp(x - new_max)
e_masked = tl.where(mask, e, 0.0)
e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
- e_s = e_chunk.reshape(())
# Kahan add
- y_k = e_s - comp_s
- new_sum_s = rs_s + y_k
- comp_s = (new_sum_s - rs_s) - y_k
- rs_s = new_sum_s
+ y_k = e_chunk - comp
+ new_sum = running_sum + y_k
+ comp = (new_sum - running_sum) - y_k
+ running_sum = new_sum
running_max = new_max
- running_sum = rs_s.reshape([1])
- comp = comp_s.reshape([1])
- final_max = running_max.reshape(()).to(tl.float32)
- total_sum = running_sum.reshape(()).to(tl.float32)
+ final_max = running_max.broadcast_to([BLOCK_N]).reshape([BLOCK_N])
+ total_sum_vec = running_sum.broadcast_to([BLOCK_N]).reshape([BLOCK_N])
+ total_sum = tl.max(total_sum_vec)
# --- Write normalized output ----------------------------------------
for off in range(0, N, BLOCK_N):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<1687 chars \u2014 see below>",
"newString": "<1427 chars \u2014 see below>"
} # --- Online pass: track running_max and Kahan-sum of exp(x-max) ------
running_max = tl.full([1], float('-inf'), dtype=tl.float32)
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
# Find max of this chunk
chunk_max = tl.max(x, keep_dims=True)
# New running max
new_max = tl.maximum(running_max, chunk_max)
# Scale the existing sum: exp(x - old_max) -> exp(x - new_max)
# factor = exp(old_max - new_max)
old_max_val = running_max.reshape(())
new_max_val = new_max.reshape(())
scale = tl.math.exp(old_max_val - new_max_val)
# Reshape running_sum to [1] for scalar ops
rs_s = running_sum.reshape(())
comp_s = comp.reshape(())
# Scale existing sum and compensation
rs_s = rs_s * scale
comp_s = comp_s * scale
# Add new chunk contributions: exp(x - new_max) for this chunk
e = tl.exp(x - new_max_val)
e_masked = tl.where(mask, e, 0.0)
e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
e_s = e_chunk.reshape(())
# Kahan add
y_k = e_s - comp_s
new_sum_s = rs_s + y_k
comp_s = (new_sum_s - rs_s) - y_k
rs_s = new_sum_s
running_max = new_max
running_sum = rs_s.reshape([1])
comp = comp_s.reshape([1])
final_max = running_max.reshape(()).to(tl.float32)
total_sum = running_sum.reshape(()).to(tl.float32) # --- Online pass: track running_max and Kahan-sum of exp(x-max) ------
running_max = tl.full([1], float('-inf'), dtype=tl.float32)
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
# Find max of this chunk
chunk_max = tl.max(x, keep_dims=True)
# New running max
new_max = tl.maximum(running_max, chunk_max)
# Scale existing sum: exp(x-old_max) -> exp(x-new_max)
# factor = exp(old_max - new_max)
scale = tl.exp(running_max - new_max)
# Scale existing sum and compensation
running_sum = running_sum * scale
comp = comp * scale
# Add new chunk contributions: exp(x - new_max)
e = tl.exp(x - new_max)
e_masked = tl.where(mask, e, 0.0)
e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
# Kahan add
y_k = e_chunk - comp
new_sum = running_sum + y_k
comp = (new_sum - running_sum) - y_k
running_sum = new_sum
running_max = new_max
final_max = running_max.broadcast_to([BLOCK_N]).reshape([BLOCK_N])
total_sum_vec = running_sum.broadcast_to([BLOCK_N]).reshape([BLOCK_N])
total_sum = tl.max(total_sum_vec)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -68,9 +68,8 @@
running_max = new_max
- final_max = running_max.broadcast_to([BLOCK_N]).reshape([BLOCK_N])
- total_sum_vec = running_sum.broadcast_to([BLOCK_N]).reshape([BLOCK_N])
- total_sum = tl.max(total_sum_vec)
+ final_max = running_max.broadcast_to([BLOCK_N])
+ total_sum_vec = running_sum.broadcast_to([BLOCK_N])
# --- Write normalized output ----------------------------------------
for off in range(0, N, BLOCK_N):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<183 chars \u2014 see below>",
"newString": "<107 chars \u2014 see below>"
} final_max = running_max.broadcast_to([BLOCK_N]).reshape([BLOCK_N])
total_sum_vec = running_sum.broadcast_to([BLOCK_N]).reshape([BLOCK_N])
total_sum = tl.max(total_sum_vec) final_max = running_max.broadcast_to([BLOCK_N])
total_sum_vec = running_sum.broadcast_to([BLOCK_N])Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check with broadcast fix"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 118, in forward
return triton_softmax(x)
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 104, in triton_softmax
_softmax_kernel_online[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel =
... (1641 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -77,7 +77,7 @@
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
e = tl.exp(x - final_max)
- y = tl.where(mask, e / total_sum, 0.0)
+ y = tl.where(mask, e / total_sum_vec, 0.0)
tl.store(y_base + idx, y, mask=mask)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": " y = tl.where(mask, e / total_sum, 0.0)",
"newString": " y = tl.where(mask, e / total_sum_vec, 0.0)"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check with total_sum_vec fix"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Benchmark online softmax"
}shape=0 variant=eager tflops=0.010 gbps=15.453 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.100 ms=0.104
shape=0 variant=sota tflops=0.009 gbps=14.901 ms=0.070
shape=0 variant=solution tflops=0.054 gbps=85.668 ms=0.012
shape=0 solution_peak_fraction=0.0476
shape=1 variant=eager tflops=0.041 gbps=65.340 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=42.989 ms=0.098
shape=1 variant=sota tflops=0.035 gbps=55.954 ms=0.075
shape=1 variant=solution tflops=0.132 gbps=210.727 ms=0.020
shape=1 solution_peak_fraction=0.1171
shape=2 variant=eager tflops=0.072 gbps=115.814 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.404 ms=0.218
shape=2 variant=sota tflops=0.070 gbps=111.646 ms=0.075
shape=2 variant=solution tflops=0.113 gbps=180.664 ms=0.046
shape=2 solution_peak_fraction=0.1004
shape=3 variant=eager tflops=0.072 gbps=114.824 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.293 ms=0.360
shape=3 variant=sota tflops=0.068 gbps=109.523 ms=0.077
shape=3 variant=solution tflops=0.063 gbps=101.194 ms=0.083
shape=3 solution_peak_fraction=0.0562
shape=4 variant=eager tflops=0.072 gbps=114.774 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.813 ms=0.228
shape=4 variant=sota tflops=0.063 gbps=100.112 ms=0.084
shape=4 variant=solution tflops=0.096 gbps=153.886 ms=0.055
shape=4 solution_peak_fraction=0.0855
peak_fraction: 0.0769
RESULT: LOW
{
"command": "<284 chars \u2014 see below>",
"description": "Check Triton shared memory and atomics"
}uv run python -c "
import torch
# Check if Triton supports __shared__ or tl.make_block_ptr
import triton.language as tl
print('make_block_ptr:', hasattr(tl, 'make_block_ptr'))
# Check atomic ops
import triton
print('atomic ops:', [x for x in dir(triton.language) if 'atomic' in x])
"make_block_ptr: True
atomic ops: ['atomic_add', 'atomic_and', 'atomic_cas', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_xchg', 'atomic_xor']
{
"command": "<251 chars \u2014 see below>",
"description": "Check tl.dot for reductions"
}uv run python -c "
import torch
# Let me check if there's a shared memory reduction pattern in Triton docs
# First, let's check tl.reduce and tl.dot
import triton.language as tl
# Check tl.dot behavior for reductions
help(tl.dot)
" 2>&1 | head -20Help on function dot in module triton.language.core:
dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=triton.language.float32, _semantic=None)
Returns the matrix product of two blocks.
The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions.
For three-dimensional blocks, `tl.dot` performs the batched matrix product,
where the first dimension of each block represents the batch dimension.
:param input: The first tensor to be multiplied.
:type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second tensor to be multiplied.
:type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
:type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
:param input_precision: How to exercise the Tensor Cores for f32 x f32. If
the device does not have Tensor Cores or the inputs are not of dtype f32,
this option is ignored. For devices that do have tensor cores, the
default precision is tf32.
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Ava
... (69 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,110 +1,232 @@
-"""Triton softmax with online max-reduce + Kahan-compensated summation.
-
-Single pass reads all elements, maintaining:
- - running_max: max seen so far
- - running_sum: Kahan-compensated sum of exp(x - running_max), adjusted when max updates
- - Then a second pass writes the normalized output.
-
-Net memory traffic: 2 passes (read + write) vs 3 passes for naive approach.
+"""CUDA C++ softmax with fp32 Kahan-compensated summation.
+
+Two-kernel approach:
+ Kernel 1: Each row splits into blocks; each block finds local max and atomic-max reduces.
+ Kernel 2: Each row (1 block) reads row_max, Kahan-sums exp(x-max), writes normalized output.
"""
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
-
+from torch.utils.cpp_extension import load_inline
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-
-@triton.jit
-def _softmax_kernel_online(
- X_ptr,
- Y_ptr,
- stride_xr,
- stride_yr,
- N: tl.constexpr,
- BLOCK_N: tl.constexpr,
-):
- row_id = tl.program_id(0)
- x_base = X_ptr + row_id * stride_xr
- y_base = Y_ptr + row_id * stride_yr
-
- # --- Online pass: track running_max and Kahan-sum of exp(x-max) ------
- running_max = tl.full([1], float('-inf'), dtype=tl.float32)
- running_sum = tl.full([1], 0.0, dtype=tl.float32)
- comp = tl.full([1], 0.0, dtype=tl.float32)
-
- for off in range(0, N, BLOCK_N):
- idx = off + tl.arange(0, BLOCK_N)
- mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
-
- # Find max of this chunk
- chunk_max = tl.max(x, keep_dims=True)
-
- # New running max
- new_max = tl.maximum(running_max, chunk_max)
-
- # Scale existing sum: exp(x-old_max) -> exp(x-new_max)
- # factor = exp(old_max - new_max)
- scale = tl.exp(running_max - new_max)
-
- # Scale existing sum and compensation
- running_sum = running_sum * scale
- comp = comp * scale
-
- # Add new chunk contributions: exp(x - new_max)
- e = tl.exp(x - new_max)
- e_masked = tl.where(mask, e, 0.0)
- e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
-
- # Kahan add
- y_k = e_chunk - comp
- new_sum = running_sum + y_k
- comp = (new_sum - running_sum) - y_k
- running_sum = new_sum
-
- running_max = new_max
-
- final_max = running_max.broadcast_to([BLOCK_N])
- total_sum_vec = running_sum.broadcast_to([BLOCK_N])
-
- # --- Write normalized output ----------------------------------------
- for off in range(0, N, BLOCK_N):
- idx = off + tl.arange(0, BLOCK_N)
- mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
- e = tl.exp(x - final_max)
- y = tl.where(mask, e / total_sum_vec, 0.0)
- tl.store(y_base + idx, y, mask=mask)
-
-
-def _pick_block_n(vocab: int) -> int:
- if vocab <= 1024:
- return vocab
- if vocab <= 4096:
- return 2048
- if vocab <= 32768:
- return 4096
- return 4096
-
-
-def triton_softmax(x: torch.Tensor) -> torch.Tensor:
+# ---------------------------------------------------------------------------
+# CUDA kernel source
+# ---------------------------------------------------------------------------
+CUDA_SRC = """
+#include <cuda_fp16.h>
+#include <math.h>
+#include <cfloat>
+
+#define WARP_SIZE 32
+
+// --- Warp-level max reduction (shfl) ---
+__device__ __forceinline__ float warp_reduce_max(float val) {
+ val = max(val, __shfl_xor_sync(0xffffffff, val, 16));
+ val = max(val, __shfl_xor_sync(0xffffffff, val, 8));
+ val = max(val, __shfl_xor_sync(0xffffffff, val, 4));
+ val = max(val, __shfl_xor_sync(0xffffffff, val, 2));
+ val = max(val, __shfl_xor_sync(0xffffffff, val, 1));
+ return val;
+}
+
+// --- Warp-level sum reduction (shfl) ---
+__device__ __forceinline__ float warp_reduce_sum(float val) {
+ val += __shfl_xor_sync(0xffffffff, val, 16);
+ val += __shfl_xor_sync(0xffffffff, val, 8);
+ val += __shfl_xor_sync(0xffffffff, val, 4);
+ val += __shfl_xor_sync(0xffffffff, val, 2);
+ val += __shfl_xor_sync(0xffffffff, val, 1);
+ return val;
+}
+
+// --- Kernel 1: per-chunk max, atomic-max into row_max ---
+__global__ void max_kernel(const float* x, float* row_max, int N, int chunk_size) {
+ const int num_chunks = gridDim.x; // per row: blockDim.y chunks
+ const int row_id = blockIdx.y;
+ const int chunk_id = blockIdx.x;
+
+ const float* row_ptr = x + (long long)row_id * N;
+ const int tid = threadIdx.x;
+
+ // Each thread loads multiple elements
+ float local_max = -FLT_MAX;
+ const int start = chunk_id * chunk_size;
+ const int end = min(start + chunk_size, N);
+
+ for (int i = start + tid; i < end; i += blockDim.x) {
+ local_max = max(local_max, row_ptr[i]);
+ }
+
+ // Warp-level max reduce
+ local_max = warp_reduce_max(local_max);
+
+ // Thread 0 of each warp does atomic max
+ if (tid % WARP_SIZE == 0) {
+ atomicMax(row_max + row_id, __float_as_int(local_max));
+ }
+}
+
+// --- Kernel 2: read row_max, Kahan-sum exp(x-max), write normalized ---
+__global__ void norm_kernel(const float* x, float* y, const float* row_max, int N) {
+ const int row_id = blockIdx.x;
+ const int tid = threadIdx.x;
+
+ const float* row_ptr = x + (long long)row_id * N;
+ float* out_ptr = y + (long long)row_id * N;
+ const float m = row_max[row_id];
+
+ // Pass 1: Kahan sum of exp(x - m)
+ float total = 0.0f;
+ float comp = 0.0f;
+
+ for (int i = tid; i < N; i += blockDim.x) {
+ float val = row_ptr[i] - m;
+ float e = expf(val);
+
+ float yk = e - comp;
+ f
... (4212 more chars){
"content": "<7243 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "<104 chars \u2014 see below>",
"description": "Check load_inline API"
}uv run python -c "
from torch.utils.cpp_extension import load_inline
help(load_inline)
" 2>&1 | head -60Help on function load_inline in module torch.utils.cpp_extension:
load_inline(name, cpp_sources, cuda_sources=None, sycl_sources=None, functions=None, extra_cflags=None, extra_cuda_cflags=None, extra_sycl_cflags=None, extra_ldflags=None, extra_include_paths=None, build_directory=None, verbose=False, with_cuda=None, with_sycl=None, is_python_module=True, with_pytorch_error_handling=True, keep_intermediates=True, use_pch=False, no_implicit_headers=False)
Load a PyTorch C++ extension just-in-time (JIT) from string sources.
This function behaves exactly like :func:`load`, but takes its sources as
strings rather than filenames. These strings are stored to files in the
build directory, after which the behavior of :func:`load_inline` is
identical to :func:`load`.
See `the
tests <https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions_jit.py>`_
for good examples of using this function.
Sources may omit two required parts of a typical non-inline C++ extension:
the necessary header includes, as well as the (pybind11) binding code. More
precisely, strings passed to ``cpp_sources`` are first concatenated into a
single ``.cpp`` file. This file is then prepended with ``#include
<torch/extension.h>``
Furthermore, if the ``functions`` argument is supplied, bindings will be
automatically generated for each function specified. ``functions`` can
either be a list of function names, or a dictionar
... (2134 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,8 +1,8 @@
"""CUDA C++ softmax with fp32 Kahan-compensated summation.
-Two-kernel approach:
- Kernel 1: Each row splits into blocks; each block finds local max and atomic-max reduces.
- Kernel 2: Each row (1 block) reads row_max, Kahan-sums exp(x-max), writes normalized output.
+Two-kernel approach for GPU saturation:
+ Kernel 1: Multiple blocks per row, each finds local max, atomicCAS-max reduces.
+ Kernel 2: One block per row, reads row_max, Kahan-sums exp(x-max), writes normalized output.
"""
import torch
import torch.nn as nn
@@ -12,161 +12,175 @@
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-# ---------------------------------------------------------------------------
-# CUDA kernel source
-# ---------------------------------------------------------------------------
-CUDA_SRC = """
-#include <cuda_fp16.h>
+CUDA_SRC = r"""
+#include <cuda_runtime.h>
#include <math.h>
#include <cfloat>
-#define WARP_SIZE 32
-
-// --- Warp-level max reduction (shfl) ---
-__device__ __forceinline__ float warp_reduce_max(float val) {
- val = max(val, __shfl_xor_sync(0xffffffff, val, 16));
- val = max(val, __shfl_xor_sync(0xffffffff, val, 8));
- val = max(val, __shfl_xor_sync(0xffffffff, val, 4));
- val = max(val, __shfl_xor_sync(0xffffffff, val, 2));
- val = max(val, __shfl_xor_sync(0xffffffff, val, 1));
- return val;
-}
-
-// --- Warp-level sum reduction (shfl) ---
-__device__ __forceinline__ float warp_reduce_sum(float val) {
- val += __shfl_xor_sync(0xffffffff, val, 16);
- val += __shfl_xor_sync(0xffffffff, val, 8);
- val += __shfl_xor_sync(0xffffffff, val, 4);
- val += __shfl_xor_sync(0xffffffff, val, 2);
- val += __shfl_xor_sync(0xffffffff, val, 1);
- return val;
-}
-
-// --- Kernel 1: per-chunk max, atomic-max into row_max ---
-__global__ void max_kernel(const float* x, float* row_max, int N, int chunk_size) {
- const int num_chunks = gridDim.x; // per row: blockDim.y chunks
+#define WARP 32
+
+// Atomic max for float using CAS (IEEE 754 bit-cast trick)
+__device__ __forceinline__ void atomicMaxF32(float* addr, float val) {
+ int* iaddr = (int*)addr;
+ int old = *iaddr;
+ while (true) {
+ float old_f = __int_as_float(old);
+ if (val <= old_f) break;
+ int assumed = old;
+ old = atomicCAS(iaddr, assumed, __float_as_int(val));
+ if (old == assumed) break;
+ }
+}
+
+// Warp-level max
+__device__ __forceinline__ float warp_max(float v) {
+ #pragma unroll
+ for (int mask = WARP / 2; mask >= 1; mask >>= 1)
+ v = fmaxf(v, __shfl_xor_sync(0xffffffff, v, mask));
+ return v;
+}
+
+// Warp-level sum
+__device__ __forceinline__ float warp_sum(float v) {
+ #pragma unroll
+ for (int mask = WARP / 2; mask >= 1; mask >>= 1)
+ v += __shfl_xor_sync(0xffffffff, v, mask);
+ return v;
+}
+
+// Kernel 1: find per-row max with multiple blocks per row
+__global__ void max_kernel(const float* x, float* row_max,
+ int N, int stride_xr) {
const int row_id = blockIdx.y;
const int chunk_id = blockIdx.x;
-
- const float* row_ptr = x + (long long)row_id * N;
+ const int chunks_per_row = gridDim.x;
const int tid = threadIdx.x;
-
- // Each thread loads multiple elements
+ const int nthreads = blockDim.x;
+
+ const float* xp = x + (long long)row_id * (long long)stride_xr;
+ const int chunk_len = (N + chunks_per_row - 1) / chunks_per_row;
+ const int start = chunk_id * chunk_len;
+ const int end = min(start + chunk_len, N);
+
float local_max = -FLT_MAX;
- const int start = chunk_id * chunk_size;
- const int end = min(start + chunk_size, N);
-
- for (int i = start + tid; i < end; i += blockDim.x) {
- local_max = max(local_max, row_ptr[i]);
- }
-
- // Warp-level max reduce
- local_max = warp_reduce_max(local_max);
-
- // Thread 0 of each warp does atomic max
- if (tid % WARP_SIZE == 0) {
- atomicMax(row_max + row_id, __float_as_int(local_max));
- }
-}
-
-// --- Kernel 2: read row_max, Kahan-sum exp(x-max), write normalized ---
-__global__ void norm_kernel(const float* x, float* y, const float* row_max, int N) {
+ for (int i = start + tid; i < end; i += nthreads) {
+ local_max = fmaxf(local_max, xp[i]);
+ }
+
+ local_max = warp_max(local_max);
+
+ if (tid % WARP == 0) {
+ atomicMaxF32(row_max + row_id, local_max);
+ }
+}
+
+// Kernel 2: normalize with Kahan-compensated summation
+// Shared memory for inter-warp Kahan sum reduction + final normalization
+__global__ void norm_kernel(const float* x, float* y, const float* row_max,
+ int N, int stride_x, int stride_y) {
const int row_id = blockIdx.x;
const int tid = threadIdx.x;
-
- const float* row_ptr = x + (long long)row_id * N;
- float* out_ptr = y + (long long)row_id * N;
+ const int nthreads = blockDim.x;
+
+ const float* xp = x + (long long)row_id * (long long)stride_x;
+ float* yp = y + (long long)row_id * (long long)stride_y;
const float m = row_max[row_id];
- // Pass 1: Kahan sum of exp(x - m)
- float total = 0.0f;
- float comp = 0.0f;
-
- for (int i = tid; i < N; i += blockDim.x) {
- float val = row_ptr[i] - m;
- float e = expf(val);
-
- float yk = e - comp;
- float t = total + yk;
- comp = (t - total) - yk;
- total = t;
- }
-
- // Warp-level sum reduce
- total = warp_reduce_sum(total);
-
- // Broadcast total via shared memory
- __shared__ float shared_total;
- if (tid % WARP_SIZE == 0) {
- atomicAdd(&shared_total, total);
+ // --- Pass 1: Kahan sum of exp(x - m) ---
+ // Each warp does its own Kahan sum, then warp sums are combined
+ float w_total = 0.0f;
... (7244 more chars){
"content": "<7467 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 180000,
"description": "Check CUDA C++ softmax correctness"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 237, in forward
return cuda_softmax(x)
^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/p
... (6615 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -2,7 +2,7 @@
Two-kernel approach for GPU saturation:
Kernel 1: Multiple blocks per row, each finds local max, atomicCAS-max reduces.
- Kernel 2: One block per row, reads row_max, Kahan-sums exp(x-max), writes normalized output.
+ Kernel 2: One block per row, reads row_max, Kahan-sums exp(x-max), writes output.
"""
import torch
import torch.nn as nn
@@ -12,6 +12,7 @@
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+# CUDA source: kernels + host launch wrappers (all in .cu so <<<>>> works)
CUDA_SRC = r"""
#include <cuda_runtime.h>
#include <math.h>
@@ -19,7 +20,7 @@
#define WARP 32
-// Atomic max for float using CAS (IEEE 754 bit-cast trick)
+// Atomic max for float using CAS
__device__ __forceinline__ void atomicMaxF32(float* addr, float val) {
int* iaddr = (int*)addr;
int old = *iaddr;
@@ -48,7 +49,7 @@
return v;
}
-// Kernel 1: find per-row max with multiple blocks per row
+// Kernel 1: multi-block max reduction per row
__global__ void max_kernel(const float* x, float* row_max,
int N, int stride_xr) {
const int row_id = blockIdx.y;
@@ -75,7 +76,6 @@
}
// Kernel 2: normalize with Kahan-compensated summation
-// Shared memory for inter-warp Kahan sum reduction + final normalization
__global__ void norm_kernel(const float* x, float* y, const float* row_max,
int N, int stride_x, int stride_y) {
const int row_id = blockIdx.x;
@@ -86,8 +86,7 @@
float* yp = y + (long long)row_id * (long long)stride_y;
const float m = row_max[row_id];
- // --- Pass 1: Kahan sum of exp(x - m) ---
- // Each warp does its own Kahan sum, then warp sums are combined
+ // Pass 1: Kahan sum of exp(x - m)
float w_total = 0.0f;
float w_comp = 0.0f;
@@ -99,21 +98,15 @@
w_total = t;
}
- // Warp-reduce the Kahan sum
w_total = warp_sum(w_total);
- // Also reduce compensation across warp
w_comp = warp_sum(w_comp);
- // Each warp leader writes to shared memory
extern __shared__ float smem[];
if (tid % WARP == 0) {
smem[tid / WARP] = w_total;
- // We skip Kahan cross-warp; per-warp Kahan + inter-warp plain sum
- // is sufficient at these tolerance levels
}
__syncthreads();
- // Thread 0 sums all warp contributions
float total_sum = 0.0f;
if (tid == 0) {
int nwarps = (nthreads + WARP - 1) / WARP;
@@ -126,58 +119,42 @@
total_sum = smem[0];
- // --- Pass 2: write normalized output ---
+ // Pass 2: write normalized output
float inv_sum = 1.0f / total_sum;
for (int i = tid; i < N; i += nthreads) {
float e = expf(xp[i] - m);
yp[i] = e * inv_sum;
}
}
+
+// Host functions callable from Python
+void softmax_fwd(torch::Tensor x, torch::Tensor y, torch::Tensor row_max,
+ int N, int stride_xr, int stride_yr, int num_chunks) {
+ int batch = row_max.size(0);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+
+ int threads1 = 256;
+ dim3 grid1(num_chunks, batch);
+ dim3 block1(threads1);
+ max_kernel<<<grid1, block1, 0, stream>>>(
+ x.data_ptr<float>(), row_max.data_ptr<float>(), N, stride_xr);
+
+ int threads2 = 256;
+ int smem_size = ((threads2 + WARP - 1) / WARP) * sizeof(float);
+ dim3 grid2(batch);
+ dim3 block2(threads2);
+ norm_kernel<<<grid2, block2, smem_size, stream>>>(
+ x.data_ptr<float>(), y.data_ptr<float>(), row_max.data_ptr<float>(),
+ N, stride_xr, stride_yr);
+}
"""
-# C++ wrapper that creates pybind11 functions
-CPP_SRC = """
+CPP_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
-// Forward declarations of CUDA kernels
-void max_kernel(const float* x, float* row_max, int N, int stride_xr,
- int num_chunks, int batch, int threads_per_block);
-void norm_kernel(const float* x, float* y, const float* row_max,
- int N, int stride_x, int stride_y, int batch, int threads_per_block);
-
-// Launch wrappers
-void launch_max_kernel(torch::Tensor x, torch::Tensor row_max, int N,
- int stride_xr, int num_chunks) {
- at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
- int threads = 256;
- dim3 grid(num_chunks, row_max.size(0));
- dim3 block(threads);
- max_kernel<<<grid, block, 0, stream>>>(
- x.data_ptr<float>(), row_max.data_ptr<float>(), N, stride_xr);
-}
-
-void launch_norm_kernel(torch::Tensor x, torch::Tensor y, torch::Tensor row_max,
- int N, int stride_x, int stride_y) {
- at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
- int threads = 256;
- int smem_size = ((threads + 31) / 32) * sizeof(float);
- dim3 grid(row_max.size(0));
- dim3 block(threads);
- norm_kernel<<<grid, block, smem_size, stream>>>(
- x.data_ptr<float>(), y.data_ptr<float>(), row_max.data_ptr<float>(),
- N, stride_x, stride_y);
-}
-
void softmax_fwd(torch::Tensor x, torch::Tensor y, torch::Tensor row_max,
- int N, int stride_xr, int stride_yr, int num_chunks) {
- launch_max_kernel(x, row_max, N, stride_xr, num_chunks);
- launch_norm_kernel(x, y, row_max, N, stride_xr, stride_yr);
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("softmax_fwd", &softmax_fwd, "CUDA softmax forward pass");
-}
+ int N, int stride_xr, int stride_yr, int num_chunks);
"""
@@ -193,12 +170,7 @@
cpp_sources=CPP_SRC,
cuda_sources=CUDA_SRC,
functions=["softmax_fwd"],
- extra_cuda_cflags=[
- "-O3",
- "-arch=sm_120",
- "-std=c++17",
- "--use_fast_math",
- ],
+ extra_cuda_cflags=[
... (382 more chars){
"content": "<6188 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -149,13 +149,7 @@
}
"""
-CPP_SRC = r"""
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-
-void softmax_fwd(torch::Tensor x, torch::Tensor y, torch::Tensor row_max,
- int N, int stride_xr, int stride_yr, int num_chunks);
-"""
+CPP_SRC = ""
_module = None
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<228 chars \u2014 see below>",
"newString": "CPP_SRC = \"\""
}CPP_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
void softmax_fwd(torch::Tensor x, torch::Tensor y, torch::Tensor row_max,
int N, int stride_xr, int stride_yr, int num_chunks);
"""Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -14,6 +14,8 @@
# CUDA source: kernels + host launch wrappers (all in .cu so <<<>>> works)
CUDA_SRC = r"""
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <math.h>
#include <cfloat>
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<77 chars \u2014 see below>",
"newString": "<141 chars \u2014 see below>"
}CUDA_SRC = r"""
#include <cuda_runtime.h>
#include <math.h>
#include <cfloat>CUDA_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <math.h>
#include <cfloat>Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 180000,
"description": "Check CUDA C++ softmax correctness"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 204, in forward
return cuda_softmax(x)
^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/p
... (3461 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -151,7 +151,7 @@
}
"""
-CPP_SRC = ""
+CPP_SRC = "void softmax_fwd(torch::Tensor, torch::Tensor, torch::Tensor, int, int, int, int);"
_module = None
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "CPP_SRC = \"\"",
"newString": "CPP_SRC = \"void softmax_fwd(torch::Tensor, torch::Tensor, torch::Tensor, int, int, int, int);\""
}Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 180000,
"description": "Test CUDA C++ build"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Benchmark CUDA C++ softmax"
}shape=0 variant=eager tflops=0.010 gbps=15.537 ms=0.067
shape=0 variant=compiled tflops=0.006 gbps=10.291 ms=0.102
shape=0 variant=sota tflops=0.009 gbps=14.757 ms=0.071
shape=0 variant=solution tflops=0.050 gbps=80.610 ms=0.013
shape=0 solution_peak_fraction=0.0448
shape=1 variant=eager tflops=0.041 gbps=65.915 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=43.165 ms=0.097
shape=1 variant=sota tflops=0.034 gbps=54.375 ms=0.077
shape=1 variant=solution tflops=0.071 gbps=112.944 ms=0.037
shape=1 solution_peak_fraction=0.0627
shape=2 variant=eager tflops=0.073 gbps=116.872 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.421 ms=0.218
shape=2 variant=sota tflops=0.067 gbps=107.260 ms=0.078
shape=2 variant=solution tflops=0.047 gbps=74.632 ms=0.112
shape=2 solution_peak_fraction=0.0415
shape=3 variant=eager tflops=0.073 gbps=116.353 ms=0.072
shape=3 variant=compiled tflops=0.015 gbps=23.288 ms=0.360
shape=3 variant=sota tflops=0.067 gbps=106.867 ms=0.078
shape=3 variant=solution tflops=0.026 gbps=41.002 ms=0.205
shape=3 solution_peak_fraction=0.0228
shape=4 variant=eager tflops=0.073 gbps=117.002 ms=0.072
shape=4 variant=compiled tflops=0.023 gbps=36.779 ms=0.228
shape=4 variant=sota tflops=0.061 gbps=97.325 ms=0.086
shape=4 variant=solution tflops=0.045 gbps=72.196 ms=0.116
shape=4 solution_peak_fraction=0.0401
peak_fraction: 0.0403
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,196 +1,118 @@
-"""CUDA C++ softmax with fp32 Kahan-compensated summation.
+"""Triton softmax with online max-reduce + Kahan-compensated summation.
-Two-kernel approach for GPU saturation:
- Kernel 1: Multiple blocks per row, each finds local max, atomicCAS-max reduces.
- Kernel 2: One block per row, reads row_max, Kahan-sums exp(x-max), writes output.
+Single read pass for max+sum, single write pass for normalization.
+2 memory passes total vs 3 for naive.
"""
import torch
import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
+import triton
+import triton.language as tl
+
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-# CUDA source: kernels + host launch wrappers (all in .cu so <<<>>> works)
-CUDA_SRC = r"""
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <cuda_runtime.h>
-#include <math.h>
-#include <cfloat>
-#define WARP 32
+@triton.jit
+def _softmax_kernel_online(
+ X_ptr,
+ Y_ptr,
+ stride_xr,
+ stride_yr,
+ N,
+ BLOCK_N: tl.constexpr,
+):
+ row_id = tl.program_id(0)
+ x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
-// Atomic max for float using CAS
-__device__ __forceinline__ void atomicMaxF32(float* addr, float val) {
- int* iaddr = (int*)addr;
- int old = *iaddr;
- while (true) {
- float old_f = __int_as_float(old);
- if (val <= old_f) break;
- int assumed = old;
- old = atomicCAS(iaddr, assumed, __float_as_int(val));
- if (old == assumed) break;
- }
-}
+ # --- Online pass: track running_max and Kahan-sum of exp(x-max) ------
+ running_max = tl.full([1], float('-inf'), dtype=tl.float32)
+ running_sum = tl.full([1], 0.0, dtype=tl.float32)
+ comp = tl.full([1], 0.0, dtype=tl.float32)
-// Warp-level max
-__device__ __forceinline__ float warp_max(float v) {
- #pragma unroll
- for (int mask = WARP / 2; mask >= 1; mask >>= 1)
- v = fmaxf(v, __shfl_xor_sync(0xffffffff, v, mask));
- return v;
-}
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
-// Warp-level sum
-__device__ __forceinline__ float warp_sum(float v) {
- #pragma unroll
- for (int mask = WARP / 2; mask >= 1; mask >>= 1)
- v += __shfl_xor_sync(0xffffffff, v, mask);
- return v;
-}
+ # Find max of this chunk
+ chunk_max = tl.max(x, keep_dims=True)
-// Kernel 1: multi-block max reduction per row
-__global__ void max_kernel(const float* x, float* row_max,
- int N, int stride_xr) {
- const int row_id = blockIdx.y;
- const int chunk_id = blockIdx.x;
- const int chunks_per_row = gridDim.x;
- const int tid = threadIdx.x;
- const int nthreads = blockDim.x;
+ # New running max
+ new_max = tl.maximum(running_max, chunk_max)
- const float* xp = x + (long long)row_id * (long long)stride_xr;
- const int chunk_len = (N + chunks_per_row - 1) / chunks_per_row;
- const int start = chunk_id * chunk_len;
- const int end = min(start + chunk_len, N);
+ # Scale existing sum by exp(old_max - new_max)
+ scale = tl.exp(running_max - new_max)
- float local_max = -FLT_MAX;
- for (int i = start + tid; i < end; i += nthreads) {
- local_max = fmaxf(local_max, xp[i]);
- }
+ # Scale existing sum and compensation
+ running_sum = running_sum * scale
+ comp = comp * scale
- local_max = warp_max(local_max);
+ # Add new chunk contributions: exp(x - new_max)
+ e = tl.exp(x - new_max)
+ e_masked = tl.where(mask, e, 0.0)
+ e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
- if (tid % WARP == 0) {
- atomicMaxF32(row_max + row_id, local_max);
- }
-}
+ # Kahan add
+ y_k = e_chunk - comp
+ new_sum = running_sum + y_k
+ comp = (new_sum - running_sum) - y_k
+ running_sum = new_sum
-// Kernel 2: normalize with Kahan-compensated summation
-__global__ void norm_kernel(const float* x, float* y, const float* row_max,
- int N, int stride_x, int stride_y) {
- const int row_id = blockIdx.x;
- const int tid = threadIdx.x;
- const int nthreads = blockDim.x;
+ running_max = new_max
- const float* xp = x + (long long)row_id * (long long)stride_x;
- float* yp = y + (long long)row_id * (long long)stride_y;
- const float m = row_max[row_id];
+ # Broadcast [1] scalars to [BLOCK_N] for element-wise ops
+ final_max = running_max.broadcast_to([BLOCK_N])
+ total_sum_vec = running_sum.broadcast_to([BLOCK_N])
- // Pass 1: Kahan sum of exp(x - m)
- float w_total = 0.0f;
- float w_comp = 0.0f;
-
- for (int i = tid; i < N; i += nthreads) {
- float e = expf(xp[i] - m);
- float yk = e - w_comp;
- float t = w_total + yk;
- w_comp = (t - w_total) - yk;
- w_total = t;
- }
-
- w_total = warp_sum(w_total);
- w_comp = warp_sum(w_comp);
-
- extern __shared__ float smem[];
- if (tid % WARP == 0) {
- smem[tid / WARP] = w_total;
- }
- __syncthreads();
-
- float total_sum = 0.0f;
- if (tid == 0) {
- int nwarps = (nthreads + WARP - 1) / WARP;
- for (int w = 0; w < nwarps; w++) {
- total_sum += smem[w];
- }
- smem[0] = total_sum;
- }
- __syncthreads();
-
- total_sum = smem[0];
-
- // Pass 2: write normalized output
- float inv_sum = 1.0f / total_sum;
- for (int i = tid; i < N; i += nthreads) {
- float e = expf(xp[i] - m);
- yp[i] = e * inv_sum;
- }
-}
-
-// Host functions callable from Python
-void so
... (3560 more chars){
"content": "<3920 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check optimized Triton softmax"
}PASS
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sweep_configs.py
@@ -0,0 +1,124 @@
+"""Sweep over BLOCK_N and num_warps to find optimal config."""
+import statistics
+import torch
+import triton
+import triton.language as tl
+import solution
+
+@triton.jit
+def _softmax_kernel_online(
+ X_ptr, Y_ptr, stride_xr, stride_yr, N, BLOCK_N: tl.constexpr,
+):
+ row_id = tl.program_id(0)
+ x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
+
+ running_max = tl.full([1], float('-inf'), dtype=tl.float32)
+ running_sum = tl.full([1], 0.0, dtype=tl.float32)
+ comp = tl.full([1], 0.0, dtype=tl.float32)
+
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
+ chunk_max = tl.max(x, keep_dims=True)
+ new_max = tl.maximum(running_max, chunk_max)
+ scale = tl.exp(running_max - new_max)
+ running_sum = running_sum * scale
+ comp = comp * scale
+ e = tl.exp(x - new_max)
+ e_masked = tl.where(mask, e, 0.0)
+ e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
+ y_k = e_chunk - comp
+ new_sum = running_sum + y_k
+ comp = (new_sum - running_sum) - y_k
+ running_sum = new_sum
+ running_max = new_max
+
+ final_max = running_max.broadcast_to([BLOCK_N])
+ total_sum_vec = running_sum.broadcast_to([BLOCK_N])
+
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
+ e = tl.exp(x - final_max)
+ y = tl.where(mask, e / total_sum_vec, 0.0)
+ tl.store(y_base + idx, y, mask=mask)
+
+
+def time_config(batch, vocab, block_n, num_warps, num_stages, iters=10):
+ x = torch.randn(batch, vocab, device='cuda', dtype=torch.float32)
+ y = torch.empty_like(x)
+
+ while vocab % block_n != 0 and block_n > 32:
+ block_n //= 2
+
+ for _ in range(3):
+ _softmax_kernel_online[(batch,)](
+ x, y, x.stride(0), y.stride(0), vocab,
+ BLOCK_N=block_n, num_warps=num_warps, num_stages=num_stages,
+ )
+ torch.cuda.synchronize()
+
+ times = []
+ for _ in range(iters):
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ _softmax_kernel_online[(batch,)](
+ x, y, x.stride(0), y.stride(0), vocab,
+ BLOCK_N=block_n, num_warps=num_warps, num_stages=num_stages,
+ )
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
+
+
+configs = [
+ (512, 4, 2), (512, 4, 3), (512, 8, 2), (512, 8, 3),
+ (1024, 4, 2), (1024, 4, 3), (1024, 8, 2), (1024, 8, 3),
+ (2048, 4, 2), (2048, 4, 3), (2048, 8, 2), (2048, 8, 3),
+ (4096, 4, 2), (4096, 4, 3), (4096, 8, 2), (4096, 8, 3),
+]
+
+shapes = [
+ (32, 4096, "32x4K"),
+ (16, 32768, "16x32K"),
+ (8, 131072, "8x128K"),
+ (4, 262144, "4x256K"),
+]
+
+for batch, vocab, label in shapes:
+ print(f"\n=== {label} (batch={batch}, vocab={vocab}) ===")
+ best_ms = float('inf')
+ best_cfg = None
+
+ for block_n_raw, num_warps, num_stages in configs:
+ block_n = min(block_n_raw, vocab)
+ if vocab % block_n != 0:
+ # Find largest divisor of vocab <= block_n_raw
+ candidate = block_n
+ while vocab % candidate != 0 and candidate > 32:
+ candidate //= 2
+ if vocab % candidate != 0:
+ continue
+ block_n = candidate
+
+ try:
+ ms = time_config(batch, vocab, block_n, num_warps, num_stages)
+ except Exception as e:
+ print(f" BLK={block_n} warps={num_warps} stages={num_stages}: ERROR ({e})")
+ continue
+
+ bw = (batch * vocab * 4 * 2) / (ms * 1e-3) / 1e9
+ if ms < best_ms:
+ best_ms = ms
+ best_cfg = (block_n, num_warps, num_stages)
+
+ print(f" BLK={block_n:5d} warps={num_warps} stages={num_stages}: {ms:.3f}ms ({bw:.0f} GB/s)")
+
+ if best_cfg:
+ print(f" >>> BEST: BLK={best_cfg[0]} warps={best_cfg[1]} stages={best_cfg[2]} @ {best_ms:.3f}ms <<<")
{
"content": "<4232 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sweep_configs.py"
}Wrote file successfully.{
"command": "uv run python sweep_configs.py 2>&1",
"timeout": 300000,
"description": "Sweep BLOCK_N/num_warps/num_stages configs"
}=== 32x4K (batch=32, vocab=4096) ===
BLK= 512 warps=4 stages=2: 0.014ms (77 GB/s)
BLK= 512 warps=4 stages=3: 0.012ms (86 GB/s)
BLK= 512 warps=8 stages=2: 0.012ms (85 GB/s)
BLK= 512 warps=8 stages=3: 0.012ms (85 GB/s)
BLK= 1024 warps=4 stages=2: 0.010ms (104 GB/s)
BLK= 1024 warps=4 stages=3: 0.010ms (110 GB/s)
BLK= 1024 warps=8 stages=2: 0.010ms (102 GB/s)
BLK= 1024 warps=8 stages=3: 0.010ms (104 GB/s)
BLK= 2048 warps=4 stages=2: 0.009ms (113 GB/s)
BLK= 2048 warps=4 stages=3: 0.009ms (115 GB/s)
BLK= 2048 warps=8 stages=2: 0.009ms (113 GB/s)
BLK= 2048 warps=8 stages=3: 0.010ms (109 GB/s)
BLK= 4096 warps=4 stages=2: 0.009ms (112 GB/s)
BLK= 4096 warps=4 stages=3: 0.009ms (114 GB/s)
BLK= 4096 warps=8 stages=2: 0.009ms (111 GB/s)
BLK= 4096 warps=8 stages=3: 0.009ms (111 GB/s)
>>> BEST: BLK=2048 warps=4 stages=3 @ 0.009ms <<<
=== 16x32K (batch=16, vocab=32768) ===
BLK= 512 warps=4 stages=2: 0.049ms (85 GB/s)
BLK= 512 warps=4 stages=3: 0.049ms (85 GB/s)
BLK= 512 warps=8 stages=2: 0.049ms (86 GB/s)
BLK= 512 warps=8 stages=3: 0.049ms (86 GB/s)
BLK= 1024 warps=4 stages=2: 0.031ms (136 GB/s)
BLK= 1024 warps=4 stages=3: 0.031ms (135 GB/s)
BLK= 1024 warps=8 stages=2: 0.031ms (135 GB/s)
BLK= 1024 warps=8 stages=3: 0.030ms (139 GB/s)
BLK= 2048 warps=4 stages=2: 0.024ms (172 GB/s)
BLK= 2048 warps=4 stages=3: 0.024ms (174 GB/s)
BLK= 2048 warps=8 stages=2: 0.021ms (201 GB/s)
BLK= 2048 warps=8 stages=3: 0.022ms (192 GB/s)
BLK=
... (1969 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -78,22 +78,13 @@
def _pick_block_n(vocab: int) -> int:
- """Choose BLOCK_N: trade register pressure vs loop unroll factor.
+ """Choose BLOCK_N based on vocab size for optimal performance.
- Larger BLOCK_N = more ILP per iteration but more registers needed.
- Optimal depends on the device's register file size and occupancy targets.
+ Empirically determined via sweep: larger blocks + more warps = better.
"""
- # For Blackwell (sm_120) with large register file:
- # - BLOCK_N=1024: 4 registers for pointer arithmetic, ~16 for loaded data
- # - BLOCK_N=2048: similar but bigger load vectors
- # - BLOCK_N=4096: high register pressure
- if vocab <= 1024:
+ if vocab <= 2048:
return vocab
- if vocab <= 2048:
- return 1024
- if vocab <= 8192:
- return 2048
- return 2048
+ return min(4096, vocab)
def triton_softmax(x: torch.Tensor) -> torch.Tensor:
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"newString": "<263 chars \u2014 see below>",
"oldString": "<644 chars \u2014 see below>"
}def _pick_block_n(vocab: int) -> int:
"""Choose BLOCK_N based on vocab size for optimal performance.
Empirically determined via sweep: larger blocks + more warps = better.
"""
if vocab <= 2048:
return vocab
return min(4096, vocab)def _pick_block_n(vocab: int) -> int:
"""Choose BLOCK_N: trade register pressure vs loop unroll factor.
Larger BLOCK_N = more ILP per iteration but more registers needed.
Optimal depends on the device's register file size and occupancy targets.
"""
# For Blackwell (sm_120) with large register file:
# - BLOCK_N=1024: 4 registers for pointer arithmetic, ~16 for loaded data
# - BLOCK_N=2048: similar but bigger load vectors
# - BLOCK_N=4096: high register pressure
if vocab <= 1024:
return vocab
if vocab <= 2048:
return 1024
if vocab <= 8192:
return 2048
return 2048Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -101,8 +101,8 @@
_softmax_kernel_online[grid](
x, y, x.stride(0), y.stride(0), vocab,
BLOCK_N=block_n,
- num_warps=4,
- num_stages=3,
+ num_warps=8,
+ num_stages=2,
)
return y
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"newString": "<174 chars \u2014 see below>",
"oldString": "<174 chars \u2014 see below>"
} grid = (batch,)
_softmax_kernel_online[grid](
x, y, x.stride(0), y.stride(0), vocab,
BLOCK_N=block_n,
num_warps=8,
num_stages=2,
) grid = (batch,)
_softmax_kernel_online[grid](
x, y, x.stride(0), y.stride(0), vocab,
BLOCK_N=block_n,
num_warps=4,
num_stages=3,
)Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Verify with optimized params"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Benchmark with 8 warps BLOCK_N=4096"
}shape=0 variant=eager tflops=0.010 gbps=15.523 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=9.789 ms=0.107
shape=0 variant=sota tflops=0.010 gbps=15.712 ms=0.067
shape=0 variant=solution tflops=0.053 gbps=84.563 ms=0.012
shape=0 solution_peak_fraction=0.0470
shape=1 variant=eager tflops=0.041 gbps=65.389 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=42.771 ms=0.098
shape=1 variant=sota tflops=0.035 gbps=55.563 ms=0.075
shape=1 variant=solution tflops=0.129 gbps=205.765 ms=0.020
shape=1 solution_peak_fraction=0.1143
shape=2 variant=eager tflops=0.072 gbps=114.624 ms=0.073
shape=2 variant=compiled tflops=0.024 gbps=38.480 ms=0.218
shape=2 variant=sota tflops=0.069 gbps=109.891 ms=0.076
shape=2 variant=solution tflops=0.105 gbps=168.798 ms=0.050
shape=2 solution_peak_fraction=0.0938
shape=3 variant=eager tflops=0.072 gbps=115.711 ms=0.072
shape=3 variant=compiled tflops=0.015 gbps=23.245 ms=0.361
shape=3 variant=sota tflops=0.069 gbps=110.843 ms=0.076
shape=3 variant=solution tflops=0.060 gbps=95.342 ms=0.088
shape=3 solution_peak_fraction=0.0530
shape=4 variant=eager tflops=0.072 gbps=114.824 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.948 ms=0.227
shape=4 variant=sota tflops=0.063 gbps=100.496 ms=0.083
shape=4 variant=solution tflops=0.091 gbps=145.919 ms=0.057
shape=4 solution_peak_fraction=0.0811
peak_fraction: 0.0736
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,7 +1,10 @@
"""Triton softmax with online max-reduce + Kahan-compensated summation.
-Single read pass for max+sum, single write pass for normalization.
-2 memory passes total vs 3 for naive.
+Multi-block-per-row variant for GPU saturation:
+ Each row is split into chunks; multiple programs handle chunks in parallel.
+ Cross-chunk reduction happens through shared (on-chip) atomics + sync.
+
+For small batch sizes, we need multiple programs per row to saturate the GPU.
"""
import torch
import torch.nn as nn
@@ -15,7 +18,81 @@
@triton.jit
-def _softmax_kernel_online(
+def _softmax_kernel_split(
+ X_ptr,
+ Y_ptr,
+ row_max_ptr,
+ stride_xr,
+ stride_yr,
+ N,
+ BLOCK_N: tl.constexpr,
+ NUM_CHUNKS: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_id = pid // NUM_CHUNKS
+ chunk_id = pid % NUM_CHUNKS
+
+ x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
+
+ # --- Phase 1: Per-block online max+sum ---
+ # Each block handles its chunk online
+ x = tl.load(x_base + chunk_id * BLOCK_N + tl.arange(0, BLOCK_N),
+ mask=(chunk_id * BLOCK_N + tl.arange(0, BLOCK_N)) < N,
+ other=float('-inf')).to(tl.float32)
+
+ chunk_max = tl.max(x, keep_dims=True)
+ chunk_sum = tl.sum(tl.where(
+ (chunk_id * BLOCK_N + tl.arange(0, BLOCK_N)) < N,
+ tl.exp(x - chunk_max),
+ 0.0
+ ), keep_dims=True).to(tl.float32)
+
+ # Write partial results to shared workspace
+ # Use atomic max for row_max
+ cm_s = chunk_max.reshape(())
+ # atomic_max on row_max[row_id]
+ tl.atomic_max(row_max_ptr + row_id, cm_s)
+
+
+class Model(nn.Module):
+ def __init__(self, batch: int, vocab: int):
+ super().__init__()
+ self.batch = batch
+ self.vocab = vocab
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Fall back to working single-block kernel for now
+ return _triton_softmax_single(x)
+
+
+def _triton_softmax_single(x: torch.Tensor) -> torch.Tensor:
+ """Proven working single-block online softmax."""
+ assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
+ batch, vocab = x.shape
+ y = torch.empty_like(x)
+
+ if vocab <= 2048:
+ block_n = vocab
+ num_warps = 4
+ else:
+ block_n = min(4096, vocab)
+ num_warps = 8
+
+ while vocab % block_n != 0 and block_n > 32:
+ block_n //= 2
+
+ _softmax_kernel_online_v2[(batch,)](
+ x, y, x.stride(0), y.stride(0), vocab,
+ BLOCK_N=block_n,
+ num_warps=num_warps,
+ num_stages=2,
+ )
+ return y
+
+
+@triton.jit
+def _softmax_kernel_online_v2(
X_ptr,
Y_ptr,
stride_xr,
@@ -27,7 +104,6 @@
x_base = X_ptr + row_id * stride_xr
y_base = Y_ptr + row_id * stride_yr
- # --- Online pass: track running_max and Kahan-sum of exp(x-max) ------
running_max = tl.full([1], float('-inf'), dtype=tl.float32)
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
@@ -37,25 +113,17 @@
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
- # Find max of this chunk
chunk_max = tl.max(x, keep_dims=True)
-
- # New running max
new_max = tl.maximum(running_max, chunk_max)
-
- # Scale existing sum by exp(old_max - new_max)
scale = tl.exp(running_max - new_max)
- # Scale existing sum and compensation
running_sum = running_sum * scale
comp = comp * scale
- # Add new chunk contributions: exp(x - new_max)
e = tl.exp(x - new_max)
e_masked = tl.where(mask, e, 0.0)
e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
- # Kahan add
y_k = e_chunk - comp
new_sum = running_sum + y_k
comp = (new_sum - running_sum) - y_k
@@ -63,11 +131,9 @@
running_max = new_max
- # Broadcast [1] scalars to [BLOCK_N] for element-wise ops
final_max = running_max.broadcast_to([BLOCK_N])
total_sum_vec = running_sum.broadcast_to([BLOCK_N])
- # --- Write normalized output ----------------------------------------
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
@@ -75,46 +141,6 @@
e = tl.exp(x - final_max)
y = tl.where(mask, e / total_sum_vec, 0.0)
tl.store(y_base + idx, y, mask=mask)
-
-
-def _pick_block_n(vocab: int) -> int:
- """Choose BLOCK_N based on vocab size for optimal performance.
-
- Empirically determined via sweep: larger blocks + more warps = better.
- """
- if vocab <= 2048:
- return vocab
- return min(4096, vocab)
-
-
-def triton_softmax(x: torch.Tensor) -> torch.Tensor:
- assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
- batch, vocab = x.shape
- y = torch.empty_like(x)
-
- block_n = _pick_block_n(vocab)
- # Ensure BLOCK_N divides N
- while vocab % block_n != 0 and block_n > 32:
- block_n //= 2
-
- grid = (batch,)
- _softmax_kernel_online[grid](
- x, y, x.stride(0), y.stride(0), vocab,
- BLOCK_N=block_n,
- num_warps=8,
- num_stages=2,
- )
- return y
-
-
-class Model(nn.Module):
- def __init__(self, batch: int, vocab: int):
- super().__init__()
- self.batch = batch
- self.vocab = vocab
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return triton_softmax(x)
BATCH = 8
{
"content": "<4235 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,11 +1,12 @@
"""Triton softmax with online max-reduce + Kahan-compensated summation.
-Multi-block-per-row variant for GPU saturation:
- Each row is split into chunks; multiple programs handle chunks in parallel.
- Cross-chunk reduction happens through shared (on-chip) atomics + sync.
+Two-phase approach:
+ Phase 1 (online max+sum): 1 block per row reads X once, writes row_max and row_sum.
+ Phase 2 (normalize): Multiple blocks per row read X, row_max, row_sum and write Y.
-For small batch sizes, we need multiple programs per row to saturate the GPU.
+Phase 2 has multiple blocks per row for GPU saturation even at small batch sizes.
"""
+import math
import torch
import torch.nn as nn
import triton
@@ -18,91 +19,17 @@
@triton.jit
-def _softmax_kernel_split(
+def _online_max_kernel(
X_ptr,
- Y_ptr,
row_max_ptr,
+ row_sum_ptr,
stride_xr,
- stride_yr,
- N,
- BLOCK_N: tl.constexpr,
- NUM_CHUNKS: tl.constexpr,
-):
- pid = tl.program_id(0)
- row_id = pid // NUM_CHUNKS
- chunk_id = pid % NUM_CHUNKS
-
- x_base = X_ptr + row_id * stride_xr
- y_base = Y_ptr + row_id * stride_yr
-
- # --- Phase 1: Per-block online max+sum ---
- # Each block handles its chunk online
- x = tl.load(x_base + chunk_id * BLOCK_N + tl.arange(0, BLOCK_N),
- mask=(chunk_id * BLOCK_N + tl.arange(0, BLOCK_N)) < N,
- other=float('-inf')).to(tl.float32)
-
- chunk_max = tl.max(x, keep_dims=True)
- chunk_sum = tl.sum(tl.where(
- (chunk_id * BLOCK_N + tl.arange(0, BLOCK_N)) < N,
- tl.exp(x - chunk_max),
- 0.0
- ), keep_dims=True).to(tl.float32)
-
- # Write partial results to shared workspace
- # Use atomic max for row_max
- cm_s = chunk_max.reshape(())
- # atomic_max on row_max[row_id]
- tl.atomic_max(row_max_ptr + row_id, cm_s)
-
-
-class Model(nn.Module):
- def __init__(self, batch: int, vocab: int):
- super().__init__()
- self.batch = batch
- self.vocab = vocab
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # Fall back to working single-block kernel for now
- return _triton_softmax_single(x)
-
-
-def _triton_softmax_single(x: torch.Tensor) -> torch.Tensor:
- """Proven working single-block online softmax."""
- assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
- batch, vocab = x.shape
- y = torch.empty_like(x)
-
- if vocab <= 2048:
- block_n = vocab
- num_warps = 4
- else:
- block_n = min(4096, vocab)
- num_warps = 8
-
- while vocab % block_n != 0 and block_n > 32:
- block_n //= 2
-
- _softmax_kernel_online_v2[(batch,)](
- x, y, x.stride(0), y.stride(0), vocab,
- BLOCK_N=block_n,
- num_warps=num_warps,
- num_stages=2,
- )
- return y
-
-
-@triton.jit
-def _softmax_kernel_online_v2(
- X_ptr,
- Y_ptr,
- stride_xr,
- stride_yr,
N,
BLOCK_N: tl.constexpr,
):
+ """Single block per row: online max + Kahan sum of exp(x-max)."""
row_id = tl.program_id(0)
x_base = X_ptr + row_id * stride_xr
- y_base = Y_ptr + row_id * stride_yr
running_max = tl.full([1], float('-inf'), dtype=tl.float32)
running_sum = tl.full([1], 0.0, dtype=tl.float32)
@@ -131,16 +58,106 @@
running_max = new_max
- final_max = running_max.broadcast_to([BLOCK_N])
- total_sum_vec = running_sum.broadcast_to([BLOCK_N])
+ # Write results (scalar per row)
+ rm = running_max.reshape(())
+ rs = running_sum.reshape(())
+ tl.store(row_max_ptr + row_id, rm)
+ tl.store(row_sum_ptr + row_id, rs)
- for off in range(0, N, BLOCK_N):
+
+@triton.jit
+def _normalize_kernel(
+ X_ptr,
+ Y_ptr,
+ row_max_ptr,
+ row_sum_ptr,
+ stride_xr,
+ stride_yr,
+ N,
+ BLOCK_N: tl.constexpr,
+ NUM_CHUNKS: tl.constexpr,
+):
+ """Multiple blocks per row for GPU saturation."""
+ pid = tl.program_id(0)
+ row_id = pid // NUM_CHUNKS
+ chunk_id = pid % NUM_CHUNKS
+
+ x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
+
+ m = tl.load(row_max_ptr + row_id).to(tl.float32)
+ s = tl.load(row_sum_ptr + row_id).to(tl.float32)
+
+ # Broadcast scalars to chunk shape (implicit broadcast)
+ inv_sum = 1.0 / s
+
+ chunk_start = chunk_id * BLOCK_N
+
+ for off in range(chunk_start, min(chunk_start + BLOCK_N, N), BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
- e = tl.exp(x - final_max)
- y = tl.where(mask, e / total_sum_vec, 0.0)
+ e = tl.exp(x - m)
+ y = tl.where(mask, e * inv_sum, 0.0)
tl.store(y_base + idx, y, mask=mask)
+
+
+def triton_softmax(x: torch.Tensor) -> torch.Tensor:
+ assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
+ batch, vocab = x.shape
+ y = torch.empty_like(x)
+
+ row_max = torch.empty((batch,), dtype=torch.float32, device=x.device)
+ row_sum = torch.empty((batch,), dtype=torch.float32, device=x.device)
+
+ # Phase 1: online max+sum (1 block per row)
+ if vocab <= 2048:
+ block_n1 = vocab
+ else:
+ block_n1 = min(4096, vocab)
+ while vocab % block_n1 != 0 and block_n1 > 32:
+ block_n1 //= 2
+
+ _online_max_kernel[(batch,)](
+ x, row_max, row_sum, x.stride(0), vocab,
+ BLOCK_N=block_n1, num_warps=8, num_stages=2,
+ )
+
+ # Phase 2: normalize (multiple blocks per row for saturation)
+ if vocab <= 1024:
+ block_n2 = vocab
+ num_chunks = 1
+ elif vocab <= 4096:
+ block_n2 = 1024
+ num_chunks = (vocab + block_n2 - 1) // block_n2
+ elif vocab <= 16384:
+ block_n2 = 2048
+ num_chunks = (vocab + block_n2 -
... (729 more chars){
"content": "<4739 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Check split approach correctness"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Benchmark split approach"
}shape=0 variant=eager tflops=0.010 gbps=15.596 ms=0.067
shape=0 variant=compiled tflops=0.006 gbps=9.738 ms=0.108
shape=0 variant=sota tflops=0.010 gbps=15.398 ms=0.068
shape=0 variant=solution tflops=0.030 gbps=47.593 ms=0.022
shape=0 solution_peak_fraction=0.0264
shape=1 variant=eager tflops=0.042 gbps=66.551 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.674 ms=0.098
shape=1 variant=sota tflops=0.035 gbps=55.563 ms=0.075
shape=1 variant=solution tflops=0.119 gbps=190.373 ms=0.022
shape=1 solution_peak_fraction=0.1058
shape=2 variant=eager tflops=0.073 gbps=116.199 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.294 ms=0.219
shape=2 variant=sota tflops=0.069 gbps=110.656 ms=0.076
shape=2 variant=solution tflops=0.129 gbps=205.603 ms=0.041
shape=2 solution_peak_fraction=0.1142
shape=3 variant=eager tflops=0.073 gbps=116.173 ms=0.072
shape=3 variant=compiled tflops=0.015 gbps=23.245 ms=0.361
shape=3 variant=sota tflops=0.068 gbps=108.728 ms=0.077
shape=3 variant=solution tflops=0.079 gbps=126.182 ms=0.066
shape=3 solution_peak_fraction=0.0701
shape=4 variant=eager tflops=0.073 gbps=116.379 ms=0.072
shape=4 variant=compiled tflops=0.023 gbps=36.707 ms=0.229
shape=4 variant=sota tflops=0.062 gbps=98.606 ms=0.085
shape=4 variant=solution tflops=0.109 gbps=173.893 ms=0.048
shape=4 solution_peak_fraction=0.0966
peak_fraction: 0.0736
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,12 +1,9 @@
"""Triton softmax with online max-reduce + Kahan-compensated summation.
-Two-phase approach:
- Phase 1 (online max+sum): 1 block per row reads X once, writes row_max and row_sum.
- Phase 2 (normalize): Multiple blocks per row read X, row_max, row_sum and write Y.
-
-Phase 2 has multiple blocks per row for GPU saturation even at small batch sizes.
+Optimal single-kernel approach tuned for RTX PRO 6000 Blackwell.
+Sweep-validated optimal: BLOCK_N=4096, 8 warps for large shapes;
+BLOCK_N=2048, 4 warps for small shapes.
"""
-import math
import torch
import torch.nn as nn
import triton
@@ -19,18 +16,19 @@
@triton.jit
-def _online_max_kernel(
+def _softmax_kernel(
X_ptr,
- row_max_ptr,
- row_sum_ptr,
+ Y_ptr,
stride_xr,
+ stride_yr,
N,
BLOCK_N: tl.constexpr,
):
- """Single block per row: online max + Kahan sum of exp(x-max)."""
row_id = tl.program_id(0)
x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
+ # --- Online pass: running max + Kahan-sum -------------------------------
running_max = tl.full([1], float('-inf'), dtype=tl.float32)
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
@@ -58,47 +56,17 @@
running_max = new_max
- # Write results (scalar per row)
- rm = running_max.reshape(())
- rs = running_sum.reshape(())
- tl.store(row_max_ptr + row_id, rm)
- tl.store(row_sum_ptr + row_id, rs)
+ # Broadcast [1] to [BLOCK_N] for element-wise ops
+ final_max = running_max.broadcast_to([BLOCK_N])
+ total_sum_vec = running_sum.broadcast_to([BLOCK_N])
-
-@triton.jit
-def _normalize_kernel(
- X_ptr,
- Y_ptr,
- row_max_ptr,
- row_sum_ptr,
- stride_xr,
- stride_yr,
- N,
- BLOCK_N: tl.constexpr,
- NUM_CHUNKS: tl.constexpr,
-):
- """Multiple blocks per row for GPU saturation."""
- pid = tl.program_id(0)
- row_id = pid // NUM_CHUNKS
- chunk_id = pid % NUM_CHUNKS
-
- x_base = X_ptr + row_id * stride_xr
- y_base = Y_ptr + row_id * stride_yr
-
- m = tl.load(row_max_ptr + row_id).to(tl.float32)
- s = tl.load(row_sum_ptr + row_id).to(tl.float32)
-
- # Broadcast scalars to chunk shape (implicit broadcast)
- inv_sum = 1.0 / s
-
- chunk_start = chunk_id * BLOCK_N
-
- for off in range(chunk_start, min(chunk_start + BLOCK_N, N), BLOCK_N):
+ # --- Normalize pass: write output ---------------------------------------
+ for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
- e = tl.exp(x - m)
- y = tl.where(mask, e * inv_sum, 0.0)
+ e = tl.exp(x - final_max)
+ y = tl.where(mask, e / total_sum_vec, 0.0)
tl.store(y_base + idx, y, mask=mask)
@@ -107,46 +75,30 @@
batch, vocab = x.shape
y = torch.empty_like(x)
- row_max = torch.empty((batch,), dtype=torch.float32, device=x.device)
- row_sum = torch.empty((batch,), dtype=torch.float32, device=x.device)
+ # Sweep-validated optimal block sizes
+ if vocab <= 2048:
+ block_n = vocab
+ num_warps = 4
+ num_stages = 3
+ elif vocab <= 4096:
+ block_n = 2048
+ num_warps = 4
+ num_stages = 3
+ else:
+ block_n = min(4096, vocab)
+ num_warps = 8
+ num_stages = 2
- # Phase 1: online max+sum (1 block per row)
- if vocab <= 2048:
- block_n1 = vocab
- else:
- block_n1 = min(4096, vocab)
- while vocab % block_n1 != 0 and block_n1 > 32:
- block_n1 //= 2
+ # Ensure divisibility
+ while vocab % block_n != 0 and block_n > 32:
+ block_n //= 2
- _online_max_kernel[(batch,)](
- x, row_max, row_sum, x.stride(0), vocab,
- BLOCK_N=block_n1, num_warps=8, num_stages=2,
+ _softmax_kernel[(batch,)](
+ x, y, x.stride(0), y.stride(0), vocab,
+ BLOCK_N=block_n,
+ num_warps=num_warps,
+ num_stages=num_stages,
)
-
- # Phase 2: normalize (multiple blocks per row for saturation)
- if vocab <= 1024:
- block_n2 = vocab
- num_chunks = 1
- elif vocab <= 4096:
- block_n2 = 1024
- num_chunks = (vocab + block_n2 - 1) // block_n2
- elif vocab <= 16384:
- block_n2 = 2048
- num_chunks = (vocab + block_n2 - 1) // block_n2
- elif vocab <= 65536:
- block_n2 = 2048
- num_chunks = (vocab + block_n2 - 1) // block_n2
- else:
- block_n2 = 4096
- num_chunks = (vocab + block_n2 - 1) // block_n2
-
- grid2 = (batch * num_chunks,)
- _normalize_kernel[grid2](
- x, y, row_max, row_sum, x.stride(0), y.stride(0), vocab,
- BLOCK_N=block_n2, NUM_CHUNKS=num_chunks,
- num_warps=8, num_stages=2,
- )
-
return y
{
"content": "<3362 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && echo \"---\" && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Verify and benchmark final clean version"
}PASS
---
shape=0 variant=eager tflops=0.010 gbps=15.534 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=9.854 ms=0.106
shape=0 variant=sota tflops=0.010 gbps=15.570 ms=0.067
shape=0 variant=solution tflops=0.052 gbps=83.806 ms=0.013
shape=0 solution_peak_fraction=0.0466
shape=1 variant=eager tflops=0.041 gbps=65.717 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=42.604 ms=0.098
shape=1 variant=sota tflops=0.035 gbps=55.942 ms=0.075
shape=1 variant=solution tflops=0.129 gbps=206.088 ms=0.020
shape=1 solution_peak_fraction=0.1145
shape=2 variant=eager tflops=0.072 gbps=115.177 ms=0.073
shape=2 variant=compiled tflops=0.024 gbps=38.418 ms=0.218
shape=2 variant=sota tflops=0.070 gbps=111.503 ms=0.075
shape=2 variant=solution tflops=0.107 gbps=170.445 ms=0.049
shape=2 solution_peak_fraction=0.0947
shape=3 variant=eager tflops=0.072 gbps=114.573 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.282 ms=0.360
shape=3 variant=sota tflops=0.069 gbps=110.493 ms=0.076
shape=3 variant=solution tflops=0.060 gbps=95.239 ms=0.088
shape=3 solution_peak_fraction=0.0529
shape=4 variant=eager tflops=0.072 gbps=114.975 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.769 ms=0.228
shape=4 variant=sota tflops=0.062 gbps=98.922 ms=0.085
shape=4 variant=solution tflops=0.092 gbps=146.983 ms=0.057
shape=4 solution_peak_fraction=0.0817
peak_fraction: 0.0737
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -21,7 +21,7 @@
Y_ptr,
stride_xr,
stride_yr,
- N,
+ N: tl.constexpr,
BLOCK_N: tl.constexpr,
):
row_id = tl.program_id(0)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<109 chars \u2014 see below>",
"newString": "<123 chars \u2014 see below>"
}def _softmax_kernel(
X_ptr,
Y_ptr,
stride_xr,
stride_yr,
N,
BLOCK_N: tl.constexpr,
):def _softmax_kernel(
X_ptr,
Y_ptr,
stride_xr,
stride_yr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
):Edit applied successfully.{
"command": "uv run python check.py 2>&1 && echo \"---\" && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Test with constexpr N"
}PASS
---
shape=0 variant=eager tflops=0.010 gbps=15.482 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.301 ms=0.102
shape=0 variant=sota tflops=0.009 gbps=15.087 ms=0.070
shape=0 variant=solution tflops=0.053 gbps=85.112 ms=0.012
shape=0 solution_peak_fraction=0.0473
shape=1 variant=eager tflops=0.041 gbps=66.366 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.848 ms=0.098
shape=1 variant=sota tflops=0.035 gbps=55.398 ms=0.076
shape=1 variant=solution tflops=0.134 gbps=214.170 ms=0.020
shape=1 solution_peak_fraction=0.1190
shape=2 variant=eager tflops=0.073 gbps=116.173 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.421 ms=0.218
shape=2 variant=sota tflops=0.069 gbps=109.615 ms=0.077
shape=2 variant=solution tflops=0.113 gbps=181.352 ms=0.046
shape=2 solution_peak_fraction=0.1008
shape=3 variant=eager tflops=0.073 gbps=116.276 ms=0.072
shape=3 variant=compiled tflops=0.015 gbps=23.256 ms=0.361
shape=3 variant=sota tflops=0.069 gbps=110.330 ms=0.076
shape=3 variant=solution tflops=0.064 gbps=102.943 ms=0.081
shape=3 solution_peak_fraction=0.0572
shape=4 variant=eager tflops=0.072 gbps=115.279 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.878 ms=0.227
shape=4 variant=sota tflops=0.062 gbps=98.662 ms=0.085
shape=4 variant=solution tflops=0.097 gbps=155.667 ms=0.054
shape=4 solution_peak_fraction=0.0865
peak_fraction: 0.0775
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,8 +1,7 @@
"""Triton softmax with online max-reduce + Kahan-compensated summation.
-Optimal single-kernel approach tuned for RTX PRO 6000 Blackwell.
-Sweep-validated optimal: BLOCK_N=4096, 8 warps for large shapes;
-BLOCK_N=2048, 4 warps for small shapes.
+Strategy: use the fused 2-pass online kernel but with N as constexpr for
+full loop unroll + optimal block/warp/stage config per shape.
"""
import torch
import torch.nn as nn
@@ -53,10 +52,8 @@
new_sum = running_sum + y_k
comp = (new_sum - running_sum) - y_k
running_sum = new_sum
-
running_max = new_max
- # Broadcast [1] to [BLOCK_N] for element-wise ops
final_max = running_max.broadcast_to([BLOCK_N])
total_sum_vec = running_sum.broadcast_to([BLOCK_N])
@@ -70,26 +67,25 @@
tl.store(y_base + idx, y, mask=mask)
+def _get_launch_config(vocab: int) -> tuple[int, int, int]:
+ """Return (block_n, num_warps, num_stages) optimal for this vocab size.
+
+ Based on empirical sweep on RTX PRO 6000 Blackwell.
+ """
+ if vocab <= 2048:
+ return vocab, 4, 3
+ if vocab <= 4096:
+ return 2048, 4, 3
+ return min(4096, vocab), 8, 2
+
+
def triton_softmax(x: torch.Tensor) -> torch.Tensor:
assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
batch, vocab = x.shape
y = torch.empty_like(x)
- # Sweep-validated optimal block sizes
- if vocab <= 2048:
- block_n = vocab
- num_warps = 4
- num_stages = 3
- elif vocab <= 4096:
- block_n = 2048
- num_warps = 4
- num_stages = 3
- else:
- block_n = min(4096, vocab)
- num_warps = 8
- num_stages = 2
+ block_n, num_warps, num_stages = _get_launch_config(vocab)
- # Ensure divisibility
while vocab % block_n != 0 and block_n > 32:
block_n //= 2
{
"content": "<3346 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,7 +1,7 @@
"""Triton softmax with online max-reduce + Kahan-compensated summation.
-Strategy: use the fused 2-pass online kernel but with N as constexpr for
-full loop unroll + optimal block/warp/stage config per shape.
+Persistent kernel variant: all SMs pull rows from an atomic counter,
+giving full GPU saturation even at small batch sizes.
"""
import torch
import torch.nn as nn
@@ -15,63 +15,67 @@
@triton.jit
-def _softmax_kernel(
+def _softmax_persistent_kernel(
X_ptr,
Y_ptr,
stride_xr,
stride_yr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
+ BATCH: tl.constexpr,
+ COUNTER_PTR,
):
- row_id = tl.program_id(0)
- x_base = X_ptr + row_id * stride_xr
- y_base = Y_ptr + row_id * stride_yr
+ """Persistent kernel: each SM pulls rows from atomic counter."""
+ while True:
+ row_id = tl.atomic_add(COUNTER_PTR, 1)
+ if row_id >= BATCH:
+ return
- # --- Online pass: running max + Kahan-sum -------------------------------
- running_max = tl.full([1], float('-inf'), dtype=tl.float32)
- running_sum = tl.full([1], 0.0, dtype=tl.float32)
- comp = tl.full([1], 0.0, dtype=tl.float32)
+ x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
- for off in range(0, N, BLOCK_N):
- idx = off + tl.arange(0, BLOCK_N)
- mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
+ # --- Online pass: running max + Kahan-sum ---------------------------
+ running_max = tl.full([1], float('-inf'), dtype=tl.float32)
+ running_sum = tl.full([1], 0.0, dtype=tl.float32)
+ comp = tl.full([1], 0.0, dtype=tl.float32)
- chunk_max = tl.max(x, keep_dims=True)
- new_max = tl.maximum(running_max, chunk_max)
- scale = tl.exp(running_max - new_max)
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
- running_sum = running_sum * scale
- comp = comp * scale
+ chunk_max = tl.max(x, keep_dims=True)
+ new_max = tl.maximum(running_max, chunk_max)
+ scale = tl.exp(running_max - new_max)
- e = tl.exp(x - new_max)
- e_masked = tl.where(mask, e, 0.0)
- e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
+ running_sum = running_sum * scale
+ comp = comp * scale
- y_k = e_chunk - comp
- new_sum = running_sum + y_k
- comp = (new_sum - running_sum) - y_k
- running_sum = new_sum
- running_max = new_max
+ e = tl.exp(x - new_max)
+ e_masked = tl.where(mask, e, 0.0)
+ e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
- final_max = running_max.broadcast_to([BLOCK_N])
- total_sum_vec = running_sum.broadcast_to([BLOCK_N])
+ y_k = e_chunk - comp
+ new_sum = running_sum + y_k
+ comp = (new_sum - running_sum) - y_k
+ running_sum = new_sum
+ running_max = new_max
- # --- Normalize pass: write output ---------------------------------------
- for off in range(0, N, BLOCK_N):
- idx = off + tl.arange(0, BLOCK_N)
- mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
- e = tl.exp(x - final_max)
- y = tl.where(mask, e / total_sum_vec, 0.0)
- tl.store(y_base + idx, y, mask=mask)
+ final_max = running_max.broadcast_to([BLOCK_N])
+ total_sum_vec = running_sum.broadcast_to([BLOCK_N])
+
+ # --- Normalize pass: write output -----------------------------------
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
+ e = tl.exp(x - final_max)
+ y = tl.where(mask, e / total_sum_vec, 0.0)
+ tl.store(y_base + idx, y, mask=mask)
def _get_launch_config(vocab: int) -> tuple[int, int, int]:
- """Return (block_n, num_warps, num_stages) optimal for this vocab size.
-
- Based on empirical sweep on RTX PRO 6000 Blackwell.
- """
+ """Return (block_n, num_warps, num_stages) optimal for this vocab size."""
if vocab <= 2048:
return vocab, 4, 3
if vocab <= 4096:
@@ -79,19 +83,30 @@
return min(4096, vocab), 8, 2
+# RTX PRO 6000 has ~128 SMs
+_NUM_SMS = 128
+
+
def triton_softmax(x: torch.Tensor) -> torch.Tensor:
assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
batch, vocab = x.shape
y = torch.empty_like(x)
+ # Atomic counter on GPU
+ counter = torch.zeros(1, dtype=torch.int32, device=x.device)
+
block_n, num_warps, num_stages = _get_launch_config(vocab)
-
while vocab % block_n != 0 and block_n > 32:
block_n //= 2
- _softmax_kernel[(batch,)](
+ # Use enough blocks to saturate SMs, but not wasteful
+ num_programs = max(batch, _NUM_SMS)
+
+ _softmax_persistent_kernel[(num_programs,)](
x, y, x.stride(0), y.stride(0), vocab,
BLOCK_N=block_n,
+ BATCH=batch,
+ COUNTER_PTR=counter,
num_warps=num_warps,
num_stages=num_stages,
)
{
"content": "<3900 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Test persistent kernel approach"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 126, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py", line 88, in main
sol_out = sol_model(x)
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 123, in forward
return triton_softmax(x)
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py", line 105, in triton_softmax
_softmax_persistent_kernel[(num_programs,)](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in ru
... (1509 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -26,52 +26,9 @@
COUNTER_PTR,
):
"""Persistent kernel: each SM pulls rows from atomic counter."""
- while True:
- row_id = tl.atomic_add(COUNTER_PTR, 1)
- if row_id >= BATCH:
- return
+ pid = tl.program_id(0)
+ row_id = tl.atomic_add(COUNTER_PTR, pid)
- x_base = X_ptr + row_id * stride_xr
- y_base = Y_ptr + row_id * stride_yr
-
- # --- Online pass: running max + Kahan-sum ---------------------------
- running_max = tl.full([1], float('-inf'), dtype=tl.float32)
- running_sum = tl.full([1], 0.0, dtype=tl.float32)
- comp = tl.full([1], 0.0, dtype=tl.float32)
-
- for off in range(0, N, BLOCK_N):
- idx = off + tl.arange(0, BLOCK_N)
- mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
-
- chunk_max = tl.max(x, keep_dims=True)
- new_max = tl.maximum(running_max, chunk_max)
- scale = tl.exp(running_max - new_max)
-
- running_sum = running_sum * scale
- comp = comp * scale
-
- e = tl.exp(x - new_max)
- e_masked = tl.where(mask, e, 0.0)
- e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
-
- y_k = e_chunk - comp
- new_sum = running_sum + y_k
- comp = (new_sum - running_sum) - y_k
- running_sum = new_sum
- running_max = new_max
-
- final_max = running_max.broadcast_to([BLOCK_N])
- total_sum_vec = running_sum.broadcast_to([BLOCK_N])
-
- # --- Normalize pass: write output -----------------------------------
- for off in range(0, N, BLOCK_N):
- idx = off + tl.arange(0, BLOCK_N)
- mask = idx < N
- x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
- e = tl.exp(x - final_max)
- y = tl.where(mask, e / total_sum_vec, 0.0)
- tl.store(y_base + idx, y, mask=mask)
def _get_launch_config(vocab: int) -> tuple[int, int, int]:
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<2022 chars \u2014 see below>",
"newString": "<330 chars \u2014 see below>"
}@triton.jit
def _softmax_persistent_kernel(
X_ptr,
Y_ptr,
stride_xr,
stride_yr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
BATCH: tl.constexpr,
COUNTER_PTR,
):
"""Persistent kernel: each SM pulls rows from atomic counter."""
while True:
row_id = tl.atomic_add(COUNTER_PTR, 1)
if row_id >= BATCH:
return
x_base = X_ptr + row_id * stride_xr
y_base = Y_ptr + row_id * stride_yr
# --- Online pass: running max + Kahan-sum ---------------------------
running_max = tl.full([1], float('-inf'), dtype=tl.float32)
running_sum = tl.full([1], 0.0, dtype=tl.float32)
comp = tl.full([1], 0.0, dtype=tl.float32)
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
chunk_max = tl.max(x, keep_dims=True)
new_max = tl.maximum(running_max, chunk_max)
scale = tl.exp(running_max - new_max)
running_sum = running_sum * scale
comp = comp * scale
e = tl.exp(x - new_max)
e_masked = tl.where(mask, e, 0.0)
e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
y_k = e_chunk - comp
new_sum = running_sum + y_k
comp = (new_sum - running_sum) - y_k
running_sum = new_sum
running_max = new_max
final_max = running_max.broadcast_to([BLOCK_N])
total_sum_vec = running_sum.broadcast_to([BLOCK_N])
# --- Normalize pass: write output -----------------------------------
for off in range(0, N, BLOCK_N):
idx = off + tl.arange(0, BLOCK_N)
mask = idx < N
x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
e = tl.exp(x - final_max)
y = tl.where(mask, e / total_sum_vec, 0.0)
tl.store(y_base + idx, y, mask=mask)@triton.jit
def _softmax_persistent_kernel(
X_ptr,
Y_ptr,
stride_xr,
stride_yr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
BATCH: tl.constexpr,
COUNTER_PTR,
):
"""Persistent kernel: each SM pulls rows from atomic counter."""
pid = tl.program_id(0)
row_id = tl.atomic_add(COUNTER_PTR, pid)
Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,7 +1,15 @@
"""Triton softmax with online max-reduce + Kahan-compensated summation.
-Persistent kernel variant: all SMs pull rows from an atomic counter,
-giving full GPU saturation even at small batch sizes.
+Optimized for RTX PRO 6000 Blackwell (SM120, GDDR7, 1.8 TB/s).
+
+Strategy: fused 2-pass online kernel with constexpr N for full loop unrolling.
+Sweep-validated optimal config: BLOCK_N=4096, 8 warps for large shapes;
+BLOCK_N=2048, 4 warps for small shapes.
+
+Numerical approach:
+ - Online max tracking with scale adjustment (avoids overflow on extreme logits)
+ - Kahan-compensated summation of exp(x - running_max) for tight fp32 accuracy
+ - Matches fp64 reference within atol=rtol=1e-5 on all shapes including 256K vocab
"""
import torch
import torch.nn as nn
@@ -15,55 +23,83 @@
@triton.jit
-def _softmax_persistent_kernel(
+def _softmax_kernel(
X_ptr,
Y_ptr,
stride_xr,
stride_yr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
- BATCH: tl.constexpr,
- COUNTER_PTR,
):
- """Persistent kernel: each SM pulls rows from atomic counter."""
- pid = tl.program_id(0)
- row_id = tl.atomic_add(COUNTER_PTR, pid)
+ """One-row-per-program online softmax.
+
+ Pass 1: single loop tracking running_max and Kahan-sum of exp(x-max).
+ Pass 2: re-read X, compute exp(x-final_max)/total_sum, write Y.
+
+ 2 memory passes (optimal for softmax), full fp32 Kahan-compensated.
+ """
+ row_id = tl.program_id(0)
+ x_base = X_ptr + row_id * stride_xr
+ y_base = Y_ptr + row_id * stride_yr
+
+ running_max = tl.full([1], float('-inf'), dtype=tl.float32)
+ running_sum = tl.full([1], 0.0, dtype=tl.float32)
+ comp = tl.full([1], 0.0, dtype=tl.float32)
+
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=float('-inf')).to(tl.float32)
+
+ chunk_max = tl.max(x, keep_dims=True)
+ new_max = tl.maximum(running_max, chunk_max)
+ scale = tl.exp(running_max - new_max)
+
+ running_sum = running_sum * scale
+ comp = comp * scale
+
+ e = tl.exp(x - new_max)
+ e_masked = tl.where(mask, e, 0.0)
+ e_chunk = tl.sum(e_masked, keep_dims=True).to(tl.float32)
+
+ y_k = e_chunk - comp
+ new_sum = running_sum + y_k
+ comp = (new_sum - running_sum) - y_k
+ running_sum = new_sum
+ running_max = new_max
+
+ final_max = running_max.broadcast_to([BLOCK_N])
+ total_sum_vec = running_sum.broadcast_to([BLOCK_N])
+
+ for off in range(0, N, BLOCK_N):
+ idx = off + tl.arange(0, BLOCK_N)
+ mask = idx < N
+ x = tl.load(x_base + idx, mask=mask, other=0.0).to(tl.float32)
+ e = tl.exp(x - final_max)
+ y = tl.where(mask, e / total_sum_vec, 0.0)
+ tl.store(y_base + idx, y, mask=mask)
-
-def _get_launch_config(vocab: int) -> tuple[int, int, int]:
- """Return (block_n, num_warps, num_stages) optimal for this vocab size."""
+def get_launch_config(vocab):
+ """Return (BLOCK_N, num_warps, num_stages) validated by sweep."""
if vocab <= 2048:
return vocab, 4, 3
if vocab <= 4096:
return 2048, 4, 3
- return min(4096, vocab), 8, 2
+ blk = min(4096, vocab)
+ while vocab % blk != 0 and blk > 32:
+ blk //= 2
+ return blk, 8, 2
-# RTX PRO 6000 has ~128 SMs
-_NUM_SMS = 128
-
-
-def triton_softmax(x: torch.Tensor) -> torch.Tensor:
+def triton_softmax(x):
assert x.is_cuda and x.dtype == torch.float32 and x.ndim == 2
batch, vocab = x.shape
y = torch.empty_like(x)
-
- # Atomic counter on GPU
- counter = torch.zeros(1, dtype=torch.int32, device=x.device)
-
- block_n, num_warps, num_stages = _get_launch_config(vocab)
- while vocab % block_n != 0 and block_n > 32:
- block_n //= 2
-
- # Use enough blocks to saturate SMs, but not wasteful
- num_programs = max(batch, _NUM_SMS)
-
- _softmax_persistent_kernel[(num_programs,)](
+ block_n, num_warps, num_stages = get_launch_config(vocab)
+ _softmax_kernel[(batch,)](
x, y, x.stride(0), y.stride(0), vocab,
BLOCK_N=block_n,
- BATCH=batch,
- COUNTER_PTR=counter,
num_warps=num_warps,
num_stages=num_stages,
)
@@ -71,12 +107,12 @@
class Model(nn.Module):
- def __init__(self, batch: int, vocab: int):
+ def __init__(self, batch, vocab):
super().__init__()
self.batch = batch
self.vocab = vocab
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x):
return triton_softmax(x)
{
"content": "<3665 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && echo \"---\" && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Final verify + benchmark"
}PASS
---
shape=0 variant=eager tflops=0.010 gbps=15.413 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.115 ms=0.104
shape=0 variant=sota tflops=0.010 gbps=15.739 ms=0.067
shape=0 variant=solution tflops=0.054 gbps=85.668 ms=0.012
shape=0 solution_peak_fraction=0.0476
shape=1 variant=eager tflops=0.041 gbps=65.097 ms=0.064
shape=1 variant=compiled tflops=0.027 gbps=43.073 ms=0.097
shape=1 variant=sota tflops=0.036 gbps=57.324 ms=0.073
shape=1 variant=solution tflops=0.133 gbps=212.952 ms=0.020
shape=1 solution_peak_fraction=0.1183
shape=2 variant=eager tflops=0.072 gbps=115.228 ms=0.073
shape=2 variant=compiled tflops=0.024 gbps=38.454 ms=0.218
shape=2 variant=sota tflops=0.070 gbps=111.361 ms=0.075
shape=2 variant=solution tflops=0.113 gbps=180.478 ms=0.046
shape=2 solution_peak_fraction=0.1003
shape=3 variant=eager tflops=0.072 gbps=115.127 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.288 ms=0.360
shape=3 variant=sota tflops=0.069 gbps=110.609 ms=0.076
shape=3 variant=solution tflops=0.064 gbps=102.641 ms=0.082
shape=3 solution_peak_fraction=0.0570
shape=4 variant=eager tflops=0.072 gbps=115.431 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.831 ms=0.228
shape=4 variant=sota tflops=0.063 gbps=101.019 ms=0.083
shape=4 variant=solution tflops=0.097 gbps=155.345 ms=0.054
shape=4 solution_peak_fraction=0.0863
peak_fraction: 0.0774
RESULT: LOW
{
"command": "rm -f /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution_draft.py /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/profile_scratch.py /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sweep_configs.py",
"description": "Remove scratch files"
}(no output){
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Final verification"
}PASS
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py</path>
<type>file</type>
<content>
1: """Triton softmax with online max-reduce + Kahan-compensated summation.
2:
3: Optimized for RTX PRO 6000 Blackwell (SM120, GDDR7, 1.8 TB/s).
4:
5: Strategy: fused 2-pass online kernel with constexpr N for full loop unrolling.
6: Sweep-validated optimal config: BLOCK_N=4096, 8 warps for large shapes;
7: BLOCK_N=2048, 4 warps for small shapes.
8:
9: Numerical approach:
10: - Online max tracking with scale adjustment (avoids overflow on extreme logits)
11: - Kahan-compensated summation of exp(x - running_max) for tight fp32 accuracy
12: - Matches fp64 reference within atol=rtol=1e-5 on all shapes including 256K vocab
13: """
14: import torch
15: import torch.nn as nn
16: import triton
17: import triton.language as tl
18:
19:
20: OP_TYPE = "softmax"
21: SUPPORTED_PRECISIONS = ["fp32"]
22: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
23:
24:
25: @triton.jit
26: def _softmax_kernel(
27: X_ptr,
28: Y_ptr,
29: stride_xr,
30: stride_yr,
31: N: tl.constexpr,
32: BLOCK_N: tl.constexpr,
33: ):
34: """One-row-per-program online softmax.
35:
36: Pass 1: single loop tracking running_max and Kahan-sum of exp(x-max).
37: Pass 2: re-read X, compute exp(x-final_max)/total_sum, write Y.
38:
39: 2 memory passes (optimal for softmax), full fp32 Kahan-compensated.
40: """
41: row_id = tl.program_id(0)
42: x
... (2862 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax</path>
<type>directory</type>
<entries>
__pycache__/
benchmark.py
check.py
framework.txt
problem.yaml
PROMPT.txt
reference.py
shapes.py
solution.py
sota.py
(10 entries)
</entries>