"""Single-kernel top-k for SM120 (RTX PRO 6000). Strategy: one fused CUDA kernel per forward. Each row is split across one or more blocks; each warp streams its span keeping a register-resident sorted top-K (Faiss WarpSelect-style threshold + per-lane candidate queue, bitonic networks via warp shuffles). Warps merge through shared memory. Multi-split rows publish packed (key|idx) candidates to a scratch buffer; the last block per row (device-scope acq-rel counter) merges them and writes the output. Values travel as a monotonic fp32->u32 key packed with the source index into a u64, so all comparisons are single integer compares and the exact fp32 bits round-trip on output. """ import os import torch from torch.utils.cpp_extension import load_inline _CPP_SRC = r""" #include int64_t topk_configure(int64_t batch, int64_t n, int64_t k, int64_t splits, int64_t threads, int64_t vals_ptr, int64_t idx_ptr); void topk_run(int64_t handle, int64_t x_ptr); void topk_run_cached(int64_t handle); """ _CUDA_SRC = r""" #include #include #include #include #include #include #include using u32 = unsigned int; using u64 = unsigned long long; #define FULL_MASK 0xffffffffu // Monotonic transform: descending float order == descending u32 key order. __device__ __forceinline__ u32 fkey(u32 bits) { return bits ^ (u32(int(bits) >> 31) | 0x80000000u); } __device__ __forceinline__ float unkey(u32 key) { u32 b = key ^ ((key & 0x80000000u) ? 0x80000000u : 0xFFFFFFFFu); return __uint_as_float(b); } __device__ __forceinline__ u64 umax(u64 a, u64 b) { return a > b ? a : b; } __device__ __forceinline__ u64 umin(u64 a, u64 b) { return a < b ? a : b; } // Descending bitonic merge of 32 elements held one-per-lane. __device__ __forceinline__ void bmerge32(u64 &v, int lane) { #pragma unroll for (int d = 16; d >= 1; d >>= 1) { u64 o = __shfl_xor_sync(FULL_MASK, v, d); v = ((lane & d) == 0) ? umax(v, o) : umin(v, o); } } // Descending bitonic merge of 64 elements: element e = j*32+lane in regs a0,a1. __device__ __forceinline__ void bmerge64(u64 &a0, u64 &a1, int lane) { u64 lo = umin(a0, a1); a0 = umax(a0, a1); a1 = lo; bmerge32(a0, lane); bmerge32(a1, lane); } // Full descending sort of 32 elements one-per-lane. __device__ __forceinline__ void sort32(u64 &v, int lane) { #pragma unroll for (int s = 2; s <= 32; s <<= 1) { #pragma unroll for (int d = s >> 1; d >= 1; d >>= 1) { u64 o = __shfl_xor_sync(FULL_MASK, v, d); bool i_low = (lane & d) == 0; bool asc = (lane & s) != 0; // gives overall descending bool take_max = i_low ^ asc; v = take_max ? umax(v, o) : umin(v, o); } } } // ---------- Warp-resident top-K buffer ------------------------------------- // R registers per lane; element e = j*32 + lane; sorted descending by e. template struct WarpBuf; template <> struct WarpBuf<1> { u64 b0; __device__ __forceinline__ void init() { b0 = 0; } // Merge a descending-sorted 32-vector (one per lane) into the buffer. __device__ __forceinline__ void merge_sorted32(u64 v, int lane) { u64 r = __shfl_xor_sync(FULL_MASK, v, 31); // reversed b0 = umax(b0, r); bmerge32(b0, lane); } __device__ __forceinline__ u32 kth_key(int k, int lane) { u64 e = __shfl_sync(FULL_MASK, b0, (k - 1) & 31); return (u32)(e >> 32); } __device__ __forceinline__ void store(u64 *dst, int lane) { dst[lane] = b0; } __device__ __forceinline__ void load(const u64 *src, int lane) { b0 = src[lane]; } // Merge another descending list of 32 (in shared, reversed addressing). __device__ __forceinline__ void merge_list(const u64 *src, int lane) { b0 = umax(b0, src[31 - lane]); bmerge32(b0, lane); } }; template <> struct WarpBuf<2> { u64 b0, b1; __device__ __forceinline__ void init() { b0 = 0; b1 = 0; } __device__ __forceinline__ void merge_sorted32(u64 v, int lane) { // positions 32..63 paired with reversed v u64 r = __shfl_xor_sync(FULL_MASK, v, 31); b1 = umax(b1, r); bmerge64(b0, b1, lane); } __device__ __forceinline__ u32 kth_key(int k, int lane) { u64 e = (k <= 32) ? __shfl_sync(FULL_MASK, b0, (k - 1) & 31) : __shfl_sync(FULL_MASK, b1, (k - 1) & 31); return (u32)(e >> 32); } __device__ __forceinline__ void store(u64 *dst, int lane) { dst[lane] = b0; dst[32 + lane] = b1; } __device__ __forceinline__ void load(const u64 *src, int lane) { b0 = src[lane]; b1 = src[32 + lane]; } __device__ __forceinline__ void merge_list(const u64 *src, int lane) { b0 = umax(b0, src[63 - lane]); b1 = umax(b1, src[31 - lane]); bmerge64(b0, b1, lane); } }; // ---------- Streaming select: queue + flush --------------------------------- template struct WarpSelect { WarpBuf buf; u64 q[T]; int cnt; u32 thresh; int k; int lane; __device__ __forceinline__ void init(int k_, int lane_) { buf.init(); cnt = 0; thresh = 0; k = k_; lane = lane_; } __device__ __forceinline__ void flush() { #pragma unroll for (int j = 0; j < T; ++j) { // skip slots empty across the whole warp if (!__ballot_sync(FULL_MASK, cnt > j)) break; u64 v = (j < cnt) ? q[j] : 0ULL; sort32(v, lane); buf.merge_sorted32(v, lane); } cnt = 0; thresh = buf.kth_key(k, lane); } __device__ __forceinline__ void push(u32 key, u32 idx) { if (key > thresh) { q[cnt++] = ((u64)key << 32) | (u64)idx; } if (__ballot_sync(FULL_MASK, cnt == T)) flush(); } // Fast path: check 4 candidates with one ballot when none pass. __device__ __forceinline__ void push4(u32 k0, u32 k1, u32 k2, u32 k3, u32 i0) { u32 m = max(max(k0, k1), max(k2, k3)); if (__ballot_sync(FULL_MASK, m > thresh)) { push(k0, i0); push(k1, i0 + 1); push(k2, i0 + 2); push(k3, i0 + 3); } } __device__ __forceinline__ void finish() { if (__ballot_sync(FULL_MASK, cnt > 0)) flush(); } }; // ---------- Kernel ----------------------------------------------------------- template __global__ void __launch_bounds__(W * 32) topk_kernel( const float *__restrict__ x, float *__restrict__ out_vals, long *__restrict__ out_idx, u64 *__restrict__ scratch, // batch * splits * k int *__restrict__ counters, // batch int n, int k, int splits, int chunk, int per_warp, int vec_ok) { constexpr int K32 = 32 * R; __shared__ u64 smem[W * K32]; __shared__ int s_last; const int row = blockIdx.y; const int split = blockIdx.x; const int w = threadIdx.x >> 5; const int lane = threadIdx.x & 31; const float *xr = x + (size_t)row * n; WarpSelect sel; sel.init(k, lane); // ---- Phase A: stream this block's chunk -------------------------------- { const int cbeg = split * chunk; const int cend = min(cbeg + chunk, n); const int wbeg = cbeg + w * per_warp; const int wend = min(wbeg + per_warp, cend); if (wbeg < wend) { const int span = wend - wbeg; const int nfull = vec_ok ? (span >> 7) : 0; // full 128-elem warp steps const float4 *base = reinterpret_cast(xr + wbeg); for (int s = 0; s < nfull; ++s) { float4 v = __ldcs(base + s * 32 + lane); u32 i0 = (u32)(wbeg + ((s * 32 + lane) << 2)); sel.push4(fkey(__float_as_uint(v.x)), fkey(__float_as_uint(v.y)), fkey(__float_as_uint(v.z)), fkey(__float_as_uint(v.w)), i0); } for (int ib = wbeg + (nfull << 7); ib < wend; ib += 32) { const int i = ib + lane; const bool act = i < wend; u32 key = act ? fkey(__float_as_uint(__ldg(xr + i))) : 0u; if (__ballot_sync(FULL_MASK, act && key > sel.thresh)) { if (act && key > sel.thresh) sel.q[sel.cnt++] = ((u64)key << 32) | (u64)(u32)i; if (__ballot_sync(FULL_MASK, sel.cnt == T)) sel.flush(); } } } sel.finish(); // block merge sel.buf.store(smem + w * K32, lane); __syncthreads(); if (w == 0) { #pragma unroll 1 for (int ow = 1; ow < W; ++ow) sel.buf.merge_list(smem + ow * K32, lane); } } if (splits == 1) { if (w == 0) { u64 e0 = sel.buf.b0; if (lane < k) { out_vals[(size_t)row * k + lane] = unkey((u32)(e0 >> 32)); out_idx[(size_t)row * k + lane] = (long)(u32)e0; } if constexpr (R == 2) { u64 e1 = sel.buf.b1; int e = 32 + lane; if (e < k) { out_vals[(size_t)row * k + e] = unkey((u32)(e1 >> 32)); out_idx[(size_t)row * k + e] = (long)(u32)e1; } } } return; } // ---- publish candidates to scratch -------------------------------------- u64 *srow = scratch + (size_t)row * splits * k; if (w == 0) { u64 e0 = sel.buf.b0; if (lane < k) srow[(size_t)split * k + lane] = e0; if constexpr (R == 2) { u64 e1 = sel.buf.b1; int e = 32 + lane; if (e < k) srow[(size_t)split * k + e] = e1; } } __syncthreads(); if (threadIdx.x == 0) { cuda::atomic_ref c(counters[row]); int prev = c.fetch_add(1, cuda::memory_order_acq_rel); s_last = (prev == splits - 1) ? 1 : 0; } __syncthreads(); if (!s_last) return; // ---- Phase B: last block merges all candidates -------------------------- { const int C = splits * k; sel.init(k, lane); const int per_w = (C + W - 1) / W; const int beg = w * per_w; const int end = min(beg + per_w, C); for (int i = beg + lane; ; i += 32) { bool act = i < end; u64 it = act ? __ldg(srow + i) : 0ULL; if (!__ballot_sync(FULL_MASK, act)) break; u32 key = (u32)(it >> 32); if (__ballot_sync(FULL_MASK, act && key > sel.thresh)) { if (act && key > sel.thresh) sel.q[sel.cnt++] = it; if (__ballot_sync(FULL_MASK, sel.cnt == T)) sel.flush(); } } sel.finish(); __syncthreads(); // smem reuse sel.buf.store(smem + w * K32, lane); __syncthreads(); if (w == 0) { #pragma unroll 1 for (int ow = 1; ow < W; ++ow) sel.buf.merge_list(smem + ow * K32, lane); u64 e0 = sel.buf.b0; if (lane < k) { out_vals[(size_t)row * k + lane] = unkey((u32)(e0 >> 32)); out_idx[(size_t)row * k + lane] = (long)(u32)e0; } if constexpr (R == 2) { u64 e1 = sel.buf.b1; int e = 32 + lane; if (e < k) { out_vals[(size_t)row * k + e] = unkey((u32)(e1 >> 32)); out_idx[(size_t)row * k + e] = (long)(u32)e1; } } } if (threadIdx.x == 0) counters[row] = 0; } } // ---------- Host side -------------------------------------------------------- struct Cfg { const float *x; float *vals; long *idx; u64 *scratch; int *counters; int batch, n, k, splits, chunk, per_warp, threads; int R; }; static std::vector g_cfgs; static void launch(const Cfg &c, cudaStream_t st) { dim3 grid(c.splits, c.batch); dim3 block(c.threads); const int W = c.threads / 32; const int vec_ok = (c.n % 4 == 0) ? 1 : 0; #define LAUNCH(RR, WW) topk_kernel<<>>( \ c.x, c.vals, c.idx, c.scratch, c.counters, c.n, c.k, c.splits, c.chunk, c.per_warp, vec_ok) if (c.R == 1) { switch (W) { case 4: LAUNCH(1, 4); break; case 8: LAUNCH(1, 8); break; case 16: LAUNCH(1, 16); break; default: TORCH_CHECK(false, "bad W"); } } else { switch (W) { case 4: LAUNCH(2, 4); break; case 8: LAUNCH(2, 8); break; case 16: LAUNCH(2, 16); break; default: TORCH_CHECK(false, "bad W"); } } #undef LAUNCH } int64_t topk_configure(int64_t batch, int64_t n, int64_t k, int64_t splits, int64_t threads, int64_t vals_ptr, int64_t idx_ptr) { Cfg c{}; c.batch = (int)batch; c.n = (int)n; c.k = (int)k; c.splits = (int)splits; c.threads = (int)threads; c.R = (k > 32) ? 2 : 1; c.chunk = (int)(((n + splits - 1) / splits + 3) & ~3LL); // recompute splits so last chunk is non-empty c.splits = (c.n + c.chunk - 1) / c.chunk; int W = c.threads / 32; c.per_warp = ((c.chunk + W - 1) / W + 3) & ~3; c.vals = (float *)vals_ptr; c.idx = (long *)idx_ptr; c.scratch = nullptr; c.counters = nullptr; if (c.splits > 1) { cudaMalloc(&c.scratch, (size_t)c.batch * c.splits * c.k * sizeof(u64)); cudaMalloc(&c.counters, (size_t)c.batch * sizeof(int)); cudaMemset(c.counters, 0, (size_t)c.batch * sizeof(int)); } g_cfgs.push_back(c); return (int64_t)g_cfgs.size() - 1; } void topk_run(int64_t handle, int64_t x_ptr) { Cfg &c = g_cfgs[handle]; c.x = (const float *)x_ptr; launch(c, c10::cuda::getCurrentCUDAStream().stream()); } void topk_run_cached(int64_t handle) { const Cfg &c = g_cfgs[handle]; launch(c, c10::cuda::getCurrentCUDAStream().stream()); } """ def _build(): os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "12.0") return load_inline( name="topk_sm120_v1", cpp_sources=_CPP_SRC, cuda_sources=_CUDA_SRC, functions=["topk_configure", "topk_run", "topk_run_cached"], extra_cuda_cflags=["-O3", "--use_fast_math", "-gencode=arch=compute_120,code=sm_120"], verbose=False, ) _ext = _build() # (splits, threads) per (batch, n, k); fallback heuristic otherwise. _TUNE = { (1, 131072, 64): (128, 256), (64, 8192, 8): (2, 256), (32, 16384, 32): (4, 256), (16, 12000, 16): (8, 256), (128, 4096, 1): (1, 256), } def _pick_cfg(batch: int, n: int, k: int): if (batch, n, k) in _TUNE: return _TUNE[(batch, n, k)] target = 256 splits = max(1, min(target // max(batch, 1), (n + 255) // 256)) return splits, 256 class Model(torch.nn.Module): def __init__(self, batch: int, n: int, k: int): super().__init__() self.batch, self.n, self.k = batch, n, k self.register_buffer("_dummy", torch.zeros(1)) dev = torch.device("cuda") self._vals = torch.empty(batch, k, dtype=torch.float32, device=dev) self._idx = torch.empty(batch, k, dtype=torch.int64, device=dev) splits, threads = _pick_cfg(batch, n, k) self._handle = _ext.topk_configure( batch, n, k, splits, threads, self._vals.data_ptr(), self._idx.data_ptr() ) self._ret = (self._vals, self._idx) self._last_ptr = -1 self._run = _ext.topk_run self._run_cached = _ext.topk_run_cached def forward(self, x: torch.Tensor): p = x.data_ptr() if p == self._last_ptr: self._run_cached(self._handle) else: self._last_ptr = p self._run(self._handle, p) return self._ret __call__ = forward # Module-level shims rebuilt by check.py / benchmark.py per shape. batch = 64 n = 8192 k = 8 def get_inputs(): x = torch.randn(batch, n, dtype=torch.float32) return [x] def get_init_inputs(): return [batch, n, k]