model
custom:GLM-5.1-[Z.AI-Coding-Plan]-0
solution.py
benchmark.log
check.log
result.json
"""Multi-block Kahan-compensated softmax via Triton with CUDA graphs.
3-phase approach with CUDA graph capture for minimal launch overhead.
Phase 1: local max + exp-sum per chunk
Phase 2: Kahan-compensated reduction
Phase 3: normalize
The CUDA graph is captured once per (batch, vocab) and replayed.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "softmax"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
_graph_cache: dict = {}
@triton.jit
def _phase1(
X_ptr, PartialMax_ptr, PartialSum_ptr,
batch, vocab,
stride_x, stride_pm, stride_ps,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
nblk = tl.cdiv(vocab, BLOCK)
row = pid // nblk
blk = pid % nblk
if row >= batch:
return
cols = blk * BLOCK + tl.arange(0, BLOCK)
mask = cols < vocab
x = tl.load(X_ptr + row * stride_x + cols, mask=mask, other=-float("inf"))
m = tl.max(x, axis=0)
e = tl.exp(x - m)
s = tl.sum(e, axis=0)
tl.store(PartialMax_ptr + row * stride_pm + blk, m)
tl.store(PartialSum_ptr + row * stride_ps + blk, s)
@triton.jit
def _phase2(
PartialMax_ptr, PartialSum_ptr,
GlobalMax_ptr, GlobalInvSum_ptr,
batch, nblk,
stride_pm, stride_ps,
BLOCK_R: tl.constexpr,
):
row = tl.program_id(0)
if row >= batch:
return
m = -float("inf")
for off in range(0, nblk, BLOCK_R):
ids = off + tl.arange(0, BLOCK_R)
mask = ids < nblk
pm = tl.load(PartialMax_ptr + row * stride_pm + ids,
mask=mask, other=-float("inf"))
m = tl.maximum(m, tl.max(pm, axis=0))
s = 0.0
c = 0.0
for off in range(0, nblk, BLOCK_R):
ids = off + tl.arange(0, BLOCK_R)
mask = ids < nblk
pm = tl.load(PartialMax_ptr + row * stride_pm + ids,
mask=mask, other=-float("inf"))
ps = tl.load(PartialSum_ptr + row * stride_ps + ids,
mask=mask, other=0.0)
adj = ps * tl.exp(pm - m)
bs = tl.sum(adj, axis=0)
y = bs - c
t = s + y
c = (t - s) - y
s = t
tl.store(GlobalMax_ptr + row, m)
tl.store(GlobalInvSum_ptr + row, 1.0 / s)
@triton.jit
def _phase3(
X_ptr, Y_ptr,
GlobalMax_ptr, GlobalInvSum_ptr,
batch, vocab,
stride_x, stride_y,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
nblk = tl.cdiv(vocab, BLOCK)
row = pid // nblk
blk = pid % nblk
if row >= batch:
return
gm = tl.load(GlobalMax_ptr + row)
gi = tl.load(GlobalInvSum_ptr + row)
cols = blk * BLOCK + tl.arange(0, BLOCK)
mask = cols < vocab
x = tl.load(X_ptr + row * stride_x + cols, mask=mask, other=-float("inf"))
out = tl.exp(x - gm) * gi
tl.store(Y_ptr + row * stride_y + cols, out, mask=mask)
class Model(nn.Module):
def __init__(self, batch: int, vocab: int):
super().__init__()
self.batch = batch
self.vocab = vocab
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, V = x.shape
device = x.device
BLOCK = 1024
nblk = (V + BLOCK - 1) // BLOCK
total = B * nblk
key = (device, B, nblk)
graph_info = _graph_cache.get(key)
if graph_info is None:
# First call: allocate buffers and capture graph
pm = torch.empty(B, nblk, device=device, dtype=torch.float32)
ps = torch.empty(B, nblk, device=device, dtype=torch.float32)
gm = torch.empty(B, device=device, dtype=torch.float32)
gi = torch.empty(B, device=device, dtype=torch.float32)
y = torch.empty_like(x)
# Warmup individual kernels
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
_phase1[(total,)](
x, pm, ps,
B, V,
x.stride(0), nblk, nblk,
BLOCK=BLOCK, num_warps=8,
)
_phase2[(B,)](
pm, ps, gm, gi,
B, nblk,
nblk, nblk,
BLOCK_R=1024, num_warps=4,
)
_phase3[(total,)](
x, y, gm, gi,
B, V,
x.stride(0), y.stride(0),
BLOCK=BLOCK, num_warps=8,
)
torch.cuda.current_stream().wait_stream(s)
# Capture graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
_phase1[(total,)](
x, pm, ps,
B, V,
x.stride(0), nblk, nblk,
BLOCK=BLOCK, num_warps=8,
)
_phase2[(B,)](
pm, ps, gm, gi,
B, nblk,
nblk, nblk,
BLOCK_R=1024, num_warps=4,
)
_phase3[(total,)](
x, y, gm, gi,
B, V,
x.stride(0), y.stride(0),
BLOCK=BLOCK, num_warps=8,
)
_graph_cache[key] = (g, x, y, pm, ps, gm, gi)
return y
g, x_cap, y, pm, ps, gm, gi = graph_info
# Copy input to graph's input buffer
x_cap.copy_(x)
g.replay()
return y
BATCH = 8
VOCAB = 32768
def get_inputs():
x = torch.randn(BATCH, VOCAB, dtype=torch.float32) * 4.0
return [x]
def get_init_inputs():
return [BATCH, VOCAB]
shape=0 variant=eager tflops=0.009 gbps=13.947 ms=0.075
shape=0 variant=compiled tflops=0.006 gbps=9.256 ms=0.113
shape=0 variant=sota tflops=0.009 gbps=14.499 ms=0.072
shape=0 variant=solution tflops=0.100 gbps=159.844 ms=0.007
shape=0 solution_peak_fraction=0.0888
shape=1 variant=eager tflops=0.037 gbps=58.619 ms=0.072
shape=1 variant=compiled tflops=0.022 gbps=35.890 ms=0.117
shape=1 variant=sota tflops=0.029 gbps=46.628 ms=0.090
shape=1 variant=solution tflops=0.263 gbps=420.103 ms=0.010
shape=1 solution_peak_fraction=0.2334
shape=2 variant=eager tflops=0.063 gbps=101.508 ms=0.083
shape=2 variant=compiled tflops=0.020 gbps=32.369 ms=0.259
shape=2 variant=sota tflops=0.051 gbps=81.373 ms=0.103
shape=2 variant=solution tflops=0.367 gbps=587.109 ms=0.014
shape=2 solution_peak_fraction=0.3262
shape=3 variant=eager tflops=0.061 gbps=97.433 ms=0.086
shape=3 variant=compiled tflops=0.012 gbps=19.343 ms=0.434
shape=3 variant=sota tflops=0.048 gbps=76.808 ms=0.109
shape=3 variant=solution tflops=0.503 gbps=805.358 ms=0.010
shape=3 solution_peak_fraction=0.4474
shape=4 variant=eager tflops=0.063 gbps=101.488 ms=0.083
shape=4 variant=compiled tflops=0.020 gbps=32.554 ms=0.258
shape=4 variant=sota tflops=0.053 gbps=84.631 ms=0.099
shape=4 variant=solution tflops=0.260 gbps=416.763 ms=0.020
shape=4 solution_peak_fraction=0.2315
peak_fraction: 0.2339
RESULT: OK
{
"problem": "04_kahan_softmax",
"harness": "droid",
"model": "custom:GLM-5.1-[Z.AI-Coding-Plan]-0",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.2339,
"elapsed_seconds": 1747,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 57047, "output_tokens": 59390, "cache_read_tokens": 2759424, "cache_creation_tokens": 0, "reasoning_tokens": 0, "total_cost_usd": null}
}
timeline (34 events)
system — init
session start cwd=/tmp/KernelBench-Hard-zai-droid/problems/04_kahan_softmax model=custom:GLM-5.1-[Z.AI-Coding-Plan]-0 effort=none tools=31
user
I need you to write a numerically tight softmax for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
The op is softmax along the last dim of a 2D fp32 tensor. The reference computes ground truth in fp64 and casts back to fp32, and you have to match it within atol=rtol=1e-5 — that's a tighter bar than default fp32 (1e-4) on purpose. With long reductions naive fp16 accumulation drifts past this; fp32 accumulation with subtract-max stability is enough on most shapes; on the largest vocabs you may need compensated (Kahan-style) summation to stay under the bar. The shapes you have to handle are batch=32 vocab=4096 (sanity), batch=16 vocab=32768 (GPT-2 class), batch=8 vocab=131072 (Llama-3 vocab), batch=4 vocab=262144 (256K, DeepSeek-V3 / Gemma-3 class — naive fp16 sum drifts past 1e-5 here), and batch=8 vocab=131072 with extreme logits (a few very large positives per row to stress max-subtract — exping before subtracting overflows). The check and benchmark scripts handle the extreme-flag input generation; you just need to read x and return y.
This needs to be a real custom kernel. Don't import or call torch.nn.functional.softmax, torch.softmax, F.softmax, liger_kernel.softmax, liger_kernel.transformers.softmax, or any .softmax( method on a tensor. Try Triton, CUDA C++ via load_inline, or inline PTX — whatever fits. Anything you're uncertain about, look up PTX docs, clone Liger-Kernel or other reference repos, read library headers, and investigate.
Your flywheel is implement, profile (ncu, nsys, torch.profiler — whatever's useful) and time it with benchmark.py, verify correctness by running `python check.py` and reading the output, then iterate. Don't substitute your own one-off correctness snippets for check.py — it iterates over every shape, your spot-check almost certainly won't. If `python check.py` hasn't printed PASS, you're not done. Take as long as you need to actually push the number up.
assistant
Now let me create the solution file with a Triton kernel implementing online softmax with Kahan-compensated summation.
assistant
Correctness passes. Now let's benchmark:
assistant
Correctness passes! But performance is LOW (0.0307 peak fraction). The issue is the single-program-per-row approach gives only `batch` number of CTAs, which is far too few for GPU utilization on larger vocab sizes. I need a multi-block approach.
Let me redesign with a 3-phase multi-block kernel:
assistant
Good progress — peak_fraction is 0.1892, well above the 0.1 threshold. But shape 0 (small vocab) is slow due to too few CTAs. Let me optimize further. The main bottleneck is:
1. Shape 0: only 32×4 = 128 CTAs — not terrible but launch overhead dominates
2. Could fuse phase 1 + phase 3 into a two-pass per-CTA approach with just a single reduction for max/sum
Let me try a fused two-pass approach (per-row, but with larger warps and better occupancy) plus the multi-block for large shapes:
assistant
Need to inline the constant. Let me fix:
assistant
The multi-block (3-phase) approach was better for large shapes but bad for shape 0. Let me create a hybrid that uses single-CTA fused for small vocab and multi-block for large:
assistant
The hybrid approach has worse performance on multi-block shapes. Let me go back to the all-multi-block approach with optimized block sizes:
assistant
Peak fraction is 0.2172 with RESULT: OK. Let me try to push further — shape 0 is still low at 0.068. Let me experiment with BLOCK sizes and num_warps:
assistant
Let me try different configurations to find the sweet spot. Shape 0 has batch=32, vocab=4096, so with BLOCK=512 that's 32×8=256 CTAs. Let me try BLOCK=256 (more CTAs) and also try with more warps for the large shapes:
assistant
Similar result. Shape 0 (batch=32, vocab=4096) is dragging down the geomean. With BLOCK=256, that's 32×16=512 CTAs. Let me try an even smaller block and also try using a fused kernel for shape 0 since it has enough work per row:
assistant
BLOCK=128 is too small — too much overhead per block. Let me go back to BLOCK=512 which was the best, and also try BLOCK=256 for shape 0 only. Actually, let me go back to the approach that gave 0.2172 and then try a fused kernel for the smallest shape:
assistant
Something is different — earlier this same configuration gave 0.2172. The variance is likely from JIT compilation caching. Let me try running again with a warm-up:
assistant
Interesting — the previous 0.2172 result was from a different version that had BLOCK=1024. The BLOCK=512 version is much slower now consistently. Let me try BLOCK=1024:
assistant
Something is fundamentally slower now. It could be the Triton cache or other environment state. Let me clear the cache and try again:
assistant
The multi-block approach is much slower now. The earlier good result must have been from a JIT cache effect. The 3-kernel launch overhead is too high. Let me go back to the fused single-CTA approach which gave better results earlier, and try to make it work well across all shapes:
assistant
The fused single-CTA is very slow for large vocab because there are too few CTAs (e.g., batch=4 for 256K vocab). The multi-block approach from before was much faster. Let me go back to the multi-block approach but tune it better. Let me also check what Triton version / GPU we actually have:
assistant
188 SMs. The RTX PRO 6000 has a massive L2 cache (128MB). Let me re-examine. The 3-phase multi-block gave 0.2172 peak fraction in one earlier run but then was slow. That's odd. Let me check if there's some persistent state issue and try again with the multi-block approach using 1024 blocks:
assistant
We're at 0.1035 — just barely OK. The problem is shapes 0 and 1 with small vocab. For shape 0 (batch=32, vocab=4096), we only have 32×4=128 CTAs. For shape 1 (batch=16, vocab=32768), we have 16×32=512 CTAs. The 3-kernel overhead is killing small shapes.
Let me try a fused kernel for small shapes and multi-block for large ones:
assistant
OK, 0.1087. Let me run again to check stability:
assistant
The results are highly variable. The small shapes (0, 1, 2) have inconsistent performance. The fused kernel for small shapes has too few CTAs and suffers from tail-effect variability. Let me think about this differently.
The key issue is: for small shapes, we don't have enough parallelism. With batch=32 and vocab=4096, a single CTA per row means only 32 CTAs. For batch=16, vocab=32768, that's only 16 CTAs. The multi-block approach for these small shapes would give 32×8=256 and 16×32=512 CTAs respectively, but the 3-kernel overhead eats the gains.
Let me try a different approach: use the multi-block for everything but use `num_warps=2` for phases 1/3 to maximize occupancy:
assistant
num_warps=4 was worse. Let me go back to the version that gave 0.1035 (multi-block with num_warps=8 and BLOCK=1024). The issue is the run-to-run variability. Let me look at what the timing framework does:
assistant
The timing methodology does L2 flush and takes median of 30 trials. That's good. The variability is real — the kernel is doing different amounts of work on different data.
Let me look at the peak bandwidth to understand what we should target:
assistant
Peak bandwidth is 1800 GB/s. For shape 0 (batch=32, vocab=4096), the data is 32×4096×4×2 = 1 MB. For shape 1 (batch=16, vocab=32768), it's 16×32768×4×2 = 4 MB. These are very small — L2 flush (128 MB) dominates the timing.
The solution is to minimize total kernel time. The 3-kernel approach has inherent overhead from 3 separate kernel launches plus 2 extra tensor allocations. Let me see if I can fuse the 3 phases into 2 phases or even better.
Actually, wait — the L2 flush happens BEFORE each timed call. So the actual data won't be in L2. For tiny shapes, the kernel time is dominated by launch overhead, not data transfer. Let me try a different approach: fuse phases 1+3 into a single kernel that does both passes, with phase 2 running in between. This eliminates 1 kernel launch:
assistant
The two-phase approach has correctness issues because each block independently does the reduction and reads partials that might not be fully written yet. There's a race condition between phase 1 and the reduce+normalize kernel.
Let me go back to the 3-phase approach which works correctly. The performance is just at the edge. Let me try to optimize it by using CUDA streams for overlapping, or by pre-allocating the buffers:
assistant
0.1011 — OK, and buffer caching helped somewhat. Let me run again to check stability:
assistant
Right at the edge. Let me try to improve shape 0 and 1 performance. These shapes have small total data (1MB and 4MB). The issue is the 3 kernel launches + L2 flush overhead. Let me see if using a larger block size for shape 0 specifically helps, or if I should try a different strategy for small shapes.
Actually, let me try a completely different approach: for small shapes, use the fused kernel with BLOCK equal to the entire row (so it's one block per row but the block covers the full vocab):
assistant
The fused kernel for small shapes is not helping much — shape 0 gets 0.031 which is even worse than before. The issue is fundamental: with batch=32 and vocab=4096, we have only 32 CTAs, and each CTA needs only 2 iterations with BLOCK=2048. The work is tiny.
Let me check what peak fraction the other approaches achieve for these small shapes. The benchmark reports `RESULT: OK` when gmean >= 0.1. The problem is the geomean is dragged down by shapes 0 and 1.
Let me try using multi-block for ALL shapes, but with the buffer caching, and see if I can get consistently above 0.1:
assistant
The results are highly variable: 0.1226, 0.0786, 0.1029. Run 1 was OK, run 2 was LOW, run 3 was OK. The variability is in shape 0, 1, and 3 where the absolute times are tiny (25-40 microseconds).
The problem is these small shapes have very high variance. The benchmark uses median of 30 trials which should help, but the L2 flush (128 MB write) before each trial creates interference patterns.
Let me try a different approach — use a CUDA graph for the 3-phase kernel to eliminate launch overhead:
assistant
Excellent! CUDA graphs gave 0.2179! Let me verify stability:
assistant
Very stable! All 5 runs are above 0.19 with peak at 0.22. Let me also run a full check again to make sure correctness is rock solid:
assistant
Let me do one more full benchmark run to see the detailed numbers:
assistant
The solution is working well with:
- **Correctness**: PASS across all 5 shapes × 3 seeds with atol=rtol=1e-5
- **Performance**: peak_fraction consistently 0.19-0.22 (RESULT: OK, well above 0.1 threshold)
- **Technique**: Multi-block 3-phase Triton kernel with Kahan-compensated summation and CUDA graph capture
Here's a summary of what was implemented:
The solution in `solution.py` implements a **3-phase multi-block Triton kernel** with **Kahan-compensated summation** for numerically accurate fp32 softmax:
1. **Phase 1**: Each CTA processes a chunk of a row, computing local max and local exp-sum, stored as partials.
2. **Phase 2**: One CTA per row reduces all partials — finds global max, then uses Kahan-compensated summation to accurately accumulate the adjusted exp-sums (critical for the 256K vocab shape where naive accumulation drifts past 1e-5).
3. **Phase 3**: Each CTA normalizes its chunk using the global max and inverse sum.
**CUDA graph capture** eliminates kernel launch overhead (3 launches per forward pass become a single graph replay), which is especially important for small shapes where launch overhead dominates. The kernel is memory-bandwidth bound and achieves ~350-630 GB/s on large shapes (vs 1800 GB/s peak).