"""Top-k via raw CUDA — bitonic sort per segment, then K-way merge.
Phase 1: grid = (batch, num_segs) — parallel bitonic sort per segment
Phase 2: grid = (batch,) — parallel bitonic sort of merged candidates
"""
from __future__ import annotations
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
_CPP_HEADER = (
'#include <torch/extension.h>\n'
'std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k);'
)
_CUDA_SOURCE = r"""
#include <cuda_fp16.h>
#include <cstdint>
#include <algorithm>
static constexpr float NEG_INF = -1e30f;
// =========================================================================
// Phase 1: segment-level bitonic sort
// =========================================================================
// Bitonic sort kernel: sorts N elements (power-of-2) in shared memory
// =========================================================================
// =========================================================================
// Phase 1: segment-level bitonic sort
// =========================================================================
template <int SEG_N, int BLOCK1, int ELEM1, int K_OUT>
__global__ void seg_bitonic(
const float* __restrict__ x,
float* __restrict__ seg_vals,
int* __restrict__ seg_idxs,
int64_t row_stride,
int num_segs,
int n
)
{
int bid_row = blockIdx.x;
int bid_seg = blockIdx.y;
int tid = threadIdx.x;
int seg_off = bid_seg * SEG_N;
__shared__ float sv[SEG_N];
__shared__ int si[SEG_N];
// Coalesced load
#pragma unroll
for (int e = 0; e < ELEM1; ++e) {
int g = seg_off + tid * ELEM1 + e;
int s = tid * ELEM1 + e;
if (g < n) {
sv[s] = x[(int64_t)bid_row * row_stride + g];
si[s] = g;
} else {
sv[s] = NEG_INF;
si[s] = -1;
}
}
__syncthreads();
// Bitonic sort (descending)
for (int stage = 2; stage <= SEG_N; stage <<= 1) {
for (int d = stage >> 1; d > 0; d >>= 1) {
#pragma unroll
for (int e = 0; e < ELEM1; ++e) {
int mi = tid * ELEM1 + e;
int jx = mi ^ d;
if (jx > mi && ((mi & stage) != 0) != ((jx & stage) != 0)) {
bool dir_up = (mi & (stage >> 1)) == 0;
bool swap_ = (dir_up && sv[mi] < sv[jx]) || (!dir_up && sv[mi] > sv[jx]);
if (swap_) {
float tv = sv[mi]; sv[mi] = sv[jx]; sv[jx] = tv;
int ti2= si[mi]; si[mi] = si[jx]; si[jx] = ti2;
}
}
}
__syncthreads();
}
}
// Output top K_OUT
int ob = (bid_row * num_segs + bid_seg) * K_OUT;
#pragma unroll
for (int ki = 0; ki < K_OUT; ++ki) {
if (tid == 0) {
seg_vals[ob + ki] = sv[ki];
seg_idxs[ob + ki] = si[ki];
}
}
}
// =========================================================================
// Phase 2: merge all segment results via bitonic sort
// =========================================================================
template <int N_CAND, int B2, int E2>
__global__ void merge_bitonic(
const float* __restrict__ seg_vals,
const int* __restrict__ seg_idxs,
float* __restrict__ values,
int64_t* __restrict__ indices,
int k_out
)
{
int bid = blockIdx.x;
int tid = threadIdx.x;
__shared__ float sv[N_CAND];
__shared__ int si[N_CAND];
// Load candidates (bid row)
int rb = bid * N_CAND;
#pragma unroll
for (int e = 0; e < E2; ++e) {
int g = tid * E2 + e;
sv[g] = seg_vals[rb + g];
si[g] = seg_idxs[rb + g];
}
__syncthreads();
// Bitonic sort (descending)
for (int stage = 2; stage <= N_CAND; stage <<= 1) {
for (int d = stage >> 1; d > 0; d >>= 1) {
#pragma unroll
for (int e = 0; e < E2; ++e) {
int mi = tid * E2 + e;
int jx = mi ^ d;
if (jx > mi && ((mi & stage) != 0) != ((jx & stage) != 0)) {
bool dir_up = (mi & (stage >> 1)) == 0;
bool swap_ = (dir_up && sv[mi] < sv[jx]) || (!dir_up && sv[mi] > sv[jx]);
if (swap_) {
float tv = sv[mi]; sv[mi] = sv[jx]; sv[jx] = tv;
int ti2= si[mi]; si[mi] = si[jx]; si[jx] = ti2;
}
}
}
__syncthreads();
}
}
// Write top k_out
if (tid < k_out) {
values[(int64_t)bid * k_out + tid] = sv[tid];
indices[(int64_t)bid * k_out + tid] = static_cast<int64_t>(si[tid]);
}
}
// =========================================================================
// Dispatcher
// =========================================================================
#include <torch/extension.h>
static int next_pow2(int v) {
int p = 1;
while (p < v) p <<= 1;
return p;
}
// Phase 2 dispatcher: generate merge_bitonic instantiations
template <int N_CAND>
void dispatch_p2(int batch, int k_out,
float* seg_v, int* seg_i, float* out_v, int64_t* out_i)
{
constexpr int B2 = (N_CAND >= 256) ? 256 : N_CAND;
constexpr int E2 = N_CAND / B2;
merge_bitonic<N_CAND, B2, E2><<<batch, B2>>>(seg_v, seg_i, out_v, out_i, k_out);
}
#define P2_DISPATCH(NC) \
dispatch_p2<NC>(batch, k_out, seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), \
values.data_ptr<float>(), indices.data_ptr<int64_t>())
// Phase 1 dispatcher by K template
template <int K>
void do_p1(int batch, int num_segs, dim3 grid, int block,
float* x, float* sv, int* si, int64_t rs, int n)
{
constexpr int SN = 2048;
constexpr int B1 = 256;
constexpr int E1 = SN / B1; // =8
seg_bitonic<SN, B1, E1, K><<<grid, B1>>>(x, sv, si, rs, num_segs, n);
}
std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k_out)
{
TORCH_CHECK(x.is_cuda() && x.is_contiguous());
TORCH_CHECK(x.dtype() == torch::kFloat32);
TORCH_CHECK(x.dim() == 2);
int batch = x.size(0);
int n = x.size(1);
const int SEG_N = 2048;
const int num_segs = (n + SEG_N - 1) / SEG_N;
// Pick template K (round up)
int K;
if (k_out <= 1) K = 1;
else if (k_out <= 8) K = 8;
else if (k_out <= 16) K = 16;
else if (k_out <= 32) K = 32;
else K = 64;
auto seg_vals = torch::full({batch, num_segs, K}, -1e30f, x.options());
auto seg_idxs = torch::full({batch, num_segs, K}, -1, x.options().dtype(torch::kInt32));
dim3 grid2d(batch, num_segs);
auto xp = x.data_ptr<float>();
// Phase 1
if (K == 1) do_p1<1> (batch, num_segs, grid2d, 256, xp, seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), x.stride(0), n);
else if (K == 8) do_p1<8> (batch, num_segs, grid2d, 256, xp, seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), x.stride(0), n);
else if (K == 16) do_p1<16>(batch, num_segs, grid2d, 256, xp, seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), x.stride(0), n);
else if (K == 32) do_p1<32>(batch, num_segs, grid2d, 256, xp, seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), x.stride(0), n);
else do_p1<64>(batch, num_segs, grid2d, 256, xp, seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), x.stride(0), n);
auto values = torch::empty({batch, k_out}, x.options());
auto indices = torch::empty({batch, k_out}, x.options().dtype(torch::kInt64));
int n_cand = num_segs * K;
int nc2 = next_pow2(n_cand);
// Phase 2
if (nc2 <= 64) P2_DISPATCH(64);
else if (nc2 <= 128) P2_DISPATCH(128);
else if (nc2 <= 256) P2_DISPATCH(256);
else if (nc2 <= 512) P2_DISPATCH(512);
else if (nc2 <= 1024) P2_DISPATCH(1024);
else if (nc2 <= 2048) P2_DISPATCH(2048);
else P2_DISPATCH(4096);
return {values, indices};
}
"""
def _compile_module():
return load_inline(
name='topk_cuda_ext',
cpp_sources=_CPP_HEADER,
cuda_sources=_CUDA_SOURCE,
functions=['run_topk'],
extra_cuda_cflags=[
'-O3', '-arch=sm_120', '--use_fast_math',
'-Xptxas=-O3,-v', '-maxrregcount=64',
],
extra_cflags=['-O3'],
with_cuda=True,
)
_ext = None
class Model(nn.Module):
"""Top-k via CUDA shared-memory bitonic-sort kernel."""
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch, self.n, self.k = batch, n, k
self.register_buffer('_dummy', torch.zeros(1))
def forward(self, x: torch.Tensor):
global _ext
if _ext is None:
_ext = _compile_module()
if x.ndim != 2:
raise ValueError(f'Expected 2D input, got {x.ndim}D')
vals, idxs = _ext.run_topk(x.contiguous(), self.k)
return vals, idxs.to(torch.int64)
def get_inputs():
batch, n = 64, 8192
x = torch.randn(batch, n, dtype=torch.float32)
return [x]
batch = 64
n = 8192
k = 8
def get_init_inputs():
return [batch, n, k]
FAIL: shape 0 {'batch': 1, 'n': 131072, 'k': 64} seed 42 values: tolerance exceeded: max_abs_diff=6.22603 (atol=0.0001, rtol=0.0001)
{
"problem": "05_topk_bitonic",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-max-preview",
"reasoning_effort": "",
"has_solution": true,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 6317169, "output_tokens": 46431, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 34694, "total_cost_usd": null}
}
{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/reference.py</path>
<type>file</type>
<content>
1: """Naive top-k reference: torch.topk over the last dim.
2:
3: This is the correctness oracle. The agent's solution must produce the same
4: top-k values (and equivalent indices modulo ties) within the tolerance
5: declared in problem.yaml. Note that solution.py is FORBIDDEN from calling
6: torch.topk / torch.sort / torch.kthvalue (see problem.yaml).
7: """
8: import torch
9: import torch.nn as nn
10:
11: OP_TYPE = "topk"
12: SUPPORTED_PRECISIONS = ["fp32"]
13: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
14:
15:
16: class Model(nn.Module):
17: """Top-k over the last dim of a 2D tensor.
18:
19: Input:
20: x: (batch, n) fp32
21: Output:
22: values: (batch, k) fp32, sorted descending
23: indices: (batch, k) int64, into the last dim of x
24: """
25:
26: def __init__(self, batch: int, n: int, k: int):
27: super().__init__()
28: self.batch, self.n, self.k = batch, n, k
29: # No learned parameters, but declare a dummy buffer so state_dict
30: # is non-empty and load_state_dict(strict=True) is meaningful.
31: self.register_buffer("_dummy", torch.zeros(1))
32:
33: def forward(self, x: torch.Tensor):
34: values, indices = torch.topk(x, k=self.k, dim=-1, largest=True, sorted=True)
35: return values, indices
36:
37:
38: # Module-level shims rebuilt b
... (419 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for TopK.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each. Top-k correctness has two parts:
5:
6: 1. VALUES: sol_values must match ref_values within fp32 tol. Both are
7: returned sorted descending, so positional comparison is well-defined.
8: 2. INDICES: lenient — we do NOT require sol_indices == ref_indices because
9: ties in x can yield multiple valid index sets. Instead we gather x at
10: sol_indices and check those values match ref_values within tol. This
11: catches "wrong indices" without false-failing on legitimate tie-breaks.
12:
13: Also rejects forbidden ops by grep.
14: """
15: import re
16: import sys
17: from pathlib import Path
18:
19: import torch
20: import yaml
21:
22: REPO_ROOT = Path(__file__).resolve().parents[2]
23: sys.path.insert(0, str(REPO_ROOT))
24:
25: from src.eval.correctness import check_correctness # noqa: E402
26:
27:
28: def main():
29: try:
30: import reference
31: import shapes
32: import solution
33: except Exception as e:
34: print(f"FAIL: import error: {e}")
35: sys.exit(1)
36:
37: problem_yaml = Path("problem.yaml")
38: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
39:
40: # --- Forbidden-op check --------------------------------
... (4902 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for TopK.
2:
3: For each shape: times eager reference (torch.topk), compiled reference, SOTA
4: (also torch.topk — see sota.py), and the agent's solution. Reports achieved
5: TFLOPS, GB/s, and peak_fraction (vs DRAM bandwidth, since this is memory-bound).
6:
7: Output lines the harness picks up:
8: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10: """
11: import sys
12: from math import exp, log
13: from pathlib import Path
14:
15: import torch
16: import yaml
17:
18: REPO_ROOT = Path(__file__).resolve().parents[2]
19: sys.path.insert(0, str(REPO_ROOT))
20:
21: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22: from src.eval.timing import time_fn # noqa: E402
23: from src.hardware import get as get_hw # noqa: E402
24:
25:
26: def _eval_formula(expr: str, vars: dict) -> float:
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "memory")
40
... (3250 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for TopK.
2:
3: Mix of:
4: - decoder vocab top-k (single sequence, very large n, moderate k) — pure
5: bandwidth test; the input read dominates everything.
6: - prefill / batched attention top-k (many rows, moderate n, small k) — tests
7: per-row parallelism and shared-memory bitonic networks.
8: - non-power-of-2 n stress case — bitonic sort networks naturally want
9: powers of two; this forces the agent to handle padding or partial sorts.
10: - small-k limit — k=1 (argmax) is a degenerate but useful sanity case.
11: """
12:
13: SHAPES = [
14: {"batch": 1, "n": 131072, "k": 64}, # decoder vocab top-k (Llama vocab ~128k)
15: {"batch": 64, "n": 8192, "k": 8}, # prefill / attention top-k
16: {"batch": 32, "n": 16384, "k": 32}, # mid-size batched
17: {"batch": 16, "n": 12000, "k": 16}, # non-power-of-2 n stress
18: {"batch": 128, "n": 4096, "k": 1}, # batched argmax (k=1 corner case)
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/problem.yaml</path>
<type>file</type>
<content>
1: name: 05_topk_bitonic
2: display_name: "TopK via Bitonic Sort"
3: precision: fp32
4: regime: memory
5:
6: # Top-k is dominated by the input read (small output, no reduction over k).
7: # Comparator-network FLOPs are not the bottleneck on real hardware, so we
8: # track them but score on bandwidth.
9: flops_formula: "batch * n * 4" # ~O(n log^2 n) compares total, but the
10: # bitonic network is so cheap relative
11: # to memory that a coarse 4*n estimate
12: # is fine for telemetry only.
13: bytes_formula: "batch * n * 4 + batch * k * (4 + 8)" # fp32 input read + (fp32 value + int64 idx) output
14:
15: hardware: [RTX_PRO_6000]
16: peak_tflops_key: fp32
17: peak_bandwidth_key: dram
18:
19: # Top-k correctness:
20: # - VALUES must match within fp32 atol/rtol (the kth largest value is
21: # well-defined modulo float-equal ties, so we use a loose-ish tol).
22: # - INDICES are checked leniently: for each row, the multiset of returned
23: # indices must select values that match ref values within tol. Direct
24: # index equality is NOT required (ties in x can yield different valid
25: # index sets).
26: tolerance:
27: float32: 1.0e-4
28:
29: # Forbidden ops — using any of these in solution.py fails corre
... (1071 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to optimize a top-k kernel for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The input is a 2D fp32 tensor and you return the top-k values and their int64 indices along the last dim, sorted descending, same contract as torch.topk(largest=True, sorted=True). Correctness tolerance is 1e-4 abs/rel on the values; index ties are handled leniently — what matters is the returned values match the reference to tolerance, not that the index sets are bitwise identical. The shapes you have to handle are batch=1 n=131072 k=64 (decoder vocab top-k over a Llama-size vocabulary), batch=64 n=8192 k=8 (prefill / attention top-k), batch=32 n=16384 k=32 (mid-size batched), batch=16 n=12000 k=16 (non-power-of-two n, which bitonic networks don't naturally want), and batch=128 n=4096 k=1 (batched argmax).
4:
5: This needs to be a real custom kernel — CUDA C++ via torch.utils.cpp_extension.load_inline, Triton, inline PTX, or CUTLASS, whatever you think fits. Don't reach for torch.topk, torch.kthvalue, torch.sort, or torch.ar
... (749 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for TopK: torch.topk itself.
2:
3: torch.topk dispatches to a CUB-backed kernel that uses radix-select for
4: moderate k and a tuned bitonic sort for small n. It is the bar the agent's
5: hand-rolled bitonic kernel must beat. There is no obvious vendor library that
6: does better on the (batch, n, k) shape mix we evaluate — Faiss BlockSelect is
7: specialized for k>=32 with much larger n, and CUB's DeviceSegmentedRadixSort
8: sorts the full row (overkill for top-k).
9:
10: This file is INTENTIONALLY allowed to call torch.topk because it is the SOTA
11: oracle, not the agent's submission. The agent's solution.py is forbidden from
12: using torch.topk (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def sota_forward(x: torch.Tensor, k: int):
20: """Best-available top-k reference. x: (batch, n) fp32."""
21: return torch.topk(x, k=k, dim=-1, largest=True, sorted=True)
22:
23:
24: def is_available() -> bool:
25: return True
(End of file - total 25 lines)
</content>{
"command": "nvidia-smi --query-gpu=name,driver_version,memory.total,memory.free --format=csv,noheader",
"description": "Check GPU availability and memory"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 595.58.03, 97887 MiB, 96579 MiB
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware</path>
<type>directory</type>
<entries>
__init__.py
__pycache__/
m4_max.py
rtx_pro_6000.py
(4 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"todos": [
{
"content": "First Triton-based top-k kernel implementation",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py to get baseline performance",
"status": "pending",
"priority": "high"
},
{
"content": "Profile with ncu and optimize kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Iterate until performance target met",
"status": "pen
... (43 more chars)content: 'First Triton-based top-k kernel implementation'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py to get baseline performance'
status: 'pending'
priority: 'high'
content: 'Profile with ncu and optimize kernel'
status: 'pending'
priority: 'high'
content: 'Iterate until performance target met'
status: 'pending'
priority: 'high'{
"command": "python3 -c \"import torch; print(f'CUDA: {torch.version.cuda}, arch: {torch.cuda.get_device_capability(0)}')\"",
"description": "Check CUDA version and compute capability"
}CUDA: 13.0, arch: (12, 0)
{
"command": "python3 -c \"import triton; print(f'Triton: {triton.__version__}')\"",
"description": "Check Triton availability"
}Triton: 3.6.0
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -0,0 +1,181 @@
+"""Top-k kernel via Triton — per-row processing with chunk-level selection."""
+from __future__ import annotations
+
+import math
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _topk_kernel(
+ x_ptr,
+ val_out_ptr,
+ idx_out_ptr,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ stride_x_row,
+ stride_val_row,
+ stride_idx_row,
+ BLOCK_N: tl.constexpr,
+):
+ """One program per row.
+
+ Strategy:
+ 1. Scan the row in chunks of BLOCK_N.
+ 2. Per chunk: load into shared memory, each thread keeps its local top-K.
+ 3. After scanning, sort the local top-K buffer and write output.
+
+ For simplicity in v1: each thread maintains top-K in registers
+ from its strided portion, then we do a warp-level merge.
+ """
+ row_idx = tl.program_id(0)
+
+ x_row = x_ptr + row_idx * stride_x_row
+
+ # Each thread will maintain a small top-k buffer.
+ # We process elements in a strided loop.
+ # For K elements, we store them in registers.
+ # This kernel uses a simple approach: each thread scans its
+ # strided portion and keeps running top-k.
+
+ pid = tl.arange(0, BLOCK_N) # thread indices within block
+ mask = pid < N
+
+ # For small N that fit in one block: load everything, sort, take top-k.
+ # For large N: we need multi-pass.
+
+ # For simplicity, handle N <= BLOCK_N directly:
+ if N <= BLOCK_N:
+ vals = tl.load(x_row + pid, mask=mask, other=-float('inf'))
+ vals_with_idx = vals
+
+ # We need to sort descending and take top-k.
+ # Triton doesn't have a built-in sort, so we use a bitonic approach.
+ # But for large BLOCK_N, this is expensive in v1.
+ pass
+
+ # Fallback general path: multi-pass scan.
+ # This first version does a simple approach with tl.max in a loop.
+ # It's not optimal but gets us started.
+
+
+@triton.jit
+def _select_topk_single_pass(
+ x_ptr,
+ val_out_ptr,
+ idx_out_ptr,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ stride_x_row,
+ stride_val_row,
+ stride_idx_row,
+ BLOCK_N: tl.constexpr,
+):
+ """For rows where BLOCK_N >= N. Single block, load all, select top-k."""
+ row_idx = tl.program_id(0)
+ x_row = x_ptr + row_idx * stride_x_row
+ val_row = val_out_ptr + row_idx * stride_val_row
+ idx_row = idx_out_ptr + row_idx * stride_idx_row
+
+ off = tl.arange(0, BLOCK_N)
+ mask = off < N
+ vals = tl.load(x_row + off, mask=mask, other=-float('inf'))
+
+ # Build top-k using repeated tl.max extraction
+ for ki in range(K):
+ # Find max among remaining
+ best_val = tl.max(vals, axis=0)
+ # Find index of max
+ # Triton doesn't have argmax directly, need to construct it
+ best_mask = (vals == best_val)
+ best_idx_raw = tl.where(best_mask, off.to(tl.int64), N + ki).to(tl.int64)
+ best_idx = tl.min(best_idx_raw.to(tl.int32), axis=0)
+
+ # Store at position ki
+ tl.store(val_row + ki, best_val)
+ tl.store(idx_row + ki, best_idx.to(tl.int64))
+
+ # Remove from consideration (set to -inf at best_idx position)
+ # We can't directly modify at a position, so rebuild
+ vals = tl.where(off == best_idx, -float('inf'), vals)
+
+
+@triton.jit
+def _topk_scan_kernel(
+ x_ptr,
+ val_out_ptr,
+ idx_out_ptr,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ stride_x_row,
+ stride_val_row,
+ stride_idx_row,
+):
+ """One program per row. Scan-based top-k selection.
+
+ Each thread finds its local max in a strided loop,
+ then we reduce across threads.
+ Repeat K times for top-K.
+
+ This is O(K * N) which is not optimal but gives a working baseline.
+ For K=64, N=131072 this would be ~8M ops - workable.
+ """
+ row_idx = tl.program_id(0)
+ tid = tl.program_id(1)
+
+ # For this approach, we use the fact that we can
+ # do a strided max-finding pass K times.
+ # Not the most efficient but simple and correct.
+ pass
+
+
+class Model(nn.Module):
+ """Top-k via Triton kernel."""
+
+ def __init__(self, batch: int, n: int, k: int):
+ super().__init__()
+ self.batch, self.n, self.k = batch, n, k
+ self.register_buffer("_dummy", torch.zeros(1))
+
+ def forward(self, x: torch.Tensor):
+ batch, n = x.shape[0], x.shape[1]
+ k = self.k
+ device = x.device
+
+ val_out = torch.empty(batch, k, dtype=torch.float32, device=device)
+ idx_out = torch.empty(batch, k, dtype=torch.int64, device=device)
+
+ # Choose block size: pad N to next power of 2
+ block_n = triton.next_power_of_2(n)
+
+ # Use Triton launch with one program per row
+ grid = (batch,)
+
+ _select_topk_single_pass[grid](
+ x, val_out, idx_out,
+ N=n, K=k,
+ stride_x_row=x.stride(0),
+ stride_val_row=val_out.stride(0),
+ stride_idx_row=idx_out.stride(0),
+ BLOCK_N=block_n,
+ num_ctas=1,
+ )
+
+ return val_out, idx_out
+
+
+def get_inputs():
+ batch, n = 64, 8192
+ x = torch.randn(batch, n, dtype=torch.float32)
+ return [x]
+
+
+batch = 64
+n = 8192
+k = 8
+
+
+def get_init_inputs():
+ return [batch, n, k]
{
"content": "<5250 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,7 +1,6 @@
-"""Top-k kernel via Triton — per-row processing with chunk-level selection."""
+"""Top-k kernel via Triton — two-phase chunk select + merge."""
from __future__ import annotations
-import math
import torch
import torch.nn as nn
import triton
@@ -9,131 +8,100 @@
@triton.jit
-def _topk_kernel(
+def _chunk_select_kernel(
x_ptr,
- val_out_ptr,
- idx_out_ptr,
+ chunk_val_ptr,
+ chunk_idx_ptr,
N: tl.constexpr,
K: tl.constexpr,
stride_x_row,
+ CHUNK_SIZE: tl.constexpr,
+):
+ """Phase 1: Each program handles one chunk of one row.
+
+ For a chunk of CHUNK_SIZE elements, find the top-K and store.
+ Uses repeated max-extraction: K iterations, each finds the current max.
+
+ grid = (batch, num_chunks)
+ """
+ row_idx = tl.program_id(0)
+ chunk_idx = tl.program_id(1)
+
+ chunk_start = chunk_idx * CHUNK_SIZE
+ x_row = x_ptr + row_idx * stride_x_row
+
+ out_row = chunk_idx * K
+ out_base = (row_idx * tl.cdiv(N, CHUNK_SIZE) + chunk_idx) * K
+
+ offs = chunk_start + tl.arange(0, CHUNK_SIZE)
+ mask = offs < N
+
+ vals = tl.load(x_row + offs, mask=mask, other=-float('inf'))
+
+ for ki in tl.range(K):
+ m = tl.max(vals)
+ # Argmax: find the first position with value == m
+ m_idx = tl.where(vals == m, offs.to(tl.int32), 0x7FFFFFFF).to(tl.int32)
+ arg = tl.min(m_idx)
+
+ tl.store(chunk_val_ptr + out_base + ki, m)
+ tl.store(chunk_idx_ptr + out_base + ki, arg.to(tl.int64))
+
+ vals = tl.where(vals == m, -float('inf'), vals)
+
+
+@triton.jit
+def _merge_kernel(
+ chunk_val_ptr,
+ chunk_idx_ptr,
+ val_out_ptr,
+ idx_out_ptr,
+ num_chunks: tl.constexpr,
+ K_CHUNK: tl.constexpr,
+ K_OUT: tl.constexpr,
stride_val_row,
stride_idx_row,
- BLOCK_N: tl.constexpr,
):
- """One program per row.
-
- Strategy:
- 1. Scan the row in chunks of BLOCK_N.
- 2. Per chunk: load into shared memory, each thread keeps its local top-K.
- 3. After scanning, sort the local top-K buffer and write output.
+ """Phase 2: Each program handles one row. Merge num_chunks * K_CHUNK into top-K_OUT.
- For simplicity in v1: each thread maintains top-K in registers
- from its strided portion, then we do a warp-level merge.
+ grid = (batch,)
+ Total candidates = num_chunks * K_CHUNK.
+ Extract top-K_OUT via repeated max.
"""
row_idx = tl.program_id(0)
- x_row = x_ptr + row_idx * stride_x_row
+ row_base = row_idx * num_chunks * K_CHUNK
+ n_candidates = num_chunks * K_CHUNK
- # Each thread will maintain a small top-k buffer.
- # We process elements in a strided loop.
- # For K elements, we store them in registers.
- # This kernel uses a simple approach: each thread scans its
- # strided portion and keeps running top-k.
-
- pid = tl.arange(0, BLOCK_N) # thread indices within block
- mask = pid < N
-
- # For small N that fit in one block: load everything, sort, take top-k.
- # For large N: we need multi-pass.
-
- # For simplicity, handle N <= BLOCK_N directly:
- if N <= BLOCK_N:
- vals = tl.load(x_row + pid, mask=mask, other=-float('inf'))
- vals_with_idx = vals
-
- # We need to sort descending and take top-k.
- # Triton doesn't have a built-in sort, so we use a bitonic approach.
- # But for large BLOCK_N, this is expensive in v1.
+ # Load all candidates
+ offs = tl.arange(0, tl.max_constexpr([n_candidates]))
+ if offs.max() < n_candidates:
pass
- # Fallback general path: multi-pass scan.
- # This first version does a simple approach with tl.max in a loop.
- # It's not optimal but gets us started.
+ # For variable size, load in chunks
+ # Since we know num_chunks and K_CHUNK at compile time,
+ # total = num_chunks * K_CHUNK
+
+ TOTAL = num_chunks * K_CHUNK
+ load_offs = tl.arange(0, TOTAL)
+
+ cand_vals = tl.load(chunk_val_ptr + row_base + load_offs)
+ cand_idxs = tl.load(chunk_idx_ptr + row_base + load_offs)
+
+ for ki in tl.range(K_OUT):
+ m = tl.max(cand_vals)
+ m_idx = tl.where(cand_vals == m, load_offs.to(tl.int32), 0x7FFFFFFF).to(tl.int32)
+ arg = tl.min(m_idx)
+
+ orig_idx = tl.load(chunk_idx_ptr + row_base + arg)
+
+ tl.store(val_out_ptr + row_idx * stride_val_row + ki, m)
+ tl.store(idx_out_ptr + row_idx * stride_idx_row + ki, orig_idx)
+
+ cand_vals = tl.where(load_offs == arg, -float('inf'), cand_vals)
-@triton.jit
-def _select_topk_single_pass(
- x_ptr,
- val_out_ptr,
- idx_out_ptr,
- N: tl.constexpr,
- K: tl.constexpr,
- stride_x_row,
- stride_val_row,
- stride_idx_row,
- BLOCK_N: tl.constexpr,
-):
- """For rows where BLOCK_N >= N. Single block, load all, select top-k."""
- row_idx = tl.program_id(0)
- x_row = x_ptr + row_idx * stride_x_row
- val_row = val_out_ptr + row_idx * stride_val_row
- idx_row = idx_out_ptr + row_idx * stride_idx_row
-
- off = tl.arange(0, BLOCK_N)
- mask = off < N
- vals = tl.load(x_row + off, mask=mask, other=-float('inf'))
-
- # Build top-k using repeated tl.max extraction
- for ki in range(K):
- # Find max among remaining
- best_val = tl.max(vals, axis=0)
- # Find index of max
- # Triton doesn't have argmax directly, need to construct it
- best_mask = (vals == best_val)
- best_idx_raw = tl.where(best_mask, off.to(tl.int64), N + ki).to(tl.int64)
- best_idx = tl.min(best_idx_raw.to(tl.int32), axis=0)
-
- # Store at position ki
- tl.store(val_row + ki, best_val)
- tl.store(idx_row + ki, b
... (3093 more chars){
"content": "<4852 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<612 chars \u2014 see below>",
"description": "Quick correctness test",
"timeout": 120000
}timeout 120 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
# Quick test with small shape
batch, n, k = 4, 256, 8
model = solution.Model(batch, n, k).to(device).eval()
x = torch.randn(batch, n, device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
print(f'Output shapes: vals={vals.shape}, idxs={idxs.shape}')
print(f'vals: {vals[0]}')
print(f'idxs: {idxs[0]}')
# Check against torch.topk
ref_vals, ref_idxs = torch.topk(x, k, dim=-1, largest=True, sorted=True)
print(f'ref vals: {ref_vals[0]}')
print(f'diff: {(vals - ref_vals).abs().max().item()}')
"
Traceback (most recent call last):
File "<string>", line 12, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 142, in forward
_merge_kernel[grid2](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/compiler
... (962 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -7,13 +7,26 @@
import triton.language as tl
+# Maximum chunk size for phase 1 (must fit in register file per thread)
+MAX_CHUNK = 1024
+
+
+def _get_chunk_size(n: int, k: int) -> int:
+ """Return chunk size for phase 1. Must be >= k and <= MAX_CHUNK, power of 2."""
+ chunk = min(MAX_CHUNK, triton.next_power_of_2(k))
+ # Ensure chunk >= k
+ if chunk < k:
+ chunk = triton.next_power_of_2(k)
+ return min(chunk, MAX_CHUNK)
+
+
@triton.jit
def _chunk_select_kernel(
x_ptr,
chunk_val_ptr,
chunk_idx_ptr,
- N: tl.constexpr,
- K: tl.constexpr,
+ N,
+ K,
stride_x_row,
CHUNK_SIZE: tl.constexpr,
):
@@ -30,24 +43,24 @@
chunk_start = chunk_idx * CHUNK_SIZE
x_row = x_ptr + row_idx * stride_x_row
- out_row = chunk_idx * K
- out_base = (row_idx * tl.cdiv(N, CHUNK_SIZE) + chunk_idx) * K
+ # Number of chunks per row (computed at compile time since CHUNK_SIZE is constexpr)
+ num_chunks = tl.cdiv(N.to(tl.int32), CHUNK_SIZE.to(tl.int32))
+ out_base = (row_idx * num_chunks + chunk_idx) * K
offs = chunk_start + tl.arange(0, CHUNK_SIZE)
mask = offs < N
- vals = tl.load(x_row + offs, mask=mask, other=-float('inf'))
+ vals = tl.load(x_row + offs, mask=mask, other=-1e30)
- for ki in tl.range(K):
+ for ki in range(K):
m = tl.max(vals)
- # Argmax: find the first position with value == m
- m_idx = tl.where(vals == m, offs.to(tl.int32), 0x7FFFFFFF).to(tl.int32)
+ m_idx = tl.where(vals == m, offs.to(tl.int32), 0x7FFFFFFF)
arg = tl.min(m_idx)
tl.store(chunk_val_ptr + out_base + ki, m)
tl.store(chunk_idx_ptr + out_base + ki, arg.to(tl.int64))
- vals = tl.where(vals == m, -float('inf'), vals)
+ vals = tl.where(offs == arg, -1e30, vals)
@triton.jit
@@ -56,41 +69,31 @@
chunk_idx_ptr,
val_out_ptr,
idx_out_ptr,
- num_chunks: tl.constexpr,
- K_CHUNK: tl.constexpr,
- K_OUT: tl.constexpr,
+ num_chunks,
+ K_CHUNK,
+ K_OUT,
stride_val_row,
stride_idx_row,
+ TOTAL_CANDIDATES: tl.constexpr,
):
- """Phase 2: Each program handles one row. Merge num_chunks * K_CHUNK into top-K_OUT.
+ """Phase 2: Each program handles one row. Merge num_chunks*K_CHUNK into top-K_OUT.
grid = (batch,)
- Total candidates = num_chunks * K_CHUNK.
- Extract top-K_OUT via repeated max.
+ Total candidates = num_chunks * K_CHUNK, loaded in blocks.
"""
row_idx = tl.program_id(0)
row_base = row_idx * num_chunks * K_CHUNK
- n_candidates = num_chunks * K_CHUNK
- # Load all candidates
- offs = tl.arange(0, tl.max_constexpr([n_candidates]))
- if offs.max() < n_candidates:
- pass
-
- # For variable size, load in chunks
- # Since we know num_chunks and K_CHUNK at compile time,
- # total = num_chunks * K_CHUNK
-
- TOTAL = num_chunks * K_CHUNK
- load_offs = tl.arange(0, TOTAL)
+ # Load all candidates at once using constexpr TOTAL_CANDIDATES
+ load_offs = tl.arange(0, TOTAL_CANDIDATES)
cand_vals = tl.load(chunk_val_ptr + row_base + load_offs)
cand_idxs = tl.load(chunk_idx_ptr + row_base + load_offs)
- for ki in tl.range(K_OUT):
+ for ki in range(K_OUT):
m = tl.max(cand_vals)
- m_idx = tl.where(cand_vals == m, load_offs.to(tl.int32), 0x7FFFFFFF).to(tl.int32)
+ m_idx = tl.where(cand_vals == m, load_offs.to(tl.int32), 0x7FFFFFFF)
arg = tl.min(m_idx)
orig_idx = tl.load(chunk_idx_ptr + row_base + arg)
@@ -98,7 +101,7 @@
tl.store(val_out_ptr + row_idx * stride_val_row + ki, m)
tl.store(idx_out_ptr + row_idx * stride_idx_row + ki, orig_idx)
- cand_vals = tl.where(load_offs == arg, -float('inf'), cand_vals)
+ cand_vals = tl.where(load_offs == arg, -1e30, cand_vals)
class Model(nn.Module):
@@ -113,22 +116,21 @@
device = x.device
# Phase 1: chunk-level selection
- chunk_size = 1024 # power of 2, must be >= k
+ chunk_size = _get_chunk_size(n, k)
num_chunks = triton.cdiv(n, chunk_size)
- # If chunk_size < k, increase chunk_size
- if chunk_size < k:
- chunk_size = triton.next_power_of_2(k)
- num_chunks = triton.cdiv(n, chunk_size)
+ # Phase 2 candidate count (must be constexpr for merge kernel)
+ # We pad total_candidates to next power of 2 for tl.arange
+ total_candidates = num_chunks * k
+ total_candidates_padded = triton.next_power_of_2(total_candidates)
- # Temp storage for chunk-level top-k
- chunk_vals = torch.empty(batch, num_chunks, k, dtype=torch.float32, device=device)
- chunk_idxs = torch.empty(batch, num_chunks, k, dtype=torch.int64, device=device)
+ # Temp storage for chunk-level top-k (padded)
+ chunk_vals = torch.full((batch, total_candidates_padded,), -1e30, dtype=torch.float32, device=device)
+ chunk_idxs = torch.full((batch, total_candidates_padded,), -1, dtype=torch.int64, device=device)
grid1 = (batch, num_chunks)
-
_chunk_select_kernel[grid1](
- x, chunk_vals.view(-1), chunk_idxs.view(-1),
+ x, chunk_vals, chunk_idxs,
N=n, K=k,
stride_x_row=x.stride(0),
CHUNK_SIZE=chunk_size,
@@ -140,11 +142,12 @@
grid2 = (batch,)
_merge_kernel[grid2](
- chunk_vals.view(-1), chunk_idxs.view(-1),
+ chunk_vals, chunk_idxs,
val_out, idx_out,
num_chunks=num_chunks, K_CHUNK=k, K_OUT=k,
stride_val_row=val_out.stride(0),
stride_idx_row=idx_out.stride(0),
+
... (107 more chars){
"content": "<5034 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<628 chars \u2014 see below>",
"description": "Quick correctness test v2",
"timeout": 120000
}timeout 120 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
# Quick test with small shape
batch, n, k = 4, 256, 8
model = solution.Model(batch, n, k).to(device).eval()
x = torch.randn(batch, n, device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
print(f'Output shapes: vals={vals.shape}, idxs={idxs.shape}')
print(f'vals[0]: {vals[0].cpu()}')
# Check against torch.topk
ref_vals, ref_idxs = torch.topk(x, k, dim=-1, largest=True, sorted=True)
print(f'ref[0]: {ref_vals[0].cpu()}')
print(f'max_abs_diff: {(vals - ref_vals).abs().max().item():.2e}')
" 2>&1 | head -80
Traceback (most recent call last):
File "<string>", line 12, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 132, in forward
_chunk_select_kernel[grid1](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/c
... (1066 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -7,33 +7,20 @@
import triton.language as tl
-# Maximum chunk size for phase 1 (must fit in register file per thread)
-MAX_CHUNK = 1024
-
-
-def _get_chunk_size(n: int, k: int) -> int:
- """Return chunk size for phase 1. Must be >= k and <= MAX_CHUNK, power of 2."""
- chunk = min(MAX_CHUNK, triton.next_power_of_2(k))
- # Ensure chunk >= k
- if chunk < k:
- chunk = triton.next_power_of_2(k)
- return min(chunk, MAX_CHUNK)
-
-
@triton.jit
def _chunk_select_kernel(
x_ptr,
chunk_val_ptr,
chunk_idx_ptr,
N,
- K,
stride_x_row,
+ num_chunks,
+ K: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
):
"""Phase 1: Each program handles one chunk of one row.
- For a chunk of CHUNK_SIZE elements, find the top-K and store.
- Uses repeated max-extraction: K iterations, each finds the current max.
+ For a chunk of CHUNK_SIZE elements, find the top-K via repeated max.
grid = (batch, num_chunks)
"""
@@ -42,9 +29,6 @@
chunk_start = chunk_idx * CHUNK_SIZE
x_row = x_ptr + row_idx * stride_x_row
-
- # Number of chunks per row (computed at compile time since CHUNK_SIZE is constexpr)
- num_chunks = tl.cdiv(N.to(tl.int32), CHUNK_SIZE.to(tl.int32))
out_base = (row_idx * num_chunks + chunk_idx) * K
offs = chunk_start + tl.arange(0, CHUNK_SIZE)
@@ -70,30 +54,26 @@
val_out_ptr,
idx_out_ptr,
num_chunks,
- K_CHUNK,
- K_OUT,
stride_val_row,
stride_idx_row,
+ K_OUT: tl.constexpr,
TOTAL_CANDIDATES: tl.constexpr,
):
- """Phase 2: Each program handles one row. Merge num_chunks*K_CHUNK into top-K_OUT.
+ """Phase 2: Each program merges num_chunks*K candidates into top-K_OUT.
grid = (batch,)
- Total candidates = num_chunks * K_CHUNK, loaded in blocks.
"""
row_idx = tl.program_id(0)
+ k_chunk = TOTAL_CANDIDATES // num_chunks
- row_base = row_idx * num_chunks * K_CHUNK
-
- # Load all candidates at once using constexpr TOTAL_CANDIDATES
+ row_base = row_idx * num_chunks * k_chunk
load_offs = tl.arange(0, TOTAL_CANDIDATES)
cand_vals = tl.load(chunk_val_ptr + row_base + load_offs)
- cand_idxs = tl.load(chunk_idx_ptr + row_base + load_offs)
for ki in range(K_OUT):
m = tl.max(cand_vals)
- m_idx = tl.where(cand_vals == m, load_offs.to(tl.int32), 0x7FFFFFFF)
+ m_idx = tl.where(cand_vals == m, load_offs, 0x7FFFFFFF).to(tl.int32)
arg = tl.min(m_idx)
orig_idx = tl.load(chunk_idx_ptr + row_base + arg)
@@ -114,39 +94,42 @@
batch, n = x.shape[0], x.shape[1]
k = self.k
device = x.device
+ dtype = x.dtype
- # Phase 1: chunk-level selection
- chunk_size = _get_chunk_size(n, k)
- num_chunks = triton.cdiv(n, chunk_size)
+ # Phase 1: chunk-level top-k selection
+ CHUNK_SIZE = 1024
+ num_chunks = triton.cdiv(n, CHUNK_SIZE)
- # Phase 2 candidate count (must be constexpr for merge kernel)
- # We pad total_candidates to next power of 2 for tl.arange
+ # Pad total candidates to power of 2 for merge kernel
total_candidates = num_chunks * k
total_candidates_padded = triton.next_power_of_2(total_candidates)
- # Temp storage for chunk-level top-k (padded)
- chunk_vals = torch.full((batch, total_candidates_padded,), -1e30, dtype=torch.float32, device=device)
- chunk_idxs = torch.full((batch, total_candidates_padded,), -1, dtype=torch.int64, device=device)
+ chunk_vals = torch.full(
+ (batch, total_candidates_padded,), -1e30, dtype=dtype, device=device
+ )
+ chunk_idxs = torch.full(
+ (batch, total_candidates_padded,), -1, dtype=torch.int64, device=device
+ )
grid1 = (batch, num_chunks)
_chunk_select_kernel[grid1](
x, chunk_vals, chunk_idxs,
- N=n, K=k,
- stride_x_row=x.stride(0),
- CHUNK_SIZE=chunk_size,
+ N=n, stride_x_row=x.stride(0),
+ num_chunks=num_chunks,
+ K=k, CHUNK_SIZE=CHUNK_SIZE,
)
- # Phase 2: merge all chunk results
- val_out = torch.empty(batch, k, dtype=torch.float32, device=device)
+ # Phase 2: merge
+ val_out = torch.empty(batch, k, dtype=dtype, device=device)
idx_out = torch.empty(batch, k, dtype=torch.int64, device=device)
grid2 = (batch,)
_merge_kernel[grid2](
- chunk_vals, chunk_idxs,
- val_out, idx_out,
- num_chunks=num_chunks, K_CHUNK=k, K_OUT=k,
+ chunk_vals, chunk_idxs, val_out, idx_out,
+ num_chunks=num_chunks,
stride_val_row=val_out.stride(0),
stride_idx_row=idx_out.stride(0),
+ K_OUT=k,
TOTAL_CANDIDATES=total_candidates_padded,
)
{
"content": "<4187 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<1361 chars \u2014 see below>",
"description": "Test all shapes",
"timeout": 120000
}timeout 120 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
# Quick test with small shape
batch, n, k = 4, 256, 8
model = solution.Model(batch, n, k).to(device).eval()
x = torch.randn(batch, n, device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
print(f'Output shapes: vals={vals.shape}, idxs={idxs.shape}')
print(f'vals[0]: {vals[0].cpu()}')
ref_vals, ref_idxs = torch.topk(x, k, dim=-1, largest=True, sorted=True)
print(f'ref[0]: {ref_vals[0].cpu()}')
print(f'max_abs_diff: {(vals - ref_vals).abs().max().item():.2e}')
# Test all shapes from shapes.py
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if diff < 1e-4 else \"FAIL\"} ')
" 2>&1 | tail -20
Output shapes: vals=torch.Size([4, 8]), idxs=torch.Size([4, 8])
vals[0]: tensor([3.2682, 2.5568, 2.5005, 2.4025, 2.1860, 2.1808, 2.0991, 2.0974])
ref[0]: tensor([3.2682, 2.5568, 2.5005, 2.4025, 2.1860, 2.1808, 2.0991, 2.0974])
max_abs_diff: 0.00e+00
batch= 1 n=131072 k=64 max_diff=0.00e+00 OK
batch= 64 n= 8192 k= 8 max_diff=0.00e+00 OK
batch= 32 n= 16384 k=32 max_diff=0.00e+00 OK
batch= 16 n= 12000 k=16 max_diff=1.00e+30 FAIL
batch=128 n= 4096 k= 1 max_diff=0.00e+00 OK
{
"oldString": "<572 chars \u2014 see below>",
"newString": "<690 chars \u2014 see below>"
}@triton.jit
def _merge_kernel(
chunk_val_ptr,
chunk_idx_ptr,
val_out_ptr,
idx_out_ptr,
num_chunks,
stride_val_row,
stride_idx_row,
K_OUT: tl.constexpr,
TOTAL_CANDIDATES: tl.constexpr,
):
"""Phase 2: Each program merges num_chunks*K candidates into top-K_OUT.
grid = (batch,)
"""
row_idx = tl.program_id(0)
k_chunk = TOTAL_CANDIDATES // num_chunks
row_base = row_idx * num_chunks * k_chunk
load_offs = tl.arange(0, TOTAL_CANDIDATES)
cand_vals = tl.load(chunk_val_ptr + row_base + load_offs)@triton.jit
def _merge_kernel(
chunk_val_ptr,
chunk_idx_ptr,
val_out_ptr,
idx_out_ptr,
num_chunks,
stride_val_row,
stride_idx_row,
K_OUT: tl.constexpr,
K_CHUNK: tl.constexpr,
LOAD_SIZE: tl.constexpr,
):
"""Phase 2: Each program merges num_chunks*K_CHUNK candidates into top-K_OUT.
grid = (batch,)
LOAD_SIZE must be >= num_chunks * K_CHUNK and a power of 2.
"""
row_idx = tl.program_id(0)
row_base = row_idx * num_chunks * K_CHUNK
load_offs = tl.arange(0, LOAD_SIZE)
load_mask = load_offs < (num_chunks * K_CHUNK)
cand_vals = tl.load(chunk_val_ptr + row_base + load_offs, mask=load_mask, other=-1e30)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -57,19 +57,21 @@
stride_val_row,
stride_idx_row,
K_OUT: tl.constexpr,
- TOTAL_CANDIDATES: tl.constexpr,
+ K_CHUNK: tl.constexpr,
+ LOAD_SIZE: tl.constexpr,
):
- """Phase 2: Each program merges num_chunks*K candidates into top-K_OUT.
+ """Phase 2: Each program merges num_chunks*K_CHUNK candidates into top-K_OUT.
grid = (batch,)
+ LOAD_SIZE must be >= num_chunks * K_CHUNK and a power of 2.
"""
row_idx = tl.program_id(0)
- k_chunk = TOTAL_CANDIDATES // num_chunks
- row_base = row_idx * num_chunks * k_chunk
- load_offs = tl.arange(0, TOTAL_CANDIDATES)
+ row_base = row_idx * num_chunks * K_CHUNK
+ load_offs = tl.arange(0, LOAD_SIZE)
+ load_mask = load_offs < (num_chunks * K_CHUNK)
- cand_vals = tl.load(chunk_val_ptr + row_base + load_offs)
+ cand_vals = tl.load(chunk_val_ptr + row_base + load_offs, mask=load_mask, other=-1e30)
for ki in range(K_OUT):
m = tl.max(cand_vals)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<572 chars \u2014 see below>",
"newString": "<690 chars \u2014 see below>"
}@triton.jit
def _merge_kernel(
chunk_val_ptr,
chunk_idx_ptr,
val_out_ptr,
idx_out_ptr,
num_chunks,
stride_val_row,
stride_idx_row,
K_OUT: tl.constexpr,
TOTAL_CANDIDATES: tl.constexpr,
):
"""Phase 2: Each program merges num_chunks*K candidates into top-K_OUT.
grid = (batch,)
"""
row_idx = tl.program_id(0)
k_chunk = TOTAL_CANDIDATES // num_chunks
row_base = row_idx * num_chunks * k_chunk
load_offs = tl.arange(0, TOTAL_CANDIDATES)
cand_vals = tl.load(chunk_val_ptr + row_base + load_offs)@triton.jit
def _merge_kernel(
chunk_val_ptr,
chunk_idx_ptr,
val_out_ptr,
idx_out_ptr,
num_chunks,
stride_val_row,
stride_idx_row,
K_OUT: tl.constexpr,
K_CHUNK: tl.constexpr,
LOAD_SIZE: tl.constexpr,
):
"""Phase 2: Each program merges num_chunks*K_CHUNK candidates into top-K_OUT.
grid = (batch,)
LOAD_SIZE must be >= num_chunks * K_CHUNK and a power of 2.
"""
row_idx = tl.program_id(0)
row_base = row_idx * num_chunks * K_CHUNK
load_offs = tl.arange(0, LOAD_SIZE)
load_mask = load_offs < (num_chunks * K_CHUNK)
cand_vals = tl.load(chunk_val_ptr + row_base + load_offs, mask=load_mask, other=-1e30)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -84,6 +84,7 @@
tl.store(idx_out_ptr + row_idx * stride_idx_row + ki, orig_idx)
cand_vals = tl.where(load_offs == arg, -1e30, cand_vals)
+ cand_vals = tl.where(load_mask, cand_vals, -1e30)
class Model(nn.Module):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": " cand_vals = tl.where(load_offs == arg, -1e30, cand_vals)",
"newString": "<122 chars \u2014 see below>"
} cand_vals = tl.where(load_offs == arg, -1e30, cand_vals)
cand_vals = tl.where(load_mask, cand_vals, -1e30)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -133,7 +133,8 @@
stride_val_row=val_out.stride(0),
stride_idx_row=idx_out.stride(0),
K_OUT=k,
- TOTAL_CANDIDATES=total_candidates_padded,
+ K_CHUNK=k,
+ LOAD_SIZE=total_candidates_padded,
)
return val_out, idx_out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<336 chars \u2014 see below>",
"oldString": "<320 chars \u2014 see below>"
} grid2 = (batch,)
_merge_kernel[grid2](
chunk_vals, chunk_idxs, val_out, idx_out,
num_chunks=num_chunks,
stride_val_row=val_out.stride(0),
stride_idx_row=idx_out.stride(0),
K_OUT=k,
K_CHUNK=k,
LOAD_SIZE=total_candidates_padded,
) grid2 = (batch,)
_merge_kernel[grid2](
chunk_vals, chunk_idxs, val_out, idx_out,
num_chunks=num_chunks,
stride_val_row=val_out.stride(0),
stride_idx_row=idx_out.stride(0),
K_OUT=k,
TOTAL_CANDIDATES=total_candidates_padded,
)Edit applied successfully.{
"command": "<904 chars \u2014 see below>",
"description": "Test all shapes after fix",
"timeout": 120000
}timeout 120 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
batch= 1 n=131072 k=64 max_diff=0.00e+00 OK
batch= 64 n= 8192 k= 8 max_diff=0.00e+00 OK
batch= 32 n= 16384 k=32 max_diff=0.00e+00 OK
batch= 16 n= 12000 k=16 max_diff=0.00e+00 OK
batch=128 n= 4096 k= 1 max_diff=0.00e+00 OK
All passed: True
{
"todos": [
{
"content": "First Triton-based top-k kernel implementation",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run benchmark.py to get baseline performance",
"status": "pending",
"priority": "high"
},
{
"content": "Profile with ncu and optimize kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Iterate until performance target met",
"status": "p
... (45 more chars)content: 'First Triton-based top-k kernel implementation'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark.py to get baseline performance'
status: 'pending'
priority: 'high'
content: 'Profile with ncu and optimize kernel'
status: 'pending'
priority: 'high'
content: 'Iterate until performance target met'
status: 'pending'
priority: 'high'{
"command": "timeout 120 python3 check.py 2>&1",
"description": "Run official correctness check",
"timeout": 120000
}PASS
{
"command": "timeout 300 python3 benchmark.py 2>&1",
"description": "Run performance benchmark",
"timeout": 300000
}shape=0 variant=eager tflops=0.011 gbps=11.049 ms=0.048
shape=0 variant=compiled tflops=0.007 gbps=7.132 ms=0.074
shape=0 variant=sota tflops=0.011 gbps=11.090 ms=0.047
shape=0 variant=solution tflops=0.006 gbps=5.866 ms=0.090
shape=0 solution_peak_fraction=0.0033
shape=1 variant=eager tflops=0.043 gbps=43.342 ms=0.049
shape=1 variant=compiled tflops=0.026 gbps=25.882 ms=0.081
shape=1 variant=sota tflops=0.044 gbps=43.980 ms=0.048
shape=1 variant=solution tflops=0.075 gbps=74.776 ms=0.028
shape=1 solution_peak_fraction=0.0415
shape=2 variant=eager tflops=0.045 gbps=44.966 ms=0.047
shape=2 variant=compiled tflops=0.028 gbps=28.420 ms=0.074
shape=2 variant=sota tflops=0.045 gbps=45.431 ms=0.046
shape=2 variant=solution tflops=0.058 gbps=58.414 ms=0.036
shape=2 solution_peak_fraction=0.0325
shape=3 variant=eager tflops=0.018 gbps=17.889 ms=0.043
shape=3 variant=compiled tflops=0.010 gbps=10.518 ms=0.073
shape=3 variant=sota tflops=0.018 gbps=18.268 ms=0.042
shape=3 variant=solution tflops=0.026 gbps=25.757 ms=0.030
shape=3 solution_peak_fraction=0.0143
shape=4 variant=eager tflops=0.098 gbps=98.253 ms=0.021
shape=4 variant=compiled tflops=0.036 gbps=35.936 ms=0.058
shape=4 variant=sota tflops=0.103 gbps=103.038 ms=0.020
shape=4 variant=solution tflops=0.074 gbps=74.023 ms=0.028
shape=4 solution_peak_fraction=0.0411
peak_fraction: 0.0192
RESULT: LOW
{
"todos": [
{
"content": "First Triton-based top-k kernel implementation",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Run benchmark.py to get baseline performance",
"status": "completed",
"priority": "high"
},
{
"content": "Profile with ncu and optimize kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Implement CUDA C++ kernel with shared-memory merge tree",
... (190 more chars)content: 'First Triton-based top-k kernel implementation'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py to get baseline performance'
status: 'completed'
priority: 'high'
content: 'Profile with ncu and optimize kernel'
status: 'pending'
priority: 'high'
content: 'Implement CUDA C++ kernel with shared-memory merge tree'
status: 'in_progress'
priority: 'high'
content: 'Iterate until performance target met'
status: 'pending'
priority: 'high'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,143 +1,299 @@
-"""Top-k kernel via Triton — two-phase chunk select + merge."""
+"""Top-k kernel via raw CUDA C++ — shared-memory merge tree.
+
+Design:
+ 1. One block per row (batch dimension).
+ 2. Each thread scans its strided portion, maintaining a top-K buffer in
+ shared memory (insertion sort, rejecting most elements once the K-th
+ threshold stabilises).
+ 3. Block-wide merge tree: log2(BLOCK_SIZE) rounds of pairwise merging two
+ sorted K-element buffers into one, keeping the top K.
+ 4. Thread 0 writes the final K values + indices to global memory.
+
+Register / shared-memory pressure:
+ K=64 → 128 threads, 64 KiB per shared buffer (2 buffers = 128 KiB)
+ K=32 → 256 threads, 32 KiB per buffer
+ K=16 → 256 threads, 16 KiB per buffer
+ K= 8 → 256 threads, 8 KiB per buffer
+ K= 1 → 512 threads, 2 KiB per buffer (special-cased to plain max)
+"""
from __future__ import annotations
+
+import os
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
-
-
-@triton.jit
-def _chunk_select_kernel(
- x_ptr,
- chunk_val_ptr,
- chunk_idx_ptr,
- N,
- stride_x_row,
- num_chunks,
- K: tl.constexpr,
- CHUNK_SIZE: tl.constexpr,
-):
- """Phase 1: Each program handles one chunk of one row.
-
- For a chunk of CHUNK_SIZE elements, find the top-K via repeated max.
-
- grid = (batch, num_chunks)
- """
- row_idx = tl.program_id(0)
- chunk_idx = tl.program_id(1)
-
- chunk_start = chunk_idx * CHUNK_SIZE
- x_row = x_ptr + row_idx * stride_x_row
- out_base = (row_idx * num_chunks + chunk_idx) * K
-
- offs = chunk_start + tl.arange(0, CHUNK_SIZE)
- mask = offs < N
-
- vals = tl.load(x_row + offs, mask=mask, other=-1e30)
-
- for ki in range(K):
- m = tl.max(vals)
- m_idx = tl.where(vals == m, offs.to(tl.int32), 0x7FFFFFFF)
- arg = tl.min(m_idx)
-
- tl.store(chunk_val_ptr + out_base + ki, m)
- tl.store(chunk_idx_ptr + out_base + ki, arg.to(tl.int64))
-
- vals = tl.where(offs == arg, -1e30, vals)
-
-
-@triton.jit
-def _merge_kernel(
- chunk_val_ptr,
- chunk_idx_ptr,
- val_out_ptr,
- idx_out_ptr,
- num_chunks,
- stride_val_row,
- stride_idx_row,
- K_OUT: tl.constexpr,
- K_CHUNK: tl.constexpr,
- LOAD_SIZE: tl.constexpr,
-):
- """Phase 2: Each program merges num_chunks*K_CHUNK candidates into top-K_OUT.
-
- grid = (batch,)
- LOAD_SIZE must be >= num_chunks * K_CHUNK and a power of 2.
- """
- row_idx = tl.program_id(0)
-
- row_base = row_idx * num_chunks * K_CHUNK
- load_offs = tl.arange(0, LOAD_SIZE)
- load_mask = load_offs < (num_chunks * K_CHUNK)
-
- cand_vals = tl.load(chunk_val_ptr + row_base + load_offs, mask=load_mask, other=-1e30)
-
- for ki in range(K_OUT):
- m = tl.max(cand_vals)
- m_idx = tl.where(cand_vals == m, load_offs, 0x7FFFFFFF).to(tl.int32)
- arg = tl.min(m_idx)
-
- orig_idx = tl.load(chunk_idx_ptr + row_base + arg)
-
- tl.store(val_out_ptr + row_idx * stride_val_row + ki, m)
- tl.store(idx_out_ptr + row_idx * stride_idx_row + ki, orig_idx)
-
- cand_vals = tl.where(load_offs == arg, -1e30, cand_vals)
- cand_vals = tl.where(load_mask, cand_vals, -1e30)
-
-
-class Model(nn.Module):
+from torch.utils.cpp_extension import load_inline
+
+_CUDA_SOURCE = r"""
+#include <cuda_fp16.h>
+#include <algorithm>
+#include <cmath>
+
+static constexpr float NEG_INF = -1e30f;
+
+// --------------------------------------------------------------------------
+// Generic kernel: K >= 2
+// --------------------------------------------------------------------------
+template <int K>
+struct BlockSizeTraits;
+
+template <> struct BlockSizeTraits<64> { static constexpr int BLOCK = 128; };
+template <> struct BlockSizeTraits<32> { static constexpr int BLOCK = 256; };
+template <> struct BlockSizeTraits<16> { static constexpr int BLOCK = 256; };
+template <> struct BlockSizeTraits<8> { static constexpr int BLOCK = 256; };
+
+// Bank-conflict padding: shared arrays are laid out as [BLOCK][K+PAD].
+// For K a multiple of 32, padding=1 ensures consecutive threads in a warp
+// hit consecutive banks.
+template <int K>
+static constexpr int SPAD = (K % 32 == 0) ? 1 : 0;
+
+template <int K, int BLOCK, int PAD>
+__global__ void topk_merge_tree_kernel(
+ const float* __restrict__ x,
+ float* __restrict__ values,
+ int64_t* __restrict__ indices,
+ int64_t row_stride,
+ int n
+)
+{
+ extern __shared__ char s_buf[];
+ float* s_val = reinterpret_cast<float*>(s_buf);
+ // Indices stored as int32 in shared mem; cast to int64 on output.
+ int* s_idx = reinterpret_cast<int*>(s_buf + sizeof(float) * BLOCK * (K + PAD));
+
+ int tid = threadIdx.x;
+ int bid = blockIdx.x;
+
+ // -- Initialise local buffer to -inf / -1 ------------------------------------
+ int base = tid * (K + PAD);
+ #pragma unroll
+ for (int i = 0; i < K; ++i) {
+ s_val[base + i] = NEG_INF;
+ s_idx[base + i] = -1;
+ }
+ __syncthreads();
+
+ // -- Phase 1: strided scan with insertion-sort buffer --------------------------
+ const float* row = x + bid * row_stride;
+
+ for (int i = tid; i < n; i += BLOCK) {
+ float v = row[i];
+ float kth = s_val[base + K - 1];
+ if (v > kth) {
+ // Find insertion point by scanning from back.
+ // Unroll hint helps for moderate K.
+ int pos = K - 1;
+ while (pos > 0 && s_val[base + pos - 1] < v) {
+ s_val[base + pos] = s_val[base + pos - 1];
+ s_idx[base + pos] = s_idx[base + pos - 1];
+ --pos;
+ }
+ s_val[base
... (8308 more chars){
"content": "<9911 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<947 chars \u2014 see below>",
"description": "Test CUDA kernel on all shapes",
"timeout": 180000
}timeout 180 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
model._dummy = model._dummy.to(device)
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/usr/lib/python3.12/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 21, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 291, in forward
_ext = _compile_module()
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 260, in _compile_module
return load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp
... (44347 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -5,20 +5,11 @@
2. Each thread scans its strided portion, maintaining a top-K buffer in
shared memory (insertion sort, rejecting most elements once the K-th
threshold stabilises).
- 3. Block-wide merge tree: log2(BLOCK_SIZE) rounds of pairwise merging two
+ 3. Block-wide merge tree: log2(BLOCK) rounds of pairwise merging two
sorted K-element buffers into one, keeping the top K.
4. Thread 0 writes the final K values + indices to global memory.
-
-Register / shared-memory pressure:
- K=64 → 128 threads, 64 KiB per shared buffer (2 buffers = 128 KiB)
- K=32 → 256 threads, 32 KiB per buffer
- K=16 → 256 threads, 16 KiB per buffer
- K= 8 → 256 threads, 8 KiB per buffer
- K= 1 → 512 threads, 2 KiB per buffer (special-cased to plain max)
"""
from __future__ import annotations
-
-import os
import torch
import torch.nn as nn
@@ -26,28 +17,17 @@
_CUDA_SOURCE = r"""
#include <cuda_fp16.h>
-#include <algorithm>
-#include <cmath>
+#include <cstdint>
static constexpr float NEG_INF = -1e30f;
+// Bank-conflict padding for shared arrays.
+// For K = 64 = 2*32, pad=1 avoids bank conflicts; otherwise pad=0.
+template <int K> struct spad_v { static constexpr int value = (K % 32 == 0) ? 1 : 0; };
+
// --------------------------------------------------------------------------
// Generic kernel: K >= 2
// --------------------------------------------------------------------------
-template <int K>
-struct BlockSizeTraits;
-
-template <> struct BlockSizeTraits<64> { static constexpr int BLOCK = 128; };
-template <> struct BlockSizeTraits<32> { static constexpr int BLOCK = 256; };
-template <> struct BlockSizeTraits<16> { static constexpr int BLOCK = 256; };
-template <> struct BlockSizeTraits<8> { static constexpr int BLOCK = 256; };
-
-// Bank-conflict padding: shared arrays are laid out as [BLOCK][K+PAD].
-// For K a multiple of 32, padding=1 ensures consecutive threads in a warp
-// hit consecutive banks.
-template <int K>
-static constexpr int SPAD = (K % 32 == 0) ? 1 : 0;
-
template <int K, int BLOCK, int PAD>
__global__ void topk_merge_tree_kernel(
const float* __restrict__ x,
@@ -59,13 +39,13 @@
{
extern __shared__ char s_buf[];
float* s_val = reinterpret_cast<float*>(s_buf);
- // Indices stored as int32 in shared mem; cast to int64 on output.
- int* s_idx = reinterpret_cast<int*>(s_buf + sizeof(float) * BLOCK * (K + PAD));
+ int* s_idx = reinterpret_cast<int*>(s_buf
+ + sizeof(float) * BLOCK * (K + PAD));
int tid = threadIdx.x;
int bid = blockIdx.x;
- // -- Initialise local buffer to -inf / -1 ------------------------------------
+ // ---- Initialise local buffer to -inf / -1 --------------------------------
int base = tid * (K + PAD);
#pragma unroll
for (int i = 0; i < K; ++i) {
@@ -74,15 +54,13 @@
}
__syncthreads();
- // -- Phase 1: strided scan with insertion-sort buffer --------------------------
+ // ---- Phase 1: strided scan with insertion-sort buffer --------------------
const float* row = x + bid * row_stride;
for (int i = tid; i < n; i += BLOCK) {
float v = row[i];
float kth = s_val[base + K - 1];
if (v > kth) {
- // Find insertion point by scanning from back.
- // Unroll hint helps for moderate K.
int pos = K - 1;
while (pos > 0 && s_val[base + pos - 1] < v) {
s_val[base + pos] = s_val[base + pos - 1];
@@ -95,18 +73,18 @@
}
__syncthreads();
- // -- Phase 2: merge-tree reduction --------------------------------------------
+ // ---- Phase 2: merge-tree reduction ---------------------------------------
for (int stride = BLOCK / 2; stride > 0; stride /= 2) {
if (tid < stride) {
- int a_base = tid * (K + PAD);
- int b_base = (tid + stride) * (K + PAD);
-
- // Local merge into registers to avoid shared-mem races.
+ int a_base = tid * (K + PAD);
+ int b_base = (tid+stride) * (K + PAD);
+
+ // Local merge into registers
float mval[K];
int midx[K];
int ai = 0, bi = 0, oi = 0;
while (oi < K) {
- if (ai < K && (bi >= K || s_val[a_base + ai] >= s_val[b_base + bi])) {
+ if (ai < K && (bi >= K || s_val[a_base+ai] >= s_val[b_base+bi])) {
mval[oi] = s_val[a_base + ai];
midx[oi] = s_idx[a_base + ai];
++ai; ++oi;
@@ -125,10 +103,10 @@
__syncthreads();
}
- // -- Phase 3: thread 0 writes result ------------------------------------------
+ // ---- Phase 3: thread 0 writes result -------------------------------------
if (tid == 0) {
- float* out_val = values + bid * K;
- int64_t* out_idx = indices + bid * K;
+ float* out_val = values + (int64_t)bid * K;
+ int64_t* out_idx = indices + (int64_t)bid * K;
#pragma unroll
for (int i = 0; i < K; ++i) {
out_val[i] = s_val[i];
@@ -138,8 +116,9 @@
}
// --------------------------------------------------------------------------
-// K == 1 : special-case argmax kernel (no shared-sort buffer needed)
-// --------------------------------------------------------------------------
+// K == 1 : argmax kernel (no shared sort buffer)
+// --------------------------------------------------------------------------
+template <int BLOCK>
__global__ void topk_argmax_kernel(
const float* __restrict__ x,
float* __restrict__ values,
@@ -152,26 +131,24 @@
int bid = blockIdx.x;
const float* row = x + bid * row_stride;
- // Thread-local max
- float best = NEG_INF;
+ float best = NEG_INF;
int be
... (6724 more chars){
"content": "<9748 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<982 chars \u2014 see below>",
"description": "Compile and test CUDA kernel",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4 and vals.shape == (s['batch'], s['k']) and idxs.shape == (s['batch'], s['k'])
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/usr/lib/python3.12/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 20, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 289, in forward
_ext = _compile_module()
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 258, in _compile_module
return load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp
... (5636 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -257,7 +257,8 @@
def _compile_module():
return load_inline(
name='topk_cuda_ext',
- cpp_sources='',
+ cpp_sources='#include <torch/extension.h>\n'
+ 'std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k);',
cuda_sources=_CUDA_SOURCE,
functions=['run_topk'],
extra_cuda_cflags=[
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<167 chars \u2014 see below>",
"newString": "<295 chars \u2014 see below>"
}def _compile_module():
return load_inline(
name='topk_cuda_ext',
cpp_sources='',
cuda_sources=_CUDA_SOURCE,
functions=['run_topk'],def _compile_module():
return load_inline(
name='topk_cuda_ext',
cpp_sources='#include <torch/extension.h>\n'
'std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k);',
cuda_sources=_CUDA_SOURCE,
functions=['run_topk'],Edit applied successfully.{
"command": "<982 chars \u2014 see below>",
"description": "Compile and test CUDA kernel v2",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4 and vals.shape == (s['batch'], s['k']) and idxs.shape == (s['batch'], s['k'])
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/usr/lib/python3.12/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 20, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 290, in forward
_ext = _compile_module()
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 258, in _compile_module
return load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp
... (1681 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -248,9 +248,8 @@
return {values, indices};
}
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("run_topk", &run_topk, "Top-K selection on last dim");
-}
+std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k);
+"""
"""
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<78 chars \u2014 see below>",
"oldString": "<109 chars \u2014 see below>"
}std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k);
"""PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run_topk", &run_topk, "Top-K selection on last dim");
}Edit applied successfully.{
"command": "rm -rf /home/infatoshi/.cache/torch_extensions/py312_cu130/topk_cuda_ext/",
"description": "Clear stale build cache"
}{
"command": "<904 chars \u2014 see below>",
"description": "Compile and test CUDA kernel v3",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
Traceback (most recent call last):
File "<string>", line 3, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 279
"""Top-k via CUDA shared-memory merge-tree kernel."""
^
SyntaxError: unterminated triple-quoted string literal (detected at line 308)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
1: """Top-k kernel via raw CUDA C++ — shared-memory merge tree.
2:
3: Design:
4: 1. One block per row (batch dimension).
5: 2. Each thread scans its strided portion, maintaining a top-K buffer in
6: shared memory (insertion sort, rejecting most elements once the K-th
7: threshold stabilises).
8: 3. Block-wide merge tree: log2(BLOCK) rounds of pairwise merging two
9: sorted K-element buffers into one, keeping the top K.
10: 4. Thread 0 writes the final K values + indices to global memory.
11: """
12: from __future__ import annotations
13:
14: import torch
15: import torch.nn as nn
16: from torch.utils.cpp_extension import load_inline
17:
18: _CUDA_SOURCE = r"""
19: #include <cuda_fp16.h>
20: #include <cstdint>
21:
22: static constexpr float NEG_INF = -1e30f;
23:
24: // Bank-conflict padding for shared arrays.
25: // For K = 64 = 2*32, pad=1 avoids bank conflicts; otherwise pad=0.
26: template <int K> struct spad_v { static constexpr int value = (K % 32 == 0) ? 1 : 0; };
27:
28: // --------------------------------------------------------------------------
29: // Generic kernel: K >= 2
30: // --------------------------------------------------------------------------
31: template <int K, int BLOCK, int PAD>
32: __global__ void topk_merge_tree_kernel(
33: const float* __restrict__ x,
34: float* __restrict__ values,
35: i
... (9936 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -247,9 +247,6 @@
return {values, indices};
}
-
-std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k);
-"""
"""
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<60 chars \u2014 see below>",
"oldString": "<140 chars \u2014 see below>"
} return {values, indices};
}
"""
def _compile_module(): return {values, indices};
}
std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k);
"""
"""
def _compile_module():Edit applied successfully.{
"command": "<904 chars \u2014 see below>",
"description": "Compile and test CUDA kernel v4",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
Traceback (most recent call last):
File "<string>", line 21, in <module>
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
{
"command": "<787 chars \u2014 see below>",
"description": "Debug launch error with blocking",
"timeout": 300000
}timeout 300 CUDA_LAUNCH_BLOCKING=1 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
# Test simplest shape first - k=1 argmax
s = {'batch': 128, 'n': 4096, 'k': 1}
torch.manual_seed(42)
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
print(f'k=1: vals shape = {vals.shape}, first val = {vals[0].item()}, ref_max = {x[0].max().item()}')
# Test k=8
s = {'batch': 4, 'n': 256, 'k': 8}
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
print(f'k=8: vals shape = {vals.shape}')
" 2>&1
timeout: failed to run command ‘CUDA_LAUNCH_BLOCKING=1’: No such file or directory
{
"command": "<514 chars \u2014 see below>",
"description": "Debug launch error with blocking v2",
"timeout": 300000
}timeout 300 env CUDA_LAUNCH_BLOCKING=1 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
# Test simplest shape first - k=1 argmax
s = {'batch': 128, 'n': 4096, 'k': 1}
torch.manual_seed(42)
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
with torch.no_grad():
vals, idxs = model(x)
print(f'k=1: vals shape = {vals.shape}, first val = {vals[0].item()}, ref_max = {x[0].max().item()}')
" 2>&1
k=1: vals shape = torch.Size([128, 1]), first val = 3.4430480003356934, ref_max = 3.4430480003356934
{
"command": "<621 chars \u2014 see below>",
"description": "Test k=8",
"timeout": 300000
}timeout 300 env CUDA_LAUNCH_BLOCKING=1 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
# Test k=8
s = {'batch': 4, 'n': 256, 'k': 8}
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
print(f'k=8: vals shape = {vals.shape}')
print(f'vals[0]: {vals[0]}')
ref_v, _ = torch.topk(x, 8, dim=-1, largest=True, sorted=True)
print(f'ref[0]: {ref_v[0]}')
print(f'diff: {(vals - ref_v).abs().max().item()}')
" 2>&1
k=8: vals shape = torch.Size([4, 8])
vals[0]: tensor([2.8544, 2.6124, 2.5454, 2.2313, 2.0809, 2.0615, 1.9923, 1.9598],
device='cuda:0')
ref[0]: tensor([2.8544, 2.6124, 2.5454, 2.2313, 2.0809, 2.0615, 1.9923, 1.9598],
device='cuda:0')
diff: 0.0
{
"command": "<776 chars \u2014 see below>",
"description": "Test k=16, 32, 64",
"timeout": 300000
}timeout 300 env CUDA_LAUNCH_BLOCKING=1 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
for s in [
{'batch': 4, 'n': 256, 'k': 16},
{'batch': 4, 'n': 256, 'k': 32},
{'batch': 4, 'n': 256, 'k': 64},
]:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
try:
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_v, _ = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_v).abs().max().item()
print(f'k={s[\"k\"]:2d}: diff={diff:.2e} OK')
except Exception as e:
print(f'k={s[\"k\"]:2d}: {e}')
" 2>&1
k=16: diff=0.00e+00 OK
k=32: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
k=64: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"offset": 31,
"limit": 50
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
31: template <int K, int BLOCK, int PAD>
32: __global__ void topk_merge_tree_kernel(
33: const float* __restrict__ x,
34: float* __restrict__ values,
35: int64_t* __restrict__ indices,
36: int64_t row_stride,
37: int n
38: )
39: {
40: extern __shared__ char s_buf[];
41: float* s_val = reinterpret_cast<float*>(s_buf);
42: int* s_idx = reinterpret_cast<int*>(s_buf
43: + sizeof(float) * BLOCK * (K + PAD));
44:
45: int tid = threadIdx.x;
46: int bid = blockIdx.x;
47:
48: // ---- Initialise local buffer to -inf / -1 --------------------------------
49: int base = tid * (K + PAD);
50: #pragma unroll
51: for (int i = 0; i < K; ++i) {
52: s_val[base + i] = NEG_INF;
53: s_idx[base + i] = -1;
54: }
55: __syncthreads();
56:
57: // ---- Phase 1: strided scan with insertion-sort buffer --------------------
58: const float* row = x + bid * row_stride;
59:
60: for (int i = tid; i < n; i += BLOCK) {
61: float v = row[i];
62: float kth = s_val[base + K - 1];
63: if (v > kth) {
64: int pos = K - 1;
65: while (pos > 0 && s_val[base + pos - 1] < v) {
66: s_val[base + pos] = s_val[base + pos - 1];
67: s_idx[base + pos] = s_idx[base + pos - 1];
68: --pos;
69:
... (496 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -37,10 +37,9 @@
int n
)
{
- extern __shared__ char s_buf[];
- float* s_val = reinterpret_cast<float*>(s_buf);
- int* s_idx = reinterpret_cast<int*>(s_buf
- + sizeof(float) * BLOCK * (K + PAD));
+ static constexpr int STRIDE = K + PAD;
+ __shared__ float s_val[BLOCK * STRIDE];
+ __shared__ int s_idx[BLOCK * STRIDE];
int tid = threadIdx.x;
int bid = blockIdx.x;
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<186 chars \u2014 see below>",
"oldString": "<245 chars \u2014 see below>"
}{
static constexpr int STRIDE = K + PAD;
__shared__ float s_val[BLOCK * STRIDE];
__shared__ int s_idx[BLOCK * STRIDE];
int tid = threadIdx.x;
int bid = blockIdx.x;{
extern __shared__ char s_buf[];
float* s_val = reinterpret_cast<float*>(s_buf);
int* s_idx = reinterpret_cast<int*>(s_buf
+ sizeof(float) * BLOCK * (K + PAD));
int tid = threadIdx.x;
int bid = blockIdx.x;Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"limit": 60,
"offset": 47
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
47: // ---- Initialise local buffer to -inf / -1 --------------------------------
48: int base = tid * (K + PAD);
49: #pragma unroll
50: for (int i = 0; i < K; ++i) {
51: s_val[base + i] = NEG_INF;
52: s_idx[base + i] = -1;
53: }
54: __syncthreads();
55:
56: // ---- Phase 1: strided scan with insertion-sort buffer --------------------
57: const float* row = x + bid * row_stride;
58:
59: for (int i = tid; i < n; i += BLOCK) {
60: float v = row[i];
61: float kth = s_val[base + K - 1];
62: if (v > kth) {
63: int pos = K - 1;
64: while (pos > 0 && s_val[base + pos - 1] < v) {
65: s_val[base + pos] = s_val[base + pos - 1];
66: s_idx[base + pos] = s_idx[base + pos - 1];
67: --pos;
68: }
69: s_val[base + pos] = v;
70: s_idx[base + pos] = i;
71: }
72: }
73: __syncthreads();
74:
75: // ---- Phase 2: merge-tree reduction ---------------------------------------
76: for (int stride = BLOCK / 2; stride > 0; stride /= 2) {
77: if (tid < stride) {
78: int a_base = tid * (K + PAD);
79: int b_base = (tid+stride) * (K + PAD);
80:
81: // Local merge into registers
82: float mval[K];
83: int midx[K
... (966 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -45,7 +45,7 @@
int bid = blockIdx.x;
// ---- Initialise local buffer to -inf / -1 --------------------------------
- int base = tid * (K + PAD);
+ int base = tid * STRIDE;
#pragma unroll
for (int i = 0; i < K; ++i) {
s_val[base + i] = NEG_INF;
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<129 chars \u2014 see below>",
"oldString": "<132 chars \u2014 see below>"
} // ---- Initialise local buffer to -inf / -1 --------------------------------
int base = tid * STRIDE;
#pragma unroll // ---- Initialise local buffer to -inf / -1 --------------------------------
int base = tid * (K + PAD);
#pragma unrollEdit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -75,8 +75,8 @@
// ---- Phase 2: merge-tree reduction ---------------------------------------
for (int stride = BLOCK / 2; stride > 0; stride /= 2) {
if (tid < stride) {
- int a_base = tid * (K + PAD);
- int b_base = (tid+stride) * (K + PAD);
+ int a_base = tid * STRIDE;
+ int b_base = (tid+stride) * STRIDE;
// Local merge into registers
float mval[K];
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<93 chars \u2014 see below>",
"oldString": "<99 chars \u2014 see below>"
} int a_base = tid * STRIDE;
int b_base = (tid+stride) * STRIDE; int a_base = tid * (K + PAD);
int b_base = (tid+stride) * (K + PAD);Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"limit": 60,
"offset": 170
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
170: switch (k) {
171: case 1: {
172: topk_argmax_kernel<512><<<batch, 512>>>(x, values, indices, n, n);
173: break;
174: }
175: case 8: {
176: constexpr int KK=8, BL=256, PD=spad_v<8>::value;
177: size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
178: topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
179: break;
180: }
181: case 16: {
182: constexpr int KK=16, BL=256, PD=spad_v<16>::value;
183: size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
184: topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
185: break;
186: }
187: case 32: {
188: constexpr int KK=32, BL=256, PD=spad_v<32>::value;
189: size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
190: topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
191: break;
192: }
193: case 64: {
194: constexpr int KK=64, BL=128, PD=spad_v<64>::value;
195: size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
196: topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
197: break;
198: }
199: default: {
200: // Fallback: pick smallest template K0 >= k
201: if (k <= 8) {
... (1524 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -174,46 +174,37 @@
}
case 8: {
constexpr int KK=8, BL=256, PD=spad_v<8>::value;
- size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
- topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
+ topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
break;
}
case 16: {
constexpr int KK=16, BL=256, PD=spad_v<16>::value;
- size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
- topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
+ topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
break;
}
case 32: {
constexpr int KK=32, BL=256, PD=spad_v<32>::value;
- size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
- topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
+ topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
break;
}
case 64: {
constexpr int KK=64, BL=128, PD=spad_v<64>::value;
- size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
- topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
+ topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
break;
}
default: {
- // Fallback: pick smallest template K0 >= k
if (k <= 8) {
constexpr int KK=8, BL=256, PD=spad_v<8>::value;
- size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
- topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
+ topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
} else if (k <= 16) {
constexpr int KK=16, BL=256, PD=spad_v<16>::value;
- size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
- topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
+ topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
} else if (k <= 32) {
constexpr int KK=32, BL=256, PD=spad_v<32>::value;
- size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
- topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
+ topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
} else {
constexpr int KK=64, BL=128, PD=spad_v<64>::value;
- size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
- topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
+ topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
}
break;
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<1552 chars \u2014 see below>",
"oldString": "<2252 chars \u2014 see below>"
} switch (k) {
case 1: {
topk_argmax_kernel<512><<<batch, 512>>>(x, values, indices, n, n);
break;
}
case 8: {
constexpr int KK=8, BL=256, PD=spad_v<8>::value;
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
break;
}
case 16: {
constexpr int KK=16, BL=256, PD=spad_v<16>::value;
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
break;
}
case 32: {
constexpr int KK=32, BL=256, PD=spad_v<32>::value;
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
break;
}
case 64: {
constexpr int KK=64, BL=128, PD=spad_v<64>::value;
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
break;
}
default: {
if (k <= 8) {
constexpr int KK=8, BL=256, PD=spad_v<8>::value;
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
} else if (k <= 16) {
constexpr int KK=16, BL=256, PD=spad_v<16>::value;
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
} else if (k <= 32) {
constexpr int KK=32, BL=256, PD=spad_v<32>::value;
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
} else {
constexpr int KK=64, BL=128, PD=spad_v<64>::value;
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL>>>(x,values,indices,n,n);
}
break;
}
} switch (k) {
case 1: {
topk_argmax_kernel<512><<<batch, 512>>>(x, values, indices, n, n);
break;
}
case 8: {
constexpr int KK=8, BL=256, PD=spad_v<8>::value;
size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
break;
}
case 16: {
constexpr int KK=16, BL=256, PD=spad_v<16>::value;
size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
break;
}
case 32: {
constexpr int KK=32, BL=256, PD=spad_v<32>::value;
size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
break;
}
case 64: {
constexpr int KK=64, BL=128, PD=spad_v<64>::value;
size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
break;
}
default: {
// Fallback: pick smallest template K0 >= k
if (k <= 8) {
constexpr int KK=8, BL=256, PD=spad_v<8>::value;
size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
} else if (k <= 16) {
constexpr int KK=16, BL=256, PD=spad_v<16>::value;
size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
} else if (k <= 32) {
constexpr int KK=32, BL=256, PD=spad_v<32>::value;
size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
} else {
constexpr int KK=64, BL=128, PD=spad_v<64>::value;
size_t smem = sizeof(float)*BL*(KK+PD) + sizeof(int)*BL*(KK+PD);
topk_merge_tree_kernel<KK,BL,PD><<<batch, BL, smem>>>(x,values,indices,n,n);
}
break;
}
}Edit applied successfully.{
"command": "<962 chars \u2014 see below>",
"description": "Compile and test with static shared memory",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/usr/lib/python3.12/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 21, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 276, in forward
_ext = _compile_module()
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 244, in _compile_module
return load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp
... (5397 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,13 +1,25 @@
-"""Top-k kernel via raw CUDA C++ — shared-memory merge tree.
+"""Top-k kernel via raw CUDA C++ — merge-tree with register scan.
+
+Key constraint: SM120 has a 48 KB per-block shared memory limit.
+For K=64, storing all per-thread buffers in shared memory would need
+256 * 65 * 8 = ~130 KB — impossible.
+
+Solution: each thread accumulates its top-K in **registers** during the scan.
+Only at the end do threads dump to shared memory for the merge tree.
+Shared memory for the merge then only needs to hold the *final* candidates:
+ K=64, BL=128 → 128*64*8 = 64 KB. Still too big for 48 KB.
+ K=64, BL=64 → 64*65*8 = 33 KB. Fits.
+ K=32, BL=128 → 128*33*8 = 33 KB. Fits.
+ K=16, BL=256 → 256*16*8 = 32 KB. Fits.
+ K= 8, BL=256 → 256* 8*8 = 16 KB. Fits.
+ K= 1, BL=512 → 512* 1*8 = 4 KB. Fits.
Design:
- 1. One block per row (batch dimension).
- 2. Each thread scans its strided portion, maintaining a top-K buffer in
- shared memory (insertion sort, rejecting most elements once the K-th
- threshold stabilises).
- 3. Block-wide merge tree: log2(BLOCK) rounds of pairwise merging two
- sorted K-element buffers into one, keeping the top K.
- 4. Thread 0 writes the final K values + indices to global memory.
+ 1. One block per row.
+ 2. Each thread scans strided elements, keeping top-K in registers.
+ 3. Threads dump their register buffers to shared memory.
+ 4. Block-wide merge tree to extract final top-K.
+ 5. Thread 0 writes output.
"""
from __future__ import annotations
@@ -15,21 +27,24 @@
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
+_CPP_HEADER = (
+ '#include <torch/extension.h>\n'
+ 'std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k);'
+)
+
_CUDA_SOURCE = r"""
#include <cuda_fp16.h>
#include <cstdint>
static constexpr float NEG_INF = -1e30f;
-// Bank-conflict padding for shared arrays.
-// For K = 64 = 2*32, pad=1 avoids bank conflicts; otherwise pad=0.
template <int K> struct spad_v { static constexpr int value = (K % 32 == 0) ? 1 : 0; };
// --------------------------------------------------------------------------
-// Generic kernel: K >= 2
+// Generic kernel (K >= 2): register scan + shared-memory merge
// --------------------------------------------------------------------------
template <int K, int BLOCK, int PAD>
-__global__ void topk_merge_tree_kernel(
+__global__ void topk_kernel(
const float* __restrict__ x,
float* __restrict__ values,
int64_t* __restrict__ indices,
@@ -38,71 +53,74 @@
)
{
static constexpr int STRIDE = K + PAD;
- __shared__ float s_val[BLOCK * STRIDE];
- __shared__ int s_idx[BLOCK * STRIDE];
+
+ // ---- Register-allocated per-thread top-K buffer --------------------------
+ float tv[K];
+ int ti[K];
+ #pragma unroll
+ for (int i = 0; i < K; ++i) { tv[i] = NEG_INF; ti[i] = -1; }
int tid = threadIdx.x;
int bid = blockIdx.x;
-
- // ---- Initialise local buffer to -inf / -1 --------------------------------
+ const float* row = x + bid * row_stride;
+
+ // Strided scan: each thread processes n/BLOCK elements
+ for (int i = tid; i < n; i += BLOCK) {
+ float v = row[i];
+ if (v > tv[K - 1]) {
+ // Insertion: shift from back
+ for (int j = K - 1; j > 0; --j) {
+ if (tv[j-1] < v) { tv[j] = tv[j-1]; ti[j] = ti[j-1]; }
+ else { break; }
+ }
+ // Find the actual insertion point
+ int pos;
+ for (pos = 0; pos < K && tv[pos] >= v; ++pos) {}
+ // Shift [pos..K-2] to [pos+1..K-1]
+ for (int j = K - 1; j > pos; --j) {
+ tv[j] = tv[j-1];
+ ti[j] = ti[j-1];
+ }
+ tv[pos] = v;
+ ti[pos] = i;
+ }
+ }
+
+ // ---- Dump to shared memory for merge tree --------------------------------
+ static constexpr int SMEM_SIZE = BLOCK * STRIDE;
+ __shared__ float s_val[SMEM_SIZE];
+ __shared__ int s_idx[SMEM_SIZE];
+
int base = tid * STRIDE;
#pragma unroll
for (int i = 0; i < K; ++i) {
- s_val[base + i] = NEG_INF;
- s_idx[base + i] = -1;
+ s_val[base + i] = tv[i];
+ s_idx[base + i] = ti[i];
}
__syncthreads();
- // ---- Phase 1: strided scan with insertion-sort buffer --------------------
- const float* row = x + bid * row_stride;
-
- for (int i = tid; i < n; i += BLOCK) {
- float v = row[i];
- float kth = s_val[base + K - 1];
- if (v > kth) {
- int pos = K - 1;
- while (pos > 0 && s_val[base + pos - 1] < v) {
- s_val[base + pos] = s_val[base + pos - 1];
- s_idx[base + pos] = s_idx[base + pos - 1];
- --pos;
- }
- s_val[base + pos] = v;
- s_idx[base + pos] = i;
- }
- }
- __syncthreads();
-
- // ---- Phase 2: merge-tree reduction ---------------------------------------
+ // ---- Merge-tree reduction ------------------------------------------------
for (int stride = BLOCK / 2; stride > 0; stride /= 2) {
if (tid < stride) {
- int a_base = tid * STRIDE;
- int b_base = (tid+stride) * STRIDE;
-
- // Local merge into registers
- float mval[K];
- int midx[K];
+ int a = tid * STRIDE;
+ int b = (tid + stride) * STRIDE;
+ float mv[K];
+ int mi[K];
int ai = 0, bi = 0, oi = 0;
while (oi < K) {
- if (ai < K && (bi >= K || s_val[a_base+ai] >= s_val[b_base+bi])) {
- mval[oi] = s_val[a_base + ai];
- midx[oi] = s_idx[a_base + ai];
-
... (4844 more chars){
"content": "<9501 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<962 chars \u2014 see below>",
"description": "Compile and test register-scan version",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
batch= 1 n=131072 k=64 max_diff=3.17e-02 FAIL
batch= 64 n= 8192 k= 8 max_diff=1.04e-01 FAIL
batch= 32 n= 16384 k=32 max_diff=2.32e-01 FAIL
batch= 16 n= 12000 k=16 max_diff=1.42e-01 FAIL
batch=128 n= 4096 k= 1 max_diff=0.00e+00 OK
All passed: False
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -68,18 +68,11 @@
for (int i = tid; i < n; i += BLOCK) {
float v = row[i];
if (v > tv[K - 1]) {
- // Insertion: shift from back
- for (int j = K - 1; j > 0; --j) {
- if (tv[j-1] < v) { tv[j] = tv[j-1]; ti[j] = ti[j-1]; }
- else { break; }
- }
- // Find the actual insertion point
- int pos;
- for (pos = 0; pos < K && tv[pos] >= v; ++pos) {}
- // Shift [pos..K-2] to [pos+1..K-1]
- for (int j = K - 1; j > pos; --j) {
- tv[j] = tv[j-1];
- ti[j] = ti[j-1];
+ int pos = K - 1;
+ while (pos > 0 && tv[pos - 1] < v) {
+ tv[pos] = tv[pos - 1];
+ ti[pos] = ti[pos - 1];
+ --pos;
}
tv[pos] = v;
ti[pos] = i;
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<416 chars \u2014 see below>",
"oldString": "<733 chars \u2014 see below>"
} // Strided scan: each thread processes n/BLOCK elements
for (int i = tid; i < n; i += BLOCK) {
float v = row[i];
if (v > tv[K - 1]) {
int pos = K - 1;
while (pos > 0 && tv[pos - 1] < v) {
tv[pos] = tv[pos - 1];
ti[pos] = ti[pos - 1];
--pos;
}
tv[pos] = v;
ti[pos] = i;
}
} // Strided scan: each thread processes n/BLOCK elements
for (int i = tid; i < n; i += BLOCK) {
float v = row[i];
if (v > tv[K - 1]) {
// Insertion: shift from back
for (int j = K - 1; j > 0; --j) {
if (tv[j-1] < v) { tv[j] = tv[j-1]; ti[j] = ti[j-1]; }
else { break; }
}
// Find the actual insertion point
int pos;
for (pos = 0; pos < K && tv[pos] >= v; ++pos) {}
// Shift [pos..K-2] to [pos+1..K-1]
for (int j = K - 1; j > pos; --j) {
tv[j] = tv[j-1];
ti[j] = ti[j-1];
}
tv[pos] = v;
ti[pos] = i;
}
}Edit applied successfully.{
"command": "<962 chars \u2014 see below>",
"description": "Test with fixed insertion sort",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
batch= 1 n=131072 k=64 max_diff=0.00e+00 OK
batch= 64 n= 8192 k= 8 max_diff=0.00e+00 OK
batch= 32 n= 16384 k=32 max_diff=0.00e+00 OK
batch= 16 n= 12000 k=16 max_diff=0.00e+00 OK
batch=128 n= 4096 k= 1 max_diff=0.00e+00 OK
All passed: True
{
"todos": [
{
"content": "First Triton-based top-k kernel implementation",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Implement CUDA C++ kernel with shared-memory merge tree",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for CUDA kernel correctness",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run benchmark.py for baseline CUDA per
... (317 more chars)content: 'First Triton-based top-k kernel implementation'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Implement CUDA C++ kernel with shared-memory merge tree'
status: 'completed'
priority: 'high'
content: 'Run check.py for CUDA kernel correctness'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark.py for baseline CUDA performance'
status: 'pending'
priority: 'high'
content: 'Optimize based on profiling results'
status: 'pending'
priority: 'high'
content: 'Iterate until performance target met'
status: 'pending'
priority: 'high'{
"command": "timeout 180 python3 check.py 2>&1",
"description": "Run official correctness check",
"timeout": 180000
}PASS
{
"command": "timeout 600 python3 benchmark.py 2>&1",
"description": "Run performance benchmark for CUDA kernel",
"timeout": 600000
}shape=0 variant=eager tflops=0.011 gbps=11.027 ms=0.048
shape=0 variant=compiled tflops=0.007 gbps=7.212 ms=0.073
shape=0 variant=sota tflops=0.011 gbps=11.135 ms=0.047
shape=0 variant=solution tflops=0.000 gbps=0.055 ms=9.569
shape=0 solution_peak_fraction=0.0000
shape=1 variant=eager tflops=0.043 gbps=43.356 ms=0.049
shape=1 variant=compiled tflops=0.026 gbps=26.354 ms=0.080
shape=1 variant=sota tflops=0.044 gbps=43.951 ms=0.048
shape=1 variant=solution tflops=0.060 gbps=60.246 ms=0.035
shape=1 solution_peak_fraction=0.0335
shape=2 variant=eager tflops=0.045 gbps=44.935 ms=0.047
shape=2 variant=compiled tflops=0.028 gbps=28.249 ms=0.075
shape=2 variant=sota tflops=0.045 gbps=45.043 ms=0.047
shape=2 variant=solution tflops=0.004 gbps=3.676 ms=0.574
shape=2 solution_peak_fraction=0.0020
shape=3 variant=eager tflops=0.018 gbps=18.036 ms=0.043
shape=3 variant=compiled tflops=0.010 gbps=10.511 ms=0.073
shape=3 variant=sota tflops=0.018 gbps=18.331 ms=0.042
shape=3 variant=solution tflops=0.006 gbps=6.275 ms=0.123
shape=3 solution_peak_fraction=0.0035
shape=4 variant=eager tflops=0.095 gbps=95.395 ms=0.022
shape=4 variant=compiled tflops=0.035 gbps=35.451 ms=0.059
shape=4 variant=sota tflops=0.103 gbps=103.282 ms=0.020
shape=4 variant=solution tflops=0.248 gbps=248.424 ms=0.008
shape=4 solution_peak_fraction=0.1380
peak_fraction: 0.0040
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,27 +1,19 @@
-"""Top-k kernel via raw CUDA C++ — merge-tree with register scan.
-
-Key constraint: SM120 has a 48 KB per-block shared memory limit.
-For K=64, storing all per-thread buffers in shared memory would need
-256 * 65 * 8 = ~130 KB — impossible.
-
-Solution: each thread accumulates its top-K in **registers** during the scan.
-Only at the end do threads dump to shared memory for the merge tree.
-Shared memory for the merge then only needs to hold the *final* candidates:
- K=64, BL=128 → 128*64*8 = 64 KB. Still too big for 48 KB.
- K=64, BL=64 → 64*65*8 = 33 KB. Fits.
- K=32, BL=128 → 128*33*8 = 33 KB. Fits.
- K=16, BL=256 → 256*16*8 = 32 KB. Fits.
- K= 8, BL=256 → 256* 8*8 = 16 KB. Fits.
- K= 1, BL=512 → 512* 1*8 = 4 KB. Fits.
-
-Design:
- 1. One block per row.
- 2. Each thread scans strided elements, keeping top-K in registers.
- 3. Threads dump their register buffers to shared memory.
- 4. Block-wide merge tree to extract final top-K.
- 5. Thread 0 writes output.
+"""Top-k via raw CUDA — two-phase segment-select + merge.
+
+Phase 1 (high parallelism): grid = (batch, num_segments)
+ Each block handles a segment of a row, keeps per-thread top-M in registers,
+ then thread 0 selects the segment-level top-K sequentially from shared memory.
+
+Phase 2 (merge): grid = (batch,)
+ Each block merges num_segments * K candidates into final top-K.
+ Candidates are loaded in chunks to stay within the 48 KB shared-memory cap.
+
+This design keeps good GPU utilisation even for batch=1 and n=131k:
+ 131072 / 2048 = 64 segments → 64 blocks of 256 threads = 16k threads.
"""
from __future__ import annotations
+
+import math
import torch
import torch.nn as nn
@@ -35,214 +27,203 @@
_CUDA_SOURCE = r"""
#include <cuda_fp16.h>
#include <cstdint>
+#include <algorithm>
static constexpr float NEG_INF = -1e30f;
-template <int K> struct spad_v { static constexpr int value = (K % 32 == 0) ? 1 : 0; };
-
-// --------------------------------------------------------------------------
-// Generic kernel (K >= 2): register scan + shared-memory merge
-// --------------------------------------------------------------------------
-template <int K, int BLOCK, int PAD>
-__global__ void topk_kernel(
+// ========== Phase 1: segment-level top-K ====================================
+// Each block handles SEG_SIZE elements and outputs the top-K from that segment.
+//
+// Strategy:
+// 1. Each thread scans its strided portion, keeping top-M in registers
+// where M = min(K, ceil(SEG_SIZE/BLOCK)).
+// 2. Threads dump their buffers to shared memory.
+// 3. Thread 0 linear-scans shared memory, picks the segment-wide top-K
+// (K sequential iterations of max-finding).
+//
+// For SEG_SIZE=2048, BLOCK=256 → M = 8. Shared = 2048*8 = 16 KB. Fits.
+
+template <int K, int BLOCK, int M>
+__global__ void segment_select_kernel(
const float* __restrict__ x,
+ float* __restrict__ seg_vals, // [batch, num_segs, K]
+ int* __restrict__ seg_idxs, // [batch, num_segs, K]
+ int64_t row_stride,
+ int n,
+ int seg_size
+)
+{
+ int bid_row = blockIdx.x; // which row
+ int bid_seg = blockIdx.y; // which segment in the row
+ int tid = threadIdx.x;
+
+ int seg_start = bid_seg * seg_size;
+ int seg_end = std::min(seg_start + seg_size, n);
+ int out_base = ((int)bid_row * (int)gridDim.y + (int)bid_seg) * K;
+
+ // Per-thread register buffer (M elements, sorted descending)
+ float tv[M];
+ int ti[M];
+ #pragma unroll
+ for (int i = 0; i < M; ++i) { tv[i] = NEG_INF; ti[i] = -1; }
+
+ // --- Strided scan -------------------------------------------------------
+ for (int i = seg_start + tid; i < seg_end; i += BLOCK) {
+ float v = x[(int64_t)bid_row * row_stride + i];
+ if (v > tv[M-1]) {
+ int pos = M - 1;
+ while (pos > 0 && tv[pos-1] < v) { tv[pos]=tv[pos-1]; ti[pos]=ti[pos-1]; --pos; }
+ tv[pos] = v;
+ ti[pos] = i;
+ }
+ }
+
+ // --- Dump to shared memory -----------------------------------------------
+ __shared__ float s_val[BLOCK * M];
+ __shared__ int s_idx[BLOCK * M];
+ int base = tid * M;
+ #pragma unroll
+ for (int i = 0; i < M; ++i) { s_val[base+i] = tv[i]; s_idx[base+i] = ti[i]; }
+ __syncthreads();
+
+ // --- Thread 0 selects top-K from BLOCK*M candidates ---------------------
+ if (tid == 0) {
+ int total = BLOCK * M;
+ for (int ki = 0; ki < K; ++ki) {
+ float best = NEG_INF;
+ int best_pos = -1;
+ #pragma unroll
+ for (int j = 0; j < total; ++j) {
+ if (s_val[j] > best) { best = s_val[j]; best_pos = j; }
+ }
+ seg_vals[out_base + ki] = best;
+ seg_idxs[out_base + ki] = (best_pos >= 0) ? s_idx[best_pos] : -1;
+ if (best_pos >= 0) s_val[best_pos] = NEG_INF; // mark consumed
+ }
+ }
+}
+
+// ========== Phase 2: merge across segments ===================================
+// Each block handles one row. Reads (num_segs*K) candidates and outputs top-K.
+// Candidates are loaded in blocks of LOAD_BLK elements to respect shared-mem cap.
+
+template <int K, int BLOCK, int LOAD_BLK>
+__global__ void merge_kernel(
+ const float* __restrict__ seg_vals,
+ const int* __restrict__ seg_idxs,
float* __restrict__ values,
int64_t* __restrict__ indices,
- int64_t row_stride,
- int n
+ int num_segs
)
{
- static constexpr int STRIDE = K + PAD;
-
- // ---- Register-allocated per-thread top-K buffer --------------------------
+ int bid = blockIdx.x;
+ int tid = threadIdx.x;
+ int row_base = bid * num_segs * K;
+ int n_cand = num_segs * K;
+
+ // Per-
... (10123 more chars){
"content": "<9323 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<1141 chars \u2014 see below>",
"description": "Compile and test two-phase kernel",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
try:
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, ref_idxs = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
except Exception as e:
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} ERROR: {e}')
all_ok = False
print(f'All passed: {all_ok}')
" 2>&1
batch= 1 n=131072 k=64 ERROR: Error building extension 'topk_cuda_ext': [1/2] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=topk_cuda_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 -arch=sm_120 --use_fast_math -Xptxas=-O3,-v -maxrregcount=96 -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/topk_cuda_ext/cuda.cu -o cuda.cuda.o
FAILED: cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=topk_cuda_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 -arch=sm_120 --use_fast_math -Xptxas=-O3,-v -maxrregcount=96 -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/topk_cuda_ext/cuda.cu -
... (5488 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"offset": 205,
"limit": 30
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
205: if (k_out <= 8) { PHASE1(8); }
206: else if (k_out <= 16) { PHASE1(16); }
207: else if (k_out <= 32) { PHASE1(32); }
208: else if (k_out <= 64) { PHASE1(64); }
209: else { PHASE1(64); }
210:
211: #undef PHASE1
212:
213: auto values = torch::empty({batch, k_out}, x.options());
214: auto indices = torch::empty({batch, k_out}, x.options().dtype(torch::kInt64));
215:
216: #define PHASE2(K) merge_kernel<K, 256, 256> \
217: <<<batch, 256>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
218: values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
219:
220: if (k_out <= 8) { PHASE2(8); }
221: else if (k_out <= 16) { PHASE2(16); }
222: else if (k_out <= 32) { PHASE2(32); }
223: else if (k_out <= 64) { PHASE2(64); }
224: else { PHASE2(64); }
225:
226: #undef PHASE2
227:
228: return {values, indices};
229: }
230: """
231:
232:
233: def _compile_module():
234: return load_inline(
(Showing lines 205-234 of 281. Use offset=235 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -213,17 +213,24 @@
auto values = torch::empty({batch, k_out}, x.options());
auto indices = torch::empty({batch, k_out}, x.options().dtype(torch::kInt64));
- #define PHASE2(K) merge_kernel<K, 256, 256> \
+ #define PHASE2(K, BL) merge_kernel<(K), 256, 256> \
<<<batch, 256>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
-
- if (k_out <= 8) { PHASE2(8); }
- else if (k_out <= 16) { PHASE2(16); }
- else if (k_out <= 32) { PHASE2(32); }
- else if (k_out <= 64) { PHASE2(64); }
- else { PHASE2(64); }
+ // For K>=16 use smaller blocks to fit smem cap:
+ // K=64 BL=64 → smem=32768 K=32 BL=128 → smem=32768
+ #define PHASE2_SMALL(K, BL) merge_kernel<(K), (BL), (BL)> \
+ <<<batch, (BL)>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
+ values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
+
+ if (k_out <= 4) { PHASE2(4, 256); }
+ else if (k_out <= 8) { PHASE2(8, 256); }
+ else if (k_out <= 16) { PHASE2_SMALL(16, 256); }
+ else if (k_out <= 32) { PHASE2_SMALL(32, 128); }
+ else if (k_out <= 64) { PHASE2_SMALL(64, 64); }
+ else { PHASE2_SMALL(64, 64); }
#undef PHASE2
+ #undef PHASE2_SMALL
return {values, indices};
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<928 chars \u2014 see below>",
"oldString": "<461 chars \u2014 see below>"
} #define PHASE2(K, BL) merge_kernel<(K), 256, 256> \
<<<batch, 256>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
// For K>=16 use smaller blocks to fit smem cap:
// K=64 BL=64 → smem=32768 K=32 BL=128 → smem=32768
#define PHASE2_SMALL(K, BL) merge_kernel<(K), (BL), (BL)> \
<<<batch, (BL)>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
if (k_out <= 4) { PHASE2(4, 256); }
else if (k_out <= 8) { PHASE2(8, 256); }
else if (k_out <= 16) { PHASE2_SMALL(16, 256); }
else if (k_out <= 32) { PHASE2_SMALL(32, 128); }
else if (k_out <= 64) { PHASE2_SMALL(64, 64); }
else { PHASE2_SMALL(64, 64); }
#undef PHASE2
#undef PHASE2_SMALL #define PHASE2(K) merge_kernel<K, 256, 256> \
<<<batch, 256>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
if (k_out <= 8) { PHASE2(8); }
else if (k_out <= 16) { PHASE2(16); }
else if (k_out <= 32) { PHASE2(32); }
else if (k_out <= 64) { PHASE2(64); }
else { PHASE2(64); }
#undef PHASE2Edit applied successfully.{
"command": "<1134 chars \u2014 see below>",
"description": "Compile and test with reduced merge shared memory",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
try:
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, _ = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
except Exception as e:
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} ERROR: {e}')
all_ok = False
print(f'All passed: {all_ok}')
" 2>&1
batch= 1 n=131072 k=64 max_diff=0.00e+00 OK
batch= 64 n= 8192 k= 8 max_diff=0.00e+00 OK
batch= 32 n= 16384 k=32 max_diff=0.00e+00 OK
batch= 16 n= 12000 k=16 max_diff=0.00e+00 OK
batch=128 n= 4096 k= 1 max_diff=1.72e+00 FAIL
All passed: False
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"limit": 10,
"offset": 170
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
170: #pragma unroll
171: for (int i = 0; i < K; ++i) { out_v[i]=sm_v[i]; out_i[i]=(int64_t)sm_i[i]; }
172: }
173: }
174:
175: // ========== Dispatcher =======================================================
176: #include <torch/extension.h>
177:
178: std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k_out)
179: {
(Showing lines 170-179 of 288. Use offset=180 to continue.)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"limit": 60,
"offset": 175
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
175: // ========== Dispatcher =======================================================
176: #include <torch/extension.h>
177:
178: std::tuple<torch::Tensor, torch::Tensor> run_topk(torch::Tensor x, int k_out)
179: {
180: TORCH_CHECK(x.is_cuda() && x.is_contiguous());
181: TORCH_CHECK(x.dtype() == torch::kFloat32);
182: TORCH_CHECK(x.dim() == 2);
183:
184: int batch = x.size(0);
185: int n = x.size(1);
186:
187: // Segment size for phase 1 — tuned so BLOCK*M stays below 48 KB smem.
188: // M = ceil(SEG_SIZE / BLOCK). BLOCK=256, SEG_SIZE=2048 → M=8, smem=16KB.
189: const int SEG_SIZE = 2048;
190: const int BLOCK1 = 256;
191: const int num_segs = (n + SEG_SIZE - 1) / SEG_SIZE;
192: const int M = (SEG_SIZE + BLOCK1 - 1) / BLOCK1; // =8
193:
194: auto seg_vals = torch::full({batch, num_segs, k_out}, -1e30f, x.options());
195: auto seg_idxs = torch::full({batch, num_segs, k_out}, -1, x.options().dtype(torch::kInt32));
196:
197: dim3 grid2d(batch, num_segs);
198:
199: // Choose template based on k_out
200: #define PHASE1(K) segment_select_kernel<K, BLOCK1, (M < K ? M : K)> \
201: <<<grid2d, BLOCK1>>>(x.data_ptr<float>(), \
202: seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), \
203: x.stride(0), n, SEG_SIZE)
204:
... (1529 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -196,41 +196,47 @@
dim3 grid2d(batch, num_segs);
- // Choose template based on k_out
- #define PHASE1(K) segment_select_kernel<K, BLOCK1, (M < K ? M : K)> \
- <<<grid2d, BLOCK1>>>(x.data_ptr<float>(), \
- seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), \
+ // Select the template K (round up to next supported value)
+ int K;
+ if (k_out <= 1) K = 1;
+ else if (k_out <= 8) K = 8;
+ else if (k_out <= 16) K = 16;
+ else if (k_out <= 32) K = 32;
+ else K = 64;
+
+ // Phase 1
+ #define PHASE1(KK) segment_select_kernel<KK, BLOCK1, (M < KK ? M : KK)> \
+ <<<grid2d, BLOCK1>>>(x.data_ptr<float>(), \
+ seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), \
x.stride(0), n, SEG_SIZE)
- if (k_out <= 8) { PHASE1(8); }
- else if (k_out <= 16) { PHASE1(16); }
- else if (k_out <= 32) { PHASE1(32); }
- else if (k_out <= 64) { PHASE1(64); }
- else { PHASE1(64); }
+ switch (K) {
+ case 1: PHASE1(1); break;
+ case 8: PHASE1(8); break;
+ case 16: PHASE1(16); break;
+ case 32: PHASE1(32); break;
+ default: PHASE1(64); break;
+ }
#undef PHASE1
auto values = torch::empty({batch, k_out}, x.options());
auto indices = torch::empty({batch, k_out}, x.options().dtype(torch::kInt64));
- #define PHASE2(K, BL) merge_kernel<(K), 256, 256> \
- <<<batch, 256>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
- values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
- // For K>=16 use smaller blocks to fit smem cap:
- // K=64 BL=64 → smem=32768 K=32 BL=128 → smem=32768
- #define PHASE2_SMALL(K, BL) merge_kernel<(K), (BL), (BL)> \
+ // Phase 2 — smaller blocks for large K to fit 48 KB smem
+ #define PHASE2_S(KK, BL) merge_kernel<(KK), (BL), (BL)> \
<<<batch, (BL)>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
- if (k_out <= 4) { PHASE2(4, 256); }
- else if (k_out <= 8) { PHASE2(8, 256); }
- else if (k_out <= 16) { PHASE2_SMALL(16, 256); }
- else if (k_out <= 32) { PHASE2_SMALL(32, 128); }
- else if (k_out <= 64) { PHASE2_SMALL(64, 64); }
- else { PHASE2_SMALL(64, 64); }
-
- #undef PHASE2
- #undef PHASE2_SMALL
+ switch (K) {
+ case 1: PHASE2_S(1, 512); break;
+ case 8: PHASE2_S(8, 256); break;
+ case 16: PHASE2_S(16, 256); break;
+ case 32: PHASE2_S(32, 128); break;
+ default: PHASE2_S(64, 64); break;
+ }
+
+ #undef PHASE2_S
return {values, indices};
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<1491 chars \u2014 see below>",
"oldString": "<1644 chars \u2014 see below>"
} dim3 grid2d(batch, num_segs);
// Select the template K (round up to next supported value)
int K;
if (k_out <= 1) K = 1;
else if (k_out <= 8) K = 8;
else if (k_out <= 16) K = 16;
else if (k_out <= 32) K = 32;
else K = 64;
// Phase 1
#define PHASE1(KK) segment_select_kernel<KK, BLOCK1, (M < KK ? M : KK)> \
<<<grid2d, BLOCK1>>>(x.data_ptr<float>(), \
seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), \
x.stride(0), n, SEG_SIZE)
switch (K) {
case 1: PHASE1(1); break;
case 8: PHASE1(8); break;
case 16: PHASE1(16); break;
case 32: PHASE1(32); break;
default: PHASE1(64); break;
}
#undef PHASE1
auto values = torch::empty({batch, k_out}, x.options());
auto indices = torch::empty({batch, k_out}, x.options().dtype(torch::kInt64));
// Phase 2 — smaller blocks for large K to fit 48 KB smem
#define PHASE2_S(KK, BL) merge_kernel<(KK), (BL), (BL)> \
<<<batch, (BL)>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
switch (K) {
case 1: PHASE2_S(1, 512); break;
case 8: PHASE2_S(8, 256); break;
case 16: PHASE2_S(16, 256); break;
case 32: PHASE2_S(32, 128); break;
default: PHASE2_S(64, 64); break;
}
#undef PHASE2_S dim3 grid2d(batch, num_segs);
// Choose template based on k_out
#define PHASE1(K) segment_select_kernel<K, BLOCK1, (M < K ? M : K)> \
<<<grid2d, BLOCK1>>>(x.data_ptr<float>(), \
seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(), \
x.stride(0), n, SEG_SIZE)
if (k_out <= 8) { PHASE1(8); }
else if (k_out <= 16) { PHASE1(16); }
else if (k_out <= 32) { PHASE1(32); }
else if (k_out <= 64) { PHASE1(64); }
else { PHASE1(64); }
#undef PHASE1
auto values = torch::empty({batch, k_out}, x.options());
auto indices = torch::empty({batch, k_out}, x.options().dtype(torch::kInt64));
#define PHASE2(K, BL) merge_kernel<(K), 256, 256> \
<<<batch, 256>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
// For K>=16 use smaller blocks to fit smem cap:
// K=64 BL=64 → smem=32768 K=32 BL=128 → smem=32768
#define PHASE2_SMALL(K, BL) merge_kernel<(K), (BL), (BL)> \
<<<batch, (BL)>>>(seg_vals.data_ptr<float>(), seg_idxs.data_ptr<int>(),\
values.data_ptr<float>(), indices.data_ptr<int64_t>(), num_segs)
if (k_out <= 4) { PHASE2(4, 256); }
else if (k_out <= 8) { PHASE2(8, 256); }
else if (k_out <= 16) { PHASE2_SMALL(16, 256); }
else if (k_out <= 32) { PHASE2_SMALL(32, 128); }
else if (k_out <= 64) { PHASE2_SMALL(64, 64); }
else { PHASE2_SMALL(64, 64); }
#undef PHASE2
#undef PHASE2_SMALLEdit applied successfully.{
"command": "<955 chars \u2014 see below>",
"description": "Compile and test with consistent K selection",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, _ = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
batch= 1 n=131072 k=64 max_diff=0.00e+00 OK
batch= 64 n= 8192 k= 8 max_diff=0.00e+00 OK
batch= 32 n= 16384 k=32 max_diff=0.00e+00 OK
batch= 16 n= 12000 k=16 max_diff=0.00e+00 OK
batch=128 n= 4096 k= 1 max_diff=0.00e+00 OK
All passed: True
{
"command": "python3 check.py 2>&1",
"description": "Official correctness check",
"timeout": 180000
}PASS
{
"command": "timeout 600 python3 benchmark.py 2>&1",
"description": "Performance benchmark",
"timeout": 600000
}shape=0 variant=eager tflops=0.011 gbps=10.983 ms=0.048
shape=0 variant=compiled tflops=0.007 gbps=7.215 ms=0.073
shape=0 variant=sota tflops=0.011 gbps=11.075 ms=0.047
shape=0 variant=solution tflops=0.001 gbps=0.541 ms=0.971
shape=0 solution_peak_fraction=0.0003
shape=1 variant=eager tflops=0.043 gbps=43.615 ms=0.048
shape=1 variant=compiled tflops=0.026 gbps=25.745 ms=0.082
shape=1 variant=sota tflops=0.044 gbps=44.172 ms=0.048
shape=1 variant=solution tflops=0.024 gbps=23.673 ms=0.089
shape=1 solution_peak_fraction=0.0132
shape=2 variant=eager tflops=0.045 gbps=45.213 ms=0.047
shape=2 variant=compiled tflops=0.028 gbps=28.353 ms=0.074
shape=2 variant=sota tflops=0.045 gbps=45.572 ms=0.046
shape=2 variant=solution tflops=0.007 gbps=7.361 ms=0.287
shape=2 solution_peak_fraction=0.0041
shape=3 variant=eager tflops=0.018 gbps=17.949 ms=0.043
shape=3 variant=compiled tflops=0.010 gbps=10.504 ms=0.073
shape=3 variant=sota tflops=0.018 gbps=18.408 ms=0.042
shape=3 variant=solution tflops=0.005 gbps=5.209 ms=0.148
shape=3 solution_peak_fraction=0.0029
shape=4 variant=eager tflops=0.095 gbps=94.980 ms=0.022
shape=4 variant=compiled tflops=0.036 gbps=35.789 ms=0.059
shape=4 variant=sota tflops=0.099 gbps=99.069 ms=0.021
shape=4 variant=solution tflops=0.135 gbps=135.504 ms=0.015
shape=4 solution_peak_fraction=0.0753
peak_fraction: 0.0051
RESULT: LOW
{
"command": "<1945 chars \u2014 see below>",
"description": "Profile kernel timings",
"timeout": 120000
}timeout 120 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
# Test shape 0 - the worst case
model = solution.Model(1, 131072, 64).to(device).eval()
x = torch.randn(1, 131072, device=device, dtype=torch.float32)
warmup = 10
iters = 100
for _ in range(warmup):
with torch.no_grad():
model(x)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
with torch.no_grad():
vals, idxs = model(x)
end.record()
torch.cuda.synchronize()
avg_ms = start.elapsed_time(end) / iters
print(f'Shape 0 avg: {avg_ms:.3f} ms')
# Time sota for comparison
start.record()
for _ in range(iters):
torch.topk(x, 64, dim=-1, largest=True, sorted=True)
end.record()
torch.cuda.synchronize()
sota_ms = start.elapsed_time(end) / iters
print(f'SOTA avg: {sota_ms:.3f} ms')
# Check shapes 0-4
for s in [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
warmup = 10
iters = 100
for _ in range(warmup):
with torch.no_grad():
model(x)
torch.cuda.synchronize()
start.record()
for _ in range(iters):
with torch.no_grad():
model(x)
end.record()
torch.cuda.synchronize()
sol = start.elapsed_time(end) / iters
start.record()
for _ in range(iters):
torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
end.record()
torch.cuda.synchronize()
ref = start.elapsed_time(end) / iters
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d}: sol={sol:.3f}ms ref={ref:.3f}ms ratio={sol/ref:.1f}x')
" 2>&1
Shape 0 avg: 0.906 ms
SOTA avg: 0.178 ms
batch= 1 n=131072 k=64: sol=0.931ms ref=0.041ms ratio=22.5x
batch= 64 n= 8192 k= 8: sol=0.081ms ref=0.040ms ratio=2.0x
batch= 32 n= 16384 k=32: sol=0.278ms ref=0.039ms ratio=7.1x
batch= 16 n= 12000 k=16: sol=0.136ms ref=0.056ms ratio=2.4x
batch=128 n= 4096 k= 1: sol=0.011ms ref=0.015ms ratio=0.7x
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,19 +1,19 @@
-"""Top-k via raw CUDA — two-phase segment-select + merge.
-
-Phase 1 (high parallelism): grid = (batch, num_segments)
- Each block handles a segment of a row, keeps per-thread top-M in registers,
- then thread 0 selects the segment-level top-K sequentially from shared memory.
+"""Top-k via raw CUDA — bitonic sort per segment, then K-way merge.
+
+Phase 1 (high parallelism): grid = (batch, num_segs)
+ 1. All threads cooperatively load a segment into shared memory (coalesced).
+ 2. Parallel bitonic sort on shared memory (descending).
+ 3. Output top-K to global memory.
Phase 2 (merge): grid = (batch,)
- Each block merges num_segments * K candidates into final top-K.
- Candidates are loaded in chunks to stay within the 48 KB shared-memory cap.
-
-This design keeps good GPU utilisation even for batch=1 and n=131k:
- 131072 / 2048 = 64 segments → 64 blocks of 256 threads = 16k threads.
+ 1. Load all segment top-K candidates into shared memory.
+ 2. Bitonic sort again, output final top-K.
+
+Bitonic sort on shared memory uses BLOCK threads for BLOCK*ELEM_PER_THREAD
+elements. Each sub-stage thread processes ELEM_PER_THREAD compare-swaps,
+giving near-perfect shared-memory throughput.
"""
from __future__ import annotations
-
-import math
import torch
import torch.nn as nn
@@ -31,149 +31,191 @@
static constexpr float NEG_INF = -1e30f;
-// ========== Phase 1: segment-level top-K ====================================
-// Each block handles SEG_SIZE elements and outputs the top-K from that segment.
-//
-// Strategy:
-// 1. Each thread scans its strided portion, keeping top-M in registers
-// where M = min(K, ceil(SEG_SIZE/BLOCK)).
-// 2. Threads dump their buffers to shared memory.
-// 3. Thread 0 linear-scans shared memory, picks the segment-wide top-K
-// (K sequential iterations of max-finding).
-//
-// For SEG_SIZE=2048, BLOCK=256 → M = 8. Shared = 2048*8 = 16 KB. Fits.
-
-template <int K, int BLOCK, int M>
-__global__ void segment_select_kernel(
+// ===== Bitonic compare-swap (descending, shared memory) =====================
+__device__ __forceinline__
+void cmp_swap(float* sv, int* si, int ixj_mask, int dir, int stage_mask) {
+ int ixj = __laneid() ^ ixj_mask;
+ int dir_x = (__laneid() & stage_mask) != 0;
+ bool descending = (dir == dir_x);
+ int i = __laneid();
+ if ((descending && sv[i] < sv[ixj]) || (!descending && sv[i] > sv[ixj])) {
+ float tmp_v = sv[i]; sv[i] = sv[ixj]; sv[ixj] = tmp_v;
+ int tmp_i = si[i]; si[i] = si[ixj]; si[ixj] = tmp_i;
+ }
+}
+
+// =========================================================================
+// Phase 1: bitonic sort per segment → segment-level top-K
+// =========================================================================
+// SEG_N = number of elements per segment (power-of-2, fits in smem)
+// BLOCK = number of threads per block
+// ELEM = SEG_N / BLOCK (elements per thread for loading)
+// K_OUT = output K
+template <int SEG_N, int BLOCK, int ELEM, int K_OUT>
+__global__ void segment_bitonic_kernel(
const float* __restrict__ x,
- float* __restrict__ seg_vals, // [batch, num_segs, K]
- int* __restrict__ seg_idxs, // [batch, num_segs, K]
+ float* __restrict__ seg_vals, // [batch, num_segs, K_OUT]
+ int* __restrict__ seg_idxs, // [batch, num_segs, K_OUT]
int64_t row_stride,
- int n,
- int seg_size
+ int n
)
{
- int bid_row = blockIdx.x; // which row
- int bid_seg = blockIdx.y; // which segment in the row
+ int bid_row = blockIdx.x;
+ int bid_seg = blockIdx.y;
int tid = threadIdx.x;
-
- int seg_start = bid_seg * seg_size;
- int seg_end = std::min(seg_start + seg_size, n);
- int out_base = ((int)bid_row * (int)gridDim.y + (int)bid_seg) * K;
-
- // Per-thread register buffer (M elements, sorted descending)
- float tv[M];
- int ti[M];
+ int seg_start = bid_seg * SEG_N;
+
+ // --- Load segment into shared memory (coalesced) ------------------------
+ __shared__ float sv[SEG_N];
+ __shared__ int si[SEG_N];
+
#pragma unroll
- for (int i = 0; i < M; ++i) { tv[i] = NEG_INF; ti[i] = -1; }
-
- // --- Strided scan -------------------------------------------------------
- for (int i = seg_start + tid; i < seg_end; i += BLOCK) {
- float v = x[(int64_t)bid_row * row_stride + i];
- if (v > tv[M-1]) {
- int pos = M - 1;
- while (pos > 0 && tv[pos-1] < v) { tv[pos]=tv[pos-1]; ti[pos]=ti[pos-1]; --pos; }
- tv[pos] = v;
- ti[pos] = i;
- }
- }
-
- // --- Dump to shared memory -----------------------------------------------
- __shared__ float s_val[BLOCK * M];
- __shared__ int s_idx[BLOCK * M];
- int base = tid * M;
+ for (int e = 0; e < ELEM; ++e) {
+ int gidx = seg_start + tid * ELEM + e;
+ int sidx = tid * ELEM + e;
+ if (gidx < n) {
+ sv[sidx] = x[(int64_t)bid_row * row_stride + gidx];
+ si[sidx] = gidx;
+ } else {
+ sv[sidx] = NEG_INF;
+ si[sidx] = -1;
+ }
+ }
+ __syncthreads();
+
+ // --- Bitonic sort (descending) on shared memory -------------------------
+ // SEG_N must be a power of 2.
+ for (int stage = 2; stage <= SEG_N; stage <<= 1) {
+ for (int d = stage >> 1; d > 0; d >>= 1) {
+ // Each thread handles ELEM pairs
+ #pragma unroll
+ for (int e = 0; e < ELEM; ++e) {
+ int my_i = tid * ELEM + e;
+ int jx = my_i ^ d;
+ if (jx > my_i && ((my_i & stage) == 0) != ((jx & stage) == 0)) {
+ // Compare-swap for descending
+ // direction:
... (12574 more chars){
"content": "<11975 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<1134 chars \u2014 see below>",
"description": "Compile and test bitonic sort version",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
try:
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, _ = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
except Exception as e:
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} ERROR: {e}')
all_ok = False
print(f'All passed: {all_ok}')
" 2>&1
batch= 1 n=131072 k=64 ERROR: Error building extension 'topk_cuda_ext': [1/2] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=topk_cuda_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 -arch=sm_120 --use_fast_math -Xptxas=-O3,-v -maxrregcount=64 -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/topk_cuda_ext/cuda.cu -o cuda.cuda.o
FAILED: cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=topk_cuda_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 -arch=sm_120 --use_fast_math -Xptxas=-O3,-v -maxrregcount=64 -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/topk_cuda_ext/cuda.cu -
... (19016 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,17 +1,7 @@
"""Top-k via raw CUDA — bitonic sort per segment, then K-way merge.
-Phase 1 (high parallelism): grid = (batch, num_segs)
- 1. All threads cooperatively load a segment into shared memory (coalesced).
- 2. Parallel bitonic sort on shared memory (descending).
- 3. Output top-K to global memory.
-
-Phase 2 (merge): grid = (batch,)
- 1. Load all segment top-K candidates into shared memory.
- 2. Bitonic sort again, output final top-K.
-
-Bitonic sort on shared memory uses BLOCK threads for BLOCK*ELEM_PER_THREAD
-elements. Each sub-stage thread processes ELEM_PER_THREAD compare-swaps,
-giving near-perfect shared-memory throughput.
+Phase 1: grid = (batch, num_segs) — parallel bitonic sort per segment
+Phase 2: grid = (batch,) — parallel bitonic sort of merged candidates
"""
from __future__ import annotations
@@ -31,82 +21,46 @@
static constexpr float NEG_INF = -1e30f;
-// ===== Bitonic compare-swap (descending, shared memory) =====================
-__device__ __forceinline__
-void cmp_swap(float* sv, int* si, int ixj_mask, int dir, int stage_mask) {
- int ixj = __laneid() ^ ixj_mask;
- int dir_x = (__laneid() & stage_mask) != 0;
- bool descending = (dir == dir_x);
- int i = __laneid();
- if ((descending && sv[i] < sv[ixj]) || (!descending && sv[i] > sv[ixj])) {
- float tmp_v = sv[i]; sv[i] = sv[ixj]; sv[ixj] = tmp_v;
- int tmp_i = si[i]; si[i] = si[ixj]; si[ixj] = tmp_i;
- }
-}
-
-// =========================================================================
-// Phase 1: bitonic sort per segment → segment-level top-K
-// =========================================================================
-// SEG_N = number of elements per segment (power-of-2, fits in smem)
-// BLOCK = number of threads per block
-// ELEM = SEG_N / BLOCK (elements per thread for loading)
-// K_OUT = output K
-template <int SEG_N, int BLOCK, int ELEM, int K_OUT>
-__global__ void segment_bitonic_kernel(
- const float* __restrict__ x,
- float* __restrict__ seg_vals, // [batch, num_segs, K_OUT]
- int* __restrict__ seg_idxs, // [batch, num_segs, K_OUT]
- int64_t row_stride,
- int n
+// =========================================================================
+// Bitonic sort kernel: sorts N elements (power-of-2) in shared memory
+// =========================================================================
+// Generic: N elements, B threads, E=ELEM=B/N elements per thread, K_OUT output.
+template <int N, int B, int ELEM, int K_OUT>
+__global__ void bitonic_topk(
+ const float* __restrict__ in_v,
+ const int* __restrict__ in_i,
+ int64_t in_row_stride,
+ float* __restrict__ out_v,
+ int* __restrict__ out_i,
+ int row_idx
)
{
- int bid_row = blockIdx.x;
- int bid_seg = blockIdx.y;
- int tid = threadIdx.x;
- int seg_start = bid_seg * SEG_N;
-
- // --- Load segment into shared memory (coalesced) ------------------------
- __shared__ float sv[SEG_N];
- __shared__ int si[SEG_N];
-
+ int tid = threadIdx.x;
+
+ __shared__ float sv[N];
+ __shared__ int si[N];
+
+ // Load input
#pragma unroll
for (int e = 0; e < ELEM; ++e) {
- int gidx = seg_start + tid * ELEM + e;
- int sidx = tid * ELEM + e;
- if (gidx < n) {
- sv[sidx] = x[(int64_t)bid_row * row_stride + gidx];
- si[sidx] = gidx;
- } else {
- sv[sidx] = NEG_INF;
- si[sidx] = -1;
- }
+ int gidx = tid * ELEM + e;
+ sv[gidx] = in_v[in_row_stride * row_idx + gidx];
+ si[gidx] = in_i ? in_i[in_row_stride * row_idx + gidx] : gidx;
}
__syncthreads();
- // --- Bitonic sort (descending) on shared memory -------------------------
- // SEG_N must be a power of 2.
- for (int stage = 2; stage <= SEG_N; stage <<= 1) {
+ // Bitonic sort (descending)
+ for (int stage = 2; stage <= N; stage <<= 1) {
for (int d = stage >> 1; d > 0; d >>= 1) {
- // Each thread handles ELEM pairs
#pragma unroll
for (int e = 0; e < ELEM; ++e) {
- int my_i = tid * ELEM + e;
- int jx = my_i ^ d;
- if (jx > my_i && ((my_i & stage) == 0) != ((jx & stage) == 0)) {
- // Compare-swap for descending
- // direction: ascending within each half of the block
- bool dir_up = (my_i & (stage >> 1)) == 0;
- // For descending sort: swap if wrong direction
- if (dir_up) {
- if (sv[my_i] < sv[jx]) {
- float tv = sv[my_i]; sv[my_i] = sv[jx]; sv[jx] = tv;
- int ti2 = si[my_i]; si[my_i] = si[jx]; si[jx] = ti2;
- }
- } else {
- if (sv[my_i] > sv[jx]) {
- float tv = sv[my_i]; sv[my_i] = sv[jx]; sv[jx] = tv;
- int ti2 = si[my_i]; si[my_i] = si[jx]; si[jx] = ti2;
- }
+ int mi = tid * ELEM + e;
+ int jx = mi ^ d;
+ if (jx > mi && ((mi & stage) != 0) != ((jx & stage) != 0)) {
+ bool dir_up = (mi & (stage >> 1)) == 0;
+ if ((dir_up && sv[mi] < sv[jx]) || (!dir_up && sv[mi] > sv[jx])) {
+ float tv = sv[mi]; sv[mi] = sv[jx]; sv[jx] = tv;
+ int ti2= si[mi]; si[mi] = si[jx]; si[jx] = ti2;
}
}
}
@@ -114,36 +68,93 @@
}
}
- // Now sv[0..K_OUT-1] are the largest K_OUT elements in descending order
- // (assuming SEG_N >= K_OUT, which it is because SEG_N >= ELEM >= 1 and K
... (12471 more chars){
"content": "<10883 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<1177 chars \u2014 see below>",
"description": "Compile and test bitonic sort v2",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
try:
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, _ = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
except Exception as e:
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} ERROR')
import traceback; traceback.print_exc()
all_ok = False
print(f'All passed: {all_ok}')
" 2>&1
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/usr/lib/python3.12/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 22, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 322, in forward
_ext = _compile_module()
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 294, in _compile_module
return load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp
... (2498 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -24,60 +24,6 @@
// =========================================================================
// Bitonic sort kernel: sorts N elements (power-of-2) in shared memory
// =========================================================================
-// Generic: N elements, B threads, E=ELEM=B/N elements per thread, K_OUT output.
-template <int N, int B, int ELEM, int K_OUT>
-__global__ void bitonic_topk(
- const float* __restrict__ in_v,
- const int* __restrict__ in_i,
- int64_t in_row_stride,
- float* __restrict__ out_v,
- int* __restrict__ out_i,
- int row_idx
-)
-{
- int tid = threadIdx.x;
-
- __shared__ float sv[N];
- __shared__ int si[N];
-
- // Load input
- #pragma unroll
- for (int e = 0; e < ELEM; ++e) {
- int gidx = tid * ELEM + e;
- sv[gidx] = in_v[in_row_stride * row_idx + gidx];
- si[gidx] = in_i ? in_i[in_row_stride * row_idx + gidx] : gidx;
- }
- __syncthreads();
-
- // Bitonic sort (descending)
- for (int stage = 2; stage <= N; stage <<= 1) {
- for (int d = stage >> 1; d > 0; d >>= 1) {
- #pragma unroll
- for (int e = 0; e < ELEM; ++e) {
- int mi = tid * ELEM + e;
- int jx = mi ^ d;
- if (jx > mi && ((mi & stage) != 0) != ((jx & stage) != 0)) {
- bool dir_up = (mi & (stage >> 1)) == 0;
- if ((dir_up && sv[mi] < sv[jx]) || (!dir_up && sv[mi] > sv[jx])) {
- float tv = sv[mi]; sv[mi] = sv[jx]; sv[jx] = tv;
- int ti2= si[mi]; si[mi] = si[jx]; si[jx] = ti2;
- }
- }
- }
- __syncthreads();
- }
- }
-
- // Write top K_OUT
- if (tid < K_OUT) {
- out_v[out_base(row_idx) + tid] = sv[tid];
- out_i[out_base(row_idx) + tid] = si[tid];
- }
-}
-
-// I need a different approach since out_base depends on the kernel.
-// Let me just use direct parameters instead.
-
// =========================================================================
// Phase 1: segment-level bitonic sort
// =========================================================================
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<192 chars \u2014 see below>",
"oldString": "<1916 chars \u2014 see below>"
}// =========================================================================
// Phase 1: segment-level bitonic sort
// =========================================================================// Generic: N elements, B threads, E=ELEM=B/N elements per thread, K_OUT output.
template <int N, int B, int ELEM, int K_OUT>
__global__ void bitonic_topk(
const float* __restrict__ in_v,
const int* __restrict__ in_i,
int64_t in_row_stride,
float* __restrict__ out_v,
int* __restrict__ out_i,
int row_idx
)
{
int tid = threadIdx.x;
__shared__ float sv[N];
__shared__ int si[N];
// Load input
#pragma unroll
for (int e = 0; e < ELEM; ++e) {
int gidx = tid * ELEM + e;
sv[gidx] = in_v[in_row_stride * row_idx + gidx];
si[gidx] = in_i ? in_i[in_row_stride * row_idx + gidx] : gidx;
}
__syncthreads();
// Bitonic sort (descending)
for (int stage = 2; stage <= N; stage <<= 1) {
for (int d = stage >> 1; d > 0; d >>= 1) {
#pragma unroll
for (int e = 0; e < ELEM; ++e) {
int mi = tid * ELEM + e;
int jx = mi ^ d;
if (jx > mi && ((mi & stage) != 0) != ((jx & stage) != 0)) {
bool dir_up = (mi & (stage >> 1)) == 0;
if ((dir_up && sv[mi] < sv[jx]) || (!dir_up && sv[mi] > sv[jx])) {
float tv = sv[mi]; sv[mi] = sv[jx]; sv[jx] = tv;
int ti2= si[mi]; si[mi] = si[jx]; si[jx] = ti2;
}
}
}
__syncthreads();
}
}
// Write top K_OUT
if (tid < K_OUT) {
out_v[out_base(row_idx) + tid] = sv[tid];
out_i[out_base(row_idx) + tid] = si[tid];
}
}
// I need a different approach since out_base depends on the kernel.
// Let me just use direct parameters instead.
// =========================================================================
// Phase 1: segment-level bitonic sort
// =========================================================================Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,17 +1,4 @@
static constexpr float NEG_INF = -1e30f;
-
-// ===== Bitonic compare-swap (descending, shared memory) =====================
-__device__ __forceinline__
-void cmp_swap(float* sv, int* si, int ixj_mask, int dir, int stage_mask) {
- int ixj = __laneid() ^ ixj_mask;
- int dir_x = (__laneid() & stage_mask) != 0;
- bool descending = (dir == dir_x);
- int i = __laneid();
- if ((descending && sv[i] < sv[ixj]) || (!descending && sv[i] > sv[ixj])) {
- float tmp_v = sv[i]; sv[i] = sv[ixj]; sv[ixj] = tmp_v;
- int tmp_i = si[i]; si[i] = si[ixj]; si[ixj] = tmp_i;
- }
-}
// =========================================================================
// Phase 1: segment-level bitonic sort
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<234 chars \u2014 see below>",
"oldString": "<777 chars \u2014 see below>"
}static constexpr float NEG_INF = -1e30f;
// =========================================================================
// Phase 1: segment-level bitonic sort
// =========================================================================static constexpr float NEG_INF = -1e30f;
// ===== Bitonic compare-swap (descending, shared memory) =====================
__device__ __forceinline__
void cmp_swap(float* sv, int* si, int ixj_mask, int dir, int stage_mask) {
int ixj = __laneid() ^ ixj_mask;
int dir_x = (__laneid() & stage_mask) != 0;
bool descending = (dir == dir_x);
int i = __laneid();
if ((descending && sv[i] < sv[ixj]) || (!descending && sv[i] > sv[ixj])) {
float tmp_v = sv[i]; sv[i] = sv[ixj]; sv[ixj] = tmp_v;
int tmp_i = si[i]; si[i] = si[ixj]; si[ixj] = tmp_i;
}
}
// =========================================================================
// Phase 1: segment-level bitonic sort
// =========================================================================Edit applied successfully.{
"command": "<955 chars \u2014 see below>",
"description": "Compile and test bitonic sort v3",
"timeout": 300000
}timeout 300 python3 -c "
import torch
import solution
device = torch.device('cuda:0')
test_shapes = [
{'batch': 1, 'n': 131072, 'k': 64},
{'batch': 64, 'n': 8192, 'k': 8},
{'batch': 32, 'n': 16384, 'k': 32},
{'batch': 16, 'n': 12000, 'k': 16},
{'batch': 128, 'n': 4096, 'k': 1},
]
torch.manual_seed(42)
all_ok = True
for s in test_shapes:
model = solution.Model(s['batch'], s['n'], s['k']).to(device).eval()
x = torch.randn(s['batch'], s['n'], device=device, dtype=torch.float32)
torch.cuda.synchronize()
with torch.no_grad():
vals, idxs = model(x)
torch.cuda.synchronize()
ref_vals, _ = torch.topk(x, s['k'], dim=-1, largest=True, sorted=True)
diff = (vals - ref_vals).abs().max().item()
ok = diff < 1e-4
all_ok = all_ok and ok
print(f' batch={s[\"batch\"]:3d} n={s[\"n\"]:6d} k={s[\"k\"]:2d} max_diff={diff:.2e} {\"OK\" if ok else \"FAIL\"}')
print(f'All passed: {all_ok}')
" 2>&1
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/usr/lib/python3.12/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 21, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 270, in forward
_ext = _compile_module()
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 242, in _compile_module
return load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/utils/cpp
... (9270 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -228,8 +228,7 @@
else if (nc2 <= 512) P2_DISPATCH(512);
else if (nc2 <= 1024) P2_DISPATCH(1024);
else if (nc2 <= 2048) P2_DISPATCH(2048);
- else if (nc2 <= 4096) P2_DISPATCH(4096);
- else P2_DISPATCH(8192);
+ else P2_DISPATCH(4096);
return {values, indices};
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"newString": "<309 chars \u2014 see below>",
"oldString": "<354 chars \u2014 see below>"
} if (nc2 <= 64) P2_DISPATCH(64);
else if (nc2 <= 128) P2_DISPATCH(128);
else if (nc2 <= 256) P2_DISPATCH(256);
else if (nc2 <= 512) P2_DISPATCH(512);
else if (nc2 <= 1024) P2_DISPATCH(1024);
else if (nc2 <= 2048) P2_DISPATCH(2048);
else P2_DISPATCH(4096); if (nc2 <= 64) P2_DISPATCH(64);
else if (nc2 <= 128) P2_DISPATCH(128);
else if (nc2 <= 256) P2_DISPATCH(256);
else if (nc2 <= 512) P2_DISPATCH(512);
else if (nc2 <= 1024) P2_DISPATCH(1024);
else if (nc2 <= 2048) P2_DISPATCH(2048);
else if (nc2 <= 4096) P2_DISPATCH(4096);
else P2_DISPATCH(8192);Edit applied successfully.