"""Custom CUDA top-k via streaming selection + block merge (SM120)."""
from __future__ import annotations
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
CUDA_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <cfloat>
#define WARP_SIZE 32
__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
float* vals, int64_t* idxs) {
if (val <= vals[k - 1]) return;
int pos = k - 1;
while (pos > 0 && val > vals[pos - 1]) {
vals[pos] = vals[pos - 1];
idxs[pos] = idxs[pos - 1];
--pos;
}
vals[pos] = val;
idxs[pos] = idx;
}
template <int K>
__device__ __forceinline__ void insert_desc(float val, int64_t idx,
float* vals, int64_t* idxs) {
insert_desc(K, val, idx, vals, idxs);
}
template <int K>
__device__ void merge_topk(const float* a_val, const int64_t* a_idx,
const float* b_val, const int64_t* b_idx,
float* out_val, int64_t* out_idx) {
int ai = 0, bi = 0, oi = 0;
while (oi < K) {
if (ai < K && (bi >= K || a_val[ai] > b_val[bi])) {
out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
} else {
out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
}
++oi;
}
}
template <int K>
__host__ __device__ constexpr int threads_for_k() {
if (K <= 16) return 128;
return 64;
}
template <int K>
__device__ void block_reduce_topk(float* local_val, int64_t* local_idx,
float* s_val, int64_t* s_idx, int nthreads) {
int tid = threadIdx.x;
#pragma unroll
for (int j = 0; j < K; ++j) {
s_val[tid * K + j] = local_val[j];
s_idx[tid * K + j] = local_idx[j];
}
__syncthreads();
for (int stride = nthreads / 2; stride >= 1; stride >>= 1) {
if (tid < stride) {
float tmp[K]; int64_t tidx[K];
merge_topk<K>(s_val + tid * K, s_idx + tid * K,
s_val + (tid + stride) * K, s_idx + (tid + stride) * K,
tmp, tidx);
#pragma unroll
for (int j = 0; j < K; ++j) {
s_val[tid * K + j] = tmp[j];
s_idx[tid * K + j] = tidx[j];
}
}
__syncthreads();
}
}
__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
float& out_val, int64_t& out_idx) {
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
float ov = __shfl_down_sync(0xffffffff, val, offset);
int64_t oi = __shfl_down_sync(0xffffffff, idx, offset);
if (ov > val) { val = ov; idx = oi; }
}
out_val = val; out_idx = idx;
}
__global__ void topk1_kernel(const float* __restrict__ in,
float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,
int batch, int n) {
int row = blockIdx.x;
if (row >= batch) return;
const float* row_in = in + (int64_t)row * n;
float best = -FLT_MAX;
int64_t best_i = 0;
int i = threadIdx.x * 4;
for (; i + 3 < n; i += blockDim.x * 4) {
float4 v = *reinterpret_cast<const float4*>(row_in + i);
if (v.x > best) { best = v.x; best_i = i; }
if (v.y > best) { best = v.y; best_i = i + 1; }
if (v.z > best) { best = v.z; best_i = i + 2; }
if (v.w > best) { best = v.w; best_i = i + 3; }
}
for (; i < n; i += blockDim.x) {
float v = row_in[i];
if (v > best) { best = v; best_i = i; }
}
__shared__ float sv[32];
__shared__ int64_t si[32];
int lane = threadIdx.x & 31, wid = threadIdx.x >> 5;
float wv = best; int64_t wi = best_i;
warp_reduce_max(wv, wi, wv, wi);
if (lane == 0) { sv[wid] = wv; si[wid] = wi; }
__syncthreads();
if (wid == 0) {
wv = (lane < (blockDim.x + 31) / 32) ? sv[lane] : -FLT_MAX;
wi = (lane < (blockDim.x + 31) / 32) ? si[lane] : 0;
warp_reduce_max(wv, wi, wv, wi);
if (lane == 0) { out_vals[row] = wv; out_idxs[row] = wi; }
}
}
template <int K>
__device__ void scan_row_topk(const float* row_in, int n,
float* local_val, int64_t* local_idx) {
#pragma unroll
for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
int i = threadIdx.x * 4;
for (; i + 3 < n; i += blockDim.x * 4) {
float4 v = *reinterpret_cast<const float4*>(row_in + i);
insert_desc<K>(v.x, (int64_t)i, local_val, local_idx);
insert_desc<K>(v.y, (int64_t)(i + 1), local_val, local_idx);
insert_desc<K>(v.z, (int64_t)(i + 2), local_val, local_idx);
insert_desc<K>(v.w, (int64_t)(i + 3), local_val, local_idx);
}
for (; i < n; i += blockDim.x) {
insert_desc<K>(row_in[i], (int64_t)i, local_val, local_idx);
}
}
template <int K>
__global__ void topk_kernel(const float* __restrict__ in,
float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,
int batch, int n) {
int row = blockIdx.x;
if (row >= batch) return;
const float* row_in = in + (int64_t)row * n;
float local_val[K]; int64_t local_idx[K];
scan_row_topk<K>(row_in, n, local_val, local_idx);
extern __shared__ char smem[];
float* s_val = reinterpret_cast<float*>(smem);
int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
block_reduce_topk<K>(local_val, local_idx, s_val, s_idx, blockDim.x);
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < K; ++j) {
out_vals[(int64_t)row * K + j] = s_val[j];
out_idxs[(int64_t)row * K + j] = s_idx[j];
}
}
}
template <int K>
__global__ void topk_slice_kernel(const float* __restrict__ in,
float* __restrict__ partial_vals,
int64_t* __restrict__ partial_idxs,
int row, int n, int num_blocks) {
int bid = blockIdx.x;
if (bid >= num_blocks) return;
int slice = (n + num_blocks - 1) / num_blocks;
int start = bid * slice;
int end = min(start + slice, n);
const float* row_in = in + (int64_t)row * n;
float local_val[K]; int64_t local_idx[K];
#pragma unroll
for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
int i = start + threadIdx.x * 4;
for (; i + 3 < end; i += blockDim.x * 4) {
float4 v = *reinterpret_cast<const float4*>(row_in + i);
insert_desc<K>(v.x, (int64_t)i, local_val, local_idx);
insert_desc<K>(v.y, (int64_t)(i + 1), local_val, local_idx);
insert_desc<K>(v.z, (int64_t)(i + 2), local_val, local_idx);
insert_desc<K>(v.w, (int64_t)(i + 3), local_val, local_idx);
}
for (; i < end; i += blockDim.x) {
insert_desc<K>(row_in[i], (int64_t)i, local_val, local_idx);
}
extern __shared__ char smem[];
float* s_val = reinterpret_cast<float*>(smem);
int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
block_reduce_topk<K>(local_val, local_idx, s_val, s_idx, blockDim.x);
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < K; ++j) {
partial_vals[(int64_t)bid * K + j] = s_val[j];
partial_idxs[(int64_t)bid * K + j] = s_idx[j];
}
}
}
template <int K>
__global__ void topk_merge_kernel(const float* __restrict__ partial_vals,
const int64_t* __restrict__ partial_idxs,
float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,
int row, int num_partials) {
float local_val[K]; int64_t local_idx[K];
#pragma unroll
for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
for (int p = threadIdx.x; p < num_partials; p += blockDim.x) {
#pragma unroll
for (int j = 0; j < K; ++j) {
insert_desc<K>(partial_vals[(int64_t)p * K + j],
partial_idxs[(int64_t)p * K + j],
local_val, local_idx);
}
}
extern __shared__ char smem[];
float* s_val = reinterpret_cast<float*>(smem);
int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
block_reduce_topk<K>(local_val, local_idx, s_val, s_idx, blockDim.x);
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < K; ++j) {
out_vals[(int64_t)row * K + j] = s_val[j];
out_idxs[(int64_t)row * K + j] = s_idx[j];
}
}
}
template <int K>
void launch_topk(const float* in, float* out_vals, int64_t* out_idxs,
int batch, int n, cudaStream_t stream,
float* partial_vals, int64_t* partial_idxs, int partial_cap) {
constexpr int T = threads_for_k<K>();
size_t smem = (size_t)T * K * (sizeof(float) + sizeof(int64_t));
if (batch == 1 && n >= 65536 && partial_vals != nullptr) {
int num_blocks = min(partial_cap, (n + T - 1) / T);
topk_slice_kernel<K><<<num_blocks, T, smem, stream>>>(
in, partial_vals, partial_idxs, 0, n, num_blocks);
topk_merge_kernel<K><<<1, T, smem, stream>>>(
partial_vals, partial_idxs, out_vals, out_idxs, 0, num_blocks);
} else {
topk_kernel<K><<<batch, T, smem, stream>>>(in, out_vals, out_idxs, batch, n);
}
}
void dispatch_topk(const torch::Tensor& input, torch::Tensor& out_vals,
torch::Tensor& out_idxs, int k,
torch::Tensor& partial_vals, torch::Tensor& partial_idxs) {
const float* in_ptr = input.data_ptr<float>();
float* val_ptr = out_vals.data_ptr<float>();
int64_t* idx_ptr = out_idxs.data_ptr<int64_t>();
int batch = (int)input.size(0);
int n = (int)input.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
float* pval = partial_vals.defined() && partial_vals.numel() > 0
? partial_vals.data_ptr<float>() : nullptr;
int64_t* pidx = partial_idxs.defined() && partial_idxs.numel() > 0
? partial_idxs.data_ptr<int64_t>() : nullptr;
int partial_cap = partial_vals.defined() ? (int)partial_vals.size(0) : 0;
if (k == 1) {
topk1_kernel<<<batch, 256, 0, stream>>>(in_ptr, val_ptr, idx_ptr, batch, n);
} else if (k == 8) {
launch_topk<8>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
} else if (k == 16) {
launch_topk<16>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
} else if (k == 32) {
launch_topk<32>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
} else if (k == 64) {
launch_topk<64>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
} else {
TORCH_CHECK(false, "unsupported k=", k);
}
}
std::vector<torch::Tensor> topk_cuda(torch::Tensor input, int k,
torch::Tensor partial_vals,
torch::Tensor partial_idxs) {
TORCH_CHECK(input.is_cuda() && input.dtype() == torch::kFloat32);
TORCH_CHECK(input.dim() == 2);
auto opts_f = torch::TensorOptions().dtype(torch::kFloat32).device(input.device());
auto opts_i = torch::TensorOptions().dtype(torch::kInt64).device(input.device());
auto values = torch::empty({(int)input.size(0), k}, opts_f);
auto indices = torch::empty({(int)input.size(0), k}, opts_i);
dispatch_topk(input.contiguous(), values, indices, k, partial_vals, partial_idxs);
return {values, indices};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_cuda", &topk_cuda, "Custom top-k (values, indices)");
}
"""
CUDA_FLAGS = [
"-O3",
"--use_fast_math",
"-lineinfo",
"-gencode=arch=compute_120,code=sm_120",
]
_topk_ext = None
_partial_cache: dict[tuple[int, int, str], tuple[torch.Tensor, torch.Tensor]] = {}
def _get_ext():
global _topk_ext
if _topk_ext is None:
_topk_ext = load_inline(
name="topk_cuda_ext",
cpp_sources="",
cuda_sources=CUDA_SRC,
functions=None,
extra_cuda_cflags=CUDA_FLAGS,
verbose=False,
)
return _topk_ext
def _get_partial_workspace(k: int, cap: int, device: torch.device):
key = (k, cap, str(device))
if key not in _partial_cache:
_partial_cache[key] = (
torch.empty(cap, k, dtype=torch.float32, device=device),
torch.empty(cap, k, dtype=torch.int64, device=device),
)
return _partial_cache[key]
class Model(nn.Module):
"""Top-k over the last dim of a 2D tensor."""
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch, self.n, self.k = batch, n, k
self.register_buffer("_dummy", torch.zeros(1))
self._use_slice = batch == 1 and n >= 65536 and k > 1
self._partial_cap = min(1024, (n + 63) // 64) if self._use_slice else 0
def forward(self, x: torch.Tensor):
ext = _get_ext()
if self._use_slice:
pv, pi = _get_partial_workspace(self.k, self._partial_cap, x.device)
else:
pv = torch.empty(0, dtype=torch.float32, device=x.device)
pi = torch.empty(0, dtype=torch.int64, device=x.device)
return ext.topk_cuda(x.contiguous(), self.k, pv, pi)
batch = 64
n = 8192
k = 8
def get_inputs():
x = torch.randn(batch, n, dtype=torch.float32)
return [x]
def get_init_inputs():
return [batch, n, k]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120752_cursor_composer-2.5-fast_05_topk_bitonic/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T18:33:36.334278+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T18:33:36.467775+00:00 elapsed_s=0.133 ms=1.852128
shape=0 variant=solution tflops=0.000 gbps=0.283 ms=1.852
shape=0 solution_peak_fraction=0.0002
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T18:33:36.469748+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T18:33:36.476497+00:00 elapsed_s=0.007 ms=0.038048
shape=1 variant=solution tflops=0.055 gbps=55.280 ms=0.038
shape=1 solution_peak_fraction=0.0307
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T18:33:36.478189+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T18:33:36.535566+00:00 elapsed_s=0.057 ms=0.882048
shape=2 variant=solution tflops=0.002 gbps=2.392 ms=0.882
shape=2 solution_peak_fraction=0.0013
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T18:33:36.536191+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T18:33:36.552124+00:00 elapsed_s=0.016 ms=0.190160
shape=3 variant=solution tflops=0.004 gbps=4.055 ms=0.190
shape=3 solution_peak_fraction=0.0023
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T18:33:36.553427+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T18:33:36.558671+00:00 elapsed_s=0.005 ms=0.011664
shape=4 variant=solution tflops=0.180 gbps=179.929 ms=0.012
shape=4 solution_peak_fraction=0.1000
peak_fraction: 0.0043
RESULT: LOW
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120752_cursor_composer-2.5-fast_05_topk_bitonic/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120752_cursor_composer-2.5-fast_05_topk_bitonic/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120752_cursor_composer-2.5-fast_05_topk_bitonic/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120752_cursor_composer-2.5-fast_05_topk_bitonic/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120752_cursor_composer-2.5-fast_05_topk_bitonic/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120752_cursor_composer-2.5-fast_05_topk_bitonic/repo
Installed 56 packages in 66ms
PASS
{
"run_id": "20260615_120752_cursor_composer-2.5-fast_05_topk_bitonic",
"run_group": "",
"problem": "05_topk_bitonic",
"harness": "cursor",
"model": "composer-2.5-fast",
"reasoning_effort": "",
"started_at": "2026-06-15T12:07:52-06:00",
"harness_finished_at": "2026-06-15T12:32:14-06:00",
"finished_at": "2026-06-15T12:33:36-06:00",
"start_epoch": 1781546872,
"harness_end_epoch": 1781548334,
"end_epoch": 1781548416,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.0043,
"template_mutated": false,
"elapsed_seconds": 1462,
"total_elapsed_seconds": 1544,
"check_elapsed_seconds": 81,
"benchmark_elapsed_seconds": 1,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 50.56019151846785,
"usage": {"input_tokens": 234740, "output_tokens": 73919, "cache_read_tokens": 5229344, "cache_creation_tokens": 0, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"targetDirectory": "/workspace/problems/05_topk_bitonic",
"globPattern": "**/*"
}{
"path": "/workspace/problems/05_topk_bitonic/reference.py"
}null{
"path": "/workspace/problems/05_topk_bitonic/solution.py"
}"""Naive top-k reference: torch.topk over the last dim.
This is the correctness oracle. The agent's solution must produce the same
top-k values (and equivalent indices modulo ties) within the tolerance
declared in problem.yaml. Note that solution.py is FORBIDDEN from calling
torch.topk / torch.sort / torch.kthvalue (see problem.yaml).
"""
import torch
import torch.nn as nn
OP_TYPE = "topk"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
class Model(nn.Module):
"""Top-k over the last dim of a 2D tensor.
Input:
x: (batch, n) fp32
Output:
values: (batch, k) fp32, sorted descending
indices: (batch, k) int64, into the last dim of x
"""
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch, self.n, self.k = batch, n, k
# No learned parameters, but declare a dummy buffer so state_dict
# is non-empty and load_state_dict(strict=True) is meaningful.
self.register_buffer("_dummy", torch.zeros(1))
def forward(self, x: torch.Tensor):
values, indices = torch.topk(x, k=self.k, dim=-1, largest=True, sorted=True)
return values, indices
# Module-level shims rebuilt by check.py / benchmark.py per shape.
batch = 64
n = 8192
k = 8
def get_inputs():
# fp32 input drawn from a roughly Gaussian distribution; ties unlikely
# but possible. Seed is set by the caller.
x = torch.randn(batch, n, dtype=torch.float32)
... (61 more chars){
"path": "/workspace/problems/05_topk_bitonic/check.py"
}errorMessage: 'File not found'{
"path": "/workspace/problems/05_topk_bitonic/benchmark.py"
}"""Correctness runner for TopK.
Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
each. Top-k correctness has two parts:
1. VALUES: sol_values must match ref_values within fp32 tol. Both are
returned sorted descending, so positional comparison is well-defined.
2. INDICES: lenient — we do NOT require sol_indices == ref_indices because
ties in x can yield multiple valid index sets. Instead we gather x at
sol_indices and check those values match ref_values within tol. This
catches "wrong indices" without false-failing on legitimate tie-breaks.
Also rejects forbidden ops by grep.
"""
import re
import sys
from pathlib import Path
import torch
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.correctness import check_correctness # noqa: E402
from src.eval.numeric_stress import ( # noqa: E402
numeric_stress_cases,
numeric_stress_context,
tolerance_for_case,
)
def main():
try:
import reference
import shapes
import solution
except Exception as e:
print(f"FAIL: import error: {e}")
sys.exit(1)
problem_yaml = Path("problem.yaml")
meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
# --- Forbidden-op check ------------------------------------------------
sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
for forbidden in meta.get
... (4786 more chars)"""Roofline benchmark for TopK.
For each shape: times eager reference (torch.topk), compiled reference, SOTA
(also torch.topk — see sota.py), and the agent's solution. Reports achieved
TFLOPS, GB/s, and peak_fraction (vs DRAM bandwidth, since this is memory-bound).
Output lines the harness picks up:
shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
"""
import sys
from math import exp, log
from pathlib import Path
import torch
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
from src.hardware import get as get_hw # noqa: E402
def _eval_formula(expr: str, vars: dict) -> float:
return float(eval(expr, {"__builtins__": {}}, vars))
def main():
import reference
import shapes
import solution
meta = yaml.safe_load(Path("problem.yaml").read_text())
hw = get_hw(meta["hardware"][0])
peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
peak_gbps = hw.peak_bandwidth_gb_s
regime = meta.get("regime", "memory")
flops_formula = meta["flops_formula"]
bytes_formula = meta["bytes_formula"]
num_perf_trials = int(meta.get("num_perf_trials", 50))
device = torch.device("cuda:0")
include_baselines = benchmark_baselines_enabled(
... (3576 more chars){
"path": "/workspace/problems/05_topk_bitonic/shapes.py"
}{
"path": "/workspace/problems/05_topk_bitonic/problem.yaml"
}{
"path": "/workspace/problems/05_topk_bitonic/sota.py"
}"""Canonical shape sweep for TopK.
Mix of:
- decoder vocab top-k (single sequence, very large n, moderate k) — pure
bandwidth test; the input read dominates everything.
- prefill / batched attention top-k (many rows, moderate n, small k) — tests
per-row parallelism and shared-memory bitonic networks.
- non-power-of-2 n stress case — bitonic sort networks naturally want
powers of two; this forces the agent to handle padding or partial sorts.
- small-k limit — k=1 (argmax) is a degenerate but useful sanity case.
"""
SHAPES = [
{"batch": 1, "n": 131072, "k": 64}, # decoder vocab top-k (Llama vocab ~128k)
{"batch": 64, "n": 8192, "k": 8}, # prefill / attention top-k
{"batch": 32, "n": 16384, "k": 32}, # mid-size batched
{"batch": 16, "n": 12000, "k": 16}, # non-power-of-2 n stress
{"batch": 128, "n": 4096, "k": 1}, # batched argmax (k=1 corner case)
]
name: 05_topk_bitonic
display_name: "TopK via Bitonic Sort"
precision: fp32
regime: memory
# Top-k is dominated by the input read (small output, no reduction over k).
# Comparator-network FLOPs are not the bottleneck on real hardware, so we
# track them but score on bandwidth.
flops_formula: "batch * n * 4" # ~O(n log^2 n) compares total, but the
# bitonic network is so cheap relative
# to memory that a coarse 4*n estimate
# is fine for telemetry only.
bytes_formula: "batch * n * 4 + batch * k * (4 + 8)" # fp32 input read + (fp32 value + int64 idx) output
hardware: [RTX_PRO_6000]
peak_tflops_key: fp32
peak_bandwidth_key: dram
# Top-k correctness:
# - VALUES must match within fp32 atol/rtol (the kth largest value is
# well-defined modulo float-equal ties, so we use a loose-ish tol).
# - INDICES are checked leniently: for each row, the multiset of returned
# indices must select values that match ref values within tol. Direct
# index equality is NOT required (ties in x can yield different valid
# index sets).
tolerance:
float32: 1.0e-4
# Forbidden ops — using any of these in solution.py fails correctness post-hoc.
# This problem is about IMPLEMENTING the selection, not dispatching to PyTorch's
# tuned top-k. torch.sort is also banned because torch.topk falls back to it.
forbidden:
- "torch.topk"
- "torch.kthvalue"
... (697 more chars)"""SOTA reference for TopK: torch.topk itself.
torch.topk dispatches to a CUB-backed kernel that uses radix-select for
moderate k and a tuned bitonic sort for small n. It is the bar the agent's
hand-rolled bitonic kernel must beat. There is no obvious vendor library that
does better on the (batch, n, k) shape mix we evaluate — Faiss BlockSelect is
specialized for k>=32 with much larger n, and CUB's DeviceSegmentedRadixSort
sorts the full row (overkill for top-k).
This file is INTENTIONALLY allowed to call torch.topk because it is the SOTA
oracle, not the agent's submission. The agent's solution.py is forbidden from
using torch.topk (see problem.yaml.forbidden).
"""
from __future__ import annotations
import torch
def sota_forward(x: torch.Tensor, k: int):
"""Best-available top-k reference. x: (batch, n) fp32."""
return torch.topk(x, k=k, dim=-1, largest=True, sorted=True)
def is_available() -> bool:
return True
{
"command": "cd /workspace/problems/05_topk_bitonic && uv run python -c \"import torch; print(torch.cuda.get_device_name(0)); print(torch.version.cuda); print(torch.__version__); props=torch.cuda.get_device_properties(0); print(f'SM={props.major}.{props.minor}, mem={props.total_memory/1e9:.1f}GB')\"",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_fba77c68-4347-4089-822a-8a9f4e6c450",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
... (1454 more chars)null{
"command": "<192 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_8529998f-8f51-4121-92e6-d6f2c647fd7",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name": "
... (1019 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
print('CUDA extension works')
# Check if cub is available via cuda
"error: 'Glob pattern "**/*" matches every file and is not allowed. Use a more specific glob or no glob.'{
"command": "nvcc --version 2>/dev/null; ls /usr/local/cuda/include/cub 2>/dev/null | head -5",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_a4e8ac19-e6a9-412f-a616-2a5da625f7a",
"simpleCommands": [
"nvcc",
"ls",
"head"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "nvcc",
"args": [
{
"type": "word",
"value": "--version"
}
],
"fullText": "nvcc --version"
},
{
... (1090 more chars)nullnull--- /dev/null
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -0,0 +1,474 @@
+"""Custom CUDA top-k kernel for RTX PRO 6000 (SM120).
+
+Per-row streaming selection with block-local merge and bitonic sort for the
+final k elements. Uses load_inline CUDA — no torch.topk / sort.
+"""
+from __future__ import annotations
+
+import math
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+CUDA_SRC = r"""
+#include <torch/extension.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cstdint>
+#include <cfloat>
+
+#define WARP_SIZE 32
+
+// ---------------------------------------------------------------------------
+// Helpers
+// ---------------------------------------------------------------------------
+
+__device__ __forceinline__ int float_cmp_gt(float a, float b) {
+ return a > b;
+}
+
+__device__ __forceinline__ void swap_pair(float& va, int64_t& ia, float& vb, int64_t& ib) {
+ float tv = va; int64_t ti = ia;
+ va = vb; ia = ib;
+ vb = tv; ib = ti;
+}
+
+// Insert (val, idx) into sorted-descending arrays of length K (K <= 64).
+template <int K>
+__device__ __forceinline__ void insert_topk(float val, int64_t idx,
+ float* vals, int64_t* idxs) {
+ if (!float_cmp_gt(val, vals[K - 1])) return;
+ int pos = K - 1;
+ while (pos > 0 && float_cmp_gt(val, vals[pos - 1])) {
+ vals[pos] = vals[pos - 1];
+ idxs[pos] = idxs[pos - 1];
+ --pos;
+ }
+ vals[pos] = val;
+ idxs[pos] = idx;
+}
+
+// Bitonic sort (descending) for N power-of-two, vals/idxs length N in shared mem.
+template <int N>
+__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
+ for (int k = 2; k <= N; k <<= 1) {
+ for (int j = k >> 1; j > 0; j >>= 1) {
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
+ int ixj = i ^ j;
+ if (ixj > i) {
+ bool asc = ((i & k) == 0);
+ bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
+ if (swap_it) swap_pair(vals[i], idxs[i], vals[ixj], idxs[ixj]);
+ }
+ }
+ __syncthreads();
+ }
+ }
+}
+
+// Merge two sorted-desc arrays of size K into one sorted-desc array of size K.
+template <int K>
+__device__ void merge_topk(float* a_val, int64_t* a_idx,
+ float* b_val, int64_t* b_idx,
+ float* out_val, int64_t* out_idx) {
+ int ai = 0, bi = 0, oi = 0;
+ while (oi < K) {
+ if (ai < K && (bi >= K || float_cmp_gt(a_val[ai], b_val[bi]))) {
+ out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
+ } else {
+ out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
+ }
+ ++oi;
+ }
+}
+
+// Warp reduce max for k=1
+__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
+ float& out_val, int64_t& out_idx) {
+ for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
+ float other_val = __shfl_down_sync(0xffffffff, val, offset);
+ int64_t other_idx = __shfl_down_sync(0xffffffff, idx, offset);
+ if (other_val > val) { val = other_val; idx = other_idx; }
+ }
+ out_val = val; out_idx = idx;
+}
+
+// ---------------------------------------------------------------------------
+// k == 1 kernel: one block per row
+// ---------------------------------------------------------------------------
+__global__ void topk1_kernel(const float* __restrict__ input,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+
+ const float* row_in = input + (int64_t)row * n;
+ float best_val = -FLT_MAX;
+ int64_t best_idx = 0;
+
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ float v = row_in[i];
+ if (v > best_val) { best_val = v; best_idx = i; }
+ }
+
+ __shared__ float s_val[32];
+ __shared__ int64_t s_idx[32];
+
+ int lane = threadIdx.x & 31;
+ int wid = threadIdx.x >> 5;
+ float wval = best_val;
+ int64_t widx = best_idx;
+ warp_reduce_max(wval, widx, wval, widx);
+
+ if (lane == 0) { s_val[wid] = wval; s_idx[wid] = widx; }
+ __syncthreads();
+
+ if (wid == 0) {
+ wval = (lane < (blockDim.x + 31) / 32) ? s_val[lane] : -FLT_MAX;
+ widx = (lane < (blockDim.x + 31) / 32) ? s_idx[lane] : 0;
+ warp_reduce_max(wval, widx, wval, widx);
+ if (lane == 0) {
+ out_vals[row] = wval;
+ out_idxs[row] = widx;
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// General top-k: one block per row, threads stream n with local top-k
+// ---------------------------------------------------------------------------
+template <int K, int MAX_K>
+__global__ void topk_kernel(const float* __restrict__ input,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+
+ const float* row_in = input + (int64_t)row * n;
+
+ // Per-thread local top-k in registers
+ float local_val[MAX_K];
+ int64_t local_idx[MAX_K];
+ #pragma unroll
+ for (int i = 0; i < K; ++i) {
+ local_val[i] = -FLT_MAX;
+ local_idx[i] = 0;
+ }
+
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ insert_topk<K>(row_in[i], (int64_t)i, local_val, local_idx);
+ }
+
+ // Shared memory for block reduction: each thread contributes K elements
+ extern __shared__ char smem[];
+ float* s_val = reinterpret_cast<float*>(smem);
+ int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
+
+ int tid
... (11021 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<16461 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_446c876b-7ec2-4b14-9c70-0ae63cfc20b",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bi
... (899 more chars)failure:
command: 'cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: forbidden op used: torch.topk
stderr: ''
executionTime: 854
interleavedOutput:
FAIL: forbidden op used: torch.topk
aborted: False
localExecutionTimeMs: 795
isBackground: False--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,474 +1,5 @@
"""Custom CUDA top-k kernel for RTX PRO 6000 (SM120).
Per-row streaming selection with block-local merge and bitonic sort for the
-final k elements. Uses load_inline CUDA — no torch.topk / sort.
-"""
-from __future__ import annotations
-
-import math
-
-import torch
-import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-
-CUDA_SRC = r"""
-#include <torch/extension.h>
-#include <cuda.h>
-#include <cuda_runtime.h>
-#include <cstdint>
-#include <cfloat>
-
-#define WARP_SIZE 32
-
-// ---------------------------------------------------------------------------
-// Helpers
-// ---------------------------------------------------------------------------
-
-__device__ __forceinline__ int float_cmp_gt(float a, float b) {
- return a > b;
-}
-
-__device__ __forceinline__ void swap_pair(float& va, int64_t& ia, float& vb, int64_t& ib) {
- float tv = va; int64_t ti = ia;
- va = vb; ia = ib;
- vb = tv; ib = ti;
-}
-
-// Insert (val, idx) into sorted-descending arrays of length K (K <= 64).
-template <int K>
-__device__ __forceinline__ void insert_topk(float val, int64_t idx,
- float* vals, int64_t* idxs) {
- if (!float_cmp_gt(val, vals[K - 1])) return;
- int pos = K - 1;
- while (pos > 0 && float_cmp_gt(val, vals[pos - 1])) {
- vals[pos] = vals[pos - 1];
- idxs[pos] = idxs[pos - 1];
- --pos;
- }
- vals[pos] = val;
- idxs[pos] = idx;
-}
-
-// Bitonic sort (descending) for N power-of-two, vals/idxs length N in shared mem.
-template <int N>
-__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
- for (int k = 2; k <= N; k <<= 1) {
- for (int j = k >> 1; j > 0; j >>= 1) {
- for (int i = threadIdx.x; i < N; i += blockDim.x) {
- int ixj = i ^ j;
- if (ixj > i) {
- bool asc = ((i & k) == 0);
- bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
- if (swap_it) swap_pair(vals[i], idxs[i], vals[ixj], idxs[ixj]);
- }
- }
- __syncthreads();
- }
- }
-}
-
-// Merge two sorted-desc arrays of size K into one sorted-desc array of size K.
-template <int K>
-__device__ void merge_topk(float* a_val, int64_t* a_idx,
- float* b_val, int64_t* b_idx,
- float* out_val, int64_t* out_idx) {
- int ai = 0, bi = 0, oi = 0;
- while (oi < K) {
- if (ai < K && (bi >= K || float_cmp_gt(a_val[ai], b_val[bi]))) {
- out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
- } else {
- out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
- }
- ++oi;
- }
-}
-
-// Warp reduce max for k=1
-__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
- float& out_val, int64_t& out_idx) {
- for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
- float other_val = __shfl_down_sync(0xffffffff, val, offset);
- int64_t other_idx = __shfl_down_sync(0xffffffff, idx, offset);
- if (other_val > val) { val = other_val; idx = other_idx; }
- }
- out_val = val; out_idx = idx;
-}
-
-// ---------------------------------------------------------------------------
-// k == 1 kernel: one block per row
-// ---------------------------------------------------------------------------
-__global__ void topk1_kernel(const float* __restrict__ input,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
-
- const float* row_in = input + (int64_t)row * n;
- float best_val = -FLT_MAX;
- int64_t best_idx = 0;
-
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- float v = row_in[i];
- if (v > best_val) { best_val = v; best_idx = i; }
- }
-
- __shared__ float s_val[32];
- __shared__ int64_t s_idx[32];
-
- int lane = threadIdx.x & 31;
- int wid = threadIdx.x >> 5;
- float wval = best_val;
- int64_t widx = best_idx;
- warp_reduce_max(wval, widx, wval, widx);
-
- if (lane == 0) { s_val[wid] = wval; s_idx[wid] = widx; }
- __syncthreads();
-
- if (wid == 0) {
- wval = (lane < (blockDim.x + 31) / 32) ? s_val[lane] : -FLT_MAX;
- widx = (lane < (blockDim.x + 31) / 32) ? s_idx[lane] : 0;
- warp_reduce_max(wval, widx, wval, widx);
- if (lane == 0) {
- out_vals[row] = wval;
- out_idxs[row] = widx;
- }
- }
-}
-
-// ---------------------------------------------------------------------------
-// General top-k: one block per row, threads stream n with local top-k
-// ---------------------------------------------------------------------------
-template <int K, int MAX_K>
-__global__ void topk_kernel(const float* __restrict__ input,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
-
- const float* row_in = input + (int64_t)row * n;
-
- // Per-thread local top-k in registers
- float local_val[MAX_K];
- int64_t local_idx[MAX_K];
- #pragma unroll
- for (int i = 0; i < K; ++i) {
- local_val[i] = -FLT_MAX;
- local_idx[i] = 0;
- }
-
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- insert_topk<K>(row_in[i], (int64_t)i, local_val, local_idx);
- }
-
- // Shared memory for block reduction: each thread contributes K elements
- extern __shared__ char smem[];
- float* s_val = reinterpret_cast<float*>(smem);
- int64_t* s_idx = reinterpret_cast<int64_t*>
... (11107 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<175 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_fc2c2513-de3f-4cb6-a543-c3166ff97de",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bi
... (917 more chars)failure:
command: 'cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/05_topk_bitonic/check.py", line 157, in <module>
main()
File "/workspace/problems/05_topk_bitonic/check.py", line 82, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/05_topk_bitonic/solution.py", line 459, in forward
ext = _get_ext()
... (8095 more chars){
"pattern": "getCurrentCUDAStream",
"path": "/workspace/.venv/lib/python3.11/site-packages/torch/include",
"caseInsensitive": false,
"headLimit": 10,
"multiline": false,
"toolCallId": "tool_261bf2c2-97f9-4e11-83dd-bd36545d26f",
"offset": 0
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,5 +1,3 @@
-"""Custom CUDA top-k kernel for RTX PRO 6000 (SM120).
-
-Per-row streaming selection with block-local merge and bitonic sort for the
-final k elements. Uses load_inline CUDA.
-"""+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <cuda.h>{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<81 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_07f8a237-b4f9-4476-b8cd-455a086197e",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bi
... (917 more chars)failure:
command: 'cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/problems/05_topk_bitonic/check.py", line 157, in <module>
main()
File "/workspace/problems/05_topk_bitonic/check.py", line 102, in main
ok, msg = check_correctness(
^^^^^^^^^^^^^^^^^^
File "/workspace/src/eval/correctness.py", line 71, in check_correctness
if torch.isnan(solution_out).any():
^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
stderr: ''
executionTime: 31823
interleavedOutput:
Traceback (most recent call last):
File "/workspace/problems/05_topk_bitonic/check.py", line 157, in <module>
main()
File "/workspace/problems/05_topk_bitonic/check.py", line 102, in main
ok, msg = check_correctness(
^^^^^^^^^^^^^^^^^^
File "/workspace/src/eval/correctness.py", line 71, in check_correctness
if torch.isnan(solut
... (571 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 1,
"limit": 100
}"""Custom CUDA top-k kernel for RTX PRO 6000 (SM120).
Per-row streaming selection with block-local merge and bitonic sort for the
final k elements. Uses load_inline CUDA.
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
CUDA_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <cfloat>
#define WARP_SIZE 32
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
__device__ __forceinline__ int float_cmp_gt(float a, float b) {
return a > b;
}
__device__ __forceinline__ void swap_pair(float& va, int64_t& ia, float& vb, int64_t& ib) {
float tv = va; int64_t ti = ia;
va = vb; ia = ib;
vb = tv; ib = ti;
}
// Insert (val, idx) into sorted-descending arrays of length K (K <= 64).
template <int K>
__device__ __forceinline__ void insert_topk(float val, int64_t idx,
float* vals, int64_t* idxs) {
if (!float_cmp_gt(val, vals[K - 1])) return;
int pos = K - 1;
while (pos > 0 && float_cmp_gt(val, vals[pos - 1])) {
vals[pos] = vals[pos - 1];
idxs[pos] = idxs[pos - 1];
--pos;
}
vals[pos] = val;
idxs[pos] = idx;
}
// Bitonic sort (descending) for N power-of-two, vals/idxs length N in shared
... (1876 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 98
}// ---------------------------------------------------------------------------
// k == 1 kernel: one block per row
// ---------------------------------------------------------------------------
__global__ void topk1_kernel(const float* __restrict__ input,
float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,
int batch, int n) {
int row = blockIdx.x;
if (row >= batch) return;
const float* row_in = input + (int64_t)row * n;
float best_val = -FLT_MAX;
int64_t best_idx = 0;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
float v = row_in[i];
if (v > best_val) { best_val = v; best_idx = i; }
}
__shared__ float s_val[32];
__shared__ int64_t s_idx[32];
int lane = threadIdx.x & 31;
int wid = threadIdx.x >> 5;
float wval = best_val;
int64_t widx = best_idx;
warp_reduce_max(wval, widx, wval, widx);
if (lane == 0) { s_val[wid] = wval; s_idx[wid] = widx; }
__syncthreads();
if (wid == 0) {
wval = (lane < (blockDim.x + 31) / 32) ? s_val[lane] : -FLT_MAX;
widx = (lane < (blockDim.x + 31) / 32) ? s_idx[lane] : 0;
warp_reduce_max(wval, widx, wval, widx);
if (lane == 0) {
out_vals[row] = wval;
out_idxs[row] = widx;
}
}
}
// ---------------------------------------------------------------------------
// General top-k: one block per row, threa
... (11790 more chars)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,3 +1,51 @@
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <cuda.h>+// Pick thread count so shared memory stays <= 48 KiB.
+template <int K>
+__host__ __device__ constexpr int threads_for_k() {
+ if (K <= 16) return 128;
+ if (K <= 32) return 96;
+ return 64;
+}
+
+template <int K>
+__host__ __device__ constexpr size_t smem_bytes(int threads) {
+ return (size_t)threads * K * (sizeof(float) + sizeof(int64_t));
+}
+
+// Tree-reduce per-thread top-k arrays in shared memory.
+template <int K>
+__device__ void block_reduce_topk(float* local_val, int64_t* local_idx,
+ float* s_val, int64_t* s_idx, int threads) {
+ int tid = threadIdx.x;
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ s_val[tid * K + j] = local_val[j];
+ s_idx[tid * K + j] = local_idx[j];
+ }
+ __syncthreads();
+
+ for (int stride = threads / 2; stride >= 1; stride >>= 1) {
+ if (tid < stride) {
+ float tmp_val[K];
+ int64_t tmp_idx[K];
+ merge_topk<K>(s_val + tid * K, s_idx + tid * K,
+ s_val + (tid + stride) * K, s_idx + (tid + stride) * K,
+ tmp_val, tmp_idx);
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ s_val[tid * K + j] = tmp_val[j];
+ s_idx[tid * K + j] = tmp_idx[j];
+ }
+ }
+ __syncthreads();
+ }
+}
+
+__device__ __forceinline__ void write_topk(int row, int k,
+ const float* vals, const int64_t* idxs,
+ float* out_vals, int64_t* out_idxs) {
+ #pragma unroll
+ for (int j = 0; j < k; ++j) {
+ out_vals[(int64_t)row * k + j] = vals[j];
+ out_idxs[(int64_t)row * k + j] = idxs[j];
+ }
+}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<1718 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,51 +1,31 @@
-// Pick thread count so shared memory stays <= 48 KiB.
template <int K>
-__host__ __device__ constexpr int threads_for_k() {
- if (K <= 16) return 128;
- if (K <= 32) return 96;
- return 64;
-}
+__global__ void topk_kernel(const float* __restrict__ input,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
-template <int K>
-__host__ __device__ constexpr size_t smem_bytes(int threads) {
- return (size_t)threads * K * (sizeof(float) + sizeof(int64_t));
-}
+ const float* row_in = input + (int64_t)row * n;
+ float local_val[K];
+ int64_t local_idx[K];
+ #pragma unroll
+ for (int i = 0; i < K; ++i) {
+ local_val[i] = -FLT_MAX;
+ local_idx[i] = 0;
+ }
-// Tree-reduce per-thread top-k arrays in shared memory.
-template <int K>
-__device__ void block_reduce_topk(float* local_val, int64_t* local_idx,
- float* s_val, int64_t* s_idx, int threads) {
- int tid = threadIdx.x;
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- s_val[tid * K + j] = local_val[j];
- s_idx[tid * K + j] = local_idx[j];
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ insert_topk<K>(row_in[i], (int64_t)i, local_val, local_idx);
}
- __syncthreads();
- for (int stride = threads / 2; stride >= 1; stride >>= 1) {
- if (tid < stride) {
- float tmp_val[K];
- int64_t tmp_idx[K];
- merge_topk<K>(s_val + tid * K, s_idx + tid * K,
- s_val + (tid + stride) * K, s_idx + (tid + stride) * K,
- tmp_val, tmp_idx);
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- s_val[tid * K + j] = tmp_val[j];
- s_idx[tid * K + j] = tmp_idx[j];
- }
- }
- __syncthreads();
- }
-}
+ extern __shared__ char smem[];
+ float* s_val = reinterpret_cast<float*>(smem);
+ int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
-__device__ __forceinline__ void write_topk(int row, int k,
- const float* vals, const int64_t* idxs,
- float* out_vals, int64_t* out_idxs) {
- #pragma unroll
- for (int j = 0; j < k; ++j) {
- out_vals[(int64_t)row * k + j] = vals[j];
- out_idxs[(int64_t)row * k + j] = idxs[j];
+ block_reduce_topk<K>(local_val, local_idx, s_val, s_idx, blockDim.x);
+
+ if (threadIdx.x == 0) {
+ write_topk<K>(row, K, s_val, s_idx, out_vals, out_idxs);
}
}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<989 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,22 +1,54 @@
template <int K>
-__global__ void topk_kernel(const float* __restrict__ input,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
+__global__ void topk_slice_kernel(const float* __restrict__ input,
+ float* __restrict__ partial_vals,
+ int64_t* __restrict__ partial_idxs,
+ int row, int n, int num_blocks) {
+ int bid = blockIdx.x;
+ if (bid >= num_blocks) return;
- const float* row_in = input + (int64_t)row * n;
+ int slice = (n + num_blocks - 1) / num_blocks;
+ int start = bid * slice;
+ int end = min(start + slice, n);
+
float local_val[K];
int64_t local_idx[K];
#pragma unroll
- for (int i = 0; i < K; ++i) {
- local_val[i] = -FLT_MAX;
- local_idx[i] = 0;
+ for (int i = 0; i < K; ++i) { local_val[i] = -FLT_MAX; local_idx[i] = 0; }
+
+ const float* row_in = input + (int64_t)row * n;
+ for (int i = start + threadIdx.x; i < end; i += blockDim.x) {
+ insert_topk<K>(row_in[i], (int64_t)i, local_val, local_idx);
}
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- insert_topk<K>(row_in[i], (int64_t)i, local_val, local_idx);
+ extern __shared__ char smem[];
+ float* s_val = reinterpret_cast<float*>(smem);
+ int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
+
+ block_reduce_topk<K>(local_val, local_idx, s_val, s_idx, blockDim.x);
+
+ if (threadIdx.x == 0) {
+ write_topk<K>(bid, K, s_val, s_idx, partial_vals, partial_idxs);
+ }
+}
+
+template <int K>
+__global__ void topk_merge_kernel(const float* __restrict__ partial_vals,
+ const int64_t* __restrict__ partial_idxs,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int row, int num_partials) {
+ float local_val[K];
+ int64_t local_idx[K];
+ #pragma unroll
+ for (int i = 0; i < K; ++i) { local_val[i] = -FLT_MAX; local_idx[i] = 0; }
+
+ for (int p = threadIdx.x; p < num_partials; p += blockDim.x) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ insert_topk<K>(partial_vals[(int64_t)p * K + j],
+ partial_idxs[(int64_t)p * K + j],
+ local_val, local_idx);
+ }
}
extern __shared__ char smem[];
{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<2319 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,63 +1,60 @@
template <int K>
-__global__ void topk_slice_kernel(const float* __restrict__ input,
- float* __restrict__ partial_vals,
- int64_t* __restrict__ partial_idxs,
- int row, int n, int num_blocks) {
- int bid = blockIdx.x;
- if (bid >= num_blocks) return;
+void launch_topk(const float* input, float* out_vals, int64_t* out_idxs,
+ int batch, int n, cudaStream_t stream,
+ float* partial_vals, int64_t* partial_idxs, int partial_cap) {
+ constexpr int threads = threads_for_k<K>();
+ size_t smem = smem_bytes<K>(threads);
- int slice = (n + num_blocks - 1) / num_blocks;
- int start = bid * slice;
- int end = min(start + slice, n);
-
- float local_val[K];
- int64_t local_idx[K];
- #pragma unroll
- for (int i = 0; i < K; ++i) { local_val[i] = -FLT_MAX; local_idx[i] = 0; }
-
- const float* row_in = input + (int64_t)row * n;
- for (int i = start + threadIdx.x; i < end; i += blockDim.x) {
- insert_topk<K>(row_in[i], (int64_t)i, local_val, local_idx);
- }
-
- extern __shared__ char smem[];
- float* s_val = reinterpret_cast<float*>(smem);
- int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
-
- block_reduce_topk<K>(local_val, local_idx, s_val, s_idx, blockDim.x);
-
- if (threadIdx.x == 0) {
- write_topk<K>(bid, K, s_val, s_idx, partial_vals, partial_idxs);
+ if (batch == 1 && n >= 65536) {
+ int num_blocks = min(partial_cap, (n + threads - 1) / threads);
+ topk_slice_kernel<K><<<num_blocks, threads, smem, stream>>>(
+ input, partial_vals, partial_idxs, 0, n, num_blocks);
+ topk_merge_kernel<K><<<1, threads, smem, stream>>>(
+ partial_vals, partial_idxs, out_vals, out_idxs, 0, num_blocks);
+ } else {
+ topk_kernel<K><<<batch, threads, smem, stream>>>(
+ input, out_vals, out_idxs, batch, n);
}
}
-template <int K>
-__global__ void topk_merge_kernel(const float* __restrict__ partial_vals,
- const int64_t* __restrict__ partial_idxs,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int row, int num_partials) {
- float local_val[K];
- int64_t local_idx[K];
- #pragma unroll
- for (int i = 0; i < K; ++i) { local_val[i] = -FLT_MAX; local_idx[i] = 0; }
+void dispatch_topk(const torch::Tensor& input, torch::Tensor& out_vals,
+ torch::Tensor& out_idxs, int k,
+ torch::Tensor& partial_vals, torch::Tensor& partial_idxs) {
+ const float* in_ptr = input.data_ptr<float>();
+ float* val_ptr = out_vals.data_ptr<float>();
+ int64_t* idx_ptr = out_idxs.data_ptr<int64_t>();
+ int batch = (int)input.size(0);
+ int n = (int)input.size(1);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ float* pval = partial_vals.defined() ? partial_vals.data_ptr<float>() : nullptr;
+ int64_t* pidx = partial_idxs.defined() ? partial_idxs.data_ptr<int64_t>() : nullptr;
+ int partial_cap = partial_vals.defined() ? (int)partial_vals.size(0) : 0;
- for (int p = threadIdx.x; p < num_partials; p += blockDim.x) {
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- insert_topk<K>(partial_vals[(int64_t)p * K + j],
- partial_idxs[(int64_t)p * K + j],
- local_val, local_idx);
- }
+ if (k == 1) {
+ launch_topk1(in_ptr, val_ptr, idx_ptr, batch, n, stream);
+ } else if (k == 8) {
+ launch_topk<8>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
+ } else if (k == 16) {
+ launch_topk<16>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
+ } else if (k == 32) {
+ launch_topk<32>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
+ } else if (k == 64) {
+ launch_topk<64>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
+ } else {
+ TORCH_CHECK(false, "unsupported k=", k);
}
+}
- extern __shared__ char smem[];
- float* s_val = reinterpret_cast<float*>(smem);
- int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
-
- block_reduce_topk<K>(local_val, local_idx, s_val, s_idx, blockDim.x);
-
- if (threadIdx.x == 0) {
- write_topk<K>(row, K, s_val, s_idx, out_vals, out_idxs);
- }
+std::vector<torch::Tensor> topk_cuda(torch::Tensor input, int k,
+ torch::Tensor partial_vals,
+ torch::Tensor partial_idxs) {
+ TORCH_CHECK(input.is_cuda() && input.dtype() == torch::kFloat32);
+ TORCH_CHECK(input.dim() == 2);
+ int batch = (int)input.size(0);
+ auto opts_f = torch::TensorOptions().dtype(torch::kFloat32).device(input.device());
+ auto opts_i = torch::TensorOptions().dtype(torch::kInt64).device(input.device());
+ auto values = torch::empty({batch, k}, opts_f);
+ auto indices = torch::empty({batch, k}, opts_i);
+ dispatch_topk(input.contiguous(), values, indices, k, partial_vals, partial_idxs);
+ return {values, indices};
}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<2934 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,60 +1,18 @@
-template <int K>
-void launch_topk(const float* input, float* out_vals, int64_t* out_idxs,
- int batch, int n, cudaStream_t stream,
- float* partial_vals, int64_t* partial_idxs, int partial_cap) {
- constexpr int threads = threads_for_k<K>();
- size_t smem = smem_bytes<K>(threads);
+ def __init__(self, batch: int, n: int, k: int):
+ super().__init__()
+ self.batch, self.n, self.k = batch, n, k
+ self.register_buffer("_dummy", torch.zeros(1))
+ # Workspace for large-n single-row slice path (k=64, n=131072).
+ if batch == 1 and n >= 65536 and k > 1:
+ cap = min(512, (n + 63) // 64)
+ self.register_buffer("_partial_vals", torch.empty(cap, k, dtype=torch.float32))
+ self.register_buffer("_partial_idxs", torch.empty(cap, k, dtype=torch.int64))
+ else:
+ self._partial_vals = None
+ self._partial_idxs = None
- if (batch == 1 && n >= 65536) {
- int num_blocks = min(partial_cap, (n + threads - 1) / threads);
- topk_slice_kernel<K><<<num_blocks, threads, smem, stream>>>(
- input, partial_vals, partial_idxs, 0, n, num_blocks);
- topk_merge_kernel<K><<<1, threads, smem, stream>>>(
- partial_vals, partial_idxs, out_vals, out_idxs, 0, num_blocks);
- } else {
- topk_kernel<K><<<batch, threads, smem, stream>>>(
- input, out_vals, out_idxs, batch, n);
- }
-}
-
-void dispatch_topk(const torch::Tensor& input, torch::Tensor& out_vals,
- torch::Tensor& out_idxs, int k,
- torch::Tensor& partial_vals, torch::Tensor& partial_idxs) {
- const float* in_ptr = input.data_ptr<float>();
- float* val_ptr = out_vals.data_ptr<float>();
- int64_t* idx_ptr = out_idxs.data_ptr<int64_t>();
- int batch = (int)input.size(0);
- int n = (int)input.size(1);
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- float* pval = partial_vals.defined() ? partial_vals.data_ptr<float>() : nullptr;
- int64_t* pidx = partial_idxs.defined() ? partial_idxs.data_ptr<int64_t>() : nullptr;
- int partial_cap = partial_vals.defined() ? (int)partial_vals.size(0) : 0;
-
- if (k == 1) {
- launch_topk1(in_ptr, val_ptr, idx_ptr, batch, n, stream);
- } else if (k == 8) {
- launch_topk<8>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
- } else if (k == 16) {
- launch_topk<16>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
- } else if (k == 32) {
- launch_topk<32>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
- } else if (k == 64) {
- launch_topk<64>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap);
- } else {
- TORCH_CHECK(false, "unsupported k=", k);
- }
-}
-
-std::vector<torch::Tensor> topk_cuda(torch::Tensor input, int k,
- torch::Tensor partial_vals,
- torch::Tensor partial_idxs) {
- TORCH_CHECK(input.is_cuda() && input.dtype() == torch::kFloat32);
- TORCH_CHECK(input.dim() == 2);
- int batch = (int)input.size(0);
- auto opts_f = torch::TensorOptions().dtype(torch::kFloat32).device(input.device());
- auto opts_i = torch::TensorOptions().dtype(torch::kInt64).device(input.device());
- auto values = torch::empty({batch, k}, opts_f);
- auto indices = torch::empty({batch, k}, opts_i);
- dispatch_topk(input.contiguous(), values, indices, k, partial_vals, partial_idxs);
- return {values, indices};
-}+ def forward(self, x: torch.Tensor):
+ ext = _get_ext()
+ pv = self._partial_vals if self._partial_vals is not None else torch.Tensor()
+ pi = self._partial_idxs if self._partial_idxs is not None else torch.Tensor()
+ return ext.topk_cuda(x.contiguous(), self.k, pv, pi){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<916 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,18 +1,10 @@
- def __init__(self, batch: int, n: int, k: int):
- super().__init__()
- self.batch, self.n, self.k = batch, n, k
- self.register_buffer("_dummy", torch.zeros(1))
- # Workspace for large-n single-row slice path (k=64, n=131072).
- if batch == 1 and n >= 65536 and k > 1:
- cap = min(512, (n + 63) // 64)
- self.register_buffer("_partial_vals", torch.empty(cap, k, dtype=torch.float32))
- self.register_buffer("_partial_idxs", torch.empty(cap, k, dtype=torch.int64))
- else:
- self._partial_vals = None
- self._partial_idxs = None
-
- def forward(self, x: torch.Tensor):
- ext = _get_ext()
- pv = self._partial_vals if self._partial_vals is not None else torch.Tensor()
- pi = self._partial_idxs if self._partial_idxs is not None else torch.Tensor()
- return ext.topk_cuda(x.contiguous(), self.k, pv, pi)+template <int K>
+__device__ __forceinline__ void write_topk(int row,
+ const float* vals, const int64_t* idxs,
+ float* out_vals, int64_t* out_idxs) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ out_vals[(int64_t)row * K + j] = vals[j];
+ out_idxs[(int64_t)row * K + j] = idxs[j];
+ }
+}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<393 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,10 +1,4 @@
-template <int K>
-__device__ __forceinline__ void write_topk(int row,
- const float* vals, const int64_t* idxs,
- float* out_vals, int64_t* out_idxs) {
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- out_vals[(int64_t)row * K + j] = vals[j];
- out_idxs[(int64_t)row * K + j] = idxs[j];
- }
-}+ self.register_buffer("_dummy", torch.zeros(1))
+ cap = min(512, (n + 63) // 64) if (batch == 1 and n >= 65536 and k > 1) else 0
+ self.register_buffer("_partial_vals", torch.empty(cap, max(k, 1), dtype=torch.float32))
+ self.register_buffer("_partial_idxs", torch.empty(cap, max(k, 1), dtype=torch.int64)){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<331 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,4 +1,2 @@
- self.register_buffer("_dummy", torch.zeros(1))
- cap = min(512, (n + 63) // 64) if (batch == 1 and n >= 65536 and k > 1) else 0
- self.register_buffer("_partial_vals", torch.empty(cap, max(k, 1), dtype=torch.float32))
- self.register_buffer("_partial_idxs", torch.empty(cap, max(k, 1), dtype=torch.int64))+ ext = _get_ext()
+ return ext.topk_cuda(x.contiguous(), self.k, self._partial_vals, self._partial_idxs){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<117 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,2 +1,11 @@
- ext = _get_ext()
- return ext.topk_cuda(x.contiguous(), self.k, self._partial_vals, self._partial_idxs)+_partial_cache: dict[tuple[int, int], tuple[torch.Tensor, torch.Tensor]] = {}
+
+
+def _get_partial_workspace(k: int, cap: int, device: torch.device):
+ key = (k, cap)
+ if key not in _partial_cache:
+ _partial_cache[key] = (
+ torch.empty(cap, k, dtype=torch.float32, device=device),
+ torch.empty(cap, k, dtype=torch.int64, device=device),
+ )
+ return _partial_cache[key]{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<409 chars \u2014 see diff>"
}path: ''
error: 'The string to replace was not found in the file.'
modelVisibleError: 'The string to replace was not found in the file (even after relaxing whitespace).'{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 350
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_cuda", &topk_cuda, "Custom top-k (values, indices)");
}
"""
CUDA_FLAGS = [
"-O3",
"--use_fast_math",
"-lineinfo",
"-gencode=arch=compute_120,code=sm_120",
]
_topk_ext = None
def _get_ext():
global _topk_ext
if _topk_ext is None:
_topk_ext = load_inline(
name="topk_cuda_ext",
cpp_sources="",
cuda_sources=CUDA_SRC,
functions=None,
extra_cuda_cflags=CUDA_FLAGS,
verbose=False,
)
return _topk_ext
class Model(nn.Module):
"""Top-k over the last dim of a 2D tensor."""
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch, self.n, self.k = batch, n, k
self.register_buffer("_dummy", torch.zeros(1))
cap = min(512, (n + 63) // 64) if (batch == 1 and n >= 65536 and k > 1) else 0
self.register_buffer("_partial_vals", torch.empty(cap, max(k, 1), dtype=torch.float32))
self.register_buffer("_partial_idxs", torch.empty(cap, max(k, 1), dtype=torch.int64))
def forward(self, x: torch.Tensor):
ext = _get_ext()
return ext.topk_cuda(x.contiguous(), self.k, self._partial_vals, self._partial_idxs)
batch = 64
n = 8192
k = 8
def get_inputs():
x = torch.randn(batch, n, dtype=torch.float32)
return [x]
def get_init_inputs():
return [batch, n, k]
--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,8 +1,9 @@
-_partial_cache: dict[tuple[int, int], tuple[torch.Tensor, torch.Tensor]] = {}
+_topk_ext = None
+_partial_cache: dict[tuple[int, int, str], tuple[torch.Tensor, torch.Tensor]] = {}
def _get_partial_workspace(k: int, cap: int, device: torch.device):
- key = (k, cap)
+ key = (k, cap, str(device))
if key not in _partial_cache:
_partial_cache[key] = (
torch.empty(cap, k, dtype=torch.float32, device=device),
{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<444 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,12 +1,14 @@
-_topk_ext = None
-_partial_cache: dict[tuple[int, int, str], tuple[torch.Tensor, torch.Tensor]] = {}
+ def __init__(self, batch: int, n: int, k: int):
+ super().__init__()
+ self.batch, self.n, self.k = batch, n, k
+ self.register_buffer("_dummy", torch.zeros(1))
+ self._use_slice = batch == 1 and n >= 65536 and k > 1
+ self._partial_cap = min(512, (n + 63) // 64) if self._use_slice else 0
-
-def _get_partial_workspace(k: int, cap: int, device: torch.device):
- key = (k, cap, str(device))
- if key not in _partial_cache:
- _partial_cache[key] = (
- torch.empty(cap, k, dtype=torch.float32, device=device),
- torch.empty(cap, k, dtype=torch.int64, device=device),
- )
- return _partial_cache[key]+ def forward(self, x: torch.Tensor):
+ ext = _get_ext()
+ if self._use_slice:
+ pv, pi = _get_partial_workspace(self.k, self._partial_cap, x.device)
+ else:
+ pv = pi = torch.empty(0, device=x.device)
+ return ext.topk_cuda(x.contiguous(), self.k, pv, pi){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<627 chars \u2014 see diff>"
}null{
"pattern": "write_topk",
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"caseInsensitive": false,
"multiline": false,
"toolCallId": "tool_418b20c4-b8bd-4bdd-8662-13aa5a48e95",
"offset": 0
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,14 +1 @@
- def __init__(self, batch: int, n: int, k: int):
- super().__init__()
- self.batch, self.n, self.k = batch, n, k
- self.register_buffer("_dummy", torch.zeros(1))
- self._use_slice = batch == 1 and n >= 65536 and k > 1
- self._partial_cap = min(512, (n + 63) // 64) if self._use_slice else 0
-
- def forward(self, x: torch.Tensor):
- ext = _get_ext()
- if self._use_slice:
- pv, pi = _get_partial_workspace(self.k, self._partial_cap, x.device)
- else:
- pv = pi = torch.empty(0, device=x.device)
- return ext.topk_cuda(x.contiguous(), self.k, pv, pi)+ write_topk<K>(row, s_val, s_idx, out_vals, out_idxs);{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": " write_topk<K>(row, s_val, s_idx, out_vals, out_idxs);"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1 +1 @@
- write_topk<K>(row, s_val, s_idx, out_vals, out_idxs);+ write_topk<K>(bid, s_val, s_idx, partial_vals, partial_idxs);{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": " write_topk<K>(bid, s_val, s_idx, partial_vals, partial_idxs);"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_d613e871-e29e-4e0b-84de-158b9bc7260",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1379 more chars)failure:
command: 'rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/05_topk_bitonic/check.py", line 157, in <module>
main()
File "/workspace/problems/05_topk_bitonic/check.py", line 82, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/05_topk_bitonic
... (8699 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 38,
"limit": 120
}// Insert (val, idx) into sorted-descending arrays of length K (K <= 64).
template <int K>
__device__ __forceinline__ void insert_topk(float val, int64_t idx,
float* vals, int64_t* idxs) {
if (!float_cmp_gt(val, vals[K - 1])) return;
int pos = K - 1;
while (pos > 0 && float_cmp_gt(val, vals[pos - 1])) {
vals[pos] = vals[pos - 1];
idxs[pos] = idxs[pos - 1];
--pos;
}
vals[pos] = val;
idxs[pos] = idx;
}
// Pick thread count so shared memory stays <= 48 KiB.
template <int K>
__host__ __device__ constexpr int threads_for_k() {
if (K <= 16) return 128;
if (K <= 32) return 96;
return 64;
}
template <int K>
__host__ __device__ constexpr size_t smem_bytes(int threads) {
return (size_t)threads * K * (sizeof(float) + sizeof(int64_t));
}
// Tree-reduce per-thread top-k arrays in shared memory.
template <int K>
__device__ void block_reduce_topk(float* local_val, int64_t* local_idx,
float* s_val, int64_t* s_idx, int threads) {
int tid = threadIdx.x;
#pragma unroll
for (int j = 0; j < K; ++j) {
s_val[tid * K + j] = local_val[j];
s_idx[tid * K + j] = local_idx[j];
}
__syncthreads();
for (int stride = threads / 2; stride >= 1; stride >>= 1) {
if (tid < stride) {
float tmp_val[K];
int64_t tmp_idx[K];
merge_topk<K>(s_val + tid * K, s_idx + tid * K,
... (2711 more chars)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1 +1,17 @@
- write_topk<K>(bid, s_val, s_idx, partial_vals, partial_idxs);+// Merge two sorted-desc arrays of size K into one sorted-desc array of size K.
+template <int K>
+__device__ void merge_topk(float* a_val, int64_t* a_idx,
+ float* b_val, int64_t* b_idx,
+ float* out_val, int64_t* out_idx) {
+ int ai = 0, bi = 0, oi = 0;
+ while (oi < K) {
+ if (ai < K && (bi >= K || float_cmp_gt(a_val[ai], b_val[bi]))) {
+ out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
+ } else {
+ out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
+ }
+ ++oi;
+ }
+}
+
+// Pick thread count so shared memory stays <= 48 KiB.{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<640 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,17 +1,3 @@
-// Merge two sorted-desc arrays of size K into one sorted-desc array of size K.
-template <int K>
-__device__ void merge_topk(float* a_val, int64_t* a_idx,
- float* b_val, int64_t* b_idx,
- float* out_val, int64_t* out_idx) {
- int ai = 0, bi = 0, oi = 0;
- while (oi < K) {
- if (ai < K && (bi >= K || float_cmp_gt(a_val[ai], b_val[bi]))) {
- out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
- } else {
- out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
- }
- ++oi;
- }
}
-// Pick thread count so shared memory stays <= 48 KiB.+// Warp reduce max for k=1{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<29 chars \u2014 see diff>"
}null{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 95,
"limit": 50
}if (tid < stride) {
float tmp_val[K];
int64_t tmp_idx[K];
merge_topk<K>(s_val + tid * K, s_idx + tid * K,
s_val + (tid + stride) * K, s_idx + (tid + stride) * K,
tmp_val, tmp_idx);
#pragma unroll
for (int j = 0; j < K; ++j) {
s_val[tid * K + j] = tmp_val[j];
s_idx[tid * K + j] = tmp_idx[j];
}
}
__syncthreads();
}
}
template <int K>
__device__ __forceinline__ void write_topk(int row,
const float* vals, const int64_t* idxs,
float* out_vals, int64_t* out_idxs) {
#pragma unroll
for (int j = 0; j < K; ++j) {
out_vals[(int64_t)row * K + j] = vals[j];
out_idxs[(int64_t)row * K + j] = idxs[j];
}
}
// Warp reduce max for k=1
__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
float& out_val, int64_t& out_idx) {
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
float other_val = __shfl_down_sync(0xffffffff, val, offset);
int64_t other_idx = __shfl_down_sync(0xffffffff, idx, offset);
if (other_val > val) { val = other_val; idx = other_idx; }
}
out_val = val; out_idx = idx;
}
// ---------------------------------------------------------------------------
// k == 1 kernel: one block per ro
... (451 more chars){
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_ebf8c4f3-b57d-4e14-a1b5-d989b815a01",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1375 more chars)failure:
command: 'rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/problems/05_topk_bitonic/check.py", line 157, in <module>
main()
File "/workspace/problems/05_topk_bitonic/check.py", line 82, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/05_topk_bitonic/solution.py", line 406, in forward
return ext.topk_cuda(x.contiguous(), self.k, pv, pi)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected scalar type Long but found Float
stderr: ''
executionTime: 29855
interleavedOutput:
Traceback (most recent call last):
File "/workspace/problems/05_topk_bitonic/check.py", line 157, in <module>
main()
File "/workspace/problems/05_topk_bitonic/check.py", line 82, in main
sol_out = sol_model(*inputs)
... (783 more chars)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,3 +1,2 @@
-}
-
-// Warp reduce max for k=1+ pv = torch.empty(0, dtype=torch.float32, device=x.device)
+ pi = torch.empty(0, dtype=torch.int64, device=x.device){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<137 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_9ef1ed34-05b9-4c6a-a067-4f5c699b986",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bi
... (899 more chars)failure:
command: 'cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 2 {'batch': 32, 'n': 16384, 'k': 32} seed 42 case nominal values: tolerance exceeded: max_abs_diff=0.599911 max_rel_diff=0.137924 bad=983/1024 worst_idx=(9, 0) (atol=0.0001, rtol=0.0001)
stderr: ''
executionTime: 1890
interleavedOutput:
FAIL: shape 2 {'batch': 32, 'n': 16384, 'k': 32} seed 42 case nominal values: tolerance exceeded: max_abs_diff=0.599911 max_rel_diff=0.137924 bad=983/1024 worst_idx=(9, 0) (atol=0.0001, rtol=0.0001)
aborted: False
localExecutionTimeMs: 1819
isBackground: False{
"command": "<624 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_f709b215-25ff-4962-a736-c5f4510dcb9",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name": "
... (1911 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch
import reference
import solution
device = torch.device('cuda:0')
reference.batch, reference.n, reference.k = 32, 16384, 32
torch.manual_seed(42)
x = reference.get_inputs()[0].to(device)
ref = reference.Model(32, 16384, 32).to(device)
sol = solution.Model(32, 16384, 32).to(device)
rv, ri = ref(x)
sv, si = sol(x)
print('max val diff', (rv-sv).abs().max().item())
print('row 9 ref top5', rv[9,:5])
print('row 9 sol top5', sv[9,:5])
bad = (rv-sv).abs() > 1e-4
print('bad count', bad.sum().item())
print('worst rows', bad.any(dim=1).nonzero()[:10])
"
null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,2 +1,6 @@
- pv = torch.empty(0, dtype=torch.float32, device=x.device)
- pi = torch.empty(0, dtype=torch.int64, device=x.device)+template <int K>
+__host__ __device__ constexpr int threads_for_k() {
+ if (K <= 16) return 128;
+ if (K <= 32) return 64;
+ return 64;
+}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<142 chars \u2014 see diff>"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_3cf08605-f9c4-4ea7-98cd-069150a0198",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1380 more chars)null{
"command": "cd /workspace/problems/05_topk_bitonic && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_3fcddeb5-a3b7-48e1-8741-cae6849c546",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_top
... (915 more chars){
"command": "cd /workspace/problems/05_topk_bitonic && BENCHMARK_BASELINES=1 uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_2d0b4d7a-8c88-4b79-886c-6d7199115d0",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /work
... (947 more chars)nullnull{
"command": "<649 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_a0bdb6ef-32d3-43a3-aa85-0f0364accc8",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name": "
... (1961 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch, time
import shapes
device = torch.device('cuda:0')
for i, s in enumerate(shapes.SHAPES):
x = torch.randn(s['batch'], s['n'], device=device)
k = s['k']
# warmup
for _ in range(10):
torch.topk(x, k, dim=-1)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(50):
torch.topk(x, k, dim=-1)
torch.cuda.synchronize()
ms = (time.perf_counter()-t0)/50*1000
bytes_moved = s['batch']*s['n']*4 + s['batch']*s['k']*(4+8)
gbps = bytes_moved / ms / 1e6
print(f'shape {i} {s}: {ms:.3f} ms, {gbps:.1f} GB/s')
"
null{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"limit": 20
}"""Custom CUDA top-k kernel for RTX PRO 6000 (SM120).
Per-row streaming selection with block-local merge and bitonic sort for the
final k elements. Uses load_inline CUDA.
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
CUDA_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <cfloat>--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,6 +1,527 @@
-template <int K>
-__host__ __device__ constexpr int threads_for_k() {
- if (K <= 16) return 128;
- if (K <= 32) return 64;
- return 64;
-}+"""Custom CUDA top-k via radix selection for RTX PRO 6000 (SM120)."""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+CUDA_SRC = r"""
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cstdint>
+#include <cfloat>
+
+#define WARP_SIZE 32
+#define RADIX 256
+
+__device__ __forceinline__ uint32_t float_to_sortable(float f) {
+ uint32_t bits = __float_as_uint(f);
+ uint32_t mask = (-(int32_t)(bits >> 31)) | 0x80000000u;
+ return bits ^ mask;
+}
+
+__device__ __forceinline__ float sortable_to_float(uint32_t u) {
+ uint32_t mask = (((u >> 31) - 1) | 0x80000000u);
+ return __uint_as_float(u ^ mask);
+}
+
+__device__ __forceinline__ void swap_pair(float& va, int64_t& ia, float& vb, int64_t& ib) {
+ float tv = va; int64_t ti = ia;
+ va = vb; ia = ib;
+ vb = tv; ib = ti;
+}
+
+template <int K>
+__device__ __forceinline__ void insert_desc(float val, int64_t idx,
+ float* vals, int64_t* idxs) {
+ if (val <= vals[K - 1]) return;
+ int pos = K - 1;
+ while (pos > 0 && val > vals[pos - 1]) {
+ vals[pos] = vals[pos - 1];
+ idxs[pos] = idxs[pos - 1];
+ --pos;
+ }
+ vals[pos] = val;
+ idxs[pos] = idx;
+}
+
+template <int K>
+__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
+ for (int k2 = 2; k2 <= K; k2 <<= 1) {
+ for (int j = k2 >> 1; j > 0; j >>= 1) {
+ for (int i = 0; i < K; ++i) {
+ int ixj = i ^ j;
+ if (ixj > i) {
+ bool asc = ((i & k2) == 0);
+ bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
+ if (swap_it) swap_pair(vals[i], idxs[i], vals[ixj], idxs[ixj]);
+ }
+ }
+ }
+ }
+}
+
+__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
+ float& out_val, int64_t& out_idx) {
+ for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
+ float ov = __shfl_down_sync(0xffffffff, val, offset);
+ int64_t oi = __shfl_down_sync(0xffffffff, idx, offset);
+ if (ov > val) { val = ov; idx = oi; }
+ }
+ out_val = val; out_idx = idx;
+}
+
+// ---------------------------------------------------------------------------
+// k == 1
+// ---------------------------------------------------------------------------
+__global__ void topk1_kernel(const float* __restrict__ in,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+ const float* row_in = in + (int64_t)row * n;
+ float best = -FLT_MAX;
+ int64_t best_i = 0;
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ float v = row_in[i];
+ if (v > best) { best = v; best_i = i; }
+ }
+ __shared__ float sv[32];
+ __shared__ int64_t si[32];
+ int lane = threadIdx.x & 31, wid = threadIdx.x >> 5;
+ float wv = best; int64_t wi = best_i;
+ warp_reduce_max(wv, wi, wv, wi);
+ if (lane == 0) { sv[wid] = wv; si[wid] = wi; }
+ __syncthreads();
+ if (wid == 0) {
+ wv = (lane < (blockDim.x + 31) / 32) ? sv[lane] : -FLT_MAX;
+ wi = (lane < (blockDim.x + 31) / 32) ? si[lane] : 0;
+ warp_reduce_max(wv, wi, wv, wi);
+ if (lane == 0) { out_vals[row] = wv; out_idxs[row] = wi; }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Radix top-k: one block per row
+// ---------------------------------------------------------------------------
+template <int K>
+__global__ void radix_topk_kernel(const float* __restrict__ in,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+ const float* row_in = in + (int64_t)row * n;
+
+ __shared__ int hist[RADIX];
+ __shared__ uint32_t prefix_u;
+ __shared__ int need_k;
+
+ uint32_t prefix = 0;
+ int need = K;
+
+ for (int shift = 24; shift >= 0; shift -= 8) {
+ if (threadIdx.x < RADIX) hist[threadIdx.x] = 0;
+ __syncthreads();
+
+ uint32_t pmask = (shift == 24) ? 0u : (~0u << (shift + 8));
+
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ uint32_t u = float_to_sortable(row_in[i]);
+ if (shift < 24 && (u & pmask) != (prefix & pmask)) continue;
+ int bucket = (u >> shift) & 0xFF;
+ atomicAdd(&hist[bucket], 1);
+ }
+ __syncthreads();
+
+ if (threadIdx.x == 0) {
+ int acc = 0;
+ uint32_t chosen = 0;
+ int new_need = need;
+ for (int b = 255; b >= 0; --b) {
+ acc += hist[b];
+ if (acc >= need) {
+ chosen = (uint32_t)b << shift;
+ new_need = need - (acc - hist[b]);
+ break;
+ }
+ }
+ prefix |= chosen;
+ prefix_u = prefix;
+ need_k = new_need;
+ }
+ __syncthreads();
+ prefix = prefix_u;
+ need = need_k;
+ }
+
+ uint32_t thresh_u = prefix;
+
+ float g_val[K];
+ int64_t g_idx[K];
+ float e_val[K];
+ int64_t e_idx[K];
+ int ng = 0, ne = 0;
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ g_val[j] = -FLT_MAX; g_idx[j]
... (12986 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<18185 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,170 +1,8 @@
-"""Custom CUDA top-k via radix selection for RTX PRO 6000 (SM120)."""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-
-CUDA_SRC = r"""
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <cuda.h>
-#include <cuda_runtime.h>
-#include <cstdint>
-#include <cfloat>
-
-#define WARP_SIZE 32
-#define RADIX 256
-
-__device__ __forceinline__ uint32_t float_to_sortable(float f) {
- uint32_t bits = __float_as_uint(f);
- uint32_t mask = (-(int32_t)(bits >> 31)) | 0x80000000u;
- return bits ^ mask;
-}
-
-__device__ __forceinline__ float sortable_to_float(uint32_t u) {
- uint32_t mask = (((u >> 31) - 1) | 0x80000000u);
- return __uint_as_float(u ^ mask);
-}
-
-__device__ __forceinline__ void swap_pair(float& va, int64_t& ia, float& vb, int64_t& ib) {
- float tv = va; int64_t ti = ia;
- va = vb; ia = ib;
- vb = tv; ib = ti;
-}
-
template <int K>
-__device__ __forceinline__ void insert_desc(float val, int64_t idx,
- float* vals, int64_t* idxs) {
- if (val <= vals[K - 1]) return;
- int pos = K - 1;
- while (pos > 0 && val > vals[pos - 1]) {
- vals[pos] = vals[pos - 1];
- idxs[pos] = idxs[pos - 1];
- --pos;
- }
- vals[pos] = val;
- idxs[pos] = idx;
-}
-
-template <int K>
-__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
- for (int k2 = 2; k2 <= K; k2 <<= 1) {
- for (int j = k2 >> 1; j > 0; j >>= 1) {
- for (int i = 0; i < K; ++i) {
- int ixj = i ^ j;
- if (ixj > i) {
- bool asc = ((i & k2) == 0);
- bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
- if (swap_it) swap_pair(vals[i], idxs[i], vals[ixj], idxs[ixj]);
- }
- }
- }
- }
-}
-
-__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
- float& out_val, int64_t& out_idx) {
- for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
- float ov = __shfl_down_sync(0xffffffff, val, offset);
- int64_t oi = __shfl_down_sync(0xffffffff, idx, offset);
- if (ov > val) { val = ov; idx = oi; }
- }
- out_val = val; out_idx = idx;
-}
-
-// ---------------------------------------------------------------------------
-// k == 1
-// ---------------------------------------------------------------------------
-__global__ void topk1_kernel(const float* __restrict__ in,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
- const float* row_in = in + (int64_t)row * n;
- float best = -FLT_MAX;
- int64_t best_i = 0;
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- float v = row_in[i];
- if (v > best) { best = v; best_i = i; }
- }
- __shared__ float sv[32];
- __shared__ int64_t si[32];
- int lane = threadIdx.x & 31, wid = threadIdx.x >> 5;
- float wv = best; int64_t wi = best_i;
- warp_reduce_max(wv, wi, wv, wi);
- if (lane == 0) { sv[wid] = wv; si[wid] = wi; }
- __syncthreads();
- if (wid == 0) {
- wv = (lane < (blockDim.x + 31) / 32) ? sv[lane] : -FLT_MAX;
- wi = (lane < (blockDim.x + 31) / 32) ? si[lane] : 0;
- warp_reduce_max(wv, wi, wv, wi);
- if (lane == 0) { out_vals[row] = wv; out_idxs[row] = wi; }
- }
-}
-
-// ---------------------------------------------------------------------------
-// Radix top-k: one block per row
-// ---------------------------------------------------------------------------
-template <int K>
-__global__ void radix_topk_kernel(const float* __restrict__ in,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
- const float* row_in = in + (int64_t)row * n;
-
- __shared__ int hist[RADIX];
- __shared__ uint32_t prefix_u;
- __shared__ int need_k;
-
- uint32_t prefix = 0;
- int need = K;
-
- for (int shift = 24; shift >= 0; shift -= 8) {
- if (threadIdx.x < RADIX) hist[threadIdx.x] = 0;
- __syncthreads();
-
- uint32_t pmask = (shift == 24) ? 0u : (~0u << (shift + 8));
-
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- uint32_t u = float_to_sortable(row_in[i]);
- if (shift < 24 && (u & pmask) != (prefix & pmask)) continue;
- int bucket = (u >> shift) & 0xFF;
- atomicAdd(&hist[bucket], 1);
- }
- __syncthreads();
-
- if (threadIdx.x == 0) {
- int acc = 0;
- uint32_t chosen = 0;
- int new_need = need;
- for (int b = 255; b >= 0; --b) {
- acc += hist[b];
- if (acc >= need) {
- chosen = (uint32_t)b << shift;
- new_need = need - (acc - hist[b]);
- break;
- }
- }
- prefix |= chosen;
- prefix_u = prefix;
- need_k = new_need;
- }
- __syncthreads();
- prefix = prefix_u;
- need = need_k;
- }
-
- uint32_t thresh_u = prefix;
-
- float g_val[K];
- int64_t g_idx[K];
- float e_val[K];
- int64_t e_idx[K];
- int ng = 0, ne = 0;
+__device__ void collect_finalize(const float* row_in, int n, uint32_t thresh_u,
+ float* out_vals, int64_t* out_idxs, int row) {
+ float g_val[K]; int64_t g_idx[K];
+ float e_val[K]; int64_t e_idx[K];
... (13964 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<2269 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,64 +1,2 @@
-template <int K>
-__device__ void collect_finalize(const float* row_in, int n, uint32_t thresh_u,
- float* out_vals, int64_t* out_idxs, int row) {
- float g_val[K]; int64_t g_idx[K];
- float e_val[K]; int64_t e_idx[K];
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- g_val[j] = -FLT_MAX; g_idx[j] = 0;
- e_val[j] = -FLT_MAX; e_idx[j] = 0;
- }
-
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- uint32_t u = float_to_sortable(row_in[i]);
- float v = row_in[i];
- if (u > thresh_u) insert_desc<K>(v, (int64_t)i, g_val, g_idx);
- else if (u == thresh_u) insert_desc<K>(v, (int64_t)i, e_val, e_idx);
- }
-
- __shared__ float sb_g[K];
- __shared__ int64_t sb_gi[K];
- __shared__ float sb_e[K];
- __shared__ int64_t sb_ei[K];
-
- for (int t = 0; t < blockDim.x; ++t) {
- if (threadIdx.x == t) {
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- sb_g[j] = g_val[j]; sb_gi[j] = g_idx[j];
- sb_e[j] = e_val[j]; sb_ei[j] = e_idx[j];
- }
- }
- __syncthreads();
- if (threadIdx.x == 0) {
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- if (sb_g[j] > -FLT_MAX) insert_desc<K>(sb_g[j], sb_gi[j], g_val, g_idx);
- if (sb_e[j] > -FLT_MAX) insert_desc<K>(sb_e[j], sb_ei[j], e_val, e_idx);
- }
- }
- __syncthreads();
- }
-
- if (threadIdx.x == 0) {
- float out_v[K]; int64_t out_i[K];
- #pragma unroll
- for (int j = 0; j < K; ++j) { out_v[j] = -FLT_MAX; out_i[j] = 0; }
- int oi = 0;
- #pragma unroll
- for (int j = 0; j < K && oi < K; ++j) {
- if (g_val[j] > -FLT_MAX) { out_v[oi] = g_val[j]; out_i[oi] = g_idx[j]; ++oi; }
- }
- int need_eq = K - oi;
- #pragma unroll
- for (int j = 0; j < K && need_eq > 0; ++j) {
- if (e_val[j] > -FLT_MAX) { out_v[oi] = e_val[j]; out_i[oi] = e_idx[j]; ++oi; --need_eq; }
- }
- bitonic_sort_desc<K>(out_v, out_i);
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- out_vals[(int64_t)row * K + j] = out_v[j];
- out_idxs[(int64_t)row * K + j] = out_i[j];
- }
- }
+ collect_finalize<K>(row_in, n, prefix, out_vals, out_idxs, row);
}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<70 chars \u2014 see diff>"
}path: ''
error: 'The string to replace was not found in the file.'
modelVisibleError:
The string to replace was not found in the file (even after relaxing whitespace). Found a possible fuzzy match, did you mean:
}
// ---------------------------------------------------------------------------
// k == 1
// ---------------------------------------------------------------------------
__global__ void topk1_kernel(const float* __restrict__ in,
80| float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 100,
"limit": 200
}wv = (lane < (blockDim.x + 31) / 32) ? sv[lane] : -FLT_MAX;
wi = (lane < (blockDim.x + 31) / 32) ? si[lane] : 0;
warp_reduce_max(wv, wi, wv, wi);
if (lane == 0) { out_vals[row] = wv; out_idxs[row] = wi; }
}
}
// ---------------------------------------------------------------------------
// Radix top-k: one block per row
// ---------------------------------------------------------------------------
template <int K>
__global__ void radix_topk_kernel(const float* __restrict__ in,
float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,
int batch, int n) {
int row = blockIdx.x;
if (row >= batch) return;
const float* row_in = in + (int64_t)row * n;
__shared__ int hist[RADIX];
__shared__ uint32_t prefix_u;
__shared__ int need_k;
uint32_t prefix = 0;
int need = K;
for (int shift = 24; shift >= 0; shift -= 8) {
if (threadIdx.x < RADIX) hist[threadIdx.x] = 0;
__syncthreads();
uint32_t pmask = (shift == 24) ? 0u : (~0u << (shift + 8));
for (int i = threadIdx.x; i < n; i += blockDim.x) {
uint32_t u = float_to_sortable(row_in[i]);
if (shift < 24 && (u & pmask) != (prefix & pmask)) continue;
int bucket = (u >> shift) & 0xFF;
atomicAdd(&hist[bucket], 1);
}
__syncthreads();
if (threadIdx.x == 0) {
int
... (5898 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 258
}}
template <int K>
__global__ void radix_collect_large_kernel(const float* __restrict__ in,
float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,
int row, int n, uint32_t thresh_u) {
float g_val[K]; int64_t g_idx[K];
float e_val[K]; int64_t e_idx[K];
#pragma unroll
for (int j = 0; j < K; ++j) {
g_val[j] = -FLT_MAX; g_idx[j] = 0;
e_val[j] = -FLT_MAX; e_idx[j] = 0;
}
const float* row_in = in + (int64_t)row * n;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
uint32_t u = float_to_sortable(row_in[i]);
float v = row_in[i];
if (u > thresh_u) insert_desc<K>(v, (int64_t)i, g_val, g_idx);
else if (u == thresh_u) insert_desc<K>(v, (int64_t)i, e_val, e_idx);
}
__shared__ float sg[128 * 64];
__shared__ int64_t si[128 * 64];
__shared__ float se[128 * 64];
__shared__ int64_t sei[128 * 64];
int tid = threadIdx.x;
#pragma unroll
for (int j = 0; j < K; ++j) {
sg[tid * K + j] = g_val[j]; si[tid * K + j] = g_idx[j];
se[tid * K + j] = e_val[j]; sei[tid * K + j] = e_idx[j];
}
__syncthreads();
for (int stride = blockDim.x / 2; stride >= 1; stride >>= 1) {
if (tid < stride) {
float tmp[K]; int64_t tidx[K];
int ai = 0, bi = 0, oi = 0;
while (oi < K) {
float av = sg
... (6565 more chars)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,2 +1,121 @@
+// ---------------------------------------------------------------------------
+// Radix top-k helpers
+// ---------------------------------------------------------------------------
+template <int K>
+__device__ void collect_finalize(const float* row_in, int n, uint32_t thresh_u,
+ float* out_vals, int64_t* out_idxs, int row) {
+ float g_val[K]; int64_t g_idx[K];
+ float e_val[K]; int64_t e_idx[K];
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ g_val[j] = -FLT_MAX; g_idx[j] = 0;
+ e_val[j] = -FLT_MAX; e_idx[j] = 0;
+ }
+
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ uint32_t u = float_to_sortable(row_in[i]);
+ float v = row_in[i];
+ if (u > thresh_u) insert_desc<K>(v, (int64_t)i, g_val, g_idx);
+ else if (u == thresh_u) insert_desc<K>(v, (int64_t)i, e_val, e_idx);
+ }
+
+ __shared__ float sb_g[K];
+ __shared__ int64_t sb_gi[K];
+ __shared__ float sb_e[K];
+ __shared__ int64_t sb_ei[K];
+
+ for (int t = 0; t < blockDim.x; ++t) {
+ if (threadIdx.x == t) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ sb_g[j] = g_val[j]; sb_gi[j] = g_idx[j];
+ sb_e[j] = e_val[j]; sb_ei[j] = e_idx[j];
+ }
+ }
+ __syncthreads();
+ if (threadIdx.x == 0) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ if (sb_g[j] > -FLT_MAX) insert_desc<K>(sb_g[j], sb_gi[j], g_val, g_idx);
+ if (sb_e[j] > -FLT_MAX) insert_desc<K>(sb_e[j], sb_ei[j], e_val, e_idx);
+ }
+ }
+ __syncthreads();
+ }
+
+ if (threadIdx.x == 0) {
+ float out_v[K]; int64_t out_i[K];
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { out_v[j] = -FLT_MAX; out_i[j] = 0; }
+ int oi = 0;
+ #pragma unroll
+ for (int j = 0; j < K && oi < K; ++j) {
+ if (g_val[j] > -FLT_MAX) { out_v[oi] = g_val[j]; out_i[oi] = g_idx[j]; ++oi; }
+ }
+ int need_eq = K - oi;
+ #pragma unroll
+ for (int j = 0; j < K && need_eq > 0; ++j) {
+ if (e_val[j] > -FLT_MAX) { out_v[oi] = e_val[j]; out_i[oi] = e_idx[j]; ++oi; --need_eq; }
+ }
+ bitonic_sort_desc<K>(out_v, out_i);
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ out_vals[(int64_t)row * K + j] = out_v[j];
+ out_idxs[(int64_t)row * K + j] = out_i[j];
+ }
+ }
+}
+
+template <int K>
+__global__ void radix_topk_kernel(const float* __restrict__ in,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+ const float* row_in = in + (int64_t)row * n;
+
+ __shared__ int hist[RADIX];
+ __shared__ uint32_t prefix_u;
+ __shared__ int need_k;
+
+ uint32_t prefix = 0;
+ int need = K;
+
+ for (int shift = 24; shift >= 0; shift -= 8) {
+ if (threadIdx.x < RADIX) hist[threadIdx.x] = 0;
+ __syncthreads();
+
+ uint32_t pmask = (shift == 24) ? 0u : (~0u << (shift + 8));
+
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ uint32_t u = float_to_sortable(row_in[i]);
+ if (shift < 24 && (u & pmask) != (prefix & pmask)) continue;
+ int bucket = (u >> shift) & 0xFF;
+ atomicAdd(&hist[bucket], 1);
+ }
+ __syncthreads();
+
+ if (threadIdx.x == 0) {
+ int acc = 0;
+ uint32_t chosen = 0;
+ int new_need = need;
+ for (int b = 255; b >= 0; --b) {
+ acc += hist[b];
+ if (acc >= need) {
+ chosen = (uint32_t)b << shift;
+ new_need = need - (acc - hist[b]);
+ break;
+ }
+ }
+ prefix |= chosen;
+ prefix_u = prefix;
+ need_k = new_need;
+ }
+ __syncthreads();
+ prefix = prefix_u;
+ need = need_k;
+ }
+
collect_finalize<K>(row_in, n, prefix, out_vals, out_idxs, row);
}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<4128 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,121 +1,8 @@
-// ---------------------------------------------------------------------------
-// Radix top-k helpers
-// ---------------------------------------------------------------------------
template <int K>
-__device__ void collect_finalize(const float* row_in, int n, uint32_t thresh_u,
- float* out_vals, int64_t* out_idxs, int row) {
- float g_val[K]; int64_t g_idx[K];
- float e_val[K]; int64_t e_idx[K];
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- g_val[j] = -FLT_MAX; g_idx[j] = 0;
- e_val[j] = -FLT_MAX; e_idx[j] = 0;
- }
-
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- uint32_t u = float_to_sortable(row_in[i]);
- float v = row_in[i];
- if (u > thresh_u) insert_desc<K>(v, (int64_t)i, g_val, g_idx);
- else if (u == thresh_u) insert_desc<K>(v, (int64_t)i, e_val, e_idx);
- }
-
- __shared__ float sb_g[K];
- __shared__ int64_t sb_gi[K];
- __shared__ float sb_e[K];
- __shared__ int64_t sb_ei[K];
-
- for (int t = 0; t < blockDim.x; ++t) {
- if (threadIdx.x == t) {
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- sb_g[j] = g_val[j]; sb_gi[j] = g_idx[j];
- sb_e[j] = e_val[j]; sb_ei[j] = e_idx[j];
- }
- }
- __syncthreads();
- if (threadIdx.x == 0) {
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- if (sb_g[j] > -FLT_MAX) insert_desc<K>(sb_g[j], sb_gi[j], g_val, g_idx);
- if (sb_e[j] > -FLT_MAX) insert_desc<K>(sb_e[j], sb_ei[j], e_val, e_idx);
- }
- }
- __syncthreads();
- }
-
- if (threadIdx.x == 0) {
- float out_v[K]; int64_t out_i[K];
- #pragma unroll
- for (int j = 0; j < K; ++j) { out_v[j] = -FLT_MAX; out_i[j] = 0; }
- int oi = 0;
- #pragma unroll
- for (int j = 0; j < K && oi < K; ++j) {
- if (g_val[j] > -FLT_MAX) { out_v[oi] = g_val[j]; out_i[oi] = g_idx[j]; ++oi; }
- }
- int need_eq = K - oi;
- #pragma unroll
- for (int j = 0; j < K && need_eq > 0; ++j) {
- if (e_val[j] > -FLT_MAX) { out_v[oi] = e_val[j]; out_i[oi] = e_idx[j]; ++oi; --need_eq; }
- }
- bitonic_sort_desc<K>(out_v, out_i);
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- out_vals[(int64_t)row * K + j] = out_v[j];
- out_idxs[(int64_t)row * K + j] = out_i[j];
- }
- }
-}
-
-template <int K>
-__global__ void radix_topk_kernel(const float* __restrict__ in,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
+__global__ void radix_collect_large_kernel(const float* __restrict__ in,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int row, int n, uint32_t thresh_u) {
const float* row_in = in + (int64_t)row * n;
-
- __shared__ int hist[RADIX];
- __shared__ uint32_t prefix_u;
- __shared__ int need_k;
-
- uint32_t prefix = 0;
- int need = K;
-
- for (int shift = 24; shift >= 0; shift -= 8) {
- if (threadIdx.x < RADIX) hist[threadIdx.x] = 0;
- __syncthreads();
-
- uint32_t pmask = (shift == 24) ? 0u : (~0u << (shift + 8));
-
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- uint32_t u = float_to_sortable(row_in[i]);
- if (shift < 24 && (u & pmask) != (prefix & pmask)) continue;
- int bucket = (u >> shift) & 0xFF;
- atomicAdd(&hist[bucket], 1);
- }
- __syncthreads();
-
- if (threadIdx.x == 0) {
- int acc = 0;
- uint32_t chosen = 0;
- int new_need = need;
- for (int b = 255; b >= 0; --b) {
- acc += hist[b];
- if (acc >= need) {
- chosen = (uint32_t)b << shift;
- new_need = need - (acc - hist[b]);
- break;
- }
- }
- prefix |= chosen;
- prefix_u = prefix;
- need_k = new_need;
- }
- __syncthreads();
- prefix = prefix_u;
- need = need_k;
- }
-
- collect_finalize<K>(row_in, n, prefix, out_vals, out_idxs, row);
+ collect_finalize<K>(row_in, n, thresh_u, out_vals, out_idxs, row);
}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<439 chars \u2014 see diff>"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_a8746fb0-e603-4c4d-93a7-3f286c4aa7a",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1375 more chars)failure:
command: 'rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 0 {'batch': 1, 'n': 131072, 'k': 64} seed 42 case nominal values: tolerance exceeded: max_abs_diff=3.40282e+38 max_rel_diff=1.04893e+38 bad=64/64 worst_idx=(0, 0) (atol=0.0001, rtol=0.0001)
stderr: ''
executionTime: 99916
interleavedOutput:
FAIL: shape 0 {'batch': 1, 'n': 131072, 'k': 64} seed 42 case nominal values: tolerance exceeded: max_abs_diff=3.40282e+38 max_rel_diff=1.04893e+38 bad=64/64 worst_idx=(0, 0) (atol=0.0001, rtol=0.0001)
aborted: False
localExecutionTimeMs: 99845
isBackground: False{
"command": "<585 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_fe2852c6-d0df-47a0-b2d9-fe10a926328",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name": "
... (1823 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch, importlib
import reference, solution
importlib.reload(solution)
device = torch.device('cuda:0')
for shape in [(64,8192,8),(1,131072,64)]:
b,n,k = shape
torch.manual_seed(42)
x = torch.randn(b,n,device=device)
ref = reference.Model(b,n,k).to(device)
sol = solution.Model(b,n,k).to(device)
rv,ri = ref(x)
sv,si = sol(x)
diff = (rv-sv).abs().max().item()
print(f'{shape}: max_diff={diff}')
if diff > 1e-3:
print(' ref', rv[0,:5])
print(' sol', sv[0,:5])
"
null{
"command": "<1396 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_8deac224-b7ae-4c60-a65f-ec6535a8870",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name":
... (3510 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch
import struct
def float_to_sortable(f):
bits = struct.unpack('I', struct.pack('f', float(f)))[0]
mask = ((-(bits >> 31)) & 0xFFFFFFFF) | 0x80000000
return bits ^ mask
def sortable_to_float(u):
mask = (((u >> 31) - 1) & 0xFFFFFFFF) | 0x80000000
bits = u ^ mask
return struct.unpack('f', struct.pack('I', bits))[0]
torch.manual_seed(42)
x = torch.randn(8192)
us = [float_to_sortable(v.item()) for v in x]
K=8
prefix=0
need=K
row = x
for shift in [24,16,8,0]:
pmask = 0 if shift==24 else ((~0) << (shift+8)) & 0xFFFFFFFF
hist = [0]*256
for i,v in enumerate(x):
u = float_to_sortable(v.item())
if shift < 24 and (u & pmask) != (prefix & pmask):
continue
bucket = (u >> shift) & 0xFF
hist[bucket] += 1
acc=0
for b in range(255,-1,-1):
acc += hist[b]
if acc >= need:
prefix |= b << shift
need = need - (acc - hist[b])
break
print(f'shift={shift} prefix={prefix:08x} need={need} chosen_byte={(prefix>>shift)&0xff}')
thresh_u = prefix
ref = torch.topk(x, K)
greater = sum(1 for u in us if u > thresh_u)
equal = sum(1 for u in us if u == thresh_u)
print('thresh float', sortable_to_float(thresh_u))
print('ref kth', ref.values[-1].item())
print('greater', greater, 'equal', equal)
"
null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,8 +1,2 @@
-template <int K>
-__global__ void radix_collect_large_kernel(const float* __restrict__ in,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int row, int n, uint32_t thresh_u) {
- const float* row_in = in + (int64_t)row * n;
- collect_finalize<K>(row_in, n, thresh_u, out_vals, out_idxs, row);
-}+ for (int b = threadIdx.x; b < RADIX; b += blockDim.x) hist[b] = 0;
+ __syncthreads();{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<99 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,2 +1,2 @@
- for (int b = threadIdx.x; b < RADIX; b += blockDim.x) hist[b] = 0;
- __syncthreads();+ for (int b = threadIdx.x; b < RADIX; b += blockDim.x) local_hist[b] = 0;
+ __syncthreads();{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<97 chars \u2014 see diff>"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_cc60cc66-f347-4c82-bcea-1021b2d83f6",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1372 more chars)failure:
command: 'rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 0 {'batch': 1, 'n': 131072, 'k': 64} seed 123 case nominal values: tolerance exceeded: max_abs_diff=0.022629 max_rel_diff=0.00672554 bad=40/64 worst_idx=(0, 45) (atol=0.0001, rtol=0.0001)
stderr: ''
executionTime: 100582
interleavedOutput:
FAIL: shape 0 {'batch': 1, 'n': 131072, 'k': 64} seed 123 case nominal values: tolerance exceeded: max_abs_diff=0.022629 max_rel_diff=0.00672554 bad=40/64 worst_idx=(0, 45) (atol=0.0001, rtol=0.0001)
aborted: False
localExecutionTimeMs: 100499
isBackground: False{
"command": "<844 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_d3097f06-3284-4bd6-8be0-497cc2f347b",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name": "
... (2358 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch, importlib, reference, solution
importlib.reload(solution)
device = torch.device('cuda:0')
for seed in [42,123,456]:
torch.manual_seed(seed)
x = torch.randn(1,131072,device=device)
ref = reference.Model(1,131072,64).to(device)
sol = solution.Model(1,131072,64).to(device)
rv,ri = ref(x); sv,si = sol(x)
diff = (rv-sv).abs().max().item()
bad = ((rv-sv).abs() > 1e-4).sum().item()
print(f'seed {seed}: max_diff={diff:.6f} bad={bad}')
for shape in [(64,8192,8),(32,16384,32),(16,12000,16),(128,4096,1)]:
b,n,k=shape
torch.manual_seed(42)
x=torch.randn(b,n,device=device)
rv,_=reference.Model(b,n,k).to(device)(x)
sv,_=solution.Model(b,n,k).to(device)(x)
print(f'{shape}: max_diff={(rv-sv).abs().max().item():.6f}')
"
null{
"command": "<1129 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_477c79dc-3d82-405d-ad1f-ecfb5a79d10",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name":
... (2949 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch, struct, reference
def float_to_sortable(f):
bits = struct.unpack('I', struct.pack('f', float(f)))[0]
mask = ((-(bits >> 31)) & 0xFFFFFFFF) | 0x80000000
return bits ^ mask
def radix_thresh(x, K):
prefix=0; need=K
for shift in [24,16,8,0]:
pmask = 0 if shift==24 else ((~0) << (shift+8)) & 0xFFFFFFFF
hist=[0]*256
for v in x.tolist():
u=float_to_sortable(v)
if shift<24 and (u & pmask)!=(prefix & pmask): continue
hist[(u>>shift)&0xFF]+=1
acc=0
for b in range(255,-1,-1):
acc+=hist[b]
if acc>=need:
prefix |= b<<shift
need = need-(acc-hist[b])
break
return prefix
torch.manual_seed(42)
x = torch.randn(8192)
K=8
tu = radix_thresh(x,K)
us = [float_to_sortable(v) for v in x.tolist()]
greater = sum(1 for u in us if u > tu)
equal = sum(1 for u in us if u==tu)
print('greater', greater, 'equal', equal, 'K', K)
ref = torch.topk(torch.tensor(x), K)
print('ref', ref.values.tolist())
"
null{
"command": "<531 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_6641d5b5-8d2f-4bce-b0ad-62a3f9c07a9",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name": "
... (1710 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch, reference, solution
device='cuda:0'
torch.manual_seed(42)
x=torch.randn(64,8192,device=device)
rv,_=reference.Model(64,8192,8).to(device)(x)
sv,_=solution.Model(64,8192,8).to(device)(x)
bad=(rv-sv).abs()>1e-4
rows=bad.any(dim=1).nonzero().flatten()[:5]
for r in rows:
print('row', int(r), 'ref', rv[r].tolist())
print(' sol', sv[r].tolist())
print(' gathered', torch.gather(x[r],0,torch.argsort(x[r],descending=True)[:8]).tolist())
"
null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,2 +1,40 @@
- for (int b = threadIdx.x; b < RADIX; b += blockDim.x) local_hist[b] = 0;
- __syncthreads();+template <int K>
+__device__ void collect_finalize(const float* row_in, int n, uint32_t thresh_u,
+ float* out_vals, int64_t* out_idxs, int row) {
+ float acc_val[K];
+ int64_t acc_idx[K];
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { acc_val[j] = -FLT_MAX; acc_idx[j] = 0; }
+
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ uint32_t u = float_to_sortable(row_in[i]);
+ if (u >= thresh_u) insert_desc<K>(row_in[i], (int64_t)i, acc_val, acc_idx);
+ }
+
+ __shared__ float sb_v[K];
+ __shared__ int64_t sb_i[K];
+
+ for (int t = 0; t < blockDim.x; ++t) {
+ if (threadIdx.x == t) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { sb_v[j] = acc_val[j]; sb_i[j] = acc_idx[j]; }
+ }
+ __syncthreads();
+ if (threadIdx.x == 0) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ if (sb_v[j] > -FLT_MAX) insert_desc<K>(sb_v[j], sb_i[j], acc_val, acc_idx);
+ }
+ }
+ __syncthreads();
+ }
+
+ if (threadIdx.x == 0) {
+ bitonic_sort_desc<K>(acc_val, acc_idx);
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ out_vals[(int64_t)row * K + j] = acc_val[j];
+ out_idxs[(int64_t)row * K + j] = acc_idx[j];
+ }
+ }
+}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<1321 chars \u2014 see diff>"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_a85253a9-5c53-4753-931a-e8f53684db0",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1378 more chars)failure:
command: 'rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 0 {'batch': 1, 'n': 131072, 'k': 64} seed 123 case nominal values: tolerance exceeded: max_abs_diff=0.022629 max_rel_diff=0.00672554 bad=40/64 worst_idx=(0, 45) (atol=0.0001, rtol=0.0001)
stderr: ''
executionTime: 61545
interleavedOutput:
FAIL: shape 0 {'batch': 1, 'n': 131072, 'k': 64} seed 123 case nominal values: tolerance exceeded: max_abs_diff=0.022629 max_rel_diff=0.00672554 bad=40/64 worst_idx=(0, 45) (atol=0.0001, rtol=0.0001)
aborted: False
localExecutionTimeMs: 61467
isBackground: False{
"command": "<558 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_7ffc49a2-f792-45a9-9e57-200f7c5df75",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bitonic"
},
{
"name": "
... (1761 more chars)cd /workspace/problems/05_topk_bitonic && uv run python -c "
import torch, reference, solution, shapes
device='cuda:0'
for shape in shapes.SHAPES:
b,n,k = shape['batch'],shape['n'],shape['k']
for seed in [42,123,456]:
torch.manual_seed(seed)
x=torch.randn(b,n,device=device)
rv,_=reference.Model(b,n,k).to(device)(x)
sv,_=solution.Model(b,n,k).to(device)(x)
bad=((rv-sv).abs()>1e-4).sum().item()
if bad: print(f'FAIL {shape} seed {seed} bad={bad} max={(rv-sv).abs().max().item():.6f}')
print('done')
"
null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,40 +1,383 @@
-template <int K>
-__device__ void collect_finalize(const float* row_in, int n, uint32_t thresh_u,
- float* out_vals, int64_t* out_idxs, int row) {
- float acc_val[K];
- int64_t acc_idx[K];
+"""Custom CUDA top-k via streaming selection + block merge (SM120)."""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+CUDA_SRC = r"""
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <cuda_runtime.h>
+#include <cstdint>
+#include <cfloat>
+
+#define WARP_SIZE 32
+
+__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
+ float* vals, int64_t* idxs) {
+ if (val <= vals[k - 1]) return;
+ int pos = k - 1;
+ while (pos > 0 && val > vals[pos - 1]) {
+ vals[pos] = vals[pos - 1];
+ idxs[pos] = idxs[pos - 1];
+ --pos;
+ }
+ vals[pos] = val;
+ idxs[pos] = idx;
+}
+
+template <int K>
+__device__ __forceinline__ void insert_desc(float val, int64_t idx,
+ float* vals, int64_t* idxs) {
+ insert_desc(K, val, idx, vals, idxs);
+}
+
+template <int K>
+__device__ void merge_topk(const float* a_val, const int64_t* a_idx,
+ const float* b_val, const int64_t* b_idx,
+ float* out_val, int64_t* out_idx) {
+ int ai = 0, bi = 0, oi = 0;
+ while (oi < K) {
+ if (ai < K && (bi >= K || a_val[ai] > b_val[bi])) {
+ out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
+ } else {
+ out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
+ }
+ ++oi;
+ }
+}
+
+template <int K>
+__host__ __device__ constexpr int threads_for_k() {
+ if (K <= 16) return 128;
+ return 64;
+}
+
+template <int K>
+__device__ void block_reduce_topk(float* local_val, int64_t* local_idx,
+ float* s_val, int64_t* s_idx, int nthreads) {
+ int tid = threadIdx.x;
#pragma unroll
- for (int j = 0; j < K; ++j) { acc_val[j] = -FLT_MAX; acc_idx[j] = 0; }
-
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- uint32_t u = float_to_sortable(row_in[i]);
- if (u >= thresh_u) insert_desc<K>(row_in[i], (int64_t)i, acc_val, acc_idx);
- }
-
- __shared__ float sb_v[K];
- __shared__ int64_t sb_i[K];
-
- for (int t = 0; t < blockDim.x; ++t) {
- if (threadIdx.x == t) {
- #pragma unroll
- for (int j = 0; j < K; ++j) { sb_v[j] = acc_val[j]; sb_i[j] = acc_idx[j]; }
- }
- __syncthreads();
- if (threadIdx.x == 0) {
+ for (int j = 0; j < K; ++j) {
+ s_val[tid * K + j] = local_val[j];
+ s_idx[tid * K + j] = local_idx[j];
+ }
+ __syncthreads();
+ for (int stride = nthreads / 2; stride >= 1; stride >>= 1) {
+ if (tid < stride) {
+ float tmp[K]; int64_t tidx[K];
+ merge_topk<K>(s_val + tid * K, s_idx + tid * K,
+ s_val + (tid + stride) * K, s_idx + (tid + stride) * K,
+ tmp, tidx);
#pragma unroll
for (int j = 0; j < K; ++j) {
- if (sb_v[j] > -FLT_MAX) insert_desc<K>(sb_v[j], sb_i[j], acc_val, acc_idx);
+ s_val[tid * K + j] = tmp[j];
+ s_idx[tid * K + j] = tidx[j];
}
}
__syncthreads();
}
+}
+
+__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
+ float& out_val, int64_t& out_idx) {
+ for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
+ float ov = __shfl_down_sync(0xffffffff, val, offset);
+ int64_t oi = __shfl_down_sync(0xffffffff, idx, offset);
+ if (ov > val) { val = ov; idx = oi; }
+ }
+ out_val = val; out_idx = idx;
+}
+
+__global__ void topk1_kernel(const float* __restrict__ in,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+ const float* row_in = in + (int64_t)row * n;
+ float best = -FLT_MAX;
+ int64_t best_i = 0;
+
+ int i = threadIdx.x * 4;
+ for (; i + 3 < n; i += blockDim.x * 4) {
+ float4 v = *reinterpret_cast<const float4*>(row_in + i);
+ if (v.x > best) { best = v.x; best_i = i; }
+ if (v.y > best) { best = v.y; best_i = i + 1; }
+ if (v.z > best) { best = v.z; best_i = i + 2; }
+ if (v.w > best) { best = v.w; best_i = i + 3; }
+ }
+ for (; i < n; i += blockDim.x) {
+ float v = row_in[i];
+ if (v > best) { best = v; best_i = i; }
+ }
+
+ __shared__ float sv[32];
+ __shared__ int64_t si[32];
+ int lane = threadIdx.x & 31, wid = threadIdx.x >> 5;
+ float wv = best; int64_t wi = best_i;
+ warp_reduce_max(wv, wi, wv, wi);
+ if (lane == 0) { sv[wid] = wv; si[wid] = wi; }
+ __syncthreads();
+ if (wid == 0) {
+ wv = (lane < (blockDim.x + 31) / 32) ? sv[lane] : -FLT_MAX;
+ wi = (lane < (blockDim.x + 31) / 32) ? si[lane] : 0;
+ warp_reduce_max(wv, wi, wv, wi);
+ if (lane == 0) { out_vals[row] = wv; out_idxs[row] = wi; }
+ }
+}
+
+template <int K>
+__device__ void scan_row_topk(const float* row_in, int n,
+ float* local_val, int64_t* local_idx) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
+
+ int i = threadIdx.x * 4;
+ for (; i + 3 < n; i += blockDim.x * 4) {
+ float4 v = *reinterpret_cast<const float4*>(row_in + i);
+ insert_desc<K>(v.x, (int64_t)i, local_val, local_idx);
+ insert_desc<K>(v.y, (int64
... (9609 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<13982 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,383 +1,11 @@
-"""Custom CUDA top-k via streaming selection + block merge (SM120)."""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-
-CUDA_SRC = r"""
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <cuda_runtime.h>
-#include <cstdint>
-#include <cfloat>
-
-#define WARP_SIZE 32
-
-__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
- float* vals, int64_t* idxs) {
- if (val <= vals[k - 1]) return;
- int pos = k - 1;
- while (pos > 0 && val > vals[pos - 1]) {
- vals[pos] = vals[pos - 1];
- idxs[pos] = idxs[pos - 1];
- --pos;
- }
- vals[pos] = val;
- idxs[pos] = idx;
-}
-
-template <int K>
-__device__ __forceinline__ void insert_desc(float val, int64_t idx,
- float* vals, int64_t* idxs) {
- insert_desc(K, val, idx, vals, idxs);
-}
-
-template <int K>
-__device__ void merge_topk(const float* a_val, const int64_t* a_idx,
- const float* b_val, const int64_t* b_idx,
- float* out_val, int64_t* out_idx) {
- int ai = 0, bi = 0, oi = 0;
- while (oi < K) {
- if (ai < K && (bi >= K || a_val[ai] > b_val[bi])) {
- out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
- } else {
- out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
- }
- ++oi;
- }
-}
-
-template <int K>
-__host__ __device__ constexpr int threads_for_k() {
- if (K <= 16) return 128;
- return 64;
-}
-
-template <int K>
-__device__ void block_reduce_topk(float* local_val, int64_t* local_idx,
- float* s_val, int64_t* s_idx, int nthreads) {
- int tid = threadIdx.x;
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- s_val[tid * K + j] = local_val[j];
- s_idx[tid * K + j] = local_idx[j];
- }
- __syncthreads();
- for (int stride = nthreads / 2; stride >= 1; stride >>= 1) {
- if (tid < stride) {
- float tmp[K]; int64_t tidx[K];
- merge_topk<K>(s_val + tid * K, s_idx + tid * K,
- s_val + (tid + stride) * K, s_idx + (tid + stride) * K,
- tmp, tidx);
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- s_val[tid * K + j] = tmp[j];
- s_idx[tid * K + j] = tidx[j];
- }
- }
- __syncthreads();
- }
-}
-
-__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
- float& out_val, int64_t& out_idx) {
- for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
- float ov = __shfl_down_sync(0xffffffff, val, offset);
- int64_t oi = __shfl_down_sync(0xffffffff, idx, offset);
- if (ov > val) { val = ov; idx = oi; }
- }
- out_val = val; out_idx = idx;
-}
-
-__global__ void topk1_kernel(const float* __restrict__ in,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
- const float* row_in = in + (int64_t)row * n;
- float best = -FLT_MAX;
- int64_t best_i = 0;
-
- int i = threadIdx.x * 4;
- for (; i + 3 < n; i += blockDim.x * 4) {
- float4 v = *reinterpret_cast<const float4*>(row_in + i);
- if (v.x > best) { best = v.x; best_i = i; }
- if (v.y > best) { best = v.y; best_i = i + 1; }
- if (v.z > best) { best = v.z; best_i = i + 2; }
- if (v.w > best) { best = v.w; best_i = i + 3; }
- }
- for (; i < n; i += blockDim.x) {
- float v = row_in[i];
- if (v > best) { best = v; best_i = i; }
- }
-
- __shared__ float sv[32];
- __shared__ int64_t si[32];
- int lane = threadIdx.x & 31, wid = threadIdx.x >> 5;
- float wv = best; int64_t wi = best_i;
- warp_reduce_max(wv, wi, wv, wi);
- if (lane == 0) { sv[wid] = wv; si[wid] = wi; }
- __syncthreads();
- if (wid == 0) {
- wv = (lane < (blockDim.x + 31) / 32) ? sv[lane] : -FLT_MAX;
- wi = (lane < (blockDim.x + 31) / 32) ? si[lane] : 0;
- warp_reduce_max(wv, wi, wv, wi);
- if (lane == 0) { out_vals[row] = wv; out_idxs[row] = wi; }
- }
-}
-
-template <int K>
-__device__ void scan_row_topk(const float* row_in, int n,
- float* local_val, int64_t* local_idx) {
- #pragma unroll
- for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
-
- int i = threadIdx.x * 4;
- for (; i + 3 < n; i += blockDim.x * 4) {
+ int i = start + threadIdx.x * 4;
+ for (; i + 3 < end; i += blockDim.x * 4) {
float4 v = *reinterpret_cast<const float4*>(row_in + i);
insert_desc<K>(v.x, (int64_t)i, local_val, local_idx);
insert_desc<K>(v.y, (int64_t)(i + 1), local_val, local_idx);
insert_desc<K>(v.z, (int64_t)(i + 2), local_val, local_idx);
insert_desc<K>(v.w, (int64_t)(i + 3), local_val, local_idx);
}
- for (; i < n; i += blockDim.x) {
+ for (i = start + threadIdx.x; i < end; i += blockDim.x) {
insert_desc<K>(row_in[i], (int64_t)i, local_val, local_idx);
- }
-}
-
-template <int K>
-__global__ void topk_kernel(const float* __restrict__ in,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
- const float* row_in = in + (int64_t)row * n;
- float local_val[K]; int64_t local_idx[K];
- scan_row_topk<K>(row_in, n, local_val, local_idx);
-
- extern __shared__
... (8647 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<561 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -6,6 +6,6 @@
insert_desc<K>(v.z, (int64_t)(i + 2), local_val, local_idx);
insert_desc<K>(v.w, (int64_t)(i + 3), local_val, local_idx);
}
- for (i = start + threadIdx.x; i < end; i += blockDim.x) {
+ for (; i < end; i += blockDim.x) {
insert_desc<K>(row_in[i], (int64_t)i, local_val, local_idx);
}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<538 chars \u2014 see diff>"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_7a4bb4f5-b420-4e6f-8560-8ad03e16503",
"simpleCommands": [
"rm",
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
... (1934 more chars)null{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 1,
"limit": 50
}"""Custom CUDA top-k via streaming selection + block merge (SM120)."""
from __future__ import annotations
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
CUDA_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <cfloat>
#define WARP_SIZE 32
__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
float* vals, int64_t* idxs) {
if (val <= vals[k - 1]) return;
int pos = k - 1;
while (pos > 0 && val > vals[pos - 1]) {
vals[pos] = vals[pos - 1];
idxs[pos] = idxs[pos - 1];
--pos;
}
vals[pos] = val;
idxs[pos] = idx;
}
template <int K>
__device__ __forceinline__ void insert_desc(float val, int64_t idx,
float* vals, int64_t* idxs) {
insert_desc(K, val, idx, vals, idxs);
}
template <int K>
__device__ void merge_topk(const float* a_val, const int64_t* a_idx,
const float* b_val, const int64_t* b_idx,
float* out_val, int64_t* out_idx) {
int ai = 0, bi = 0, oi = 0;
while (oi < K) {
if (ai < K && (bi >= K || a_val[ai] > b_val[bi])) {
out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
} else {
out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
}
++oi;
}
}
--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,11 +1,217 @@
- int i = start + threadIdx.x * 4;
- for (; i + 3 < end; i += blockDim.x * 4) {
+#define WARP_SIZE 32
+#define RADIX 256
+
+__device__ __forceinline__ uint32_t float_to_sortable(float f) {
+ uint32_t bits = __float_as_uint(f);
+ uint32_t mask = (-(int32_t)(bits >> 31)) | 0x80000000u;
+ return bits ^ mask;
+}
+
+__device__ __forceinline__ void bitonic_sort_desc(int k, float* vals, int64_t* idxs) {
+ for (int k2 = 2; k2 <= k; k2 <<= 1) {
+ for (int j = k2 >> 1; j > 0; j >>= 1) {
+ for (int i = 0; i < k; ++i) {
+ int ixj = i ^ j;
+ if (ixj > i) {
+ bool asc = ((i & k2) == 0);
+ bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
+ if (swap_it) {
+ float tv = vals[i]; int64_t ti = idxs[i];
+ vals[i] = vals[ixj]; idxs[i] = idxs[ixj];
+ vals[ixj] = tv; idxs[ixj] = ti;
+ }
+ }
+ }
+ }
+ }
+}
+
+template <int K>
+__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
+ bitonic_sort_desc(K, vals, idxs);
+}
+
+__device__ __forceinline__ void radix_pass(const float* row_in, int n,
+ int* hist, uint32_t prefix, int shift) {
+ for (int b = threadIdx.x; b < RADIX; b += blockDim.x) hist[b] = 0;
+ __syncthreads();
+ uint32_t pmask = (shift == 24) ? 0u : (~0u << (shift + 8));
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ uint32_t u = float_to_sortable(row_in[i]);
+ if (shift < 24 && (u & pmask) != (prefix & pmask)) continue;
+ int bucket = (u >> shift) & 0xFF;
+ atomicAdd(&hist[bucket], 1);
+ }
+ __syncthreads();
+}
+
+__device__ __forceinline__ void radix_select_from_hist(int* hist, int& need, uint32_t& prefix, int shift) {
+ int acc = 0;
+ for (int b = 255; b >= 0; --b) {
+ acc += hist[b];
+ if (acc >= need) {
+ prefix |= (uint32_t)b << shift;
+ need = need - (acc - hist[b]);
+ break;
+ }
+ }
+}
+
+template <int K>
+__device__ void collect_ge_thresh(const float* row_in, int n, uint32_t thresh_u,
+ float* out_vals, int64_t* out_idxs, int row) {
+ float acc_val[K]; int64_t acc_idx[K];
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { acc_val[j] = -FLT_MAX; acc_idx[j] = 0; }
+
+ int i = threadIdx.x * 4;
+ for (; i + 3 < n; i += blockDim.x * 4) {
float4 v = *reinterpret_cast<const float4*>(row_in + i);
- insert_desc<K>(v.x, (int64_t)i, local_val, local_idx);
- insert_desc<K>(v.y, (int64_t)(i + 1), local_val, local_idx);
- insert_desc<K>(v.z, (int64_t)(i + 2), local_val, local_idx);
- insert_desc<K>(v.w, (int64_t)(i + 3), local_val, local_idx);
- }
- for (; i < end; i += blockDim.x) {
- insert_desc<K>(row_in[i], (int64_t)i, local_val, local_idx);
- }+ uint32_t u0 = float_to_sortable(v.x);
+ uint32_t u1 = float_to_sortable(v.y);
+ uint32_t u2 = float_to_sortable(v.z);
+ uint32_t u3 = float_to_sortable(v.w);
+ if (u0 >= thresh_u) insert_desc<K>(v.x, (int64_t)i, acc_val, acc_idx);
+ if (u1 >= thresh_u) insert_desc<K>(v.y, (int64_t)(i + 1), acc_val, acc_idx);
+ if (u2 >= thresh_u) insert_desc<K>(v.z, (int64_t)(i + 2), acc_val, acc_idx);
+ if (u3 >= thresh_u) insert_desc<K>(v.w, (int64_t)(i + 3), acc_val, acc_idx);
+ }
+ for (; i < n; i += blockDim.x) {
+ if (float_to_sortable(row_in[i]) >= thresh_u)
+ insert_desc<K>(row_in[i], (int64_t)i, acc_val, acc_idx);
+ }
+
+ __shared__ float sb_v[K];
+ __shared__ int64_t sb_i[K];
+ for (int t = 0; t < blockDim.x; ++t) {
+ if (threadIdx.x == t) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { sb_v[j] = acc_val[j]; sb_i[j] = acc_idx[j]; }
+ }
+ __syncthreads();
+ if (threadIdx.x == 0) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ if (sb_v[j] > -FLT_MAX) insert_desc<K>(sb_v[j], sb_i[j], acc_val, acc_idx);
+ }
+ }
+ __syncthreads();
+ }
+
+ if (threadIdx.x == 0) {
+ bitonic_sort_desc<K>(acc_val, acc_idx);
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ out_vals[(int64_t)row * K + j] = acc_val[j];
+ out_idxs[(int64_t)row * K + j] = acc_idx[j];
+ }
+ }
+}
+
+template <int K>
+__global__ void radix_topk_kernel(const float* __restrict__ in,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+ const float* row_in = in + (int64_t)row * n;
+ __shared__ int hist[RADIX];
+ __shared__ uint32_t prefix_u;
+ __shared__ int need_k;
+
+ uint32_t prefix = 0;
+ int need = K;
+ for (int shift = 24; shift >= 0; shift -= 8) {
+ radix_pass(row_in, n, hist, prefix, shift);
+ if (threadIdx.x == 0) {
+ int nneed = need;
+ uint32_t pp = prefix;
+ radix_select_from_hist(hist, nneed, pp, shift);
+ prefix_u = pp;
+ need_k = nneed;
+ }
+ __syncthreads();
+ prefix = prefix_u;
+ need = need_k;
+ }
+ collect_ge_thresh<K>(row_in, n, prefix, out_vals, out_idxs, row);
+}
+
+__global__ void radix_hist_slice_kernel(const float* __restrict__ in,
+ int* __restrict__ global_hist,
+ int row, int n, int num_blocks,
+ uint32_t prefix, int shift) {
+ int bid = blockIdx
... (3010 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<8183 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,217 +1,4 @@
-#define WARP_SIZE 32
-#define RADIX 256
-
-__device__ __forceinline__ uint32_t float_to_sortable(float f) {
- uint32_t bits = __float_as_uint(f);
- uint32_t mask = (-(int32_t)(bits >> 31)) | 0x80000000u;
- return bits ^ mask;
-}
-
-__device__ __forceinline__ void bitonic_sort_desc(int k, float* vals, int64_t* idxs) {
- for (int k2 = 2; k2 <= k; k2 <<= 1) {
- for (int j = k2 >> 1; j > 0; j >>= 1) {
- for (int i = 0; i < k; ++i) {
- int ixj = i ^ j;
- if (ixj > i) {
- bool asc = ((i & k2) == 0);
- bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
- if (swap_it) {
- float tv = vals[i]; int64_t ti = idxs[i];
- vals[i] = vals[ixj]; idxs[i] = idxs[ixj];
- vals[ixj] = tv; idxs[ixj] = ti;
- }
- }
- }
- }
- }
-}
-
-template <int K>
-__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
- bitonic_sort_desc(K, vals, idxs);
-}
-
-__device__ __forceinline__ void radix_pass(const float* row_in, int n,
- int* hist, uint32_t prefix, int shift) {
- for (int b = threadIdx.x; b < RADIX; b += blockDim.x) hist[b] = 0;
- __syncthreads();
- uint32_t pmask = (shift == 24) ? 0u : (~0u << (shift + 8));
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- uint32_t u = float_to_sortable(row_in[i]);
- if (shift < 24 && (u & pmask) != (prefix & pmask)) continue;
- int bucket = (u >> shift) & 0xFF;
- atomicAdd(&hist[bucket], 1);
- }
- __syncthreads();
-}
-
-__device__ __forceinline__ void radix_select_from_hist(int* hist, int& need, uint32_t& prefix, int shift) {
- int acc = 0;
- for (int b = 255; b >= 0; --b) {
- acc += hist[b];
- if (acc >= need) {
- prefix |= (uint32_t)b << shift;
- need = need - (acc - hist[b]);
- break;
- }
- }
-}
-
-template <int K>
-__device__ void collect_ge_thresh(const float* row_in, int n, uint32_t thresh_u,
- float* out_vals, int64_t* out_idxs, int row) {
- float acc_val[K]; int64_t acc_idx[K];
- #pragma unroll
- for (int j = 0; j < K; ++j) { acc_val[j] = -FLT_MAX; acc_idx[j] = 0; }
-
- int i = threadIdx.x * 4;
- for (; i + 3 < n; i += blockDim.x * 4) {
- float4 v = *reinterpret_cast<const float4*>(row_in + i);
- uint32_t u0 = float_to_sortable(v.x);
- uint32_t u1 = float_to_sortable(v.y);
- uint32_t u2 = float_to_sortable(v.z);
- uint32_t u3 = float_to_sortable(v.w);
- if (u0 >= thresh_u) insert_desc<K>(v.x, (int64_t)i, acc_val, acc_idx);
- if (u1 >= thresh_u) insert_desc<K>(v.y, (int64_t)(i + 1), acc_val, acc_idx);
- if (u2 >= thresh_u) insert_desc<K>(v.z, (int64_t)(i + 2), acc_val, acc_idx);
- if (u3 >= thresh_u) insert_desc<K>(v.w, (int64_t)(i + 3), acc_val, acc_idx);
- }
- for (; i < n; i += blockDim.x) {
- if (float_to_sortable(row_in[i]) >= thresh_u)
- insert_desc<K>(row_in[i], (int64_t)i, acc_val, acc_idx);
- }
-
- __shared__ float sb_v[K];
- __shared__ int64_t sb_i[K];
- for (int t = 0; t < blockDim.x; ++t) {
- if (threadIdx.x == t) {
- #pragma unroll
- for (int j = 0; j < K; ++j) { sb_v[j] = acc_val[j]; sb_i[j] = acc_idx[j]; }
- }
- __syncthreads();
- if (threadIdx.x == 0) {
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- if (sb_v[j] > -FLT_MAX) insert_desc<K>(sb_v[j], sb_i[j], acc_val, acc_idx);
- }
- }
- __syncthreads();
- }
-
- if (threadIdx.x == 0) {
- bitonic_sort_desc<K>(acc_val, acc_idx);
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- out_vals[(int64_t)row * K + j] = acc_val[j];
- out_idxs[(int64_t)row * K + j] = acc_idx[j];
- }
- }
-}
-
-template <int K>
-__global__ void radix_topk_kernel(const float* __restrict__ in,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
- const float* row_in = in + (int64_t)row * n;
- __shared__ int hist[RADIX];
- __shared__ uint32_t prefix_u;
- __shared__ int need_k;
-
- uint32_t prefix = 0;
- int need = K;
- for (int shift = 24; shift >= 0; shift -= 8) {
- radix_pass(row_in, n, hist, prefix, shift);
- if (threadIdx.x == 0) {
- int nneed = need;
- uint32_t pp = prefix;
- radix_select_from_hist(hist, nneed, pp, shift);
- prefix_u = pp;
- need_k = nneed;
- }
- __syncthreads();
- prefix = prefix_u;
- need = need_k;
- }
- collect_ge_thresh<K>(row_in, n, prefix, out_vals, out_idxs, row);
-}
-
-__global__ void radix_hist_slice_kernel(const float* __restrict__ in,
- int* __restrict__ global_hist,
- int row, int n, int num_blocks,
- uint32_t prefix, int shift) {
- int bid = blockIdx.x;
- if (bid >= num_blocks) return;
- int slice = (n + num_blocks - 1) / num_blocks;
- int start = bid * slice;
- int end = min(start + slice, n);
- const float* row_in = in + (int64_t)row * n;
- uint32_t pmask = (shift == 24) ? 0u : (~0u << (shift + 8));
-
- __shared__ int local_hist[RADIX];
- for (int b = threadIdx.x; b < RADIX; b += blockDim.x) local_hist[b] = 0;
- __syncthreads();
-
- for (int i = start + threadIdx.x; i < end; i += blockDim.x) {
... (2601 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<179 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,4 +1,15 @@
- radix_collect_large_kernel<K><<<1, threads, 0, stream>>>(in, out_vals, out_idxs, 0, n, h_prefix);
-}
-
-__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,+template <int K>
+void launch_topk(const float* in, float* out_vals, int64_t* out_idxs,
+ int batch, int n, cudaStream_t stream,
+ float* partial_vals, int64_t* partial_idxs, int partial_cap,
+ int* hist_buf, uint32_t* prefix_dev, int* need_dev) {
+ constexpr int T = 128;
+ if (batch == 1 && n >= 65536 && hist_buf != nullptr) {
+ cudaMemsetAsync(prefix_dev, 0, sizeof(uint32_t), stream);
+ int init_k = K;
+ cudaMemcpyAsync(need_dev, &init_k, sizeof(int), cudaMemcpyHostToDevice, stream);
+ launch_large_row_radix<K>(in, out_vals, out_idxs, n, hist_buf, prefix_dev, need_dev, stream);
+ } else {
+ radix_topk_kernel<K><<<batch, T, 0, stream>>>(in, out_vals, out_idxs, batch, n);
+ }
+}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<768 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,15 +1,59 @@
-template <int K>
-void launch_topk(const float* in, float* out_vals, int64_t* out_idxs,
- int batch, int n, cudaStream_t stream,
- float* partial_vals, int64_t* partial_idxs, int partial_cap,
- int* hist_buf, uint32_t* prefix_dev, int* need_dev) {
- constexpr int T = 128;
- if (batch == 1 && n >= 65536 && hist_buf != nullptr) {
- cudaMemsetAsync(prefix_dev, 0, sizeof(uint32_t), stream);
- int init_k = K;
- cudaMemcpyAsync(need_dev, &init_k, sizeof(int), cudaMemcpyHostToDevice, stream);
- launch_large_row_radix<K>(in, out_vals, out_idxs, n, hist_buf, prefix_dev, need_dev, stream);
+void dispatch_topk(const torch::Tensor& input, torch::Tensor& out_vals,
+ torch::Tensor& out_idxs, int k,
+ torch::Tensor& partial_vals, torch::Tensor& partial_idxs,
+ torch::Tensor& hist_buf, torch::Tensor& prefix_dev,
+ torch::Tensor& need_dev) {
+ const float* in_ptr = input.data_ptr<float>();
+ float* val_ptr = out_vals.data_ptr<float>();
+ int64_t* idx_ptr = out_idxs.data_ptr<int64_t>();
+ int batch = (int)input.size(0);
+ int n = (int)input.size(1);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ float* pval = partial_vals.defined() && partial_vals.numel() > 0
+ ? partial_vals.data_ptr<float>() : nullptr;
+ int64_t* pidx = partial_idxs.defined() && partial_idxs.numel() > 0
+ ? partial_idxs.data_ptr<int64_t>() : nullptr;
+ int partial_cap = partial_vals.defined() ? (int)partial_vals.size(0) : 0;
+ int* hist_ptr = hist_buf.defined() && hist_buf.numel() > 0
+ ? hist_buf.data_ptr<int>() : nullptr;
+ uint32_t* prefix_ptr = prefix_dev.defined() && prefix_dev.numel() > 0
+ ? prefix_dev.data_ptr<uint32_t>() : nullptr;
+ int* need_ptr = need_dev.defined() && need_dev.numel() > 0
+ ? need_dev.data_ptr<int>() : nullptr;
+
+ if (k == 1) {
+ topk1_kernel<<<batch, 256, 0, stream>>>(in_ptr, val_ptr, idx_ptr, batch, n);
+ } else if (k == 8) {
+ launch_topk<8>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap,
+ hist_ptr, prefix_ptr, need_ptr);
+ } else if (k == 16) {
+ launch_topk<16>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap,
+ hist_ptr, prefix_ptr, need_ptr);
+ } else if (k == 32) {
+ launch_topk<32>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap,
+ hist_ptr, prefix_ptr, need_ptr);
+ } else if (k == 64) {
+ launch_topk<64>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap,
+ hist_ptr, prefix_ptr, need_ptr);
} else {
- radix_topk_kernel<K><<<batch, T, 0, stream>>>(in, out_vals, out_idxs, batch, n);
+ TORCH_CHECK(false, "unsupported k=", k);
}
+}
+
+std::vector<torch::Tensor> topk_cuda(torch::Tensor input, int k,
+ torch::Tensor partial_vals,
+ torch::Tensor partial_idxs,
+ torch::Tensor hist_buf,
+ torch::Tensor prefix_dev,
+ torch::Tensor need_dev) {
+ TORCH_CHECK(input.is_cuda() && input.dtype() == torch::kFloat32);
+ TORCH_CHECK(input.dim() == 2);
+ int batch = (int)input.size(0);
+ auto opts_f = torch::TensorOptions().dtype(torch::kFloat32).device(input.device());
+ auto opts_i = torch::TensorOptions().dtype(torch::kInt64).device(input.device());
+ auto values = torch::empty({batch, k}, opts_f);
+ auto indices = torch::empty({batch, k}, opts_i);
+ dispatch_topk(input.contiguous(), values, indices, k, partial_vals, partial_idxs,
+ hist_buf, prefix_dev, need_dev);
+ return {values, indices};
}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<3123 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,59 +1,13 @@
-void dispatch_topk(const torch::Tensor& input, torch::Tensor& out_vals,
- torch::Tensor& out_idxs, int k,
- torch::Tensor& partial_vals, torch::Tensor& partial_idxs,
- torch::Tensor& hist_buf, torch::Tensor& prefix_dev,
- torch::Tensor& need_dev) {
- const float* in_ptr = input.data_ptr<float>();
- float* val_ptr = out_vals.data_ptr<float>();
- int64_t* idx_ptr = out_idxs.data_ptr<int64_t>();
- int batch = (int)input.size(0);
- int n = (int)input.size(1);
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- float* pval = partial_vals.defined() && partial_vals.numel() > 0
- ? partial_vals.data_ptr<float>() : nullptr;
- int64_t* pidx = partial_idxs.defined() && partial_idxs.numel() > 0
- ? partial_idxs.data_ptr<int64_t>() : nullptr;
- int partial_cap = partial_vals.defined() ? (int)partial_vals.size(0) : 0;
- int* hist_ptr = hist_buf.defined() && hist_buf.numel() > 0
- ? hist_buf.data_ptr<int>() : nullptr;
- uint32_t* prefix_ptr = prefix_dev.defined() && prefix_dev.numel() > 0
- ? prefix_dev.data_ptr<uint32_t>() : nullptr;
- int* need_ptr = need_dev.defined() && need_dev.numel() > 0
- ? need_dev.data_ptr<int>() : nullptr;
+_hist_cache: dict[str, torch.Tensor] = {}
+_radix_state: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
- if (k == 1) {
- topk1_kernel<<<batch, 256, 0, stream>>>(in_ptr, val_ptr, idx_ptr, batch, n);
- } else if (k == 8) {
- launch_topk<8>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap,
- hist_ptr, prefix_ptr, need_ptr);
- } else if (k == 16) {
- launch_topk<16>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap,
- hist_ptr, prefix_ptr, need_ptr);
- } else if (k == 32) {
- launch_topk<32>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap,
- hist_ptr, prefix_ptr, need_ptr);
- } else if (k == 64) {
- launch_topk<64>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap,
- hist_ptr, prefix_ptr, need_ptr);
- } else {
- TORCH_CHECK(false, "unsupported k=", k);
- }
-}
-std::vector<torch::Tensor> topk_cuda(torch::Tensor input, int k,
- torch::Tensor partial_vals,
- torch::Tensor partial_idxs,
- torch::Tensor hist_buf,
- torch::Tensor prefix_dev,
- torch::Tensor need_dev) {
- TORCH_CHECK(input.is_cuda() && input.dtype() == torch::kFloat32);
- TORCH_CHECK(input.dim() == 2);
- int batch = (int)input.size(0);
- auto opts_f = torch::TensorOptions().dtype(torch::kFloat32).device(input.device());
- auto opts_i = torch::TensorOptions().dtype(torch::kInt64).device(input.device());
- auto values = torch::empty({batch, k}, opts_f);
- auto indices = torch::empty({batch, k}, opts_i);
- dispatch_topk(input.contiguous(), values, indices, k, partial_vals, partial_idxs,
- hist_buf, prefix_dev, need_dev);
- return {values, indices};
-}+def _get_radix_workspace(device: torch.device):
+ key = str(device)
+ if key not in _radix_state:
+ _radix_state[key] = (
+ torch.zeros(256, dtype=torch.int32, device=device),
+ torch.zeros(1, dtype=torch.int32, device=device),
+ torch.zeros(1, dtype=torch.int32, device=device),
+ )
+ return _radix_state[key]{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<480 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,13 +1,11 @@
-_hist_cache: dict[str, torch.Tensor] = {}
-_radix_state: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
+ self._use_large_radix = batch == 1 and n >= 65536 and k > 1
-
-def _get_radix_workspace(device: torch.device):
- key = str(device)
- if key not in _radix_state:
- _radix_state[key] = (
- torch.zeros(256, dtype=torch.int32, device=device),
- torch.zeros(1, dtype=torch.int32, device=device),
- torch.zeros(1, dtype=torch.int32, device=device),
- )
- return _radix_state[key]+ def forward(self, x: torch.Tensor):
+ ext = _get_ext()
+ pv = torch.empty(0, dtype=torch.float32, device=x.device)
+ pi = torch.empty(0, dtype=torch.int64, device=x.device)
+ if self._use_large_radix:
+ hist, prefix, need = _get_radix_workspace(x.device)
+ else:
+ hist = prefix = need = torch.empty(0, dtype=torch.int32, device=x.device)
+ return ext.topk_cuda(x.contiguous(), self.k, pv, pi, hist, prefix, need){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<542 chars \u2014 see diff>"
}null{
"pattern": "_get_partial|_partial_cache|_hist_cache",
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"caseInsensitive": false,
"multiline": false,
"toolCallId": "tool_6fd3b3a3-9bec-435b-a354-db5028ec69c",
"offset": 0
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,11 +1 @@
- self._use_large_radix = batch == 1 and n >= 65536 and k > 1
-
- def forward(self, x: torch.Tensor):
- ext = _get_ext()
- pv = torch.empty(0, dtype=torch.float32, device=x.device)
- pi = torch.empty(0, dtype=torch.int64, device=x.device)
- if self._use_large_radix:
- hist, prefix, need = _get_radix_workspace(x.device)
- else:
- hist = prefix = need = torch.empty(0, dtype=torch.int32, device=x.device)
- return ext.topk_cuda(x.contiguous(), self.k, pv, pi, hist, prefix, need)+_radix_state: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "_radix_state: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1 +1,4 @@
-_radix_state: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}+ return _radix_state[key]
+
+
+class Model(nn.Module):{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<54 chars \u2014 see diff>"
}path: ''
error: 'The string to replace was not found in the file.'
modelVisibleError:
The string to replace was not found in the file (even after relaxing whitespace). Found a possible fuzzy match, did you mean:
580|
def _get_partial_workspace(k: int, cap: int, device: torch.device):
key = (k, cap, str(device))
if key not in _partial_cache:
_partial_cache[key] = (
torch.empty(cap, k, dtype=torch.float32, device=device),
torch.empty(cap, k, dtype=torch.int64, device=device),
)
return _partial_cache[key]
590|
class Model(nn.Module):
"""Top-k over the last dim of a 2D tensor."""{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 565
}def _get_ext():
global _topk_ext
if _topk_ext is None:
_topk_ext = load_inline(
name="topk_cuda_ext",
cpp_sources="",
cuda_sources=CUDA_SRC,
functions=None,
extra_cuda_cflags=CUDA_FLAGS,
verbose=False,
)
return _topk_ext
def _get_partial_workspace(k: int, cap: int, device: torch.device):
key = (k, cap, str(device))
if key not in _partial_cache:
_partial_cache[key] = (
torch.empty(cap, k, dtype=torch.float32, device=device),
torch.empty(cap, k, dtype=torch.int64, device=device),
)
return _partial_cache[key]
class Model(nn.Module):
"""Top-k over the last dim of a 2D tensor."""
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch, self.n, self.k = batch, n, k
self.register_buffer("_dummy", torch.zeros(1))
self._use_large_radix = batch == 1 and n >= 65536 and k > 1
def forward(self, x: torch.Tensor):
ext = _get_ext()
pv = torch.empty(0, dtype=torch.float32, device=x.device)
pi = torch.empty(0, dtype=torch.int64, device=x.device)
if self._use_large_radix:
hist, prefix, need = _get_radix_workspace(x.device)
else:
hist = prefix = need = torch.empty(0, dtype=torch.int32, device=x.device)
return ext.topk_cuda(x.contiguous(), self.k, pv, pi, hist, prefix, need)
batch = 64
n = 8192
k = 8
d
... (133 more chars){
"pattern": "_get_radix_workspace",
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"caseInsensitive": false,
"multiline": false,
"toolCallId": "tool_d2ac4638-f361-496c-9e2d-6f830498aa7",
"offset": 0
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,4 +1,2 @@
- return _radix_state[key]
-
class Model(nn.Module):{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<24 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,2 +1,2 @@
-
-class Model(nn.Module):+ uint32_t* prefix_ptr = prefix_dev.defined() && prefix_dev.numel() > 0
+ ? reinterpret_cast<uint32_t*>(prefix_dev.data_ptr<int32_t>()) : nullptr;{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<154 chars \u2014 see diff>"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_8f9b0fd2-c62d-43e6-8d6e-656246ad92a",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1375 more chars)failure:
command: 'rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/05_topk_bitonic/check.py", line 157, in <module>
main()
File "/workspace/problems/05_topk_bitonic/check.py", line 82, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/05_topk_bitonic
... (13451 more chars)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,2 +1,22 @@
- uint32_t* prefix_ptr = prefix_dev.defined() && prefix_dev.numel() > 0
- ? reinterpret_cast<uint32_t*>(prefix_dev.data_ptr<int32_t>()) : nullptr;+#define RADIX 256
+
+__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
+ float* vals, int64_t* idxs) {
+ if (val <= vals[k - 1]) return;
+ int pos = k - 1;
+ while (pos > 0 && val > vals[pos - 1]) {
+ vals[pos] = vals[pos - 1];
+ idxs[pos] = idxs[pos - 1];
+ --pos;
+ }
+ vals[pos] = val;
+ idxs[pos] = idx;
+}
+
+template <int K>
+__device__ __forceinline__ void insert_desc(float val, int64_t idx,
+ float* vals, int64_t* idxs) {
+ insert_desc(K, val, idx, vals, idxs);
+}
+
+__device__ __forceinline__ uint32_t float_to_sortable(float f) {{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<674 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,22 +1,4 @@
-#define RADIX 256
-
-__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
- float* vals, int64_t* idxs) {
- if (val <= vals[k - 1]) return;
- int pos = k - 1;
- while (pos > 0 && val > vals[pos - 1]) {
- vals[pos] = vals[pos - 1];
- idxs[pos] = idxs[pos - 1];
- --pos;
- }
- vals[pos] = val;
- idxs[pos] = idx;
}
template <int K>
-__device__ __forceinline__ void insert_desc(float val, int64_t idx,
- float* vals, int64_t* idxs) {
- insert_desc(K, val, idx, vals, idxs);
-}
-
-__device__ __forceinline__ uint32_t float_to_sortable(float f) {+__device__ void merge_topk(const float* a_val, const int64_t* a_idx,{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<88 chars \u2014 see diff>"
}null{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"offset": 1,
"limit": 30
}"""Custom CUDA top-k via streaming selection + block merge (SM120)."""
from __future__ import annotations
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
CUDA_SRC = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <cfloat>
#define WARP_SIZE 32
#define RADIX 256
__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
float* vals, int64_t* idxs) {
if (val <= vals[k - 1]) return;
int pos = k - 1;
while (pos > 0 && val > vals[pos - 1]) {
vals[pos] = vals[pos - 1];
idxs[pos] = idxs[pos - 1];
--pos;
}
vals[pos] = val;
idxs[pos] = idx;
}
{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_df55bde8-4f08-4024-9a49-d1a78923082",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1357 more chars)null{
"taskId": "589665",
"blockUntilMs": 120000,
"regex": "PASS|FAIL"
}null{
"path": "/home/agent/.cursor/projects/workspace-problems-05-topk-bitonic/terminals/589665.txt"
}---
pid: 3291
cwd: "/workspace/problems/05_topk_bitonic"
command: "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1"
started_at: 2026-06-15T18:22:53.910Z
running_for_ms: 135001
---
FAIL: shape 0 {'batch': 1, 'n': 131072, 'k': 64} seed 42 case nominal values: tolerance exceeded: max_abs_diff=0.0318 max_rel_diff=0.00951986 bad=37/64 worst_idx=(0, 40) (atol=0.0001, rtol=0.0001)
---
exit_code: 1
elapsed_ms: 136539
ended_at: 2026-06-15T18:25:10.449Z
---
--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,4 +1,18 @@
-}
+template <int K>
+void launch_topk(const float* in, float* out_vals, int64_t* out_idxs,
+ int batch, int n, cudaStream_t stream,
+ float* partial_vals, int64_t* partial_idxs, int partial_cap,
+ int* hist_buf, uint32_t* prefix_dev, int* need_dev) {
+ constexpr int T = threads_for_k<K>();
+ size_t smem = (size_t)T * K * (sizeof(float) + sizeof(int64_t));
-template <int K>
-__device__ void merge_topk(const float* a_val, const int64_t* a_idx,+ if (batch == 1 && n >= 65536 && partial_vals != nullptr) {
+ int num_blocks = min(partial_cap, (n + T - 1) / T);
+ topk_slice_kernel<K><<<num_blocks, T, smem, stream>>>(
+ in, partial_vals, partial_idxs, 0, n, num_blocks);
+ topk_merge_kernel<K><<<1, T, smem, stream>>>(
+ partial_vals, partial_idxs, out_vals, out_idxs, 0, num_blocks);
+ } else {
+ radix_topk_kernel<K><<<batch, 128, 0, stream>>>(in, out_vals, out_idxs, batch, n);
+ }
+}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<894 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,18 +1,12 @@
-template <int K>
-void launch_topk(const float* in, float* out_vals, int64_t* out_idxs,
- int batch, int n, cudaStream_t stream,
- float* partial_vals, int64_t* partial_idxs, int partial_cap,
- int* hist_buf, uint32_t* prefix_dev, int* need_dev) {
- constexpr int T = threads_for_k<K>();
- size_t smem = (size_t)T * K * (sizeof(float) + sizeof(int64_t));
+_radix_state: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
+_partial_cache: dict[tuple[int, int, str], tuple[torch.Tensor, torch.Tensor]] = {}
- if (batch == 1 && n >= 65536 && partial_vals != nullptr) {
- int num_blocks = min(partial_cap, (n + T - 1) / T);
- topk_slice_kernel<K><<<num_blocks, T, smem, stream>>>(
- in, partial_vals, partial_idxs, 0, n, num_blocks);
- topk_merge_kernel<K><<<1, T, smem, stream>>>(
- partial_vals, partial_idxs, out_vals, out_idxs, 0, num_blocks);
- } else {
- radix_topk_kernel<K><<<batch, 128, 0, stream>>>(in, out_vals, out_idxs, batch, n);
- }
-}+
+def _get_partial_workspace(k: int, cap: int, device: torch.device):
+ key = (k, cap, str(device))
+ if key not in _partial_cache:
+ _partial_cache[key] = (
+ torch.empty(cap, k, dtype=torch.float32, device=device),
+ torch.empty(cap, k, dtype=torch.int64, device=device),
+ )
+ return _partial_cache[key]{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<505 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,12 +1,12 @@
-_radix_state: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
-_partial_cache: dict[tuple[int, int, str], tuple[torch.Tensor, torch.Tensor]] = {}
+ self._use_slice = batch == 1 and n >= 65536 and k > 1
+ self._partial_cap = min(1024, (n + 63) // 64) if self._use_slice else 0
-
-def _get_partial_workspace(k: int, cap: int, device: torch.device):
- key = (k, cap, str(device))
- if key not in _partial_cache:
- _partial_cache[key] = (
- torch.empty(cap, k, dtype=torch.float32, device=device),
- torch.empty(cap, k, dtype=torch.int64, device=device),
- )
- return _partial_cache[key]+ def forward(self, x: torch.Tensor):
+ ext = _get_ext()
+ empty_i = torch.empty(0, dtype=torch.int32, device=x.device)
+ if self._use_slice:
+ pv, pi = _get_partial_workspace(self.k, self._partial_cap, x.device)
+ else:
+ pv = torch.empty(0, dtype=torch.float32, device=x.device)
+ pi = torch.empty(0, dtype=torch.int64, device=x.device)
+ return ext.topk_cuda(x.contiguous(), self.k, pv, pi, empty_i, empty_i, empty_i){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<625 chars \u2014 see diff>"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_3c1ad4f9-e919-4769-a0cc-732ca82fb1c",
"simpleCommands": [
"rm",
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
"value": "-rf"
},
{
... (1372 more chars)failure:
command: 'rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 1 {'batch': 64, 'n': 8192, 'k': 8} seed 42 case nominal values: tolerance exceeded: max_abs_diff=0.319239 max_rel_diff=0.104575 bad=13/512 worst_idx=(49, 3) (atol=0.0001, rtol=0.0001)
stderr: ''
executionTime: 50203
interleavedOutput:
FAIL: shape 1 {'batch': 64, 'n': 8192, 'k': 8} seed 42 case nominal values: tolerance exceeded: max_abs_diff=0.319239 max_rel_diff=0.104575 bad=13/512 worst_idx=(49, 3) (atol=0.0001, rtol=0.0001)
aborted: False
localExecutionTimeMs: 50126
isBackground: False--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,12 +1,382 @@
+"""Custom CUDA top-k via streaming selection + block merge (SM120)."""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+CUDA_SRC = r"""
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <cuda_runtime.h>
+#include <cstdint>
+#include <cfloat>
+
+#define WARP_SIZE 32
+
+__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
+ float* vals, int64_t* idxs) {
+ if (val <= vals[k - 1]) return;
+ int pos = k - 1;
+ while (pos > 0 && val > vals[pos - 1]) {
+ vals[pos] = vals[pos - 1];
+ idxs[pos] = idxs[pos - 1];
+ --pos;
+ }
+ vals[pos] = val;
+ idxs[pos] = idx;
+}
+
+template <int K>
+__device__ __forceinline__ void insert_desc(float val, int64_t idx,
+ float* vals, int64_t* idxs) {
+ insert_desc(K, val, idx, vals, idxs);
+}
+
+template <int K>
+__device__ void merge_topk(const float* a_val, const int64_t* a_idx,
+ const float* b_val, const int64_t* b_idx,
+ float* out_val, int64_t* out_idx) {
+ int ai = 0, bi = 0, oi = 0;
+ while (oi < K) {
+ if (ai < K && (bi >= K || a_val[ai] > b_val[bi])) {
+ out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
+ } else {
+ out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
+ }
+ ++oi;
+ }
+}
+
+template <int K>
+__host__ __device__ constexpr int threads_for_k() {
+ if (K <= 16) return 128;
+ return 64;
+}
+
+template <int K>
+__device__ void block_reduce_topk(float* local_val, int64_t* local_idx,
+ float* s_val, int64_t* s_idx, int nthreads) {
+ int tid = threadIdx.x;
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ s_val[tid * K + j] = local_val[j];
+ s_idx[tid * K + j] = local_idx[j];
+ }
+ __syncthreads();
+ for (int stride = nthreads / 2; stride >= 1; stride >>= 1) {
+ if (tid < stride) {
+ float tmp[K]; int64_t tidx[K];
+ merge_topk<K>(s_val + tid * K, s_idx + tid * K,
+ s_val + (tid + stride) * K, s_idx + (tid + stride) * K,
+ tmp, tidx);
+ #pragma unroll
+ for (int j = 0; j < K; ++j) {
+ s_val[tid * K + j] = tmp[j];
+ s_idx[tid * K + j] = tidx[j];
+ }
+ }
+ __syncthreads();
+ }
+}
+
+__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
+ float& out_val, int64_t& out_idx) {
+ for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
+ float ov = __shfl_down_sync(0xffffffff, val, offset);
+ int64_t oi = __shfl_down_sync(0xffffffff, idx, offset);
+ if (ov > val) { val = ov; idx = oi; }
+ }
+ out_val = val; out_idx = idx;
+}
+
+__global__ void topk1_kernel(const float* __restrict__ in,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+ const float* row_in = in + (int64_t)row * n;
+ float best = -FLT_MAX;
+ int64_t best_i = 0;
+
+ int i = threadIdx.x * 4;
+ for (; i + 3 < n; i += blockDim.x * 4) {
+ float4 v = *reinterpret_cast<const float4*>(row_in + i);
+ if (v.x > best) { best = v.x; best_i = i; }
+ if (v.y > best) { best = v.y; best_i = i + 1; }
+ if (v.z > best) { best = v.z; best_i = i + 2; }
+ if (v.w > best) { best = v.w; best_i = i + 3; }
+ }
+ for (; i < n; i += blockDim.x) {
+ float v = row_in[i];
+ if (v > best) { best = v; best_i = i; }
+ }
+
+ __shared__ float sv[32];
+ __shared__ int64_t si[32];
+ int lane = threadIdx.x & 31, wid = threadIdx.x >> 5;
+ float wv = best; int64_t wi = best_i;
+ warp_reduce_max(wv, wi, wv, wi);
+ if (lane == 0) { sv[wid] = wv; si[wid] = wi; }
+ __syncthreads();
+ if (wid == 0) {
+ wv = (lane < (blockDim.x + 31) / 32) ? sv[lane] : -FLT_MAX;
+ wi = (lane < (blockDim.x + 31) / 32) ? si[lane] : 0;
+ warp_reduce_max(wv, wi, wv, wi);
+ if (lane == 0) { out_vals[row] = wv; out_idxs[row] = wi; }
+ }
+}
+
+template <int K>
+__device__ void scan_row_topk(const float* row_in, int n,
+ float* local_val, int64_t* local_idx) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
+
+ int i = threadIdx.x * 4;
+ for (; i + 3 < n; i += blockDim.x * 4) {
+ float4 v = *reinterpret_cast<const float4*>(row_in + i);
+ insert_desc<K>(v.x, (int64_t)i, local_val, local_idx);
+ insert_desc<K>(v.y, (int64_t)(i + 1), local_val, local_idx);
+ insert_desc<K>(v.z, (int64_t)(i + 2), local_val, local_idx);
+ insert_desc<K>(v.w, (int64_t)(i + 3), local_val, local_idx);
+ }
+ for (; i < n; i += blockDim.x) {
+ insert_desc<K>(row_in[i], (int64_t)i, local_val, local_idx);
+ }
+}
+
+template <int K>
+__global__ void topk_kernel(const float* __restrict__ in,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idxs,
+ int batch, int n) {
+ int row = blockIdx.x;
+ if (row >= batch) return;
+ const float* row_in = in + (int64_t)row * n;
+ float local_val[K]; int64_t local_idx[K];
+ scan_row_topk<K>(row_in, n, local_val, local_idx);
+
+ extern __shared__ char smem[];
+ float* s_val = reinterpret_cast<float*>(smem);
+ int64_t* s_idx = reinterpret_cast<int64_t*>(s_val + blockDim.x * K);
+ bloc
... (8645 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<13897 chars \u2014 see diff>"
}null{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_e349d034-2a60-4454-b236-80f75e7b1fc",
"simpleCommands": [
"rm",
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
... (1936 more chars)null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,136 +1,62 @@
-"""Custom CUDA top-k via streaming selection + block merge (SM120)."""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-
-CUDA_SRC = r"""
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <cuda_runtime.h>
-#include <cstdint>
-#include <cfloat>
-
-#define WARP_SIZE 32
-
-__device__ __forceinline__ void insert_desc(int k, float val, int64_t idx,
- float* vals, int64_t* idxs) {
- if (val <= vals[k - 1]) return;
- int pos = k - 1;
- while (pos > 0 && val > vals[pos - 1]) {
- vals[pos] = vals[pos - 1];
- idxs[pos] = idxs[pos - 1];
- --pos;
- }
- vals[pos] = val;
- idxs[pos] = idx;
-}
-
template <int K>
-__device__ __forceinline__ void insert_desc(float val, int64_t idx,
- float* vals, int64_t* idxs) {
- insert_desc(K, val, idx, vals, idxs);
-}
-
-template <int K>
-__device__ void merge_topk(const float* a_val, const int64_t* a_idx,
- const float* b_val, const int64_t* b_idx,
- float* out_val, int64_t* out_idx) {
- int ai = 0, bi = 0, oi = 0;
- while (oi < K) {
- if (ai < K && (bi >= K || a_val[ai] > b_val[bi])) {
- out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai;
- } else {
- out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi;
+__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
+ for (int k2 = 2; k2 <= K; k2 <<= 1) {
+ for (int j = k2 >> 1; j > 0; j >>= 1) {
+ for (int i = 0; i < K; ++i) {
+ int ixj = i ^ j;
+ if (ixj > i) {
+ bool asc = ((i & k2) == 0);
+ bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
+ if (swap_it) {
+ float tv = vals[i]; int64_t ti = idxs[i];
+ vals[i] = vals[ixj]; idxs[i] = idxs[ixj];
+ vals[ixj] = tv; idxs[ixj] = ti;
+ }
+ }
+ }
}
- ++oi;
}
}
-template <int K>
-__host__ __device__ constexpr int threads_for_k() {
- if (K <= 16) return 128;
- return 64;
-}
+template <int K, int TILE>
+__device__ void scan_row_tiled(const float* row_in, int n,
+ float* local_val, int64_t* local_idx) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
-template <int K>
-__device__ void block_reduce_topk(float* local_val, int64_t* local_idx,
- float* s_val, int64_t* s_idx, int nthreads) {
- int tid = threadIdx.x;
- #pragma unroll
- for (int j = 0; j < K; ++j) {
- s_val[tid * K + j] = local_val[j];
- s_idx[tid * K + j] = local_idx[j];
- }
- __syncthreads();
- for (int stride = nthreads / 2; stride >= 1; stride >>= 1) {
- if (tid < stride) {
- float tmp[K]; int64_t tidx[K];
- merge_topk<K>(s_val + tid * K, s_idx + tid * K,
- s_val + (tid + stride) * K, s_idx + (tid + stride) * K,
- tmp, tidx);
- #pragma unroll
+ __shared__ float tile_v[TILE];
+ __shared__ int64_t tile_i[TILE];
+
+ for (int base = 0; base < n; base += TILE) {
+ int len = min(TILE, n - base);
+ for (int t = threadIdx.x; t < TILE; t += blockDim.x) {
+ if (t < len) {
+ tile_v[t] = row_in[base + t];
+ tile_i[t] = (int64_t)(base + t);
+ } else {
+ tile_v[t] = -FLT_MAX;
+ tile_i[t] = 0;
+ }
+ }
+ __syncthreads();
+ bitonic_sort_desc<TILE>(tile_v, tile_i);
+ __syncthreads();
+ if (threadIdx.x == 0) {
for (int j = 0; j < K; ++j) {
- s_val[tid * K + j] = tmp[j];
- s_idx[tid * K + j] = tidx[j];
+ insert_desc<K>(tile_v[j], tile_i[j], local_val, local_idx);
}
}
__syncthreads();
}
}
-__device__ __forceinline__ void warp_reduce_max(float val, int64_t idx,
- float& out_val, int64_t& out_idx) {
- for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
- float ov = __shfl_down_sync(0xffffffff, val, offset);
- int64_t oi = __shfl_down_sync(0xffffffff, idx, offset);
- if (ov > val) { val = ov; idx = oi; }
- }
- out_val = val; out_idx = idx;
-}
-
-__global__ void topk1_kernel(const float* __restrict__ in,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idxs,
- int batch, int n) {
- int row = blockIdx.x;
- if (row >= batch) return;
- const float* row_in = in + (int64_t)row * n;
- float best = -FLT_MAX;
- int64_t best_i = 0;
-
- int i = threadIdx.x * 4;
- for (; i + 3 < n; i += blockDim.x * 4) {
- float4 v = *reinterpret_cast<const float4*>(row_in + i);
- if (v.x > best) { best = v.x; best_i = i; }
- if (v.y > best) { best = v.y; best_i = i + 1; }
- if (v.z > best) { best = v.z; best_i = i + 2; }
- if (v.w > best) { best = v.w; best_i = i + 3; }
- }
- for (; i < n; i += blockDim.x) {
- float v = row_in[i];
- if (v > best) { best = v; best_i = i; }
- }
-
- __shared__ float sv[32];
- __shared__ int64_t si[32];
- int lane = threadIdx.x & 31, wid = threadIdx.x >> 5;
- float wv = best; int64_t wi = best_i;
- warp_reduce_max(wv, wi, wv, wi);
- if (lane == 0) { sv[wid] = wv; si[wid] = wi; }
- __syncthreads();
- if (wid == 0) {
- wv = (lane < (blockDim.x + 31) /
... (9750 more chars){
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<2610 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,74 +1 @@
-template <int K>
-__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
- for (int k2 = 2; k2 <= K; k2 <<= 1) {
- for (int j = k2 >> 1; j > 0; j >>= 1) {
- for (int i = 0; i < K; ++i) {
- int ixj = i ^ j;
- if (ixj > i) {
- bool asc = ((i & k2) == 0);
- bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
- if (swap_it) {
- float tv = vals[i]; int64_t ti = idxs[i];
- vals[i] = vals[ixj]; idxs[i] = idxs[ixj];
- vals[ixj] = tv; idxs[ixj] = ti;
- }
- }
- }
- }
- }
-}
-
-template <int K, int TILE>
-__device__ void scan_row_tiled(const float* row_in, int n,
- float* local_val, int64_t* local_idx) {
- #pragma unroll
- for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
-
- __shared__ float tile_v[TILE];
- __shared__ int64_t tile_i[TILE];
-
- for (int base = 0; base < n; base += TILE) {
- int len = min(TILE, n - base);
- for (int t = threadIdx.x; t < TILE; t += blockDim.x) {
- if (t < len) {
- tile_v[t] = row_in[base + t];
- tile_i[t] = (int64_t)(base + t);
- } else {
- tile_v[t] = -FLT_MAX;
- tile_i[t] = 0;
- }
- }
- __syncthreads();
- bitonic_sort_desc<TILE>(tile_v, tile_i);
- __syncthreads();
- if (threadIdx.x == 0) {
- for (int j = 0; j < K; ++j) {
- insert_desc<K>(tile_v[j], tile_i[j], local_val, local_idx);
- }
- }
- __syncthreads();
- }
-}
-
-template <int K>
-__device__ void scan_row_topk(const float* row_in, int n,
- float* local_val, int64_t* local_idx) {
- if (K >= 32) {
- scan_row_tiled<K, 256>(row_in, n, local_val, local_idx);
- return;
- }
- #pragma unroll
- for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
-
- int i = threadIdx.x * 4;
- for (; i + 3 < n; i += blockDim.x * 4) {
- float4 v = *reinterpret_cast<const float4*>(row_in + i);
- insert_desc<K>(v.x, (int64_t)i, local_val, local_idx);
- insert_desc<K>(v.y, (int64_t)(i + 1), local_val, local_idx);
- insert_desc<K>(v.z, (int64_t)(i + 2), local_val, local_idx);
- insert_desc<K>(v.w, (int64_t)(i + 3), local_val, local_idx);
- }
- for (; i < n; i += blockDim.x) {
- insert_desc<K>(row_in[i], (int64_t)i, local_val, local_idx);
- }
-}+ self._partial_cap = min(1024, (n + 63) // 64) if self._use_slice else 0{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": " self._partial_cap = min(1024, (n + 63) // 64) if self._use_slice else 0"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1 +1,20 @@
- self._partial_cap = min(1024, (n + 63) // 64) if self._use_slice else 0+template <int N>
+__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
+ for (int k2 = 2; k2 <= N; k2 <<= 1) {
+ for (int j = k2 >> 1; j > 0; j >>= 1) {
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
+ int ixj = i ^ j;
+ if (ixj > i) {
+ bool asc = ((i & k2) == 0);
+ bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
+ if (swap_it) {
+ float tv = vals[i]; int64_t ti = idxs[i];
+ vals[i] = vals[ixj]; idxs[i] = idxs[ixj];
+ vals[ixj] = tv; idxs[ixj] = ti;
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<758 chars \u2014 see diff>"
}null--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1,20 +1 @@
-template <int N>
-__device__ void bitonic_sort_desc(float* vals, int64_t* idxs) {
- for (int k2 = 2; k2 <= N; k2 <<= 1) {
- for (int j = k2 >> 1; j > 0; j >>= 1) {
- for (int i = threadIdx.x; i < N; i += blockDim.x) {
- int ixj = i ^ j;
- if (ixj > i) {
- bool asc = ((i & k2) == 0);
- bool swap_it = asc ? (vals[i] < vals[ixj]) : (vals[i] > vals[ixj]);
- if (swap_it) {
- float tv = vals[i]; int64_t ti = idxs[i];
- vals[i] = vals[ixj]; idxs[i] = idxs[ixj];
- vals[ixj] = tv; idxs[ixj] = ti;
- }
- }
- }
- __syncthreads();
- }
- }
-}+ bitonic_sort_desc<TILE>(tile_v, tile_i);{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": " bitonic_sort_desc<TILE>(tile_v, tile_i);"
}path: ''
error: 'Incorrect tool arguments'
modelVisibleError: 'old_string and new_string are exactly the same'{
"command": "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_17eaf887-227c-456c-a4e6-4c22dff3a4b",
"simpleCommands": [
"rm",
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "rm",
"args": [
{
"type": "word",
... (1953 more chars)null{
"taskId": "237187",
"blockUntilMs": 120000,
"regex": "PASS|FAIL|peak_fraction"
}null{
"path": "/home/agent/.cursor/projects/workspace-problems-05-topk-bitonic/terminals/237187.txt"
}---
pid: 3738
cwd: "/workspace/problems/05_topk_bitonic"
command: "rm -rf /home/agent/.cache/torch_extensions/py311_cu130/topk_cuda_ext && cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1"
started_at: 2026-06-15T18:27:57.606Z
running_for_ms: 130003
---
PASS
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T18:30:10.216927+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T18:30:10.354084+00:00 elapsed_s=0.137 ms=1.855088
shape=0 variant=solution tflops=0.000 gbps=0.283 ms=1.855
shape=0 solution_peak_fraction=0.0002
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T18:30:10.355958+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T18:30:10.362659+00:00 elapsed_s=0.007 ms=0.037968
shape=1 variant=solution tflops=0.055 gbps=55.397 ms=0.038
shape=1 solution_peak_fraction=0.0308
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T18:30:10.364378+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T18:30:10.450654+00:00 elapsed_s=0.086 ms=1.363824
shape=2 variant=solution tflops=0.002 gbps=1.547 ms=1.364
shape=2 solution_peak_fraction=0.0009
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T18:30:10.451309+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T18:30:10.467247+00:00 elapsed_s=0.016 ms=0.190784
shape=3 varian
... (510 more chars)--- a//workspace/problems/05_topk_bitonic/solution.py
+++ b//workspace/problems/05_topk_bitonic/solution.py
@@ -1 +1,18 @@
- bitonic_sort_desc<TILE>(tile_v, tile_i);+template <int K>
+__device__ void scan_row_topk(const float* row_in, int n,
+ float* local_val, int64_t* local_idx) {
+ #pragma unroll
+ for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; }
+
+ int i = threadIdx.x * 4;
+ for (; i + 3 < n; i += blockDim.x * 4) {
+ float4 v = *reinterpret_cast<const float4*>(row_in + i);
+ insert_desc<K>(v.x, (int64_t)i, local_val, local_idx);
+ insert_desc<K>(v.y, (int64_t)(i + 1), local_val, local_idx);
+ insert_desc<K>(v.z, (int64_t)(i + 2), local_val, local_idx);
+ insert_desc<K>(v.w, (int64_t)(i + 3), local_val, local_idx);
+ }
+ for (; i < n; i += blockDim.x) {
+ insert_desc<K>(row_in[i], (int64_t)i, local_val, local_idx);
+ }
+}{
"path": "/workspace/problems/05_topk_bitonic/solution.py",
"streamContent": "<772 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/05_topk_bitonic && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_e00f7f13-4050-4fe4-9dbe-a6ad5ff15cf",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/05_topk_bitonic"
}
],
"fullText": "cd /workspace/problems/05_topk_bi
... (908 more chars)null