"""Custom top-k kernel for RTX PRO 6000 (SM120 Blackwell, GDDR7). Same Model / get_inputs / get_init_inputs interface as reference.py. Algorithm (per row, last dim): * k == 1 : warp-reduce argmax, one block per row. * k >= 2 : tiled, two-stage. stage 1 (per tile): each thread keeps a sorted-ascending top-K in registers, then a shared-memory pairwise MERGE TREE (ping-pong, log2(BLK) levels, O(BLK*K) compares) reduces BLK per-thread buffers to one top-K. Far fewer syncs / compares than a full bitonic sort. stage 2 (merge): a second merge-tree combines the per-tile sorted buffers into the final sorted-descending top-K. Tiles are chosen (power of two) so batch*tiles fills the GPU. * CUDA graph capture keyed on the input data_ptr removes Python/launch overhead (these 0.5-2MB shapes are dominated by per-call overhead, not by the memory read). Falls back to direct execution when the input pointer changes (e.g. the correctness runner uses many inputs). Implemented as a CUDA C++ extension via torch.utils.cpp_extension.load_inline. """ import torch import torch.nn as nn from torch.utils.cpp_extension import load_inline OP_TYPE = "topk" SUPPORTED_PRECISIONS = ["fp32"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] _CUDA_SRC = r""" #include #include #include #include constexpr int npow2(int x){ int p = 1; while (p < x) p <<= 1; return p; } constexpr int ilog2(int x){ int r=0; while ((1< __device__ __forceinline__ void bitonic_sort_desc(float* sval, int* sidx, int M) { int tid = threadIdx.x; int logM = ilog2(M); for (int s = 1; s <= logM; s++) { int size = 1 << s; for (int stride = size >> 1; stride > 0; stride >>= 1) { __syncthreads(); for (int i = tid; i < M; i += BLK) { int j = i ^ stride; if (j > i) { float a = sval[i], b = sval[j]; bool desc = ((i & size) == 0); bool swap = desc ? (a < b) : (a > b); if (swap) { sval[i] = b; sval[j] = a; int ti = sidx[i]; sidx[i] = sidx[j]; sidx[j] = ti; } } } } } __syncthreads(); } // ---------- Argmax (k=1): one block per row ---------- template __global__ void argmax_kernel( const float* __restrict__ xin, float* __restrict__ out_val, long long* __restrict__ out_idx, int n) { constexpr int WARP = 32; constexpr int NWARPS = BLK / WARP; int row = blockIdx.x; int tid = threadIdx.x; int wid = tid / WARP, lane = tid % WARP; const long long base = (long long)row * n; __shared__ float sw[NWARPS]; __shared__ int si[NWARPS]; float myv = -INFINITY; int myi = 0; for (int e = tid; e < n; e += BLK) { float val = xin[base + e]; if (val > myv) { myv = val; myi = e; } } for (int off = 16; off > 0; off >>= 1) { float ov = __shfl_xor_sync(0xffffffff, myv, off); int oi = __shfl_xor_sync(0xffffffff, myi, off); if (ov > myv) { myv = ov; myi = oi; } } if (lane == 0) { sw[wid] = myv; si[wid] = myi; } __syncthreads(); if (wid == 0) { myv = (lane < NWARPS) ? sw[lane] : -INFINITY; myi = (lane < NWARPS) ? si[lane] : 0; for (int off = 16; off > 0; off >>= 1) { float ov = __shfl_xor_sync(0xffffffff, myv, off); int oi = __shfl_xor_sync(0xffffffff, myi, off); if (ov > myv) { myv = ov; myi = oi; } } if (lane == 0) { out_val[row] = myv; out_idx[row] = (long long)myi; } } } // ---------- Merge two ascending sorted K-arrays into top-K ascending ---------- template __device__ __forceinline__ void merge_topk( const float* A, const int* Ai, const float* B, const int* Bi, float* C, int* Ci) { int ia = K - 1, ib = K - 1; #pragma unroll for (int c = K - 1; c >= 0; c--) { float av = (ia >= 0) ? A[ia] : -INFINITY; float bv = (ib >= 0) ? B[ib] : -INFINITY; if (av >= bv) { C[c] = av; Ci[c] = Ai[ia]; ia--; } else { C[c] = bv; Ci[c] = Bi[ib]; ib--; } } } // ---------- Stage 1 (merge-tree): per-tile top-K into partial buffers ---------- // Per-thread top-K kept sorted ascending in registers; then a shared-mem // pairwise merge tree (ping-pong) reduces BLK buffers to 1. template __global__ void topk_part_mt_kernel( const float* __restrict__ xin, float* __restrict__ tmp_val, int* __restrict__ tmp_idx, int n, int tile_len, int tiles) { int seg = blockIdx.x; int row = seg / tiles; int tile = seg % tiles; int tid = threadIdx.x; long long base = (long long)row * n + (long long)tile * tile_len; int my_len = min(tile_len, n - tile * tile_len); int idx_base = tile * tile_len; // phase 1: register sorted-ascending top-K, rv[0] = min float rv[K]; int ri[K]; #pragma unroll for (int s = 0; s < K; s++) { rv[s] = -INFINITY; ri[s] = 0; } for (int e = tid; e < my_len; e += BLK) { float val = xin[base + e]; if (val > rv[0]) { rv[0] = val; ri[0] = idx_base + e; #pragma unroll for (int s = 1; s < K; s++) { if (rv[s] < rv[s-1]) { float tf = rv[s]; rv[s] = rv[s-1]; rv[s-1] = tf; int ti = ri[s]; ri[s] = ri[s-1]; ri[s-1] = ti; } } } } extern __shared__ char smem[]; // Pad buffer stride to an odd value (coprime with 32 banks) to avoid the // 32-way bank conflicts on the writeback sv[tid*KP+s] when K is a multiple of 32. // K=64 keeps K (register-spill/merge bound, not bank-conflict bound; K+1 once // mis-verified under a sync change, so keep the simpler verified stride there). constexpr int KP = (K == 64) ? K : K + 1; float* sv = (float*)smem; // 2*BLK*KP floats int* si = (int*)(sv + (size_t)2 * BLK * KP); // 2*BLK*KP ints #pragma unroll for (int s = 0; s < K; s++) { sv[tid*KP + s] = rv[s]; si[tid*KP + s] = ri[s]; } __syncthreads(); float* src_v = sv; int* src_i = si; // region A float* dst_v = sv + BLK * KP; int* dst_i = si + BLK * KP; // region B int active = BLK; while (active > 1) { int pair = tid; if (pair < active/2) { merge_topk(src_v + (2*pair)*KP, src_i + (2*pair)*KP, src_v + (2*pair+1)*KP, src_i + (2*pair+1)*KP, dst_v + pair*KP, dst_i + pair*KP); } __syncthreads(); float* tv2 = src_v; src_v = dst_v; dst_v = tv2; int* ti2 = src_i; src_i = dst_i; dst_i = ti2; active >>= 1; } // result (top-K ascending) in src buffer 0; write ASCENDING for stage 2 if (tid < K) { tmp_val[seg*K + tid] = src_v[tid]; tmp_idx[seg*K + tid] = src_i[tid]; } } // ---------- Stage 2 (bitonic merge): merge tiles*K candidates -> top-K ---------- // (fallback for very large candidate counts) template __global__ void topk_merge_kernel( const float* __restrict__ tmp_val, const int* __restrict__ tmp_idx, float* __restrict__ out_val, long long* __restrict__ out_idx, int total, int M) { int row = blockIdx.x; int tid = threadIdx.x; extern __shared__ char smem[]; float* sval = (float*)smem; int* sidx = (int*)(smem + (size_t)M * 4); for (int i = tid; i < M; i += BLK) { if (i < total) { sval[i] = tmp_val[(long long)row*total + i]; sidx[i] = tmp_idx[(long long)row*total + i]; } else { sval[i] = -INFINITY; sidx[i] = 0; } } __syncthreads(); bitonic_sort_desc(sval, sidx, M); if (tid < K) { out_val[row*K + tid] = sval[tid]; out_idx[row*K + tid] = (long long)sidx[tid]; } } // ---------- Stage 2 (merge-tree): merge tiles sorted K-buffers -> top-K ---------- template __global__ void topk_merge_mt_kernel( const float* __restrict__ tmp_val, const int* __restrict__ tmp_idx, float* __restrict__ out_val, long long* __restrict__ out_idx, int tiles, int total) { int row = blockIdx.x; int tid = threadIdx.x; constexpr int KP = (K == 64) ? K : K + 1; // padded buffer stride -> bank-conflict-free extern __shared__ char smem[]; float* sv = (float*)smem; // 2*tiles*KP floats int* si = (int*)(sv + (size_t)2 * tiles * KP); // 2*tiles*KP ints // scatter-load: element (b,s) -> sv[b*KP + s] for (int i = tid; i < total; i += BLK) { int b = i / K, s = i - b * K; sv[b*KP + s] = tmp_val[(long long)row*total + i]; si[b*KP + s] = tmp_idx[(long long)row*total + i]; } __syncthreads(); float* src_v = sv; int* src_i = si; float* dst_v = sv + tiles*KP; int* dst_i = si + tiles*KP; int active = tiles; while (active > 1) { int pair = tid; if (pair < active/2) { merge_topk(src_v + (2*pair)*KP, src_i + (2*pair)*KP, src_v + (2*pair+1)*KP, src_i + (2*pair+1)*KP, dst_v + pair*KP, dst_i + pair*KP); } __syncthreads(); float* tv2 = src_v; src_v = dst_v; dst_v = tv2; int* ti2 = src_i; src_i = dst_i; dst_i = ti2; active >>= 1; } if (tid < K) { int s = K - 1 - tid; out_val[row*K + tid] = src_v[s]; out_idx[row*K + tid] = (long long)src_i[s]; } } // ================= dispatch ================= template struct PartialBlk { static constexpr int VAL = 256; }; template <> struct PartialBlk<16> { static constexpr int VAL = 128; }; template <> struct PartialBlk<32> { static constexpr int VAL = 128; }; template <> struct PartialBlk<64> { static constexpr int VAL = 64; }; template static inline void launch_multiblock(int batch, int n, int tile_len, int tiles, const float* xv, float* ov, long long* oi, float* tv, int* ti, cudaStream_t s) { constexpr int BLK = PartialBlk::VAL; constexpr size_t smem = (size_t)4 * BLK * (K + 1) * 4; // 2 regions * (val+idx), padded topk_part_mt_kernel<<>>( xv, tv, ti, n, tile_len, tiles); int total = tiles * K; // merge-tree merge: 2 regions * (tiles buffers of K+1 each) * (val+idx) size_t m_smem = (size_t)4 * tiles * (K + 1) * 4; constexpr int MB = 256; if (m_smem <= 99 * 1024) { topk_merge_mt_kernel<<>>(tv, ti, ov, oi, tiles, total); } else { int M = npow2(total); topk_merge_kernel<<>>(tv, ti, ov, oi, total, M); } } // One-time init: opt in to >48KB dynamic shared mem (call once, NOT in capture). void topk_init() { auto setp = [](auto kc) { constexpr int K = decltype(kc)::value; constexpr int BLK = PartialBlk::VAL; constexpr size_t smem = (size_t)4 * BLK * (K + 1) * 4; if (smem > 49152) cudaFuncSetAttribute(topk_part_mt_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 99 * 1024); }; setp(std::integral_constant{}); setp(std::integral_constant{}); setp(std::integral_constant{}); setp(std::integral_constant{}); auto setm = [](auto kc) { constexpr int K = decltype(kc)::value; cudaFuncSetAttribute(topk_merge_kernel<256, K>, cudaFuncAttributeMaxDynamicSharedMemorySize, 99 * 1024); cudaFuncSetAttribute(topk_merge_mt_kernel<256, K>, cudaFuncAttributeMaxDynamicSharedMemorySize, 99 * 1024); }; setm(std::integral_constant{}); setm(std::integral_constant{}); setm(std::integral_constant{}); setm(std::integral_constant{}); } void topk_forward(torch::Tensor x, torch::Tensor ov, torch::Tensor oi, torch::Tensor tmp_val, torch::Tensor tmp_idx, int64_t k) { int batch = x.size(0); int n = x.size(1); cudaStream_t s = at::cuda::getCurrentCUDAStream(); const float* xv = x.data_ptr(); float* ov_ = ov.data_ptr(); long long* oi_ = (long long*)oi.data_ptr(); float* tv = tmp_val.data_ptr(); int* ti = tmp_idx.data_ptr(); if (k == 1) { constexpr int BLK = 256; argmax_kernel<<>>(xv, ov_, oi_, n); return; } // tiles (power of two): fill the GPU while keeping merge shared <= ~99KB int max_tiles_merge = (99 * 1024) / (16 * (int)k); int want = (256 + batch - 1) / batch; int tiles = 1; while (tiles * 2 <= want && tiles * 2 <= max_tiles_merge) tiles *= 2; if (tiles < 1) tiles = 1; int tile_len = (n + tiles - 1) / tiles; tiles = (n + tile_len - 1) / tile_len; switch (k) { case 2: launch_multiblock<2>(batch, n, tile_len, tiles, xv, ov_, oi_, tv, ti, s); break; case 4: launch_multiblock<4>(batch, n, tile_len, tiles, xv, ov_, oi_, tv, ti, s); break; case 8: launch_multiblock<8>(batch, n, tile_len, tiles, xv, ov_, oi_, tv, ti, s); break; case 16: launch_multiblock<16>(batch, n, tile_len, tiles, xv, ov_, oi_, tv, ti, s); break; case 32: launch_multiblock<32>(batch, n, tile_len, tiles, xv, ov_, oi_, tv, ti, s); break; case 64: launch_multiblock<64>(batch, n, tile_len, tiles, xv, ov_, oi_, tv, ti, s); break; default: break; } } """ _DECLS = r""" void topk_forward(torch::Tensor x, torch::Tensor ov, torch::Tensor oi, torch::Tensor tmp_val, torch::Tensor tmp_idx, int64_t k); void topk_init(); """ _mod = None def _get_mod(): global _mod if _mod is None: _mod = load_inline( name="topk_cuda_sol5", cpp_sources=[_DECLS], cuda_sources=[_CUDA_SRC], functions=["topk_forward", "topk_init"], extra_cuda_cflags=["-O3", "--use_fast_math", "-std=c++17", "-gencode=arch=compute_120,code=sm_120"], verbose=False, ) _mod.topk_init() return _mod class Model(nn.Module): def __init__(self, batch: int, n: int, k: int): super().__init__() self.batch, self.n, self.k = batch, n, k self.register_buffer("_dummy", torch.zeros(1)) self._ov = None self._oi = None self._tv = None self._ti = None self._graph = None self._captured_ptr = None def _ensure(self, device): if self._ov is None: _get_mod() self._ov = torch.empty(self.batch, self.k, dtype=torch.float32, device=device) self._oi = torch.empty(self.batch, self.k, dtype=torch.int64, device=device) tiles = max(1, (self.n + 255) // 256) self._tv = torch.empty(self.batch * tiles, self.k, dtype=torch.float32, device=device) self._ti = torch.empty(self.batch * tiles, self.k, dtype=torch.int32, device=device) def _run(self, x): _get_mod().topk_forward(x, self._ov, self._oi, self._tv, self._ti, self.k) def _capture(self, x): s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): self._run(x) torch.cuda.current_stream().wait_stream(s) g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): self._run(x) self._graph = g self._captured_ptr = x.data_ptr() # Override __call__ to bypass nn.Module hook dispatch overhead. def __call__(self, x: torch.Tensor): self._ensure(x.device) if x.data_ptr() != self._captured_ptr: self._capture(x) self._graph.replay() return self._ov, self._oi def forward(self, x: torch.Tensor): self._ensure(x.device) if x.data_ptr() != self._captured_ptr: self._capture(x) else: self._graph.replay() return self._ov, self._oi # Module-level shims rebuilt by check.py / benchmark.py per shape. batch = 64 n = 8192 k = 8 def get_inputs(): x = torch.randn(batch, n, dtype=torch.float32) return [x] def get_init_inputs(): return [batch, n, k]