"""Single-kernel top-k for SM120 (RTX PRO 6000).
Strategy: one fused CUDA kernel per forward. Each row is split across one or
more blocks; each warp streams its span keeping a register-resident sorted
top-K (Faiss WarpSelect-style threshold + per-lane candidate queue, bitonic
networks via warp shuffles). Warps merge through shared memory. Multi-split
rows publish packed (key|idx) candidates to a scratch buffer; the last block
per row (device-scope acq-rel counter) merges them and writes the output.
Values travel as a monotonic fp32->u32 key packed with the source index into
a u64, so all comparisons are single integer compares and the exact fp32 bits
round-trip on output.
"""
import os
import torch
from torch.utils.cpp_extension import load_inline
_CPP_SRC = r"""
#include <cstdint>
int64_t topk_configure(int64_t batch, int64_t n, int64_t k, int64_t splits,
int64_t threads, int64_t vals_ptr, int64_t idx_ptr);
void topk_run(int64_t handle, int64_t x_ptr);
void topk_run_cached(int64_t handle);
"""
_CUDA_SRC = r"""
#include <cuda_runtime.h>
#include <cuda/atomic>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
using u32 = unsigned int;
using u64 = unsigned long long;
#define FULL_MASK 0xffffffffu
// Monotonic transform: descending float order == descending u32 key order.
__device__ __forceinline__ u32 fkey(u32 bits) {
return bits ^ (u32(int(bits) >> 31) | 0x80000000u);
}
__device__ __forceinline__ float unkey(u32 key) {
u32 b = key ^ ((key & 0x80000000u) ? 0x80000000u : 0xFFFFFFFFu);
return __uint_as_float(b);
}
__device__ __forceinline__ u64 umax(u64 a, u64 b) { return a > b ? a : b; }
__device__ __forceinline__ u64 umin(u64 a, u64 b) { return a < b ? a : b; }
// Descending bitonic merge of 32 elements held one-per-lane.
__device__ __forceinline__ void bmerge32(u64 &v, int lane) {
#pragma unroll
for (int d = 16; d >= 1; d >>= 1) {
u64 o = __shfl_xor_sync(FULL_MASK, v, d);
v = ((lane & d) == 0) ? umax(v, o) : umin(v, o);
}
}
// Descending bitonic merge of 64 elements: element e = j*32+lane in regs a0,a1.
__device__ __forceinline__ void bmerge64(u64 &a0, u64 &a1, int lane) {
u64 lo = umin(a0, a1);
a0 = umax(a0, a1);
a1 = lo;
bmerge32(a0, lane);
bmerge32(a1, lane);
}
// Full descending sort of 32 elements one-per-lane.
__device__ __forceinline__ void sort32(u64 &v, int lane) {
#pragma unroll
for (int s = 2; s <= 32; s <<= 1) {
#pragma unroll
for (int d = s >> 1; d >= 1; d >>= 1) {
u64 o = __shfl_xor_sync(FULL_MASK, v, d);
bool i_low = (lane & d) == 0;
bool asc = (lane & s) != 0; // gives overall descending
bool take_max = i_low ^ asc;
v = take_max ? umax(v, o) : umin(v, o);
}
}
}
// ---------- Warp-resident top-K buffer -------------------------------------
// R registers per lane; element e = j*32 + lane; sorted descending by e.
template <int R> struct WarpBuf;
template <> struct WarpBuf<1> {
u64 b0;
__device__ __forceinline__ void init() { b0 = 0; }
// Merge a descending-sorted 32-vector (one per lane) into the buffer.
__device__ __forceinline__ void merge_sorted32(u64 v, int lane) {
u64 r = __shfl_xor_sync(FULL_MASK, v, 31); // reversed
b0 = umax(b0, r);
bmerge32(b0, lane);
}
__device__ __forceinline__ u32 kth_key(int k, int lane) {
u64 e = __shfl_sync(FULL_MASK, b0, (k - 1) & 31);
return (u32)(e >> 32);
}
__device__ __forceinline__ void store(u64 *dst, int lane) { dst[lane] = b0; }
__device__ __forceinline__ void load(const u64 *src, int lane) { b0 = src[lane]; }
// Merge another descending list of 32 (in shared, reversed addressing).
__device__ __forceinline__ void merge_list(const u64 *src, int lane) {
b0 = umax(b0, src[31 - lane]);
bmerge32(b0, lane);
}
};
template <> struct WarpBuf<2> {
u64 b0, b1;
__device__ __forceinline__ void init() { b0 = 0; b1 = 0; }
__device__ __forceinline__ void merge_sorted32(u64 v, int lane) {
// positions 32..63 paired with reversed v
u64 r = __shfl_xor_sync(FULL_MASK, v, 31);
b1 = umax(b1, r);
bmerge64(b0, b1, lane);
}
__device__ __forceinline__ u32 kth_key(int k, int lane) {
u64 e = (k <= 32) ? __shfl_sync(FULL_MASK, b0, (k - 1) & 31)
: __shfl_sync(FULL_MASK, b1, (k - 1) & 31);
return (u32)(e >> 32);
}
__device__ __forceinline__ void store(u64 *dst, int lane) {
dst[lane] = b0; dst[32 + lane] = b1;
}
__device__ __forceinline__ void load(const u64 *src, int lane) {
b0 = src[lane]; b1 = src[32 + lane];
}
__device__ __forceinline__ void merge_list(const u64 *src, int lane) {
b0 = umax(b0, src[63 - lane]);
b1 = umax(b1, src[31 - lane]);
bmerge64(b0, b1, lane);
}
};
// ---------- Streaming select: queue + flush ---------------------------------
template <int R, int T> struct WarpSelect {
WarpBuf<R> buf;
u64 q[T];
int cnt;
u32 thresh;
int k;
int lane;
__device__ __forceinline__ void init(int k_, int lane_) {
buf.init();
cnt = 0;
thresh = 0;
k = k_;
lane = lane_;
}
__device__ __forceinline__ void flush() {
#pragma unroll
for (int j = 0; j < T; ++j) {
// skip slots empty across the whole warp
if (!__ballot_sync(FULL_MASK, cnt > j)) break;
u64 v = (j < cnt) ? q[j] : 0ULL;
sort32(v, lane);
buf.merge_sorted32(v, lane);
}
cnt = 0;
thresh = buf.kth_key(k, lane);
}
__device__ __forceinline__ void push(u32 key, u32 idx) {
if (key > thresh) {
q[cnt++] = ((u64)key << 32) | (u64)idx;
}
if (__ballot_sync(FULL_MASK, cnt == T)) flush();
}
// Fast path: check 4 candidates with one ballot when none pass.
__device__ __forceinline__ void push4(u32 k0, u32 k1, u32 k2, u32 k3, u32 i0) {
u32 m = max(max(k0, k1), max(k2, k3));
if (__ballot_sync(FULL_MASK, m > thresh)) {
push(k0, i0);
push(k1, i0 + 1);
push(k2, i0 + 2);
push(k3, i0 + 3);
}
}
__device__ __forceinline__ void finish() {
if (__ballot_sync(FULL_MASK, cnt > 0)) flush();
}
};
// ---------- Kernel -----------------------------------------------------------
template <int R, int T, int W>
__global__ void __launch_bounds__(W * 32) topk_kernel(
const float *__restrict__ x,
float *__restrict__ out_vals,
long *__restrict__ out_idx,
u64 *__restrict__ scratch, // batch * splits * k
int *__restrict__ counters, // batch
int n, int k, int splits, int chunk, int per_warp, int vec_ok)
{
constexpr int K32 = 32 * R;
__shared__ u64 smem[W * K32];
__shared__ int s_last;
const int row = blockIdx.y;
const int split = blockIdx.x;
const int w = threadIdx.x >> 5;
const int lane = threadIdx.x & 31;
const float *xr = x + (size_t)row * n;
WarpSelect<R, T> sel;
sel.init(k, lane);
// ---- Phase A: stream this block's chunk --------------------------------
{
const int cbeg = split * chunk;
const int cend = min(cbeg + chunk, n);
const int wbeg = cbeg + w * per_warp;
const int wend = min(wbeg + per_warp, cend);
if (wbeg < wend) {
const int span = wend - wbeg;
const int nfull = vec_ok ? (span >> 7) : 0; // full 128-elem warp steps
const float4 *base = reinterpret_cast<const float4 *>(xr + wbeg);
for (int s = 0; s < nfull; ++s) {
float4 v = __ldcs(base + s * 32 + lane);
u32 i0 = (u32)(wbeg + ((s * 32 + lane) << 2));
sel.push4(fkey(__float_as_uint(v.x)), fkey(__float_as_uint(v.y)),
fkey(__float_as_uint(v.z)), fkey(__float_as_uint(v.w)), i0);
}
for (int ib = wbeg + (nfull << 7); ib < wend; ib += 32) {
const int i = ib + lane;
const bool act = i < wend;
u32 key = act ? fkey(__float_as_uint(__ldg(xr + i))) : 0u;
if (__ballot_sync(FULL_MASK, act && key > sel.thresh)) {
if (act && key > sel.thresh) sel.q[sel.cnt++] = ((u64)key << 32) | (u64)(u32)i;
if (__ballot_sync(FULL_MASK, sel.cnt == T)) sel.flush();
}
}
}
sel.finish();
// block merge
sel.buf.store(smem + w * K32, lane);
__syncthreads();
if (w == 0) {
#pragma unroll 1
for (int ow = 1; ow < W; ++ow) sel.buf.merge_list(smem + ow * K32, lane);
}
}
if (splits == 1) {
if (w == 0) {
u64 e0 = sel.buf.b0;
if (lane < k) {
out_vals[(size_t)row * k + lane] = unkey((u32)(e0 >> 32));
out_idx[(size_t)row * k + lane] = (long)(u32)e0;
}
if constexpr (R == 2) {
u64 e1 = sel.buf.b1;
int e = 32 + lane;
if (e < k) {
out_vals[(size_t)row * k + e] = unkey((u32)(e1 >> 32));
out_idx[(size_t)row * k + e] = (long)(u32)e1;
}
}
}
return;
}
// ---- publish candidates to scratch --------------------------------------
u64 *srow = scratch + (size_t)row * splits * k;
if (w == 0) {
u64 e0 = sel.buf.b0;
if (lane < k) srow[(size_t)split * k + lane] = e0;
if constexpr (R == 2) {
u64 e1 = sel.buf.b1;
int e = 32 + lane;
if (e < k) srow[(size_t)split * k + e] = e1;
}
}
__syncthreads();
if (threadIdx.x == 0) {
cuda::atomic_ref<int, cuda::thread_scope_device> c(counters[row]);
int prev = c.fetch_add(1, cuda::memory_order_acq_rel);
s_last = (prev == splits - 1) ? 1 : 0;
}
__syncthreads();
if (!s_last) return;
// ---- Phase B: last block merges all candidates --------------------------
{
const int C = splits * k;
sel.init(k, lane);
const int per_w = (C + W - 1) / W;
const int beg = w * per_w;
const int end = min(beg + per_w, C);
for (int i = beg + lane; ; i += 32) {
bool act = i < end;
u64 it = act ? __ldg(srow + i) : 0ULL;
if (!__ballot_sync(FULL_MASK, act)) break;
u32 key = (u32)(it >> 32);
if (__ballot_sync(FULL_MASK, act && key > sel.thresh)) {
if (act && key > sel.thresh) sel.q[sel.cnt++] = it;
if (__ballot_sync(FULL_MASK, sel.cnt == T)) sel.flush();
}
}
sel.finish();
__syncthreads(); // smem reuse
sel.buf.store(smem + w * K32, lane);
__syncthreads();
if (w == 0) {
#pragma unroll 1
for (int ow = 1; ow < W; ++ow) sel.buf.merge_list(smem + ow * K32, lane);
u64 e0 = sel.buf.b0;
if (lane < k) {
out_vals[(size_t)row * k + lane] = unkey((u32)(e0 >> 32));
out_idx[(size_t)row * k + lane] = (long)(u32)e0;
}
if constexpr (R == 2) {
u64 e1 = sel.buf.b1;
int e = 32 + lane;
if (e < k) {
out_vals[(size_t)row * k + e] = unkey((u32)(e1 >> 32));
out_idx[(size_t)row * k + e] = (long)(u32)e1;
}
}
}
if (threadIdx.x == 0) counters[row] = 0;
}
}
// ---------- Host side --------------------------------------------------------
struct Cfg {
const float *x;
float *vals;
long *idx;
u64 *scratch;
int *counters;
int batch, n, k, splits, chunk, per_warp, threads;
int R;
};
static std::vector<Cfg> g_cfgs;
static void launch(const Cfg &c, cudaStream_t st) {
dim3 grid(c.splits, c.batch);
dim3 block(c.threads);
const int W = c.threads / 32;
const int vec_ok = (c.n % 4 == 0) ? 1 : 0;
#define LAUNCH(RR, WW) topk_kernel<RR, 4, WW><<<grid, block, 0, st>>>( \
c.x, c.vals, c.idx, c.scratch, c.counters, c.n, c.k, c.splits, c.chunk, c.per_warp, vec_ok)
if (c.R == 1) {
switch (W) {
case 4: LAUNCH(1, 4); break;
case 8: LAUNCH(1, 8); break;
case 16: LAUNCH(1, 16); break;
default: TORCH_CHECK(false, "bad W");
}
} else {
switch (W) {
case 4: LAUNCH(2, 4); break;
case 8: LAUNCH(2, 8); break;
case 16: LAUNCH(2, 16); break;
default: TORCH_CHECK(false, "bad W");
}
}
#undef LAUNCH
}
int64_t topk_configure(int64_t batch, int64_t n, int64_t k, int64_t splits,
int64_t threads, int64_t vals_ptr, int64_t idx_ptr) {
Cfg c{};
c.batch = (int)batch;
c.n = (int)n;
c.k = (int)k;
c.splits = (int)splits;
c.threads = (int)threads;
c.R = (k > 32) ? 2 : 1;
c.chunk = (int)(((n + splits - 1) / splits + 3) & ~3LL);
// recompute splits so last chunk is non-empty
c.splits = (c.n + c.chunk - 1) / c.chunk;
int W = c.threads / 32;
c.per_warp = ((c.chunk + W - 1) / W + 3) & ~3;
c.vals = (float *)vals_ptr;
c.idx = (long *)idx_ptr;
c.scratch = nullptr;
c.counters = nullptr;
if (c.splits > 1) {
cudaMalloc(&c.scratch, (size_t)c.batch * c.splits * c.k * sizeof(u64));
cudaMalloc(&c.counters, (size_t)c.batch * sizeof(int));
cudaMemset(c.counters, 0, (size_t)c.batch * sizeof(int));
}
g_cfgs.push_back(c);
return (int64_t)g_cfgs.size() - 1;
}
void topk_run(int64_t handle, int64_t x_ptr) {
Cfg &c = g_cfgs[handle];
c.x = (const float *)x_ptr;
launch(c, c10::cuda::getCurrentCUDAStream().stream());
}
void topk_run_cached(int64_t handle) {
const Cfg &c = g_cfgs[handle];
launch(c, c10::cuda::getCurrentCUDAStream().stream());
}
"""
def _build():
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "12.0")
return load_inline(
name="topk_sm120_v1",
cpp_sources=_CPP_SRC,
cuda_sources=_CUDA_SRC,
functions=["topk_configure", "topk_run", "topk_run_cached"],
extra_cuda_cflags=["-O3", "--use_fast_math", "-gencode=arch=compute_120,code=sm_120"],
verbose=False,
)
_ext = _build()
# (splits, threads) per (batch, n, k); fallback heuristic otherwise.
_TUNE = {
(1, 131072, 64): (128, 256),
(64, 8192, 8): (2, 256),
(32, 16384, 32): (4, 256),
(16, 12000, 16): (8, 256),
(128, 4096, 1): (1, 256),
}
def _pick_cfg(batch: int, n: int, k: int):
if (batch, n, k) in _TUNE:
return _TUNE[(batch, n, k)]
target = 256
splits = max(1, min(target // max(batch, 1), (n + 255) // 256))
return splits, 256
class Model(torch.nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch, self.n, self.k = batch, n, k
self.register_buffer("_dummy", torch.zeros(1))
dev = torch.device("cuda")
self._vals = torch.empty(batch, k, dtype=torch.float32, device=dev)
self._idx = torch.empty(batch, k, dtype=torch.int64, device=dev)
splits, threads = _pick_cfg(batch, n, k)
self._handle = _ext.topk_configure(
batch, n, k, splits, threads, self._vals.data_ptr(), self._idx.data_ptr()
)
self._ret = (self._vals, self._idx)
self._last_ptr = -1
self._run = _ext.topk_run
self._run_cached = _ext.topk_run_cached
def forward(self, x: torch.Tensor):
p = x.data_ptr()
if p == self._last_ptr:
self._run_cached(self._handle)
else:
self._last_ptr = p
self._run(self._handle, p)
return self._ret
__call__ = forward
# Module-level shims rebuilt by check.py / benchmark.py per shape.
batch = 64
n = 8192
k = 8
def get_inputs():
x = torch.randn(batch, n, dtype=torch.float32)
return [x]
def get_init_inputs():
return [batch, n, k]
warning: `VIRTUAL_ENV=/home/infatoshi/benchmarks/KernelBench-Hard/.venv` does not match the project environment path `/home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260610_174700_claude_claude-fable-5_05_topk_bitonic/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-11T00:34:42.891807+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-11T00:34:42.903325+00:00 elapsed_s=0.012 ms=0.024384
shape=0 variant=solution tflops=0.022 gbps=21.533 ms=0.024
shape=0 solution_peak_fraction=0.0120
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-11T00:34:42.905305+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-11T00:34:42.910618+00:00 elapsed_s=0.005 ms=0.014192
shape=1 variant=solution tflops=0.148 gbps=148.203 ms=0.014
shape=1 solution_peak_fraction=0.0823
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-11T00:34:42.912376+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-11T00:34:42.917678+00:00 elapsed_s=0.005 ms=0.015696
shape=2 variant=solution tflops=0.134 gbps=134.393 ms=0.016
shape=2 solution_peak_fraction=0.0747
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-11T00:34:42.918314+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-11T00:34:42.923479+00:00 elapsed_s=0.005 ms=0.011872
shape=3 variant=solution tflops=0.065 gbps=64.949 ms=0.012
shape=3 solution_peak_fraction=0.0361
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-11T00:34:42.924813+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-11T00:34:42.929925+00:00 elapsed_s=0.005 ms=0.010560
shape=4 variant=solution tflops=0.199 gbps=198.739 ms=0.011
shape=4 solution_peak_fraction=0.1104
peak_fraction: 0.0494
RESULT: LOW
warning: `VIRTUAL_ENV=/home/infatoshi/benchmarks/KernelBench-Hard/.venv` does not match the project environment path `/home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260610_174700_claude_claude-fable-5_05_topk_bitonic/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260610_174700_claude_claude-fable-5_05_topk_bitonic/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260610_174700_claude_claude-fable-5_05_topk_bitonic/repo/.venv
Creating virtual environment at: /home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260610_174700_claude_claude-fable-5_05_topk_bitonic/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260610_174700_claude_claude-fable-5_05_topk_bitonic/repo
Built kernelbench-hard @ file:///home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260610_174700_claude_claude-fable-5_05_topk_bitonic/repo
Installed 56 packages in 70ms
PASS
{
"run_id": "20260610_174700_claude_claude-fable-5_05_topk_bitonic",
"run_group": "",
"problem": "05_topk_bitonic",
"harness": "claude",
"model": "claude-fable-5",
"reasoning_effort": "max",
"started_at": "2026-06-10T17:47:00-06:00",
"harness_finished_at": "2026-06-10T18:32:01-06:00",
"finished_at": "2026-06-10T18:34:43-06:00",
"start_epoch": 1781135220,
"harness_end_epoch": 1781137921,
"end_epoch": 1781138083,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.0494,
"template_mutated": false,
"elapsed_seconds": 2701,
"total_elapsed_seconds": 2863,
"check_elapsed_seconds": 29,
"benchmark_elapsed_seconds": 133,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 124,
"session_complete": false,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 1.6467974824139209,
"usage": {"input_tokens": 4154, "output_tokens": 4448, "cache_read_tokens": 3204466, "cache_creation_tokens": 358376, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"command": "ls -la /workspace/problems/05_topk_bitonic/",
"description": "List files in problem workspace"
}{
"file_path": "/workspace/problems/05_topk_bitonic/reference.py"
}1 """Naive top-k reference: torch.topk over the last dim.
2
3 This is the correctness oracle. The agent's solution must produce the same
4 top-k values (and equivalent indices modulo ties) within the tolerance
5 declared in problem.yaml. Note that solution.py is FORBIDDEN from calling
6 torch.topk / torch.sort / torch.kthvalue (see problem.yaml).
7 """
8 import torch
9 import torch.nn as nn
10
11 OP_TYPE = "topk"
12 SUPPORTED_PRECISIONS = ["fp32"]
13 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
14
15
16 class Model(nn.Module):
17 """Top-k over the last dim of a 2D tensor.
18
19 Input:
20 x: (batch, n) fp32
21 Output:
22 values: (batch, k) fp32, sorted descending
23 indices: (batch, k) int64, into the last dim of x
24 """
25
26 def __init__(self, batch: int, n: int, k: int):
27 super().__init__()
28 self.batch, self.n, self.k = batch, n, k
29 # No learned parameters, but declare a dummy buffer so state_dict
30 # is non-empty and load_state_dict(strict=True) is meaningful.
31 self.register_buffer("_dummy", torch.zeros(1))
32
33 def forward(self, x: torch.Tensor):
34 values, indices = torch.topk(x, k=self.k, dim=-1, largest=True, sorted=True)
35 return values, indices
36
37
38 # Module-level shims rebuilt by check.py / benchmark.py per shape.
39 batch = 64
40 n = 8192
41 k = 8
42
43
44 def get_inputs():
45 # fp32 input drawn from a roughly Gaussian dist
... (211 more chars)total 44
drwxrwxr-x 2 1002 1002 4096 Jun 10 23:47 .
drwxrwxr-x 3 1002 1002 4096 Jun 10 23:47 ..
-rw-rw-r-- 1 1002 1002 2084 Apr 27 09:21 PROMPT.txt
-rw-rw-r-- 1 1002 1002 5078 Jun 1 21:05 benchmark.py
-rw-rw-r-- 1 1002 1002 6288 Jun 2 01:09 check.py
-rw-rw-r-- 1 1002 1002 2199 Apr 26 04:20 problem.yaml
-rw-rw-r-- 1 1002 1002 1561 Apr 26 04:19 reference.py
-rw-rw-r-- 1 1002 1002 933 Apr 26 04:19 shapes.py
-rw-rw-r-- 1 1002 1002 945 Apr 26 04:20 sota.py{
"file_path": "/workspace/problems/05_topk_bitonic/check.py"
}1 """Correctness runner for TopK.
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each. Top-k correctness has two parts:
5
6 1. VALUES: sol_values must match ref_values within fp32 tol. Both are
7 returned sorted descending, so positional comparison is well-defined.
8 2. INDICES: lenient — we do NOT require sol_indices == ref_indices because
9 ties in x can yield multiple valid index sets. Instead we gather x at
10 sol_indices and check those values match ref_values within tol. This
11 catches "wrong indices" without false-failing on legitimate tie-breaks.
12
13 Also rejects forbidden ops by grep.
14 """
15 import re
16 import sys
17 from pathlib import Path
18
19 import torch
20 import yaml
21
22 REPO_ROOT = Path(__file__).resolve().parents[2]
23 sys.path.insert(0, str(REPO_ROOT))
24
25 from src.eval.correctness import check_correctness # noqa: E402
26 from src.eval.numeric_stress import ( # noqa: E402
27 numeric_stress_cases,
28 numeric_stress_context,
29 tolerance_for_case,
30 )
31
32
33 def main():
34 try:
35 import reference
36 import shapes
37 import solution
38 except Exception as e:
39 print(f"FAIL: import error: {e}")
40 sys.exit(1)
41
42 problem_yaml = Path("problem.yaml")
43 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
44
45 # --- Forbidden-op check --------------------------------------
... (5310 more chars){
"file_path": "/workspace/problems/05_topk_bitonic/benchmark.py"
}1 """Roofline benchmark for TopK.
2
3 For each shape: times eager reference (torch.topk), compiled reference, SOTA
4 (also torch.topk — see sota.py), and the agent's solution. Reports achieved
5 TFLOPS, GB/s, and peak_fraction (vs DRAM bandwidth, since this is memory-bound).
6
7 Output lines the harness picks up:
8 shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9 peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10 """
11 import sys
12 from math import exp, log
13 from pathlib import Path
14
15 import torch
16 import yaml
17
18 REPO_ROOT = Path(__file__).resolve().parents[2]
19 sys.path.insert(0, str(REPO_ROOT))
20
21 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22 from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
23 from src.hardware import get as get_hw # noqa: E402
24
25
26 def _eval_formula(expr: str, vars: dict) -> float:
27 return float(eval(expr, {"__builtins__": {}}, vars))
28
29
30 def main():
31 import reference
32 import shapes
33 import solution
34
35 meta = yaml.safe_load(Path("problem.yaml").read_text())
36 hw = get_hw(meta["hardware"][0])
37 peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38 peak_gbps = hw.peak_bandwidth_gb_s
39 regime = meta.get("regime", "memory")
40 flops_formula = meta["flops_formula"]
41 bytes_formula = meta["bytes_formula"]
42 num_perf_trials = int(meta.
... (4052 more chars){
"file_path": "/workspace/problems/05_topk_bitonic/shapes.py"
}1 """Canonical shape sweep for TopK.
2
3 Mix of:
4 - decoder vocab top-k (single sequence, very large n, moderate k) — pure
5 bandwidth test; the input read dominates everything.
6 - prefill / batched attention top-k (many rows, moderate n, small k) — tests
7 per-row parallelism and shared-memory bitonic networks.
8 - non-power-of-2 n stress case — bitonic sort networks naturally want
9 powers of two; this forces the agent to handle padding or partial sorts.
10 - small-k limit — k=1 (argmax) is a degenerate but useful sanity case.
11 """
12
13 SHAPES = [
14 {"batch": 1, "n": 131072, "k": 64}, # decoder vocab top-k (Llama vocab ~128k)
15 {"batch": 64, "n": 8192, "k": 8}, # prefill / attention top-k
16 {"batch": 32, "n": 16384, "k": 32}, # mid-size batched
17 {"batch": 16, "n": 12000, "k": 16}, # non-power-of-2 n stress
18 {"batch": 128, "n": 4096, "k": 1}, # batched argmax (k=1 corner case)
19 ]
20 {
"file_path": "/workspace/problems/05_topk_bitonic/problem.yaml"
}1 name: 05_topk_bitonic
2 display_name: "TopK via Bitonic Sort"
3 precision: fp32
4 regime: memory
5
6 # Top-k is dominated by the input read (small output, no reduction over k).
7 # Comparator-network FLOPs are not the bottleneck on real hardware, so we
8 # track them but score on bandwidth.
9 flops_formula: "batch * n * 4" # ~O(n log^2 n) compares total, but the
10 # bitonic network is so cheap relative
11 # to memory that a coarse 4*n estimate
12 # is fine for telemetry only.
13 bytes_formula: "batch * n * 4 + batch * k * (4 + 8)" # fp32 input read + (fp32 value + int64 idx) output
14
15 hardware: [RTX_PRO_6000]
16 peak_tflops_key: fp32
17 peak_bandwidth_key: dram
18
19 # Top-k correctness:
20 # - VALUES must match within fp32 atol/rtol (the kth largest value is
21 # well-defined modulo float-equal ties, so we use a loose-ish tol).
22 # - INDICES are checked leniently: for each row, the multiset of returned
23 # indices must select values that match ref values within tol. Direct
24 # index equality is NOT required (ties in x can yield different valid
25 # index sets).
26 tolerance:
27 float32: 1.0e-4
28
29 # Forbidden ops — using any of these in solution.py fails correctness post-hoc.
30 # This problem is about IMPLEMENTING the selection, not dispatching to PyTorch's
31 # tuned top-k. torch.sort is also banned b
... (859 more chars){
"file_path": "/workspace/problems/05_topk_bitonic/sota.py"
}1 """SOTA reference for TopK: torch.topk itself.
2
3 torch.topk dispatches to a CUB-backed kernel that uses radix-select for
4 moderate k and a tuned bitonic sort for small n. It is the bar the agent's
5 hand-rolled bitonic kernel must beat. There is no obvious vendor library that
6 does better on the (batch, n, k) shape mix we evaluate — Faiss BlockSelect is
7 specialized for k>=32 with much larger n, and CUB's DeviceSegmentedRadixSort
8 sorts the full row (overkill for top-k).
9
10 This file is INTENTIONALLY allowed to call torch.topk because it is the SOTA
11 oracle, not the agent's submission. The agent's solution.py is forbidden from
12 using torch.topk (see problem.yaml.forbidden).
13 """
14 from __future__ import annotations
15
16 import torch
17
18
19 def sota_forward(x: torch.Tensor, k: int):
20 """Best-available top-k reference. x: (batch, n) fp32."""
21 return torch.topk(x, k=k, dim=-1, largest=True, sorted=True)
22
23
24 def is_available() -> bool:
25 return True
26 {
"command": "ls /workspace/ && ls /workspace/src/eval/ 2>/dev/null",
"description": "List workspace and eval source directories"
}problems
pyproject.toml
src
uv.lock
__init__.py
__pycache__
correctness.py
numeric_stress.py
report.py
roofline.py
shapes.py
timing.py{
"file_path": "/workspace/src/eval/timing.py"
}1 """Centralized GPU timing for benchmark.py files.
2
3 All problems' benchmark.py should call `time_fn` from here so we have one
4 implementation of warmup, L2 flush, and CUDA event capture to maintain.
5
6 Methodology:
7 - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8 torch.compile reduce-overhead CUDA-graph capture.
9 - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10 (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11 we measure HBM-load bandwidth, not L2-cached bandwidth).
12 - GPU timing via cuda.Event with synchronize() AFTER record() but
13 BEFORE elapsed_time() — the canonical NVIDIA pattern.
14 - Reported value is the median of `iters` trials, robust to outliers.
15
16 Notes / known biases not addressed here:
17 - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18 launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19 On small shapes where launch overhead matters, this gives compile an
20 artificial advantage. We accept this as the cost of using the agreed
21 "compiled" baseline policy.
22 - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23 - Median over a fairly small number of trials (default 30) is fine for
24 headline numbers but won't catch bimodal distributions. Use --extra
25 instrumentation if you ever care.
26 """
27 from __future__ import annotations
28
29 import os
30 import re
31
... (3182 more chars){
"file_path": "/workspace/src/eval/numeric_stress.py"
}1 """Numeric distribution stress cases for correctness checks.
2
3 These cases are not hidden shapes. They rerun the same shape/seed validation
4 under a few scale regimes that catch kernels that only work for the nominal
5 N(0, 1)-ish inputs.
6 """
7 from __future__ import annotations
8
9 import os
10 from collections.abc import Iterator, Mapping, Sequence
11 from contextlib import contextmanager
12 from dataclasses import dataclass, field
13
14 import torch
15
16 ToleranceOverride = Mapping[str, float | Mapping[str, float]]
17
18
19 @dataclass(frozen=True)
20 class NumericStressCase:
21 name: str
22 input_scales: Mapping[int, float] = field(default_factory=dict)
23 state_scales: Mapping[str, float] = field(default_factory=dict)
24 tolerance: ToleranceOverride | None = None
25
26
27 NOMINAL_CASE = NumericStressCase("nominal")
28
29 _SMALL_BF16 = {"bfloat16": {"atol": 1e-4, "rtol": 5e-2}}
30 _MED_BF16 = {"bfloat16": {"atol": 5e-4, "rtol": 5e-2}}
31 _LARGE_BF16 = {"bfloat16": {"atol": 2e-1, "rtol": 5e-2}}
32 _TINY_FP32 = {"float32": {"atol": 1e-7, "rtol": 1e-4}}
33 _LARGE_FP32 = {"float32": {"atol": 1e-1, "rtol": 1e-4}}
34
35 _CASES: dict[str, tuple[NumericStressCase, ...]] = {
36 "01_fp8_gemm": (
37 NumericStressCase("small_input", input_scales={0: 1e-3}, tolerance=_SMALL_BF16),
38 NumericStressCase("large_input", input_scales={0: 64.0}, tolerance=_LARGE_BF16),
39 NumericStressCase(
40 "small_weight",
41
... (5156 more chars){
"command": "nvidia-smi && uv run python -c \"import torch; print(torch.__version__, torch.version.cuda); print(torch.cuda.get_device_name(0)); print(torch.cuda.get_device_properties(0))\"",
"description": "Check GPU and torch environment"
}Wed Jun 10 23:58:51 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 610.43.02 KMD Version: 610.43.02 CUDA UMD Version: 13.3 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 Off | Off |
| 32% 45C P0 90W / 600W | 15MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (1043 more chars){
"command": "cat /workspace/src/eval/roofline.py /workspace/src/eval/correctness.py | head -120 && grep -rn \"RTX_PRO_6000\" /workspace/src/hardware* 2>/dev/null | head; ls /workspace/src/",
"description": "Read roofline, correctness, hardware specs"
}"""Roofline math: achieved TFLOPS / GB/s, peak fraction.
FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
dense-equivalent algorithmic work; kernels that skip work (sparsity, early
exit) are still measured against the dense cost.
"""
from __future__ import annotations
def compute_tflops(flops: float, time_ms: float) -> float:
"""flops per invocation and wall time in ms -> TFLOPS."""
if time_ms <= 0:
return 0.0
return flops / (time_ms * 1e-3) / 1e12
def compute_gbps(bytes_moved: float, time_ms: float) -> float:
"""bytes per invocation and wall time in ms -> GB/s."""
if time_ms <= 0:
return 0.0
return bytes_moved / (time_ms * 1e-3) / 1e9
def peak_fraction(
achieved: float,
peak: float,
) -> float:
"""Return achieved / peak, clamped to [0, infinity)."""
if peak <= 0:
return 0.0
return max(0.0, achieved / peak)
"""Per-dtype correctness tolerance.
Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
We extend to fp8 (0.1, permissive) and int (exact).
"""
from __future__ import annotations
import torch
DEFAULT_TOLERANCE = {
torch.float32: {"atol": 1e-4, "rtol": 1e-4},
torch.float16: {"atol": 1e-2, "rtol": 1e-2},
torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
torch.int8: {"atol": 0, "rtol": 0},
torch.int32: {"atol": 0, "rtol": 0},
... (3095 more chars){
"file_path": "/workspace/src/hardware/rtx_pro_6000.py"
}1 """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2
3 Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4 will see 60-85% of peak on well-tuned code.
5 """
6 from dataclasses import dataclass
7
8
9 @dataclass(frozen=True)
10 class HardwareTarget:
11 name: str
12 sm: str
13 vram_gb: int
14 peak_bandwidth_gb_s: float # DRAM
15 peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16
17
18 RTX_PRO_6000 = HardwareTarget(
19 name="RTX PRO 6000 Blackwell Workstation",
20 sm="sm_120a",
21 vram_gb=96,
22 peak_bandwidth_gb_s=1800.0,
23 peak_tflops_dense={
24 "fp4": 800.0,
25 "nvfp4": 800.0,
26 "mxfp4": 800.0,
27 "fp6": 800.0,
28 "fp8": 400.0,
29 "bf16": 200.0,
30 "fp16": 200.0,
31 "tf32": 100.0,
32 "fp32": 12.0, # non-tensor-core SIMT fp32
33 "int8": 400.0,
34 "int4": 800.0,
35 },
36 )
37 {
"command": "which nvcc; nvcc --version 2>/dev/null | tail -2; echo \"---\"; ls /usr/local/ | head",
"description": "Check nvcc availability"
}Command running in background with ID: bjz2w3tbp. Output is being written to: /tmp/claude-1002/-workspace-problems-05-topk-bitonic/3797c0bb-21bf-4ce4-9897-446e9d3626a7/tasks/bjz2w3tbp.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"file_path": "/tmp/claude-1002/-workspace-problems-05-topk-bitonic/3797c0bb-21bf-4ce4-9897-446e9d3626a7/tasks/bjz2w3tbp.output"
}1 /kbh/bin/nvcc
2 --- /dev/null
+++ b//workspace/problems/05_topk_bitonic/scratch/floor.py
@@ -0,0 +1,78 @@
+"""Measure the timing floor under harness methodology: empty kernel, pure-read kernel."""
+import statistics
+import sys
+
+import torch
+from torch.utils.cpp_extension import load_inline
+
+src = r"""
+#include <torch/extension.h>
+#include <cuda_runtime.h>
+
+__global__ void empty_kernel() {}
+
+__global__ void read_kernel(const float4* __restrict__ x, float* __restrict__ out, long n4) {
+ long i = (long)blockIdx.x * blockDim.x + threadIdx.x;
+ float acc = 0.f;
+ long stride = (long)gridDim.x * blockDim.x;
+ for (; i < n4; i += stride) {
+ float4 v = x[i];
+ acc += v.x + v.y + v.z + v.w;
+ }
+ if (acc == 1234.5678f) out[0] = acc; // never true; keeps the load alive
+}
+
+void launch_empty() {
+ empty_kernel<<<1, 32>>>();
+}
+
+void launch_read(torch::Tensor x, torch::Tensor out, int blocks, int threads) {
+ long n4 = x.numel() / 4;
+ read_kernel<<<blocks, threads>>>(
+ reinterpret_cast<const float4*>(x.data_ptr<float>()), out.data_ptr<float>(), n4);
+}
+"""
+
+mod = load_inline(
+ name="floor_probe",
+ cpp_sources="void launch_empty(); void launch_read(torch::Tensor, torch::Tensor, int, int);",
+ cuda_sources=src,
+ functions=["launch_empty", "launch_read"],
+ extra_cuda_cflags=["-O3", "--use_fast_math", "-gencode=arch=compute_120,code=sm_120"],
+ verbose=False,
+)
+
+_L2 = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
+
+
+def time_fn(fn, iters=50, warmup=10):
+ for _ in range(warmup):
+ fn()
+ torch.cuda.synchronize()
+ times = []
+ for _ in range(iters):
+ _L2.zero_()
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ fn()
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
+
+
+out = torch.zeros(8, device="cuda")
+print(f"empty kernel: {time_fn(mod.launch_empty)*1000:.2f} us")
+
+for nbytes_mb, label in [(0.5, "0.5MB"), (2.0, "2MB"), (8.0, "8MB"), (64.0, "64MB")]:
+ n = int(nbytes_mb * 1024 * 1024 / 4)
+ x = torch.randn(n, device="cuda")
+ for blocks in [148, 188, 376, 752, 1504]:
+ ms = time_fn(lambda: mod.launch_read(x, out, blocks, 256))
+ gbps = n * 4 / (ms * 1e-3) / 1e9
+ print(f"read {label} blocks={blocks}: {ms*1000:7.2f} us {gbps:7.1f} GB/s")
+
+# Also measure: python-side null op (event pair only)
+print(f"event pair only: {time_fn(lambda: None)*1000:.2f} us")
{
"file_path": "/workspace/problems/05_topk_bitonic/scratch/floor.py",
"content": "<2479 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/05_topk_bitonic/scratch/floor.py (file state is current in your context — no need to Read it back){
"query": "select:TaskCreate,TaskUpdate",
"max_results": 2
}{
"subject": "Measure timing floor (empty/read kernels)",
"description": "Run scratch/floor.py to establish launch-overhead floor and achievable read bandwidth under harness timing methodology.",
"activeForm": "Measuring timing floor"
}Task #1 created successfully: Measure timing floor (empty/read kernels){
"subject": "Implement single-kernel top-k CUDA solution",
"description": "Write solution.py with a fused one-launch top-k kernel (per-warp register selection + block merge; multi-block + last-block-reduce for batch=1). Preallocate outputs in Model.__init__.",
"activeForm": "Implementing top-k kernel"
}Task #2 created successfully: Implement single-kernel top-k CUDA solution{
"subject": "Pass check.py and optimize benchmark",
"description": "Iterate: python check.py must print PASS; then push peak_fraction in benchmark.py via ncu/profiling.",
"activeForm": "Verifying and optimizing"
}Task #3 created successfully: Pass check.py and optimize benchmark{
"taskId": "1",
"status": "in_progress"
}Updated task #1 status{
"command": "uv run python scratch/floor.py 2>&1 | tail -30",
"description": "Run timing floor measurement",
"timeout": 600000
}empty kernel: 4.67 us
read 0.5MB blocks=148: 5.60 us 93.6 GB/s
read 0.5MB blocks=188: 5.73 us 91.5 GB/s
read 0.5MB blocks=376: 5.57 us 94.2 GB/s
read 0.5MB blocks=752: 5.39 us 97.2 GB/s
read 0.5MB blocks=1504: 5.62 us 93.4 GB/s
read 2MB blocks=148: 5.71 us 367.1 GB/s
read 2MB blocks=188: 5.60 us 374.5 GB/s
read 2MB blocks=376: 5.42 us 386.6 GB/s
read 2MB blocks=752: 5.60 us 374.5 GB/s
read 2MB blocks=1504: 5.76 us 364.1 GB/s
read 8MB blocks=148: 7.86 us 1067.8 GB/s
read 8MB blocks=188: 8.37 us 1002.5 GB/s
read 8MB blocks=376: 7.92 us 1059.2 GB/s
read 8MB blocks=752: 7.58 us 1106.1 GB/s
read 8MB blocks=1504: 7.82 us 1072.2 GB/s
read 64MB blocks=148: 47.90 us 1400.9 GB/s
read 64MB blocks=188: 47.49 us 1413.2 GB/s
read 64MB blocks=376: 46.94 us 1429.6 GB/s
read 64MB blocks=752: 46.74 us 1435.9 GB/s
read 64MB blocks=1504: 46.90 us 1431.0 GB/s
event pair only: 2.29 us{
"taskId": "1",
"status": "completed"
}Updated task #1 status{
"taskId": "2",
"status": "in_progress"
}Updated task #2 status--- /dev/null
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -0,0 +1,497 @@
+"""Single-kernel top-k for SM120 (RTX PRO 6000).
+
+Strategy: one fused CUDA kernel per forward. Each row is split across one or
+more blocks; each warp streams its span keeping a register-resident sorted
+top-K (Faiss WarpSelect-style threshold + per-lane candidate queue, bitonic
+networks via warp shuffles). Warps merge through shared memory. Multi-split
+rows publish packed (key|idx) candidates to a scratch buffer; the last block
+per row (device-scope acq-rel counter) merges them and writes the output.
+
+Values travel as a monotonic fp32->u32 key packed with the source index into
+a u64, so all comparisons are single integer compares and the exact fp32 bits
+round-trip on output.
+"""
+import os
+
+import torch
+from torch.utils.cpp_extension import load_inline
+
+_CPP_SRC = r"""
+#include <cstdint>
+int64_t topk_configure(int64_t batch, int64_t n, int64_t k, int64_t splits,
+ int64_t threads, int64_t vals_ptr, int64_t idx_ptr);
+void topk_run(int64_t handle, int64_t x_ptr);
+void topk_run_cached(int64_t handle);
+"""
+
+_CUDA_SRC = r"""
+#include <cuda_runtime.h>
+#include <cuda/atomic>
+#include <c10/cuda/CUDAStream.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/extension.h>
+#include <vector>
+
+using u32 = unsigned int;
+using u64 = unsigned long long;
+
+#define FULL_MASK 0xffffffffu
+
+// Monotonic transform: descending float order == descending u32 key order.
+__device__ __forceinline__ u32 fkey(u32 bits) {
+ return bits ^ (u32(int(bits) >> 31) | 0x80000000u);
+}
+__device__ __forceinline__ float unkey(u32 key) {
+ u32 b = key ^ ((key & 0x80000000u) ? 0x80000000u : 0xFFFFFFFFu);
+ return __uint_as_float(b);
+}
+__device__ __forceinline__ u64 umax(u64 a, u64 b) { return a > b ? a : b; }
+__device__ __forceinline__ u64 umin(u64 a, u64 b) { return a < b ? a : b; }
+
+// Descending bitonic merge of 32 elements held one-per-lane.
+__device__ __forceinline__ void bmerge32(u64 &v, int lane) {
+ #pragma unroll
+ for (int d = 16; d >= 1; d >>= 1) {
+ u64 o = __shfl_xor_sync(FULL_MASK, v, d);
+ v = ((lane & d) == 0) ? umax(v, o) : umin(v, o);
+ }
+}
+
+// Descending bitonic merge of 64 elements: element e = j*32+lane in regs a0,a1.
+__device__ __forceinline__ void bmerge64(u64 &a0, u64 &a1, int lane) {
+ u64 lo = umin(a0, a1);
+ a0 = umax(a0, a1);
+ a1 = lo;
+ bmerge32(a0, lane);
+ bmerge32(a1, lane);
+}
+
+// Full descending sort of 32 elements one-per-lane.
+__device__ __forceinline__ void sort32(u64 &v, int lane) {
+ #pragma unroll
+ for (int s = 2; s <= 32; s <<= 1) {
+ #pragma unroll
+ for (int d = s >> 1; d >= 1; d >>= 1) {
+ u64 o = __shfl_xor_sync(FULL_MASK, v, d);
+ bool i_low = (lane & d) == 0;
+ bool asc = (lane & s) != 0; // gives overall descending
+ bool take_max = i_low ^ asc;
+ v = take_max ? umax(v, o) : umin(v, o);
+ }
+ }
+}
+
+// ---------- Warp-resident top-K buffer -------------------------------------
+// R registers per lane; element e = j*32 + lane; sorted descending by e.
+
+template <int R> struct WarpBuf;
+
+template <> struct WarpBuf<1> {
+ u64 b0;
+ __device__ __forceinline__ void init() { b0 = 0; }
+ // Merge a descending-sorted 32-vector (one per lane) into the buffer.
+ __device__ __forceinline__ void merge_sorted32(u64 v, int lane) {
+ u64 r = __shfl_xor_sync(FULL_MASK, v, 31); // reversed
+ b0 = umax(b0, r);
+ bmerge32(b0, lane);
+ }
+ __device__ __forceinline__ u32 kth_key(int k, int lane) {
+ u64 e = __shfl_sync(FULL_MASK, b0, (k - 1) & 31);
+ return (u32)(e >> 32);
+ }
+ __device__ __forceinline__ void store(u64 *dst, int lane) { dst[lane] = b0; }
+ __device__ __forceinline__ void load(const u64 *src, int lane) { b0 = src[lane]; }
+ // Merge another descending list of 32 (in shared, reversed addressing).
+ __device__ __forceinline__ void merge_list(const u64 *src, int lane) {
+ b0 = umax(b0, src[31 - lane]);
+ bmerge32(b0, lane);
+ }
+};
+
+template <> struct WarpBuf<2> {
+ u64 b0, b1;
+ __device__ __forceinline__ void init() { b0 = 0; b1 = 0; }
+ __device__ __forceinline__ void merge_sorted32(u64 v, int lane) {
+ // positions 32..63 paired with reversed v
+ u64 r = __shfl_xor_sync(FULL_MASK, v, 31);
+ b1 = umax(b1, r);
+ bmerge64(b0, b1, lane);
+ }
+ __device__ __forceinline__ u32 kth_key(int k, int lane) {
+ u64 e = (k <= 32) ? __shfl_sync(FULL_MASK, b0, (k - 1) & 31)
+ : __shfl_sync(FULL_MASK, b1, (k - 1) & 31);
+ return (u32)(e >> 32);
+ }
+ __device__ __forceinline__ void store(u64 *dst, int lane) {
+ dst[lane] = b0; dst[32 + lane] = b1;
+ }
+ __device__ __forceinline__ void load(const u64 *src, int lane) {
+ b0 = src[lane]; b1 = src[32 + lane];
+ }
+ __device__ __forceinline__ void merge_list(const u64 *src, int lane) {
+ b0 = umax(b0, src[63 - lane]);
+ b1 = umax(b1, src[31 - lane]);
+ bmerge64(b0, b1, lane);
+ }
+};
+
+// ---------- Streaming select: queue + flush ---------------------------------
+
+template <int R, int T> struct WarpSelect {
+ WarpBuf<R> buf;
+ u64 q[T];
+ int cnt;
+ u32 thresh;
+ int k;
+ int lane;
+
+ __device__ __forceinline__ void init(int k_, int lane_) {
+ buf.init();
+ cnt = 0;
+ thresh = 0;
+ k = k_;
+ lane = lane_;
+ }
+
+ __device__ __forceinline__ void flush() {
+ #pragma unroll
+ for (int j = 0; j < T; ++j) {
+ // skip slots empty across the whole warp
+ if (!__ballot_sync(FULL_MASK, cnt > j)) break;
+ u64 v = (j < cnt) ? q[j] : 0ULL;
+ sort32(v, lane);
+ buf.merge_sorted32(v, lane)
... (11037 more chars){
"file_path": "/workspace/problems/05_topk_bitonic/solution.py",
"content": "<16454 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/05_topk_bitonic/solution.py (file state is current in your context — no need to Read it back)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -231,9 +231,14 @@
sel.push4(fkey(__float_as_uint(v.x)), fkey(__float_as_uint(v.y)),
fkey(__float_as_uint(v.z)), fkey(__float_as_uint(v.w)), i0);
}
- for (int i = wbeg + (nfull << 7) + lane; i < wend; i += 32) {
- u32 key = fkey(__float_as_uint(__ldg(xr + i)));
- if (__ballot_sync(FULL_MASK, key > sel.thresh)) sel.push(key, (u32)i);
+ for (int ib = wbeg + (nfull << 7); ib < wend; ib += 32) {
+ const int i = ib + lane;
+ const bool act = i < wend;
+ u32 key = act ? fkey(__float_as_uint(__ldg(xr + i))) : 0u;
+ if (__ballot_sync(FULL_MASK, act && key > sel.thresh)) {
+ if (act && key > sel.thresh) sel.q[sel.cnt++] = ((u64)key << 32) | (u64)(u32)i;
+ if (__ballot_sync(FULL_MASK, sel.cnt == T)) sel.flush();
+ }
}
}
sel.finish();
{
"replace_all": false,
"file_path": "/workspace/problems/05_topk_bitonic/solution.py",
"old_string": "<270 chars \u2014 see diff>",
"new_string": "<542 chars \u2014 see diff>"
}The file /workspace/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -254,27 +254,13 @@
if (splits == 1) {
if (w == 0) {
- #pragma unroll
- for (int j = 0; j < R; ++j) {
- int e = j * 32 + lane;
- if (e < k) {
- u64 p = (j == 0) ? sel.buf.b0 : sel.buf.kth_key(0, 0), q;
- }
- }
- // (specialized below; see store_out)
- }
- }
-
- // Unified output path helper is below via macro-free explicit code:
- if (splits == 1) {
- if (w == 0) {
u64 e0 = sel.buf.b0;
if (lane < k) {
out_vals[(size_t)row * k + lane] = unkey((u32)(e0 >> 32));
out_idx[(size_t)row * k + lane] = (long)(u32)e0;
}
- if (R == 2) {
- u64 e1 = ((WarpBuf<2> *)&sel.buf)->b1;
+ if constexpr (R == 2) {
+ u64 e1 = sel.buf.b1;
int e = 32 + lane;
if (e < k) {
out_vals[(size_t)row * k + e] = unkey((u32)(e1 >> 32));
@@ -290,8 +276,8 @@
if (w == 0) {
u64 e0 = sel.buf.b0;
if (lane < k) srow[(size_t)split * k + lane] = e0;
- if (R == 2) {
- u64 e1 = ((WarpBuf<2> *)&sel.buf)->b1;
+ if constexpr (R == 2) {
+ u64 e1 = sel.buf.b1;
int e = 32 + lane;
if (e < k) srow[(size_t)split * k + e] = e1;
}
{
"replace_all": false,
"file_path": "/workspace/problems/05_topk_bitonic/solution.py",
"old_string": "<1459 chars \u2014 see diff>",
"new_string": "<1011 chars \u2014 see diff>"
}The file /workspace/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -321,8 +321,8 @@
out_vals[(size_t)row * k + lane] = unkey((u32)(e0 >> 32));
out_idx[(size_t)row * k + lane] = (long)(u32)e0;
}
- if (R == 2) {
- u64 e1 = ((WarpBuf<2> *)&sel.buf)->b1;
+ if constexpr (R == 2) {
+ u64 e1 = sel.buf.b1;
int e = 32 + lane;
if (e < k) {
out_vals[(size_t)row * k + e] = unkey((u32)(e1 >> 32));
{
"replace_all": false,
"file_path": "/workspace/problems/05_topk_bitonic/solution.py",
"old_string": "<318 chars \u2014 see diff>",
"new_string": "<310 chars \u2014 see diff>"
}The file /workspace/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- /dev/null
+++ b//workspace/problems/05_topk_bitonic/scratch/test_sol.py
@@ -0,0 +1,47 @@
+import sys
+import torch
+
+sys.path.insert(0, "/workspace/problems/05_topk_bitonic")
+import solution
+
+SHAPES = [
+ (1, 131072, 64),
+ (64, 8192, 8),
+ (32, 16384, 32),
+ (16, 12000, 16),
+ (128, 4096, 1),
+ # extra edge shapes
+ (3, 700, 5),
+ (2, 64, 64),
+ (1, 65, 64),
+ (5, 4097, 33),
+ (7, 12345, 7),
+]
+
+dev = torch.device("cuda")
+fails = 0
+for (b, n, k) in SHAPES:
+ m = solution.Model(b, n, k).to(dev).eval()
+ for seed in (0, 1, 2):
+ torch.manual_seed(seed)
+ for scale in (1.0, 1e-4, 1e3):
+ x = torch.randn(b, n, device=dev) * scale
+ rv, ri = torch.topk(x, k=k, dim=-1, largest=True, sorted=True)
+ sv, si = m(x)
+ torch.cuda.synchronize()
+ ok_v = torch.allclose(rv, sv, atol=1e-4 * max(scale, 1), rtol=1e-4)
+ g = torch.gather(x, -1, si)
+ ok_i = torch.allclose(rv, g, atol=1e-4 * max(scale, 1), rtol=1e-4)
+ in_range = (si.min() >= 0) and (si.max() < n)
+ if not (ok_v and ok_i and in_range):
+ fails += 1
+ bad = (rv - sv).abs().max().item()
+ print(f"FAIL b={b} n={n} k={k} seed={seed} scale={scale} "
+ f"ok_v={ok_v} ok_i={ok_i} in_range={in_range} maxdiff={bad}")
+ if fails < 3:
+ row = (rv - sv).abs().amax(-1).argmax().item()
+ print(" ref:", rv[row][:10].tolist())
+ print(" sol:", sv[row][:10].tolist())
+ print(f"done b={b} n={n} k={k}")
+
+print("ALL OK" if fails == 0 else f"{fails} FAILURES")
{
"file_path": "/workspace/problems/05_topk_bitonic/scratch/test_sol.py",
"content": "<1579 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/05_topk_bitonic/scratch/test_sol.py (file state is current in your context — no need to Read it back){
"command": "uv run python scratch/test_sol.py 2>&1 | tail -25",
"description": "Build and test solution against torch.topk",
"timeout": 600000
}done b=1 n=131072 k=64
done b=64 n=8192 k=8
done b=32 n=16384 k=32
done b=16 n=12000 k=16
done b=128 n=4096 k=1
done b=3 n=700 k=5
done b=2 n=64 k=64
done b=1 n=65 k=64
done b=5 n=4097 k=33
done b=7 n=12345 k=7
ALL OK--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -199,7 +199,7 @@
long *__restrict__ out_idx,
u64 *__restrict__ scratch, // batch * splits * k
int *__restrict__ counters, // batch
- int n, int k, int splits, int chunk, int per_warp)
+ int n, int k, int splits, int chunk, int per_warp, int vec_ok)
{
constexpr int K32 = 32 * R;
__shared__ u64 smem[W * K32];
{
"replace_all": false,
"file_path": "/workspace/problems/05_topk_bitonic/solution.py",
"old_string": " int n, int k, int splits, int chunk, int per_warp)",
"new_string": " int n, int k, int splits, int chunk, int per_warp, int vec_ok)"
}The file /workspace/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -223,7 +223,7 @@
const int wend = min(wbeg + per_warp, cend);
if (wbeg < wend) {
const int span = wend - wbeg;
- const int nfull = span >> 7; // full 128-element warp steps
+ const int nfull = vec_ok ? (span >> 7) : 0; // full 128-elem warp steps
const float4 *base = reinterpret_cast<const float4 *>(xr + wbeg);
for (int s = 0; s < nfull; ++s) {
float4 v = __ldcs(base + s * 32 + lane);
{
"replace_all": false,
"file_path": "/workspace/problems/05_topk_bitonic/solution.py",
"old_string": "<114 chars \u2014 see diff>",
"new_string": "<126 chars \u2014 see diff>"
}The file /workspace/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -352,8 +352,9 @@
dim3 grid(c.splits, c.batch);
dim3 block(c.threads);
const int W = c.threads / 32;
+ const int vec_ok = (c.n % 4 == 0) ? 1 : 0;
#define LAUNCH(RR, WW) topk_kernel<RR, 4, WW><<<grid, block, 0, st>>>( \
- c.x, c.vals, c.idx, c.scratch, c.counters, c.n, c.k, c.splits, c.chunk, c.per_warp)
+ c.x, c.vals, c.idx, c.scratch, c.counters, c.n, c.k, c.splits, c.chunk, c.per_warp, vec_ok)
if (c.R == 1) {
switch (W) {
case 4: LAUNCH(1, 4); break;
{
"replace_all": false,
"file_path": "/workspace/problems/05_topk_bitonic/solution.py",
"old_string": "<168 chars \u2014 see diff>",
"new_string": "<223 chars \u2014 see diff>"
}The file /workspace/problems/05_topk_bitonic/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)