import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _softmax_single_kernel(
X_ptr, Y_ptr,
batch, vocab,
stride_xb, stride_xv,
stride_yb, stride_yv,
BLOCK: tl.constexpr,
):
row_idx = tl.program_id(0)
if row_idx >= batch:
return
x_row = X_ptr + row_idx * stride_xb
y_row = Y_ptr + row_idx * stride_yb
row_max = -float('inf')
for block_start in range(0, vocab, BLOCK):
offs = block_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
row_max = tl.maximum(row_max, tl.max(x, axis=0))
total = 0.0
comp = 0.0
for block_start in range(0, vocab, BLOCK):
offs = block_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
e = tl.exp(x - row_max)
tl.store(y_row + offs, e, mask=mask)
block_sum = tl.sum(e, axis=0)
y = block_sum - comp
t = total + y
comp = (t - total) - y
total = t
inv_total = 1.0 / total
for block_start in range(0, vocab, BLOCK):
offs = block_start + tl.arange(0, BLOCK)
mask = offs < vocab
e = tl.load(y_row + offs, mask=mask, other=0.0)
tl.store(y_row + offs, e * inv_total, mask=mask)
@triton.jit
def _find_max_sum_kernel(
X_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_xb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
chunk_idx = pid % num_chunks
if row_idx >= batch:
return
chunk_start = chunk_idx * BLOCK
x_base = X_ptr + row_idx * stride_xb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
chunk_max = tl.max(x, axis=0)
e = tl.exp(x - chunk_max)
chunk_sum = tl.sum(e, axis=0)
tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
@triton.jit
def _finalize_kernel(
X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_xb, stride_yb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
chunk_idx = pid % num_chunks
if row_idx >= batch:
return
global_max = -float('inf')
for i in range(num_chunks):
m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
global_max = tl.maximum(global_max, m)
total = 0.0
comp = 0.0
for i in range(num_chunks):
cm = tl.load(buf_max_ptr + row_idx * num_chunks + i)
cs = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
adjusted = cs * tl.exp(cm - global_max)
y = adjusted - comp
t = total + y
comp = (t - total) - y
total = t
inv_total = 1.0 / total
chunk_start = chunk_idx * BLOCK
x_base = X_ptr + row_idx * stride_xb
y_base = Y_ptr + row_idx * stride_yb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
e = tl.exp(x - global_max)
tl.store(y_base + offs, e * inv_total, mask=mask)
class Model(nn.Module):
def __init__(self, batch: int, vocab: int):
super().__init__()
self.batch = batch
self.vocab = vocab
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
BLOCK = 4096
num_chunks = triton.cdiv(self.vocab, BLOCK)
if self.batch >= 32:
_softmax_single_kernel[(self.batch,)](
x, y,
self.batch, self.vocab,
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
BLOCK=BLOCK,
)
else:
total_programs = self.batch * num_chunks
buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
_find_max_sum_kernel[(total_programs,)](
x, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0),
BLOCK=BLOCK,
)
_finalize_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
)
return y
BATCH = 8
VOCAB = 32768
def get_inputs():
x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
return [x]
def get_init_inputs():
return [BATCH, VOCAB]
shape=0 variant=eager tflops=0.009 gbps=13.965 ms=0.075
shape=0 variant=compiled tflops=0.005 gbps=8.049 ms=0.130
shape=0 variant=sota tflops=0.007 gbps=10.634 ms=0.099
shape=0 variant=solution tflops=0.034 gbps=53.718 ms=0.020
shape=0 solution_peak_fraction=0.0298
shape=1 variant=eager tflops=0.037 gbps=58.619 ms=0.072
shape=1 variant=compiled tflops=0.018 gbps=28.979 ms=0.145
shape=1 variant=sota tflops=0.016 gbps=25.960 ms=0.162
shape=1 variant=solution tflops=0.065 gbps=103.206 ms=0.041
shape=1 solution_peak_fraction=0.0573
shape=2 variant=eager tflops=0.062 gbps=98.439 ms=0.085
shape=2 variant=compiled tflops=0.017 gbps=27.147 ms=0.309
shape=2 variant=sota tflops=0.029 gbps=47.080 ms=0.178
shape=2 variant=solution tflops=0.079 gbps=126.793 ms=0.066
shape=2 solution_peak_fraction=0.0704
shape=3 variant=eager tflops=0.064 gbps=102.580 ms=0.082
shape=3 variant=compiled tflops=0.012 gbps=18.623 ms=0.450
shape=3 variant=sota tflops=0.046 gbps=73.812 ms=0.114
shape=3 variant=solution tflops=0.076 gbps=121.899 ms=0.069
shape=3 solution_peak_fraction=0.0677
shape=4 variant=eager tflops=0.066 gbps=105.491 ms=0.080
shape=4 variant=compiled tflops=0.017 gbps=26.960 ms=0.311
shape=4 variant=sota tflops=0.031 gbps=50.291 ms=0.167
shape=4 variant=solution tflops=0.076 gbps=122.383 ms=0.069
shape=4 solution_peak_fraction=0.0680
peak_fraction: 0.0561
RESULT: LOW
PASS
{
"problem": "04_kahan_softmax",
"harness": "opencode",
"model": "zai/glm-5.1",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0561,
"elapsed_seconds": 1908,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 233050, "output_tokens": 15837, "cache_read_tokens": 2817408, "cache_creation_tokens": 0, "reasoning_tokens": 62983, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, shapes.py to understand interfaces and constraints",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement numerically tight softmax kernel in solution.py",
"status": "pending",
"priority": "high"
},
{
"content": "Verify correctness with check.py (all shapes PASS)",
"status": "pending",
"priority": "high"
},
{
"content": "Benchmark and optimize performance",
"status": "pending",
"priority": "medium"
}
... (3 more chars)content: 'Read reference.py, check.py, benchmark.py, shapes.py to understand interfaces and constraints'
status: 'in_progress'
priority: 'high'
content: 'Implement numerically tight softmax kernel in solution.py'
status: 'pending'
priority: 'high'
content: 'Verify correctness with check.py (all shapes PASS)'
status: 'pending'
priority: 'high'
content: 'Benchmark and optimize performance'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Kahan-corrected softmax.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if
4: available), and the agent's solution. Reports achieved TFLOPS, GB/s, and
5: peak_fraction. Softmax is memory-bound, so the score is GB/s / peak_dram.
6:
7: Output lines the harness picks up:
8: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10: """
11: import sys
12: from math import exp, log
13: from pathlib import Path
14:
15: import torch
16: import yaml
17:
18: REPO_ROOT = Path(__file__).resolve().parents[2]
19: sys.path.insert(0, str(REPO_ROOT))
20:
21: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22: from src.eval.timing import time_fn # noqa: E402
23: from src.hardware import get as get_hw # noqa: E402
24:
25:
26: def _eval_formula(expr: str, vars: dict) -> float:
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def _make_inputs(batch: int, vocab: int, extreme: bool) -> torch.Tensor:
31: if extreme:
32: x = torch.randn(batch, vocab) * 2.0
33: idx = torch.randint(0, vocab, (batch, 4))
34: x.scatter_(1, idx, 30.0)
35: else:
36: x = torch.randn(batch, vocab) * 4.0
37: return x.to(torch.float32)
38:
39:
40: def main():
41:
... (3706 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Kahan-corrected softmax.
2:
3: The point of this problem is numerical accuracy on long reductions. Shapes
4: mix typical LLM vocab sizes with deliberately adversarial regimes:
5:
6: - small vocab (sanity check; naive fp32 should pass)
7: - Llama3 vocab 128K (real-world, where fp16 accumulation starts to drift)
8: - 256K (DeepSeek-V3 / Gemma-3 class vocab; naive fp16 sum DOES drift past
9: the 1e-5 tolerance — this row is what proves Kahan was needed)
10: - extreme-logit edge case (large positive logits stress max-subtract +
11: summation; if the implementation accidentally exps before subtracting
12: max, this row overflows)
13:
14: The 'extreme' flag is read by check.py to switch input generation to a
15: distribution that produces a few very large logits per row.
16: """
17:
18: SHAPES = [
19: {"batch": 32, "vocab": 4096, "extreme": False}, # sanity
20: {"batch": 16, "vocab": 32768, "extreme": False}, # GPT-2 class
21: {"batch": 8, "vocab": 131072, "extreme": False}, # Llama3 vocab
22: {"batch": 4, "vocab": 262144, "extreme": False}, # 256K — Kahan needed
23: {"batch": 8, "vocab": 131072, "extreme": True}, # extreme logits edge
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/reference.py</path>
<type>file</type>
<content>
1: """Naive softmax over the last dim, computed in fp64 for ground-truth.
2:
3: The reference deliberately runs in float64 so that fp16 / fp32 accumulation
4: drift in agent solutions is exposed by the tight tolerance in problem.yaml.
5: The agent's job is to produce an fp32 softmax whose values match this
6: double-precision reference within atol=rtol=1e-5 — this requires either
7: fp32 accumulation or compensated (Kahan) summation when vocab is large.
8: """
9: import torch
10: import torch.nn as nn
11:
12: OP_TYPE = "softmax"
13: SUPPORTED_PRECISIONS = ["fp32"]
14: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
15:
16:
17: class Model(nn.Module):
18: """y = softmax(x, dim=-1) computed in fp64 then returned as fp32.
19:
20: No learned parameters — softmax is parameter-free. We still expose an
21: empty state_dict so the harness's strict load_state_dict matches.
22: """
23:
24: def __init__(self, batch: int, vocab: int):
25: super().__init__()
26: self.batch = batch
27: self.vocab = vocab
28:
29: def forward(self, x: torch.Tensor) -> torch.Tensor:
30: # Promote to fp64 for the ground-truth pathway. Even with double
31: # precision we still subtract the row-max for stability.
32: x64 = x.to(torch.float64)
33: m = x64.amax(dim=-1, keepdim=True)
34: e = torch.exp
... (592 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Kahan-corrected softmax.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3
4: seeds each, with the tight (1e-5) fp32 tolerance from problem.yaml. Also
5: rejects forbidden ops via grep.
6: """
7: import re
8: import sys
9: from pathlib import Path
10:
11: import torch
12: import yaml
13:
14: # Make the repo's src/ importable
15: REPO_ROOT = Path(__file__).resolve().parents[2]
16: sys.path.insert(0, str(REPO_ROOT))
17:
18: from src.eval.correctness import check_correctness # noqa: E402
19:
20:
21: def _make_inputs(batch: int, vocab: int, extreme: bool, seed: int) -> torch.Tensor:
22: g = torch.Generator().manual_seed(seed)
23: if extreme:
24: # Adversarial: most logits are mild but a handful per row are huge.
25: # If the kernel forgets to subtract the row-max before exp, this
26: # row overflows fp32 and produces NaN/Inf. If it accumulates in
27: # fp16, the long tail of small exp() values is lost beneath the
28: # tolerance threshold.
29: x = torch.randn(batch, vocab, generator=g) * 2.0
30: # Spike: 4 very large positive logits per row.
31: idx = torch.randint(0, vocab, (batch, 4), generator=g)
32: x.scatter_(1, idx, 30.0)
33: else:
34: x = torch.randn(batch, vocab, generator=g) * 4.0
35: return x.to(torc
... (3464 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/problem.yaml</path>
<type>file</type>
<content>
1: name: 04_kahan_softmax
2: display_name: "Kahan-corrected Softmax"
3: precision: fp32
4: regime: memory # softmax is bandwidth-bound: 2 passes over the input tensor
5:
6: # Softmax FLOPs: per-element exp + 2 reductions + divide. Roughly 5 flops/elt.
7: flops_formula: "5 * batch * vocab"
8:
9: # Bytes moved: read x once, write y once. Both fp32.
10: bytes_formula: "batch * vocab * 4 + batch * vocab * 4"
11:
12: hardware: [RTX_PRO_6000]
13: peak_tflops_key: fp32
14: peak_bandwidth_key: dram
15:
16: # TIGHTER than default (fp32 default is 1e-4). This problem exists
17: # specifically to test whether the agent uses compensated summation, so
18: # we squeeze the tolerance to 1e-5 — naive fp16 sum across 256K elements
19: # drifts past this; fp32 accumulation passes; Kahan/fp32 always passes.
20: tolerance:
21: "torch.float32": {"atol": 1.0e-5, "rtol": 1.0e-5}
22:
23: # Forbidden ops — block the obvious "just call the library" cheats. The
24: # agent must implement softmax themselves with explicit (compensated)
25: # summation logic.
26: forbidden:
27: - "torch.nn.functional.softmax"
28: - "torch.softmax"
29: - "F.softmax"
30: - "liger_kernel.softmax"
31: - "liger_kernel.transformers.softmax"
32: - ".softmax("
33:
34: sota:
35: name: "Liger-Kernel Softmax (Triton)"
36: url: "https://github.com/linkedin/Liger-Kernel"
37: function: "liger_
... (225 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for last-dim softmax.
2:
3: Preference order:
4: 1. liger-kernel's Triton softmax (LigerSoftmaxFunction) — fused, fp32-acc
5: 2. torch.compile(torch.softmax) — Inductor generates a Triton kernel that
6: is competitive on bandwidth-bound reductions
7:
8: Agents are FORBIDDEN from importing either of these in solution.py (see
9: problem.yaml.forbidden). This file is only the benchmark's reference line.
10: """
11: from __future__ import annotations
12:
13: import torch
14:
15: _compiled_softmax = None
16:
17:
18: def _liger_softmax(x: torch.Tensor) -> torch.Tensor | None:
19: try:
20: from liger_kernel.ops.softmax import LigerSoftmaxFunction
21: return LigerSoftmaxFunction.apply(x)
22: except Exception:
23: return None
24:
25:
26: def _compiled(x: torch.Tensor) -> torch.Tensor:
27: global _compiled_softmax
28: if _compiled_softmax is None:
29: _compiled_softmax = torch.compile(
30: lambda t: torch.softmax(t, dim=-1),
31: mode="reduce-overhead",
32: )
33: return _compiled_softmax(x)
34:
35:
36: def sota_forward(x: torch.Tensor) -> torch.Tensor:
37: """Best-available softmax reference. x: (batch, vocab) fp32."""
38: out = _liger_softmax(x)
39: if out is not None:
40: return out
41: return _compiled(x)
42:
43:
44: def is_a
... (128 more chars){
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, shapes.py to understand interfaces and constraints",
"priority": "high",
"status": "completed"
},
{
"content": "Implement numerically tight softmax kernel in solution.py",
"priority": "high",
"status": "in_progress"
},
{
"content": "Verify correctness with check.py (all shapes PASS)",
"priority": "high",
"status": "pending"
},
{
"content": "Benchmark and optimize performance",
"priority": "medium",
"status": "pending"
}
... (5 more chars)content: 'Read reference.py, check.py, benchmark.py, shapes.py to understand interfaces and constraints'
status: 'completed'
priority: 'high'
content: 'Implement numerically tight softmax kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py (all shapes PASS)'
status: 'pending'
priority: 'high'
content: 'Benchmark and optimize performance'
status: 'pending'
priority: 'medium'{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>/dev/null || echo \"nvidia-smi failed\"",
"description": "Check GPU availability"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "uv run python -c \"import triton; print(triton.__version__); import torch; print(torch.cuda.get_device_name(0)); print(torch.cuda.get_device_capability(0))\"",
"description": "Check Triton and PyTorch versions"
}3.6.0
NVIDIA RTX PRO 6000 Blackwell Workstation Edition
(12, 0)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _softmax_kernel(
+ X_ptr, Y_ptr,
+ batch, vocab,
+ stride_xb, stride_xv,
+ stride_yb, stride_yv,
+ BLOCK_SIZE: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= batch:
+ return
+
+ x_row = X_ptr + row_idx * stride_xb
+ y_row = Y_ptr + row_idx * stride_yb
+
+ row_max = -float('inf')
+ for block_start in range(0, vocab, BLOCK_SIZE):
+ offs = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offs < vocab
+ x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
+ row_max = tl.maximum(row_max, tl.max(x, axis=0))
+
+ total = 0.0
+ comp = 0.0
+ for block_start in range(0, vocab, BLOCK_SIZE):
+ offs = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offs < vocab
+ x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
+ e = tl.exp(x - row_max)
+ tl.store(y_row + offs, e, mask=mask)
+ block_sum = tl.sum(e, axis=0)
+ y = block_sum - comp
+ t = total + y
+ comp = (t - total) - y
+ total = t
+
+ inv_total = 1.0 / total
+ for block_start in range(0, vocab, BLOCK_SIZE):
+ offs = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offs < vocab
+ e = tl.load(y_row + offs, mask=mask, other=0.0)
+ tl.store(y_row + offs, e * inv_total, mask=mask)
+
+
+class Model(nn.Module):
+ def __init__(self, batch: int, vocab: int):
+ super().__init__()
+ self.batch = batch
+ self.vocab = vocab
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = torch.empty_like(x)
+ BLOCK_SIZE = 2048
+ grid = (self.batch,)
+ _softmax_kernel[grid](
+ x, y,
+ self.batch, self.vocab,
+ x.stride(0), x.stride(1),
+ y.stride(0), y.stride(1),
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ return y
+
+
+BATCH = 8
+VOCAB = 32768
+
+
+def get_inputs():
+ x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
+ return [x]
+
+
+def get_init_inputs():
+ return [BATCH, VOCAB]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"content": "<2090 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "uv run python check.py",
"description": "Run correctness check",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax",
"timeout": 120000
}PASS
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, shapes.py to understand interfaces and constraints",
"priority": "high",
"status": "completed"
},
{
"content": "Implement numerically tight softmax kernel in solution.py",
"priority": "high",
"status": "completed"
},
{
"content": "Verify correctness with check.py (all shapes PASS)",
"priority": "high",
"status": "completed"
},
{
"content": "Benchmark and optimize performance",
"priority": "medium",
"status": "in_progress"
... (9 more chars)content: 'Read reference.py, check.py, benchmark.py, shapes.py to understand interfaces and constraints'
status: 'completed'
priority: 'high'
content: 'Implement numerically tight softmax kernel in solution.py'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py (all shapes PASS)'
status: 'completed'
priority: 'high'
content: 'Benchmark and optimize performance'
status: 'in_progress'
priority: 'medium'{
"command": "uv run python benchmark.py",
"description": "Run performance benchmark",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py:7836: UserWarning:
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.
warnings.warn(
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py:7836: UserWarning:
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.
warnings.warn(
shape=0 variant=eager tflops=0.009 gbps=14.461 ms=0.073
shape=0 variant=compiled tflops=0.007 gbps=11.102 ms=0.094
shape=0 variant=sota tflops=0.030 gbps=48.188 ms=0.022
shape=0 variant=solution tflops=0.040 gbps=64.000 ms=0.016
shape=0 solution_peak_fraction=0.0356
shape=1 variant=eager tflops=0.040 gbps=64.346 ms=0.065
shape=1 variant=compiled tflops=0.029 gbps=46.332 ms=0.091
shape=1 variant=sota tflops=0.042 gbps=67.424 ms=0.062
shape=1 variant=solution tflops=0.081 gbps=129.135 ms=0.032
shape=1 solution_peak_fraction=0.0717
shape=2 variant=eager tflops=0.069 gbps=111.078 ms=0.076
shape=2 variant=compiled tflops=0.023 gbps=36.536 ms=0.230
shape=2 variant=sota tflops=0.135 gbps=216.112 ms=0.039
shape=2 variant=solution tflops=0.044 gbps=71.186 ms=0.118
shape=2 solution_peak_fraction=0.0395
shape=3 variant=eager tflops=0.069 g
... (536 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -5,47 +5,91 @@
@triton.jit
-def _softmax_kernel(
- X_ptr, Y_ptr,
- batch, vocab,
- stride_xb, stride_xv,
- stride_yb, stride_yv,
- BLOCK_SIZE: tl.constexpr,
+def _find_chunk_max_kernel(
+ X_ptr, buf_max_ptr,
+ batch, vocab, num_chunks,
+ stride_xb,
+ BLOCK: tl.constexpr,
):
- row_idx = tl.program_id(0)
+ pid = tl.program_id(0)
+ row_idx = pid // num_chunks
+ chunk_idx = pid % num_chunks
if row_idx >= batch:
return
- x_row = X_ptr + row_idx * stride_xb
- y_row = Y_ptr + row_idx * stride_yb
+ chunk_start = chunk_idx * BLOCK
+ x_base = X_ptr + row_idx * stride_xb
+
+ offs = chunk_start + tl.arange(0, BLOCK)
+ mask = offs < vocab
+ x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
+ local_max = tl.max(x, axis=0)
+
+ tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, local_max)
+
+
+@triton.jit
+def _exp_sum_kernel(
+ X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
+ batch, vocab, num_chunks,
+ stride_xb, stride_yb,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_chunks
+ chunk_idx = pid % num_chunks
+ if row_idx >= batch:
+ return
row_max = -float('inf')
- for block_start in range(0, vocab, BLOCK_SIZE):
- offs = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offs < vocab
- x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
- row_max = tl.maximum(row_max, tl.max(x, axis=0))
+ for i in range(num_chunks):
+ m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
+ row_max = tl.maximum(row_max, m)
+
+ chunk_start = chunk_idx * BLOCK
+ x_base = X_ptr + row_idx * stride_xb
+ y_base = Y_ptr + row_idx * stride_yb
+
+ offs = chunk_start + tl.arange(0, BLOCK)
+ mask = offs < vocab
+ x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
+ e = tl.exp(x - row_max)
+ tl.store(y_base + offs, e, mask=mask)
+
+ chunk_sum = tl.sum(e, axis=0)
+ tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
+
+
+@triton.jit
+def _normalize_kernel(
+ Y_ptr, buf_sum_ptr,
+ batch, vocab, num_chunks,
+ stride_yb,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_chunks
+ chunk_idx = pid % num_chunks
+ if row_idx >= batch:
+ return
total = 0.0
comp = 0.0
- for block_start in range(0, vocab, BLOCK_SIZE):
- offs = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offs < vocab
- x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
- e = tl.exp(x - row_max)
- tl.store(y_row + offs, e, mask=mask)
- block_sum = tl.sum(e, axis=0)
- y = block_sum - comp
+ for i in range(num_chunks):
+ s = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
+ y = s - comp
t = total + y
comp = (t - total) - y
total = t
inv_total = 1.0 / total
- for block_start in range(0, vocab, BLOCK_SIZE):
- offs = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offs < vocab
- e = tl.load(y_row + offs, mask=mask, other=0.0)
- tl.store(y_row + offs, e * inv_total, mask=mask)
+ chunk_start = chunk_idx * BLOCK
+ y_base = Y_ptr + row_idx * stride_yb
+
+ offs = chunk_start + tl.arange(0, BLOCK)
+ mask = offs < vocab
+ e = tl.load(y_base + offs, mask=mask, other=0.0)
+ tl.store(y_base + offs, e * inv_total, mask=mask)
class Model(nn.Module):
@@ -56,14 +100,30 @@
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
- BLOCK_SIZE = 2048
- grid = (self.batch,)
- _softmax_kernel[grid](
- x, y,
- self.batch, self.vocab,
- x.stride(0), x.stride(1),
- y.stride(0), y.stride(1),
- BLOCK_SIZE=BLOCK_SIZE,
+ BLOCK = 2048
+ num_chunks = triton.cdiv(self.vocab, BLOCK)
+ total_programs = self.batch * num_chunks
+
+ buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
+ buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
+
+ _find_chunk_max_kernel[(total_programs,)](
+ x, buf_max,
+ self.batch, self.vocab, num_chunks,
+ x.stride(0),
+ BLOCK=BLOCK,
+ )
+ _exp_sum_kernel[(total_programs,)](
+ x, y, buf_max, buf_sum,
+ self.batch, self.vocab, num_chunks,
+ x.stride(0), y.stride(0),
+ BLOCK=BLOCK,
+ )
+ _normalize_kernel[(total_programs,)](
+ y, buf_sum,
+ self.batch, self.vocab, num_chunks,
+ y.stride(0),
+ BLOCK=BLOCK,
)
return y
{
"content": "<3622 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py",
"description": "Verify correctness of 3-kernel approach",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
{
"command": "uv run python benchmark.py",
"description": "Benchmark 3-kernel approach",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}shape=0 variant=eager tflops=0.009 gbps=14.079 ms=0.074
shape=0 variant=compiled tflops=0.005 gbps=8.680 ms=0.121
shape=0 variant=sota tflops=0.008 gbps=13.331 ms=0.079
shape=0 variant=solution tflops=0.022 gbps=35.832 ms=0.029
shape=0 solution_peak_fraction=0.0199
shape=1 variant=eager tflops=0.037 gbps=59.878 ms=0.070
shape=1 variant=compiled tflops=0.022 gbps=35.622 ms=0.118
shape=1 variant=sota tflops=0.029 gbps=47.182 ms=0.089
shape=1 variant=solution tflops=0.093 gbps=148.020 ms=0.028
shape=1 solution_peak_fraction=0.0822
shape=2 variant=eager tflops=0.065 gbps=104.774 ms=0.080
shape=2 variant=compiled tflops=0.020 gbps=32.512 ms=0.258
shape=2 variant=sota tflops=0.055 gbps=87.207 ms=0.096
shape=2 variant=solution tflops=0.175 gbps=280.518 ms=0.030
shape=2 solution_peak_fraction=0.1558
shape=3 variant=eager tflops=0.065 gbps=104.690 ms=0.080
shape=3 variant=compiled tflops=0.013 gbps=20.063 ms=0.418
shape=3 variant=sota tflops=0.049 gbps=78.888 ms=0.106
shape=3 variant=solution tflops=0.166 gbps=265.328 ms=0.032
shape=3 solution_peak_fraction=0.1474
shape=4 variant=eager tflops=0.067 gbps=106.541 ms=0.079
shape=4 variant=compiled tflops=0.020 gbps=32.552 ms=0.258
shape=4 variant=sota tflops=0.053 gbps=85.445 ms=0.098
shape=4 variant=solution tflops=0.171 gbps=274.209 ms=0.031
shape=4 solution_peak_fraction=0.1523
peak_fraction: 0.0895
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -5,31 +5,7 @@
@triton.jit
-def _find_chunk_max_kernel(
- X_ptr, buf_max_ptr,
- batch, vocab, num_chunks,
- stride_xb,
- BLOCK: tl.constexpr,
-):
- pid = tl.program_id(0)
- row_idx = pid // num_chunks
- chunk_idx = pid % num_chunks
- if row_idx >= batch:
- return
-
- chunk_start = chunk_idx * BLOCK
- x_base = X_ptr + row_idx * stride_xb
-
- offs = chunk_start + tl.arange(0, BLOCK)
- mask = offs < vocab
- x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
- local_max = tl.max(x, axis=0)
-
- tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, local_max)
-
-
-@triton.jit
-def _exp_sum_kernel(
+def _compute_exp_kernel(
X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_xb, stride_yb,
@@ -41,11 +17,6 @@
if row_idx >= batch:
return
- row_max = -float('inf')
- for i in range(num_chunks):
- m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
- row_max = tl.maximum(row_max, m)
-
chunk_start = chunk_idx * BLOCK
x_base = X_ptr + row_idx * stride_xb
y_base = Y_ptr + row_idx * stride_yb
@@ -53,16 +24,18 @@
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
- e = tl.exp(x - row_max)
+ chunk_max = tl.max(x, axis=0)
+ e = tl.exp(x - chunk_max)
tl.store(y_base + offs, e, mask=mask)
chunk_sum = tl.sum(e, axis=0)
+ tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
@triton.jit
-def _normalize_kernel(
- Y_ptr, buf_sum_ptr,
+def _rescale_normalize_kernel(
+ Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_yb,
BLOCK: tl.constexpr,
@@ -73,23 +46,30 @@
if row_idx >= batch:
return
+ global_max = -float('inf')
+ for i in range(num_chunks):
+ m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
+ global_max = tl.maximum(global_max, m)
+
total = 0.0
comp = 0.0
for i in range(num_chunks):
- s = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
- y = s - comp
+ cm = tl.load(buf_max_ptr + row_idx * num_chunks + i)
+ cs = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
+ adjusted = cs * tl.exp(cm - global_max)
+ y = adjusted - comp
t = total + y
comp = (t - total) - y
total = t
- inv_total = 1.0 / total
+ chunk_max = tl.load(buf_max_ptr + row_idx * num_chunks + chunk_idx)
+ rescale = tl.exp(chunk_max - global_max) / total
chunk_start = chunk_idx * BLOCK
y_base = Y_ptr + row_idx * stride_yb
-
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
e = tl.load(y_base + offs, mask=mask, other=0.0)
- tl.store(y_base + offs, e * inv_total, mask=mask)
+ tl.store(y_base + offs, e * rescale, mask=mask)
class Model(nn.Module):
@@ -107,20 +87,14 @@
buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
- _find_chunk_max_kernel[(total_programs,)](
- x, buf_max,
- self.batch, self.vocab, num_chunks,
- x.stride(0),
- BLOCK=BLOCK,
- )
- _exp_sum_kernel[(total_programs,)](
+ _compute_exp_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
)
- _normalize_kernel[(total_programs,)](
- y, buf_sum,
+ _rescale_normalize_kernel[(total_programs,)](
+ y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
y.stride(0),
BLOCK=BLOCK,
{
"content": "<3224 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py",
"description": "Verify correctness of 2-kernel approach",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
{
"command": "uv run python benchmark.py",
"description": "Benchmark 2-kernel approach",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}shape=0 variant=eager tflops=0.009 gbps=14.048 ms=0.075
shape=0 variant=compiled tflops=0.006 gbps=8.832 ms=0.119
shape=0 variant=sota tflops=0.009 gbps=13.662 ms=0.077
shape=0 variant=solution tflops=0.027 gbps=43.144 ms=0.024
shape=0 solution_peak_fraction=0.0240
shape=1 variant=eager tflops=0.038 gbps=60.935 ms=0.069
shape=1 variant=compiled tflops=0.023 gbps=36.263 ms=0.116
shape=1 variant=sota tflops=0.029 gbps=46.920 ms=0.089
shape=1 variant=solution tflops=0.112 gbps=178.451 ms=0.024
shape=1 solution_peak_fraction=0.0991
shape=2 variant=eager tflops=0.065 gbps=104.294 ms=0.080
shape=2 variant=compiled tflops=0.020 gbps=32.713 ms=0.256
shape=2 variant=sota tflops=0.054 gbps=86.932 ms=0.096
shape=2 variant=solution tflops=0.191 gbps=305.886 ms=0.027
shape=2 solution_peak_fraction=0.1699
shape=3 variant=eager tflops=0.066 gbps=105.406 ms=0.080
shape=3 variant=compiled tflops=0.013 gbps=20.034 ms=0.419
shape=3 variant=sota tflops=0.049 gbps=78.769 ms=0.106
shape=3 variant=solution tflops=0.167 gbps=266.678 ms=0.031
shape=3 solution_peak_fraction=0.1482
shape=4 variant=eager tflops=0.066 gbps=106.110 ms=0.079
shape=4 variant=compiled tflops=0.021 gbps=32.803 ms=0.256
shape=4 variant=sota tflops=0.053 gbps=85.459 ms=0.098
shape=4 variant=solution tflops=0.195 gbps=312.448 ms=0.027
shape=4 solution_peak_fraction=0.1736
peak_fraction: 0.1008
RESULT: OK
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py</path>
<type>file</type>
<content>
1: import torch
2: import torch.nn as nn
3: import triton
4: import triton.language as tl
5:
6:
7: @triton.jit
8: def _compute_exp_kernel(
9: X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
10: batch, vocab, num_chunks,
11: stride_xb, stride_yb,
12: BLOCK: tl.constexpr,
13: ):
14: pid = tl.program_id(0)
15: row_idx = pid // num_chunks
16: chunk_idx = pid % num_chunks
17: if row_idx >= batch:
18: return
19:
20: chunk_start = chunk_idx * BLOCK
21: x_base = X_ptr + row_idx * stride_xb
22: y_base = Y_ptr + row_idx * stride_yb
23:
24: offs = chunk_start + tl.arange(0, BLOCK)
25: mask = offs < vocab
26: x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
27: chunk_max = tl.max(x, axis=0)
28: e = tl.exp(x - chunk_max)
29: tl.store(y_base + offs, e, mask=mask)
30:
31: chunk_sum = tl.sum(e, axis=0)
32: tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
33: tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
34:
35:
36: @triton.jit
37: def _rescale_normalize_kernel(
38: Y_ptr, buf_max_ptr, buf_sum_ptr,
39: batch, vocab, num_chunks,
40: stride_yb,
41: BLOCK: tl.constexpr,
42: ):
43: pid = tl.program_id(0)
44: row_idx = pid // num_chunks
45: chunk_idx = pid % num_chunks
46: if row_idx >= batch:
47: return
48:
... (2351 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -2,6 +2,50 @@
import torch.nn as nn
import triton
import triton.language as tl
+
+
+@triton.jit
+def _softmax_single_kernel(
+ X_ptr, Y_ptr,
+ batch, vocab,
+ stride_xb, stride_xv,
+ stride_yb, stride_yv,
+ BLOCK: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ if row_idx >= batch:
+ return
+
+ x_row = X_ptr + row_idx * stride_xb
+ y_row = Y_ptr + row_idx * stride_yb
+
+ row_max = -float('inf')
+ for block_start in range(0, vocab, BLOCK):
+ offs = block_start + tl.arange(0, BLOCK)
+ mask = offs < vocab
+ x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
+ row_max = tl.maximum(row_max, tl.max(x, axis=0))
+
+ total = 0.0
+ comp = 0.0
+ for block_start in range(0, vocab, BLOCK):
+ offs = block_start + tl.arange(0, BLOCK)
+ mask = offs < vocab
+ x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
+ e = tl.exp(x - row_max)
+ tl.store(y_row + offs, e, mask=mask)
+ block_sum = tl.sum(e, axis=0)
+ y = block_sum - comp
+ t = total + y
+ comp = (t - total) - y
+ total = t
+
+ inv_total = 1.0 / total
+ for block_start in range(0, vocab, BLOCK):
+ offs = block_start + tl.arange(0, BLOCK)
+ mask = offs < vocab
+ e = tl.load(y_row + offs, mask=mask, other=0.0)
+ tl.store(y_row + offs, e * inv_total, mask=mask)
@triton.jit
@@ -82,23 +126,32 @@
y = torch.empty_like(x)
BLOCK = 2048
num_chunks = triton.cdiv(self.vocab, BLOCK)
- total_programs = self.batch * num_chunks
- buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
- buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
+ if self.batch >= 16:
+ _softmax_single_kernel[(self.batch,)](
+ x, y,
+ self.batch, self.vocab,
+ x.stride(0), x.stride(1),
+ y.stride(0), y.stride(1),
+ BLOCK=BLOCK,
+ )
+ else:
+ total_programs = self.batch * num_chunks
+ buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
+ buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
- _compute_exp_kernel[(total_programs,)](
- x, y, buf_max, buf_sum,
- self.batch, self.vocab, num_chunks,
- x.stride(0), y.stride(0),
- BLOCK=BLOCK,
- )
- _rescale_normalize_kernel[(total_programs,)](
- y, buf_max, buf_sum,
- self.batch, self.vocab, num_chunks,
- y.stride(0),
- BLOCK=BLOCK,
- )
+ _compute_exp_kernel[(total_programs,)](
+ x, y, buf_max, buf_sum,
+ self.batch, self.vocab, num_chunks,
+ x.stride(0), y.stride(0),
+ BLOCK=BLOCK,
+ )
+ _rescale_normalize_kernel[(total_programs,)](
+ y, buf_max, buf_sum,
+ self.batch, self.vocab, num_chunks,
+ y.stride(0),
+ BLOCK=BLOCK,
+ )
return y
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<3050 chars \u2014 see below>",
"newString": "<4682 chars \u2014 see below>"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _compute_exp_kernel(
X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_xb, stride_yb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
chunk_idx = pid % num_chunks
if row_idx >= batch:
return
chunk_start = chunk_idx * BLOCK
x_base = X_ptr + row_idx * stride_xb
y_base = Y_ptr + row_idx * stride_yb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
chunk_max = tl.max(x, axis=0)
e = tl.exp(x - chunk_max)
tl.store(y_base + offs, e, mask=mask)
chunk_sum = tl.sum(e, axis=0)
tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
@triton.jit
def _rescale_normalize_kernel(
Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_yb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
chunk_idx = pid % num_chunks
if row_idx >= batch:
return
global_max = -float('inf')
for i in range(num_chunks):
m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
global_max = tl.maximum(global_max, m)
total = 0.0
comp = 0.0
for i in range(num_chunks):
cm = tl.load(buf_max_ptr + row_idx * num_chunks + i)
cs = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
adjusted = cs * tl.exp(cm - global_max)
y = adjusted - comp
t = total + y
comp = (t - total) - y
total = t
chunk_max = tl.load(buf_max_ptr + row_idx * num_chunks + chunk_idx)
rescale = tl.exp(chunk_max - global_max) / total
chunk_start = chunk_idx * BLOCK
y_base = Y_ptr + row_idx * stride_yb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
e = tl.load(y_base + offs, mask=mask, other=0.0)
tl.store(y_base + offs, e * rescale, mask=mask)
class Model(nn.Module):
def __init__(self, batch: int, vocab: int):
super().__init__()
self.batch = batch
self.vocab = vocab
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
BLOCK = 2048
num_chunks = triton.cdiv(self.vocab, BLOCK)
total_programs = self.batch * num_chunks
buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
_compute_exp_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
)
_rescale_normalize_kernel[(total_programs,)](
y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
y.stride(0),
BLOCK=BLOCK,
)
return yimport torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _softmax_single_kernel(
X_ptr, Y_ptr,
batch, vocab,
stride_xb, stride_xv,
stride_yb, stride_yv,
BLOCK: tl.constexpr,
):
row_idx = tl.program_id(0)
if row_idx >= batch:
return
x_row = X_ptr + row_idx * stride_xb
y_row = Y_ptr + row_idx * stride_yb
row_max = -float('inf')
for block_start in range(0, vocab, BLOCK):
offs = block_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
row_max = tl.maximum(row_max, tl.max(x, axis=0))
total = 0.0
comp = 0.0
for block_start in range(0, vocab, BLOCK):
offs = block_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
e = tl.exp(x - row_max)
tl.store(y_row + offs, e, mask=mask)
block_sum = tl.sum(e, axis=0)
y = block_sum - comp
t = total + y
comp = (t - total) - y
total = t
inv_total = 1.0 / total
for block_start in range(0, vocab, BLOCK):
offs = block_start + tl.arange(0, BLOCK)
mask = offs < vocab
e = tl.load(y_row + offs, mask=mask, other=0.0)
tl.store(y_row + offs, e * inv_total, mask=mask)
@triton.jit
def _compute_exp_kernel(
X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_xb, stride_yb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
chunk_idx = pid % num_chunks
if row_idx >= batch:
return
chunk_start = chunk_idx * BLOCK
x_base = X_ptr + row_idx * stride_xb
y_base = Y_ptr + row_idx * stride_yb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
chunk_max = tl.max(x, axis=0)
e = tl.exp(x - chunk_max)
tl.store(y_base + offs, e, mask=mask)
chunk_sum = tl.sum(e, axis=0)
tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
@triton.jit
def _rescale_normalize_kernel(
Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_yb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
chunk_idx = pid % num_chunks
if row_idx >= batch:
return
global_max = -float('inf')
for i in range(num_chunks):
m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
global_max = tl.maximum(global_max, m)
total = 0.0
comp = 0.0
for i in range(num_chunks):
cm = tl.load(buf_max_ptr + row_idx * num_chunks + i)
cs = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
adjusted = cs * tl.exp(cm - global_max)
y = adjusted - comp
t = total + y
comp = (t - total) - y
total = t
chunk_max = tl.load(buf_max_ptr + row_idx * num_chunks + chunk_idx)
rescale = tl.exp(chunk_max - global_max) / total
chunk_start = chunk_idx * BLOCK
y_base = Y_ptr + row_idx * stride_yb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
e = tl.load(y_base + offs, mask=mask, other=0.0)
tl.store(y_base + offs, e * rescale, mask=mask)
class Model(nn.Module):
def __init__(self, batch: int, vocab: int):
super().__init__()
self.batch = batch
self.vocab = vocab
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
BLOCK = 2048
num_chunks = triton.cdiv(self.vocab, BLOCK)
if self.batch >= 16:
_softmax_single_kernel[(self.batch,)](
x, y,
self.batch, self.vocab,
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
BLOCK=BLOCK,
)
else:
total_programs =
... (682 more chars)Edit applied successfully.{
"command": "uv run python check.py",
"description": "Verify hybrid approach correctness",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
{
"command": "uv run python benchmark.py",
"description": "Benchmark hybrid approach",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}shape=0 variant=eager tflops=0.009 gbps=13.751 ms=0.076
shape=0 variant=compiled tflops=0.005 gbps=8.684 ms=0.121
shape=0 variant=sota tflops=0.008 gbps=13.076 ms=0.080
shape=0 variant=solution tflops=0.044 gbps=70.017 ms=0.015
shape=0 solution_peak_fraction=0.0389
shape=1 variant=eager tflops=0.037 gbps=58.750 ms=0.071
shape=1 variant=compiled tflops=0.021 gbps=33.851 ms=0.124
shape=1 variant=sota tflops=0.029 gbps=47.174 ms=0.089
shape=1 variant=solution tflops=0.063 gbps=101.136 ms=0.041
shape=1 solution_peak_fraction=0.0562
shape=2 variant=eager tflops=0.065 gbps=104.774 ms=0.080
shape=2 variant=compiled tflops=0.020 gbps=32.774 ms=0.256
shape=2 variant=sota tflops=0.053 gbps=84.508 ms=0.099
shape=2 variant=solution tflops=0.192 gbps=307.320 ms=0.027
shape=2 solution_peak_fraction=0.1707
shape=3 variant=eager tflops=0.066 gbps=105.236 ms=0.080
shape=3 variant=compiled tflops=0.013 gbps=20.072 ms=0.418
shape=3 variant=sota tflops=0.050 gbps=79.558 ms=0.105
shape=3 variant=solution tflops=0.172 gbps=274.496 ms=0.031
shape=3 solution_peak_fraction=0.1525
shape=4 variant=eager tflops=0.066 gbps=105.279 ms=0.080
shape=4 variant=compiled tflops=0.020 gbps=32.650 ms=0.257
shape=4 variant=sota tflops=0.053 gbps=85.570 ms=0.098
shape=4 variant=solution tflops=0.199 gbps=317.750 ms=0.026
shape=4 solution_peak_fraction=0.1765
peak_fraction: 0.1001
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -127,7 +127,7 @@
BLOCK = 2048
num_chunks = triton.cdiv(self.vocab, BLOCK)
- if self.batch >= 16:
+ if self.batch >= 32:
_softmax_single_kernel[(self.batch,)](
x, y,
self.batch, self.vocab,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": " if self.batch >= 16:",
"newString": " if self.batch >= 32:"
}Edit applied successfully.{
"command": "uv run python check.py && uv run python benchmark.py",
"description": "Verify and benchmark with threshold=32",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
shape=0 variant=eager tflops=0.009 gbps=14.048 ms=0.075
shape=0 variant=compiled tflops=0.005 gbps=8.703 ms=0.120
shape=0 variant=sota tflops=0.009 gbps=13.841 ms=0.076
shape=0 variant=solution tflops=0.043 gbps=68.985 ms=0.015
shape=0 solution_peak_fraction=0.0383
shape=1 variant=eager tflops=0.038 gbps=61.263 ms=0.068
shape=1 variant=compiled tflops=0.022 gbps=35.632 ms=0.118
shape=1 variant=sota tflops=0.030 gbps=47.516 ms=0.088
shape=1 variant=solution tflops=0.112 gbps=179.551 ms=0.023
shape=1 solution_peak_fraction=0.0998
shape=2 variant=eager tflops=0.065 gbps=103.922 ms=0.081
shape=2 variant=compiled tflops=0.020 gbps=32.648 ms=0.257
shape=2 variant=sota tflops=0.055 gbps=87.236 ms=0.096
shape=2 variant=solution tflops=0.191 gbps=305.529 ms=0.027
shape=2 solution_peak_fraction=0.1697
shape=3 variant=eager tflops=0.066 gbps=105.215 ms=0.080
shape=3 variant=compiled tflops=0.013 gbps=20.090 ms=0.418
shape=3 variant=sota tflops=0.049 gbps=78.923 ms=0.106
shape=3 variant=solution tflops=0.172 gbps=274.496 ms=0.031
shape=3 solution_peak_fraction=0.1525
shape=4 variant=eager tflops=0.066 gbps=105.661 ms=0.079
shape=4 variant=compiled tflops=0.020 gbps=32.644 ms=0.257
shape=4 variant=sota tflops=0.054 gbps=86.516 ms=0.097
shape=4 variant=solution tflops=0.198 gbps=316.599 ms=0.026
shape=4 solution_peak_fraction=0.1759
peak_fraction: 0.1117
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -49,10 +49,10 @@
@triton.jit
-def _compute_exp_kernel(
- X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
+def _find_max_sum_kernel(
+ X_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
- stride_xb, stride_yb,
+ stride_xb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
@@ -63,25 +63,23 @@
chunk_start = chunk_idx * BLOCK
x_base = X_ptr + row_idx * stride_xb
- y_base = Y_ptr + row_idx * stride_yb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
chunk_max = tl.max(x, axis=0)
e = tl.exp(x - chunk_max)
- tl.store(y_base + offs, e, mask=mask)
+ chunk_sum = tl.sum(e, axis=0)
- chunk_sum = tl.sum(e, axis=0)
tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
@triton.jit
-def _rescale_normalize_kernel(
- Y_ptr, buf_max_ptr, buf_sum_ptr,
+def _finalize_kernel(
+ X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
- stride_yb,
+ stride_xb, stride_yb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
@@ -106,14 +104,16 @@
comp = (t - total) - y
total = t
- chunk_max = tl.load(buf_max_ptr + row_idx * num_chunks + chunk_idx)
- rescale = tl.exp(chunk_max - global_max) / total
+ inv_total = 1.0 / total
chunk_start = chunk_idx * BLOCK
+ x_base = X_ptr + row_idx * stride_xb
y_base = Y_ptr + row_idx * stride_yb
+
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
- e = tl.load(y_base + offs, mask=mask, other=0.0)
- tl.store(y_base + offs, e * rescale, mask=mask)
+ x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
+ e = tl.exp(x - global_max)
+ tl.store(y_base + offs, e * inv_total, mask=mask)
class Model(nn.Module):
@@ -140,16 +140,16 @@
buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
- _compute_exp_kernel[(total_programs,)](
+ _find_max_sum_kernel[(total_programs,)](
+ x, buf_max, buf_sum,
+ self.batch, self.vocab, num_chunks,
+ x.stride(0),
+ BLOCK=BLOCK,
+ )
+ _finalize_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
- BLOCK=BLOCK,
- )
- _rescale_normalize_kernel[(total_programs,)](
- y, buf_max, buf_sum,
- self.batch, self.vocab, num_chunks,
- y.stride(0),
BLOCK=BLOCK,
)
return y
{
"content": "<4745 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py && uv run python benchmark.py",
"description": "Verify and benchmark optimized 2-kernel",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
shape=0 variant=eager tflops=0.009 gbps=14.054 ms=0.075
shape=0 variant=compiled tflops=0.006 gbps=8.831 ms=0.119
shape=0 variant=sota tflops=0.009 gbps=13.659 ms=0.077
shape=0 variant=solution tflops=0.044 gbps=69.645 ms=0.015
shape=0 solution_peak_fraction=0.0387
shape=1 variant=eager tflops=0.037 gbps=59.728 ms=0.070
shape=1 variant=compiled tflops=0.023 gbps=36.273 ms=0.116
shape=1 variant=sota tflops=0.029 gbps=46.463 ms=0.090
shape=1 variant=solution tflops=0.117 gbps=186.979 ms=0.022
shape=1 solution_peak_fraction=0.1039
shape=2 variant=eager tflops=0.066 gbps=105.300 ms=0.080
shape=2 variant=compiled tflops=0.021 gbps=32.811 ms=0.256
shape=2 variant=sota tflops=0.054 gbps=86.530 ms=0.097
shape=2 variant=solution tflops=0.196 gbps=313.195 ms=0.027
shape=2 solution_peak_fraction=0.1740
shape=3 variant=eager tflops=0.065 gbps=104.461 ms=0.080
shape=3 variant=compiled tflops=0.012 gbps=19.942 ms=0.421
shape=3 variant=sota tflops=0.048 gbps=76.684 ms=0.109
shape=3 variant=solution tflops=0.167 gbps=267.904 ms=0.031
shape=3 solution_peak_fraction=0.1488
shape=4 variant=eager tflops=0.065 gbps=104.149 ms=0.081
shape=4 variant=compiled tflops=0.020 gbps=32.770 ms=0.256
shape=4 variant=sota tflops=0.054 gbps=86.846 ms=0.097
shape=4 variant=solution tflops=0.200 gbps=320.274 ms=0.026
shape=4 solution_peak_fraction=0.1779
peak_fraction: 0.1131
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -134,6 +134,7 @@
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
BLOCK=BLOCK,
+ num_warps=8,
)
else:
total_programs = self.batch * num_chunks
@@ -145,12 +146,14 @@
self.batch, self.vocab, num_chunks,
x.stride(0),
BLOCK=BLOCK,
+ num_warps=8,
)
_finalize_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
+ num_warps=8,
)
return y
{
"content": "<4832 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py && uv run python benchmark.py",
"description": "Test with num_warps=8",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
shape=0 variant=eager tflops=0.009 gbps=13.991 ms=0.075
shape=0 variant=compiled tflops=0.005 gbps=8.691 ms=0.121
shape=0 variant=sota tflops=0.008 gbps=13.248 ms=0.079
shape=0 variant=solution tflops=0.043 gbps=68.913 ms=0.015
shape=0 solution_peak_fraction=0.0383
shape=1 variant=eager tflops=0.039 gbps=62.164 ms=0.067
shape=1 variant=compiled tflops=0.022 gbps=35.768 ms=0.117
shape=1 variant=sota tflops=0.029 gbps=46.031 ms=0.091
shape=1 variant=solution tflops=0.105 gbps=167.397 ms=0.025
shape=1 solution_peak_fraction=0.0930
shape=2 variant=eager tflops=0.064 gbps=102.661 ms=0.082
shape=2 variant=compiled tflops=0.020 gbps=32.059 ms=0.262
shape=2 variant=sota tflops=0.049 gbps=78.983 ms=0.106
shape=2 variant=solution tflops=0.177 gbps=283.706 ms=0.030
shape=2 solution_peak_fraction=0.1576
shape=3 variant=eager tflops=0.064 gbps=102.963 ms=0.081
shape=3 variant=compiled tflops=0.013 gbps=20.014 ms=0.419
shape=3 variant=sota tflops=0.049 gbps=78.510 ms=0.107
shape=3 variant=solution tflops=0.157 gbps=250.616 ms=0.033
shape=3 solution_peak_fraction=0.1392
shape=4 variant=eager tflops=0.065 gbps=104.795 ms=0.080
shape=4 variant=compiled tflops=0.020 gbps=32.684 ms=0.257
shape=4 variant=sota tflops=0.053 gbps=84.494 ms=0.099
shape=4 variant=solution tflops=0.191 gbps=304.996 ms=0.028
shape=4 solution_peak_fraction=0.1694
peak_fraction: 0.1058
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -134,7 +134,6 @@
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
BLOCK=BLOCK,
- num_warps=8,
)
else:
total_programs = self.batch * num_chunks
@@ -146,14 +145,12 @@
self.batch, self.vocab, num_chunks,
x.stride(0),
BLOCK=BLOCK,
- num_warps=8,
)
_finalize_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
- num_warps=8,
)
return y
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": "<1055 chars \u2014 see below>",
"newString": "<968 chars \u2014 see below>"
} if self.batch >= 32:
_softmax_single_kernel[(self.batch,)](
x, y,
self.batch, self.vocab,
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
BLOCK=BLOCK,
num_warps=8,
)
else:
total_programs = self.batch * num_chunks
buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
_find_max_sum_kernel[(total_programs,)](
x, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0),
BLOCK=BLOCK,
num_warps=8,
)
_finalize_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
num_warps=8,
) if self.batch >= 32:
_softmax_single_kernel[(self.batch,)](
x, y,
self.batch, self.vocab,
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
BLOCK=BLOCK,
)
else:
total_programs = self.batch * num_chunks
buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
_find_max_sum_kernel[(total_programs,)](
x, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0),
BLOCK=BLOCK,
)
_finalize_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
)Edit applied successfully.{
"command": "<3512 chars \u2014 see below>",
"description": "Compare BLOCK sizes for 2-kernel path",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}uv run python -c "
import torch, triton, triton.language as tl
@triton.jit
def _find_max_sum_kernel_4k(
X_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_xb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
chunk_idx = pid % num_chunks
if row_idx >= batch:
return
chunk_start = chunk_idx * BLOCK
x_base = X_ptr + row_idx * stride_xb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
chunk_max = tl.max(x, axis=0)
e = tl.exp(x - chunk_max)
chunk_sum = tl.sum(e, axis=0)
tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
@triton.jit
def _finalize_kernel_4k(
X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
batch, vocab, num_chunks,
stride_xb, stride_yb,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
chunk_idx = pid % num_chunks
if row_idx >= batch:
return
global_max = -float('inf')
for i in range(num_chunks):
m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
global_max = tl.maximum(global_max, m)
total = 0.0
comp = 0.0
for i in range(num_chunks):
cm = tl.load(buf_max_ptr + row_idx * num_chunks + i)
cs = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
adjusted = cs * tl.exp(cm - global_max)
y = adjusted - comp
t = total + y
comp = (t - total) - y
total = t
inv_total = 1.0 / total
chunk_start = chunk_idx * BLOCK
x_base = X_ptr + row_idx * stride_xb
y_base = Y_ptr + row_idx * stride_yb
offs = chunk_start + tl.arange(0, BLOCK)
mask = offs < vocab
x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
e = tl.exp(x - global_max)
tl.store(y_base + offs, e * inv_total, mask=mask)
def bench_shape(batch, vocab, BLOCK, n_warm=5, n_iter=50):
x = torch.randn(batch, vocab, dtype=torch.float32, device='cuda')
y = torch.empty_like(x)
num_chunks = triton.cdiv(vocab, BLOCK)
total_programs = batch * num_chunks
buf_max = torch.empty(batch, num_chunks, dtype=torch.float32, device='cuda')
buf_sum = torch.empty(batch, num_chunks, dtype=torch.float32, device='cuda')
for _ in range(n_warm):
_find_max_sum_kernel_4k[(total_programs,)](x, buf_max, buf_sum, batch, vocab, num_chunks, x.stride(0), BLOCK=BLOCK)
_finalize_kernel_4k[(total_programs,)](x, y, buf_max, buf_sum, batch, vocab, num_chunks, x.stride(0), y.stride(0), BLOCK=BLOCK)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_iter):
_find_max_sum_kernel_4k[(total_programs,)](x, buf_max, buf_sum, batch, vocab, num_chunks, x.stride(0), BLOCK=BLOCK)
_finalize_kernel_4k[(total_programs,)](x, y, buf_max, buf_sum, batch, vocab, num_chunks, x.stride(0), y.stride(0), BLOCK=BLOCK)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / n_iter
data_gb = 2 * batch * vocab * 4 / 1e9
gbps = data_gb / (ms / 1000)
return ms, gbps
for batch, vocab in [(16,32768), (8,131072), (4,262144), (8,131072)]:
for BS in [1024, 2048, 4096]:
ms, gbps = bench_shape(batch, vocab, BS)
print(f'batch={batch} vocab={vocab} BLOCK={BS}: {ms:.3f} ms, {gbps:.0f} GB/s')
"
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 4, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/_bench_block.py
@@ -0,0 +1,100 @@
+import torch
+import triton
+import triton.language as tl
+import sys
+sys.path.insert(0, '.')
+
+
+@triton.jit
+def _find_max_sum_kernel(
+ X_ptr, buf_max_ptr, buf_sum_ptr,
+ batch, vocab, num_chunks,
+ stride_xb,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_chunks
+ chunk_idx = pid % num_chunks
+ if row_idx >= batch:
+ return
+ chunk_start = chunk_idx * BLOCK
+ x_base = X_ptr + row_idx * stride_xb
+ offs = chunk_start + tl.arange(0, BLOCK)
+ mask = offs < vocab
+ x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
+ chunk_max = tl.max(x, axis=0)
+ e = tl.exp(x - chunk_max)
+ chunk_sum = tl.sum(e, axis=0)
+ tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
+ tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
+
+
+@triton.jit
+def _finalize_kernel(
+ X_ptr, Y_ptr, buf_max_ptr, buf_sum_ptr,
+ batch, vocab, num_chunks,
+ stride_xb, stride_yb,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ row_idx = pid // num_chunks
+ chunk_idx = pid % num_chunks
+ if row_idx >= batch:
+ return
+ global_max = -float('inf')
+ for i in range(num_chunks):
+ m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
+ global_max = tl.maximum(global_max, m)
+ total = 0.0
+ comp = 0.0
+ for i in range(num_chunks):
+ cm = tl.load(buf_max_ptr + row_idx * num_chunks + i)
+ cs = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
+ adjusted = cs * tl.exp(cm - global_max)
+ y = adjusted - comp
+ t = total + y
+ comp = (t - total) - y
+ total = t
+ inv_total = 1.0 / total
+ chunk_start = chunk_idx * BLOCK
+ x_base = X_ptr + row_idx * stride_xb
+ y_base = Y_ptr + row_idx * stride_yb
+ offs = chunk_start + tl.arange(0, BLOCK)
+ mask = offs < vocab
+ x = tl.load(x_base + offs, mask=mask, other=-float('inf'))
+ e = tl.exp(x - global_max)
+ tl.store(y_base + offs, e * inv_total, mask=mask)
+
+
+def bench_shape(batch, vocab, BLOCK, n_warm=10, n_iter=100):
+ x = torch.randn(batch, vocab, dtype=torch.float32, device='cuda')
+ y = torch.empty_like(x)
+ num_chunks = triton.cdiv(vocab, BLOCK)
+ total_programs = batch * num_chunks
+ buf_max = torch.empty(batch, num_chunks, dtype=torch.float32, device='cuda')
+ buf_sum = torch.empty(batch, num_chunks, dtype=torch.float32, device='cuda')
+
+ for _ in range(n_warm):
+ _find_max_sum_kernel[(total_programs,)](x, buf_max, buf_sum, batch, vocab, num_chunks, x.stride(0), BLOCK=BLOCK)
+ _finalize_kernel[(total_programs,)](x, y, buf_max, buf_sum, batch, vocab, num_chunks, x.stride(0), y.stride(0), BLOCK=BLOCK)
+ torch.cuda.synchronize()
+
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ for _ in range(n_iter):
+ _find_max_sum_kernel[(total_programs,)](x, buf_max, buf_sum, batch, vocab, num_chunks, x.stride(0), BLOCK=BLOCK)
+ _finalize_kernel[(total_programs,)](x, y, buf_max, buf_sum, batch, vocab, num_chunks, x.stride(0), y.stride(0), BLOCK=BLOCK)
+ end.record()
+ torch.cuda.synchronize()
+ ms = start.elapsed_time(end) / n_iter
+ data_gb = 2 * batch * vocab * 4 / 1e9
+ gbps = data_gb / (ms / 1000)
+ return ms, gbps
+
+
+for batch, vocab in [(16, 32768), (8, 131072), (4, 262144)]:
+ for BS in [1024, 2048, 4096]:
+ ms, gbps = bench_shape(batch, vocab, BS)
+ print(f'batch={batch} vocab={vocab} BLOCK={BS}: {ms:.3f} ms, {gbps:.0f} GB/s')
+ print()
{
"content": "<3529 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/_bench_block.py"
}Wrote file successfully.{
"command": "uv run python _bench_block.py",
"description": "Compare BLOCK sizes",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}batch=16 vocab=32768 BLOCK=1024: 0.012 ms, 355 GB/s
batch=16 vocab=32768 BLOCK=2048: 0.011 ms, 387 GB/s
batch=16 vocab=32768 BLOCK=4096: 0.013 ms, 317 GB/s
batch=8 vocab=131072 BLOCK=1024: 0.015 ms, 576 GB/s
batch=8 vocab=131072 BLOCK=2048: 0.012 ms, 716 GB/s
batch=8 vocab=131072 BLOCK=4096: 0.012 ms, 708 GB/s
batch=4 vocab=262144 BLOCK=1024: 0.025 ms, 333 GB/s
batch=4 vocab=262144 BLOCK=2048: 0.012 ms, 701 GB/s
batch=4 vocab=262144 BLOCK=4096: 0.011 ms, 733 GB/s
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -124,7 +124,7 @@
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
- BLOCK = 2048
+ BLOCK = 4096
num_chunks = triton.cdiv(self.vocab, BLOCK)
if self.batch >= 32:
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py",
"oldString": " BLOCK = 2048",
"newString": " BLOCK = 4096"
}Edit applied successfully.{
"command": "uv run python check.py && uv run python benchmark.py",
"description": "Test BLOCK=4096",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
shape=0 variant=eager tflops=0.009 gbps=13.914 ms=0.075
shape=0 variant=compiled tflops=0.005 gbps=8.669 ms=0.121
shape=0 variant=sota tflops=0.009 gbps=13.662 ms=0.077
shape=0 variant=solution tflops=0.043 gbps=69.058 ms=0.015
shape=0 solution_peak_fraction=0.0384
shape=1 variant=eager tflops=0.038 gbps=60.625 ms=0.069
shape=1 variant=compiled tflops=0.023 gbps=36.253 ms=0.116
shape=1 variant=sota tflops=0.029 gbps=46.603 ms=0.090
shape=1 variant=solution tflops=0.115 gbps=184.349 ms=0.023
shape=1 solution_peak_fraction=0.1024
shape=2 variant=eager tflops=0.065 gbps=104.067 ms=0.081
shape=2 variant=compiled tflops=0.021 gbps=32.813 ms=0.256
shape=2 variant=sota tflops=0.053 gbps=85.209 ms=0.098
shape=2 variant=solution tflops=0.214 gbps=343.120 ms=0.024
shape=2 solution_peak_fraction=0.1906
shape=3 variant=eager tflops=0.065 gbps=104.046 ms=0.081
shape=3 variant=compiled tflops=0.012 gbps=19.996 ms=0.420
shape=3 variant=sota tflops=0.049 gbps=78.124 ms=0.107
shape=3 variant=solution tflops=0.192 gbps=307.320 ms=0.027
shape=3 solution_peak_fraction=0.1707
shape=4 variant=eager tflops=0.065 gbps=104.711 ms=0.080
shape=4 variant=compiled tflops=0.020 gbps=32.516 ms=0.258
shape=4 variant=sota tflops=0.054 gbps=86.331 ms=0.097
shape=4 variant=solution tflops=0.218 gbps=349.292 ms=0.024
shape=4 solution_peak_fraction=0.1941
peak_fraction: 0.1199
RESULT: OK
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py</path>
<type>file</type>
<content>
1: import torch
2: import torch.nn as nn
3: import triton
4: import triton.language as tl
5:
6:
7: @triton.jit
8: def _softmax_single_kernel(
9: X_ptr, Y_ptr,
10: batch, vocab,
11: stride_xb, stride_xv,
12: stride_yb, stride_yv,
13: BLOCK: tl.constexpr,
14: ):
15: row_idx = tl.program_id(0)
16: if row_idx >= batch:
17: return
18:
19: x_row = X_ptr + row_idx * stride_xb
20: y_row = Y_ptr + row_idx * stride_yb
21:
22: row_max = -float('inf')
23: for block_start in range(0, vocab, BLOCK):
24: offs = block_start + tl.arange(0, BLOCK)
25: mask = offs < vocab
26: x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
27: row_max = tl.maximum(row_max, tl.max(x, axis=0))
28:
29: total = 0.0
30: comp = 0.0
31: for block_start in range(0, vocab, BLOCK):
32: offs = block_start + tl.arange(0, BLOCK)
33: mask = offs < vocab
34: x = tl.load(x_row + offs, mask=mask, other=-float('inf'))
35: e = tl.exp(x - row_max)
36: tl.store(y_row + offs, e, mask=mask)
37: block_sum = tl.sum(e, axis=0)
38: y = block_sum - comp
39: t = total + y
40: comp = (t - total) - y
41: total = t
42:
43: inv_total = 1.0 / total
44: for block_start in range(0, vocab, BLOCK):
45: offs = block_start + tl.ara
... (4137 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -2,6 +2,8 @@
import torch.nn as nn
import triton
import triton.language as tl
+
+_MAX_CHUNKS = 64
@triton.jit
@@ -54,6 +56,7 @@
batch, vocab, num_chunks,
stride_xb,
BLOCK: tl.constexpr,
+ MAX_CHUNKS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
@@ -71,8 +74,8 @@
e = tl.exp(x - chunk_max)
chunk_sum = tl.sum(e, axis=0)
- tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
- tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
+ tl.store(buf_max_ptr + row_idx * MAX_CHUNKS + chunk_idx, chunk_max)
+ tl.store(buf_sum_ptr + row_idx * MAX_CHUNKS + chunk_idx, chunk_sum)
@triton.jit
@@ -81,6 +84,7 @@
batch, vocab, num_chunks,
stride_xb, stride_yb,
BLOCK: tl.constexpr,
+ MAX_CHUNKS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
@@ -88,21 +92,14 @@
if row_idx >= batch:
return
- global_max = -float('inf')
- for i in range(num_chunks):
- m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
- global_max = tl.maximum(global_max, m)
+ chunk_offs = tl.arange(0, MAX_CHUNKS)
+ chunk_mask = chunk_offs < num_chunks
- total = 0.0
- comp = 0.0
- for i in range(num_chunks):
- cm = tl.load(buf_max_ptr + row_idx * num_chunks + i)
- cs = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
- adjusted = cs * tl.exp(cm - global_max)
- y = adjusted - comp
- t = total + y
- comp = (t - total) - y
- total = t
+ chunk_maxes = tl.load(buf_max_ptr + row_idx * MAX_CHUNKS + chunk_offs, mask=chunk_mask, other=-float('inf'))
+ global_max = tl.max(chunk_maxes, axis=0)
+
+ chunk_sums = tl.load(buf_sum_ptr + row_idx * MAX_CHUNKS + chunk_offs, mask=chunk_mask, other=0.0)
+ total = tl.sum(chunk_sums * tl.exp(chunk_maxes - global_max), axis=0)
inv_total = 1.0 / total
chunk_start = chunk_idx * BLOCK
@@ -137,20 +134,22 @@
)
else:
total_programs = self.batch * num_chunks
- buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
- buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
+ buf_max = torch.empty(self.batch, _MAX_CHUNKS, dtype=torch.float32, device=x.device)
+ buf_sum = torch.empty(self.batch, _MAX_CHUNKS, dtype=torch.float32, device=x.device)
_find_max_sum_kernel[(total_programs,)](
x, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0),
BLOCK=BLOCK,
+ MAX_CHUNKS=_MAX_CHUNKS,
)
_finalize_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
+ MAX_CHUNKS=_MAX_CHUNKS,
)
return y
{
"content": "<4821 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py && uv run python benchmark.py",
"description": "Test vectorized buffer reads",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
shape=0 variant=eager tflops=0.009 gbps=14.061 ms=0.075
shape=0 variant=compiled tflops=0.005 gbps=8.662 ms=0.121
shape=0 variant=sota tflops=0.009 gbps=13.679 ms=0.077
shape=0 variant=solution tflops=0.043 gbps=68.840 ms=0.015
shape=0 solution_peak_fraction=0.0382
shape=1 variant=eager tflops=0.038 gbps=61.435 ms=0.068
shape=1 variant=compiled tflops=0.023 gbps=36.258 ms=0.116
shape=1 variant=sota tflops=0.024 gbps=38.670 ms=0.108
shape=1 variant=solution tflops=0.078 gbps=124.534 ms=0.034
shape=1 solution_peak_fraction=0.0692
shape=2 variant=eager tflops=0.065 gbps=104.732 ms=0.080
shape=2 variant=compiled tflops=0.020 gbps=32.770 ms=0.256
shape=2 variant=sota tflops=0.054 gbps=86.374 ms=0.097
shape=2 variant=solution tflops=0.227 gbps=363.836 ms=0.023
shape=2 solution_peak_fraction=0.2021
shape=3 variant=eager tflops=0.066 gbps=105.597 ms=0.079
shape=3 variant=compiled tflops=0.013 gbps=20.088 ms=0.418
shape=3 variant=sota tflops=0.049 gbps=78.864 ms=0.106
shape=3 variant=solution tflops=0.224 gbps=357.632 ms=0.023
shape=3 solution_peak_fraction=0.1987
shape=4 variant=eager tflops=0.065 gbps=104.774 ms=0.080
shape=4 variant=compiled tflops=0.021 gbps=32.852 ms=0.255
shape=4 variant=sota tflops=0.054 gbps=86.019 ms=0.098
shape=4 variant=solution tflops=0.225 gbps=359.594 ms=0.023
shape=4 solution_peak_fraction=0.1998
peak_fraction: 0.1162
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py
@@ -2,8 +2,6 @@
import torch.nn as nn
import triton
import triton.language as tl
-
-_MAX_CHUNKS = 64
@triton.jit
@@ -56,7 +54,6 @@
batch, vocab, num_chunks,
stride_xb,
BLOCK: tl.constexpr,
- MAX_CHUNKS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
@@ -74,8 +71,8 @@
e = tl.exp(x - chunk_max)
chunk_sum = tl.sum(e, axis=0)
- tl.store(buf_max_ptr + row_idx * MAX_CHUNKS + chunk_idx, chunk_max)
- tl.store(buf_sum_ptr + row_idx * MAX_CHUNKS + chunk_idx, chunk_sum)
+ tl.store(buf_max_ptr + row_idx * num_chunks + chunk_idx, chunk_max)
+ tl.store(buf_sum_ptr + row_idx * num_chunks + chunk_idx, chunk_sum)
@triton.jit
@@ -84,7 +81,6 @@
batch, vocab, num_chunks,
stride_xb, stride_yb,
BLOCK: tl.constexpr,
- MAX_CHUNKS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // num_chunks
@@ -92,14 +88,21 @@
if row_idx >= batch:
return
- chunk_offs = tl.arange(0, MAX_CHUNKS)
- chunk_mask = chunk_offs < num_chunks
+ global_max = -float('inf')
+ for i in range(num_chunks):
+ m = tl.load(buf_max_ptr + row_idx * num_chunks + i)
+ global_max = tl.maximum(global_max, m)
- chunk_maxes = tl.load(buf_max_ptr + row_idx * MAX_CHUNKS + chunk_offs, mask=chunk_mask, other=-float('inf'))
- global_max = tl.max(chunk_maxes, axis=0)
-
- chunk_sums = tl.load(buf_sum_ptr + row_idx * MAX_CHUNKS + chunk_offs, mask=chunk_mask, other=0.0)
- total = tl.sum(chunk_sums * tl.exp(chunk_maxes - global_max), axis=0)
+ total = 0.0
+ comp = 0.0
+ for i in range(num_chunks):
+ cm = tl.load(buf_max_ptr + row_idx * num_chunks + i)
+ cs = tl.load(buf_sum_ptr + row_idx * num_chunks + i)
+ adjusted = cs * tl.exp(cm - global_max)
+ y = adjusted - comp
+ t = total + y
+ comp = (t - total) - y
+ total = t
inv_total = 1.0 / total
chunk_start = chunk_idx * BLOCK
@@ -134,22 +137,20 @@
)
else:
total_programs = self.batch * num_chunks
- buf_max = torch.empty(self.batch, _MAX_CHUNKS, dtype=torch.float32, device=x.device)
- buf_sum = torch.empty(self.batch, _MAX_CHUNKS, dtype=torch.float32, device=x.device)
+ buf_max = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
+ buf_sum = torch.empty(self.batch, num_chunks, dtype=torch.float32, device=x.device)
_find_max_sum_kernel[(total_programs,)](
x, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0),
BLOCK=BLOCK,
- MAX_CHUNKS=_MAX_CHUNKS,
)
_finalize_kernel[(total_programs,)](
x, y, buf_max, buf_sum,
self.batch, self.vocab, num_chunks,
x.stride(0), y.stride(0),
BLOCK=BLOCK,
- MAX_CHUNKS=_MAX_CHUNKS,
)
return y
{
"content": "<4745 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py && uv run python benchmark.py",
"description": "Final verification and benchmark",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
shape=0 variant=eager tflops=0.009 gbps=14.057 ms=0.075
shape=0 variant=compiled tflops=0.005 gbps=8.683 ms=0.121
shape=0 variant=sota tflops=0.008 gbps=13.600 ms=0.077
shape=0 variant=solution tflops=0.043 gbps=68.985 ms=0.015
shape=0 solution_peak_fraction=0.0383
shape=1 variant=eager tflops=0.038 gbps=60.527 ms=0.069
shape=1 variant=compiled tflops=0.022 gbps=35.964 ms=0.117
shape=1 variant=sota tflops=0.030 gbps=47.533 ms=0.088
shape=1 variant=solution tflops=0.114 gbps=183.061 ms=0.023
shape=1 solution_peak_fraction=0.1017
shape=2 variant=eager tflops=0.066 gbps=105.554 ms=0.079
shape=2 variant=compiled tflops=0.020 gbps=32.206 ms=0.260
shape=2 variant=sota tflops=0.053 gbps=85.084 ms=0.099
shape=2 variant=solution tflops=0.186 gbps=297.047 ms=0.028
shape=2 solution_peak_fraction=0.1650
shape=3 variant=eager tflops=0.065 gbps=103.369 ms=0.081
shape=3 variant=compiled tflops=0.013 gbps=20.130 ms=0.417
shape=3 variant=sota tflops=0.049 gbps=78.805 ms=0.106
shape=3 variant=solution tflops=0.193 gbps=308.768 ms=0.027
shape=3 solution_peak_fraction=0.1715
shape=4 variant=eager tflops=0.066 gbps=105.194 ms=0.080
shape=4 variant=compiled tflops=0.020 gbps=32.797 ms=0.256
shape=4 variant=sota tflops=0.054 gbps=86.260 ms=0.097
shape=4 variant=solution tflops=0.218 gbps=349.525 ms=0.024
shape=4 solution_peak_fraction=0.1942
peak_fraction: 0.1165
RESULT: OK
{
"command": "uv run python check.py",
"description": "Final correctness check",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax"
}PASS
{
"command": "rm -f /home/infatoshi/cuda/KernelBench-Hard/problems/04_kahan_softmax/_bench_block.py",
"description": "Clean up scratch file"
}(no output){
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, shapes.py to understand interfaces and constraints",
"priority": "high",
"status": "completed"
},
{
"content": "Implement numerically tight softmax kernel in solution.py",
"priority": "high",
"status": "completed"
},
{
"content": "Verify correctness with check.py (all shapes PASS)",
"priority": "high",
"status": "completed"
},
{
"content": "Benchmark and optimize performance",
"priority": "medium",
"status": "completed"
... (7 more chars)content: 'Read reference.py, check.py, benchmark.py, shapes.py to understand interfaces and constraints'
status: 'completed'
priority: 'high'
content: 'Implement numerically tight softmax kernel in solution.py'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py (all shapes PASS)'
status: 'completed'
priority: 'high'
content: 'Benchmark and optimize performance'
status: 'completed'
priority: 'medium'