"""TopK custom kernel via load_inline (v5).
Strategy:
- Stage 1: full bitonic sort of small chunks in shared memory; take
chunk-local top-K. Stage 1 has many small chunks for SM parallelism.
- Stage 2: each stage-1 chunk emits a sorted-descending run of K
elements. We merge G sorted runs of length K into the global top-K
using a tournament-style argmax: per iteration, one thread per run
holds its current head value; warp+block argmax picks the run with
the largest head; that run advances by 1. K iterations -> top-K.
This is O(K * log G) instead of bitonic-sort's O(G*K * log^2(G*K)).
- Shape (128, 4096, 1) is k=1 argmax — dedicated max-reduction kernel.
"""
import os
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "12.0")
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
CPP_SRC = r"""
std::vector<torch::Tensor> topk_forward(torch::Tensor x, int64_t k);
"""
CUDA_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cstdint>
#define NEG_INF (-3.4e38f)
// Block-wide bitonic sort, ascending. N must be a power of 2.
template<int N, int THREADS>
__device__ __forceinline__ void bitonic_sort_block(float* svals, int* sidxs) {
#pragma unroll
for (int sz = 2; sz <= N; sz <<= 1) {
#pragma unroll
for (int j = sz >> 1; j > 0; j >>= 1) {
#pragma unroll
for (int t0 = 0; t0 < N; t0 += THREADS) {
int t = t0 + threadIdx.x;
if (t < N) {
int p = t ^ j;
if (p > t) {
bool asc = (t & sz) == 0;
float vt = svals[t], vp = svals[p];
bool sw = asc ? (vt > vp) : (vt < vp);
if (sw) {
svals[t] = vp; svals[p] = vt;
int it = sidxs[t], ip = sidxs[p];
sidxs[t] = ip; sidxs[p] = it;
}
}
}
}
__syncthreads();
}
}
}
// Stage 1 / single-stage chunk kernel: load CHUNK_SIZE into shared (pad
// to N_PAD with -inf), bitonic sort, write top-K to (out_v, out_i).
template<int K, int N_PAD, int THREADS, int WRITE_INT64>
__global__ void topk_chunk_kernel(
const float* __restrict__ x,
void* __restrict__ out_v,
void* __restrict__ out_i,
int n,
int chunk_size,
int chunks_per_row)
{
extern __shared__ unsigned char smem_raw[];
float* svals = reinterpret_cast<float*>(smem_raw);
int* sidxs = reinterpret_cast<int*>(smem_raw + N_PAD * sizeof(float));
int row = blockIdx.y;
int blk = blockIdx.x;
int chunk_start = blk * chunk_size;
int chunk_end = chunk_start + chunk_size;
if (chunk_end > n) chunk_end = n;
const float* xrow = x + row * (size_t)n;
#pragma unroll
for (int i0 = 0; i0 < N_PAD; i0 += THREADS) {
int i = i0 + threadIdx.x;
if (i < N_PAD) {
int gi = chunk_start + i;
if (gi < chunk_end) {
svals[i] = xrow[gi];
sidxs[i] = gi;
} else {
svals[i] = NEG_INF;
sidxs[i] = -1;
}
}
}
__syncthreads();
bitonic_sort_block<N_PAD, THREADS>(svals, sidxs);
int out_off = (row * chunks_per_row + blk) * K;
if (WRITE_INT64) {
float* ov = reinterpret_cast<float*>(out_v);
int64_t* oi = reinterpret_cast<int64_t*>(out_i);
for (int q = threadIdx.x; q < K; q += THREADS) {
ov[out_off + q] = svals[N_PAD - 1 - q];
oi[out_off + q] = (int64_t)sidxs[N_PAD - 1 - q];
}
} else {
float* ov = reinterpret_cast<float*>(out_v);
int* oi = reinterpret_cast<int*>(out_i);
for (int q = threadIdx.x; q < K; q += THREADS) {
ov[out_off + q] = svals[N_PAD - 1 - q];
oi[out_off + q] = sidxs[N_PAD - 1 - q];
}
}
}
// Stage 2 (bitonic): merge candidates (already sorted in chunks of K)
// via full bitonic sort over N_PAD candidates.
template<int K_OUT, int N_PAD, int THREADS>
__global__ void topk_stage2_bitonic_kernel(
const float* __restrict__ in_v,
const int* __restrict__ in_i,
float* __restrict__ vals,
int64_t* __restrict__ idxs,
int candidates_per_row)
{
extern __shared__ unsigned char smem_raw[];
float* svals = reinterpret_cast<float*>(smem_raw);
int* sidxs = reinterpret_cast<int*>(smem_raw + N_PAD * sizeof(float));
int row = blockIdx.x;
const float* in_v_row = in_v + row * candidates_per_row;
const int* in_i_row = in_i + row * candidates_per_row;
#pragma unroll
for (int i0 = 0; i0 < N_PAD; i0 += THREADS) {
int i = i0 + threadIdx.x;
if (i < N_PAD) {
if (i < candidates_per_row) {
svals[i] = in_v_row[i];
sidxs[i] = in_i_row[i];
} else {
svals[i] = NEG_INF;
sidxs[i] = -1;
}
}
}
__syncthreads();
bitonic_sort_block<N_PAD, THREADS>(svals, sidxs);
for (int q = threadIdx.x; q < K_OUT; q += THREADS) {
vals[row * K_OUT + q] = svals[N_PAD - 1 - q];
idxs[row * K_OUT + q] = (int64_t)sidxs[N_PAD - 1 - q];
}
}
// Stage 2 (k-way merge tournament): NUM_RUNS sorted-desc runs of RUN_LEN
// each per row. THREADS == NUM_RUNS (one thread per run). K iterations of
// argmax + advance.
template<int K_OUT, int NUM_RUNS, int RUN_LEN, int THREADS>
__global__ void topk_stage2_merge_kernel(
const float* __restrict__ in_v,
const int* __restrict__ in_i,
float* __restrict__ vals,
int64_t* __restrict__ idxs)
{
int row = blockIdx.x;
int tid = threadIdx.x;
// Each thread handles one run.
int run_off = row * NUM_RUNS * RUN_LEN + tid * RUN_LEN;
float head_v = (tid < NUM_RUNS) ? in_v[run_off] : NEG_INF;
int head_i = (tid < NUM_RUNS) ? in_i[run_off] : -1;
int head_ptr = 0;
__shared__ float smax[(THREADS + 31) / 32];
__shared__ int sidx[(THREADS + 31) / 32];
__shared__ int s_winner;
__shared__ float s_winner_v;
__shared__ int s_winner_i;
constexpr int WARP_SIZE = 32;
constexpr int NUM_WARPS = THREADS / WARP_SIZE;
for (int k_iter = 0; k_iter < K_OUT; ++k_iter) {
// Warp-level argmax.
unsigned mask = 0xFFFFFFFFu;
float my_v = head_v;
int my_t = tid;
#pragma unroll
for (int s = WARP_SIZE / 2; s > 0; s >>= 1) {
float ov = __shfl_xor_sync(mask, my_v, s);
int ot = __shfl_xor_sync(mask, my_t, s);
// Tie-break: prefer smaller thread index (deterministic).
bool take = (ov > my_v) || (ov == my_v && ot < my_t);
if (take) { my_v = ov; my_t = ot; }
}
int lane = tid & (WARP_SIZE - 1);
int warp = tid / WARP_SIZE;
if (lane == 0) { smax[warp] = my_v; sidx[warp] = my_t; }
__syncthreads();
if (warp == 0) {
float v = (lane < NUM_WARPS) ? smax[lane] : NEG_INF;
int t = (lane < NUM_WARPS) ? sidx[lane] : -1;
#pragma unroll
for (int s = NUM_WARPS / 2; s > 0; s >>= 1) {
float ov = __shfl_xor_sync(mask, v, s);
int ot = __shfl_xor_sync(mask, t, s);
bool take = (ov > v) || (ov == v && ot < t);
if (take) { v = ov; t = ot; }
}
if (lane == 0) {
s_winner = t;
s_winner_v = v;
// We need the index too — the winner thread will write it.
vals[row * K_OUT + k_iter] = v;
}
}
__syncthreads();
int w = s_winner;
if (tid == w) {
idxs[row * K_OUT + k_iter] = (int64_t)head_i;
head_ptr++;
if (head_ptr < RUN_LEN) {
head_v = in_v[run_off + head_ptr];
head_i = in_i[run_off + head_ptr];
} else {
head_v = NEG_INF;
head_i = -1;
}
}
__syncthreads();
}
}
// Argmax kernel for k=1.
template<int N, int THREADS>
__global__ void argmax_kernel(
const float* __restrict__ x,
float* __restrict__ vals,
int64_t* __restrict__ idxs,
int n)
{
int row = blockIdx.x;
const float* xrow = x + row * (size_t)n;
float my_v = NEG_INF;
int my_i = -1;
for (int i = threadIdx.x; i < n; i += THREADS) {
float v = xrow[i];
if (v > my_v) { my_v = v; my_i = i; }
}
unsigned mask = 0xFFFFFFFFu;
#pragma unroll
for (int s = 16; s > 0; s >>= 1) {
float ov = __shfl_xor_sync(mask, my_v, s);
int oi = __shfl_xor_sync(mask, my_i, s);
if (ov > my_v || (ov == my_v && oi < my_i)) { my_v = ov; my_i = oi; }
}
__shared__ float wv[THREADS / 32];
__shared__ int wi[THREADS / 32];
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
if (lane == 0) { wv[warp] = my_v; wi[warp] = my_i; }
__syncthreads();
if (warp == 0) {
my_v = (lane < THREADS / 32) ? wv[lane] : NEG_INF;
my_i = (lane < THREADS / 32) ? wi[lane] : -1;
#pragma unroll
for (int s = 16; s > 0; s >>= 1) {
float ov = __shfl_xor_sync(mask, my_v, s);
int oi = __shfl_xor_sync(mask, my_i, s);
if (ov > my_v || (ov == my_v && oi < my_i)) { my_v = ov; my_i = oi; }
}
if (lane == 0) {
vals[row] = my_v;
idxs[row] = (int64_t)my_i;
}
}
}
template<int K, int N_PAD, int THREADS>
inline void launch_chunk_int32(torch::Tensor x, torch::Tensor v_scratch,
torch::Tensor i_scratch,
int n, int chunk_size, int chunks_per_row,
int batch, cudaStream_t stream)
{
size_t sm = (size_t)N_PAD * (sizeof(float) + sizeof(int));
cudaFuncSetAttribute((void*)topk_chunk_kernel<K, N_PAD, THREADS, 0>,
cudaFuncAttributeMaxDynamicSharedMemorySize, sm);
topk_chunk_kernel<K, N_PAD, THREADS, 0>
<<<dim3(chunks_per_row, batch), THREADS, sm, stream>>>(
x.data_ptr<float>(),
(void*)v_scratch.data_ptr<float>(),
(void*)i_scratch.data_ptr<int>(),
n, chunk_size, chunks_per_row);
}
template<int K, int N_PAD, int THREADS>
inline void launch_chunk_int64(torch::Tensor x, torch::Tensor vals,
torch::Tensor idxs,
int n, int chunk_size, int chunks_per_row,
int batch, cudaStream_t stream)
{
size_t sm = (size_t)N_PAD * (sizeof(float) + sizeof(int));
cudaFuncSetAttribute((void*)topk_chunk_kernel<K, N_PAD, THREADS, 1>,
cudaFuncAttributeMaxDynamicSharedMemorySize, sm);
topk_chunk_kernel<K, N_PAD, THREADS, 1>
<<<dim3(chunks_per_row, batch), THREADS, sm, stream>>>(
x.data_ptr<float>(),
(void*)vals.data_ptr<float>(),
(void*)idxs.data_ptr<int64_t>(),
n, chunk_size, chunks_per_row);
}
template<int K_OUT, int N_PAD, int THREADS>
inline void launch_stage2_bitonic(torch::Tensor v_scratch, torch::Tensor i_scratch,
torch::Tensor vals, torch::Tensor idxs,
int candidates_per_row, int batch, cudaStream_t stream)
{
size_t sm = (size_t)N_PAD * (sizeof(float) + sizeof(int));
cudaFuncSetAttribute((void*)topk_stage2_bitonic_kernel<K_OUT, N_PAD, THREADS>,
cudaFuncAttributeMaxDynamicSharedMemorySize, sm);
topk_stage2_bitonic_kernel<K_OUT, N_PAD, THREADS><<<batch, THREADS, sm, stream>>>(
v_scratch.data_ptr<float>(),
i_scratch.data_ptr<int>(),
vals.data_ptr<float>(),
idxs.data_ptr<int64_t>(),
candidates_per_row);
}
template<int K_OUT, int NUM_RUNS, int RUN_LEN, int THREADS>
inline void launch_stage2_merge(torch::Tensor v_scratch, torch::Tensor i_scratch,
torch::Tensor vals, torch::Tensor idxs,
int batch, cudaStream_t stream)
{
topk_stage2_merge_kernel<K_OUT, NUM_RUNS, RUN_LEN, THREADS>
<<<batch, THREADS, 0, stream>>>(
v_scratch.data_ptr<float>(),
i_scratch.data_ptr<int>(),
vals.data_ptr<float>(),
idxs.data_ptr<int64_t>());
}
std::vector<torch::Tensor> topk_forward(torch::Tensor x, int64_t k) {
TORCH_CHECK(x.dim() == 2, "x must be 2D");
TORCH_CHECK(x.dtype() == torch::kFloat32, "x must be fp32");
TORCH_CHECK(x.is_cuda(), "x must be CUDA");
TORCH_CHECK(x.is_contiguous(), "x must be contiguous");
int batch = (int)x.size(0);
int n = (int)x.size(1);
auto vals = torch::empty({batch, (int64_t)k}, x.options());
auto idxs = torch::empty({batch, (int64_t)k}, x.options().dtype(torch::kInt64));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (batch == 1 && n == 131072 && k == 64) {
// Stage 1: 128 chunks * 1024. K=64.
constexpr int CHUNK = 1024;
constexpr int N_PAD1 = 1024;
constexpr int K1 = 64;
constexpr int T1 = 128;
constexpr int CPR = 131072 / CHUNK; // 128
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
x.options().dtype(torch::kInt32));
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
launch_stage2_merge<K1, CPR, K1, CPR>(v_scratch, i_scratch, vals, idxs,
batch, stream);
} else if (batch == 64 && n == 8192 && k == 8) {
constexpr int CHUNK = 1024;
constexpr int N_PAD1 = 1024;
constexpr int K1 = 8;
constexpr int T1 = 128;
constexpr int CPR = 8192 / CHUNK; // 8
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
x.options().dtype(torch::kInt32));
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
constexpr int CAND2 = CPR * K1; // 64
constexpr int N_PAD2 = 64;
constexpr int T2 = 32;
launch_stage2_bitonic<K1, N_PAD2, T2>(v_scratch, i_scratch, vals, idxs,
CAND2, batch, stream);
} else if (batch == 32 && n == 16384 && k == 32) {
constexpr int CHUNK = 1024;
constexpr int N_PAD1 = 1024;
constexpr int K1 = 32;
constexpr int T1 = 128;
constexpr int CPR = 16384 / CHUNK; // 16
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
x.options().dtype(torch::kInt32));
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
launch_stage2_merge<K1, CPR, K1, 32>(v_scratch, i_scratch, vals, idxs,
batch, stream);
} else if (batch == 16 && n == 12000 && k == 16) {
constexpr int CHUNK = 1024;
constexpr int N_PAD1 = 1024;
constexpr int K1 = 16;
constexpr int T1 = 128;
constexpr int CPR = 12;
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
x.options().dtype(torch::kInt32));
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
launch_stage2_merge<K1, CPR, K1, 32>(v_scratch, i_scratch, vals, idxs,
batch, stream);
} else if (batch == 128 && n == 4096 && k == 1) {
constexpr int N = 4096;
constexpr int T = 256;
argmax_kernel<N, T><<<batch, T, 0, stream>>>(
x.data_ptr<float>(),
vals.data_ptr<float>(),
idxs.data_ptr<int64_t>(),
n);
} else {
TORCH_CHECK(false, "Unsupported (batch, n, k) combo: ",
batch, " ", n, " ", k);
}
return {vals, idxs};
}
"""
_module = load_inline(
name="topk_v6",
cpp_sources=CPP_SRC,
cuda_sources=CUDA_SRC,
functions=["topk_forward"],
extra_cuda_cflags=[
"-O3",
"--use_fast_math",
"-gencode=arch=compute_120,code=sm_120",
],
verbose=False,
)
OP_TYPE = "topk"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
class Model(nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch, self.n, self.k = batch, n, k
self.register_buffer("_dummy", torch.zeros(1))
def forward(self, x: torch.Tensor):
return _module.topk_forward(x.contiguous(), self.k)
batch = 64
n = 8192
k = 8
def get_inputs():
x = torch.randn(batch, n, dtype=torch.float32)
return [x]
def get_init_inputs():
return [batch, n, k]
shape=0 variant=eager tflops=0.009 gbps=9.281 ms=0.057
shape=0 variant=compiled tflops=0.005 gbps=5.219 ms=0.101
shape=0 variant=sota tflops=0.010 gbps=10.474 ms=0.050
shape=0 variant=solution tflops=0.009 gbps=8.718 ms=0.060
shape=0 solution_peak_fraction=0.0048
shape=1 variant=eager tflops=0.040 gbps=39.799 ms=0.053
shape=1 variant=compiled tflops=0.024 gbps=24.362 ms=0.086
shape=1 variant=sota tflops=0.043 gbps=43.457 ms=0.048
shape=1 variant=solution tflops=0.045 gbps=45.004 ms=0.047
shape=1 solution_peak_fraction=0.0250
shape=2 variant=eager tflops=0.043 gbps=42.931 ms=0.049
shape=2 variant=compiled tflops=0.023 gbps=23.248 ms=0.091
shape=2 variant=sota tflops=0.044 gbps=44.035 ms=0.048
shape=2 variant=solution tflops=0.042 gbps=41.761 ms=0.051
shape=2 solution_peak_fraction=0.0232
shape=3 variant=eager tflops=0.017 gbps=16.762 ms=0.046
shape=3 variant=compiled tflops=0.009 gbps=8.986 ms=0.086
shape=3 variant=sota tflops=0.016 gbps=16.032 ms=0.048
shape=3 variant=solution tflops=0.017 gbps=17.169 ms=0.045
shape=3 solution_peak_fraction=0.0095
shape=4 variant=eager tflops=0.088 gbps=87.679 ms=0.024
shape=4 variant=compiled tflops=0.034 gbps=34.427 ms=0.061
shape=4 variant=sota tflops=0.101 gbps=100.743 ms=0.021
shape=4 variant=solution tflops=0.198 gbps=198.139 ms=0.011
shape=4 solution_peak_fraction=0.1101
peak_fraction: 0.0197
RESULT: LOW
PASS
{
"problem": "05_topk_bitonic",
"harness": "claude",
"model": "claude-opus-4-7",
"reasoning_effort": "max",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0197,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 124, "output_tokens": 5722, "cache_read_tokens": 13891786, "cache_creation_tokens": 424373, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"command": "pwd && ls -la",
"description": "Check current directory and contents"
}/home/infatoshi/cuda/KernelBench-Hard
drwxrwxr-x infatoshi infatoshi 4.0 KB Mon Apr 27 21:56:40 2026 .
drwxr-xr-x infatoshi infatoshi 4.0 KB Tue Apr 28 00:55:24 2026 ..
.rw-rw-r-- infatoshi infatoshi 1.6 KB Fri Apr 24 00:32:28 2026 .cursorrules
drwxrwxr-x infatoshi infatoshi 4.0 KB Tue Apr 28 02:09:38 2026 .git
.rw-rw-r-- infatoshi infatoshi 262 B Fri Apr 24 00:30:12 2026 .gitignore
drwxrwxr-x infatoshi infatoshi 4.0 KB Fri Apr 24 00:41:00 2026 .pytest_cache
.rw-rw-r-- infatoshi infatoshi 5 B Fri Apr 24 00:30:15 2026 .python-version
drwxrwxr-x infatoshi infatoshi 4.0 KB Fri Apr 24 00:40:58 2026 .ruff_cache
drwxrwxr-x infatoshi infatoshi 4.0 KB Fri Apr 24 00:40:54 2026 .venv
.rw-rw-r-- infatoshi infatoshi 1012 B Mon Apr 27 02:25:28 2026 AGENTS.md
.rw-rw-r-- infatoshi infatoshi 6.9 KB Mon Apr 27 02:24:34 2026 CLAUDE.md
.rw-rw-r-- infatoshi infatoshi 45 KB Mon Apr 27 21:56:40 2026 DEVLOG.md
drwxrwxr-x infatoshi infatoshi 4.0 KB Fri Apr 24 00:30:02 2026 docs
.rw-rw-r-- infatoshi infatoshi 1.0 KB Mon Apr 27 04:26:18 2026 LICENSE
drwxrwxr-x infatoshi infatoshi 4.0 KB Fri Apr 24 00:30:02 2026 outputs
drwxrwxr-x infatoshi infatoshi 4.0 KB Fri Apr 24 00:30:02 2026 problems
.rw-rw-r-- infatoshi infatoshi 663 B Sat Apr 25 22:54:16 2026 pyproject.toml
.rw-rw-r-- infatoshi infatoshi 8.0 KB Sun Apr 26 00:07:33 2026 README.md
drwxrwxr-x infatoshi infatoshi 4.0 KB Tue Apr 28 04:05:34 2026 scripts
.rw-rw-r-- infatoshi infatoshi 6.5 KB Mon Apr 27 02:25:12 2026 SPEC.md
... (210 more chars){
"command": "ls problems/",
"description": "List problem directories"
}01_fp8_gemm
02_kda_cutlass
03_paged_attention
04_kahan_softmax
05_topk_bitonic
06_sonic_moe_swiglu
07_w4a16_gemm
08_metal_lightning_attn{
"command": "ls problems/05_topk_bitonic/",
"description": "List topk problem files"
}benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py{
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/reference.py"
}1 """Naive top-k reference: torch.topk over the last dim.
2
3 This is the correctness oracle. The agent's solution must produce the same
4 top-k values (and equivalent indices modulo ties) within the tolerance
5 declared in problem.yaml. Note that solution.py is FORBIDDEN from calling
6 torch.topk / torch.sort / torch.kthvalue (see problem.yaml).
7 """
8 import torch
9 import torch.nn as nn
10
11 OP_TYPE = "topk"
12 SUPPORTED_PRECISIONS = ["fp32"]
13 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
14
15
16 class Model(nn.Module):
17 """Top-k over the last dim of a 2D tensor.
18
19 Input:
20 x: (batch, n) fp32
21 Output:
22 values: (batch, k) fp32, sorted descending
23 indices: (batch, k) int64, into the last dim of x
24 """
25
26 def __init__(self, batch: int, n: int, k: int):
27 super().__init__()
28 self.batch, self.n, self.k = batch, n, k
29 # No learned parameters, but declare a dummy buffer so state_dict
30 # is non-empty and load_state_dict(strict=True) is meaningful.
31 self.register_buffer("_dummy", torch.zeros(1))
32
33 def forward(self, x: torch.Tensor):
34 values, indices = torch.topk(x, k=self.k, dim=-1, largest=True, sorted=True)
35 return values, indices
36
37
38 # Module-level shims rebuilt by check.py / benchmark.py per shape.
39 batch = 64
40 n = 8192
41 k = 8
42
43
44 def get_inputs():
45 # fp32 input drawn from a roughly Gaussian dist
... (211 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py"
}1 """Correctness runner for TopK.
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each. Top-k correctness has two parts:
5
6 1. VALUES: sol_values must match ref_values within fp32 tol. Both are
7 returned sorted descending, so positional comparison is well-defined.
8 2. INDICES: lenient — we do NOT require sol_indices == ref_indices because
9 ties in x can yield multiple valid index sets. Instead we gather x at
10 sol_indices and check those values match ref_values within tol. This
11 catches "wrong indices" without false-failing on legitimate tie-breaks.
12
13 Also rejects forbidden ops by grep.
14 """
15 import re
16 import sys
17 from pathlib import Path
18
19 import torch
20 import yaml
21
22 REPO_ROOT = Path(__file__).resolve().parents[2]
23 sys.path.insert(0, str(REPO_ROOT))
24
25 from src.eval.correctness import check_correctness # noqa: E402
26
27
28 def main():
29 try:
30 import reference
31 import shapes
32 import solution
33 except Exception as e:
34 print(f"FAIL: import error: {e}")
35 sys.exit(1)
36
37 problem_yaml = Path("problem.yaml")
38 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
39
40 # --- Forbidden-op check ------------------------------------------------
41 sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
42 for forbidden in meta.get("forbidden", [
... (4601 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/benchmark.py"
}1 """Roofline benchmark for TopK.
2
3 For each shape: times eager reference (torch.topk), compiled reference, SOTA
4 (also torch.topk — see sota.py), and the agent's solution. Reports achieved
5 TFLOPS, GB/s, and peak_fraction (vs DRAM bandwidth, since this is memory-bound).
6
7 Output lines the harness picks up:
8 shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9 peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10 """
11 import sys
12 from math import exp, log
13 from pathlib import Path
14
15 import torch
16 import yaml
17
18 REPO_ROOT = Path(__file__).resolve().parents[2]
19 sys.path.insert(0, str(REPO_ROOT))
20
21 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22 from src.eval.timing import time_fn # noqa: E402
23 from src.hardware import get as get_hw # noqa: E402
24
25
26 def _eval_formula(expr: str, vars: dict) -> float:
27 return float(eval(expr, {"__builtins__": {}}, vars))
28
29
30 def main():
31 import reference
32 import shapes
33 import solution
34
35 meta = yaml.safe_load(Path("problem.yaml").read_text())
36 hw = get_hw(meta["hardware"][0])
37 peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38 peak_gbps = hw.peak_bandwidth_gb_s
39 regime = meta.get("regime", "memory")
40 flops_formula = meta["flops_formula"]
41 bytes_formula = meta["bytes_formula"]
42 num_perf_trials = int(meta.get("num_perf_trials", 50))
43
44
... (2972 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/problem.yaml"
}1 name: 05_topk_bitonic
2 display_name: "TopK via Bitonic Sort"
3 precision: fp32
4 regime: memory
5
6 # Top-k is dominated by the input read (small output, no reduction over k).
7 # Comparator-network FLOPs are not the bottleneck on real hardware, so we
8 # track them but score on bandwidth.
9 flops_formula: "batch * n * 4" # ~O(n log^2 n) compares total, but the
10 # bitonic network is so cheap relative
11 # to memory that a coarse 4*n estimate
12 # is fine for telemetry only.
13 bytes_formula: "batch * n * 4 + batch * k * (4 + 8)" # fp32 input read + (fp32 value + int64 idx) output
14
15 hardware: [RTX_PRO_6000]
16 peak_tflops_key: fp32
17 peak_bandwidth_key: dram
18
19 # Top-k correctness:
20 # - VALUES must match within fp32 atol/rtol (the kth largest value is
21 # well-defined modulo float-equal ties, so we use a loose-ish tol).
22 # - INDICES are checked leniently: for each row, the multiset of returned
23 # indices must select values that match ref values within tol. Direct
24 # index equality is NOT required (ties in x can yield different valid
25 # index sets).
26 tolerance:
27 float32: 1.0e-4
28
29 # Forbidden ops — using any of these in solution.py fails correctness post-hoc.
30 # This problem is about IMPLEMENTING the selection, not dispatching to PyTorch's
31 # tuned top-k. torch.sort is also banned b
... (859 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/shapes.py"
}1 """Canonical shape sweep for TopK.
2
3 Mix of:
4 - decoder vocab top-k (single sequence, very large n, moderate k) — pure
5 bandwidth test; the input read dominates everything.
6 - prefill / batched attention top-k (many rows, moderate n, small k) — tests
7 per-row parallelism and shared-memory bitonic networks.
8 - non-power-of-2 n stress case — bitonic sort networks naturally want
9 powers of two; this forces the agent to handle padding or partial sorts.
10 - small-k limit — k=1 (argmax) is a degenerate but useful sanity case.
11 """
12
13 SHAPES = [
14 {"batch": 1, "n": 131072, "k": 64}, # decoder vocab top-k (Llama vocab ~128k)
15 {"batch": 64, "n": 8192, "k": 8}, # prefill / attention top-k
16 {"batch": 32, "n": 16384, "k": 32}, # mid-size batched
17 {"batch": 16, "n": 12000, "k": 16}, # non-power-of-2 n stress
18 {"batch": 128, "n": 4096, "k": 1}, # batched argmax (k=1 corner case)
19 ]
20 {
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/sota.py"
}1 """SOTA reference for TopK: torch.topk itself.
2
3 torch.topk dispatches to a CUB-backed kernel that uses radix-select for
4 moderate k and a tuned bitonic sort for small n. It is the bar the agent's
5 hand-rolled bitonic kernel must beat. There is no obvious vendor library that
6 does better on the (batch, n, k) shape mix we evaluate — Faiss BlockSelect is
7 specialized for k>=32 with much larger n, and CUB's DeviceSegmentedRadixSort
8 sorts the full row (overkill for top-k).
9
10 This file is INTENTIONALLY allowed to call torch.topk because it is the SOTA
11 oracle, not the agent's submission. The agent's solution.py is forbidden from
12 using torch.topk (see problem.yaml.forbidden).
13 """
14 from __future__ import annotations
15
16 import torch
17
18
19 def sota_forward(x: torch.Tensor, k: int):
20 """Best-available top-k reference. x: (batch, n) fp32."""
21 return torch.topk(x, k=k, dim=-1, largest=True, sorted=True)
22
23
24 def is_available() -> bool:
25 return True
26 {
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/PROMPT.txt"
}1 I need you to optimize a top-k kernel for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2
3 The input is a 2D fp32 tensor and you return the top-k values and their int64 indices along the last dim, sorted descending, same contract as torch.topk(largest=True, sorted=True). Correctness tolerance is 1e-4 abs/rel on the values; index ties are handled leniently — what matters is the returned values match the reference to tolerance, not that the index sets are bitwise identical. The shapes you have to handle are batch=1 n=131072 k=64 (decoder vocab top-k over a Llama-size vocabulary), batch=64 n=8192 k=8 (prefill / attention top-k), batch=32 n=16384 k=32 (mid-size batched), batch=16 n=12000 k=16 (non-power-of-two n, which bitonic networks don't naturally want), and batch=128 n=4096 k=1 (batched argmax).
4
5 This needs to be a real custom kernel — CUDA C++ via torch.utils.cpp_extension.load_inline, Triton, inline PTX, or CUTLASS, whatever you think fits. Don't reach for torch.topk, torch.kthvalue, torch.sort, or torch.argsort, or any of their Tensor.* / torch.ops.aten.* variants; they're off-limits and using them fails correctness. Anythi
... (588 more chars){
"query": "select:TodoWrite",
"max_results": 1
}{
"todos": [
{
"content": "Check GPU availability and CUDA environment",
"status": "in_progress",
"activeForm": "Checking GPU availability and CUDA environment"
},
{
"content": "Write initial CUDA topk solution and pass check.py",
"status": "pending",
"activeForm": "Writing initial CUDA topk solution"
},
{
"content": "Benchmark and identify bottlenecks per shape",
"status": "pending",
"activeForm": "Benchmarking and identifying bottlenecks"
},
{
"content": "Iterate to push peak_fraction up",
"status
... (82 more chars)Todos have been modified successfully. Ensure that you continue to use the todo list to track your progress. Please proceed with the current tasks if applicable{
"command": "nvidia-smi && echo \"---\" && which nvcc && nvcc --version",
"description": "Check GPU and nvcc"
}Tue Apr 28 07:06:52 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 30C P8 11W / 600W | 733MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (983 more chars){
"command": "cat src/eval/correctness.py 2>/dev/null | head -80",
"description": "Read correctness check"
}"""Per-dtype correctness tolerance.
Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
We extend to fp8 (0.1, permissive) and int (exact).
"""
from __future__ import annotations
import torch
DEFAULT_TOLERANCE = {
torch.float32: {"atol": 1e-4, "rtol": 1e-4},
torch.float16: {"atol": 1e-2, "rtol": 1e-2},
torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
torch.int8: {"atol": 0, "rtol": 0},
torch.int32: {"atol": 0, "rtol": 0},
torch.int64: {"atol": 0, "rtol": 0},
}
def _coerce_float(x):
"""Accept int/float/str scientific-notation as float; pass through dicts."""
if isinstance(x, (int, float)):
return float(x)
if isinstance(x, str):
try:
return float(x)
except ValueError:
return x
return x
def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
"""Lookup atol/rtol for a given dtype, with optional per-problem override.
The override dict is keyed by str(dtype). Values may be:
- a single number (or a string parseable as float, e.g. "5e-2") -> used
as both atol and rtol
- a dict {"atol": ..., "rtol": ...} -> used directly, with values coerced
"""
# Accept several spellings of the same dtype key (PyYAML quirks):
# "torch.bfloat16", "bfloat16", torch.bfloat16
if override is not None:
type_
... (1493 more chars){
"command": "cat src/eval/timing.py",
"description": "Read timing utility"
}"""Centralized GPU timing for benchmark.py files.
All problems' benchmark.py should call `time_fn` from here so we have one
implementation of warmup, L2 flush, and CUDA event capture to maintain.
Methodology:
- 10 warmup calls absorb Triton autotune (typical ~7 configs) and
torch.compile reduce-overhead CUDA-graph capture.
- Between each timed call, `_l2_flush()` writes 128 MB to evict L2
(Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
we measure HBM-load bandwidth, not L2-cached bandwidth).
- GPU timing via cuda.Event with synchronize() AFTER record() but
BEFORE elapsed_time() — the canonical NVIDIA pattern.
- Reported value is the median of `iters` trials, robust to outliers.
Notes / known biases not addressed here:
- torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
On small shapes where launch overhead matters, this gives compile an
artificial advantage. We accept this as the cost of using the agreed
"compiled" baseline policy.
- cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
- Median over a fairly small number of trials (default 30) is fine for
headline numbers but won't catch bimodal distributions. Use --extra
instrumentation if you ever care.
"""
from __future__ import annotations
import statistics
import torch
# Blackwell consumer (RTX PRO 6000) L2 is 96 MB. Allocate strictly larg
... (1066 more chars){
"command": "cat src/eval/roofline.py",
"description": "Read roofline calculation"
}"""Roofline math: achieved TFLOPS / GB/s, peak fraction.
FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
dense-equivalent algorithmic work; kernels that skip work (sparsity, early
exit) are still measured against the dense cost.
"""
from __future__ import annotations
def compute_tflops(flops: float, time_ms: float) -> float:
"""flops per invocation and wall time in ms -> TFLOPS."""
if time_ms <= 0:
return 0.0
return flops / (time_ms * 1e-3) / 1e12
def compute_gbps(bytes_moved: float, time_ms: float) -> float:
"""bytes per invocation and wall time in ms -> GB/s."""
if time_ms <= 0:
return 0.0
return bytes_moved / (time_ms * 1e-3) / 1e9
def peak_fraction(
achieved: float,
peak: float,
) -> float:
"""Return achieved / peak, clamped to [0, infinity)."""
if peak <= 0:
return 0.0
return max(0.0, achieved / peak){
"command": "cat src/hardware/rtx_pro_6000.py 2>/dev/null | head -50",
"description": "Read hardware spec"
}"""RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
will see 60-85% of peak on well-tuned code.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class HardwareTarget:
name: str
sm: str
vram_gb: int
peak_bandwidth_gb_s: float # DRAM
peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
RTX_PRO_6000 = HardwareTarget(
name="RTX PRO 6000 Blackwell Workstation",
sm="sm_120a",
vram_gb=96,
peak_bandwidth_gb_s=1800.0,
peak_tflops_dense={
"fp4": 800.0,
"nvfp4": 800.0,
"mxfp4": 800.0,
"fp6": 800.0,
"fp8": 400.0,
"bf16": 200.0,
"fp16": 200.0,
"tf32": 100.0,
"fp32": 12.0, # non-tensor-core SIMT fp32
"int8": 400.0,
"int4": 800.0,
},
){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic && uv run python -c \"import torch; print(torch.__version__, torch.cuda.is_available(), torch.cuda.get_device_capability())\" 2>&1 | tail -10",
"description": "Check torch version and CUDA capability"
}2.11.0+cu130 True (12, 0){
"command": "uv run python -c \"import torch; print('SM count:', torch.cuda.get_device_properties(0).multi_processor_count); print('shared mem per block:', torch.cuda.get_device_properties(0).shared_memory_per_block); print('shared mem per block opt-in:', torch.cuda.get_device_properties(0).shared_memory_per_block_optin); print('regs per block:', torch.cuda.get_device_properties(0).regs_per_block)\"",
"description": "Get device props"
}Exit code 1
Traceback (most recent call last):
File "<string>", line 1, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'regs_per_block'
SM count: 188
shared mem per block: 49152
shared mem per block opt-in: 101376--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -0,0 +1,365 @@
+"""TopK custom kernel via load_inline.
+
+Strategy:
+ - Each thread streams its share of the row, maintaining a sorted top-K
+ array in registers (insertion sort, descending).
+ - Threads write their per-thread top-K to shared memory: (THREADS * K)
+ candidates total.
+ - Block-wide bitonic sort over the (THREADS * K) candidates picks the
+ block-wide top-K. (THREADS * K) is small enough to fit in 99 KB
+ Blackwell opt-in shared memory.
+ - For batch=1 n=131072 (the decoder shape) one block can't saturate
+ bandwidth, so we use a 2-stage approach: G blocks each find a chunk-local
+ top-K, then a single block merges G*K candidates.
+"""
+import os
+
+os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "12.0")
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+CPP_SRC = r"""
+std::vector<torch::Tensor> topk_forward(torch::Tensor x, int64_t k);
+"""
+
+CUDA_SRC = r"""
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <cuda_runtime.h>
+#include <cstdint>
+#include <cstdio>
+
+#define NEG_INF (-3.4e38f)
+
+// Per-thread insertion: maintain sorted-descending top-K in registers.
+// my_v[0] is the current max in this thread, my_v[K-1] is the current min.
+template<int K>
+__device__ __forceinline__ void thread_insert(float v, int i, float (&my_v)[K], int (&my_i)[K]) {
+ if (v > my_v[K-1]) {
+ // Find insertion position; shift down.
+ int j = K - 1;
+ #pragma unroll
+ for (int p = K - 1; p > 0; --p) {
+ if (my_v[p-1] < v) {
+ my_v[p] = my_v[p-1];
+ my_i[p] = my_i[p-1];
+ j = p - 1;
+ } else {
+ // Already in place; record j here.
+ // Use a small trick: only update j if we did not break.
+ }
+ }
+ // Re-derive insertion point properly: walk from top, find first < v.
+ // (The unrolled loop above shifts but we still need to find j.)
+ int pos = 0;
+ #pragma unroll
+ for (int p = 0; p < K; ++p) {
+ if (my_v[p] >= v && p < K - 1) pos = p + 1;
+ }
+ // Actually the unroll-shift is messy. Replace with clean version below.
+ }
+}
+
+// Cleaner version: explicit branched insertion.
+template<int K>
+__device__ __forceinline__ void thread_insert_v2(float v, int i, float (&my_v)[K], int (&my_i)[K]) {
+ if (v <= my_v[K-1]) return;
+ // Find insertion position (largest p with my_v[p-1] < v, else 0).
+ // Then shift my_v[p..K-2] -> my_v[p+1..K-1], place v at p.
+ int pos = K - 1;
+ #pragma unroll
+ for (int p = K - 2; p >= 0; --p) {
+ if (my_v[p] < v) {
+ my_v[p+1] = my_v[p];
+ my_i[p+1] = my_i[p];
+ pos = p;
+ }
+ }
+ my_v[pos] = v;
+ my_i[pos] = i;
+}
+
+// Block-wide bitonic sort of N elements in shared memory, ascending.
+// Caller must ensure N is a power of 2 and N >= THREADS*2 (or handles small N).
+template<int N, int THREADS>
+__device__ __forceinline__ void bitonic_sort_block(float* svals, int* sidxs) {
+ // For each (size, stride) pair, every thread handles N/THREADS pairs.
+ #pragma unroll
+ for (int sz = 2; sz <= N; sz <<= 1) {
+ #pragma unroll
+ for (int j = sz >> 1; j > 0; j >>= 1) {
+ // Each thread handles elements i = threadIdx.x, threadIdx.x + THREADS, ...
+ for (int t = threadIdx.x; t < N; t += THREADS) {
+ int p = t ^ j;
+ if (p > t) {
+ bool asc = (t & sz) == 0;
+ float vt = svals[t], vp = svals[p];
+ bool sw = asc ? (vt > vp) : (vt < vp);
+ if (sw) {
+ svals[t] = vp; svals[p] = vt;
+ int it = sidxs[t], ip = sidxs[p];
+ sidxs[t] = ip; sidxs[p] = it;
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+}
+
+// Single-row block kernel: 1 block per row, per-thread top-K + block bitonic.
+// REDUCE_N must equal THREADS * K and be a power of 2.
+template<int K, int THREADS, int REDUCE_N>
+__global__ void topk_block_kernel(
+ const float* __restrict__ x,
+ float* __restrict__ vals,
+ int64_t* __restrict__ idxs,
+ int n)
+{
+ extern __shared__ unsigned char smem_raw[];
+ float* svals = reinterpret_cast<float*>(smem_raw);
+ int* sidxs = reinterpret_cast<int*>(smem_raw + REDUCE_N * sizeof(float));
+
+ int row = blockIdx.x;
+ const float* xrow = x + row * (size_t)n;
+
+ float my_v[K];
+ int my_i[K];
+ #pragma unroll
+ for (int p = 0; p < K; ++p) { my_v[p] = NEG_INF; my_i[p] = -1; }
+
+ // Stream the row.
+ for (int i = threadIdx.x; i < n; i += THREADS) {
+ float v = xrow[i];
+ thread_insert_v2<K>(v, i, my_v, my_i);
+ }
+
+ // Write per-thread top-K to shared. Layout: thread t's K elements at
+ // smem[t*K .. t*K+K-1]. Then bitonic sort over REDUCE_N = THREADS*K.
+ #pragma unroll
+ for (int p = 0; p < K; ++p) {
+ svals[threadIdx.x * K + p] = my_v[p];
+ sidxs[threadIdx.x * K + p] = my_i[p];
+ }
+ __syncthreads();
+
+ bitonic_sort_block<REDUCE_N, THREADS>(svals, sidxs);
+
+ // Output top K (last K of ascending sort, reversed).
+ for (int i = threadIdx.x; i < K; i += THREADS) {
+ vals[row * K + i] = svals[REDUCE_N - 1 - i];
+ idxs[row * K + i] = (int64_t)sidxs[REDUCE_N - 1 - i];
+ }
+}
+
+// Stage 1 for the (1, 131072, 64) shape: G blocks per row, each finds
+// top-K within a chunk; outputs G*K candidates per row to scratch.
+template<int K, int THREADS, int REDUCE_N>
+__global__ void topk_stage1_kernel(
+ const float* __restrict__ x,
+ float* __restrict__ vals_out,
+ int* __restrict__ idxs_out,
+ int n,
+ int chunk_size,
+
... (6699 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"content": "<12221 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py (file state is current in your context — no need to Read it back){
"todos": [
{
"content": "Check GPU availability and CUDA environment",
"status": "completed",
"activeForm": "Checking GPU availability and CUDA environment"
},
{
"content": "Compile v0 and pass check.py",
"status": "in_progress",
"activeForm": "Compiling v0 and verifying check.py"
},
{
"content": "Benchmark and identify bottlenecks per shape",
"status": "pending",
"activeForm": "Benchmarking and identifying bottlenecks"
},
{
"content": "Iterate to push peak_fraction up",
"status": "pending",
... (63 more chars)Todos have been modified successfully. Ensure that you continue to use the todo list to track your progress. Please proceed with the current tasks if applicable{
"command": "uv run python check.py 2>&1 | tail -60",
"description": "Run check.py"
}PASS{
"command": "uv run python benchmark.py 2>&1 | tail -40",
"description": "Benchmark v0"
}shape=0 variant=eager tflops=0.011 gbps=10.686 ms=0.049
shape=0 variant=compiled tflops=0.006 gbps=5.745 ms=0.091
shape=0 variant=sota tflops=0.011 gbps=10.627 ms=0.049
shape=0 variant=solution tflops=0.001 gbps=0.584 ms=0.899
shape=0 solution_peak_fraction=0.0003
shape=1 variant=eager tflops=0.043 gbps=43.586 ms=0.048
shape=1 variant=compiled tflops=0.025 gbps=24.963 ms=0.084
shape=1 variant=sota tflops=0.042 gbps=42.364 ms=0.050
shape=1 variant=solution tflops=0.033 gbps=32.749 ms=0.064
shape=1 solution_peak_fraction=0.0182
shape=2 variant=eager tflops=0.041 gbps=41.642 ms=0.051
shape=2 variant=compiled tflops=0.024 gbps=23.648 ms=0.089
shape=2 variant=sota tflops=0.044 gbps=43.757 ms=0.048
shape=2 variant=solution tflops=0.006 gbps=6.316 ms=0.334
shape=2 solution_peak_fraction=0.0035
shape=3 variant=eager tflops=0.017 gbps=16.681 ms=0.046
shape=3 variant=compiled tflops=0.008 gbps=8.374 ms=0.092
shape=3 variant=sota tflops=0.017 gbps=17.065 ms=0.045
shape=3 variant=solution tflops=0.005 gbps=5.296 ms=0.146
shape=3 solution_peak_fraction=0.0029
shape=4 variant=eager tflops=0.083 gbps=83.018 ms=0.025
shape=4 variant=compiled tflops=0.033 gbps=32.572 ms=0.064
shape=4 variant=sota tflops=0.098 gbps=98.327 ms=0.021
shape=4 variant=solution tflops=0.138 gbps=138.509 ms=0.015
shape=4 solution_peak_fraction=0.0769
peak_fraction: 0.0054
RESULT: LOW{
"command": "<810 chars \u2014 see below>",
"description": "Quick microbench"
}uv run python -c "
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
import sys
sys.path.insert(0, 'problems/05_topk_bitonic')
import solution
m = solution.Model(1, 131072, 64).cuda().eval()
x = torch.randn(1, 131072, device='cuda')
torch.cuda.synchronize()
import time
for _ in range(5): m(x); torch.cuda.synchronize()
ts = []
for _ in range(20):
torch.cuda.synchronize()
t0 = time.time()
m(x); torch.cuda.synchronize()
ts.append((time.time() - t0)*1000)
print('mine ms:', sorted(ts)[len(ts)//2])
print()
for _ in range(5): torch.topk(x, 64); torch.cuda.synchronize()
ts = []
for _ in range(20):
torch.cuda.synchronize()
t0 = time.time()
torch.topk(x, 64); torch.cuda.synchronize()
ts.append((time.time() - t0)*1000)
print('topk ms:', sorted(ts)[len(ts)//2])
"
mine ms: 0.8966922760009766
topk ms: 0.141143798828125--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,16 +1,13 @@
-"""TopK custom kernel via load_inline.
+"""TopK custom kernel via load_inline (v1: shared-memory bitonic sort).
Strategy:
- - Each thread streams its share of the row, maintaining a sorted top-K
- array in registers (insertion sort, descending).
- - Threads write their per-thread top-K to shared memory: (THREADS * K)
- candidates total.
- - Block-wide bitonic sort over the (THREADS * K) candidates picks the
- block-wide top-K. (THREADS * K) is small enough to fit in 99 KB
- Blackwell opt-in shared memory.
- - For batch=1 n=131072 (the decoder shape) one block can't saturate
- bandwidth, so we use a 2-stage approach: G blocks each find a chunk-local
- top-K, then a single block merges G*K candidates.
+ - For each row, load the row (or chunk) into shared memory padded to a
+ power of 2 with -inf, sort via in-place block-wide bitonic sort, then
+ output the last K elements (reversed) as the top-K descending.
+ - Where N_PAD * 8 bytes does not fit in 99 KB opt-in shared memory, we
+ use a 2-stage approach: stage 1 has G blocks per row each find the
+ chunk-local top-K via bitonic, stage 2 has 1 block per row sort the
+ G*K candidates.
"""
import os
@@ -29,71 +26,21 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cstdint>
-#include <cstdio>
#define NEG_INF (-3.4e38f)
-// Per-thread insertion: maintain sorted-descending top-K in registers.
-// my_v[0] is the current max in this thread, my_v[K-1] is the current min.
-template<int K>
-__device__ __forceinline__ void thread_insert(float v, int i, float (&my_v)[K], int (&my_i)[K]) {
- if (v > my_v[K-1]) {
- // Find insertion position; shift down.
- int j = K - 1;
- #pragma unroll
- for (int p = K - 1; p > 0; --p) {
- if (my_v[p-1] < v) {
- my_v[p] = my_v[p-1];
- my_i[p] = my_i[p-1];
- j = p - 1;
- } else {
- // Already in place; record j here.
- // Use a small trick: only update j if we did not break.
- }
- }
- // Re-derive insertion point properly: walk from top, find first < v.
- // (The unrolled loop above shifts but we still need to find j.)
- int pos = 0;
- #pragma unroll
- for (int p = 0; p < K; ++p) {
- if (my_v[p] >= v && p < K - 1) pos = p + 1;
- }
- // Actually the unroll-shift is messy. Replace with clean version below.
- }
-}
-
-// Cleaner version: explicit branched insertion.
-template<int K>
-__device__ __forceinline__ void thread_insert_v2(float v, int i, float (&my_v)[K], int (&my_i)[K]) {
- if (v <= my_v[K-1]) return;
- // Find insertion position (largest p with my_v[p-1] < v, else 0).
- // Then shift my_v[p..K-2] -> my_v[p+1..K-1], place v at p.
- int pos = K - 1;
- #pragma unroll
- for (int p = K - 2; p >= 0; --p) {
- if (my_v[p] < v) {
- my_v[p+1] = my_v[p];
- my_i[p+1] = my_i[p];
- pos = p;
- }
- }
- my_v[pos] = v;
- my_i[pos] = i;
-}
-
-// Block-wide bitonic sort of N elements in shared memory, ascending.
-// Caller must ensure N is a power of 2 and N >= THREADS*2 (or handles small N).
+// Block-wide bitonic sort, ascending. N must be a power of 2.
template<int N, int THREADS>
__device__ __forceinline__ void bitonic_sort_block(float* svals, int* sidxs) {
- // For each (size, stride) pair, every thread handles N/THREADS pairs.
#pragma unroll
for (int sz = 2; sz <= N; sz <<= 1) {
#pragma unroll
for (int j = sz >> 1; j > 0; j >>= 1) {
- // Each thread handles elements i = threadIdx.x, threadIdx.x + THREADS, ...
- for (int t = threadIdx.x; t < N; t += THREADS) {
+ #pragma unroll
+ for (int t0 = 0; t0 < N; t0 += THREADS) {
+ int t = t0 + threadIdx.x;
int p = t ^ j;
- if (p > t) {
+ if (p > t && t < N) {
bool asc = (t & sz) == 0;
float vt = svals[t], vp = svals[p];
bool sw = asc ? (vt > vp) : (vt < vp);
@@ -109,10 +56,108 @@
}
}
-// Single-row block kernel: 1 block per row, per-thread top-K + block bitonic.
-// REDUCE_N must equal THREADS * K and be a power of 2.
-template<int K, int THREADS, int REDUCE_N>
-__global__ void topk_block_kernel(
+// Stage 1 / single-stage kernel.
+// Reads chunk_size elements per block from x, pads to N_PAD with -inf,
+// bitonic sorts, writes top-K to (out_v, out_i).
+// blocks_per_row: 1 for single-stage; >1 for stage 1 of multi-stage.
+// out_i is int32; we promote to int64 in a small writeout kernel.
+template<int K, int N_PAD, int THREADS>
+__global__ void topk_stage_x_kernel(
+ const float* __restrict__ x,
+ float* __restrict__ out_v,
+ int* __restrict__ out_i,
+ int n,
+ int chunk_size,
+ int blocks_per_row)
+{
+ extern __shared__ unsigned char smem_raw[];
+ float* svals = reinterpret_cast<float*>(smem_raw);
+ int* sidxs = reinterpret_cast<int*>(smem_raw + N_PAD * sizeof(float));
+
+ int row = blockIdx.y;
+ int blk = blockIdx.x;
+ int chunk_start = blk * chunk_size;
+ int chunk_end = chunk_start + chunk_size;
+ if (chunk_end > n) chunk_end = n;
+
+ const float* xrow = x + row * (size_t)n;
+
+ #pragma unroll
+ for (int i0 = 0; i0 < N_PAD; i0 += THREADS) {
+ int i = i0 + threadIdx.x;
+ if (i < N_PAD) {
+ int gi = chunk_start + i;
+ if (gi < chunk_end) {
+ svals[i] = xrow[gi];
+ sidxs[i] = gi;
+ } else {
+ svals[i] = NEG_INF;
+ sidxs[i] = -1;
+ }
+ }
+ }
+ __syncthreads();
... (13888 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"content": "<12555 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -10",
"description": "Run check on v1"
}PASS{
"command": "uv run python benchmark.py 2>&1 | tail -40",
"description": "Bench v1"
}shape=0 variant=eager tflops=0.009 gbps=9.231 ms=0.057
shape=0 variant=compiled tflops=0.006 gbps=5.866 ms=0.090
shape=0 variant=sota tflops=0.010 gbps=10.498 ms=0.050
shape=0 variant=solution tflops=0.005 gbps=4.848 ms=0.108
shape=0 solution_peak_fraction=0.0027
shape=1 variant=eager tflops=0.043 gbps=42.625 ms=0.049
shape=1 variant=compiled tflops=0.022 gbps=22.387 ms=0.094
shape=1 variant=sota tflops=0.044 gbps=44.068 ms=0.048
shape=1 variant=solution tflops=0.016 gbps=16.063 ms=0.131
shape=1 solution_peak_fraction=0.0089
shape=2 variant=eager tflops=0.040 gbps=40.541 ms=0.052
shape=2 variant=compiled tflops=0.025 gbps=24.979 ms=0.084
shape=2 variant=sota tflops=0.044 gbps=43.815 ms=0.048
shape=2 variant=solution tflops=0.016 gbps=15.703 ms=0.134
shape=2 solution_peak_fraction=0.0087
shape=3 variant=eager tflops=0.017 gbps=16.745 ms=0.046
shape=3 variant=compiled tflops=0.009 gbps=9.190 ms=0.084
shape=3 variant=sota tflops=0.016 gbps=16.166 ms=0.048
shape=3 variant=solution tflops=0.006 gbps=5.805 ms=0.133
shape=3 solution_peak_fraction=0.0032
shape=4 variant=eager tflops=0.093 gbps=93.159 ms=0.023
shape=4 variant=compiled tflops=0.035 gbps=35.251 ms=0.060
shape=4 variant=sota tflops=0.088 gbps=88.388 ms=0.024
shape=4 variant=solution tflops=0.023 gbps=22.650 ms=0.093
shape=4 solution_peak_fraction=0.0126
peak_fraction: 0.0061
RESULT: LOW{
"command": "<1416 chars \u2014 see below>",
"description": "Detailed per-shape timing"
}cat > /tmp/topk_bench_detail.py << 'EOF'
import sys
sys.path.insert(0, 'problems/05_topk_bitonic')
import torch
import solution
import time
shapes = [
(1, 131072, 64),
(64, 8192, 8),
(32, 16384, 32),
(16, 12000, 16),
(128, 4096, 1),
]
for batch, n, k in shapes:
m = solution.Model(batch, n, k).cuda().eval()
x = torch.randn(batch, n, device='cuda')
for _ in range(20): m(x); torch.cuda.synchronize()
times = []
for _ in range(50):
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
m(x)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
times_topk = []
for _ in range(20): torch.topk(x, k); torch.cuda.synchronize()
for _ in range(50):
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
torch.topk(x, k)
e.record()
torch.cuda.synchronize()
times_topk.append(s.elapsed_time(e))
times.sort()
times_topk.sort()
sol_med = times[len(times)//2]
topk_med = times_topk[len(times_topk)//2]
print(f'shape b={batch} n={n} k={k}: sol={sol_med*1000:.1f}us, topk={topk_med*1000:.1f}us, ratio={sol_med/topk_med:.2f}x')
EOF
uv run python /tmp/topk_bench_detail.pyExit code 1
Traceback (most recent call last):
File "/tmp/topk_bench_detail.py", line 4, in <module>
import solution
ModuleNotFoundError: No module named 'solution'{
"command": "<1390 chars \u2014 see below>",
"description": "Detailed timing"
}cd problems/05_topk_bitonic && cat > /tmp/topk_bench_detail.py << 'EOF'
import sys, os
sys.path.insert(0, os.getcwd())
import torch
import solution
shapes = [
(1, 131072, 64),
(64, 8192, 8),
(32, 16384, 32),
(16, 12000, 16),
(128, 4096, 1),
]
for batch, n, k in shapes:
m = solution.Model(batch, n, k).cuda().eval()
x = torch.randn(batch, n, device='cuda')
for _ in range(20): m(x); torch.cuda.synchronize()
times = []
for _ in range(50):
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record(); m(x); e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
times_topk = []
for _ in range(20): torch.topk(x, k); torch.cuda.synchronize()
for _ in range(50):
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record(); torch.topk(x, k); e.record()
torch.cuda.synchronize()
times_topk.append(s.elapsed_time(e))
times.sort(); times_topk.sort()
sol_med = times[len(times)//2]; topk_med = times_topk[len(times_topk)//2]
print(f'shape b={batch} n={n} k={k}: sol={sol_med*1000:.1f}us, topk={topk_med*1000:.1f}us, ratio={sol_med/topk_med:.2f}x')
EOF
uv run python /tmp/topk_bench_detail.pyExit code 1
(eval):cd:1: no such file or directory: problems/05_topk_bitonic
Traceback (most recent call last):
File "/tmp/topk_bench_detail.py", line 4, in <module>
import solution
ModuleNotFoundError: No module named 'solution'{
"command": "pwd",
"description": "check pwd"
}/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic{
"command": "uv run python /tmp/topk_bench_detail.py",
"description": "Run timing from cwd"
}Exit code 1
Traceback (most recent call last):
File "/tmp/topk_bench_detail.py", line 4, in <module>
import solution
ModuleNotFoundError: No module named 'solution'{
"command": "ls && uv run python -c \"import sys; sys.path.insert(0, '.'); print(sys.path); import solution\"",
"description": "check why solution can't import"
}__pycache__
benchmark.py
check.py
framework.txt
problem.yaml
PROMPT.txt
reference.py
shapes.py
solution.py
sota.py
['.', '', '/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python311.zip', '/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11', '/home/infatoshi/.local/share/uv/python/cpython-3.11-linux-x86_64-gnu/lib/python3.11/lib-dynload', '/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages', '/home/infatoshi/cuda/KernelBench-Hard', '/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/nvidia_cutlass_dsl/python_packages']{
"command": "<1487 chars \u2014 see below>",
"description": "Run timing"
}cat > /tmp/topk_bench_detail.py << 'EOF'
import sys, os
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic')
os.chdir('/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic')
import torch
import solution
shapes = [
(1, 131072, 64),
(64, 8192, 8),
(32, 16384, 32),
(16, 12000, 16),
(128, 4096, 1),
]
for batch, n, k in shapes:
m = solution.Model(batch, n, k).cuda().eval()
x = torch.randn(batch, n, device='cuda')
for _ in range(20): m(x); torch.cuda.synchronize()
times = []
for _ in range(50):
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record(); m(x); e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
times_topk = []
for _ in range(20): torch.topk(x, k); torch.cuda.synchronize()
for _ in range(50):
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record(); torch.topk(x, k); e.record()
torch.cuda.synchronize()
times_topk.append(s.elapsed_time(e))
times.sort(); times_topk.sort()
sol_med = times[len(times)//2]; topk_med = times_topk[len(times_topk)//2]
print(f'shape b={batch} n={n} k={k}: sol={sol_med*1000:.1f}us, topk={topk_med*1000:.1f}us, ratio={sol_med/topk_med:.2f}x')
EOF
uv run python /tmp/topk_bench_detail.pyshape b=1 n=131072 k=64: sol=103.8us, topk=46.0us, ratio=2.26x
shape b=64 n=8192 k=8: sol=127.3us, topk=48.0us, ratio=2.65x
shape b=32 n=16384 k=32: sol=129.9us, topk=45.8us, ratio=2.84x
shape b=16 n=12000 k=16: sol=129.7us, topk=40.0us, ratio=3.24x
shape b=128 n=4096 k=1: sol=88.0us, topk=18.3us, ratio=4.81x--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,13 +1,19 @@
-"""TopK custom kernel via load_inline (v1: shared-memory bitonic sort).
+"""TopK custom kernel via load_inline (v3: per-thread streaming + block bitonic).
Strategy:
- - For each row, load the row (or chunk) into shared memory padded to a
- power of 2 with -inf, sort via in-place block-wide bitonic sort, then
- output the last K elements (reversed) as the top-K descending.
- - Where N_PAD * 8 bytes does not fit in 99 KB opt-in shared memory, we
- use a 2-stage approach: stage 1 has G blocks per row each find the
- chunk-local top-K via bitonic, stage 2 has 1 block per row sort the
- G*K candidates.
+ - Each thread streams its share of a chunk, maintaining a sorted top-K
+ array in registers (descending). Insertion uses static-indexed
+ bubble-up (place new element at position K-1, then bubble it up via
+ K-1 conditional swaps; all indices are compile-time constants so the
+ compiler keeps everything in registers without spills).
+ - After streaming, T threads write per-thread top-K to shared memory
+ (T*K floats + T*K ints), and a block-wide bitonic sort on those
+ T*K candidates picks the chunk-local top-K.
+ - Each row is split into G chunks; G is chosen per shape so total
+ stage-1 block count is ~256 (filling the 188-SM RTX PRO 6000).
+ - Stage 2 merges the G*K candidates per row into the final top-K.
+ - For shapes where 1 block per row already saturates the SMs (e.g.
+ batch=128) the kernel runs single-stage with int64 output directly.
"""
import os
@@ -29,7 +35,29 @@
#define NEG_INF (-3.4e38f)
-// Block-wide bitonic sort, ascending. N must be a power of 2.
+// Insert (v, idx) into sorted-descending registers my_v[0..K-1], my_i[0..K-1].
+// Static-indexed: place candidate at slot K-1, then bubble up via K-1
+// conditional swaps. All indices are compile-time constants, so PTX should
+// keep my_v/my_i in registers and use predicated swaps (no local memory).
+template<int K>
+__device__ __forceinline__ void thread_insert(float v, int idx,
+ float (&my_v)[K], int (&my_i)[K]) {
+ if (v <= my_v[K-1]) return;
+ my_v[K-1] = v;
+ my_i[K-1] = idx;
+ #pragma unroll
+ for (int p = K - 1; p > 0; --p) {
+ bool sw = my_v[p] > my_v[p-1];
+ float a = my_v[p-1], b = my_v[p];
+ int ai = my_i[p-1], bi = my_i[p];
+ my_v[p-1] = sw ? b : a;
+ my_v[p] = sw ? a : b;
+ my_i[p-1] = sw ? bi : ai;
+ my_i[p] = sw ? ai : bi;
+ }
+}
+
+// Block-wide bitonic sort, ascending.
template<int N, int THREADS>
__device__ __forceinline__ void bitonic_sort_block(float* svals, int* sidxs) {
#pragma unroll
@@ -39,15 +67,17 @@
#pragma unroll
for (int t0 = 0; t0 < N; t0 += THREADS) {
int t = t0 + threadIdx.x;
- int p = t ^ j;
- if (p > t && t < N) {
- bool asc = (t & sz) == 0;
- float vt = svals[t], vp = svals[p];
- bool sw = asc ? (vt > vp) : (vt < vp);
- if (sw) {
- svals[t] = vp; svals[p] = vt;
- int it = sidxs[t], ip = sidxs[p];
- sidxs[t] = ip; sidxs[p] = it;
+ if (t < N) {
+ int p = t ^ j;
+ if (p > t) {
+ bool asc = (t & sz) == 0;
+ float vt = svals[t], vp = svals[p];
+ bool sw = asc ? (vt > vp) : (vt < vp);
+ if (sw) {
+ svals[t] = vp; svals[p] = vt;
+ int it = sidxs[t], ip = sidxs[p];
+ sidxs[t] = ip; sidxs[p] = it;
+ }
}
}
}
@@ -56,63 +86,86 @@
}
}
-// Stage 1 / single-stage kernel.
-// Reads chunk_size elements per block from x, pads to N_PAD with -inf,
-// bitonic sorts, writes top-K to (out_v, out_i).
-// blocks_per_row: 1 for single-stage; >1 for stage 1 of multi-stage.
-// out_i is int32; we promote to int64 in a small writeout kernel.
-template<int K, int N_PAD, int THREADS>
-__global__ void topk_stage_x_kernel(
+// Stage 1 / single-stage kernel: per-thread streaming top-K + block bitonic.
+// N_REDUCE = next pow 2 of THREADS * K (must be a power of 2).
+// When write_int64=1, indices output is int64 directly (single-stage).
+// chunks_per_row = grid.x dim (1 for single-stage).
+template<int K, int THREADS, int N_REDUCE, int WRITE_INT64>
+__global__ void topk_streaming_kernel(
const float* __restrict__ x,
- float* __restrict__ out_v,
- int* __restrict__ out_i,
+ void* __restrict__ out_v, // float*
+ void* __restrict__ out_i, // int* or int64_t*
int n,
int chunk_size,
- int blocks_per_row)
+ int chunks_per_row)
{
extern __shared__ unsigned char smem_raw[];
float* svals = reinterpret_cast<float*>(smem_raw);
- int* sidxs = reinterpret_cast<int*>(smem_raw + N_PAD * sizeof(float));
+ int* sidxs = reinterpret_cast<int*>(smem_raw + N_REDUCE * sizeof(float));
int row = blockIdx.y;
int blk = blockIdx.x;
int chunk_start = blk * chunk_size;
int chunk_end = chunk_start + chunk_size;
if (chunk_end > n) chunk_end = n;
-
const float* xrow = x + row * (size_t)n;
- #pragma unroll
- for (int i0 = 0; i0 < N_PAD; i0 += THREADS) {
- int i = i0 + threadIdx.x;
- if (i < N_PAD) {
- int gi = chunk_start + i;
- if (gi < chunk_end) {
- svals[i] = xrow[gi];
- sidxs[i] = gi;
- } else {
- svals[i] = NEG_INF;
- sidxs[i] = -1;
- }
... (18237 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"content": "<15974 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -10",
"description": "Check v3"
}PASS{
"command": "uv run python /tmp/topk_bench_detail.py 2>&1 | tail -20",
"description": "Bench v3 detail"
}shape b=1 n=131072 k=64: sol=752.0us, topk=50.5us, ratio=14.88x
shape b=64 n=8192 k=8: sol=53.2us, topk=44.2us, ratio=1.20x
shape b=32 n=16384 k=32: sol=289.6us, topk=48.6us, ratio=5.95x
shape b=16 n=12000 k=16: sol=121.6us, topk=41.6us, ratio=2.92x
shape b=128 n=4096 k=1: sol=11.9us, topk=17.2us, ratio=0.69x{
"command": "<155 chars \u2014 see below>",
"description": "Clear extension cache"
}rm -rf /tmp/torch_extensions/topk_v3 ~/.cache/torch_extensions/py311_cu130/topk_v3 2>/dev/null
ls ~/.cache/torch_extensions/py311_cu130/ 2>/dev/null | headbitonic_topk
conv_tanh_scale_bias_pool_ext
cross_entropy_cuda
cross_entropy_cuda_v2
cross_entropy_cuda_v3
cross_entropy_cuda_v4
fp8_gemm_bf16_cutlass_ext
fp8_gemm_bf16_cutlass_ext_v2
fp8_gemm_bf16_kernel_v2
fp8_gemm_bf16_v2{
"todos": [
{
"content": "Check GPU availability and CUDA environment",
"status": "completed",
"activeForm": "Checking GPU availability and CUDA environment"
},
{
"content": "Compile v3 and pass check.py",
"status": "completed",
"activeForm": "Compiling v3 and passing check.py"
},
{
"content": "Benchmark and identify bottlenecks per shape",
"status": "completed",
"activeForm": "Benchmarking and identifying bottlenecks"
},
{
"content": "Diagnose K=64 streaming slowness; check register spill / PTX",
... (407 more chars)Todos have been modified successfully. Ensure that you continue to use the todo list to track your progress. Please proceed with the current tasks if applicable{
"command": "<1045 chars \u2014 see below>",
"description": "Get verbose ptxas output"
}rm -rf ~/.cache/torch_extensions/py311_cu130/topk_v3
cat > /tmp/build_verbose.py << 'EOF'
import os
os.environ['TORCH_CUDA_ARCH_LIST'] = '12.0'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic')
os.chdir('/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic')
import torch
from torch.utils.cpp_extension import load_inline
# Read CUDA source from solution.py
with open('solution.py') as f:
src = f.read()
start = src.index('CUDA_SRC = r"""') + len('CUDA_SRC = r"""')
end = src.index('"""\n\n_module')
cuda = src[start:end]
cpp = 'std::vector<torch::Tensor> topk_forward(torch::Tensor x, int64_t k);'
mod = load_inline(
name='topk_verbose',
cpp_sources=cpp,
cuda_sources=cuda,
functions=['topk_forward'],
extra_cuda_cflags=[
'-O3', '--use_fast_math',
'-gencode=arch=compute_120,code=sm_120',
'-Xptxas=-v',
],
verbose=True,
)
EOF
uv run python /tmp/build_verbose.py 2>&1 | grep -E "registers|spill|Used|stack frame|bytes" | head -50ptxas info : 46 bytes gmem
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 4 registers, used 0 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 38 registers, used 1 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 48 registers, used 1 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 48 registers, used 1 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 78 registers, used 1 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 80 registers, used 1 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers, used 1 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers, used 1 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 148 registers, used 1 barriers
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 158 registers, used 1 barriers--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,19 +1,15 @@
-"""TopK custom kernel via load_inline (v3: per-thread streaming + block bitonic).
+"""TopK custom kernel via load_inline (v4).
Strategy:
- - Each thread streams its share of a chunk, maintaining a sorted top-K
- array in registers (descending). Insertion uses static-indexed
- bubble-up (place new element at position K-1, then bubble it up via
- K-1 conditional swaps; all indices are compile-time constants so the
- compiler keeps everything in registers without spills).
- - After streaming, T threads write per-thread top-K to shared memory
- (T*K floats + T*K ints), and a block-wide bitonic sort on those
- T*K candidates picks the chunk-local top-K.
- - Each row is split into G chunks; G is chosen per shape so total
- stage-1 block count is ~256 (filling the 188-SM RTX PRO 6000).
- - Stage 2 merges the G*K candidates per row into the final top-K.
- - For shapes where 1 block per row already saturates the SMs (e.g.
- batch=128) the kernel runs single-stage with int64 output directly.
+ - Full bitonic sort of small chunks in shared memory; take chunk-local
+ top-K. Bitonic sort over a small N has O(log^2 N) parallel stages and
+ no long per-thread dependency chains, so it's faster than per-thread
+ streaming insertion-sort for K >= 16.
+ - Stage 1 is split into chunks so that total stage-1 block count is
+ O(>=128), filling the 188-SM RTX PRO 6000.
+ - Stage 2 single block per row merges the G*K candidates.
+ - Shape (128, 4096, 1) is k=1 argmax — use a dedicated 1-block-per-row
+ parallel max reduction (no sort needed).
"""
import os
@@ -35,29 +31,7 @@
#define NEG_INF (-3.4e38f)
-// Insert (v, idx) into sorted-descending registers my_v[0..K-1], my_i[0..K-1].
-// Static-indexed: place candidate at slot K-1, then bubble up via K-1
-// conditional swaps. All indices are compile-time constants, so PTX should
-// keep my_v/my_i in registers and use predicated swaps (no local memory).
-template<int K>
-__device__ __forceinline__ void thread_insert(float v, int idx,
- float (&my_v)[K], int (&my_i)[K]) {
- if (v <= my_v[K-1]) return;
- my_v[K-1] = v;
- my_i[K-1] = idx;
- #pragma unroll
- for (int p = K - 1; p > 0; --p) {
- bool sw = my_v[p] > my_v[p-1];
- float a = my_v[p-1], b = my_v[p];
- int ai = my_i[p-1], bi = my_i[p];
- my_v[p-1] = sw ? b : a;
- my_v[p] = sw ? a : b;
- my_i[p-1] = sw ? bi : ai;
- my_i[p] = sw ? ai : bi;
- }
-}
-
-// Block-wide bitonic sort, ascending.
+// Block-wide bitonic sort, ascending. N must be a power of 2.
template<int N, int THREADS>
__device__ __forceinline__ void bitonic_sort_block(float* svals, int* sidxs) {
#pragma unroll
@@ -86,22 +60,20 @@
}
}
-// Stage 1 / single-stage kernel: per-thread streaming top-K + block bitonic.
-// N_REDUCE = next pow 2 of THREADS * K (must be a power of 2).
-// When write_int64=1, indices output is int64 directly (single-stage).
-// chunks_per_row = grid.x dim (1 for single-stage).
-template<int K, int THREADS, int N_REDUCE, int WRITE_INT64>
-__global__ void topk_streaming_kernel(
+// Bitonic chunk kernel: load a CHUNK_SIZE chunk into shared (pad to N_PAD
+// with -inf), bitonic sort, write top-K. WRITE_INT64=1 writes int64 directly.
+template<int K, int N_PAD, int THREADS, int WRITE_INT64>
+__global__ void topk_chunk_kernel(
const float* __restrict__ x,
- void* __restrict__ out_v, // float*
- void* __restrict__ out_i, // int* or int64_t*
+ void* __restrict__ out_v,
+ void* __restrict__ out_i,
int n,
int chunk_size,
int chunks_per_row)
{
extern __shared__ unsigned char smem_raw[];
float* svals = reinterpret_cast<float*>(smem_raw);
- int* sidxs = reinterpret_cast<int*>(smem_raw + N_REDUCE * sizeof(float));
+ int* sidxs = reinterpret_cast<int*>(smem_raw + N_PAD * sizeof(float));
int row = blockIdx.y;
int blk = blockIdx.x;
@@ -110,63 +82,45 @@
if (chunk_end > n) chunk_end = n;
const float* xrow = x + row * (size_t)n;
- float my_v[K];
- int my_i[K];
#pragma unroll
- for (int p = 0; p < K; ++p) { my_v[p] = NEG_INF; my_i[p] = -1; }
-
- // Stream the chunk.
- int i = chunk_start + threadIdx.x;
- for (; i + 3 * THREADS < chunk_end; i += 4 * THREADS) {
- float v0 = xrow[i];
- float v1 = xrow[i + THREADS];
- float v2 = xrow[i + 2 * THREADS];
- float v3 = xrow[i + 3 * THREADS];
- thread_insert<K>(v0, i, my_v, my_i);
- thread_insert<K>(v1, i + THREADS, my_v, my_i);
- thread_insert<K>(v2, i + 2 * THREADS, my_v, my_i);
- thread_insert<K>(v3, i + 3 * THREADS, my_v, my_i);
- }
- for (; i < chunk_end; i += THREADS) {
- thread_insert<K>(xrow[i], i, my_v, my_i);
- }
-
- // Write per-thread top-K to shared at strided offset.
- #pragma unroll
- for (int p = 0; p < K; ++p) {
- svals[threadIdx.x * K + p] = my_v[p];
- sidxs[threadIdx.x * K + p] = my_i[p];
- }
- int total_real = THREADS * K;
- for (int q = total_real + threadIdx.x; q < N_REDUCE; q += THREADS) {
- svals[q] = NEG_INF;
- sidxs[q] = -1;
+ for (int i0 = 0; i0 < N_PAD; i0 += THREADS) {
+ int i = i0 + threadIdx.x;
+ if (i < N_PAD) {
+ int gi = chunk_start + i;
+ if (gi < chunk_end) {
+ svals[i] = xrow[gi];
+ sidxs[i] = gi;
+ } else {
+ svals[i] = NEG_INF;
+ sidxs[i] = -1;
+ }
+ }
}
__syncthreads();
- bitonic_sort_block<N_REDUCE, THREADS>(svals, sidxs);
+ bitonic_sort_block<N_PAD, THREADS>(svals, sidxs);
int out_off = (row * chunks_per_row + blk
... (15830 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"content": "<14850 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -10",
"description": "Check v4"
}PASS{
"command": "uv run python /tmp/topk_bench_detail.py 2>&1 | tail -10",
"description": "Bench v4"
}shape b=1 n=131072 k=64: sol=229.9us, topk=46.7us, ratio=4.92x
shape b=64 n=8192 k=8: sol=40.1us, topk=45.2us, ratio=0.89x
shape b=32 n=16384 k=32: sol=52.1us, topk=45.2us, ratio=1.15x
shape b=16 n=12000 k=16: sol=44.4us, topk=38.5us, ratio=1.15x
shape b=128 n=4096 k=1: sol=8.7us, topk=19.1us, ratio=0.46x{
"command": "<1160 chars \u2014 see below>",
"description": "Profile stages"
}cat > /tmp/profile_stages.py << 'EOF'
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # async
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic')
os.chdir('/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic')
import torch
import solution
# Profile shape 0 in detail
m = solution.Model(1, 131072, 64).cuda().eval()
x = torch.randn(1, 131072, device='cuda')
for _ in range(20): m(x); torch.cuda.synchronize()
# Use profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=False,
) as prof:
for _ in range(50):
m(x)
torch.cuda.synchronize()
events = prof.events()
kernel_events = [e for e in events if e.device_type.name == 'CUDA' and 'topk' in e.name.lower()]
# Aggregate
agg = {}
for e in kernel_events:
agg.setdefault(e.name, []).append(e.cuda_time_total / 1000.0) # us
for k, v in sorted(agg.items()):
avg = sum(v)/len(v)
print(f'{k}: avg={avg:.1f}us, count={len(v)}')
print()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
EOF
uv run python /tmp/profile_stages.py 2>&1 | tail -50/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
/tmp/profile_stages.py:28: FutureWarning: `cuda_time_total` is deprecated. Use `device_time_total` instead.
agg.setdefault(e.name, []).append(e.cuda_time_total / 1000.0) # us
void topk_chunk_kernel<64, 1024, 128, 0>(float const*, void*, void*, int, int, int): avg=0.0us, count=50
void topk_stage2_bitonic_kernel<64, 8192, 256>(float const*, int const*, float*, long*, int): avg=0.2us, count=50
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
void topk_stage2_bitonic_kernel<64, 8192, 256>(float... 0.00% 0.000us 0.00% 0.000us 0.000us 9.762ms 87.61% 9.762ms 195.239us 50
void topk_chunk_k
... (1231 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,15 +1,15 @@
-"""TopK custom kernel via load_inline (v4).
+"""TopK custom kernel via load_inline (v5).
Strategy:
- - Full bitonic sort of small chunks in shared memory; take chunk-local
- top-K. Bitonic sort over a small N has O(log^2 N) parallel stages and
- no long per-thread dependency chains, so it's faster than per-thread
- streaming insertion-sort for K >= 16.
- - Stage 1 is split into chunks so that total stage-1 block count is
- O(>=128), filling the 188-SM RTX PRO 6000.
- - Stage 2 single block per row merges the G*K candidates.
- - Shape (128, 4096, 1) is k=1 argmax — use a dedicated 1-block-per-row
- parallel max reduction (no sort needed).
+ - Stage 1: full bitonic sort of small chunks in shared memory; take
+ chunk-local top-K. Stage 1 has many small chunks for SM parallelism.
+ - Stage 2: each stage-1 chunk emits a sorted-descending run of K
+ elements. We merge G sorted runs of length K into the global top-K
+ using a tournament-style argmax: per iteration, one thread per run
+ holds its current head value; warp+block argmax picks the run with
+ the largest head; that run advances by 1. K iterations -> top-K.
+ This is O(K * log G) instead of bitonic-sort's O(G*K * log^2(G*K)).
+ - Shape (128, 4096, 1) is k=1 argmax — dedicated max-reduction kernel.
"""
import os
@@ -60,8 +60,8 @@
}
}
-// Bitonic chunk kernel: load a CHUNK_SIZE chunk into shared (pad to N_PAD
-// with -inf), bitonic sort, write top-K. WRITE_INT64=1 writes int64 directly.
+// Stage 1 / single-stage chunk kernel: load CHUNK_SIZE into shared (pad
+// to N_PAD with -inf), bitonic sort, write top-K to (out_v, out_i).
template<int K, int N_PAD, int THREADS, int WRITE_INT64>
__global__ void topk_chunk_kernel(
const float* __restrict__ x,
@@ -118,8 +118,9 @@
}
}
-// Stage 2: input is candidate values + int32 indices. Output int64 indices.
-template<int K, int N_PAD, int THREADS>
+// Stage 2 (bitonic): merge candidates (already sorted in chunks of K)
+// via full bitonic sort over N_PAD candidates.
+template<int K_OUT, int N_PAD, int THREADS>
__global__ void topk_stage2_bitonic_kernel(
const float* __restrict__ in_v,
const int* __restrict__ in_i,
@@ -152,13 +153,95 @@
bitonic_sort_block<N_PAD, THREADS>(svals, sidxs);
- for (int q = threadIdx.x; q < K; q += THREADS) {
- vals[row * K + q] = svals[N_PAD - 1 - q];
- idxs[row * K + q] = (int64_t)sidxs[N_PAD - 1 - q];
- }
-}
-
-// Argmax kernel for k=1 — block-wide max-reduction.
+ for (int q = threadIdx.x; q < K_OUT; q += THREADS) {
+ vals[row * K_OUT + q] = svals[N_PAD - 1 - q];
+ idxs[row * K_OUT + q] = (int64_t)sidxs[N_PAD - 1 - q];
+ }
+}
+
+// Stage 2 (k-way merge tournament): NUM_RUNS sorted-desc runs of RUN_LEN
+// each per row. THREADS == NUM_RUNS (one thread per run). K iterations of
+// argmax + advance.
+template<int K_OUT, int NUM_RUNS, int RUN_LEN, int THREADS>
+__global__ void topk_stage2_merge_kernel(
+ const float* __restrict__ in_v,
+ const int* __restrict__ in_i,
+ float* __restrict__ vals,
+ int64_t* __restrict__ idxs)
+{
+ int row = blockIdx.x;
+ int tid = threadIdx.x;
+
+ // Each thread handles one run.
+ int run_off = row * NUM_RUNS * RUN_LEN + tid * RUN_LEN;
+
+ float head_v = (tid < NUM_RUNS) ? in_v[run_off] : NEG_INF;
+ int head_i = (tid < NUM_RUNS) ? in_i[run_off] : -1;
+ int head_ptr = 0;
+
+ __shared__ float smax[(THREADS + 31) / 32];
+ __shared__ int sidx[(THREADS + 31) / 32];
+ __shared__ int s_winner;
+ __shared__ float s_winner_v;
+ __shared__ int s_winner_i;
+
+ constexpr int WARP_SIZE = 32;
+ constexpr int NUM_WARPS = THREADS / WARP_SIZE;
+
+ for (int k_iter = 0; k_iter < K_OUT; ++k_iter) {
+ // Warp-level argmax.
+ unsigned mask = 0xFFFFFFFFu;
+ float my_v = head_v;
+ int my_t = tid;
+ #pragma unroll
+ for (int s = WARP_SIZE / 2; s > 0; s >>= 1) {
+ float ov = __shfl_xor_sync(mask, my_v, s);
+ int ot = __shfl_xor_sync(mask, my_t, s);
+ // Tie-break: prefer smaller thread index (deterministic).
+ bool take = (ov > my_v) || (ov == my_v && ot < my_t);
+ if (take) { my_v = ov; my_t = ot; }
+ }
+ int lane = tid & (WARP_SIZE - 1);
+ int warp = tid / WARP_SIZE;
+ if (lane == 0) { smax[warp] = my_v; sidx[warp] = my_t; }
+ __syncthreads();
+
+ if (warp == 0) {
+ float v = (lane < NUM_WARPS) ? smax[lane] : NEG_INF;
+ int t = (lane < NUM_WARPS) ? sidx[lane] : -1;
+ #pragma unroll
+ for (int s = NUM_WARPS / 2; s > 0; s >>= 1) {
+ float ov = __shfl_xor_sync(mask, v, s);
+ int ot = __shfl_xor_sync(mask, t, s);
+ bool take = (ov > v) || (ov == v && ot < t);
+ if (take) { v = ov; t = ot; }
+ }
+ if (lane == 0) {
+ s_winner = t;
+ s_winner_v = v;
+ // We need the index too — the winner thread will write it.
+ vals[row * K_OUT + k_iter] = v;
+ }
+ }
+ __syncthreads();
+
+ int w = s_winner;
+ if (tid == w) {
+ idxs[row * K_OUT + k_iter] = (int64_t)head_i;
+ head_ptr++;
+ if (head_ptr < RUN_LEN) {
+ head_v = in_v[run_off + head_ptr];
+ head_i = in_i[run_off + head_ptr];
+ } else {
+ head_v = NEG_INF;
+ head_i = -1;
+ }
+ }
+ __syncthreads();
+ }
+}
+
+// Argmax kernel for k=1.
template<int N, int THREADS>
__global__ void argmax_kernel(
const float* __restrict__ x,
@@ -176,7 +259,6 @@
if (v > my_v) { my_v =
... (7947 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"content": "<17125 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -10",
"description": "Check v5"
}PASS{
"command": "uv run python /tmp/topk_bench_detail.py 2>&1 | tail -10",
"description": "Bench v5"
}shape b=1 n=131072 k=64: sol=50.0us, topk=46.5us, ratio=1.08x
shape b=64 n=8192 k=8: sol=43.7us, topk=45.4us, ratio=0.96x
shape b=32 n=16384 k=32: sol=43.7us, topk=44.9us, ratio=0.97x
shape b=16 n=12000 k=16: sol=38.6us, topk=40.5us, ratio=0.95x
shape b=128 n=4096 k=1: sol=8.5us, topk=18.8us, ratio=0.45x{
"todos": [
{
"content": "Check GPU availability and CUDA environment",
"status": "completed",
"activeForm": "Checking GPU availability and CUDA environment"
},
{
"content": "Compile v5 and pass check.py",
"status": "completed",
"activeForm": "Compiling v5 and passing check.py"
},
{
"content": "Benchmark v5 \u2014 beating SOTA on 4/5 shapes",
"status": "completed",
"activeForm": "Benchmarking v5"
},
{
"content": "Push peak_fraction up further (saturate DRAM)",
"status": "in_progress",
"active
... (53 more chars)Todos have been modified successfully. Ensure that you continue to use the todo list to track your progress. Please proceed with the current tasks if applicable{
"command": "uv run python benchmark.py 2>&1 | tail -30",
"description": "Full benchmark v5"
}shape=0 variant=eager tflops=0.011 gbps=10.655 ms=0.049
shape=0 variant=compiled tflops=0.006 gbps=6.167 ms=0.085
shape=0 variant=sota tflops=0.010 gbps=9.944 ms=0.053
shape=0 variant=solution tflops=0.009 gbps=8.908 ms=0.059
shape=0 solution_peak_fraction=0.0049
shape=1 variant=eager tflops=0.032 gbps=32.031 ms=0.066
shape=1 variant=compiled tflops=0.024 gbps=23.690 ms=0.089
shape=1 variant=sota tflops=0.044 gbps=43.775 ms=0.048
shape=1 variant=solution tflops=0.044 gbps=43.877 ms=0.048
shape=1 solution_peak_fraction=0.0244
shape=2 variant=eager tflops=0.043 gbps=43.397 ms=0.049
shape=2 variant=compiled tflops=0.024 gbps=23.962 ms=0.088
shape=2 variant=sota tflops=0.044 gbps=44.123 ms=0.048
shape=2 variant=solution tflops=0.040 gbps=40.256 ms=0.052
shape=2 solution_peak_fraction=0.0224
shape=3 variant=eager tflops=0.016 gbps=16.403 ms=0.047
shape=3 variant=compiled tflops=0.009 gbps=9.229 ms=0.084
shape=3 variant=sota tflops=0.016 gbps=16.112 ms=0.048
shape=3 variant=solution tflops=0.017 gbps=17.360 ms=0.044
shape=3 solution_peak_fraction=0.0096
shape=4 variant=eager tflops=0.089 gbps=89.291 ms=0.024
shape=4 variant=compiled tflops=0.034 gbps=34.346 ms=0.061
shape=4 variant=sota tflops=0.096 gbps=96.447 ms=0.022
shape=4 variant=solution tflops=0.173 gbps=173.273 ms=0.012
shape=4 solution_peak_fraction=0.0963
peak_fraction: 0.0190
RESULT: LOW{
"command": "<877 chars \u2014 see below>",
"description": "Profile per-shape kernels"
}cat > /tmp/profile_v5.py << 'EOF'
import os
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic')
os.chdir('/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic')
import torch
import solution
shapes = [
(1, 131072, 64),
(64, 8192, 8),
(32, 16384, 32),
(16, 12000, 16),
(128, 4096, 1),
]
for batch, n, k in shapes:
m = solution.Model(batch, n, k).cuda().eval()
x = torch.randn(batch, n, device='cuda')
for _ in range(20): m(x); torch.cuda.synchronize()
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
for _ in range(50): m(x); torch.cuda.synchronize()
print(f'\n=== Shape b={batch} n={n} k={k} ===')
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
EOF
uv run python /tmp/profile_v5.py 2>&1 | head -80/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
=== Shape b=1 n=131072 k=64 ===
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
void topk_chunk_kernel<64, 1024, 128, 0>(float const... 0.00% 0.000us 0.00% 0.000us 0.000us 1.389ms 63.31% 1.389ms 27.775us 50
void topk_stage2_merge_kernel<64, 128, 64, 128>(floa... 0.00% 0.000us 0.00% 0.000us 0.000us 804.844us 36.69% 804.844us 16.097us 50
cudaFuncSetAttribute 1.01% 30.975us 1.01% 30.975us 0.619us 0.000us 0.00% 0.000us 0.000us
... (8762 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"offset": 23,
"limit": 100
}23 std::vector<torch::Tensor> topk_forward(torch::Tensor x, int64_t k);
24 """
25
26 CUDA_SRC = r"""
27 #include <torch/extension.h>
28 #include <ATen/cuda/CUDAContext.h>
29 #include <cuda_runtime.h>
30 #include <cstdint>
31
32 #define NEG_INF (-3.4e38f)
33
34 // Block-wide bitonic sort, ascending. N must be a power of 2.
35 template<int N, int THREADS>
36 __device__ __forceinline__ void bitonic_sort_block(float* svals, int* sidxs) {
37 #pragma unroll
38 for (int sz = 2; sz <= N; sz <<= 1) {
39 #pragma unroll
40 for (int j = sz >> 1; j > 0; j >>= 1) {
41 #pragma unroll
42 for (int t0 = 0; t0 < N; t0 += THREADS) {
43 int t = t0 + threadIdx.x;
44 if (t < N) {
45 int p = t ^ j;
46 if (p > t) {
47 bool asc = (t & sz) == 0;
48 float vt = svals[t], vp = svals[p];
49 bool sw = asc ? (vt > vp) : (vt < vp);
50 if (sw) {
51 svals[t] = vp; svals[p] = vt;
52 int it = sidxs[t], ip = sidxs[p];
53 sidxs[t] = ip; sidxs[p] = it;
54 }
55 }
56 }
57 }
58 __syncthreads();
59 }
60 }
61 }
62
63 // Stage 1 / single-stage chunk kernel: load CHUNK_SIZE into shared (pad
64 // to N_PAD with -inf), bitonic sort,
... (2113 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -32,10 +32,20 @@
#define NEG_INF (-3.4e38f)
// Block-wide bitonic sort, ascending. N must be a power of 2.
+// Uses __syncwarp for j<32 (intra-warp pattern) and __syncthreads for j>=32.
+// At sz transitions where the next sz starts with a block-level j, we
+// insert a __syncthreads to make warp-level writes visible cross-warp.
template<int N, int THREADS>
__device__ __forceinline__ void bitonic_sort_block(float* svals, int* sidxs) {
#pragma unroll
for (int sz = 2; sz <= N; sz <<= 1) {
+ // If the previous sz ended on a warp-only sequence and this sz starts
+ // with a block-level j (j = sz/2 >= 32), block-sync first.
+ if ((sz >> 1) >= 32) {
+ // The very first sz=2 has prev sequence; but sz=2 has j=1 only,
+ // so this branch only triggers for sz>=64.
+ __syncthreads();
+ }
#pragma unroll
for (int j = sz >> 1; j > 0; j >>= 1) {
#pragma unroll
@@ -55,7 +65,11 @@
}
}
}
- __syncthreads();
+ if (j >= 32) {
+ __syncthreads();
+ } else {
+ __syncwarp();
+ }
}
}
}
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"old_string": "<1034 chars \u2014 see diff>",
"new_string": "<1701 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -5",
"description": "Check after warp sync change"
}FAIL: shape 1 {'batch': 64, 'n': 8192, 'k': 8} seed 42 values: tolerance exceeded: max_abs_diff=0.0782671 (atol=0.0001, rtol=0.0001)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -32,20 +32,10 @@
#define NEG_INF (-3.4e38f)
// Block-wide bitonic sort, ascending. N must be a power of 2.
-// Uses __syncwarp for j<32 (intra-warp pattern) and __syncthreads for j>=32.
-// At sz transitions where the next sz starts with a block-level j, we
-// insert a __syncthreads to make warp-level writes visible cross-warp.
template<int N, int THREADS>
__device__ __forceinline__ void bitonic_sort_block(float* svals, int* sidxs) {
#pragma unroll
for (int sz = 2; sz <= N; sz <<= 1) {
- // If the previous sz ended on a warp-only sequence and this sz starts
- // with a block-level j (j = sz/2 >= 32), block-sync first.
- if ((sz >> 1) >= 32) {
- // The very first sz=2 has prev sequence; but sz=2 has j=1 only,
- // so this branch only triggers for sz>=64.
- __syncthreads();
- }
#pragma unroll
for (int j = sz >> 1; j > 0; j >>= 1) {
#pragma unroll
@@ -65,11 +55,7 @@
}
}
}
- if (j >= 32) {
- __syncthreads();
- } else {
- __syncwarp();
- }
+ __syncthreads();
}
}
}
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"old_string": "<1701 chars \u2014 see diff>",
"new_string": "<1034 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3",
"description": "Verify revert is correct"
}PASS--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -351,12 +351,12 @@
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (batch == 1 && n == 131072 && k == 64) {
- // Stage 1: 128 chunks * 1024. K=64.
- constexpr int CHUNK = 1024;
- constexpr int N_PAD1 = 1024;
+ // Stage 1: 64 chunks * 2048. K=64. T=256.
+ constexpr int CHUNK = 2048;
+ constexpr int N_PAD1 = 2048;
constexpr int K1 = 64;
- constexpr int T1 = 128;
- constexpr int CPR = 131072 / CHUNK; // 128
+ constexpr int T1 = 256;
+ constexpr int CPR = 131072 / CHUNK; // 64
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
@@ -365,15 +365,26 @@
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
- // Stage 2: k-way merge of 128 sorted runs (K1 each) -> top-K=64.
- launch_stage2_merge<K1, CPR, K1, CPR>(v_scratch, i_scratch, vals, idxs,
- batch, stream);
+ // Stage 2: k-way merge of 64 sorted runs (K1 each) -> top-K=64.
+ launch_stage2_merge<K1, CPR, K1, 64>(v_scratch, i_scratch, vals, idxs,
+ batch, stream);
} else if (batch == 64 && n == 8192 && k == 8) {
- constexpr int CHUNK = 1024;
- constexpr int N_PAD1 = 1024;
+ // Single-stage: 1 block per row, full bitonic 8192. 64 blocks total.
+ constexpr int CHUNK = 8192;
+ constexpr int N_PAD1 = 8192;
constexpr int K1 = 8;
- constexpr int T1 = 128;
- constexpr int CPR = 8192 / CHUNK; // 8
+ constexpr int T1 = 512;
+ constexpr int CPR = 1;
+
+ launch_chunk_int64<K1, N_PAD1, T1>(x, vals, idxs,
+ n, CHUNK, CPR, batch, stream);
+ } else if (batch == 32 && n == 16384 && k == 32) {
+ // 2 chunks * 8192 each, K=32. 64 stage1 blocks.
+ constexpr int CHUNK = 8192;
+ constexpr int N_PAD1 = 8192;
+ constexpr int K1 = 32;
+ constexpr int T1 = 512;
+ constexpr int CPR = 16384 / CHUNK; // 2
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
@@ -382,19 +393,19 @@
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
- // Stage 2: 8 runs * 8 elements = 64 candidates.
- // Use bitonic on 64 (cheap).
+ // Stage 2: 2 runs * 32 = 64 candidates -> top-32. Bitonic 64.
constexpr int CAND2 = CPR * K1; // 64
constexpr int N_PAD2 = 64;
constexpr int T2 = 32;
launch_stage2_bitonic<K1, N_PAD2, T2>(v_scratch, i_scratch, vals, idxs,
CAND2, batch, stream);
- } else if (batch == 32 && n == 16384 && k == 32) {
- constexpr int CHUNK = 1024;
- constexpr int N_PAD1 = 1024;
- constexpr int K1 = 32;
- constexpr int T1 = 128;
- constexpr int CPR = 16384 / CHUNK; // 16
+ } else if (batch == 16 && n == 12000 && k == 16) {
+ // 2 chunks of 6000 (pad to 8192), K=16. 32 stage1 blocks.
+ constexpr int CHUNK = 6000;
+ constexpr int N_PAD1 = 8192;
+ constexpr int K1 = 16;
+ constexpr int T1 = 512;
+ constexpr int CPR = 2;
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
@@ -403,28 +414,12 @@
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
- // Stage 2: 16 runs * 32 = 512 candidates.
- // Try k-way merge: 16 runs of 32 -> top-32. THREADS=32 (one per run, padded for warp).
- launch_stage2_merge<K1, CPR, K1, 32>(v_scratch, i_scratch, vals, idxs,
- batch, stream);
- } else if (batch == 16 && n == 12000 && k == 16) {
- constexpr int CHUNK = 1024;
- constexpr int N_PAD1 = 1024;
- constexpr int K1 = 16;
- constexpr int T1 = 128;
- constexpr int CPR = 12; // ceil(12000/1024)
-
- auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
- auto i_scratch = torch::empty({batch * CPR * K1},
- x.options().dtype(torch::kInt32));
-
- launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
- n, CHUNK, CPR, batch, stream);
-
- // Stage 2: 12 runs of 16 = 192 candidates -> top-16.
- // Merge with THREADS=32 (12 runs + 20 inactive).
- launch_stage2_merge<K1, CPR, K1, 32>(v_scratch, i_scratch, vals, idxs,
- batch, stream);
+ // Stage 2: 2 runs * 16 = 32 candidates -> top-16. Bitonic 32.
+ constexpr int CAND2 = CPR * K1; // 32
+ constexpr int N_PAD2 = 32;
+ constexpr int T2 = 32;
+ launch_stage2_bitonic<K1, N_PAD2, T2>(v_scratch, i_scratch, vals, idxs,
+ CAND2, batch, stream);
} else if (batch == 128 && n == 4096 && k == 1) {
constexpr int N = 4096;
constexpr int T = 256;
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"old_string": "<3604 chars \u2014 see diff>",
"new_string": "<3303 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -438,7 +438,7 @@
"""
_module = load_inline(
- name="topk_v5",
+ name="topk_v6",
cpp_sources=CPP_SRC,
cuda_sources=CUDA_SRC,
functions=["topk_forward"],
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"old_string": "<42 chars \u2014 see diff>",
"new_string": "<42 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -5",
"description": "Check v6"
}2 errors detected in the compilation of "/home/infatoshi/.cache/torch_extensions/py311_cu130/topk_v6/cuda.cu".
[2/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=topk_v6 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -fPIC -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/topk_v6/main.cpp -o main.o
ninja: build stopped: subcommand failed.{
"command": "uv run python check.py 2>&1 | grep error | head -5",
"description": "Get error details"
}FAIL: import error: Error building extension 'topk_v6': [1/2] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=topk_v6 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -gencode=arch=compute_120,code=sm_120 -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/topk_v6/cuda.cu -o cuda.cuda.o
/home/infatoshi/.cache/torch_extensions/py311_cu130/topk_v6/cuda.cu(357): error: identifier "launch_chunk_int64" is undefined
/home/infatoshi/.cache/torch_extensions/py311_cu130/topk_v6/cuda.cu(357): error: operand types are incompatible ("int" and "cudaStream_t" (aka "CUstream_st *"))
2 errors detected in the compilation of "/home/infatoshi/.cache/torch_extensions/py311_cu130/topk_v6/cuda.cu".{
"pattern": "launch_chunk_int",
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"output_mode": "content"
}294:inline void launch_chunk_int32(torch::Tensor x, torch::Tensor v_scratch,
365: launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
379: launch_chunk_int64<K1, N_PAD1, T1>(x, vals, idxs,
393: launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
414: launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -307,6 +307,23 @@
n, chunk_size, chunks_per_row);
}
+template<int K, int N_PAD, int THREADS>
+inline void launch_chunk_int64(torch::Tensor x, torch::Tensor vals,
+ torch::Tensor idxs,
+ int n, int chunk_size, int chunks_per_row,
+ int batch, cudaStream_t stream)
+{
+ size_t sm = (size_t)N_PAD * (sizeof(float) + sizeof(int));
+ cudaFuncSetAttribute((void*)topk_chunk_kernel<K, N_PAD, THREADS, 1>,
+ cudaFuncAttributeMaxDynamicSharedMemorySize, sm);
+ topk_chunk_kernel<K, N_PAD, THREADS, 1>
+ <<<dim3(chunks_per_row, batch), THREADS, sm, stream>>>(
+ x.data_ptr<float>(),
+ (void*)vals.data_ptr<float>(),
+ (void*)idxs.data_ptr<int64_t>(),
+ n, chunk_size, chunks_per_row);
+}
+
template<int K_OUT, int N_PAD, int THREADS>
inline void launch_stage2_bitonic(torch::Tensor v_scratch, torch::Tensor i_scratch,
torch::Tensor vals, torch::Tensor idxs,
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"old_string": "<769 chars \u2014 see diff>",
"new_string": "<1524 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3",
"description": "Check v6"
}PASS{
"command": "uv run python /tmp/topk_bench_detail.py 2>&1 | tail -10",
"description": "Bench v6"
}shape b=1 n=131072 k=64: sol=63.3us, topk=46.1us, ratio=1.37x
shape b=64 n=8192 k=8: sol=127.8us, topk=45.5us, ratio=2.81x
shape b=32 n=16384 k=32: sol=132.0us, topk=46.5us, ratio=2.84x
shape b=16 n=12000 k=16: sol=130.0us, topk=42.2us, ratio=3.08x
shape b=128 n=4096 k=1: sol=8.3us, topk=18.4us, ratio=0.45x{
"todos": [
{
"content": "Revert to v5 base config (chunk=1024) and reapply",
"status": "in_progress",
"activeForm": "Reverting to v5 base config"
},
{
"content": "Add float4 vectorized loads to chunk kernel",
"status": "pending",
"activeForm": "Adding float4 vectorized loads"
},
{
"content": "Try cooperative grid sync to fuse stages",
"status": "pending",
"activeForm": "Trying cooperative grid sync"
},
{
"content": "Push peak_fraction up further",
"status": "pending",
"activeForm": "Pushing
... (37 more chars)Todos have been modified successfully. Ensure that you continue to use the todo list to track your progress. Please proceed with the current tasks if applicable--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -368,12 +368,12 @@
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (batch == 1 && n == 131072 && k == 64) {
- // Stage 1: 64 chunks * 2048. K=64. T=256.
- constexpr int CHUNK = 2048;
- constexpr int N_PAD1 = 2048;
+ // Stage 1: 128 chunks * 1024. K=64.
+ constexpr int CHUNK = 1024;
+ constexpr int N_PAD1 = 1024;
constexpr int K1 = 64;
- constexpr int T1 = 256;
- constexpr int CPR = 131072 / CHUNK; // 64
+ constexpr int T1 = 128;
+ constexpr int CPR = 131072 / CHUNK; // 128
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
@@ -382,26 +382,14 @@
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
- // Stage 2: k-way merge of 64 sorted runs (K1 each) -> top-K=64.
- launch_stage2_merge<K1, CPR, K1, 64>(v_scratch, i_scratch, vals, idxs,
- batch, stream);
+ launch_stage2_merge<K1, CPR, K1, CPR>(v_scratch, i_scratch, vals, idxs,
+ batch, stream);
} else if (batch == 64 && n == 8192 && k == 8) {
- // Single-stage: 1 block per row, full bitonic 8192. 64 blocks total.
- constexpr int CHUNK = 8192;
- constexpr int N_PAD1 = 8192;
+ constexpr int CHUNK = 1024;
+ constexpr int N_PAD1 = 1024;
constexpr int K1 = 8;
- constexpr int T1 = 512;
- constexpr int CPR = 1;
-
- launch_chunk_int64<K1, N_PAD1, T1>(x, vals, idxs,
- n, CHUNK, CPR, batch, stream);
- } else if (batch == 32 && n == 16384 && k == 32) {
- // 2 chunks * 8192 each, K=32. 64 stage1 blocks.
- constexpr int CHUNK = 8192;
- constexpr int N_PAD1 = 8192;
- constexpr int K1 = 32;
- constexpr int T1 = 512;
- constexpr int CPR = 16384 / CHUNK; // 2
+ constexpr int T1 = 128;
+ constexpr int CPR = 8192 / CHUNK; // 8
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
@@ -410,19 +398,17 @@
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
- // Stage 2: 2 runs * 32 = 64 candidates -> top-32. Bitonic 64.
constexpr int CAND2 = CPR * K1; // 64
constexpr int N_PAD2 = 64;
constexpr int T2 = 32;
launch_stage2_bitonic<K1, N_PAD2, T2>(v_scratch, i_scratch, vals, idxs,
CAND2, batch, stream);
- } else if (batch == 16 && n == 12000 && k == 16) {
- // 2 chunks of 6000 (pad to 8192), K=16. 32 stage1 blocks.
- constexpr int CHUNK = 6000;
- constexpr int N_PAD1 = 8192;
- constexpr int K1 = 16;
- constexpr int T1 = 512;
- constexpr int CPR = 2;
+ } else if (batch == 32 && n == 16384 && k == 32) {
+ constexpr int CHUNK = 1024;
+ constexpr int N_PAD1 = 1024;
+ constexpr int K1 = 32;
+ constexpr int T1 = 128;
+ constexpr int CPR = 16384 / CHUNK; // 16
auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
auto i_scratch = torch::empty({batch * CPR * K1},
@@ -431,12 +417,24 @@
launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
n, CHUNK, CPR, batch, stream);
- // Stage 2: 2 runs * 16 = 32 candidates -> top-16. Bitonic 32.
- constexpr int CAND2 = CPR * K1; // 32
- constexpr int N_PAD2 = 32;
- constexpr int T2 = 32;
- launch_stage2_bitonic<K1, N_PAD2, T2>(v_scratch, i_scratch, vals, idxs,
- CAND2, batch, stream);
+ launch_stage2_merge<K1, CPR, K1, 32>(v_scratch, i_scratch, vals, idxs,
+ batch, stream);
+ } else if (batch == 16 && n == 12000 && k == 16) {
+ constexpr int CHUNK = 1024;
+ constexpr int N_PAD1 = 1024;
+ constexpr int K1 = 16;
+ constexpr int T1 = 128;
+ constexpr int CPR = 12;
+
+ auto v_scratch = torch::empty({batch * CPR * K1}, x.options());
+ auto i_scratch = torch::empty({batch * CPR * K1},
+ x.options().dtype(torch::kInt32));
+
+ launch_chunk_int32<K1, N_PAD1, T1>(x, v_scratch, i_scratch,
+ n, CHUNK, CPR, batch, stream);
+
+ launch_stage2_merge<K1, CPR, K1, 32>(v_scratch, i_scratch, vals, idxs,
+ batch, stream);
} else if (batch == 128 && n == 4096 && k == 1) {
constexpr int N = 4096;
constexpr int T = 256;
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"old_string": "<3303 chars \u2014 see diff>",
"new_string": "<3134 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3 && uv run python /tmp/topk_bench_detail.py 2>&1 | tail -10",
"description": "Verify revert + bench"
}PASS
shape b=1 n=131072 k=64: sol=50.1us, topk=62.5us, ratio=0.80x
shape b=64 n=8192 k=8: sol=41.5us, topk=44.6us, ratio=0.93x
shape b=32 n=16384 k=32: sol=43.8us, topk=63.1us, ratio=0.69x
shape b=16 n=12000 k=16: sol=38.2us, topk=39.8us, ratio=0.96x
shape b=128 n=4096 k=1: sol=8.4us, topk=19.3us, ratio=0.44x