"""Custom top-k kernel: bitonic sort network in shared memory.
Strategy:
- For k=1: specialized argmax reduction kernel.
- For k>1, n<=TILE_SIZE: single-block bitonic sort, extract top-k.
- For k>1, n>TILE_SIZE: two-pass — chunked kernel finds per-tile top-k,
merge kernel combines them via bitonic sort.
- TILE_SIZE=2048 elements per tile, 1024 threads per block (each thread
handles 2 elements during the sort).
"""
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
TILE_SIZE = 2048
MAX_THREADS = 1024
_cuda_src = r"""
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <float.h>
#include <stdint.h>
// Sort a shared-memory buffer of size N (must be power of 2) using bitonic sort.
// NTHREADS threads cooperate. Each thread owns N/NTHREADS elements initially.
// We use a register-based approach: each thread holds multiple elements.
// For simplicity, we use a sequential bitonic sort approach where all threads
// participate in compare-swap phases via shared memory.
// Actually, let's use a simpler approach: load into smem, then do
// bitonic network with all NTHREADS threads, where the network has
// log2(N) stages and each stage does N/2 compare-swaps distributed
// across NTHREADS threads (each thread does ceil(N/(2*NTHREADS)) swaps).
__device__ __forceinline__ void swap_f(float& a, float& b) { float t=a; a=b; b=t; }
__device__ __forceinline__ void swap_i(int64_t& a, int64_t& b) { int64_t t=a; a=b; b=t; }
__global__
void topk_chunked_kernel(
const float* __restrict__ input,
float* __restrict__ intermediate_v,
int64_t* __restrict__ intermediate_i,
int n, int k, int num_chunks, int TILE_SZ)
{
extern __shared__ char smem_raw[];
float* s_val = (float*)smem_raw;
int64_t* s_idx = (int64_t*)(s_val + TILE_SZ);
int chunk_id = blockIdx.x;
int row = chunk_id / num_chunks;
int chunk = chunk_id % num_chunks;
int tid = threadIdx.x;
int nthreads = blockDim.x;
int offset = row * n + chunk * TILE_SZ;
int valid = n - chunk * TILE_SZ;
if (valid > TILE_SZ) valid = TILE_SZ;
// Each thread loads multiple elements (coalesced)
for (int i = tid; i < TILE_SZ; i += nthreads) {
if (i < valid) {
s_val[i] = __ldg(&input[offset + i]);
s_idx[i] = (int64_t)(chunk * TILE_SZ + i);
} else {
s_val[i] = -FLT_MAX;
s_idx[i] = (int64_t)(chunk * TILE_SZ + i);
}
}
__syncthreads();
// Bitonic sort in shared memory
for (int step = 1; step <= TILE_SZ; step <<= 1) {
for (int sub = step >> 1; sub > 0; sub >>= 1) {
// Each thread handles multiple compare-swaps
for (int i = tid; i < (TILE_SZ >> 1); i += nthreads) {
// Map work index to the two elements to compare
int block_id = i / sub;
int block_start = block_id * 2 * sub;
int offset_in_block = i % sub;
int a = block_start + offset_in_block;
int b = block_start + offset_in_block + sub;
if (b < TILE_SZ) {
int dir = (a & step) ? 1 : 0;
float va = s_val[a], vb = s_val[b];
int64_t ia = s_idx[a], ib = s_idx[b];
int swap = (dir == 0) ? (va > vb) : (va < vb);
if (swap) {
s_val[a] = vb; s_val[b] = va;
s_idx[a] = ib; s_idx[b] = ia;
}
}
}
__syncthreads();
}
}
// Extract top-k (last k elements, reversed for descending)
int out_base = row * num_chunks * k + chunk * k;
for (int i = tid; i < k; i += nthreads) {
int src = TILE_SZ - 1 - i;
intermediate_v[out_base + i] = s_val[src];
intermediate_i[out_base + i] = s_idx[src];
}
}
__global__
void topk_merge_kernel(
const float* __restrict__ intermediate_v,
const int64_t* __restrict__ intermediate_i,
float* __restrict__ out_values,
int64_t* __restrict__ out_indices,
int k, int num_chunks)
{
extern __shared__ char smem_raw[];
int total = num_chunks * k;
// Pad total to next power of 2 for bitonic sort
int sort_size = 1;
while (sort_size < total) sort_size <<= 1;
float* s_val = (float*)smem_raw;
int64_t* s_idx = (int64_t*)(s_val + sort_size);
int row = blockIdx.x;
int tid = threadIdx.x;
int nthreads = blockDim.x;
int in_base = row * total;
// Load actual data
for (int i = tid; i < total; i += nthreads) {
s_val[i] = __ldg(&intermediate_v[in_base + i]);
s_idx[i] = __ldg(&intermediate_i[in_base + i]);
}
// Pad with -inf
for (int i = total + tid; i < sort_size; i += nthreads) {
s_val[i] = -FLT_MAX;
s_idx[i] = (int64_t)0;
}
__syncthreads();
// Bitonic sort
for (int step = 1; step <= sort_size; step <<= 1) {
for (int sub = step >> 1; sub > 0; sub >>= 1) {
for (int i = tid; i < (sort_size >> 1); i += nthreads) {
int block_id = i / sub;
int block_start = block_id * 2 * sub;
int offset_in_block = i % sub;
int a = block_start + offset_in_block;
int b = block_start + offset_in_block + sub;
if (b < sort_size) {
int dir = (a & step) ? 1 : 0;
float va = s_val[a], vb = s_val[b];
int64_t ia = s_idx[a], ib = s_idx[b];
int swap = (dir == 0) ? (va > vb) : (va < vb);
if (swap) {
s_val[a] = vb; s_val[b] = va;
s_idx[a] = ib; s_idx[b] = ia;
}
}
}
__syncthreads();
}
}
// Write top-k (descending)
int out_base = row * k;
for (int i = tid; i < k; i += nthreads) {
int src = sort_size - 1 - i;
out_values[out_base + i] = s_val[src];
out_indices[out_base + i] = s_idx[src];
}
}
__global__
void topk_argmax_kernel(
const float* __restrict__ input,
float* __restrict__ out_values,
int64_t* __restrict__ out_indices,
int n)
{
extern __shared__ char smem_raw[];
float* s_val = (float*)smem_raw;
int64_t* s_idx = (int64_t*)(s_val + blockDim.x);
int row = blockIdx.x;
int tid = threadIdx.x;
const float* row_ptr = input + row * n;
float best_val = -FLT_MAX;
int64_t best_idx = 0;
for (int i = tid; i < n; i += blockDim.x) {
float v = __ldg(&row_ptr[i]);
if (v > best_val) { best_val = v; best_idx = (int64_t)i; }
}
s_val[tid] = best_val;
s_idx[tid] = best_idx;
__syncthreads();
for (int half = blockDim.x >> 1; half > 0; half >>= 1) {
if (tid < half) {
if (s_val[tid + half] > s_val[tid]) {
s_val[tid] = s_val[tid + half];
s_idx[tid] = s_idx[tid + half];
}
}
__syncthreads();
}
if (tid == 0) {
out_values[row] = s_val[0];
out_indices[row] = s_idx[0];
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run_argmax", [](torch::Tensor input, int n) {
int batch = input.size(0);
auto out_v = torch::empty({batch, 1}, input.options());
auto out_i = torch::empty({batch, 1}, input.options().dtype(torch::kInt64));
int block = 1024;
int smem = block * (4 + 8);
topk_argmax_kernel<<<batch, block, smem>>>(
input.data_ptr<float>(), out_v.data_ptr<float>(),
out_i.data_ptr<int64_t>(), n);
return std::vector<torch::Tensor>{out_v, out_i};
}, "argmax top-1");
m.def("run_chunked", [](torch::Tensor input, torch::Tensor inter_v,
torch::Tensor inter_i, int n, int k,
int num_chunks, int tile_size) {
int batch = input.size(0);
int smem = tile_size * (4 + 8);
int threads = 1024;
topk_chunked_kernel<<<batch * num_chunks, threads, smem>>>(
input.data_ptr<float>(), inter_v.data_ptr<float>(),
inter_i.data_ptr<int64_t>(), n, k, num_chunks, tile_size);
}, "chunked top-k");
m.def("run_merge", [](torch::Tensor inter_v, torch::Tensor inter_i,
torch::Tensor out_v, torch::Tensor out_i,
int k, int num_chunks) {
int batch = inter_v.size(0);
int total = num_chunks * k;
int sort_size = 1;
while (sort_size < total) sort_size <<= 1;
int smem = sort_size * (4 + 8);
int threads = min(1024, sort_size / 2);
if (threads < 32) threads = 32;
topk_merge_kernel<<<batch, threads, smem>>>(
inter_v.data_ptr<float>(), inter_i.data_ptr<int64_t>(),
out_v.data_ptr<float>(), out_i.data_ptr<int64_t>(),
k, num_chunks);
}, "merge top-k");
}
"""
_topk_module = None
def _get_module():
global _topk_module
if _topk_module is None:
_topk_module = load_inline(
name="topk_bitonic_v2",
cpp_sources="",
cuda_sources=_cuda_src,
extra_cuda_cflags=["-O3", "--use_fast_math", "-DNDEBUG"],
verbose=False,
)
return _topk_module
class Model(nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch = batch
self.n = n
self.k = k
self.register_buffer("_dummy", torch.zeros(1))
def forward(self, x: torch.Tensor):
batch, n = x.shape
k = self.k
mod = _get_module()
x = x.contiguous()
if k == 1:
result = mod.run_argmax(x, n)
return result[0], result[1]
num_chunks = (n + TILE_SIZE - 1) // TILE_SIZE
if num_chunks == 1:
inter_v = torch.empty(batch, k, dtype=torch.float32, device=x.device)
inter_i = torch.empty(batch, k, dtype=torch.int64, device=x.device)
mod.run_chunked(x, inter_v, inter_i, n, k, 1, TILE_SIZE)
return inter_v, inter_i
inter_v = torch.empty(batch, num_chunks, k, dtype=torch.float32, device=x.device)
inter_i = torch.empty(batch, num_chunks, k, dtype=torch.int64, device=x.device)
mod.run_chunked(x, inter_v, inter_i, n, k, num_chunks, TILE_SIZE)
out_v = torch.empty(batch, k, dtype=torch.float32, device=x.device)
out_i = torch.empty(batch, k, dtype=torch.int64, device=x.device)
mod.run_merge(inter_v, inter_i, out_v, out_i, k, num_chunks)
return out_v, out_i
batch = 64
n = 8192
k = 8
def get_inputs():
x = torch.randn(batch, n, dtype=torch.float32)
return [x]
def get_init_inputs():
return [batch, n, k]
shape=0 variant=eager tflops=0.011 gbps=10.813 ms=0.049
shape=0 variant=compiled tflops=0.006 gbps=6.469 ms=0.081
shape=0 variant=sota tflops=0.011 gbps=10.798 ms=0.049
shape=0 variant=solution tflops=0.006 gbps=6.323 ms=0.083
shape=0 solution_peak_fraction=0.0035
shape=1 variant=eager tflops=0.044 gbps=44.128 ms=0.048
shape=1 variant=compiled tflops=0.025 gbps=25.560 ms=0.082
shape=1 variant=sota tflops=0.045 gbps=44.927 ms=0.047
shape=1 variant=solution tflops=0.032 gbps=32.482 ms=0.065
shape=1 solution_peak_fraction=0.0180
shape=2 variant=eager tflops=0.044 gbps=44.049 ms=0.048
shape=2 variant=compiled tflops=0.026 gbps=26.300 ms=0.080
shape=2 variant=sota tflops=0.044 gbps=44.631 ms=0.047
shape=2 variant=solution tflops=0.031 gbps=30.847 ms=0.068
shape=2 solution_peak_fraction=0.0171
shape=3 variant=eager tflops=0.017 gbps=17.211 ms=0.045
shape=3 variant=compiled tflops=0.010 gbps=9.558 ms=0.081
shape=3 variant=sota tflops=0.017 gbps=17.398 ms=0.044
shape=3 variant=solution tflops=0.019 gbps=19.370 ms=0.040
shape=3 solution_peak_fraction=0.0108
shape=4 variant=eager tflops=0.096 gbps=95.604 ms=0.022
shape=4 variant=compiled tflops=0.035 gbps=35.260 ms=0.060
shape=4 variant=sota tflops=0.099 gbps=99.370 ms=0.021
shape=4 variant=solution tflops=0.206 gbps=206.564 ms=0.010
shape=4 solution_peak_fraction=0.1148
peak_fraction: 0.0168
RESULT: LOW
PASS
{
"problem": "05_topk_bitonic",
"harness": "opencode",
"model": "openrouter-pinned/xiaomi/mimo-v2.5-pro",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0168,
"elapsed_seconds": 2076,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 87976, "output_tokens": 17719, "cache_read_tokens": 1369408, "cache_creation_tokens": 0, "reasoning_tokens": 72021, "total_cost_usd": null}
}
{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/reference.py</path>
<type>file</type>
<content>
1: """Naive top-k reference: torch.topk over the last dim.
2:
3: This is the correctness oracle. The agent's solution must produce the same
4: top-k values (and equivalent indices modulo ties) within the tolerance
5: declared in problem.yaml. Note that solution.py is FORBIDDEN from calling
6: torch.topk / torch.sort / torch.kthvalue (see problem.yaml).
7: """
8: import torch
9: import torch.nn as nn
10:
11: OP_TYPE = "topk"
12: SUPPORTED_PRECISIONS = ["fp32"]
13: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
14:
15:
16: class Model(nn.Module):
17: """Top-k over the last dim of a 2D tensor.
18:
19: Input:
20: x: (batch, n) fp32
21: Output:
22: values: (batch, k) fp32, sorted descending
23: indices: (batch, k) int64, into the last dim of x
24: """
25:
26: def __init__(self, batch: int, n: int, k: int):
27: super().__init__()
28: self.batch, self.n, self.k = batch, n, k
29: # No learned parameters, but declare a dummy buffer so state_dict
30: # is non-empty and load_state_dict(strict=True) is meaningful.
31: self.register_buffer("_dummy", torch.zeros(1))
32:
33: def forward(self, x: torch.Tensor):
34: values, indices = torch.topk(x, k=self.k, dim=-1, largest=True, sorted=True)
35: return values, indices
36:
37:
38: # Module-level shims rebuilt b
... (419 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for TopK.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each. Top-k correctness has two parts:
5:
6: 1. VALUES: sol_values must match ref_values within fp32 tol. Both are
7: returned sorted descending, so positional comparison is well-defined.
8: 2. INDICES: lenient — we do NOT require sol_indices == ref_indices because
9: ties in x can yield multiple valid index sets. Instead we gather x at
10: sol_indices and check those values match ref_values within tol. This
11: catches "wrong indices" without false-failing on legitimate tie-breaks.
12:
13: Also rejects forbidden ops by grep.
14: """
15: import re
16: import sys
17: from pathlib import Path
18:
19: import torch
20: import yaml
21:
22: REPO_ROOT = Path(__file__).resolve().parents[2]
23: sys.path.insert(0, str(REPO_ROOT))
24:
25: from src.eval.correctness import check_correctness # noqa: E402
26:
27:
28: def main():
29: try:
30: import reference
31: import shapes
32: import solution
33: except Exception as e:
34: print(f"FAIL: import error: {e}")
35: sys.exit(1)
36:
37: problem_yaml = Path("problem.yaml")
38: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
39:
40: # --- Forbidden-op check --------------------------------
... (4902 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for TopK.
2:
3: For each shape: times eager reference (torch.topk), compiled reference, SOTA
4: (also torch.topk — see sota.py), and the agent's solution. Reports achieved
5: TFLOPS, GB/s, and peak_fraction (vs DRAM bandwidth, since this is memory-bound).
6:
7: Output lines the harness picks up:
8: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10: """
11: import sys
12: from math import exp, log
13: from pathlib import Path
14:
15: import torch
16: import yaml
17:
18: REPO_ROOT = Path(__file__).resolve().parents[2]
19: sys.path.insert(0, str(REPO_ROOT))
20:
21: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22: from src.eval.timing import time_fn # noqa: E402
23: from src.hardware import get as get_hw # noqa: E402
24:
25:
26: def _eval_formula(expr: str, vars: dict) -> float:
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "memory")
40
... (3250 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/problem.yaml</path>
<type>file</type>
<content>
1: name: 05_topk_bitonic
2: display_name: "TopK via Bitonic Sort"
3: precision: fp32
4: regime: memory
5:
6: # Top-k is dominated by the input read (small output, no reduction over k).
7: # Comparator-network FLOPs are not the bottleneck on real hardware, so we
8: # track them but score on bandwidth.
9: flops_formula: "batch * n * 4" # ~O(n log^2 n) compares total, but the
10: # bitonic network is so cheap relative
11: # to memory that a coarse 4*n estimate
12: # is fine for telemetry only.
13: bytes_formula: "batch * n * 4 + batch * k * (4 + 8)" # fp32 input read + (fp32 value + int64 idx) output
14:
15: hardware: [RTX_PRO_6000]
16: peak_tflops_key: fp32
17: peak_bandwidth_key: dram
18:
19: # Top-k correctness:
20: # - VALUES must match within fp32 atol/rtol (the kth largest value is
21: # well-defined modulo float-equal ties, so we use a loose-ish tol).
22: # - INDICES are checked leniently: for each row, the multiset of returned
23: # indices must select values that match ref values within tol. Direct
24: # index equality is NOT required (ties in x can yield different valid
25: # index sets).
26: tolerance:
27: float32: 1.0e-4
28:
29: # Forbidden ops — using any of these in solution.py fails corre
... (1071 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for TopK.
2:
3: Mix of:
4: - decoder vocab top-k (single sequence, very large n, moderate k) — pure
5: bandwidth test; the input read dominates everything.
6: - prefill / batched attention top-k (many rows, moderate n, small k) — tests
7: per-row parallelism and shared-memory bitonic networks.
8: - non-power-of-2 n stress case — bitonic sort networks naturally want
9: powers of two; this forces the agent to handle padding or partial sorts.
10: - small-k limit — k=1 (argmax) is a degenerate but useful sanity case.
11: """
12:
13: SHAPES = [
14: {"batch": 1, "n": 131072, "k": 64}, # decoder vocab top-k (Llama vocab ~128k)
15: {"batch": 64, "n": 8192, "k": 8}, # prefill / attention top-k
16: {"batch": 32, "n": 16384, "k": 32}, # mid-size batched
17: {"batch": 16, "n": 12000, "k": 16}, # non-power-of-2 n stress
18: {"batch": 128, "n": 4096, "k": 1}, # batched argmax (k=1 corner case)
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to optimize a top-k kernel for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The input is a 2D fp32 tensor and you return the top-k values and their int64 indices along the last dim, sorted descending, same contract as torch.topk(largest=True, sorted=True). Correctness tolerance is 1e-4 abs/rel on the values; index ties are handled leniently — what matters is the returned values match the reference to tolerance, not that the index sets are bitwise identical. The shapes you have to handle are batch=1 n=131072 k=64 (decoder vocab top-k over a Llama-size vocabulary), batch=64 n=8192 k=8 (prefill / attention top-k), batch=32 n=16384 k=32 (mid-size batched), batch=16 n=12000 k=16 (non-power-of-two n, which bitonic networks don't naturally want), and batch=128 n=4096 k=1 (batched argmax).
4:
5: This needs to be a real custom kernel — CUDA C++ via torch.utils.cpp_extension.load_inline, Triton, inline PTX, or CUTLASS, whatever you think fits. Don't reach for torch.topk, torch.kthvalue, torch.sort, or torch.ar
... (749 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for TopK: torch.topk itself.
2:
3: torch.topk dispatches to a CUB-backed kernel that uses radix-select for
4: moderate k and a tuned bitonic sort for small n. It is the bar the agent's
5: hand-rolled bitonic kernel must beat. There is no obvious vendor library that
6: does better on the (batch, n, k) shape mix we evaluate — Faiss BlockSelect is
7: specialized for k>=32 with much larger n, and CUB's DeviceSegmentedRadixSort
8: sorts the full row (overkill for top-k).
9:
10: This file is INTENTIONALLY allowed to call torch.topk because it is the SOTA
11: oracle, not the agent's submission. The agent's solution.py is forbidden from
12: using torch.topk (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def sota_forward(x: torch.Tensor, k: int):
20: """Best-available top-k reference. x: (batch, n) fp32."""
21: return torch.topk(x, k=k, dim=-1, largest=True, sorted=True)
22:
23:
24: def is_available() -> bool:
25: return True
(End of file - total 25 lines)
</content>{
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py</path>
<type>file</type>
<content>
1: """Roofline math: achieved TFLOPS / GB/s, peak fraction.
2:
3: FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
4: dense-equivalent algorithmic work; kernels that skip work (sparsity, early
5: exit) are still measured against the dense cost.
6: """
7: from __future__ import annotations
8:
9:
10: def compute_tflops(flops: float, time_ms: float) -> float:
11: """flops per invocation and wall time in ms -> TFLOPS."""
12: if time_ms <= 0:
13: return 0.0
14: return flops / (time_ms * 1e-3) / 1e12
15:
16:
17: def compute_gbps(bytes_moved: float, time_ms: float) -> float:
18: """bytes per invocation and wall time in ms -> GB/s."""
19: if time_ms <= 0:
20: return 0.0
21: return bytes_moved / (time_ms * 1e-3) / 1e9
22:
23:
24: def peak_fraction(
25: achieved: float,
26: peak: float,
27: ) -> float:
28: """Return achieved / peak, clamped to [0, infinity)."""
29: if peak <= 0:
30: return 0.0
31: return max(0.0, achieved / peak)
(End of file - total 31 lines)
</content>{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, shapes.py, problem.yaml",
"status": "completed",
"priority": "high"
},
{
"content": "Implement initial CUDA top-k kernel (bitonic sort in shared memory, tiled)",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py to verify correctness across all shapes",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py to get baseline performance numbers",
"status": "pending",
"priority": "high
... (399 more chars)content: 'Read reference.py, check.py, benchmark.py, shapes.py, problem.yaml'
status: 'completed'
priority: 'high'
content: 'Implement initial CUDA top-k kernel (bitonic sort in shared memory, tiled)'
status: 'in_progress'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py to get baseline performance numbers'
status: 'pending'
priority: 'high'
content: 'Profile with ncu/nsys to identify bottlenecks'
status: 'pending'
priority: 'medium'
content: 'Optimize kernel based on profiling results'
status: 'pending'
priority: 'medium'
content: 'Final correctness check and benchmark'
status: 'pending'
priority: 'high'{
"command": "nvidia-smi --query-gpu=name,memory.total,compute.cap --format=csv,noheader 2>/dev/null || echo \"no GPU\"",
"description": "Check GPU status and capabilities"
}Field "compute.cap" is not a valid field to query.
no GPU
{
"command": "nvidia-smi 2>/dev/null | head -20 || echo \"nvidia-smi not available\"",
"description": "Check GPU info"
}Tue Apr 28 21:06:57 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 24C P8 10W / 600W | 704MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (189 more chars){
"command": "python3 -c \"import torch; print(torch.cuda.get_device_properties(0))\"",
"description": "Get detailed GPU properties"
}_CudaDeviceProperties(name='NVIDIA RTX PRO 6000 Blackwell Workstation Edition', major=12, minor=0, total_memory=97227MB, multi_processor_count=188, uuid=d594fb7a-abcb-bad5-1fc4-675b478a47df, pci_bus_id=1, pci_device_id=0, pci_domain_id=0, L2_cache_size=128MB)
{
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -0,0 +1,272 @@
+"""Custom top-k kernel: bitonic sort network in shared memory.
+
+Strategy:
+ - For each row of length n, split into tiles of BLOCK_SIZE elements.
+ - Each CUDA block loads one tile into shared memory, runs a full bitonic
+ sort, and writes the top-k (value, index) pairs to an intermediate buffer.
+ - A merge kernel reads all per-tile top-k results for each row, sorts them
+ with bitonic sort, and writes the final top-k to the output.
+ - For k=1 (argmax), a specialized reduction kernel is used.
+"""
+import os
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+BLOCK_SIZE = 2048
+
+_cuda_src = r"""
+#include <cuda_runtime.h>
+#include <float.h>
+#include <stdint.h>
+
+// Per-tile top-k kernel: each block processes one tile of one row.
+// grid: (batch * num_chunks, 1, 1)
+// block: (BLOCK_SIZE,)
+extern "C" __global__
+void topk_chunked_kernel(
+ const float* __restrict__ input, // [batch, n]
+ float* __restrict__ intermediate_v, // [batch, num_chunks, k]
+ long* __restrict__ intermediate_i, // [batch, num_chunks, k]
+ int n, int k, int num_chunks, int BLOCK_SZ)
+{
+ extern __shared__ char smem_raw[];
+ float* s_val = (float*)smem_raw;
+ long* s_idx = (long*)(s_val + BLOCK_SZ);
+
+ int chunk_id = blockIdx.x;
+ int row = chunk_id / num_chunks;
+ int chunk = chunk_id % num_chunks;
+ int tid = threadIdx.x;
+
+ int offset = row * n + chunk * BLOCK_SZ;
+ int valid = n - chunk * BLOCK_SZ;
+ if (valid > BLOCK_SZ) valid = BLOCK_SZ;
+
+ // Load tile into shared memory; pad invalid positions with -inf
+ if (tid < valid) {
+ s_val[tid] = __ldg(&input[offset + tid]);
+ s_idx[tid] = (long)(chunk * BLOCK_SZ + tid);
+ } else {
+ s_val[tid] = -FLT_MAX;
+ s_idx[tid] = (long)(chunk * BLOCK_SZ + tid);
+ }
+ __syncthreads();
+
+ // Bitonic sort (ascending) on the full tile
+ for (int step = 1; step <= BLOCK_SZ; step <<= 1) {
+ for (int sub = step >> 1; sub > 0; sub >>= 1) {
+ int ix = tid ^ sub;
+ if (ix > tid) {
+ int dir = (~((tid & step) | (tid & (step - 1)))) & 1;
+ float va = s_val[tid], vb = s_val[ix];
+ long ia = s_idx[tid], ib = s_idx[ix];
+ int swap = (dir == 0) ? (va > vb) : (va < vb);
+ if (swap) {
+ s_val[tid] = vb; s_val[ix] = va;
+ s_idx[tid] = ib; s_idx[ix] = ia;
+ }
+ }
+ __syncthreads();
+ }
+ }
+
+ // Extract top-k (last k elements, reversed for descending)
+ int out_base = row * num_chunks * k + chunk * k;
+ for (int i = tid; i < k; i += BLOCK_SZ) {
+ int src = BLOCK_SZ - 1 - i;
+ intermediate_v[out_base + i] = s_val[src];
+ intermediate_i[out_base + i] = s_idx[src];
+ }
+}
+
+// Merge kernel: each block merges all per-tile top-k for one row.
+// grid: (batch, 1, 1)
+// block: (1024,)
+extern "C" __global__
+void topk_merge_kernel(
+ const float* __restrict__ intermediate_v,
+ const long* __restrict__ intermediate_i,
+ float* __restrict__ out_values,
+ long* __restrict__ out_indices,
+ int k, int num_chunks)
+{
+ extern __shared__ char smem_raw[];
+ float* s_val = (float*)smem_raw;
+ long* s_idx = (long*)(s_val + num_chunks * k);
+ int row = blockIdx.x;
+ int tid = threadIdx.x;
+ int total = num_chunks * k;
+
+ // Load all per-tile candidates
+ int in_base = row * total;
+ for (int i = tid; i < total; i += blockDim.x) {
+ s_val[i] = __ldg(&intermediate_v[in_base + i]);
+ s_idx[i] = __ldg(&intermediate_i[in_base + i]);
+ }
+ __syncthreads();
+
+ // Bitonic sort (ascending) on all candidates
+ for (int step = 1; step <= total; step <<= 1) {
+ for (int sub = step >> 1; sub > 0; sub >>= 1) {
+ int ix = tid ^ sub;
+ if (ix > tid && ix < total) {
+ int dir = (~((tid & step) | (tid & (step - 1)))) & 1;
+ float va = s_val[tid], vb = s_val[ix];
+ long ia = s_idx[tid], ib = s_idx[ix];
+ int swap = (dir == 0) ? (va > vb) : (va < vb);
+ if (swap) {
+ s_val[tid] = vb; s_val[ix] = va;
+ s_idx[tid] = ib; s_idx[ix] = ia;
+ }
+ }
+ __syncthreads();
+ }
+ }
+
+ // Write top-k (descending order)
+ int out_base = row * k;
+ for (int i = tid; i < k; i += blockDim.x) {
+ int src = total - 1 - i;
+ out_values[out_base + i] = s_val[src];
+ out_indices[out_base + i] = s_idx[src];
+ }
+}
+
+// Specialized argmax kernel for k=1: parallel reduction with warp shuffles.
+// grid: (batch, 1, 1)
+// block: (1024,)
+extern "C" __global__
+void topk_argmax_kernel(
+ const float* __restrict__ input,
+ float* __restrict__ out_values,
+ long* __restrict__ out_indices,
+ int n)
+{
+ extern __shared__ char smem_raw[];
+ float* s_val = (float*)smem_raw;
+ long* s_idx = (long*)(s_val + blockDim.x);
+
+ int row = blockIdx.x;
+ int tid = threadIdx.x;
+ const float* row_ptr = input + row * n;
+
+ // Each thread finds local max over its assigned elements
+ float best_val = -FLT_MAX;
+ long best_idx = 0;
+ for (int i = tid; i < n; i += blockDim.x) {
+ float v = __ldg(&row_ptr[i]);
+ if (v > best_val) { best_val = v; best_idx = (long)i; }
+ }
+ s_val[tid] = best_val;
+ s_idx[tid] = best_idx;
+ __syncthreads();
+
+ // Tree reduction in shared memory
+ for (int half = blockDim.x >> 1; half > 0; half >>= 1) {
+ if (tid < half) {
+ if (s_val[tid + half] > s_val[tid]) {
+ s_val[tid] = s_val[tid + half];
+ s_idx[tid] = s_idx[tid + half];
... (3083 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"content": "<8698 chars \u2014 see diff>"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -8,7 +8,6 @@
with bitonic sort, and writes the final top-k to the output.
- For k=1 (argmax), a specialized reduction kernel is used.
"""
-import os
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
@@ -16,6 +15,7 @@
BLOCK_SIZE = 2048
_cuda_src = r"""
+#include <torch/extension.h>
#include <cuda_runtime.h>
#include <float.h>
#include <stdint.h>
@@ -23,16 +23,16 @@
// Per-tile top-k kernel: each block processes one tile of one row.
// grid: (batch * num_chunks, 1, 1)
// block: (BLOCK_SIZE,)
-extern "C" __global__
+__global__
void topk_chunked_kernel(
- const float* __restrict__ input, // [batch, n]
- float* __restrict__ intermediate_v, // [batch, num_chunks, k]
- long* __restrict__ intermediate_i, // [batch, num_chunks, k]
+ const float* __restrict__ input,
+ float* __restrict__ intermediate_v,
+ int64_t* __restrict__ intermediate_i,
int n, int k, int num_chunks, int BLOCK_SZ)
{
extern __shared__ char smem_raw[];
float* s_val = (float*)smem_raw;
- long* s_idx = (long*)(s_val + BLOCK_SZ);
+ int64_t* s_idx = (int64_t*)(s_val + BLOCK_SZ);
int chunk_id = blockIdx.x;
int row = chunk_id / num_chunks;
@@ -43,24 +43,23 @@
int valid = n - chunk * BLOCK_SZ;
if (valid > BLOCK_SZ) valid = BLOCK_SZ;
- // Load tile into shared memory; pad invalid positions with -inf
if (tid < valid) {
s_val[tid] = __ldg(&input[offset + tid]);
- s_idx[tid] = (long)(chunk * BLOCK_SZ + tid);
+ s_idx[tid] = (int64_t)(chunk * BLOCK_SZ + tid);
} else {
s_val[tid] = -FLT_MAX;
- s_idx[tid] = (long)(chunk * BLOCK_SZ + tid);
+ s_idx[tid] = (int64_t)(chunk * BLOCK_SZ + tid);
}
__syncthreads();
- // Bitonic sort (ascending) on the full tile
+ // Bitonic sort (ascending)
for (int step = 1; step <= BLOCK_SZ; step <<= 1) {
for (int sub = step >> 1; sub > 0; sub >>= 1) {
int ix = tid ^ sub;
if (ix > tid) {
int dir = (~((tid & step) | (tid & (step - 1)))) & 1;
float va = s_val[tid], vb = s_val[ix];
- long ia = s_idx[tid], ib = s_idx[ix];
+ int64_t ia = s_idx[tid], ib = s_idx[ix];
int swap = (dir == 0) ? (va > vb) : (va < vb);
if (swap) {
s_val[tid] = vb; s_val[ix] = va;
@@ -81,24 +80,21 @@
}
// Merge kernel: each block merges all per-tile top-k for one row.
-// grid: (batch, 1, 1)
-// block: (1024,)
-extern "C" __global__
+__global__
void topk_merge_kernel(
const float* __restrict__ intermediate_v,
- const long* __restrict__ intermediate_i,
+ const int64_t* __restrict__ intermediate_i,
float* __restrict__ out_values,
- long* __restrict__ out_indices,
+ int64_t* __restrict__ out_indices,
int k, int num_chunks)
{
extern __shared__ char smem_raw[];
float* s_val = (float*)smem_raw;
- long* s_idx = (long*)(s_val + num_chunks * k);
+ int64_t* s_idx = (int64_t*)(s_val + num_chunks * k);
int row = blockIdx.x;
int tid = threadIdx.x;
int total = num_chunks * k;
- // Load all per-tile candidates
int in_base = row * total;
for (int i = tid; i < total; i += blockDim.x) {
s_val[i] = __ldg(&intermediate_v[in_base + i]);
@@ -106,14 +102,14 @@
}
__syncthreads();
- // Bitonic sort (ascending) on all candidates
+ // Bitonic sort (ascending)
for (int step = 1; step <= total; step <<= 1) {
for (int sub = step >> 1; sub > 0; sub >>= 1) {
int ix = tid ^ sub;
if (ix > tid && ix < total) {
int dir = (~((tid & step) | (tid & (step - 1)))) & 1;
float va = s_val[tid], vb = s_val[ix];
- long ia = s_idx[tid], ib = s_idx[ix];
+ int64_t ia = s_idx[tid], ib = s_idx[ix];
int swap = (dir == 0) ? (va > vb) : (va < vb);
if (swap) {
s_val[tid] = vb; s_val[ix] = va;
@@ -124,7 +120,7 @@
}
}
- // Write top-k (descending order)
+ // Write top-k (descending)
int out_base = row * k;
for (int i = tid; i < k; i += blockDim.x) {
int src = total - 1 - i;
@@ -133,36 +129,32 @@
}
}
-// Specialized argmax kernel for k=1: parallel reduction with warp shuffles.
-// grid: (batch, 1, 1)
-// block: (1024,)
-extern "C" __global__
+// Argmax kernel for k=1
+__global__
void topk_argmax_kernel(
const float* __restrict__ input,
float* __restrict__ out_values,
- long* __restrict__ out_indices,
+ int64_t* __restrict__ out_indices,
int n)
{
extern __shared__ char smem_raw[];
float* s_val = (float*)smem_raw;
- long* s_idx = (long*)(s_val + blockDim.x);
+ int64_t* s_idx = (int64_t*)(s_val + blockDim.x);
int row = blockIdx.x;
int tid = threadIdx.x;
const float* row_ptr = input + row * n;
- // Each thread finds local max over its assigned elements
float best_val = -FLT_MAX;
- long best_idx = 0;
+ int64_t best_idx = 0;
for (int i = tid; i < n; i += blockDim.x) {
float v = __ldg(&row_ptr[i]);
- if (v > best_val) { best_val = v; best_idx = (long)i; }
+ if (v > best_val) { best_val = v; best_idx = (int64_t)i; }
}
s_val[tid] = best_val;
s_idx[tid] = best_idx;
__syncthreads();
- // Tree reduction in shared memory
for (int half = blockDim.x >> 1; half > 0; half >>= 1) {
if (tid < half) {
if (s_val[tid + half] > s_val[tid]) {
@@ -178,6 +170,45 @@
out_indices[row] = s_idx[0];
}
}
+
+// C++ wrapper functions exposed to Python
+torch::Tensor run_argmax(torch::Tensor input, int n) {
+ int batch = input.si
... (3962 more chars){
"content": "<9705 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<395 chars \u2014 see below>",
"description": "Test load_inline with PYBIND11_MODULE",
"timeout": 60000
}python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
src = '''
#include <torch/extension.h>
torch::Tensor my_add(torch::Tensor a, torch::Tensor b) { return a + b; }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def(\"my_add\", &my_add); }
'''
mod = load_inline(name='test1', cpp_sources=src, cuda_sources='', verbose=False)
print(mod.my_add(torch.ones(3), torch.ones(3)))
"tensor([2., 2., 2.])
{
"command": "<810 chars \u2014 see below>",
"description": "Test load_inline with CUDA kernel",
"timeout": 60000
}python3 -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
#include <cuda_runtime.h>
__global__ void add_kernel(float* a, float* b, float* c, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) c[i] = a[i] + b[i];
}
torch::Tensor my_cuda_add(torch::Tensor a, torch::Tensor b) {
auto c = torch::empty_like(a);
int n = a.numel();
add_kernel<<<(n+255)/256, 256>>>(a.data_ptr<float>(), b.data_ptr<float>(), c.data_ptr<float>(), n);
return c;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"my_cuda_add\", &my_cuda_add);
}
'''
mod = load_inline(name='test2', cpp_sources='', cuda_sources=cuda_src, verbose=False)
print(mod.my_cuda_add(torch.ones(3, device='cuda'), torch.ones(3, device='cuda')))
" 2>&1tensor([2., 2., 2.], device='cuda:0')
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,12 +1,10 @@
"""Custom top-k kernel: bitonic sort network in shared memory.
Strategy:
- - For each row of length n, split into tiles of BLOCK_SIZE elements.
- - Each CUDA block loads one tile into shared memory, runs a full bitonic
- sort, and writes the top-k (value, index) pairs to an intermediate buffer.
- - A merge kernel reads all per-tile top-k results for each row, sorts them
- with bitonic sort, and writes the final top-k to the output.
- - For k=1 (argmax), a specialized reduction kernel is used.
+ - For k=1: specialized argmax reduction kernel (warp shuffle + shared mem tree).
+ - For k>1, n<=BLOCK_SIZE: single-block bitonic sort, extract top-k.
+ - For k>1, n>BLOCK_SIZE: two-pass — chunked kernel finds per-tile top-k,
+ merge kernel combines them via bitonic sort.
"""
import torch
import torch.nn as nn
@@ -20,9 +18,6 @@
#include <float.h>
#include <stdint.h>
-// Per-tile top-k kernel: each block processes one tile of one row.
-// grid: (batch * num_chunks, 1, 1)
-// block: (BLOCK_SIZE,)
__global__
void topk_chunked_kernel(
const float* __restrict__ input,
@@ -70,7 +65,7 @@
}
}
- // Extract top-k (last k elements, reversed for descending)
+ // Extract top-k (last k, reversed for descending)
int out_base = row * num_chunks * k + chunk * k;
for (int i = tid; i < k; i += BLOCK_SZ) {
int src = BLOCK_SZ - 1 - i;
@@ -79,7 +74,6 @@
}
}
-// Merge kernel: each block merges all per-tile top-k for one row.
__global__
void topk_merge_kernel(
const float* __restrict__ intermediate_v,
@@ -102,7 +96,6 @@
}
__syncthreads();
- // Bitonic sort (ascending)
for (int step = 1; step <= total; step <<= 1) {
for (int sub = step >> 1; sub > 0; sub >>= 1) {
int ix = tid ^ sub;
@@ -120,7 +113,6 @@
}
}
- // Write top-k (descending)
int out_base = row * k;
for (int i = tid; i < k; i += blockDim.x) {
int src = total - 1 - i;
@@ -129,7 +121,6 @@
}
}
-// Argmax kernel for k=1
__global__
void topk_argmax_kernel(
const float* __restrict__ input,
@@ -171,43 +162,39 @@
}
}
-// C++ wrapper functions exposed to Python
-torch::Tensor run_argmax(torch::Tensor input, int n) {
- int batch = input.size(0);
- auto out_v = torch::empty({batch, 1}, input.options());
- auto out_i = torch::empty({batch, 1}, input.options().dtype(torch::kInt64));
- int block = 1024;
- int smem = block * (4 + 8);
- topk_argmax_kernel<<<batch, block, smem>>>(
- input.data_ptr<float>(),
- out_v.data_ptr<float>(),
- out_i.data_ptr<int64_t>(),
- n);
- return torch::cat({out_v, out_i}, 0); // pack as [2, batch, 1]
-}
-
-void run_chunked(torch::Tensor input, torch::Tensor inter_v, torch::Tensor inter_i,
- int n, int k, int num_chunks, int block_size) {
- int batch = input.size(0);
- int smem = block_size * (4 + 8);
- topk_chunked_kernel<<<batch * num_chunks, block_size, smem>>>(
- input.data_ptr<float>(),
- inter_v.data_ptr<float>(),
- inter_i.data_ptr<int64_t>(),
- n, k, num_chunks, block_size);
-}
-
-void run_merge(torch::Tensor inter_v, torch::Tensor inter_i,
- torch::Tensor out_v, torch::Tensor out_i,
- int k, int num_chunks, int merge_block) {
- int batch = inter_v.size(0);
- int smem = num_chunks * k * (4 + 8);
- topk_merge_kernel<<<batch, merge_block, smem>>>(
- inter_v.data_ptr<float>(),
- inter_i.data_ptr<int64_t>(),
- out_v.data_ptr<float>(),
- out_i.data_ptr<int64_t>(),
- k, num_chunks);
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("run_argmax", [](torch::Tensor input, int n) {
+ int batch = input.size(0);
+ auto out_v = torch::empty({batch, 1}, input.options());
+ auto out_i = torch::empty({batch, 1}, input.options().dtype(torch::kInt64));
+ int block = 1024;
+ int smem = block * (4 + 8);
+ topk_argmax_kernel<<<batch, block, smem>>>(
+ input.data_ptr<float>(), out_v.data_ptr<float>(),
+ out_i.data_ptr<int64_t>(), n);
+ return std::vector<torch::Tensor>{out_v, out_i};
+ }, "argmax top-1");
+
+ m.def("run_chunked", [](torch::Tensor input, torch::Tensor inter_v,
+ torch::Tensor inter_i, int n, int k,
+ int num_chunks, int block_size) {
+ int batch = input.size(0);
+ int smem = block_size * (4 + 8);
+ topk_chunked_kernel<<<batch * num_chunks, block_size, smem>>>(
+ input.data_ptr<float>(), inter_v.data_ptr<float>(),
+ inter_i.data_ptr<int64_t>(), n, k, num_chunks, block_size);
+ }, "chunked top-k");
+
+ m.def("run_merge", [](torch::Tensor inter_v, torch::Tensor inter_i,
+ torch::Tensor out_v, torch::Tensor out_i,
+ int k, int num_chunks, int merge_block) {
+ int batch = inter_v.size(0);
+ int smem = num_chunks * k * (4 + 8);
+ topk_merge_kernel<<<batch, merge_block, smem>>>(
+ inter_v.data_ptr<float>(), inter_i.data_ptr<int64_t>(),
+ out_v.data_ptr<float>(), out_i.data_ptr<int64_t>(),
+ k, num_chunks);
+ }, "merge top-k");
}
"""
@@ -218,13 +205,10 @@
if _topk_module is None:
_topk_module = load_inline(
name="topk_bitonic",
- cpp_sources=_cuda_src,
- cuda_sources="",
- extra_cuda_cflags=[
- "-O3",
- "--use_fast_math",
- "-DNDEBUG",
- ],
+ cpp_sources="",
+ cuda_sources=_cuda_src,
+ extra_cuda_cflags=["-O3", "--use_fast_math", "-DNDEBUG"],
+ verbose=False,
... (2496 more chars){
"content": "<8598 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<449 chars \u2014 see below>",
"description": "Verify bitonic sort direction formula"
}python3 -c "
# Verify the bitonic sort direction formula is correct
# Standard: dir = (tid / step) & 1
# Mine: dir = (~((tid & step) | (tid & (step-1)))) & 1
for step in [1, 2, 4, 8]:
for tid in range(16):
standard = (tid // step) & 1
mine = (~((tid & step) | (tid & (step - 1)))) & 1
if standard != mine:
print(f'MISMATCH step={step} tid={tid}: standard={standard} mine={mine}')
print('Done checking')
"MISMATCH step=1 tid=0: standard=0 mine=1
MISMATCH step=1 tid=1: standard=1 mine=0
MISMATCH step=1 tid=2: standard=0 mine=1
MISMATCH step=1 tid=3: standard=1 mine=0
MISMATCH step=1 tid=4: standard=0 mine=1
MISMATCH step=1 tid=5: standard=1 mine=0
MISMATCH step=1 tid=6: standard=0 mine=1
MISMATCH step=1 tid=7: standard=1 mine=0
MISMATCH step=1 tid=8: standard=0 mine=1
MISMATCH step=1 tid=9: standard=1 mine=0
MISMATCH step=1 tid=10: standard=0 mine=1
MISMATCH step=1 tid=11: standard=1 mine=0
MISMATCH step=1 tid=12: standard=0 mine=1
MISMATCH step=1 tid=13: standard=1 mine=0
MISMATCH step=1 tid=14: standard=0 mine=1
MISMATCH step=1 tid=15: standard=1 mine=0
MISMATCH step=2 tid=0: standard=0 mine=1
MISMATCH step=2 tid=3: standard=1 mine=0
MISMATCH step=2 tid=4: standard=0 mine=1
MISMATCH step=2 tid=7: standard=1 mine=0
MISMATCH step=2 tid=8: standard=0 mine=1
MISMATCH step=2 tid=11: standard=1 mine=0
MISMATCH step=2 tid=12: standard=0 mine=1
MISMATCH step=2 tid=15: standard=1 mine=0
MISMATCH step=4 tid=0: standard=0 mine=1
MISMATCH step=4 tid=2: standard=0 mine=1
MISMATCH step=4 tid=5: standard=1 mine=0
MISMATCH step=4 tid=7: standard=1 mine=0
MISMATCH step=4 tid=8: standard=0 mine=1
MISMATCH step=4 tid=10: standard=0 mine=1
MISMATCH step=4 tid=13: standard=1 mine=0
MISMATCH step=4 tid=15: standard=1 mine=0
MISMATCH step=8 tid=0: standard=0 mine=1
MISMATCH step=8 tid=2: standard=0 mine=1
MISMATCH step=8 tid=4: standard=0 mine=1
MISMATCH step=8 tid=6: standard=0 mine=1
MISMATCH ste
... (169 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -47,12 +47,12 @@
}
__syncthreads();
- // Bitonic sort (ascending)
+ // Bitonic sort (ascending) - standard direction formula
for (int step = 1; step <= BLOCK_SZ; step <<= 1) {
for (int sub = step >> 1; sub > 0; sub >>= 1) {
int ix = tid ^ sub;
if (ix > tid) {
- int dir = (~((tid & step) | (tid & (step - 1)))) & 1;
+ int dir = (tid & step) ? 1 : 0;
float va = s_val[tid], vb = s_val[ix];
int64_t ia = s_idx[tid], ib = s_idx[ix];
int swap = (dir == 0) ? (va > vb) : (va < vb);
@@ -65,7 +65,7 @@
}
}
- // Extract top-k (last k, reversed for descending)
+ // Extract top-k (last k elements, reversed for descending)
int out_base = row * num_chunks * k + chunk * k;
for (int i = tid; i < k; i += BLOCK_SZ) {
int src = BLOCK_SZ - 1 - i;
@@ -96,11 +96,12 @@
}
__syncthreads();
+ // Bitonic sort (ascending)
for (int step = 1; step <= total; step <<= 1) {
for (int sub = step >> 1; sub > 0; sub >>= 1) {
int ix = tid ^ sub;
if (ix > tid && ix < total) {
- int dir = (~((tid & step) | (tid & (step - 1)))) & 1;
+ int dir = (tid & step) ? 1 : 0;
float va = s_val[tid], vb = s_val[ix];
int64_t ia = s_idx[tid], ib = s_idx[ix];
int swap = (dir == 0) ? (va > vb) : (va < vb);
@@ -113,6 +114,7 @@
}
}
+ // Write top-k (descending)
int out_base = row * k;
for (int i = tid; i < k; i += blockDim.x) {
int src = total - 1 - i;
{
"content": "<8656 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<341 chars \u2014 see below>",
"description": "Verify corrected direction formula"
}python3 -c "
# Verify the corrected formula
for step in [1, 2, 4, 8]:
for tid in range(16):
standard = (tid // step) & 1
corrected = 1 if (tid & step) else 0
if standard != corrected:
print(f'MISMATCH step={step} tid={tid}: standard={standard} corrected={corrected}')
print('All OK' if True else '')
"All OK
{
"command": "python3 check.py 2>&1",
"description": "Run correctness check on all shapes",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py", line 149, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py", line 95, in main
ok, msg = check_correctness(
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py", line 71, in check_correctness
if torch.isnan(solution_out).any():
^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
{
"command": "<541 chars \u2014 see below>",
"description": "Test k=1 shape with CUDA_LAUNCH_BLOCKING",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}CUDA_LAUNCH_BLOCKING=1 python3 -c "
import torch
import solution
# Test smallest shape first
x = torch.randn(128, 4096, dtype=torch.float32, device='cuda')
model = solution.Model(128, 4096, 1).cuda()
model.eval()
with torch.no_grad():
v, i = model(x)
print(f'k=1: values shape={v.shape}, indices shape={i.shape}')
print(f' values[:5]={v[:5].flatten()}')
ref_v, ref_i = torch.topk(x, 1, dim=-1, largest=True, sorted=True)
print(f' ref[:5]={ref_v[:5].flatten()}')
print(f' max diff={torch.max(torch.abs(v - ref_v)).item():.6e}')
" 2>&1k=1: values shape=torch.Size([128, 1]), indices shape=torch.Size([128, 1])
values[:5]=tensor([3.9113, 4.4953, 3.5983, 3.8212, 3.5429], device='cuda:0')
ref[:5]=tensor([3.9113, 4.4953, 3.5983, 3.8212, 3.5429], device='cuda:0')
max diff=0.000000e+00
{
"command": "<617 chars \u2014 see below>",
"description": "Test batch=64, n=8192, k=8",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}CUDA_LAUNCH_BLOCKING=1 python3 -c "
import torch
import solution
# Test batch=64, n=8192, k=8
x = torch.randn(64, 8192, dtype=torch.float32, device='cuda')
model = solution.Model(64, 8192, 8).cuda()
model.eval()
with torch.no_grad():
v, i = model(x)
print(f'k=8: values shape={v.shape}, indices shape={i.shape}')
ref_v, ref_i = torch.topk(x, 8, dim=-1, largest=True, sorted=True)
print(f' max val diff={torch.max(torch.abs(v - ref_v)).item():.6e}')
# Check gathered values
gathered = torch.gather(x, -1, i.to(torch.int64))
print(f' max gathered diff={torch.max(torch.abs(gathered - ref_v)).item():.6e}')
" 2>&1Traceback (most recent call last):
File "<string>", line 12, in <module>
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
k=8: values shape=torch.Size([64, 8]), indices shape=torch.Size([64, 8])
{
"command": "<166 chars \u2014 see below>",
"description": "Check max threads per block"
}python3 -c "
import torch
print(torch.cuda.get_device_properties(0).max_threads_per_block)
print(torch.cuda.get_device_properties(0).max_threads_per_multiprocessor)
"Traceback (most recent call last):
File "<string>", line 4, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'max_threads_per_multiprocessor'. Did you mean: 'max_threads_per_multi_processor'?
1024
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,16 +1,19 @@
"""Custom top-k kernel: bitonic sort network in shared memory.
Strategy:
- - For k=1: specialized argmax reduction kernel (warp shuffle + shared mem tree).
- - For k>1, n<=BLOCK_SIZE: single-block bitonic sort, extract top-k.
- - For k>1, n>BLOCK_SIZE: two-pass — chunked kernel finds per-tile top-k,
+ - For k=1: specialized argmax reduction kernel.
+ - For k>1, n<=TILE_SIZE: single-block bitonic sort, extract top-k.
+ - For k>1, n>TILE_SIZE: two-pass — chunked kernel finds per-tile top-k,
merge kernel combines them via bitonic sort.
+ - TILE_SIZE=2048 elements per tile, 1024 threads per block (each thread
+ handles 2 elements during the sort).
"""
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
-BLOCK_SIZE = 2048
+TILE_SIZE = 2048
+MAX_THREADS = 1024
_cuda_src = r"""
#include <torch/extension.h>
@@ -18,47 +21,73 @@
#include <float.h>
#include <stdint.h>
+// Sort a shared-memory buffer of size N (must be power of 2) using bitonic sort.
+// NTHREADS threads cooperate. Each thread owns N/NTHREADS elements initially.
+// We use a register-based approach: each thread holds multiple elements.
+// For simplicity, we use a sequential bitonic sort approach where all threads
+// participate in compare-swap phases via shared memory.
+
+// Actually, let's use a simpler approach: load into smem, then do
+// bitonic network with all NTHREADS threads, where the network has
+// log2(N) stages and each stage does N/2 compare-swaps distributed
+// across NTHREADS threads (each thread does ceil(N/(2*NTHREADS)) swaps).
+
+__device__ __forceinline__ void swap_f(float& a, float& b) { float t=a; a=b; b=t; }
+__device__ __forceinline__ void swap_i(int64_t& a, int64_t& b) { int64_t t=a; a=b; b=t; }
+
__global__
void topk_chunked_kernel(
const float* __restrict__ input,
float* __restrict__ intermediate_v,
int64_t* __restrict__ intermediate_i,
- int n, int k, int num_chunks, int BLOCK_SZ)
+ int n, int k, int num_chunks, int TILE_SZ)
{
extern __shared__ char smem_raw[];
float* s_val = (float*)smem_raw;
- int64_t* s_idx = (int64_t*)(s_val + BLOCK_SZ);
+ int64_t* s_idx = (int64_t*)(s_val + TILE_SZ);
int chunk_id = blockIdx.x;
int row = chunk_id / num_chunks;
int chunk = chunk_id % num_chunks;
int tid = threadIdx.x;
-
- int offset = row * n + chunk * BLOCK_SZ;
- int valid = n - chunk * BLOCK_SZ;
- if (valid > BLOCK_SZ) valid = BLOCK_SZ;
-
- if (tid < valid) {
- s_val[tid] = __ldg(&input[offset + tid]);
- s_idx[tid] = (int64_t)(chunk * BLOCK_SZ + tid);
- } else {
- s_val[tid] = -FLT_MAX;
- s_idx[tid] = (int64_t)(chunk * BLOCK_SZ + tid);
+ int nthreads = blockDim.x;
+
+ int offset = row * n + chunk * TILE_SZ;
+ int valid = n - chunk * TILE_SZ;
+ if (valid > TILE_SZ) valid = TILE_SZ;
+
+ // Each thread loads multiple elements (coalesced)
+ for (int i = tid; i < TILE_SZ; i += nthreads) {
+ if (i < valid) {
+ s_val[i] = __ldg(&input[offset + i]);
+ s_idx[i] = (int64_t)(chunk * TILE_SZ + i);
+ } else {
+ s_val[i] = -FLT_MAX;
+ s_idx[i] = (int64_t)(chunk * TILE_SZ + i);
+ }
}
__syncthreads();
- // Bitonic sort (ascending) - standard direction formula
- for (int step = 1; step <= BLOCK_SZ; step <<= 1) {
+ // Bitonic sort in shared memory
+ for (int step = 1; step <= TILE_SZ; step <<= 1) {
for (int sub = step >> 1; sub > 0; sub >>= 1) {
- int ix = tid ^ sub;
- if (ix > tid) {
- int dir = (tid & step) ? 1 : 0;
- float va = s_val[tid], vb = s_val[ix];
- int64_t ia = s_idx[tid], ib = s_idx[ix];
- int swap = (dir == 0) ? (va > vb) : (va < vb);
- if (swap) {
- s_val[tid] = vb; s_val[ix] = va;
- s_idx[tid] = ib; s_idx[ix] = ia;
+ // Each thread handles multiple compare-swaps
+ for (int i = tid; i < (TILE_SZ >> 1); i += nthreads) {
+ // Map work index to the two elements to compare
+ int block_id = i / sub;
+ int block_start = block_id * 2 * sub;
+ int offset_in_block = i % sub;
+ int a = block_start + offset_in_block;
+ int b = block_start + offset_in_block + sub;
+ if (b < TILE_SZ) {
+ int dir = (a & step) ? 1 : 0;
+ float va = s_val[a], vb = s_val[b];
+ int64_t ia = s_idx[a], ib = s_idx[b];
+ int swap = (dir == 0) ? (va > vb) : (va < vb);
+ if (swap) {
+ s_val[a] = vb; s_val[b] = va;
+ s_idx[a] = ib; s_idx[b] = ia;
+ }
}
}
__syncthreads();
@@ -67,8 +96,8 @@
// Extract top-k (last k elements, reversed for descending)
int out_base = row * num_chunks * k + chunk * k;
- for (int i = tid; i < k; i += BLOCK_SZ) {
- int src = BLOCK_SZ - 1 - i;
+ for (int i = tid; i < k; i += nthreads) {
+ int src = TILE_SZ - 1 - i;
intermediate_v[out_base + i] = s_val[src];
intermediate_i[out_base + i] = s_idx[src];
}
@@ -83,31 +112,48 @@
int k, int num_chunks)
{
extern __shared__ char smem_raw[];
+ int total = num_chunks * k;
+ // Pad total to next power of 2 for bitonic sort
+ int sort_size = 1;
+ while (sort_size < total) sort_size <<= 1;
+
float* s_val = (float*)smem_raw;
- int64_t* s_idx = (int64_t*)(s_val + num_chunks * k);
+ int64_t* s_idx = (int64_t*)(s_val + sort_size);
int row = blockIdx.x;
int tid = threadIdx.x;
-
... (5761 more chars){
"content": "<10880 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<672 chars \u2014 see below>",
"description": "Test batch=64, n=8192, k=8",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}CUDA_LAUNCH_BLOCKING=1 python3 -c "
import torch
import solution
# Test batch=64, n=8192, k=8
x = torch.randn(64, 8192, dtype=torch.float32, device='cuda')
model = solution.Model(64, 8192, 8).cuda()
model.eval()
with torch.no_grad():
v, i = model(x)
print(f'k=8: values shape={v.shape}, indices shape={i.shape}')
ref_v, ref_i = torch.topk(x, 8, dim=-1, largest=True, sorted=True)
print(f' max val diff={torch.max(torch.abs(v - ref_v)).item():.6e}')
gathered = torch.gather(x, -1, i.to(torch.int64))
print(f' max gathered diff={torch.max(torch.abs(gathered - ref_v)).item():.6e}')
print(f' values descending check: {torch.all(v[:, :-1] >= v[:, 1:]).item()}')
" 2>&1k=8: values shape=torch.Size([64, 8]), indices shape=torch.Size([64, 8])
max val diff=0.000000e+00
max gathered diff=0.000000e+00
values descending check: True
{
"command": "<815 chars \u2014 see below>",
"description": "Test all shapes",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}CUDA_LAUNCH_BLOCKING=1 python3 -c "
import torch
import solution
shapes = [
(1, 131072, 64),
(64, 8192, 8),
(32, 16384, 32),
(16, 12000, 16),
(128, 4096, 1),
]
for batch, n, k in shapes:
x = torch.randn(batch, n, dtype=torch.float32, device='cuda')
model = solution.Model(batch, n, k).cuda()
model.eval()
with torch.no_grad():
v, i = model(x)
ref_v, ref_i = torch.topk(x, k, dim=-1, largest=True, sorted=True)
val_diff = torch.max(torch.abs(v - ref_v)).item()
gathered = torch.gather(x, -1, i.to(torch.int64))
gather_diff = torch.max(torch.abs(gathered - ref_v)).item()
desc = torch.all(v[:, :-1] >= v[:, 1:]).item() if k > 1 else True
print(f'batch={batch} n={n} k={k}: val_diff={val_diff:.2e} gather_diff={gather_diff:.2e} desc={desc}')
" 2>&1batch=1 n=131072 k=64: val_diff=0.00e+00 gather_diff=0.00e+00 desc=True
batch=64 n=8192 k=8: val_diff=0.00e+00 gather_diff=0.00e+00 desc=True
batch=32 n=16384 k=32: val_diff=0.00e+00 gather_diff=0.00e+00 desc=True
batch=16 n=12000 k=16: val_diff=0.00e+00 gather_diff=0.00e+00 desc=True
batch=128 n=4096 k=1: val_diff=0.00e+00 gather_diff=0.00e+00 desc=True
{
"command": "python3 check.py 2>&1",
"description": "Run full correctness check",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}PASS
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, shapes.py, problem.yaml",
"priority": "high",
"status": "completed"
},
{
"content": "Implement initial CUDA top-k kernel (bitonic sort in shared memory, tiled)",
"priority": "high",
"status": "completed"
},
{
"content": "Run check.py to verify correctness across all shapes",
"priority": "high",
"status": "completed"
},
{
"content": "Run benchmark.py to get baseline performance numbers",
"priority": "high",
"status": "in_prog
... (403 more chars)content: 'Read reference.py, check.py, benchmark.py, shapes.py, problem.yaml'
status: 'completed'
priority: 'high'
content: 'Implement initial CUDA top-k kernel (bitonic sort in shared memory, tiled)'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py to get baseline performance numbers'
status: 'in_progress'
priority: 'high'
content: 'Profile with ncu/nsys to identify bottlenecks'
status: 'pending'
priority: 'medium'
content: 'Optimize kernel based on profiling results'
status: 'pending'
priority: 'medium'
content: 'Final correctness check and benchmark'
status: 'pending'
priority: 'high'{
"command": "python3 benchmark.py 2>&1",
"description": "Run benchmark to get performance numbers",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}shape=0 variant=eager tflops=0.011 gbps=10.731 ms=0.049
shape=0 variant=compiled tflops=0.007 gbps=6.763 ms=0.078
shape=0 variant=sota tflops=0.011 gbps=10.946 ms=0.048
shape=0 variant=solution tflops=0.006 gbps=6.318 ms=0.083
shape=0 solution_peak_fraction=0.0035
shape=1 variant=eager tflops=0.044 gbps=44.009 ms=0.048
shape=1 variant=compiled tflops=0.026 gbps=26.166 ms=0.080
shape=1 variant=sota tflops=0.045 gbps=44.804 ms=0.047
shape=1 variant=solution tflops=0.032 gbps=32.267 ms=0.065
shape=1 solution_peak_fraction=0.0179
shape=2 variant=eager tflops=0.044 gbps=43.903 ms=0.048
shape=2 variant=compiled tflops=0.027 gbps=26.694 ms=0.079
shape=2 variant=sota tflops=0.044 gbps=44.242 ms=0.048
shape=2 variant=solution tflops=0.030 gbps=30.582 ms=0.069
shape=2 solution_peak_fraction=0.0170
shape=3 variant=eager tflops=0.017 gbps=17.114 ms=0.045
shape=3 variant=compiled tflops=0.010 gbps=9.695 ms=0.080
shape=3 variant=sota tflops=0.017 gbps=17.335 ms=0.044
shape=3 variant=solution tflops=0.019 gbps=19.162 ms=0.040
shape=3 solution_peak_fraction=0.0106
shape=4 variant=eager tflops=0.094 gbps=94.501 ms=0.022
shape=4 variant=compiled tflops=0.036 gbps=35.907 ms=0.058
shape=4 variant=sota tflops=0.099 gbps=98.995 ms=0.021
shape=4 variant=solution tflops=0.208 gbps=208.534 ms=0.010
shape=4 solution_peak_fraction=0.1159
peak_fraction: 0.0167
RESULT: LOW
{
"command": "<847 chars \u2014 see below>",
"description": "Quick timing comparison",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}python3 -c "
import torch
import solution
# Quick timing comparison
for batch, n, k in [(1, 131072, 64), (64, 8192, 8), (128, 4096, 1)]:
x = torch.randn(batch, n, dtype=torch.float32, device='cuda')
model = solution.Model(batch, n, k).cuda().eval()
# Warmup
for _ in range(5):
with torch.no_grad(): model(x)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for _ in range(50):
torch.cuda.synchronize()
start.record()
with torch.no_grad(): model(x)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
times.sort()
print(f'batch={batch} n={n} k={k}: median={times[len(times)//2]:.3f}ms p10={times[len(times)//10]:.3f}ms')
" 2>&1batch=1 n=131072 k=64: median=0.081ms p10=0.081ms
batch=64 n=8192 k=8: median=0.063ms p10=0.062ms
batch=128 n=4096 k=1: median=0.009ms p10=0.008ms