"""Optimized top-k kernel via CUDA with load_inline."""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn as nn
import torch.utils.cpp_extension as cpp_ext
OP_TYPE = "topk"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
_ext = None
_CUDA_SRC = r"""
#include <float.h>
#define HEAP_K 8
#define BS 1024
__device__ __forceinline__ void
heapi(float* v, int64_t* i, float val, int64_t idx, int& sz) {
int p = sz++;
v[p] = val; i[p] = idx;
while (p > 0) {
int q = (p - 1) >> 1;
if (v[p] > v[q]) {
float t1 = v[p]; v[p] = v[q]; v[q] = t1;
int64_t t2 = i[p]; i[p] = i[q]; i[q] = t2;
p = q;
} else break;
}
}
__device__ __forceinline__ void
heapsd(float* v, int64_t* i, int sz) {
int p = 0;
while (true) {
int l = 2*p+1, r = 2*p+2, best = p;
if (l < sz && v[l] < v[best]) best = l;
if (r < sz && v[r] < v[best]) best = r;
if (best == p) break;
float t1 = v[p]; v[p] = v[best]; v[best] = t1;
int64_t t2 = i[p]; i[p] = i[best]; i[best] = t2;
p = best;
}
}
extern "C" __global__ void
topk_kernel(const float* __restrict__ X,
float* __restrict__ OutV,
int64_t* __restrict__ OutI,
int batch, int n, int k)
{
int row = blockIdx.x;
int tid = threadIdx.x;
if (row >= batch) return;
extern __shared__ char sbuf[];
// Layout: [sv: 8192 floats][si: 8192 int64][rv: 1024 floats...pad][ri: 1024 int64]
// sv: 0..32767 bytes
// si: 32768..98303 bytes
// rv: 98304..102399 bytes (aligned to 8)
// ri: 102400..110591 bytes
float* sv = (float*)sbuf;
int64_t* si = (int64_t*)(sbuf + 8192 * 4);
float* rv = (float*)(sbuf + (8192 * 4 + 8192 * 8));
int64_t* ri = (int64_t*)(sbuf + ((8192 * 4 + 8192 * 8 + 1024 * 4 + 7) & ~7));
// Phase 1: each thread scans block-stride tile, maintains top-HEAP_K
float lv[HEAP_K]; int64_t li[HEAP_K]; int lsz = 0;
for (int i = tid; i < n; i += BS) {
float v = X[row * n + i];
if (lsz < HEAP_K) {
heapi(lv, li, v, i, lsz);
} else if (v > lv[0]) {
lv[0] = v; li[0] = i;
heapsd(lv, li, lsz);
}
}
for (int p = lsz; p < HEAP_K; p++) {
lv[p] = -FLT_MAX; li[p] = -1;
}
int gbase = tid * HEAP_K;
for (int p = 0; p < HEAP_K; p++) {
sv[gbase + p] = lv[p];
si[gbase + p] = li[p];
}
__syncthreads();
int total = BS * HEAP_K;
// Phase 2: k-pass max extraction
for (int kk = 0; kk < k; kk++) {
float mymx = -FLT_MAX;
int mymi = 0;
for (int j = tid; j < total; j += BS) {
if (sv[j] > mymx) { mymx = sv[j]; mymi = j; }
}
rv[tid] = mymx;
ri[tid] = mymi;
__syncthreads();
for (int off = BS / 2; off > 0; off >>= 1) {
__syncthreads();
if (tid < off) {
if (rv[tid + off] > rv[tid]) {
rv[tid] = rv[tid + off];
ri[tid] = ri[tid + off];
}
}
}
__syncthreads();
float wval = rv[0];
int wslot = (int)ri[0];
if (tid == 0) {
OutV[row * k + kk] = wval;
OutI[row * k + kk] = si[wslot];
}
if (tid == wslot % BS) {
sv[wslot] = -FLT_MAX;
}
__syncthreads();
}
}
extern "C" void
topk_launch(const float* X, float* OutV, int64_t* OutI,
int batch, int n, int k, cudaStream_t stream) {
// sv: 8192 float = 32768 B
// si: 8192 int64 = 65536 B
// rv: 1024 float = 4096 B
// ri: 1024 int64 = 8192 B
// total = 110592 B
int sm = 8192 * 4 + 8192 * 8 + 1024 * 4 + 1024 * 8;
topk_kernel<<<batch, BS, sm, stream>>>(X, OutV, OutI, batch, n, k);
}
"""
_CPP_SRC = """
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h>
extern "C" void
topk_launch(const float* X, float* OutV, int64_t* OutI,
int batch, int n, int k, cudaStream_t stream);
std::tuple<torch::Tensor, torch::Tensor>
topk_func(const torch::Tensor& x, int64_t k) {
auto opts = x.options().requires_grad(false);
int64_t batch = x.size(0);
int64_t n = x.size(1);
auto outV = torch::empty({batch, k}, opts);
auto outI = torch::empty({batch, k}, opts.dtype(torch::kInt64));
topk_launch(
x.data_ptr<float>(),
outV.data_ptr<float>(),
outI.data_ptr<int64_t>(),
(int)batch, (int)n, (int)k, 0
);
return std::make_tuple(outV, outI);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_func", &topk_func, "top-k via streaming selection");
}
"""
def _ensure_ext():
global _ext
if _ext is not None:
return _ext
_ext = cpp_ext.load_inline(
name="topk_cuda_ext",
cpp_sources=_CPP_SRC.strip(),
cuda_sources=_CUDA_SRC.strip(),
extra_cuda_cflags=["-O3", "-use_fast_math"],
verbose=False,
)
return _ext
class Model(nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch, self.n, self.k = batch, n, k
self.register_buffer("_dummy", torch.zeros(1))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
result = ext.topk_func(x, self.k)
torch.cuda.synchronize()
return result
batch, n, k = 64, 8192, 8
def get_inputs():
return [torch.randn(batch, n, dtype=torch.float32)]
def get_init_inputs():
return [batch, n, k]
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py", line 149, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py", line 95, in main
ok, msg = check_correctness(
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py", line 71, in check_correctness
if torch.isnan(solution_out).any():
^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
{
"problem": "05_topk_bitonic",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-27b",
"reasoning_effort": "",
"has_solution": true,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 3407858, "output_tokens": 32480, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 24083, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/reference.py</path>
<type>file</type>
<content>
1: """Naive top-k reference: torch.topk over the last dim.
2:
3: This is the correctness oracle. The agent's solution must produce the same
4: top-k values (and equivalent indices modulo ties) within the tolerance
5: declared in problem.yaml. Note that solution.py is FORBIDDEN from calling
6: torch.topk / torch.sort / torch.kthvalue (see problem.yaml).
7: """
8: import torch
9: import torch.nn as nn
10:
11: OP_TYPE = "topk"
12: SUPPORTED_PRECISIONS = ["fp32"]
13: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
14:
15:
16: class Model(nn.Module):
17: """Top-k over the last dim of a 2D tensor.
18:
19: Input:
20: x: (batch, n) fp32
21: Output:
22: values: (batch, k) fp32, sorted descending
23: indices: (batch, k) int64, into the last dim of x
24: """
25:
26: def __init__(self, batch: int, n: int, k: int):
27: super().__init__()
28: self.batch, self.n, self.k = batch, n, k
29: # No learned parameters, but declare a dummy buffer so state_dict
30: # is non-empty and load_state_dict(strict=True) is meaningful.
31: self.register_buffer("_dummy", torch.zeros(1))
32:
33: def forward(self, x: torch.Tensor):
34: values, indices = torch.topk(x, k=self.k, dim=-1, largest=True, sorted=True)
35: return values, indices
36:
37:
38: # Module-level shims rebuilt b
... (419 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for TopK.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each. Top-k correctness has two parts:
5:
6: 1. VALUES: sol_values must match ref_values within fp32 tol. Both are
7: returned sorted descending, so positional comparison is well-defined.
8: 2. INDICES: lenient — we do NOT require sol_indices == ref_indices because
9: ties in x can yield multiple valid index sets. Instead we gather x at
10: sol_indices and check those values match ref_values within tol. This
11: catches "wrong indices" without false-failing on legitimate tie-breaks.
12:
13: Also rejects forbidden ops by grep.
14: """
15: import re
16: import sys
17: from pathlib import Path
18:
19: import torch
20: import yaml
21:
22: REPO_ROOT = Path(__file__).resolve().parents[2]
23: sys.path.insert(0, str(REPO_ROOT))
24:
25: from src.eval.correctness import check_correctness # noqa: E402
26:
27:
28: def main():
29: try:
30: import reference
31: import shapes
32: import solution
33: except Exception as e:
34: print(f"FAIL: import error: {e}")
35: sys.exit(1)
36:
37: problem_yaml = Path("problem.yaml")
38: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
39:
40: # --- Forbidden-op check --------------------------------
... (4902 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for TopK.
2:
3: For each shape: times eager reference (torch.topk), compiled reference, SOTA
4: (also torch.topk — see sota.py), and the agent's solution. Reports achieved
5: TFLOPS, GB/s, and peak_fraction (vs DRAM bandwidth, since this is memory-bound).
6:
7: Output lines the harness picks up:
8: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10: """
11: import sys
12: from math import exp, log
13: from pathlib import Path
14:
15: import torch
16: import yaml
17:
18: REPO_ROOT = Path(__file__).resolve().parents[2]
19: sys.path.insert(0, str(REPO_ROOT))
20:
21: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22: from src.eval.timing import time_fn # noqa: E402
23: from src.hardware import get as get_hw # noqa: E402
24:
25:
26: def _eval_formula(expr: str, vars: dict) -> float:
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "memory")
40
... (3250 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/problem.yaml</path>
<type>file</type>
<content>
1: name: 05_topk_bitonic
2: display_name: "TopK via Bitonic Sort"
3: precision: fp32
4: regime: memory
5:
6: # Top-k is dominated by the input read (small output, no reduction over k).
7: # Comparator-network FLOPs are not the bottleneck on real hardware, so we
8: # track them but score on bandwidth.
9: flops_formula: "batch * n * 4" # ~O(n log^2 n) compares total, but the
10: # bitonic network is so cheap relative
11: # to memory that a coarse 4*n estimate
12: # is fine for telemetry only.
13: bytes_formula: "batch * n * 4 + batch * k * (4 + 8)" # fp32 input read + (fp32 value + int64 idx) output
14:
15: hardware: [RTX_PRO_6000]
16: peak_tflops_key: fp32
17: peak_bandwidth_key: dram
18:
19: # Top-k correctness:
20: # - VALUES must match within fp32 atol/rtol (the kth largest value is
21: # well-defined modulo float-equal ties, so we use a loose-ish tol).
22: # - INDICES are checked leniently: for each row, the multiset of returned
23: # indices must select values that match ref values within tol. Direct
24: # index equality is NOT required (ties in x can yield different valid
25: # index sets).
26: tolerance:
27: float32: 1.0e-4
28:
29: # Forbidden ops — using any of these in solution.py fails corre
... (1071 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for TopK.
2:
3: Mix of:
4: - decoder vocab top-k (single sequence, very large n, moderate k) — pure
5: bandwidth test; the input read dominates everything.
6: - prefill / batched attention top-k (many rows, moderate n, small k) — tests
7: per-row parallelism and shared-memory bitonic networks.
8: - non-power-of-2 n stress case — bitonic sort networks naturally want
9: powers of two; this forces the agent to handle padding or partial sorts.
10: - small-k limit — k=1 (argmax) is a degenerate but useful sanity case.
11: """
12:
13: SHAPES = [
14: {"batch": 1, "n": 131072, "k": 64}, # decoder vocab top-k (Llama vocab ~128k)
15: {"batch": 64, "n": 8192, "k": 8}, # prefill / attention top-k
16: {"batch": 32, "n": 16384, "k": 32}, # mid-size batched
17: {"batch": 16, "n": 12000, "k": 16}, # non-power-of-2 n stress
18: {"batch": 128, "n": 4096, "k": 1}, # batched argmax (k=1 corner case)
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for TopK: torch.topk itself.
2:
3: torch.topk dispatches to a CUB-backed kernel that uses radix-select for
4: moderate k and a tuned bitonic sort for small n. It is the bar the agent's
5: hand-rolled bitonic kernel must beat. There is no obvious vendor library that
6: does better on the (batch, n, k) shape mix we evaluate — Faiss BlockSelect is
7: specialized for k>=32 with much larger n, and CUB's DeviceSegmentedRadixSort
8: sorts the full row (overkill for top-k).
9:
10: This file is INTENTIONALLY allowed to call torch.topk because it is the SOTA
11: oracle, not the agent's submission. The agent's solution.py is forbidden from
12: using torch.topk (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def sota_forward(x: torch.Tensor, k: int):
20: """Best-available top-k reference. x: (batch, n) fp32."""
21: return torch.topk(x, k=k, dim=-1, largest=True, sorted=True)
22:
23:
24: def is_available() -> bool:
25: return True
(End of file - total 25 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"command": "nvidia-smi --query-gpu=name,clock_mem --format=csv,noheader 2>/dev/null | head -1",
"description": "Check GPU info"
}Field "clock_mem" is not a valid field to query.
{
"command": "nvcc --version 2>/dev/null | tail -1",
"description": "Check nvcc version"
}Build cuda_13.2.r13.2/compiler.37434383_0
{
"command": "uv run python -c \"import triton; print(triton.__version__)\" 2>/dev/null || echo \"no triton\"",
"description": "Check triton availability"
}3.6.0
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware.py"
}{
"command": "nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader 2>/dev/null",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 595.58.03, 97887 MiB
{
"command": "nvcc --version 2>&1 | head -2",
"description": "Check CUDA compiler version"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -0,0 +1,95 @@
+"""Triton-based top-k kernel via bitonic-style selection."""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+from typing import Tuple
+
+OP_TYPE = "topk"
+SUPPORTED_PRECISIONS = ["fp32"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+@triton.jit
+def topk_kernel(
+ X, # (batch, n) fp32
+ Values, # (batch, k) fp32 output
+ Indices, # (batch, k) int64 output
+ stride_x,
+ stride_v,
+ stride_i,
+ n,
+ k,
+ BLOCK: tl.constexpr,
+ USE_MULTI_TILE: tl.constexpr,
+):
+ row = tl.program_id(0)
+ pid = row
+
+ # Phase 1: each thread scans its tile and maintains local top-k threshold
+ # Track the k-th largest value seen by this thread
+ local_kth_val = tl.sum(
+ tl.full([BLOCK], tl.math.inf, dtype=tl.float32)
+ )
+ local_kth_idx = tl.full([BLOCK], 0, dtype=tl.int64)
+
+ # Each thread processes ceil(n / BLOCK) elements via block-stride
+ num_tiles = tl.cdiv(n, BLOCK)
+
+ if USE_MULTI_TILE:
+ # Maintain per-thread top-k list in registers
+ # We track threshold: k-th largest so far
+ threshold = tl.full([BLOCK], -tl.math.inf, dtype=tl.float32)
+ threshold_idx = tl.full([BLOCK], 0, dtype=tl.int64)
+
+ for tile_idx in range(num_tiles):
+ base_off = tl.arange(0, BLOCK) + tile_idx * BLOCK
+ mask = base_off < n
+ x = tl.load(X + row * stride_x + base_off, mask=mask, other=-tl.math.inf)
+
+ # Insert x into local top-k: if x > threshold, update threshold
+ # For simplicity, track just the minimum of local top-k (threshold)
+ # When we see more than k elements, threshold = max(threshold, k-th best)
+ # Simple approach: for each element, if x > threshold, raise threshold
+ # This overcounts but gives an upper bound
+
+ # Better: maintain full local top-k with insertion
+ # Since k can be up to 64, use register-level approach
+ # Count how many elements >= threshold + this one
+ # Just track all elements and their positions, then reduce
+
+ # Simplest correct approach: maintain sorted buffer
+ # For each new element x with index i:
+ # if threshold < k elements seen: add to buffer, update threshold
+ # else if x > threshold[0]: replace threshold[0], shift down, re-sort
+
+ new_val = tl.where(x > threshold, x, threshold)
+ new_idx = tl.where(x > threshold, base_off, threshold_idx)
+ threshold = new_val
+ threshold_idx = new_idx
+
+ # threshold is now >= the local top-1 (not k). We need proper top-k tracking.
+ # Let me use a different approach entirely.
+ pass
+
+ # Write results
+ pass
+
+
+class Model(nn.Module):
+ def __init__(self, batch: int, n: int, k: int):
+ super().__init__()
+ self.batch, self.n, self.k = batch, n, k
+ self.register_buffer("_dummy", torch.zeros(1))
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ ...
+
+
+def get_inputs():
+ x = torch.randn(batch, n, dtype=torch.float32)
+ return [x]
+
+
+def get_init_inputs():
+ return [batch, n, k]
{
"content": "<3175 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_capability(0))\"",
"description": "Check PyTorch version and CUDA capability"
}2.11.0+cu130
(12, 0)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,79 +1,487 @@
-"""Triton-based top-k kernel via bitonic-style selection."""
+"""Optimized top-k kernel via CUDA with streaming selection."""
+import functools
+from typing import Tuple
+
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
-from typing import Tuple
OP_TYPE = "topk"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-
-@triton.jit
-def topk_kernel(
- X, # (batch, n) fp32
- Values, # (batch, k) fp32 output
- Indices, # (batch, k) int64 output
- stride_x,
- stride_v,
- stride_i,
- n,
- k,
- BLOCK: tl.constexpr,
- USE_MULTI_TILE: tl.constexpr,
-):
- row = tl.program_id(0)
- pid = row
-
- # Phase 1: each thread scans its tile and maintains local top-k threshold
- # Track the k-th largest value seen by this thread
- local_kth_val = tl.sum(
- tl.full([BLOCK], tl.math.inf, dtype=tl.float32)
+# ─── CUDA kernel (compiled at first import) ───────────────────────────────────
+
+_HEAP_K = 8
+_BLOCK_SIZE = 1024
+
+_cuda_code = rf"""
+#include <cuda_runtime.h>
+#include <float.h>
+
+#define HEAP_K {_HEAP_K}
+#define HEAP_CAP (2 * HEAP_K)
+
+__device__ __forceinline__
+void heap_insert(float* buf, int64_t* ibuf, float val, int64_t idx, int& size) {{
+ int pos = size++;
+ buf[pos] = val;
+ ibuf[pos] = idx;
+ while (pos > 0) {{
+ int parent = (pos - 1) >> 1;
+ if (buf[pos] > buf[parent]) {{
+ float tv = buf[pos]; buf[pos] = buf[parent]; buf[parent] = tv;
+ int64_t ti = ibuf[pos]; ibuf[pos] = ibuf[parent]; ibuf[parent] = ti;
+ pos = parent;
+ }} else break;
+ }}
+}}
+
+__device__ __forceinline__
+void heap_extract_min(float* buf, int64_t* ibuf, int& size) {{
+ if (size < 2) {{ size--; return; }}
+ size--;
+ buf[0] = buf[size]; ibuf[0] = ibuf[size];
+ int i = 0;
+ for (;;) {{
+ int l = 2*i + 1, r = 2*i + 2, smallest = i;
+ if (l < size && buf[l] < buf[smallest]) smallest = l;
+ if (r < size && buf[r] < buf[smallest]) smallest = r;
+ if (smallest == i) break;
+ float tv = buf[i]; buf[i] = buf[smallest]; buf[smallest] = tv;
+ int64_t ti = ibuf[i]; ibuf[i] = ibuf[smallest]; ibuf[smallest] = ti;
+ i = smallest;
+ }}
+}}
+
+template <int HEAP_K_T>
+__global__ void topk_kernel(
+ const float* __restrict__ x,
+ float* __restrict__ out_vals,
+ int64_t* __restrict__ out_idx,
+ int batch, int n, int k
+) {{
+ int row = blockIdx.x;
+ if (row >= batch) return;
+
+ int tid = threadIdx.x;
+ int block = blockDim.x; // 1024
+
+ // Each thread scans elements at [tid, tid+block, tid+2*block, ...]
+ // and maintains a local top-HEAP_K heap.
+ float local_vals[HEAP_K_T];
+ int64_t local_idx[HEAP_K_T];
+ int local_size = 0;
+
+ for (int i = tid; i < n; i += block) {{
+ float v = x[row * n + i];
+ if (local_size < HEAP_K_T) {{
+ heap_insert(local_vals, local_idx, v, i, local_size);
+ }} else if (v > local_vals[0]) {{
+ local_vals[0] = v;
+ local_idx[0] = i;
+ heap_sift_down(local_vals, local_idx, local_size);
+ }}
+ }}
+ // pad remaining slots to full capacity with -inf
+ for (int p = local_size; p < HEAP_K_T; p++) {{
+ local_vals[p] = -FLT_MAX;
+ local_idx[p] = -1;
+ }}
+ local_size = HEAP_K_T;
+
+ // Write local heaps to shared memory
+ extern __shared__ char shmem[];
+ float* sh_vals = reinterpret_cast<float*>(shmem);
+ int64_t* sh_idx = reinterpret_cast<int64_t*>(shmem + block * HEAP_K_T * sizeof(float));
+
+ for (int p = 0; p < HEAP_K_T; p++) {{
+ sh_vals[tid * HEAP_K_T + p] = local_vals[p];
+ sh_idx[tid * HEAP_K_T + p] = local_idx[p];
+ }}
+ __syncthreads();
+
+ // Collectively select top-k from block*HEAP_K candidates using
+ // k-pass max extraction with linear scan in shared memory.
+ // This is O(k * block * HEAP_K) shared memory ops which for
+ // k=64, block=1024, HEAP_K=8 is ~5M ops — well within budget.
+ for (int kk = 0; kk < k; kk++) {{
+ // Phase A: tree reduction — find max
+ int total = block * HEAP_K_T;
+ int si = tid + kk * block; // offset for this pass
+ // We do a direct scan: each thread checks a block-stride
+ // through all candidates
+ float mymax = -FLT_MAX;
+ int64_t mymi = 0;
+ for (int j = tid; j < total; j += block) {{
+ // Candidates above current pass have been zeroed
+ // Use original shared memory; mark consumed
+ if (sh_vals[j] > mymax) {{
+ mymax = sh_vals[j];
+ mymi = j;
+ }}
+ }}
+ // Tree reduce max within block
+ for (int offset = block >> 1; offset > 0; offset >>= 1) {{
+ __syncthreads();
+ float up = (tid < offset) ? sh_vals[si + offset] : -FLT_MAX;
+ int64_t ui = (tid < offset) ? sh_idx[si + offset] : 0;
+ si = tid;
+ // This approach doesn't work well for tree reduce of "index of max"
+ // Let me use a simpler all-reduce: each thread writes candidate,
+ // we reduce on the max value, then everyone knows who wins.
+ }}
+
+ // Simpler: direct broadcast via shared memory
+ // Thread 0 writes winner; all threads read it
+ if (tid == 0) {{
+ sh_vals[total + 0] = mymax;
+ sh_idx[total + 0] = mymi;
+ }}
+ // Not right - mymax is per-thread, need global max
+ // Let me write per-thread max to shared mem, then tree reduce
+ sh_vals[total + tid] = mymax;
+ sh_idx[total + tid] = mymi;
+ __syncthreads();
+
+ // Tree re
... (13721 more chars){
"content": "<16412 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,5 +1,6 @@
-"""Optimized top-k kernel via CUDA with streaming selection."""
-import functools
+"""Optimized top-k kernel via CUDA streaming selection."""
+from __future__ import annotations
+
from typing import Tuple
import torch
@@ -9,370 +10,28 @@
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-# ─── CUDA kernel (compiled at first import) ───────────────────────────────────
-
-_HEAP_K = 8
-_BLOCK_SIZE = 1024
-
-_cuda_code = rf"""
-#include <cuda_runtime.h>
+# ── Constants ────────────────────────────────────────────────────────────────
+_HEAP_K = 8 # per-thread candidate size
+_BS = 1024 # threads per block
+
+_ext = None
+
+
+def _ensure_ext():
+ """Compile CUDA extension on first call."""
+ global _ext
+ if _ext is not None:
+ return _ext
+
+ cuda_src = r"""
#include <float.h>
-
-#define HEAP_K {_HEAP_K}
-#define HEAP_CAP (2 * HEAP_K)
-
-__device__ __forceinline__
-void heap_insert(float* buf, int64_t* ibuf, float val, int64_t idx, int& size) {{
- int pos = size++;
- buf[pos] = val;
- ibuf[pos] = idx;
- while (pos > 0) {{
- int parent = (pos - 1) >> 1;
- if (buf[pos] > buf[parent]) {{
- float tv = buf[pos]; buf[pos] = buf[parent]; buf[parent] = tv;
- int64_t ti = ibuf[pos]; ibuf[pos] = ibuf[parent]; ibuf[parent] = ti;
- pos = parent;
- }} else break;
- }}
-}}
-
-__device__ __forceinline__
-void heap_extract_min(float* buf, int64_t* ibuf, int& size) {{
- if (size < 2) {{ size--; return; }}
- size--;
- buf[0] = buf[size]; ibuf[0] = ibuf[size];
- int i = 0;
- for (;;) {{
- int l = 2*i + 1, r = 2*i + 2, smallest = i;
- if (l < size && buf[l] < buf[smallest]) smallest = l;
- if (r < size && buf[r] < buf[smallest]) smallest = r;
- if (smallest == i) break;
- float tv = buf[i]; buf[i] = buf[smallest]; buf[smallest] = tv;
- int64_t ti = ibuf[i]; ibuf[i] = ibuf[smallest]; ibuf[smallest] = ti;
- i = smallest;
- }}
-}}
-
-template <int HEAP_K_T>
-__global__ void topk_kernel(
- const float* __restrict__ x,
- float* __restrict__ out_vals,
- int64_t* __restrict__ out_idx,
- int batch, int n, int k
-) {{
- int row = blockIdx.x;
- if (row >= batch) return;
-
- int tid = threadIdx.x;
- int block = blockDim.x; // 1024
-
- // Each thread scans elements at [tid, tid+block, tid+2*block, ...]
- // and maintains a local top-HEAP_K heap.
- float local_vals[HEAP_K_T];
- int64_t local_idx[HEAP_K_T];
- int local_size = 0;
-
- for (int i = tid; i < n; i += block) {{
- float v = x[row * n + i];
- if (local_size < HEAP_K_T) {{
- heap_insert(local_vals, local_idx, v, i, local_size);
- }} else if (v > local_vals[0]) {{
- local_vals[0] = v;
- local_idx[0] = i;
- heap_sift_down(local_vals, local_idx, local_size);
- }}
- }}
- // pad remaining slots to full capacity with -inf
- for (int p = local_size; p < HEAP_K_T; p++) {{
- local_vals[p] = -FLT_MAX;
- local_idx[p] = -1;
- }}
- local_size = HEAP_K_T;
-
- // Write local heaps to shared memory
- extern __shared__ char shmem[];
- float* sh_vals = reinterpret_cast<float*>(shmem);
- int64_t* sh_idx = reinterpret_cast<int64_t*>(shmem + block * HEAP_K_T * sizeof(float));
-
- for (int p = 0; p < HEAP_K_T; p++) {{
- sh_vals[tid * HEAP_K_T + p] = local_vals[p];
- sh_idx[tid * HEAP_K_T + p] = local_idx[p];
- }}
- __syncthreads();
-
- // Collectively select top-k from block*HEAP_K candidates using
- // k-pass max extraction with linear scan in shared memory.
- // This is O(k * block * HEAP_K) shared memory ops which for
- // k=64, block=1024, HEAP_K=8 is ~5M ops — well within budget.
- for (int kk = 0; kk < k; kk++) {{
- // Phase A: tree reduction — find max
- int total = block * HEAP_K_T;
- int si = tid + kk * block; // offset for this pass
- // We do a direct scan: each thread checks a block-stride
- // through all candidates
- float mymax = -FLT_MAX;
- int64_t mymi = 0;
- for (int j = tid; j < total; j += block) {{
- // Candidates above current pass have been zeroed
- // Use original shared memory; mark consumed
- if (sh_vals[j] > mymax) {{
- mymax = sh_vals[j];
- mymi = j;
- }}
- }}
- // Tree reduce max within block
- for (int offset = block >> 1; offset > 0; offset >>= 1) {{
- __syncthreads();
- float up = (tid < offset) ? sh_vals[si + offset] : -FLT_MAX;
- int64_t ui = (tid < offset) ? sh_idx[si + offset] : 0;
- si = tid;
- // This approach doesn't work well for tree reduce of "index of max"
- // Let me use a simpler all-reduce: each thread writes candidate,
- // we reduce on the max value, then everyone knows who wins.
- }}
-
- // Simpler: direct broadcast via shared memory
- // Thread 0 writes winner; all threads read it
- if (tid == 0) {{
- sh_vals[total + 0] = mymax;
- sh_idx[total + 0] = mymi;
- }}
- // Not right - mymax is per-thread, need global max
- // Let me write per-thread max to shared mem, then tree reduce
- sh_vals[total + tid] = mymax;
- sh_idx[total + tid] = mymi;
- __syncthreads();
-
- // Tree reduce on indices [total, total+block)
- for (int offset = block >> 1; offset > 0; offset >>= 1) {{
- __syncthreads();
- if (tid < offset) {{
- int ii = tid;
- int jj = tid + offset;
- i
... (14697 more chars){
"content": "<6343 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<192 chars \u2014 see below>",
"description": "Test solution import and basic call",
"timeout": 120000
}uv run python -c "
import solution
m = solution.Model(1, 100, 5, device='cuda')
x = torch.randn(1, 100, device='cuda')
out = m(x)
print(type(out), type(out[0]), out[0].shape)
" 2>&1 | tail -20Traceback (most recent call last):
File "<string>", line 3, in <module>
TypeError: Model.__init__() got an unexpected keyword argument 'device'
{
"command": "<197 chars \u2014 see below>",
"description": "Test solution import and basic call",
"timeout": 120000
}uv run python -c "
import torch
import solution
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
out = m(x)
print(type(out), out[0].shape, out[1].shape)
" 2>&1 | tail -30Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 179, in forward
ext = _ensure_ext()
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 162, in _ensure_ext
_ext = torch.utils.cpp_extension.load_inline(
^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'torch.utils' has no attribute 'cpp_extension'
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -5,6 +5,7 @@
import torch
import torch.nn as nn
+import torch.utils.cpp_extension as cpp_extension
OP_TYPE = "topk"
SUPPORTED_PRECISIONS = ["fp32"]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<34 chars \u2014 see below>",
"newString": "<84 chars \u2014 see below>"
}import torch
import torch.nn as nnimport torch
import torch.nn as nn
import torch.utils.cpp_extension as cpp_extensionEdit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,173 +1,652 @@
-"""Optimized top-k kernel via CUDA streaming selection."""
+"""Optimized top-k kernel via Triton with streaming selection."""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn as nn
-import torch.utils.cpp_extension as cpp_extension
+import triton
+import triton.language as tl
OP_TYPE = "topk"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-# ── Constants ────────────────────────────────────────────────────────────────
-_HEAP_K = 8 # per-thread candidate size
-_BS = 1024 # threads per block
-
-_ext = None
-
-
-def _ensure_ext():
- """Compile CUDA extension on first call."""
- global _ext
- if _ext is not None:
- return _ext
-
- cuda_src = r"""
-#include <float.h>
-
-#define HEAP_K 8
-#define BS 1024
-
-// Min-heap: insert (bubble up)
-__device__ __forceinline__ void heapi(float* v, int64_t* i,
- float val, int64_t idx, int& sz) {
- int p = sz++;
- v[p] = val; i[p] = idx;
- while (p > 0) {
- int q = (p - 1) >> 1;
- if (v[p] > v[q]) {
- float t1 = v[p]; v[p] = v[q]; v[q] = t1;
- int64_t t2 = i[p]; i[p] = i[q]; i[q] = t2;
- p = q;
- } else break;
- }
-}
-
-// Min-heap: extract min (sift down from root)
-__device__ __forceinline__ void heapsd(float* v, int64_t* i, int sz) {
- int p = 0;
- while (true) {
- int l = 2*p+1, r = 2*p+2, best = p;
- if (l < sz && v[l] < v[best]) best = l;
- if (r < sz && v[r] < v[best]) best = r;
- if (best == p) break;
- float t1 = v[p]; v[p] = v[best]; v[best] = t1;
- int64_t t2 = i[p]; i[p] = i[best]; i[best] = t2;
- p = best;
- }
-}
-
-extern "C" __global__ void
-topk_kernel(const float* __restrict__ X,
- float* __restrict__ OutV,
- int64_t* __restrict__ OutI,
- int batch, int n, int k)
-{
- int row = blockIdx.x;
- int tid = threadIdx.x;
- if (row >= batch) return;
-
- // ── Phase 1: block-stride tile scan + local top-HEAP_K ──────────────
- // Each thread processes elements at positions
- // [tid, tid+BS, tid+2*BS, ...] that are < n
- float lv[HEAP_K]; int64_t li[HEAP_K]; int lsz = 0;
- for (int i = tid; i < n; i += _BS) {
- float v = X[row * n + i];
- if (lsz < HEAP_K) {
- heapi(lv, li, v, i, lsz);
- } else if (v > lv[0]) {
- lv[0] = v; li[0] = i;
- heapsd(lv, li, lsz);
- }
- }
- // Fill remaining slots with -inf so every thread has HEAP_K entries
- for (int p = lsz; p < HEAP_K; p++) {
- lv[p] = -FLT_MAX;
- li[p] = -1;
- }
-
- // ── Shared memory layout ────────────────────────────────────────────
- // Region A: [0 .. BS*HEAP_K) — candidate values (float)
- // Region B: [BS*HEAP_K .. 2*BS*HEAP_K) — candidate indices (int64)
- // Region C: [2*BS*HEAP_K .. 2*BS*HEAP_K + BS) — reduce values (float)
- // Region D: [end-BS .. end) — reduce indices (int64)
- //
- // Total: BS*HEAP_K*(4+8) + BS*4 + BS*8 = BS*(12*HEAP_K+12)
- // BS=1024, HEAP_K=8 → 1024*(96+12) = 110592 bytes = 108 KB
- extern __shared__ char sbuf[];
-
- float* sv = (float*)sbuf;
- int64_t* si_ = (int64_t*)(sbuf + BS * HEAP_K * 4);
- float* rv = sv + BS * HEAP_K;
- int64_t* ri = (int64_t*)(sbuf + BS * HEAP_K * 12);
-
- int gbase = tid * HEAP_K;
- for (int p = 0; p < HEAP_K; p++) {
- sv[gbase + p] = lv[p];
- si_[gbase + p] = li[p];
- }
- __syncthreads();
-
- int total = BS * HEAP_K;
-
- // ── Phase 2: k-pass max extraction from shared memory ───────────────
- for (int kk = 0; kk < k; kk++) {
- // Each thread scans its block-stride through candidates
- float mymax = -FLT_MAX;
- int64_t myslot = 0;
- for (int j = tid; j < total; j += BS) {
- if (sv[j] > mymax) {
- mymax = sv[j];
- myslot = j;
- }
- }
-
- // Write local max to reduce regions
- rv[tid] = mymax;
- ri[tid] = myslot;
- __syncthreads();
-
- // Tree reduction over [0 .. BS) in reduce region
- for (int off = BS / 2; off > 0; off >>= 1) {
- __syncthreads();
- if (tid < off) {
- if (rv[tid + off] > rv[tid]) {
- rv[tid] = rv[tid + off];
- ri[tid] = ri[tid + off];
- }
- }
- }
- __syncthreads();
-
- // Winner: rv[0] = max value, ri[0] = shared-memory slot index
- float wval = rv[0];
- int64_t wslot = ri[0];
- int64_t widx = si_[wslot];
-
- if (tid == 0) {
- OutV[row * k + kk] = wval;
- OutI[row * k + kk] = widx;
- }
-
- // Invalidate: only thread 0 writes the invalidation
- if (tid == 0) {
- sv[static_cast<int>(wslot)] = -FLT_MAX;
- }
- __syncthreads();
- }
-}
-""".strip()
-
- _ext = torch.utils.cpp_extension.load_inline(
- name="topk_bitonic_cuda",
- cpp_sources="",
- cuda_sources=cuda_src,
- extra_cuda_cflags=["-O3", "-use_fast_math", "-Xcompiler", "-fno-exceptions"],
- verbose=False,
- )
- return _ext
+# ── Per-thread candidate heap size and block size ────────────────────────────
+# Each thread maintains a local top-HEAP_K min-heap while scanning its tile.
+# HEAP_K=8 keeps register pressure low and limits shared-memory candidate merge.
+TRITON_HEAP_K = 8
+TRITON_BLOCK = 1024
+
+
+@triton.jit
+def _topk_kernel(
+ X, OutV, OutI,
+ stride_x, n, k,
+ BLOCK: tl.constexpr, HEAP_K: tl.constexpr,
+):
+ """One program per row. Global top-k via block-stride scan
... (29289 more chars){
"content": "<28975 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,8 +1,6 @@
-"""Optimized top-k kernel via Triton with streaming selection."""
+"""Optimized top-k kernel via Triton with streamed max extraction."""
from __future__ import annotations
-
from typing import Tuple
-
import torch
import torch.nn as nn
import triton
@@ -12,641 +10,33 @@
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-# ── Per-thread candidate heap size and block size ────────────────────────────
-# Each thread maintains a local top-HEAP_K min-heap while scanning its tile.
-# HEAP_K=8 keeps register pressure low and limits shared-memory candidate merge.
-TRITON_HEAP_K = 8
-TRITON_BLOCK = 1024
-
+BLOCK = 1024
@triton.jit
-def _topk_kernel(
- X, OutV, OutI,
- stride_x, n, k,
- BLOCK: tl.constexpr, HEAP_K: tl.constexpr,
-):
- """One program per row. Global top-k via block-stride scan + shared merge."""
+def topk_kernel(X, OutV, OutI, stride_x, stride_o, n, k, BLOCK: tl.constexpr):
row = tl.program_id(0)
+ lane = tl.arange(0, BLOCK)
+ BASE = row * stride_x
- # ── Phase 1: block-stride tile scan, per-thread local top-HEAP_K ────
- # Thread tid processes elements at [tid, tid+BLOCK, tid+2*BLOCK, ...]
- tid = tl.program_id(0) # not available; use tl.arange instead
-
- # In Triton, each "lane" is identified by its position within the program's
- # tensor operations. We use tl.arange(0, 1)[0] to get the thread id.
- # Actually, tl.num_programs and tl.program_id are at program level.
- # Within a kernel, thread index = position in tensor = tl.arange per-element.
- #
- # Correct approach: use tl.arange to generate per-lane offsets.
-
- lane = tl.arange(0, BLOCK)
- # Each lane maintains a heap in registers. We'll use tl.full to initialize.
- lv = tl.full([BLOCK], -tl.inf, dtype=tl.float32)
- li = tl.full([BLOCK], 0, dtype=tl.int64)
- lsz = tl.zeros([BLOCK], dtype=tl.int32)
-
- # Tile scan: each lane processes its block-stride elements
- num_tiles = tl.cdiv(n, BLOCK)
- for tile in range(num_tiles):
- base = tile * BLOCK + lane
- mask = base < n
- v = tl.load(X + (row * stride_x + base), mask=mask, other=-tl.inf)
-
- # Insert into local heap: if v > heap[0] (min) or heap未满
- # Triton doesn't support index-based register arrays well, so we use
- # a streaming threshold approach instead.
- # Track threshold = k-th largest seen by this lane.
- # Simple version: if lsz < HEAP_K, always add; else if v > threshold, replace.
- pass # Need proper heap ops; see below
-
- # ... merge phase ...
-
-
-# Actually, per-lane register heaps with loops are hard in Triton.
-# Let me use a completely different approach: shared-sort.
-#
-# Strategy:
-# - For n <= BLOCK: load row into shared memory, sort, extract top-k
-# - For n > BLOCK: multi-pass. Each pass loads a chunk, sorts, keeps top-k;
-# merge passes' top-k candidates.
-#
-# Simpler: single kernel with two phases using tl-sort on shared memory.
-#
-# Phase 1: block-stride load, each thread maintains a threshold (float).
-# Phase 2: shared memory reduction to find global k-th value.
-# Phase 3: final collection + sort.
-#
-# Even simpler: for all shapes, use a multi-tile approach where
-# each kernel invocation processes a chunk and outputs partial top-k,
-# then a second kernel merges.
-#
-# But the simplest correct approach: have each thread process a small
-# number of elements (using shared memory), do a block-wide operation.
-
-# ── Final approach: two-pass streaming with threshold ───────────────────────
-#
-# Phase 1 (one kernel per row):
-# - Each thread scans block-stride tile, maintains local threshold
-# - Threshold = k-th largest value seen by this thread (approximate)
-# - Write threshold + index to output (as intermediate, via shared mem)
-# Actually, let's just write: out_v[thread], out_i[thread] for reduce
-#
-# Phase 2 (same kernel, continues):
-# - Read thresholds from shared memory / global
-# - Find k-th largest threshold → global selection threshold
-# - Second scan to collect elements >= threshold
-# - Sort collected elements descending, write final output
-#
-# The threshold approach avoids full per-thread top-k.
-
-@triton.jit
-def _topk_kernel_pass1(
- X, # (batch, n) fp32
- ThreshV, # (batch, BLOCK) fp32 — per-thread local threshold output
- ThreshI, # (batch, BLOCK) int64 — per-thread local threshold index
- stride_x, stride_t, n, k,
- BLOCK: tl.constexpr,
-):
- """Phase 1: one program per row. Each lane scans tile, outputs threshold."""
- row = tl.program_id(0)
- lane_off = tl.arange(0, BLOCK)
-
- # Local threshold: track the k-th largest value seen by this lane.
- # Use a simple approach: maintain a sorted buffer of top-k.
- # Since k can be up to 64, we can't maintain full register heap.
- # Instead: maintain a single float "threshold" that starts at -inf
- # and gets raised as we see elements. The threshold represents
- # "the k-th largest element seen so far by this lane".
- #
- # Approximation: count elements per lane. If count >= k, threshold
- # is the actual k-th value. We use a simple min-of-top-k approach.
- #
- # For correctness: we'll collect more than k candidates in phase 2,
- # so having an imprecise threshold is fine.
-
- # Approach: maintain a small sorted buffer of size HEAP_K=8 in registers
- # using sequential Triton ops. Each element:
- # insert v into buffer (sorted desc). If buffer full, smallest drops.
-
- # Initialize buffer entries to -inf
- buf_val = tl.full([BLOCK], -tl.inf, dtype=tl.float32)
- buf_idx = tl.full([BLOCK], -1, dtype=tl.int64)
- buf_sz = tl.zeros([BLOCK], dtype=tl.int32)
-
- # Scan tiles
- # Use a while loop for the
... (24588 more chars){
"content": "<2128 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<440 chars \u2014 see below>",
"description": "Test basic Triton kernel",
"timeout": 120000
}uv run python -c "
import torch
import solution
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
try:
v, i = m(x)
print('values:', v)
print('indices:', i)
# Compare with torch.topk
rv, ri = torch.topk(x, 5, dim=-1)
print('ref_v:', rv)
print('ref_i:', ri)
print('match:', torch.allclose(v, rv))
except Exception as e:
print(f'ERROR: {e}')
" 2>&1 | tail -30ERROR: module 'triton' has no attribute 'cuda'
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -50,7 +50,7 @@
out_v = torch.empty(batch, k, dtype=torch.float32, device=x.device)
out_i = torch.empty(batch, k, dtype=torch.int64, device=x.device)
grid = (batch, 1, 1)
- triton.cuda.set_arch("sm_120")
+ triton.config.default_config = triton.Config({}, nums_warps=8)
topk_kernel[grid](x, out_v, out_i, x.stride(0), out_v.stride(0), n, k, BLOCK=BLOCK)
return out_v, out_i
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<130 chars \u2014 see below>",
"newString": "<162 chars \u2014 see below>"
} triton.cuda.set_arch("sm_120")
topk_kernel[grid](x, out_v, out_i, x.stride(0), out_v.stride(0), n, k, BLOCK=BLOCK) triton.config.default_config = triton.Config({}, nums_warps=8)
topk_kernel[grid](x, out_v, out_i, x.stride(0), out_v.stride(0), n, k, BLOCK=BLOCK)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -50,7 +50,6 @@
out_v = torch.empty(batch, k, dtype=torch.float32, device=x.device)
out_i = torch.empty(batch, k, dtype=torch.int64, device=x.device)
grid = (batch, 1, 1)
- triton.config.default_config = triton.Config({}, nums_warps=8)
topk_kernel[grid](x, out_v, out_i, x.stride(0), out_v.stride(0), n, k, BLOCK=BLOCK)
return out_v, out_i
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<162 chars \u2014 see below>",
"newString": " topk_kernel[grid](x, out_v, out_i, x.stride(0), out_v.stride(0), n, k, BLOCK=BLOCK)"
} triton.config.default_config = triton.Config({}, nums_warps=8)
topk_kernel[grid](x, out_v, out_i, x.stride(0), out_v.stride(0), n, k, BLOCK=BLOCK)Edit applied successfully.{
"command": "<425 chars \u2014 see below>",
"description": "Test Triton kernel basic",
"timeout": 120000
}uv run python -c "
import torch
import solution
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
try:
v, i = m(x)
print('values:', v[0])
print('indices:', i[0])
rv, ri = torch.topk(x, 5, dim=-1)
print('ref_v:', rv[0])
print('match:', torch.allclose(v, rv, atol=1e-4))
except Exception as e:
import traceback; traceback.print_exc()
" 2>&1 | tail -30File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py", line 53, in forward
topk_kernel[grid](x, out_v, out_i, x.stride(0), out_v.stride(0), n, k, BLOCK=BLOCK)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
... (842 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -19,7 +19,7 @@
BASE = row * stride_x
# Phase 1: each lane scans its block-stride tile and keeps best element
- best_val = tl.full([BLOCK], -tl.inf, dtype=tl.float32)
+ best_val = tl.full([BLOCK], float("-inf"), dtype=tl.float32)
best_idx = tl.full([BLOCK], 0, dtype=tl.int64)
for t in range(0, 128):
off = t * BLOCK + lane
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": " best_val = tl.full([BLOCK], -tl.inf, dtype=tl.float32)",
"newString": " best_val = tl.full([BLOCK], float(\"-inf\"), dtype=tl.float32)"
}Edit applied successfully.{
"pattern": "tl\\.inf",
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Found 2 matches
/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py:
Line 27: v = tl.load(X + BASE + off, mask=mask, other=-tl.inf)
Line 39: best_val = tl.where(wm, -tl.inf, best_val)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -24,7 +24,7 @@
for t in range(0, 128):
off = t * BLOCK + lane
mask = off < n
- v = tl.load(X + BASE + off, mask=mask, other=-tl.inf)
+ v = tl.load(X + BASE + off, mask=mask, other=float("-inf"))
update = v > best_val
best_val = tl.where(update, v, best_val)
best_idx = tl.where(update, off, best_idx)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": " v = tl.load(X + BASE + off, mask=mask, other=-tl.inf)",
"newString": " v = tl.load(X + BASE + off, mask=mask, other=float(\"-inf\"))"
}Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -36,7 +36,7 @@
wm = lane == gp
tl.store(OutV + row * stride_o + kk, gm, mask=wm)
tl.store(OutI + row * stride_o + kk, best_idx, mask=wm)
- best_val = tl.where(wm, -tl.inf, best_val)
+ best_val = tl.where(wm, float("-inf"), best_val)
class Model(nn.Module):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": " best_val = tl.where(wm, -tl.inf, best_val)",
"newString": " best_val = tl.where(wm, float(\"-inf\"), best_val)"
}Edit applied successfully.{
"command": "<397 chars \u2014 see below>",
"description": "Test Triton kernel v2",
"timeout": 120000
}uv run python -c "
import torch
import solution
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
try:
v, i = m(x)
print('values:', v[0])
rv, ri = torch.topk(x, 5, dim=-1)
print('ref_v:', rv[0])
print('match:', torch.allclose(v, rv, atol=1e-4))
except Exception as e:
import traceback; traceback.print_exc()
" 2>&1 | tail -30File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 80, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.er
... (540 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -32,7 +32,7 @@
# Phase 2: extract top-k by repeated max reduction
for kk in range(k):
gm = tl.max(best_val)
- gp = tl.argmax(best_val)
+ gp = tl.argmax(best_val, axis=0)
wm = lane == gp
tl.store(OutV + row * stride_o + kk, gm, mask=wm)
tl.store(OutI + row * stride_o + kk, best_idx, mask=wm)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": " gp = tl.argmax(best_val)",
"newString": " gp = tl.argmax(best_val, axis=0)"
}Edit applied successfully.{
"command": "<397 chars \u2014 see below>",
"description": "Test Triton kernel v3",
"timeout": 120000
}uv run python -c "
import torch
import solution
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
try:
v, i = m(x)
print('values:', v[0])
rv, ri = torch.topk(x, 5, dim=-1)
print('ref_v:', rv[0])
print('match:', torch.allclose(v, rv, atol=1e-4))
except Exception as e:
import traceback; traceback.print_exc()
" 2>&1 | tail -30File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 80, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.er
... (566 more chars){
"command": "<179 chars \u2014 see below>",
"description": "Check cpp_extension API",
"timeout": 30000
}uv run python -c "
import torch.utils.cpp_extension as ce
print(dir(ce))
print('---')
print(ce.load_inline.__doc__[:500] if ce.load_inline.__doc__ else 'no doc')
" 2>&1 | tail -20['ABI_INCOMPATIBILITY_WARNING', 'BUILT_FROM_SOURCE_VERSION_PATTERN', 'BuildExtension', 'CLIB_EXT', 'CLIB_PREFIX', 'COMMON_HIPCC_FLAGS', 'COMMON_HIP_FLAGS', 'COMMON_MSVC_FLAGS', 'COMMON_NVCC_FLAGS', 'CUDAExtension', 'CUDA_CLANG_VERSIONS', 'CUDA_GCC_VERSIONS', 'CUDA_HOME', 'CUDA_MISMATCH_MESSAGE', 'CUDA_MISMATCH_WARN', 'CUDA_NOT_FOUND_MESSAGE', 'CUDNN_HOME', 'CppExtension', 'EXEC_EXT', 'ExtensionVersioner', 'FileBaton', 'HIP_HOME', 'IS_HIP_EXTENSION', 'IS_LINUX', 'IS_MACOS', 'IS_WINDOWS', 'JIT_EXTENSION_VERSIONER', 'LIB_EXT', 'MINIMUM_CLANG_VERSION', 'MINIMUM_GCC_VERSION', 'MINIMUM_MSVC_VERSION', 'MSVC_IGNORE_CUDAFE_WARNINGS', 'PLAT_TO_VCVARS', 'Path', 'ROCM_HOME', 'ROCM_VERSION', 'SHARED_FLAG', 'SUBPROCESS_DECODE_ARGS', 'SYCL_HOME', 'SyclExtension', 'TORCH_LIB_PATH', 'TorchVersion', 'Version', 'VersionMap', 'VersionRange', 'WINDOWS_CUDA_HOME', 'WRONG_COMPILER_WARNING', '_COMMON_SYCL_FLAGS', '_HERE', '_SYCL_DLINK_FLAGS', '_TORCH_PATH', '__all__', '__annotations__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', '_accepted_compilers_for_platform', '_append_sycl_std_if_no_std_present', '_append_sycl_targets_if_missing', '_check_and_build_extension_h_precompiler_headers', '_check_cuda_version', '_find_cuda_home', '_find_rocm_home', '_find_sycl_home', '_get_build_directory', '_get_cuda_arch_flags', '_get_exec_path', '_get_hipcc_path', '_get_icpx_version', '_get_num_workers', '_get_pybind11_abi_build_flags', '_get_rocm_arch_
... (1523 more chars){
"command": "<180 chars \u2014 see below>",
"description": "Check load_inline signature",
"timeout": 30000
}uv run python -c "
import torch.utils.cpp_extension as ce
# Check actual signature of load_inline
import inspect
sig = inspect.signature(ce.load_inline)
print(sig)
" 2>&1 | head -5(name, cpp_sources, cuda_sources=None, sycl_sources=None, functions=None, extra_cflags=None, extra_cuda_cflags=None, extra_sycl_cflags=None, extra_ldflags=None, extra_include_paths=None, build_directory=None, verbose=False, with_cuda=None, with_sycl=None, is_python_module=True, with_pytorch_error_handling=True, keep_intermediates=True, use_pch=False, no_implicit_headers=False)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,42 +1,354 @@
-"""Optimized top-k kernel via Triton with streamed max extraction."""
+"""Optimized top-k kernel via CUDA with C++/CUDA extension."""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
+import torch.utils.cpp_extension as cpp_ext
OP_TYPE = "topk"
SUPPORTED_PRECISIONS = ["fp32"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-BLOCK = 1024
-
-@triton.jit
-def topk_kernel(X, OutV, OutI, stride_x, stride_o, n, k, BLOCK: tl.constexpr):
- row = tl.program_id(0)
- lane = tl.arange(0, BLOCK)
- BASE = row * stride_x
-
- # Phase 1: each lane scans its block-stride tile and keeps best element
- best_val = tl.full([BLOCK], float("-inf"), dtype=tl.float32)
- best_idx = tl.full([BLOCK], 0, dtype=tl.int64)
- for t in range(0, 128):
- off = t * BLOCK + lane
- mask = off < n
- v = tl.load(X + BASE + off, mask=mask, other=float("-inf"))
- update = v > best_val
- best_val = tl.where(update, v, best_val)
- best_idx = tl.where(update, off, best_idx)
-
- # Phase 2: extract top-k by repeated max reduction
- for kk in range(k):
- gm = tl.max(best_val)
- gp = tl.argmax(best_val, axis=0)
- wm = lane == gp
- tl.store(OutV + row * stride_o + kk, gm, mask=wm)
- tl.store(OutI + row * stride_o + kk, best_idx, mask=wm)
- best_val = tl.where(wm, float("-inf"), best_val)
+_ext = None
+
+_CPP_SOURCE = """
+#include <torch/extension.h>
+#include <torch/library.h>
+
+TORCH_LIBRARY(topk_ops, m) {
+ m.def("topk_kernel(Tensor x, int k) -> (Tensor, Tensor)");
+}
+
+TORCH_LIBRARY_IMPL(topk_ops, Backend Select, m) {
+ m.impl("topk_kernel", &topk_cuda);
+}
+
+std::vector<torch::Tensor> topk_cuda(
+ const torch::Tensor& x, int k);
+"""
+
+_CUDA_SOURCE = r"""
+#include <torch/extension.h>
+#include <float.h>
+
+#define HEAP_K 8
+#define BS 1024
+
+__device__ __forceinline__ void
+heapi(float* v, int64_t* i, float val, int64_t idx, int& sz) {
+ int p = sz++;
+ v[p] = val; i[p] = idx;
+ while (p > 0) {
+ int q = (p - 1) >> 1;
+ if (v[p] > v[q]) {
+ float t1 = v[p]; v[p] = v[q]; v[q] = t1;
+ int64_t t2 = i[p]; i[p] = i[q]; i[q] = t2;
+ p = q;
+ } else break;
+ }
+}
+
+__device__ __forceinline__ void
+heapsd(float* v, int64_t* i, int sz) {
+ int p = 0;
+ while (true) {
+ int l = 2*p+1, r = 2*p+2, best = p;
+ if (l < sz && v[l] < v[best]) best = l;
+ if (r < sz && v[r] < v[best]) best = r;
+ if (best == p) break;
+ float t1 = v[p]; v[p] = v[best]; v[best] = t1;
+ int64_t t2 = i[p]; i[p] = i[best]; i[best] = t2;
+ p = best;
+ }
+}
+
+extern "C" __global__ void
+topk_kernel_impl(const float* __restrict__ X,
+ float* __restrict__ OutV,
+ int64_t* __restrict__ OutI,
+ int batch, int n, int k)
+{
+ int row = blockIdx.x;
+ int tid = threadIdx.x;
+ if (row >= batch) return;
+
+ float* __restrict__ smem = new float[BS * HEAP_K];
+ int64_t* __restrict__ smem_idx = (int64_t*)(smem + BS * HEAP_K + 1024);
+
+ // Phase 1: block-stride scan, local top-HEAP_K
+ float lv[HEAP_K]; int64_t li[HEAP_K]; int lsz = 0;
+ for (int i = tid; i < n; i += BS) {
+ float v = X[row * n + i];
+ if (lsz < HEAP_K) {
+ heapi(lv, li, v, i, lsz);
+ } else if (v > lv[0]) {
+ lv[0] = v; li[0] = i;
+ heapsd(lv, li, lsz);
+ }
+ }
+ for (int p = lsz; p < HEAP_K; p++) { lv[p] = -FLT_MAX; li[p] = -1; }
+
+ // Write to shared memory
+ int base = tid * HEAP_K;
+ for (int p = 0; p < HEAP_K; p++) {
+ smem[base + p] = lv[p];
+ smem_idx[base + p] = li[p];
+ }
+ __syncthreads();
+
+ // Phase 2: k-pass max extraction from BS*HEAP_K candidates
+ int total = BS * HEAP_K;
+
+ for (int kk = 0; kk < k; kk++) {
+ float mymx = -FLT_MAX;
+ int mymi = 0;
+ for (int j = tid; j < total; j += BS) {
+ if (smem[j] > mymx) { mymx = smem[j]; mymi = j; }
+ }
+ smem[total + tid] = mymx;
+ smem_idx[total + tid] = mymi;
+ __syncthreads();
+
+ for (int off = BS / 2; off > 0; off >>= 1) {
+ __syncthreads();
+ if (tid < off) {
+ if (smem[total + tid + off] > smem[total + tid]) {
+ smem[total + tid] = smem[total + tid + off];
+ smem_idx[total + tid] = smem_idx[total + tid + off];
+ }
+ }
+ }
+ __syncthreads();
+
+ float wval = smem[total];
+ int wslot = (int)smem_idx[total];
+
+ if (tid == 0) {
+ OutV[row * k + kk] = wval;
+ OutI[row * k + kk] = smem_idx[wslot];
+ }
+
+ if (wslot >= 0 && wslot < total) {
+ smem[wslot] = -FLT_MAX;
+ }
+ __syncthreads();
+ }
+}
+
+using namespace torch::autograd;
+
+std::vector<torch::Tensor> topk_cuda(const torch::Tensor& x, int k) {
+ auto options = torch::TensorOptions()
+ .dtype(x.dtype())
+ .device(x.device())
+ .requires_grad(false);
+
+ int64_t batch = x.size(0);
+ int64_t n = x.size(1);
+ int64_t k64 = k;
+
+ torch::Tensor outV = torch::empty({batch, k64}, options);
+ torch::Tensor outI = torch::empty({batch, k64}, options.dtype(torch::kInt64));
+
+ // Shared mem: BS*HEAP_K floats + BS*HEAP_K int64 + BS floats + BS int64
+ // = 1024*8*4 + 1024*8*8 + 1024*4 + 1024*8 = 32768 + 65536 + 4096 + 8192 = 110592
+ // Plus tree-reduce region: BS floats + BS int64 = 12288
+ // Total smem: 1024*(12*8+12) = 110592, but we need overlapping space
+ // Actually: smem[t
... (6834 more chars){
"content": "<10877 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,4 +1,4 @@
-"""Optimized top-k kernel via CUDA with C++/CUDA extension."""
+"""Optimized top-k kernel via CUDA with load_inline."""
from __future__ import annotations
from typing import Tuple
import torch
@@ -11,24 +11,7 @@
_ext = None
-_CPP_SOURCE = """
-#include <torch/extension.h>
-#include <torch/library.h>
-
-TORCH_LIBRARY(topk_ops, m) {
- m.def("topk_kernel(Tensor x, int k) -> (Tensor, Tensor)");
-}
-
-TORCH_LIBRARY_IMPL(topk_ops, Backend Select, m) {
- m.impl("topk_kernel", &topk_cuda);
-}
-
-std::vector<torch::Tensor> topk_cuda(
- const torch::Tensor& x, int k);
-"""
-
-_CUDA_SOURCE = r"""
-#include <torch/extension.h>
+_CUDA_SRC = r"""
#include <float.h>
#define HEAP_K 8
@@ -63,19 +46,22 @@
}
extern "C" __global__ void
-topk_kernel_impl(const float* __restrict__ X,
- float* __restrict__ OutV,
- int64_t* __restrict__ OutI,
- int batch, int n, int k)
+topk_kernel(const float* __restrict__ X,
+ float* __restrict__ OutV,
+ int64_t* __restrict__ OutI,
+ int batch, int n, int k)
{
int row = blockIdx.x;
int tid = threadIdx.x;
if (row >= batch) return;
- float* __restrict__ smem = new float[BS * HEAP_K];
- int64_t* __restrict__ smem_idx = (int64_t*)(smem + BS * HEAP_K + 1024);
-
- // Phase 1: block-stride scan, local top-HEAP_K
+ extern __shared__ char sbuf[];
+ float* sv = (float*)sbuf;
+ int64_t* si = (int64_t*)(sbuf + BS * HEAP_K * 4);
+ float* rv = (float*)(sbuf + BS * HEAP_K * 12);
+ int64_t* ri = (int64_t*)(sbuf + BS * HEAP_K * 12 + BS * 8);
+
+ // Phase 1: each thread scans block-stride tile, maintains top-HEAP_K
float lv[HEAP_K]; int64_t li[HEAP_K]; int lsz = 0;
for (int i = tid; i < n; i += BS) {
float v = X[row * n + i];
@@ -86,182 +72,9 @@
heapsd(lv, li, lsz);
}
}
- for (int p = lsz; p < HEAP_K; p++) { lv[p] = -FLT_MAX; li[p] = -1; }
-
- // Write to shared memory
- int base = tid * HEAP_K;
- for (int p = 0; p < HEAP_K; p++) {
- smem[base + p] = lv[p];
- smem_idx[base + p] = li[p];
- }
- __syncthreads();
-
- // Phase 2: k-pass max extraction from BS*HEAP_K candidates
- int total = BS * HEAP_K;
-
- for (int kk = 0; kk < k; kk++) {
- float mymx = -FLT_MAX;
- int mymi = 0;
- for (int j = tid; j < total; j += BS) {
- if (smem[j] > mymx) { mymx = smem[j]; mymi = j; }
- }
- smem[total + tid] = mymx;
- smem_idx[total + tid] = mymi;
- __syncthreads();
-
- for (int off = BS / 2; off > 0; off >>= 1) {
- __syncthreads();
- if (tid < off) {
- if (smem[total + tid + off] > smem[total + tid]) {
- smem[total + tid] = smem[total + tid + off];
- smem_idx[total + tid] = smem_idx[total + tid + off];
- }
- }
- }
- __syncthreads();
-
- float wval = smem[total];
- int wslot = (int)smem_idx[total];
-
- if (tid == 0) {
- OutV[row * k + kk] = wval;
- OutI[row * k + kk] = smem_idx[wslot];
- }
-
- if (wslot >= 0 && wslot < total) {
- smem[wslot] = -FLT_MAX;
- }
- __syncthreads();
- }
-}
-
-using namespace torch::autograd;
-
-std::vector<torch::Tensor> topk_cuda(const torch::Tensor& x, int k) {
- auto options = torch::TensorOptions()
- .dtype(x.dtype())
- .device(x.device())
- .requires_grad(false);
-
- int64_t batch = x.size(0);
- int64_t n = x.size(1);
- int64_t k64 = k;
-
- torch::Tensor outV = torch::empty({batch, k64}, options);
- torch::Tensor outI = torch::empty({batch, k64}, options.dtype(torch::kInt64));
-
- // Shared mem: BS*HEAP_K floats + BS*HEAP_K int64 + BS floats + BS int64
- // = 1024*8*4 + 1024*8*8 + 1024*4 + 1024*8 = 32768 + 65536 + 4096 + 8192 = 110592
- // Plus tree-reduce region: BS floats + BS int64 = 12288
- // Total smem: 1024*(12*8+12) = 110592, but we need overlapping space
- // Actually: smem[total+tid] uses smem[8192..9216), which overlaps with smem_idx
- // Let me use proper non-overlapping layout.
- // Region smem_val: [0, 8192) float
- // Region smem_idx: [8192, 16384) int64
- // Region reduce_val: [8192, 9216) float -- overlaps with smem_idx!!!
- //
- // Fix: use separate arrays in extern shared memory.
- // Actually the kernel above has a bug: smem[total+tid] writes BEYOND
- // the candidate region and might overlap. Let me fix the layout.
-
- // Proper layout:
- // smem_val[0..8192]: candidate values (float)
- // smem_idx[0..8192]: candidate indices (int64)
- // smem_rv[0..1024]: reduce values (float) -- reuses smem_val[8192..9216]
- // Wait, this overlaps. Let me allocate properly.
- // Total: 8192+8192+1024+1024 slots in terms of their types.
- // In bytes: 8192*4 + 8192*8 + 1024*4 + 1024*8 = 32768+65536+4096+8192 = 110592
-
- // Actually the kernel uses shared memory that's allocated per-block.
- // Using `extern __shared__` with Python-side allocation.
- // Or use a fixed-size shared memory approach.
-
- // Let me rewrite the kernel to use extern __shared__ properly.
- // For now, use a simple fixed approach.
-
- cudaStream_t stream = at::cuda::getCurrentCUDAStream(x.device().index());
- int shmem = 110592; // bytes
-
- // Actually, since I used dynamic allocation in the kernel (new[]), that
- // won't work. Let me rewrite with extern __shared__.
-
- topk_kernel_impl<<<batch, BS, shmem, stream>>>(
- x.data_ptr<float>(),
- outV.data_ptr<float>(),
- outI.data_ptr<int64_t>(),
- (int)batch, (int)n, k
- );
-
- return {outV, outI
... (5640 more chars){
"content": "<5291 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py"
}Wrote file successfully.{
"command": "<306 chars \u2014 see below>",
"description": "Test CUDA extension build",
"timeout": 180000
}uv run python -c "
import torch
import solution
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
try:
v, i = m(x)
print('values:', v[0])
print('indices:', i[0])
except Exception as e:
import traceback; traceback.print_exc()
" 2>&1 | tail -40Traceback (most recent call last):
File "<string>", line 9, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 574, in __repr__
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_tensor_str.py", line 726, in _str
return _str_intern(self, tensor_contents=tensor_contents)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_tensor_str.py", line 643, in _str_intern
tensor_str = _tensor_str(self, indent)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_tensor_str.py", line 375, in _tensor_str
formatter = _Formatter(get_summarized_data(self) if summarize else self)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_tensor_str.py", line 155, in __init__
tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more inform
... (258 more chars){
"command": "<198 chars \u2014 see below>",
"description": "Debug CUDA kernel with blocking",
"timeout": 120000
}CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch
import solution
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
v, i = m(x)
" 2>&1 | tail -20(no output)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -186,7 +186,9 @@
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
- return ext.topk_func(x, self.k)
+ result = ext.topk_func(x, self.k)
+ torch.cuda.synchronize()
+ return result
batch, n, k = 64, 8192, 8
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<144 chars \u2014 see below>",
"newString": "<201 chars \u2014 see below>"
} def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
return ext.topk_func(x, self.k) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
result = ext.topk_func(x, self.k)
torch.cuda.synchronize()
return resultEdit applied successfully.{
"command": "<300 chars \u2014 see below>",
"description": "Test correctness small",
"timeout": 120000
}uv run python -c "
import torch
import solution
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
v, i = m(x)
rv, ri = torch.topk(x, 5, dim=-1)
print('sol:', v[0])
print('ref:', rv[0])
print('match:', torch.allclose(v, rv, atol=1e-4))
" 2>&1 | tail -10Traceback (most recent call last):
File "<string>", line 8, in <module>
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -58,8 +58,9 @@
extern __shared__ char sbuf[];
float* sv = (float*)sbuf;
int64_t* si = (int64_t*)(sbuf + BS * HEAP_K * 4);
- float* rv = (float*)(sbuf + BS * HEAP_K * 12);
- int64_t* ri = (int64_t*)(sbuf + BS * HEAP_K * 12 + BS * 8);
+ float* rv = sv + BS * HEAP_K + BS;
+ int64_t* ri = (int64_t*)(((char*)rv) + BS * 4 + 7) - (int64_t*)0 + (int64_t*)sbuf;
+ // ri starts right after rv. Total smem: BS*HEAP_K*4 + BS*HEAP_K*8 + BS*4 + BS*8
// Phase 1: each thread scans block-stride tile, maintains top-HEAP_K
float lv[HEAP_K]; int64_t li[HEAP_K]; int lsz = 0;
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<237 chars \u2014 see below>",
"newString": "<333 chars \u2014 see below>"
} extern __shared__ char sbuf[];
float* sv = (float*)sbuf;
int64_t* si = (int64_t*)(sbuf + BS * HEAP_K * 4);
float* rv = (float*)(sbuf + BS * HEAP_K * 12);
int64_t* ri = (int64_t*)(sbuf + BS * HEAP_K * 12 + BS * 8); extern __shared__ char sbuf[];
float* sv = (float*)sbuf;
int64_t* si = (int64_t*)(sbuf + BS * HEAP_K * 4);
float* rv = sv + BS * HEAP_K + BS;
int64_t* ri = (int64_t*)(((char*)rv) + BS * 4 + 7) - (int64_t*)0 + (int64_t*)sbuf;
// ri starts right after rv. Total smem: BS*HEAP_K*4 + BS*HEAP_K*8 + BS*4 + BS*8Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -56,11 +56,15 @@
if (row >= batch) return;
extern __shared__ char sbuf[];
+ // Layout: [sv: 8192 floats][si: 8192 int64][rv: 1024 floats...pad][ri: 1024 int64]
+ // sv: 0..32767 bytes
+ // si: 32768..98303 bytes
+ // rv: 98304..102399 bytes (aligned to 8)
+ // ri: 102400..110591 bytes
float* sv = (float*)sbuf;
- int64_t* si = (int64_t*)(sbuf + BS * HEAP_K * 4);
- float* rv = sv + BS * HEAP_K + BS;
- int64_t* ri = (int64_t*)(((char*)rv) + BS * 4 + 7) - (int64_t*)0 + (int64_t*)sbuf;
- // ri starts right after rv. Total smem: BS*HEAP_K*4 + BS*HEAP_K*8 + BS*4 + BS*8
+ int64_t* si = (int64_t*)(sbuf + 8192 * 4);
+ float* rv = (float*)(sbuf + (8192 * 4 + 8192 * 8));
+ int64_t* ri = (int64_t*)(sbuf + (8192 * 4 + 8192 * 8 + 1024 * 4 + 7) & ~7);
// Phase 1: each thread scans block-stride tile, maintains top-HEAP_K
float lv[HEAP_K]; int64_t li[HEAP_K]; int lsz = 0;
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<333 chars \u2014 see below>",
"newString": "<496 chars \u2014 see below>"
} extern __shared__ char sbuf[];
float* sv = (float*)sbuf;
int64_t* si = (int64_t*)(sbuf + BS * HEAP_K * 4);
float* rv = sv + BS * HEAP_K + BS;
int64_t* ri = (int64_t*)(((char*)rv) + BS * 4 + 7) - (int64_t*)0 + (int64_t*)sbuf;
// ri starts right after rv. Total smem: BS*HEAP_K*4 + BS*HEAP_K*8 + BS*4 + BS*8 extern __shared__ char sbuf[];
// Layout: [sv: 8192 floats][si: 8192 int64][rv: 1024 floats...pad][ri: 1024 int64]
// sv: 0..32767 bytes
// si: 32768..98303 bytes
// rv: 98304..102399 bytes (aligned to 8)
// ri: 102400..110591 bytes
float* sv = (float*)sbuf;
int64_t* si = (int64_t*)(sbuf + 8192 * 4);
float* rv = (float*)(sbuf + (8192 * 4 + 8192 * 8));
int64_t* ri = (int64_t*)(sbuf + (8192 * 4 + 8192 * 8 + 1024 * 4 + 7) & ~7);Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -130,7 +130,12 @@
extern "C" void
topk_launch(const float* X, float* OutV, int64_t* OutI,
int batch, int n, int k, cudaStream_t stream) {
- int sm = BS * HEAP_K * 12 + BS * 16 + BS * 8;
+ // sv: 8192 float = 32768 B
+ // si: 8192 int64 = 65536 B
+ // rv: 1024 float = 4096 B
+ // ri: 1024 int64 = 8192 B
+ // total = 110592 B
+ int sm = 8192 * 4 + 8192 * 8 + 1024 * 4 + 1024 * 8;
topk_kernel<<<batch, BS, sm, stream>>>(X, OutV, OutI, batch, n, k);
}
"""
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<255 chars \u2014 see below>",
"newString": "<414 chars \u2014 see below>"
}extern "C" void
topk_launch(const float* X, float* OutV, int64_t* OutI,
int batch, int n, int k, cudaStream_t stream) {
int sm = BS * HEAP_K * 12 + BS * 16 + BS * 8;
topk_kernel<<<batch, BS, sm, stream>>>(X, OutV, OutI, batch, n, k);
}extern "C" void
topk_launch(const float* X, float* OutV, int64_t* OutI,
int batch, int n, int k, cudaStream_t stream) {
// sv: 8192 float = 32768 B
// si: 8192 int64 = 65536 B
// rv: 1024 float = 4096 B
// ri: 1024 int64 = 8192 B
// total = 110592 B
int sm = 8192 * 4 + 8192 * 8 + 1024 * 4 + 1024 * 8;
topk_kernel<<<batch, BS, sm, stream>>>(X, OutV, OutI, batch, n, k);
}Edit applied successfully.{
"command": "<360 chars \u2014 see below>",
"description": "Test corrected kernel",
"timeout": 180000
}CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch
import solution
# Force rebuild
solution._ext = None
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
v, i = m(x)
rv, ri = torch.topk(x, 5, dim=-1)
print('sol:', v[0])
print('ref:', rv[0])
print('match:', torch.allclose(v, rv, atol=1e-4))
" 2>&1 | tail -15_write_ninja_file_and_build_library(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2429, in _write_ninja_file_and_build_library
_run_ninja_build(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2791, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error building extension 'topk_cuda_ext': [1/2] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=topk_cuda_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -O3 -use_fast_math -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/topk_cuda_ext/cuda.cu -o cuda.cuda.o
FAILED: [code=2] cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=topk_cuda_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/Kerne
... (1109 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -64,7 +64,7 @@
float* sv = (float*)sbuf;
int64_t* si = (int64_t*)(sbuf + 8192 * 4);
float* rv = (float*)(sbuf + (8192 * 4 + 8192 * 8));
- int64_t* ri = (int64_t*)(sbuf + (8192 * 4 + 8192 * 8 + 1024 * 4 + 7) & ~7);
+ int64_t* ri = (int64_t*)(sbuf + ((8192 * 4 + 8192 * 8 + 1024 * 4 + 7) & ~7));
// Phase 1: each thread scans block-stride tile, maintains top-HEAP_K
float lv[HEAP_K]; int64_t li[HEAP_K]; int lsz = 0;
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": " int64_t* ri = (int64_t*)(sbuf + (8192 * 4 + 8192 * 8 + 1024 * 4 + 7) & ~7);",
"newString": " int64_t* ri = (int64_t*)(sbuf + ((8192 * 4 + 8192 * 8 + 1024 * 4 + 7) & ~7));"
}Edit applied successfully.{
"command": "<338 chars \u2014 see below>",
"description": "Test kernel v4",
"timeout": 180000
}CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch, solution
solution._ext = None
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
v, i = m(x)
rv, ri = torch.topk(x, 5, dim=-1)
print('sol:', v[0])
print('ref:', rv[0])
print('match:', torch.allclose(v, rv, atol=1e-4))
" 2>&1 | tail -10Traceback (most recent call last):
File "<string>", line 8, in <module>
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
{
"command": "<828 chars \u2014 see below>",
"description": "Test minimal CUDA extension",
"timeout": 180000
}CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch
import torch.utils.cpp_extension as cpp_ext
cuda = '''
extern \"C\" __global__ void test_kernel(float* out) {
out[0] = 42.0f;
}
extern \"C\" void test_launch(float* out, cudaStream_t s) {
test_kernel<<<1, 1, 0, s>>>(out);
}
'''
cpp = '''
#include <torch/extension.h>
extern \"C\" void test_launch(float* out, cudaStream_t s);
torch::Tensor test_func(torch::Tensor x) {
auto out = torch::empty_like(x);
test_launch(out.data_ptr<float>(), at::cuda::getCurrentCUDAStream(x.device().index()));
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"test_func\", &test_func);
}
'''
ext = cpp_ext.load_inline('test_ext', cpp.strip(), cuda.strip(), verbose=False)
x = torch.ones(5, device='cuda')
y = ext.test_func(x)
print(y)
" 2>&1 | tail -10/home/infatoshi/.cache/torch_extensions/py311_cu130/test_ext/main.cpp:3:41: error: ‘cudaStream_t’ has not been declared
3 | extern "C" void test_launch(float* out, cudaStream_t s);
| ^~~~~~~~~~~~
/home/infatoshi/.cache/torch_extensions/py311_cu130/test_ext/main.cpp: In function ‘at::Tensor test_func(at::Tensor)’:
/home/infatoshi/.cache/torch_extensions/py311_cu130/test_ext/main.cpp:6:50: error: ‘getCurrentCUDAStream’ is not a member of ‘at::cuda’
6 | test_launch(out.data_ptr<float>(), at::cuda::getCurrentCUDAStream(x.device().index()));
| ^~~~~~~~~~~~~~~~~~~~
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=test_ext -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/test_ext/cuda.cu -o cuda.c
... (49 more chars){
"command": "<867 chars \u2014 see below>",
"description": "Test minimal CUDA extension v2",
"timeout": 180000
}CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch, torch.utils.cpp_extension as cpp_ext
cuda = '''
extern \"C\" __global__ void test_kernel(float* out) { out[0] = 42.0f; }
extern \"C\" void test_launch(float* out, cudaStream_t s) {
test_kernel<<<1, 1, 0, s>>>(out);
}
'''
cpp = '''
#include <torch/extension.h>
#include <cuda_runtime.h>
extern \"C\" void test_launch(float* out, cudaStream_t s);
torch::Tensor test_func(torch::Tensor x) {
auto out = torch::empty_like(x);
cudaStream_t s = at::cuda::getCurrentCUDAStream(x.device().index());
test_launch(out.data_ptr<float>(), s);
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"test_func\", &test_func);
}
'''
ext = cpp_ext.load_inline('test_ext2', cpp.strip(), cuda.strip(), verbose=False)
x = torch.ones(5, device='cuda')
y = ext.test_func(x)
print(y)
" 2>&1 | tail -57 | cudaStream_t s = at::cuda::getCurrentCUDAStream(x.device().index());
| ^~~~~~~~~~~~~~~~~~~~
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=test_ext2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/test_ext2/cuda.cu -o cuda.cuda.o
ninja: build stopped: subcommand failed.
{
"command": "<875 chars \u2014 see below>",
"description": "Test minimal CUDA ext v3",
"timeout": 180000
}CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch, torch.utils.cpp_extension as cpp_ext
cuda = '''
extern \"C\" __global__ void test_kernel(float* out) { out[0] = 42.0f; }
extern \"C\" void test_launch(float* out, cudaStream_t s) {
test_kernel<<<1, 1, 0, s>>>(out);
}
'''
cpp = '''
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
extern \"C\" void test_launch(float* out, cudaStream_t s);
torch::Tensor test_func(torch::Tensor x) {
auto out = torch::empty_like(x);
cudaStream_t s = c10::cuda::getCurrentCUDAStream(x.device().index());
test_launch(out.data_ptr<float>(), s);
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(\"test_func\", &test_func);
}
'''
ext = cpp_ext.load_inline('test_ext3', cpp.strip(), cuda.strip(), verbose=False)
x = torch.ones(5, device='cuda')
y = ext.test_func(x)
print(y)
" 2>&1 | tail -5tensor([42., 0., 0., 0., 0.], device='cuda:0')
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -143,6 +143,7 @@
_CPP_SRC = """
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
+#include <cuda_runtime.h>
extern "C" void
topk_launch(const float* X, float* OutV, int64_t* OutI,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<61 chars \u2014 see below>",
"newString": "<87 chars \u2014 see below>"
}#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h>Edit applied successfully.{
"command": "<219 chars \u2014 see below>",
"description": "Debug kernel with blocking",
"timeout": 180000
}CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch
import solution
solution._ext = None
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
v, i = m(x)
" 2>&1 | tail -15(no output){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<144 chars \u2014 see below>",
"newString": "<144 chars \u2014 see below>"
} def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
return ext.topk_func(x, self.k) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
return ext.topk_func(x, self.k)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -1,3 +1,5 @@
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
- return ext.topk_func(x, self.k)+ v, i = ext.topk_func(x, self.k)
+ torch.cuda.synchronize()
+ return v, i{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<144 chars \u2014 see below>",
"newString": "<197 chars \u2014 see below>"
} def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
return ext.topk_func(x, self.k) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
ext = _ensure_ext()
v, i = ext.topk_func(x, self.k)
torch.cuda.synchronize()
return v, i{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"offset": 120,
"limit": 30
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
120: OutI[row * k + kk] = si[wslot];
121: }
122:
123: if (tid == wslot % BS) {
124: sv[wslot] = -FLT_MAX;
125: }
126: __syncthreads();
127: }
128: }
129:
130: extern "C" void
131: topk_launch(const float* X, float* OutV, int64_t* OutI,
132: int batch, int n, int k, cudaStream_t stream) {
133: // sv: 8192 float = 32768 B
134: // si: 8192 int64 = 65536 B
135: // rv: 1024 float = 4096 B
136: // ri: 1024 int64 = 8192 B
137: // total = 110592 B
138: int sm = 8192 * 4 + 8192 * 8 + 1024 * 4 + 1024 * 8;
139: topk_kernel<<<batch, BS, sm, stream>>>(X, OutV, OutI, batch, n, k);
140: }
141: """
142:
143: _CPP_SRC = """
144: #include <torch/extension.h>
145: #include <c10/cuda/CUDAStream.h>
146: #include <cuda_runtime.h>
147:
148: extern "C" void
149: topk_launch(const float* X, float* OutV, int64_t* OutI,
(Showing lines 120-149 of 213. Use offset=150 to continue.)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"limit": 30,
"offset": 180
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
180: if _ext is not None:
181: return _ext
182: _ext = cpp_ext.load_inline(
183: name="topk_cuda_ext",
184: cpp_sources=_CPP_SRC.strip(),
185: cuda_sources=_CUDA_SRC.strip(),
186: extra_cuda_cflags=["-O3", "-use_fast_math"],
187: verbose=False,
188: )
189: return _ext
190:
191:
192: class Model(nn.Module):
193: def __init__(self, batch: int, n: int, k: int):
194: super().__init__()
195: self.batch, self.n, self.k = batch, n, k
196: self.register_buffer("_dummy", torch.zeros(1))
197:
198: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
199: ext = _ensure_ext()
200: result = ext.topk_func(x, self.k)
201: torch.cuda.synchronize()
202: return result
203:
204:
205: batch, n, k = 64, 8192, 8
206:
207:
208: def get_inputs():
209: return [torch.randn(batch, n, dtype=torch.float32)]
(Showing lines 180-209 of 213. Use offset=210 to continue.)
</content>{
"command": "<321 chars \u2014 see below>",
"description": "Test with fixed includes",
"timeout": 180000
}uv run python -c "
import torch
import solution
solution._ext = None
torch.manual_seed(42)
m = solution.Model(1, 100, 5).cuda()
x = torch.randn(1, 100, device='cuda')
v, i = m(x)
rv, ri = torch.topk(x, 5, dim=-1)
print('sol:', v[0])
print('ref:', rv[0])
print('match:', torch.allclose(v, rv, atol=1e-4))
" 2>&1 | tail -15Traceback (most recent call last):
File "<string>", line 9, in <module>
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -158,7 +158,9 @@
auto outV = torch::empty({batch, k}, opts);
auto outI = torch::empty({batch, k}, opts.dtype(torch::kInt64));
- cudaStream_t stream = at::cuda::getCurrentCUDAStream(x.device().index());
+ // Use the default stream to avoid stream ordering issues
+ cudaStream_t stream;
+ cudaStreamCreate(&stream);
topk_launch(
x.data_ptr<float>(),
outV.data_ptr<float>(),
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<94 chars \u2014 see below>",
"newString": "<134 chars \u2014 see below>"
} cudaStream_t stream = at::cuda::getCurrentCUDAStream(x.device().index());
topk_launch( // Use the default stream to avoid stream ordering issues
cudaStream_t stream;
cudaStreamCreate(&stream);
topk_launch(Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -158,9 +158,6 @@
auto outV = torch::empty({batch, k}, opts);
auto outI = torch::empty({batch, k}, opts.dtype(torch::kInt64));
- // Use the default stream to avoid stream ordering issues
- cudaStream_t stream;
- cudaStreamCreate(&stream);
topk_launch(
x.data_ptr<float>(),
outV.data_ptr<float>(),
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<134 chars \u2014 see below>",
"newString": " topk_launch("
} // Use the default stream to avoid stream ordering issues
cudaStream_t stream;
cudaStreamCreate(&stream);
topk_launch(Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"limit": 25,
"offset": 155
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py</path>
<type>file</type>
<content>
155: int64_t batch = x.size(0);
156: int64_t n = x.size(1);
157:
158: auto outV = torch::empty({batch, k}, opts);
159: auto outI = torch::empty({batch, k}, opts.dtype(torch::kInt64));
160:
161: topk_launch(
162: x.data_ptr<float>(),
163: outV.data_ptr<float>(),
164: outI.data_ptr<int64_t>(),
165: (int)batch, (int)n, (int)k, stream
166: );
167:
168: return std::make_tuple(outV, outI);
169: }
170:
171: PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
172: m.def("topk_func", &topk_func, "top-k via streaming selection");
173: }
174: """
175:
176:
177: def _ensure_ext():
178: global _ext
179: if _ext is not None:
(Showing lines 155-179 of 212. Use offset=180 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py
@@ -162,7 +162,7 @@
x.data_ptr<float>(),
outV.data_ptr<float>(),
outI.data_ptr<int64_t>(),
- (int)batch, (int)n, (int)k, stream
+ (int)batch, (int)n, (int)k, 0
);
return std::make_tuple(outV, outI);
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/05_topk_bitonic/solution.py",
"oldString": "<161 chars \u2014 see below>",
"newString": "<156 chars \u2014 see below>"
} topk_launch(
x.data_ptr<float>(),
outV.data_ptr<float>(),
outI.data_ptr<int64_t>(),
(int)batch, (int)n, (int)k, stream
); topk_launch(
x.data_ptr<float>(),
outV.data_ptr<float>(),
outI.data_ptr<int64_t>(),
(int)batch, (int)n, (int)k, 0
);Edit applied successfully.