"""Numerically tight softmax with Kahan summation for large-vocab stability.
Implements: y = softmax(x, dim=-1) in fp32 with subtract-max stability.
Uses two-pass Kahan-style compensated summation to keep fp32 accumulation
error under 1e-5 tolerance even on the largest vocab sizes (256K+).
"""
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
cuda_src = '''
#include <torch/extension.h>
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum_val = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum_val + y;
c = (t - sum_val) - y;
sum_val = t;
}
extern __shared__ float sdata[];
sdata[threadIdx.x] = sum_val;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float a = sdata[threadIdx.x];
float b = sdata[threadIdx.x + s];
float y = b - c;
float t = a + y;
c = (t - a) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
torch::Tensor softmax_exp_cuda(torch::Tensor x) {
int batch = x.size(0);
int n_cols = x.size(1);
auto max_vals = x.amax(-1);
auto exp_out = torch::empty_like(x);
int block_size = 1024;
int num_blocks = batch;
softmax_exp_kernel<<<num_blocks, block_size>>>(
exp_out.data_ptr<float>(),
x.data_ptr<float>(),
max_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return exp_out;
}
torch::Tensor kahan_reduce_cuda(torch::Tensor exp_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto sum_out = torch::zeros({batch}, exp_vals.dtype());
int block_size = 1024;
int num_blocks = batch;
int shared_size = block_size * sizeof(float);
kahan_reduce_kernel<<<num_blocks, block_size, shared_size>>>(
sum_out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return sum_out;
}
torch::Tensor normalize_softmax_cuda(torch::Tensor exp_vals, torch::Tensor sum_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto out = torch::empty_like(exp_vals);
int block_size = 1024;
int num_blocks = batch;
normalize_kernel<<<num_blocks, block_size>>>(
out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
sum_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("softmax_exp", &softmax_exp_cuda);
m.def("kahan_reduce", &kahan_reduce_cuda);
m.def("normalize", &normalize_softmax_cuda);
}
'''
_softmax_module = None
def _get_softmax_module():
global _softmax_module
if _softmax_module is None:
_softmax_module = load_inline(
name='softmax_kahan',
cpp_sources='',
cuda_sources=cuda_src,
verbose=False
)
return _softmax_module
class Model(nn.Module):
"""Softmax with Kahan summation for numerical accuracy.
Uses subtract-max stability and Kahan-style compensated summation
for fp32 accumulation to stay under 1e-5 tolerance on all shapes.
"""
def __init__(self, batch: int, vocab: int):
super().__init__()
self.batch = batch
self.vocab = vocab
def forward(self, x: torch.Tensor) -> torch.Tensor:
mod = _get_softmax_module()
exp_out = mod.softmax_exp(x)
sum_vals = mod.kahan_reduce(exp_out)
out = mod.normalize(exp_out, sum_vals)
return out
BATCH = 8
VOCAB = 32768
def get_inputs():
x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
return [x]
def get_init_inputs():
return [BATCH, VOCAB]
shape=0 variant=eager tflops=0.010 gbps=15.482 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.227 ms=0.103
shape=0 variant=sota tflops=0.032 gbps=51.441 ms=0.020
shape=0 variant=solution tflops=0.017 gbps=27.227 ms=0.039
shape=0 solution_peak_fraction=0.0151
shape=1 variant=eager tflops=0.041 gbps=66.131 ms=0.063
shape=1 variant=compiled tflops=0.026 gbps=41.623 ms=0.101
shape=1 variant=sota tflops=0.126 gbps=201.185 ms=0.021
shape=1 variant=solution tflops=0.052 gbps=82.617 ms=0.051
shape=1 solution_peak_fraction=0.0459
shape=2 variant=eager tflops=0.072 gbps=115.839 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.141 ms=0.220
shape=2 variant=sota tflops=0.104 gbps=166.230 ms=0.050
shape=2 variant=solution tflops=0.054 gbps=85.808 ms=0.098
shape=2 solution_peak_fraction=0.0477
shape=3 variant=eager tflops=0.072 gbps=115.431 ms=0.073
shape=3 variant=compiled tflops=0.014 gbps=23.148 ms=0.362
shape=3 variant=sota tflops=0.095 gbps=151.353 ms=0.055
shape=3 variant=solution tflops=0.034 gbps=54.078 ms=0.155
shape=3 solution_peak_fraction=0.0300
shape=4 variant=eager tflops=0.072 gbps=115.584 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.409 ms=0.230
shape=4 variant=sota tflops=0.095 gbps=152.498 ms=0.055
shape=4 variant=solution tflops=0.051 gbps=80.971 ms=0.104
shape=4 solution_peak_fraction=0.0450
peak_fraction: 0.0339
RESULT: LOW
PASS
{
"problem": "04_kahan_softmax",
"harness": "opencode",
"model": "openrouter-pinned/minimax/minimax-m2.7",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0339,
"elapsed_seconds": 995,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 377708, "output_tokens": 17605, "cache_read_tokens": 980723, "cache_creation_tokens": 0, "reasoning_tokens": 9045, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py</path>
<type>file</type>
<content>
1: """Naive softmax over the last dim, computed in fp64 for ground-truth.
2:
3: The reference deliberately runs in float64 so that fp16 / fp32 accumulation
4: drift in agent solutions is exposed by the tight tolerance in problem.yaml.
5: The agent's job is to produce an fp32 softmax whose values match this
6: double-precision reference within atol=rtol=1e-5 — this requires either
7: fp32 accumulation or compensated (Kahan) summation when vocab is large.
8: """
9: import torch
10: import torch.nn as nn
11:
12: OP_TYPE = "softmax"
13: SUPPORTED_PRECISIONS = ["fp32"]
14: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
15:
16:
17: class Model(nn.Module):
18: """y = softmax(x, dim=-1) computed in fp64 then returned as fp32.
19:
20: No learned parameters — softmax is parameter-free. We still expose an
21: empty state_dict so the harness's strict load_state_dict matches.
22: """
23:
24: def __init__(self, batch: int, vocab: int):
25: super().__init__()
26: self.batch = batch
27: self.vocab = vocab
28:
29: def forward(self, x: torch.Tensor) -> torch.Tensor:
30: # Promote to fp64 for the ground-truth pathway. Even with double
31: # precision we still subtract the row-max for stability.
32: x64 = x.to(torch.float64)
33: m = x64.amax(dim=-1, keepdim=True)
34: e = torch.exp
... (592 more chars){
"pattern": "*",
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/PROMPT.txt
/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py
/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py
/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py
/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml
/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py
/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Kahan-corrected softmax.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3
4: seeds each, with the tight (1e-5) fp32 tolerance from problem.yaml. Also
5: rejects forbidden ops via grep.
6: """
7: import re
8: import sys
9: from pathlib import Path
10:
11: import torch
12: import yaml
13:
14: # Make the repo's src/ importable
15: REPO_ROOT = Path(__file__).resolve().parents[2]
16: sys.path.insert(0, str(REPO_ROOT))
17:
18: from src.eval.correctness import check_correctness # noqa: E402
19:
20:
21: def _make_inputs(batch: int, vocab: int, extreme: bool, seed: int) -> torch.Tensor:
22: g = torch.Generator().manual_seed(seed)
23: if extreme:
24: # Adversarial: most logits are mild but a handful per row are huge.
25: # If the kernel forgets to subtract the row-max before exp, this
26: # row overflows fp32 and produces NaN/Inf. If it accumulates in
27: # fp16, the long tail of small exp() values is lost beneath the
28: # tolerance threshold.
29: x = torch.randn(batch, vocab, generator=g) * 2.0
30: # Spike: 4 very large positive logits per row.
31: idx = torch.randint(0, vocab, (batch, 4), generator=g)
32: x.scatter_(1, idx, 30.0)
33: else:
34: x = torch.randn(batch, vocab, generator=g) * 4.0
35: return x.to(torc
... (3464 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Kahan-corrected softmax.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if
4: available), and the agent's solution. Reports achieved TFLOPS, GB/s, and
5: peak_fraction. Softmax is memory-bound, so the score is GB/s / peak_dram.
6:
7: Output lines the harness picks up:
8: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10: """
11: import sys
12: from math import exp, log
13: from pathlib import Path
14:
15: import torch
16: import yaml
17:
18: REPO_ROOT = Path(__file__).resolve().parents[2]
19: sys.path.insert(0, str(REPO_ROOT))
20:
21: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22: from src.eval.timing import time_fn # noqa: E402
23: from src.hardware import get as get_hw # noqa: E402
24:
25:
26: def _eval_formula(expr: str, vars: dict) -> float:
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def _make_inputs(batch: int, vocab: int, extreme: bool) -> torch.Tensor:
31: if extreme:
32: x = torch.randn(batch, vocab) * 2.0
33: idx = torch.randint(0, vocab, (batch, 4))
34: x.scatter_(1, idx, 30.0)
35: else:
36: x = torch.randn(batch, vocab) * 4.0
37: return x.to(torch.float32)
38:
39:
40: def main():
41:
... (3706 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Kahan-corrected softmax.
2:
3: The point of this problem is numerical accuracy on long reductions. Shapes
4: mix typical LLM vocab sizes with deliberately adversarial regimes:
5:
6: - small vocab (sanity check; naive fp32 should pass)
7: - Llama3 vocab 128K (real-world, where fp16 accumulation starts to drift)
8: - 256K (DeepSeek-V3 / Gemma-3 class vocab; naive fp16 sum DOES drift past
9: the 1e-5 tolerance — this row is what proves Kahan was needed)
10: - extreme-logit edge case (large positive logits stress max-subtract +
11: summation; if the implementation accidentally exps before subtracting
12: max, this row overflows)
13:
14: The 'extreme' flag is read by check.py to switch input generation to a
15: distribution that produces a few very large logits per row.
16: """
17:
18: SHAPES = [
19: {"batch": 32, "vocab": 4096, "extreme": False}, # sanity
20: {"batch": 16, "vocab": 32768, "extreme": False}, # GPT-2 class
21: {"batch": 8, "vocab": 131072, "extreme": False}, # Llama3 vocab
22: {"batch": 4, "vocab": 262144, "extreme": False}, # 256K — Kahan needed
23: {"batch": 8, "vocab": 131072, "extreme": True}, # extreme logits edge
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml</path>
<type>file</type>
<content>
1: name: 04_kahan_softmax
2: display_name: "Kahan-corrected Softmax"
3: precision: fp32
4: regime: memory # softmax is bandwidth-bound: 2 passes over the input tensor
5:
6: # Softmax FLOPs: per-element exp + 2 reductions + divide. Roughly 5 flops/elt.
7: flops_formula: "5 * batch * vocab"
8:
9: # Bytes moved: read x once, write y once. Both fp32.
10: bytes_formula: "batch * vocab * 4 + batch * vocab * 4"
11:
12: hardware: [RTX_PRO_6000]
13: peak_tflops_key: fp32
14: peak_bandwidth_key: dram
15:
16: # TIGHTER than default (fp32 default is 1e-4). This problem exists
17: # specifically to test whether the agent uses compensated summation, so
18: # we squeeze the tolerance to 1e-5 — naive fp16 sum across 256K elements
19: # drifts past this; fp32 accumulation passes; Kahan/fp32 always passes.
20: tolerance:
21: "torch.float32": {"atol": 1.0e-5, "rtol": 1.0e-5}
22:
23: # Forbidden ops — block the obvious "just call the library" cheats. The
24: # agent must implement softmax themselves with explicit (compensated)
25: # summation logic.
26: forbidden:
27: - "torch.nn.functional.softmax"
28: - "torch.softmax"
29: - "F.softmax"
30: - "liger_kernel.softmax"
31: - "liger_kernel.transformers.softmax"
32: - ".softmax("
33:
34: sota:
35: name: "Liger-Kernel Softmax (Triton)"
36: url: "https://github.com/linkedin/Liger-Kernel"
37: function: "liger_
... (225 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for last-dim softmax.
2:
3: Preference order:
4: 1. liger-kernel's Triton softmax (LigerSoftmaxFunction) — fused, fp32-acc
5: 2. torch.compile(torch.softmax) — Inductor generates a Triton kernel that
6: is competitive on bandwidth-bound reductions
7:
8: Agents are FORBIDDEN from importing either of these in solution.py (see
9: problem.yaml.forbidden). This file is only the benchmark's reference line.
10: """
11: from __future__ import annotations
12:
13: import torch
14:
15: _compiled_softmax = None
16:
17:
18: def _liger_softmax(x: torch.Tensor) -> torch.Tensor | None:
19: try:
20: from liger_kernel.ops.softmax import LigerSoftmaxFunction
21: return LigerSoftmaxFunction.apply(x)
22: except Exception:
23: return None
24:
25:
26: def _compiled(x: torch.Tensor) -> torch.Tensor:
27: global _compiled_softmax
28: if _compiled_softmax is None:
29: _compiled_softmax = torch.compile(
30: lambda t: torch.softmax(t, dim=-1),
31: mode="reduce-overhead",
32: )
33: return _compiled_softmax(x)
34:
35:
36: def sota_forward(x: torch.Tensor) -> torch.Tensor:
37: """Best-available softmax reference. x: (batch, vocab) fp32."""
38: out = _liger_softmax(x)
39: if out is not None:
40: return out
41: return _compiled(x)
42:
43:
44: def is_a
... (128 more chars){
"command": "nvidia-smi --query-gpu=name,memory.total,compute_cap --format=csv",
"description": "Check GPU compute capability"
}name, memory.total [MiB], compute_cap
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 12.0
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware.py"
}{
"command": "python -c \"import triton; print(triton.__version__)\"",
"description": "Check Triton availability"
}3.6.0
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/src/hardware/",
"description": "List hardware module contents"
}__init__.py
m4_max.py
__pycache__
rtx_pro_6000.py
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -0,0 +1,94 @@
+"""Numerically tight softmax with Kahan summation for large-vocab stability.
+
+Implements: y = softmax(x, dim=-1) in fp32 with subtract-max stability.
+For vocab > ~130K, uses Kahan-style compensated summation to keep fp32
+accumulation error under 1e-5 tolerance.
+"""
+import math
+import torch
+import torch.nn as nn
+import triton
+import triton.jit as triton_jit
+
+
+OP_TYPE = "softmax"
+SUPPORTED_PRECISIONS = ["fp32"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+@triton_jit
+def kahan_sum_kernel(output, input, n_elements, BLOCK_SIZE: tl.constexpr):
+ pid = tl.program_id(0)
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+
+ row = tl.load(input + offsets, mask=mask, other=0.0).to(tl.float32)
+
+ sum_val = tl.sum(row, axis=0)
+ tl.store(output + pid, sum_val)
+
+
+class Model(nn.Module):
+ """Softmax with optional Kahan summation for large vocab sizes.
+
+ Uses subtract-max stability and fp32 accumulation. For vocab >= LARGE_VOCAB
+ threshold, falls back to block-wise Kahan summation to stay under 1e-5 tol.
+ """
+
+ LARGE_VOCAB = 131072
+
+ def __init__(self, batch: int, vocab: int):
+ super().__init__()
+ self.batch = batch
+ self.vocab = vocab
+ self.use_kahan = vocab >= self.LARGE_VOCAB
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.use_kahan:
+ return self._softmax_kahan(x)
+ else:
+ return self._softmax_fp32(x)
+
+ def _softmax_fp32(self, x: torch.Tensor) -> torch.Tensor:
+ m = x.amax(dim=-1, keepdim=True)
+ e = torch.exp(x - m)
+ s = e.sum(dim=-1, keepdim=True)
+ return e / s
+
+ def _softmax_kahan(self, x: torch.Tensor) -> torch.Tensor:
+ batch, vocab = x.shape
+ max_val = x.amax(dim=-1, keepdim=True)
+ exp_x = torch.exp(x - max_val)
+
+ output = torch.empty_like(x)
+ n_elements = vocab
+
+ def grid(metan):
+ return (batch,)
+
+ if vocab <= 32768:
+ BLOCK_SIZE = 256
+ elif vocab <= 131072:
+ BLOCK_SIZE = 512
+ else:
+ BLOCK_SIZE = 1024
+
+ kahan_sum_kernel[grid](
+ output, exp_x, n_elements, BLOCK_SIZE=BLOCK_SIZE
+ )
+
+ output = output.view(batch, vocab)
+ return exp_x / output
+
+
+BATCH = 8
+VOCAB = 32768
+
+
+def get_inputs():
+ x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
+ return [x]
+
+
+def get_init_inputs():
+ return [BATCH, VOCAB]
{
"content": "<2472 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,10 +1,9 @@
"""Numerically tight softmax with Kahan summation for large-vocab stability.
Implements: y = softmax(x, dim=-1) in fp32 with subtract-max stability.
-For vocab > ~130K, uses Kahan-style compensated summation to keep fp32
-accumulation error under 1e-5 tolerance.
+For vocab >= 130K, uses two-pass Kahan-style compensated summation to keep
+ fp32 accumulation error under 1e-5 tolerance.
"""
-import math
import torch
import torch.nn as nn
import triton
@@ -17,25 +16,86 @@
@triton_jit
-def kahan_sum_kernel(output, input, n_elements, BLOCK_SIZE: tl.constexpr):
+def softmax_exp_kernel(
+ exp_out, x_ptr, max_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
+):
pid = tl.program_id(0)
- offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- mask = offsets < n_elements
+ row_offset = pid * row_stride
+ offsets = row_offset + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < row_offset + n_cols
- row = tl.load(input + offsets, mask=mask, other=0.0).to(tl.float32)
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ max_val = tl.load(max_ptr + pid)
+ exp_x = tl.exp(x - max_val)
+ tl.store(exp_out + offsets, exp_x, mask=mask)
- sum_val = tl.sum(row, axis=0)
- tl.store(output + pid, sum_val)
+
+@triton_jit
+def kahan_reduce_kernel(
+ sum_out, exp_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
+):
+ pid = tl.program_id(0)
+ row_offset = pid * row_stride
+ offsets = row_offset + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < row_offset + n_cols
+
+ exp_vals = tl.load(exp_ptr + offsets, mask=mask, other=0.0)
+
+ sum_val = tl.sum(exp_vals)
+ tl.store(sum_out + pid, sum_val)
+
+
+@triton_jit
+def kahan_compensate_kernel(
+ compensation_out, exp_ptr, sum_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
+):
+ pid = tl.program_id(0)
+ row_offset = pid * row_stride
+ offsets = row_offset + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < row_offset + n_cols
+
+ exp_vals = tl.load(exp_ptr + offsets, mask=mask, other=0.0)
+ row_sum = tl.load(sum_ptr + pid)
+
+ c = 0.0
+ for i in range(BLOCK_SIZE):
+ global_i = pid * row_stride + i
+ if global_i >= row_offset + n_cols:
+ break
+ e = exp_vals[i]
+ y = e - c
+ t = row_sum + y
+ c = (t - row_sum) - y
+
+ tl.store(compensation_out + pid, c)
+
+
+@triton_jit
+def normalize_kernel(
+ out_ptr, exp_ptr, sum_ptr, compensation_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
+):
+ pid = tl.program_id(0)
+ row_offset = pid * row_stride
+ offsets = row_offset + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < row_offset + n_cols
+
+ exp_vals = tl.load(exp_ptr + offsets, mask=mask, other=0.0)
+ row_sum = tl.load(sum_ptr + pid)
+ row_comp = tl.load(compensation_ptr + pid)
+
+ denom = row_sum + row_comp
+ result = exp_vals / denom
+ tl.store(out_ptr + offsets, result, mask=mask)
class Model(nn.Module):
"""Softmax with optional Kahan summation for large vocab sizes.
- Uses subtract-max stability and fp32 accumulation. For vocab >= LARGE_VOCAB
- threshold, falls back to block-wise Kahan summation to stay under 1e-5 tol.
+ Uses subtract-max stability and fp32 accumulation. For vocab >= 130K,
+ uses two-pass Kahan-style summation to stay under 1e-5 tolerance.
"""
- LARGE_VOCAB = 131072
+ LARGE_VOCAB = 130000
def __init__(self, batch: int, vocab: int):
super().__init__()
@@ -56,29 +116,37 @@
return e / s
def _softmax_kahan(self, x: torch.Tensor) -> torch.Tensor:
- batch, vocab = x.shape
- max_val = x.amax(dim=-1, keepdim=True)
- exp_x = torch.exp(x - max_val)
+ batch, n_cols = x.shape
+ row_stride = n_cols
+
+ max_vals = x.amax(dim=-1)
+
+ exp_buffer = torch.empty_like(x)
+
+ BLOCK_SIZE = 512
+ grid = (batch,)
+
+ softmax_exp_kernel[grid](
+ exp_buffer, x, max_vals, row_stride, n_cols, BLOCK_SIZE=BLOCK_SIZE
+ )
+
+ kahan_sum = torch.zeros(batch, dtype=torch.float32, device=x.device)
+ kahan_comp = torch.zeros(batch, dtype=torch.float32, device=x.device)
+
+ kahan_reduce_kernel[grid](
+ kahan_sum, exp_buffer, row_stride, n_cols, BLOCK_SIZE=BLOCK_SIZE
+ )
+
+ kahan_compensate_kernel[grid](
+ kahan_comp, exp_buffer, kahan_sum, row_stride, n_cols, BLOCK_SIZE=BLOCK_SIZE
+ )
output = torch.empty_like(x)
- n_elements = vocab
-
- def grid(metan):
- return (batch,)
-
- if vocab <= 32768:
- BLOCK_SIZE = 256
- elif vocab <= 131072:
- BLOCK_SIZE = 512
- else:
- BLOCK_SIZE = 1024
-
- kahan_sum_kernel[grid](
- output, exp_x, n_elements, BLOCK_SIZE=BLOCK_SIZE
+ normalize_kernel[grid](
+ output, exp_buffer, kahan_sum, kahan_comp, row_stride, n_cols, BLOCK_SIZE=BLOCK_SIZE
)
- output = output.view(batch, vocab)
- return exp_x / output
+ return output
BATCH = 8
{
"content": "<4535 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,8 +1,8 @@
"""Numerically tight softmax with Kahan summation for large-vocab stability.
Implements: y = softmax(x, dim=-1) in fp32 with subtract-max stability.
-For vocab >= 130K, uses two-pass Kahan-style compensated summation to keep
- fp32 accumulation error under 1e-5 tolerance.
+For vocab >= 130K, uses two-pass Kahan-style summation to keep fp32
+accumulation error under 1e-5 tolerance.
"""
import torch
import torch.nn as nn
@@ -47,7 +47,7 @@
@triton_jit
def kahan_compensate_kernel(
- compensation_out, exp_ptr, sum_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
+ comp_out, exp_ptr, sum_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
row_offset = pid * row_stride
@@ -57,22 +57,22 @@
exp_vals = tl.load(exp_ptr + offsets, mask=mask, other=0.0)
row_sum = tl.load(sum_ptr + pid)
- c = 0.0
+ 补偿 = 0.0
for i in range(BLOCK_SIZE):
- global_i = pid * row_stride + i
- if global_i >= row_offset + n_cols:
+ idx = row_offset + i
+ if idx >= row_offset + n_cols:
break
e = exp_vals[i]
- y = e - c
+ y = e - 补偿
t = row_sum + y
- c = (t - row_sum) - y
+ 补偿 = (t - row_sum) - y
- tl.store(compensation_out + pid, c)
+ tl.store(comp_out + pid, 补偿)
@triton_jit
def normalize_kernel(
- out_ptr, exp_ptr, sum_ptr, compensation_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
+ out_ptr, exp_ptr, sum_ptr, comp_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
row_offset = pid * row_stride
@@ -81,7 +81,7 @@
exp_vals = tl.load(exp_ptr + offsets, mask=mask, other=0.0)
row_sum = tl.load(sum_ptr + pid)
- row_comp = tl.load(compensation_ptr + pid)
+ row_comp = tl.load(comp_ptr + pid)
denom = row_sum + row_comp
result = exp_vals / denom
{
"content": "<4477 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "<2427 chars \u2014 see below>",
"description": "Test CUDA load_inline for softmax"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float sum = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum + y;
c = (t - sum) - y;
sum = t;
}
// Thread reduction
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
__syncthreads();
if (threadIdx.x < s && threadIdx.x + s < n_cols) {
float x = exp_vals[row_offset + threadIdx.x + s];
float y = x - c;
float t = sum + y;
c = (t - sum) - y;
sum = t;
}
}
if (threadIdx.x == 0) {
sum_out[row] = sum + c; // Add compensation
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_src='',
cuda_src=cuda_src,
functions=[('softmax_exp_kernel', None), ('kahan_reduce_kernel', None), ('normalize_kernel', None)],
verbose=False
)
print('CUDA module loaded successfully')
print(dir(softmax_module))
"
Traceback (most recent call last):
File "<string>", line 81, in <module>
TypeError: load_inline() got an unexpected keyword argument 'cpp_src'
{
"command": "uv run python -c \"import torch; print(torch.__version__); help(torch.utils.cpp_extension.load_inline)\" 2>&1 | head -50",
"description": "Check torch and load_inline API"
}Traceback (most recent call last):
File "<string>", line 1, in <module>
AttributeError: module 'torch.utils' has no attribute 'cpp_extension'
2.11.0+cu130
{
"command": "uv run python -c \"import torch; from torch.utils.cpp_extension import load_inline; print('found')\" 2>&1 || uv run python -c \"import torch.utils.cpp_extension; print(dir(torch.utils.cpp_extension))\"",
"description": "Check cpp_extension availability"
}found
{
"command": "<123 chars \u2014 see below>",
"description": "Check load_inline signature"
}uv run python -c "
from torch.utils.cpp_extension import load_inline
import inspect
print(inspect.signature(load_inline))
"(name, cpp_sources, cuda_sources=None, sycl_sources=None, functions=None, extra_cflags=None, extra_cuda_cflags=None, extra_sycl_cflags=None, extra_ldflags=None, extra_include_paths=None, build_directory=None, verbose=False, with_cuda=None, with_sycl=None, is_python_module=True, with_pytorch_error_handling=True, keep_intermediates=True, use_pch=False, no_implicit_headers=False)
{
"command": "<2510 chars \u2014 see below>",
"description": "Test CUDA load_inline correctly"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float sum = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum + y;
c = (t - sum) - y;
sum = t;
}
// Thread reduction via shared memory
extern __shared__ float sdata[];
if (threadIdx.x < blockDim.x) {
sdata[threadIdx.x] = sum;
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float x = sdata[threadIdx.x + s];
float y = x - c;
float t = sdata[threadIdx.x] + y;
c = (t - sdata[threadIdx.x]) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_sources='',
cuda_sources=cuda_src,
functions=[],
verbose=False
)
print('CUDA module loaded successfully')
print(dir(softmax_module))
"CUDA module loaded successfully
['__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__']
{
"command": "<3821 chars \u2014 see below>",
"description": "Test CUDA load_inline with functions"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float sum = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum + y;
c = (t - sum) - y;
sum = t;
}
// Thread reduction via shared memory
extern __shared__ float sdata[];
if (threadIdx.x < blockDim.x) {
sdata[threadIdx.x] = sum;
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float x = sdata[threadIdx.x + s];
float y = x - c;
float t = sdata[threadIdx.x] + y;
c = (t - sdata[threadIdx.x]) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int stride = n_cols;
int row_offset = row * stride;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
torch::Tensor softmax_exp(torch::Tensor x) {
int batch = x.size(0);
int n_cols = x.size(1);
auto max_vals = x.amax(-1);
auto exp_out = torch::empty_like(x);
int block_size = 256;
int num_blocks = batch;
softmax_exp_kernel<<<num_blocks, block_size>>>(
exp_out.data_ptr<float>(),
x.data_ptr<float>(),
max_vals.data_ptr<float>(),
batch, n_cols
);
return exp_out;
}
torch::Tensor kahan_reduce(torch::Tensor exp_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto sum_out = torch::zeros({batch}, x.dtype());
int block_size = 256;
int num_blocks = batch;
kahan_reduce_kernel<<<num_blocks, block_size>>>(
sum_out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
batch, n_cols
);
return sum_out;
}
torch::Tensor normalize_softmax(torch::Tensor exp_vals, torch::Tensor sum_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto out = torch::empty_like(exp_vals);
int block_size = 256;
int num_blocks = batch;
normalize_kernel<<<num_blocks, block_size>>>(
out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
sum_vals.data_ptr<float>(),
batch, n_cols
);
return out;
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_sources='',
cuda_sources=cuda_src,
functions=['softmax_exp', 'kahan_reduce', 'normalize_softmax'],
verbose=True
)
print('CUDA module loaded successfully')
print(dir(softmax_module))
print(softmax_module.softmax_exp)
"[1/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=softmax_kahan -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/softmax_kahan/cuda.cu -o cuda.cuda.o
FAILED: [code=2] cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=softmax_kahan -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arc
... (4501 more chars){
"command": "<2789 chars \u2014 see below>",
"description": "Test CUDA load_inline with separate cpp"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum_val = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum_val + y;
c = (t - sum_val) - y;
sum_val = t;
}
// Store per-thread sum to shared
extern __shared__ float sdata[];
sdata[threadIdx.x] = sum_val;
__syncthreads();
// Reduce
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float a = sdata[threadIdx.x];
float b = sdata[threadIdx.x + s];
float y = b - c;
float t = a + y;
c = (t - a) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
'''
cpp_src = '''
torch::Tensor softmax_exp_cuda(torch::Tensor x);
torch::Tensor kahan_reduce_cuda(torch::Tensor exp_vals);
torch::Tensor normalize_softmax_cuda(torch::Tensor exp_vals, torch::Tensor sum_vals);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"softmax_exp\", &softmax_exp_cuda);
m.def(\"kahan_reduce\", &kahan_reduce_cuda);
m.def(\"normalize\", &normalize_softmax_cuda);
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_sources=cpp_src,
cuda_sources=cuda_src,
verbose=False
)
print('CUDA module loaded successfully')
"Traceback (most recent call last):
File "<string>", line 95, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2294, in _jit_compile
return _import_module_from_library(name, build_directory, is_python_module)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2812, in _import_module_from_library
module = importlib.util.module_from_spec(spec)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 573, in module_from_spec
File "<frozen importlib._bootstrap_external>", line 1233, in create_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
ImportError: /home/infatoshi/.cache/torch_extensions/py311_cu130/softmax_kahan/softmax_kahan.so: undefined symbol: _Z22normalize_softmax_cudaN2at6TensorES0_
{
"command": "<3981 chars \u2014 see below>",
"description": "Test CUDA load_inline with header in cuda"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum_val = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum_val + y;
c = (t - sum_val) - y;
sum_val = t;
}
extern __shared__ float sdata[];
sdata[threadIdx.x] = sum_val;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float a = sdata[threadIdx.x];
float b = sdata[threadIdx.x + s];
float y = b - c;
float t = a + y;
c = (t - a) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
torch::Tensor softmax_exp_cuda(torch::Tensor x) {
int batch = x.size(0);
int n_cols = x.size(1);
auto max_vals = x.amax(-1);
auto exp_out = torch::empty_like(x);
int block_size = 256;
int num_blocks = batch;
softmax_exp_kernel<<<num_blocks, block_size>>>(
exp_out.data_ptr<float>(),
x.data_ptr<float>(),
max_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return exp_out;
}
torch::Tensor kahan_reduce_cuda(torch::Tensor exp_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto sum_out = torch::zeros({batch}, exp_vals.dtype());
int block_size = 256;
int num_blocks = batch;
int shared_size = block_size * sizeof(float);
kahan_reduce_kernel<<<num_blocks, block_size, shared_size>>>(
sum_out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return sum_out;
}
torch::Tensor normalize_softmax_cuda(torch::Tensor exp_vals, torch::Tensor sum_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto out = torch::empty_like(exp_vals);
int block_size = 256;
int num_blocks = batch;
normalize_kernel<<<num_blocks, block_size>>>(
out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
sum_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"softmax_exp\", &softmax_exp_cuda);
m.def(\"kahan_reduce\", &kahan_reduce_cuda);
m.def(\"normalize\", &normalize_softmax_cuda);
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_sources='',
cuda_sources=cuda_src,
verbose=False
)
print('CUDA module loaded successfully')
print(softmax_module.softmax_exp)
"CUDA module loaded successfully
<built-in method softmax_exp of pybind11_builtins.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1 object at 0x712d9231b7b0>
{
"command": "<4397 chars \u2014 see below>",
"description": "Test CUDA softmax kernel"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum_val = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum_val + y;
c = (t - sum_val) - y;
sum_val = t;
}
extern __shared__ float sdata[];
sdata[threadIdx.x] = sum_val;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float a = sdata[threadIdx.x];
float b = sdata[threadIdx.x + s];
float y = b - c;
float t = a + y;
c = (t - a) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
torch::Tensor softmax_exp_cuda(torch::Tensor x) {
int batch = x.size(0);
int n_cols = x.size(1);
auto max_vals = x.amax(-1);
auto exp_out = torch::empty_like(x);
int block_size = 256;
int num_blocks = batch;
softmax_exp_kernel<<<num_blocks, block_size>>>(
exp_out.data_ptr<float>(),
x.data_ptr<float>(),
max_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return exp_out;
}
torch::Tensor kahan_reduce_cuda(torch::Tensor exp_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto sum_out = torch::zeros({batch}, exp_vals.dtype());
int block_size = 256;
int num_blocks = batch;
int shared_size = block_size * sizeof(float);
kahan_reduce_kernel<<<num_blocks, block_size, shared_size>>>(
sum_out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return sum_out;
}
torch::Tensor normalize_softmax_cuda(torch::Tensor exp_vals, torch::Tensor sum_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto out = torch::empty_like(exp_vals);
int block_size = 256;
int num_blocks = batch;
normalize_kernel<<<num_blocks, block_size>>>(
out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
sum_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"softmax_exp\", &softmax_exp_cuda);
m.def(\"kahan_reduce\", &kahan_reduce_cuda);
m.def(\"normalize\", &normalize_softmax_cuda);
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_sources='',
cuda_sources=cuda_src,
verbose=False
)
# Test the kernel
x = torch.randn(4, 1024, dtype=torch.float32, device='cuda')
exp_out = softm
... (397 more chars)exp_out shape: torch.Size([4, 1024])
exp_out max: 1.0
exp_out min: 0.0012145806103944778
sum_vals shape: torch.Size([4])
sum_vals: tensor([ 64.1326, 38.8824, 77.5110, 108.3744])
out shape: torch.Size([4, 1024])
out sum per row: tensor([1.0000, 1.0000, 1.0000, 1.0000], device='cuda:0')
{
"command": "<4626 chars \u2014 see below>",
"description": "Verify CUDA softmax correctness"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum_val = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum_val + y;
c = (t - sum_val) - y;
sum_val = t;
}
extern __shared__ float sdata[];
sdata[threadIdx.x] = sum_val;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float a = sdata[threadIdx.x];
float b = sdata[threadIdx.x + s];
float y = b - c;
float t = a + y;
c = (t - a) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
torch::Tensor softmax_exp_cuda(torch::Tensor x) {
int batch = x.size(0);
int n_cols = x.size(1);
auto max_vals = x.amax(-1);
auto exp_out = torch::empty_like(x);
int block_size = 256;
int num_blocks = batch;
softmax_exp_kernel<<<num_blocks, block_size>>>(
exp_out.data_ptr<float>(),
x.data_ptr<float>(),
max_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return exp_out;
}
torch::Tensor kahan_reduce_cuda(torch::Tensor exp_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto sum_out = torch::zeros({batch}, exp_vals.dtype());
int block_size = 256;
int num_blocks = batch;
int shared_size = block_size * sizeof(float);
kahan_reduce_kernel<<<num_blocks, block_size, shared_size>>>(
sum_out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return sum_out;
}
torch::Tensor normalize_softmax_cuda(torch::Tensor exp_vals, torch::Tensor sum_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto out = torch::empty_like(exp_vals);
int block_size = 256;
int num_blocks = batch;
normalize_kernel<<<num_blocks, block_size>>>(
out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
sum_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"softmax_exp\", &softmax_exp_cuda);
m.def(\"kahan_reduce\", &kahan_reduce_cuda);
m.def(\"normalize\", &normalize_softmax_cuda);
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_sources='',
cuda_sources=cuda_src,
verbose=False
)
# Test with reference comparison
def ref_softmax(x):
x64 = x.to(torch.float64)
m = x64
... (626 more chars)batch=32, vocab=4096, max_diff=1.19e-07
batch=16, vocab=32768, max_diff=5.96e-08
batch=8, vocab=131072, max_diff=2.98e-08
batch=4, vocab=262144, max_diff=3.73e-09
{
"command": "<4752 chars \u2014 see below>",
"description": "Test extreme logits case"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum_val = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum_val + y;
c = (t - sum_val) - y;
sum_val = t;
}
extern __shared__ float sdata[];
sdata[threadIdx.x] = sum_val;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float a = sdata[threadIdx.x];
float b = sdata[threadIdx.x + s];
float y = b - c;
float t = a + y;
c = (t - a) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
torch::Tensor softmax_exp_cuda(torch::Tensor x) {
int batch = x.size(0);
int n_cols = x.size(1);
auto max_vals = x.amax(-1);
auto exp_out = torch::empty_like(x);
int block_size = 256;
int num_blocks = batch;
softmax_exp_kernel<<<num_blocks, block_size>>>(
exp_out.data_ptr<float>(),
x.data_ptr<float>(),
max_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return exp_out;
}
torch::Tensor kahan_reduce_cuda(torch::Tensor exp_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto sum_out = torch::zeros({batch}, exp_vals.dtype());
int block_size = 256;
int num_blocks = batch;
int shared_size = block_size * sizeof(float);
kahan_reduce_kernel<<<num_blocks, block_size, shared_size>>>(
sum_out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return sum_out;
}
torch::Tensor normalize_softmax_cuda(torch::Tensor exp_vals, torch::Tensor sum_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto out = torch::empty_like(exp_vals);
int block_size = 256;
int num_blocks = batch;
normalize_kernel<<<num_blocks, block_size>>>(
out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
sum_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"softmax_exp\", &softmax_exp_cuda);
m.def(\"kahan_reduce\", &kahan_reduce_cuda);
m.def(\"normalize\", &normalize_softmax_cuda);
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_sources='',
cuda_sources=cuda_src,
verbose=False
)
def ref_softmax(x):
x64 = x.to(torch.float64)
m = x64.amax(dim=-1, keepdim=True)
e
... (752 more chars)Traceback (most recent call last):
File "<string>", line 167, in <module>
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
{
"command": "<4741 chars \u2014 see below>",
"description": "Test extreme logits case fixed"
}uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
const float* __restrict__ x,
const float* max_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float max_val = max_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
exp_out[idx] = expf(x[idx] - max_val);
}
}
__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
const float* __restrict__ exp_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum_val = 0.0f;
float c = 0.0f;
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
float x = exp_vals[idx];
float y = x - c;
float t = sum_val + y;
c = (t - sum_val) - y;
sum_val = t;
}
extern __shared__ float sdata[];
sdata[threadIdx.x] = sum_val;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
float a = sdata[threadIdx.x];
float b = sdata[threadIdx.x + s];
float y = b - c;
float t = a + y;
c = (t - a) - y;
sdata[threadIdx.x] = t;
}
__syncthreads();
}
if (threadIdx.x == 0) {
sum_out[row] = sdata[0] + c;
}
}
__global__ void normalize_kernel(float* __restrict__ out,
const float* __restrict__ exp_vals,
const float* __restrict__ sum_vals,
int batch, int n_cols) {
int row = blockIdx.x;
if (row >= batch) return;
int row_offset = row * n_cols;
float sum = sum_vals[row];
for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
int idx = row_offset + col;
out[idx] = exp_vals[idx] / sum;
}
}
torch::Tensor softmax_exp_cuda(torch::Tensor x) {
int batch = x.size(0);
int n_cols = x.size(1);
auto max_vals = x.amax(-1);
auto exp_out = torch::empty_like(x);
int block_size = 256;
int num_blocks = batch;
softmax_exp_kernel<<<num_blocks, block_size>>>(
exp_out.data_ptr<float>(),
x.data_ptr<float>(),
max_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return exp_out;
}
torch::Tensor kahan_reduce_cuda(torch::Tensor exp_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto sum_out = torch::zeros({batch}, exp_vals.dtype());
int block_size = 256;
int num_blocks = batch;
int shared_size = block_size * sizeof(float);
kahan_reduce_kernel<<<num_blocks, block_size, shared_size>>>(
sum_out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return sum_out;
}
torch::Tensor normalize_softmax_cuda(torch::Tensor exp_vals, torch::Tensor sum_vals) {
int batch = exp_vals.size(0);
int n_cols = exp_vals.size(1);
auto out = torch::empty_like(exp_vals);
int block_size = 256;
int num_blocks = batch;
normalize_kernel<<<num_blocks, block_size>>>(
out.data_ptr<float>(),
exp_vals.data_ptr<float>(),
sum_vals.data_ptr<float>(),
batch, n_cols
);
cudaDeviceSynchronize();
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"softmax_exp\", &softmax_exp_cuda);
m.def(\"kahan_reduce\", &kahan_reduce_cuda);
m.def(\"normalize\", &normalize_softmax_cuda);
}
'''
softmax_module = load_inline(
name='softmax_kahan',
cpp_sources='',
cuda_sources=cuda_src,
verbose=False
)
def ref_softmax(x):
x64 = x.to(torch.float64)
m = x64.amax(dim=-1, keepdim=True)
e
... (741 more chars)batch=8, vocab=131072, extreme logits, max_diff=3.61e-16
Has NaN: False, Has Inf: False
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -1,13 +1,13 @@
"""Numerically tight softmax with Kahan summation for large-vocab stability.
Implements: y = softmax(x, dim=-1) in fp32 with subtract-max stability.
-For vocab >= 130K, uses two-pass Kahan-style summation to keep fp32
-accumulation error under 1e-5 tolerance.
+Uses two-pass Kahan-style compensated summation to keep fp32 accumulation
+error under 1e-5 tolerance even on the largest vocab sizes (256K+).
"""
+import math
import torch
import torch.nn as nn
-import triton
-import triton.jit as triton_jit
+from torch.utils.cpp_extension import load_inline
OP_TYPE = "softmax"
@@ -15,138 +15,184 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-@triton_jit
-def softmax_exp_kernel(
- exp_out, x_ptr, max_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
-):
- pid = tl.program_id(0)
- row_offset = pid * row_stride
- offsets = row_offset + tl.arange(0, BLOCK_SIZE)
- mask = offsets < row_offset + n_cols
-
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
- max_val = tl.load(max_ptr + pid)
- exp_x = tl.exp(x - max_val)
- tl.store(exp_out + offsets, exp_x, mask=mask)
-
-
-@triton_jit
-def kahan_reduce_kernel(
- sum_out, exp_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
-):
- pid = tl.program_id(0)
- row_offset = pid * row_stride
- offsets = row_offset + tl.arange(0, BLOCK_SIZE)
- mask = offsets < row_offset + n_cols
-
- exp_vals = tl.load(exp_ptr + offsets, mask=mask, other=0.0)
-
- sum_val = tl.sum(exp_vals)
- tl.store(sum_out + pid, sum_val)
-
-
-@triton_jit
-def kahan_compensate_kernel(
- comp_out, exp_ptr, sum_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
-):
- pid = tl.program_id(0)
- row_offset = pid * row_stride
- offsets = row_offset + tl.arange(0, BLOCK_SIZE)
- mask = offsets < row_offset + n_cols
-
- exp_vals = tl.load(exp_ptr + offsets, mask=mask, other=0.0)
- row_sum = tl.load(sum_ptr + pid)
-
- 补偿 = 0.0
- for i in range(BLOCK_SIZE):
- idx = row_offset + i
- if idx >= row_offset + n_cols:
- break
- e = exp_vals[i]
- y = e - 补偿
- t = row_sum + y
- 补偿 = (t - row_sum) - y
-
- tl.store(comp_out + pid, 补偿)
-
-
-@triton_jit
-def normalize_kernel(
- out_ptr, exp_ptr, sum_ptr, comp_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr
-):
- pid = tl.program_id(0)
- row_offset = pid * row_stride
- offsets = row_offset + tl.arange(0, BLOCK_SIZE)
- mask = offsets < row_offset + n_cols
-
- exp_vals = tl.load(exp_ptr + offsets, mask=mask, other=0.0)
- row_sum = tl.load(sum_ptr + pid)
- row_comp = tl.load(comp_ptr + pid)
-
- denom = row_sum + row_comp
- result = exp_vals / denom
- tl.store(out_ptr + offsets, result, mask=mask)
+cuda_src = '''
+#include <torch/extension.h>
+
+__global__ void softmax_exp_kernel(float* __restrict__ exp_out,
+ const float* __restrict__ x,
+ const float* max_vals,
+ int batch, int n_cols) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+
+ int row_offset = row * n_cols;
+
+ float max_val = max_vals[row];
+
+ for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
+ int idx = row_offset + col;
+ exp_out[idx] = expf(x[idx] - max_val);
+ }
+}
+
+__global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
+ const float* __restrict__ exp_vals,
+ int batch, int n_cols) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+
+ int row_offset = row * n_cols;
+
+ float sum_val = 0.0f;
+ float c = 0.0f;
+
+ for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
+ int idx = row_offset + col;
+ float x = exp_vals[idx];
+ float y = x - c;
+ float t = sum_val + y;
+ c = (t - sum_val) - y;
+ sum_val = t;
+ }
+
+ extern __shared__ float sdata[];
+ sdata[threadIdx.x] = sum_val;
+ __syncthreads();
+
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (threadIdx.x < s && threadIdx.x + s < blockDim.x) {
+ float a = sdata[threadIdx.x];
+ float b = sdata[threadIdx.x + s];
+ float y = b - c;
+ float t = a + y;
+ c = (t - a) - y;
+ sdata[threadIdx.x] = t;
+ }
+ __syncthreads();
+ }
+
+ if (threadIdx.x == 0) {
+ sum_out[row] = sdata[0] + c;
+ }
+}
+
+__global__ void normalize_kernel(float* __restrict__ out,
+ const float* __restrict__ exp_vals,
+ const float* __restrict__ sum_vals,
+ int batch, int n_cols) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+
+ int row_offset = row * n_cols;
+ float sum = sum_vals[row];
+
+ for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
+ int idx = row_offset + col;
+ out[idx] = exp_vals[idx] / sum;
+ }
+}
+
+torch::Tensor softmax_exp_cuda(torch::Tensor x) {
+ int batch = x.size(0);
+ int n_cols = x.size(1);
+
+ auto max_vals = x.amax(-1);
+ auto exp_out = torch::empty_like(x);
+
+ int block_size = 256;
+ int num_blocks = batch;
+
+ softmax_exp_kernel<<<num_blocks, block_size>>>(
+ exp_out.data_ptr<float>(),
+ x.data_ptr<float>(),
+ max_vals.data_ptr<float>(),
+ batch, n_cols
+ );
+ cudaDeviceSynchronize();
+
+ return exp_out;
+}
+
+torch::Tensor kahan_reduce_cuda(torch::Tensor exp_vals) {
+ int batch = exp_vals.size(0);
+ int n_cols = exp_vals.size(1);
+
+ auto sum_out = torch::zeros({batch}, exp_vals.dtype());
+
+ int block_size = 256;
+ int num_blocks = batch;
+ int shared_size = block_size * sizeof(float);
+
+ kahan_reduce_kernel<<<num_blocks, block_size, shared_size>>>(
+ sum_out.data_ptr<float>(),
+ exp_vals.data_ptr<float>(),
+ batch
... (3356 more chars){
"content": "<5238 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python check.py 2>&1",
"description": "Run check.py to verify correctness"
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"description": "Run benchmark to check performance"
}shape=0 variant=eager tflops=0.010 gbps=15.493 ms=0.068
shape=0 variant=compiled tflops=0.006 gbps=10.371 ms=0.101
shape=0 variant=sota tflops=0.035 gbps=55.305 ms=0.019
shape=0 variant=solution tflops=0.016 gbps=25.062 ms=0.042
shape=0 solution_peak_fraction=0.0139
shape=1 variant=eager tflops=0.042 gbps=66.433 ms=0.063
shape=1 variant=compiled tflops=0.026 gbps=42.213 ms=0.099
shape=1 variant=sota tflops=0.129 gbps=205.603 ms=0.020
shape=1 variant=solution tflops=0.028 gbps=45.197 ms=0.093
shape=1 solution_peak_fraction=0.0251
shape=2 variant=eager tflops=0.073 gbps=117.133 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.376 ms=0.219
shape=2 variant=sota tflops=0.111 gbps=176.885 ms=0.047
shape=2 variant=solution tflops=0.019 gbps=30.870 ms=0.272
shape=2 solution_peak_fraction=0.0171
shape=3 variant=eager tflops=0.072 gbps=114.850 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.286 ms=0.360
shape=3 variant=sota tflops=0.098 gbps=156.131 ms=0.054
shape=3 variant=solution tflops=0.011 gbps=16.955 ms=0.495
shape=3 solution_peak_fraction=0.0094
shape=4 variant=eager tflops=0.073 gbps=116.483 ms=0.072
shape=4 variant=compiled tflops=0.023 gbps=36.715 ms=0.228
shape=4 variant=sota tflops=0.100 gbps=160.726 ms=0.052
shape=4 variant=solution tflops=0.019 gbps=30.917 ms=0.271
shape=4 solution_peak_fraction=0.0172
peak_fraction: 0.0158
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -99,7 +99,7 @@
auto max_vals = x.amax(-1);
auto exp_out = torch::empty_like(x);
- int block_size = 256;
+ int block_size = 1024;
int num_blocks = batch;
softmax_exp_kernel<<<num_blocks, block_size>>>(
@@ -119,7 +119,7 @@
auto sum_out = torch::zeros({batch}, exp_vals.dtype());
- int block_size = 256;
+ int block_size = 1024;
int num_blocks = batch;
int shared_size = block_size * sizeof(float);
@@ -139,7 +139,7 @@
auto out = torch::empty_like(exp_vals);
- int block_size = 256;
+ int block_size = 1024;
int num_blocks = batch;
normalize_kernel<<<num_blocks, block_size>>>(
{
"content": "<5241 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python check.py 2>&1",
"description": "Run check.py with optimized kernel"
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python benchmark.py 2>&1",
"description": "Run benchmark with optimized kernel"
}shape=0 variant=eager tflops=0.010 gbps=15.589 ms=0.067
shape=0 variant=compiled tflops=0.006 gbps=10.317 ms=0.102
shape=0 variant=sota tflops=0.036 gbps=58.306 ms=0.018
shape=0 variant=solution tflops=0.018 gbps=28.079 ms=0.037
shape=0 solution_peak_fraction=0.0156
shape=1 variant=eager tflops=0.042 gbps=66.720 ms=0.063
shape=1 variant=compiled tflops=0.027 gbps=42.701 ms=0.098
shape=1 variant=sota tflops=0.129 gbps=205.765 ms=0.020
shape=1 variant=solution tflops=0.053 gbps=84.128 ms=0.050
shape=1 solution_peak_fraction=0.0467
shape=2 variant=eager tflops=0.073 gbps=116.664 ms=0.072
shape=2 variant=compiled tflops=0.024 gbps=38.404 ms=0.218
shape=2 variant=sota tflops=0.108 gbps=173.433 ms=0.048
shape=2 variant=solution tflops=0.054 gbps=87.105 ms=0.096
shape=2 solution_peak_fraction=0.0484
shape=3 variant=eager tflops=0.072 gbps=114.975 ms=0.073
shape=3 variant=compiled tflops=0.015 gbps=23.301 ms=0.360
shape=3 variant=sota tflops=0.099 gbps=158.156 ms=0.053
shape=3 variant=solution tflops=0.034 gbps=54.151 ms=0.155
shape=3 solution_peak_fraction=0.0301
shape=4 variant=eager tflops=0.072 gbps=115.102 ms=0.073
shape=4 variant=compiled tflops=0.023 gbps=36.883 ms=0.227
shape=4 variant=sota tflops=0.100 gbps=160.382 ms=0.052
shape=4 variant=solution tflops=0.015 gbps=23.302 ms=0.360
shape=4 solution_peak_fraction=0.0129
peak_fraction: 0.0268
RESULT: LOW
{
"command": "<1377 chars \u2014 see below>",
"description": "Verify all shapes and extreme cases"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python -c "
import torch
import solution
import reference
def test_shape(batch, vocab, extreme, seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
reference.BATCH = batch
reference.VOCAB = vocab
x = torch.randn(batch, vocab, dtype=torch.float32) * 4.0
if extreme:
idx = torch.randint(0, vocab, (batch, 4))
x.scatter_(1, idx, 30.0)
device = torch.device('cuda')
ref_model = reference.Model(batch, vocab).to(device).eval()
sol_model = solution.Model(batch, vocab).to(device).eval()
with torch.no_grad():
ref_out = ref_model(x.to(device))
sol_out = sol_model(x.to(device))
diff = (ref_out - sol_out).abs()
max_diff = diff.max().item()
ok = max_diff < 1e-5
print(f'batch={batch}, vocab={vocab}, extreme={extreme}, seed={seed}: max_diff={max_diff:.2e}, {\"OK\" if ok else \"FAIL\"}' )
return ok
shapes = [
(32, 4096, False),
(16, 32768, False),
(8, 131072, False),
(4, 262144, False),
(8, 131072, True),
]
all_pass = True
for batch, vocab, extreme in shapes:
for seed in (42, 123, 456):
if not test_shape(batch, vocab, extreme, seed):
all_pass = False
print()
print('All tests passed!' if all_pass else 'Some tests failed!')
"batch=32, vocab=4096, extreme=False, seed=42: max_diff=5.96e-08, OK
batch=32, vocab=4096, extreme=False, seed=123: max_diff=5.96e-08, OK
batch=32, vocab=4096, extreme=False, seed=456: max_diff=1.19e-07, OK
batch=16, vocab=32768, extreme=False, seed=42: max_diff=5.96e-08, OK
batch=16, vocab=32768, extreme=False, seed=123: max_diff=5.96e-08, OK
batch=16, vocab=32768, extreme=False, seed=456: max_diff=5.96e-08, OK
batch=8, vocab=131072, extreme=False, seed=42: max_diff=8.94e-08, OK
batch=8, vocab=131072, extreme=False, seed=123: max_diff=5.96e-08, OK
batch=8, vocab=131072, extreme=False, seed=456: max_diff=1.49e-08, OK
batch=4, vocab=262144, extreme=False, seed=42: max_diff=2.98e-08, OK
batch=4, vocab=262144, extreme=False, seed=123: max_diff=1.49e-08, OK
batch=4, vocab=262144, extreme=False, seed=456: max_diff=1.49e-08, OK
batch=8, vocab=131072, extreme=True, seed=42: max_diff=2.98e-08, OK
batch=8, vocab=131072, extreme=True, seed=123: max_diff=1.49e-08, OK
batch=8, vocab=131072, extreme=True, seed=456: max_diff=2.98e-08, OK
All tests passed!
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python check.py 2>&1",
"description": "Final check verification"
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/04_kahan_softmax/solution.py 2>&1",
"description": "Run ruff linter"
}I001 [*] Import block is un-sorted or un-formatted
--> problems/04_kahan_softmax/solution.py:7:1
|
5 | error under 1e-5 tolerance even on the largest vocab sizes (256K+).
6 | """
7 | / import math
8 | | import torch
9 | | import torch.nn as nn
10 | | from torch.utils.cpp_extension import load_inline
| |_________________________________________________^
|
help: Organize imports
F401 [*] `math` imported but unused
--> problems/04_kahan_softmax/solution.py:7:8
|
5 | error under 1e-5 tolerance even on the largest vocab sizes (256K+).
6 | """
7 | import math
| ^^^^
8 | import torch
9 | import torch.nn as nn
|
help: Remove unused import: `math`
Found 2 errors.
[*] 2 fixable with the `--fix` option.
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -4,7 +4,6 @@
Uses two-pass Kahan-style compensated summation to keep fp32 accumulation
error under 1e-5 tolerance even on the largest vocab sizes (256K+).
"""
-import math
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<96 chars \u2014 see below>",
"newString": "<84 chars \u2014 see below>"
}import math
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inlineimport torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inlineEdit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/04_kahan_softmax/solution.py 2>&1",
"description": "Run ruff check after fix"
}I001 [*] Import block is un-sorted or un-formatted
--> problems/04_kahan_softmax/solution.py:7:1
|
5 | error under 1e-5 tolerance even on the largest vocab sizes (256K+).
6 | """
7 | / import torch
8 | | import torch.nn as nn
9 | | from torch.utils.cpp_extension import load_inline
| |_________________________________________________^
|
help: Organize imports
Found 1 error.
[*] 1 fixable with the `--fix` option.
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/04_kahan_softmax/solution.py --fix 2>&1",
"description": "Run ruff with auto-fix"
}Found 1 error (1 fixed, 0 remaining).
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py</path>
<type>file</type>
<content>
1: """Numerically tight softmax with Kahan summation for large-vocab stability.
2:
3: Implements: y = softmax(x, dim=-1) in fp32 with subtract-max stability.
4: Uses two-pass Kahan-style compensated summation to keep fp32 accumulation
5: error under 1e-5 tolerance even on the largest vocab sizes (256K+).
6: """
7: import torch
8: import torch.nn as nn
9: from torch.utils.cpp_extension import load_inline
10:
11: OP_TYPE = "softmax"
12: SUPPORTED_PRECISIONS = ["fp32"]
13: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
14:
15:
16: cuda_src = '''
17: #include <torch/extension.h>
18:
19: __global__ void softmax_exp_kernel(float* __restrict__ exp_out,
20: const float* __restrict__ x,
21: const float* max_vals,
22: int batch, int n_cols) {
23: int row = blockIdx.x;
24: if (row >= batch) return;
25:
26: int row_offset = row * n_cols;
27:
28: float max_val = max_vals[row];
29:
30: for (int col = threadIdx.x; col < n_cols; col += blockDim.x) {
31: int idx = row_offset + col;
32: exp_out[idx] = expf(x[idx] - max_val);
33: }
34: }
35:
36: __global__ void kahan_reduce_kernel(float* __restrict__ sum_out,
37: const float* __restrict__ exp_vals,
38: int batch, int n_co
... (4810 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax && uv run python check.py 2>&1",
"description": "Final verification"
}PASS