"""Custom top-k kernel for RTX PRO 6000 (SM120 Blackwell). Tiny inputs (0.5-2MB) => latency bound, ~8us read floor. Hybrid by k: k==1 : block argmax reduction. k<=16: per-thread register top-k (threshold-gated) + pairwise tree-merge. k>=32: load chunk to shared, bitonic sort descending, keep top-k (register arrays of size k spill to local mem for large k). Rows split across blocks (phase1); phase2 merges per-block partials. """ import torch import torch.nn as nn from torch.utils.cpp_extension import load_inline _CUDA = r''' #include #include #include #include #include #include #include // ===================== register top-k (small k) ===================== template __device__ __forceinline__ void insert(float* rv, int* ri, float v, int idx) { if (v <= rv[K-1]) return; #pragma unroll for (int i = K-1; i > 0; --i) { if (rv[i-1] < v) { rv[i] = rv[i-1]; ri[i] = ri[i-1]; } else { rv[i] = v; ri[i] = idx; return; } } rv[0] = v; ri[0] = idx; } template __global__ void regmerge(const float* __restrict__ src, const int* __restrict__ srcidx, float* __restrict__ pv, int* __restrict__ pi, float* __restrict__ outv, long* __restrict__ outi, int nrow, int bpr, int final_direct) { extern __shared__ char smem[]; float* sv = (float*)smem; int* si = (int*)(sv + blockDim.x * K); const int tid = threadIdx.x; const int row = blockIdx.y; const int b = blockIdx.x; const int BD = blockDim.x; int chunk = (nrow + bpr - 1) / bpr; long start = (long)b * chunk; long end = start + chunk; if (end > nrow) end = nrow; float rv[K]; int ri[K]; #pragma unroll for (int i = 0; i < K; ++i) { rv[i] = -CUDART_INF_F; ri[i] = 0; } const float* rx = src + (long)row * nrow; if (srcidx == nullptr) { for (long i = start + tid; i < end; i += BD) insert(rv, ri, rx[i], (int)i); } else { const int* rxi = srcidx + (long)row * nrow; for (long i = start + tid; i < end; i += BD) insert(rv, ri, rx[i], rxi[i]); } #pragma unroll for (int i = 0; i < K; ++i) { sv[tid*K + i] = rv[i]; si[tid*K + i] = ri[i]; } __syncthreads(); for (int stride = BD >> 1; stride > 0; stride >>= 1) { if (tid < stride) { int a = tid * K, bb = (tid + stride) * K; float tv[K]; int ti[K]; int ia = 0, ib = 0; #pragma unroll for (int o = 0; o < K; ++o) { float av = sv[a+ia], bv = sv[bb+ib]; if (av >= bv) { tv[o] = av; ti[o] = si[a+ia]; ia++; } else { tv[o] = bv; ti[o] = si[bb+ib]; ib++; } } #pragma unroll for (int o = 0; o < K; ++o) { sv[a+o] = tv[o]; si[a+o] = ti[o]; } } __syncthreads(); } if (final_direct) { for (int i = tid; i < K; i += BD) { outv[(long)row*K + i] = sv[i]; outi[(long)row*K + i] = (long)si[i]; } } else { for (int i = tid; i < K; i += BD) { long o = ((long)row*bpr + b)*K + i; pv[o] = sv[i]; pi[o] = si[i]; } } } // ===== register top-k + scalar double-buffer tree-merge (large k, no spill) ===== template __global__ void regmerge_db(const float* __restrict__ src, const int* __restrict__ srcidx, float* __restrict__ pv, int* __restrict__ pi, float* __restrict__ outv, long* __restrict__ outi, int nrow, int bpr, int final_direct) { extern __shared__ char smem[]; const int BD = blockDim.x; float* av = (float*)smem; int* ai = (int*)(av + BD*K); float* bv = (float*)(ai + BD*K); int* bi = (int*)(bv + BD*K); const int tid = threadIdx.x; const int row = blockIdx.y, b = blockIdx.x; int chunk = (nrow + bpr - 1) / bpr; long start = (long)b*chunk, end = start+chunk; if (end>nrow) end=nrow; float rv[K]; int ri[K]; #pragma unroll for (int i=0;i(rv,ri,rx[i],(int)i); } else { const int* rxi=srcidx+(long)row*nrow; for(long i=start+tid;i(rv,ri,rx[i],rxi[i]); } #pragma unroll for (int i=0;i1; nL>>=1) { int half=nL>>1; if (tid=y){ dv[d+o]=x; di[d+o]=si[a+ia]; ia++; } else { dv[d+o]=y; di[d+o]=si[bb+ib]; ib++; } } } __syncthreads(); float* t1=sv; sv=dv; dv=t1; int* t2=si; si=di; di=t2; } if (final_direct) { for(int i=tid;i fine). __device__ __forceinline__ unsigned f2ord(float f) { unsigned b = __float_as_uint(f); return (b & 0x80000000u) ? ~b : (b | 0x80000000u); } __device__ __forceinline__ float ord2f(unsigned o) { unsigned b = (o & 0x80000000u) ? (o & 0x7fffffffu) : ~o; return __uint_as_float(b); } __device__ __forceinline__ unsigned long long pack(float v, int idx) { return ((unsigned long long)f2ord(v) << 32) | (unsigned)idx; } #define NEG_KEY ((unsigned long long)0) // f2ord(-inf)=0 packs to smallest __device__ __forceinline__ void bitonic_desc(unsigned long long* s, int N) { for (int k = 2; k <= N; k <<= 1) { for (int j = k >> 1; j > 0; j >>= 1) { for (int i = threadIdx.x; i < N; i += blockDim.x) { int ixj = i ^ j; if (ixj > i) { bool up = ((i & k) == 0); unsigned long long a = s[i], b = s[ixj]; bool sw = up ? (a < b) : (a > b); if (sw) { s[i]=b; s[ixj]=a; } } } __syncthreads(); } } } __global__ void bitonic_kernel(const float* __restrict__ src, const int* __restrict__ srcidx, float* __restrict__ pv, int* __restrict__ pi, float* __restrict__ outv, long* __restrict__ outi, int nrow, int k, int CHUNK, int bpr, int final_direct) { extern __shared__ char smem[]; unsigned long long* s = (unsigned long long*)smem; const int row = blockIdx.y, b = blockIdx.x; long start = (long)b * CHUNK; const float* rx = src + (long)row * nrow; if (srcidx == nullptr) { for (int i = threadIdx.x; i < CHUNK; i += blockDim.x) { long gi = start + i; s[i] = (gi < nrow) ? pack(rx[gi], (int)gi) : NEG_KEY; } } else { const int* rxi = srcidx + (long)row * nrow; for (int i = threadIdx.x; i < CHUNK; i += blockDim.x) { long gi = start + i; s[i] = (gi < nrow) ? pack(rx[gi], rxi[gi]) : NEG_KEY; } } __syncthreads(); bitonic_desc(s, CHUNK); if (final_direct) { for (int i = threadIdx.x; i < k; i += blockDim.x) { outv[(long)row*k+i]=ord2f((unsigned)(s[i]>>32)); outi[(long)row*k+i]=(long)(unsigned)s[i]; } } else { for (int i = threadIdx.x; i < k; i += blockDim.x) { long o=((long)row*bpr+b)*k+i; pv[o]=ord2f((unsigned)(s[i]>>32)); pi[o]=(int)(unsigned)s[i]; } } } // ===================== cooperative single-launch (single row, large k) ===== namespace cg = cooperative_groups; __device__ __forceinline__ int dpow2(int v){ int p=1; while(p top k) then grid-sync'd // reduce passes, all in-kernel. SHP = shared capacity (ull). Row = blockIdx.y. __global__ void coop_topk(const float* __restrict__ x, int n, int k, int RC, unsigned long long* __restrict__ bufA, unsigned long long* __restrict__ bufB, float* __restrict__ outv, long* __restrict__ outi) { cg::grid_group grid = cg::this_grid(); extern __shared__ unsigned long long sm[]; const int G = gridDim.x; const int row = blockIdx.y; const int t = threadIdx.x, BD = blockDim.x; unsigned long long* A = bufA + (long)row * G * k; unsigned long long* B = bufB + (long)row * G * k; const float* rx = x + (long)row * n; // phase A int chunk = (n + G - 1) / G; int CP = dpow2(chunk); long start = (long)blockIdx.x * chunk; for (int i = t; i < CP; i += BD) { long gi = start + i; sm[i] = (i < chunk && gi < n) ? pack(rx[gi], (int)gi) : 0ULL; } __syncthreads(); bitonic_desc(sm, CP); for (int i = t; i < k; i += BD) A[(long)blockIdx.x * k + i] = sm[i]; grid.sync(); // reduce passes int m = G * k; unsigned long long* cur = A; unsigned long long* alt = B; while (m > k) { int groups = (m + RC - 1) / RC; int cchunk = (m + groups - 1) / groups; int CP2 = dpow2(cchunk); if (blockIdx.x < groups) { long st = (long)blockIdx.x * cchunk; for (int i = t; i < CP2; i += BD) { long gi = st + i; sm[i] = (i < cchunk && gi < m) ? cur[gi] : 0ULL; } __syncthreads(); bitonic_desc(sm, CP2); for (int i = t; i < k; i += BD) alt[(long)blockIdx.x * k + i] = sm[i]; } grid.sync(); m = groups * k; unsigned long long* tmp = cur; cur = alt; alt = tmp; } if (blockIdx.x == 0) { for (int i = t; i < k; i += BD) { outv[(long)row*k + i] = ord2f((unsigned)(cur[i] >> 32)); outi[(long)row*k + i] = (long)(unsigned)cur[i]; } } } // ===================== argmax (k=1) ===================== __global__ void argmax1(const float* __restrict__ x, float* __restrict__ pv, int* __restrict__ pi, int n, int chunk, int bpr) { __shared__ float sv[1024]; __shared__ int si[1024]; const int row = blockIdx.y, b = blockIdx.x, t = threadIdx.x; long start = (long)b*chunk, end = start+chunk; if (end>n) end=n; const float* rx = x + (long)row*n; float best=-CUDART_INF_F; int bidx=0; for (long i=start+t; ibest){best=v;bidx=(int)i;} } sv[t]=best; si[t]=bidx; __syncthreads(); for(int s=blockDim.x>>1;s>0;s>>=1){ if(tsv[t]){sv[t]=sv[t+s];si[t]=si[t+s];} __syncthreads(); } if(t==0){ long o=(long)row*bpr+b; pv[o]=sv[0]; pi[o]=si[0]; } } // one block per row, full reduction, write int64 directly __global__ void argmax_single(const float* __restrict__ x, float* __restrict__ outv, long* __restrict__ outi, int n) { __shared__ float sv[1024]; __shared__ int si[1024]; const int row = blockIdx.x, t = threadIdx.x; const float* rx = x + (long)row*n; float best=-CUDART_INF_F; int bidx=0; for (long i=t; ibest){best=v;bidx=(int)i;} } sv[t]=best; si[t]=bidx; __syncthreads(); for(int s=blockDim.x>>1;s>0;s>>=1){ if(tsv[t]){sv[t]=sv[t+s];si[t]=si[t+s];} __syncthreads(); } if(t==0){ outv[row]=sv[0]; outi[row]=(long)si[0]; } } __global__ void argmax2(const float* __restrict__ pv, const int* __restrict__ pi, float* __restrict__ outv, long* __restrict__ outi, int bpr) { __shared__ float sv[1024]; __shared__ int si[1024]; const int row=blockIdx.x, t=threadIdx.x; const float* bv=pv+(long)row*bpr; const int* bi=pi+(long)row*bpr; float best=-CUDART_INF_F; int bidx=0; for(int i=t;ibest){best=bv[i];bidx=bi[i];} } sv[t]=best; si[t]=bidx; __syncthreads(); for(int s=blockDim.x>>1;s>0;s>>=1){ if(tsv[t]){sv[t]=sv[t+s];si[t]=si[t+s];} __syncthreads(); } if(t==0){ outv[row]=sv[0]; outi[row]=(long)si[0]; } } static bool attr_set = false; static int g_numSM = -1; // Cooperative single-launch top-k for a single row (batch==1), large k. // Returns 0 on success, nonzero (cuda error) so caller can fall back. int topk_coop(torch::Tensor x, int k, int RC, torch::Tensor outv, torch::Tensor outi, torch::Tensor bufA, torch::Tensor bufB) { int batch = x.size(0); int n = x.size(1); auto stream = at::cuda::getCurrentCUDAStream(); if (g_numSM < 0) { cudaDeviceGetAttribute(&g_numSM, cudaDevAttrMultiProcessorCount, 0); } const int BD = 256; auto np2 = [](int v){ int p=1; while(p 512 ? RC : 512; int maxbpsm = 1; cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxbpsm, coop_topk, BD, (size_t)SHP*8); int maxG = maxbpsm * g_numSM / batch; int G = target < maxG ? target : maxG; if (G < 1) G = 1; int chunkA = (n + G - 1) / G; int CP = np2(chunkA); SHP = CP > RC ? CP : RC; cudaFuncSetAttribute(coop_topk, cudaFuncAttributeMaxDynamicSharedMemorySize, (size_t)SHP*8); cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxbpsm, coop_topk, BD, (size_t)SHP*8); maxG = maxbpsm * g_numSM / batch; if (G > maxG) { G = maxG; if (G < 1) G = 1; chunkA = (n + G - 1) / G; CP = np2(chunkA); SHP = CP > RC ? CP : RC; cudaFuncSetAttribute(coop_topk, cudaFuncAttributeMaxDynamicSharedMemorySize, (size_t)SHP*8); } const float* X = x.data_ptr(); float* OV = outv.data_ptr(); long* OI = outi.data_ptr(); unsigned long long* A = (unsigned long long*)bufA.data_ptr(); unsigned long long* B = (unsigned long long*)bufB.data_ptr(); void* kargs[] = { (void*)&X, (void*)&n, (void*)&k, (void*)&RC, (void*)&A, (void*)&B, (void*)&OV, (void*)&OI }; dim3 grid(G, batch); dim3 block(BD); cudaError_t err = cudaLaunchCooperativeKernel((void*)coop_topk, grid, block, kargs, (size_t)SHP*8, stream); return (int)err; } void topk_cuda(torch::Tensor x, int k, int bpr, int BD, int BD2, int CHUNK, int CHUNK2, torch::Tensor outv, torch::Tensor outi, torch::Tensor sAv, torch::Tensor sAi, torch::Tensor sBv, torch::Tensor sBi) { int batch = x.size(0); int n = x.size(1); auto stream = at::cuda::getCurrentCUDAStream(); float* OV = outv.data_ptr(); long* OI = outi.data_ptr(); float* Av = sAv.data_ptr(); int* Ai = sAi.data_ptr(); float* Bv = sBv.data_ptr(); int* Bi = sBi.data_ptr(); const float* X = x.data_ptr(); if (k == 1) { if (bpr == 1) { argmax_single<<>>(X, OV, OI, n); return; } int chunk = (n + bpr - 1) / bpr; dim3 grid(bpr, batch); argmax1<<>>(X, Av, Ai, n, chunk, bpr); argmax2<<>>(Av, Ai, OV, OI, bpr); return; } if (k <= 8) { auto run = [&](auto kc) { constexpr int K = decltype(kc)::value; size_t sh1 = (size_t)BD * K * (sizeof(float)+sizeof(int)); size_t sh2 = (size_t)BD2 * K * (sizeof(float)+sizeof(int)); if (bpr == 1) { dim3 grid(1, batch); regmerge<<>>(X, nullptr, nullptr,nullptr, OV, OI, n, 1, 1); return; } dim3 grid(bpr, batch); regmerge<<>>(X, nullptr, Av, Ai, nullptr,nullptr, n, bpr, 0); int m = bpr*K; dim3 grid2(1, batch); regmerge<<>>(Av, Ai, nullptr,nullptr, OV, OI, m, 1, 1); }; if (k==8) run(std::integral_constant{}); else run(std::integral_constant{}); return; } // large k: bitonic load-to-shared, multi-pass reduction. cudaFuncSetAttribute(bitonic_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 101376); auto np2 = [](int v){ int p=1; while(p>>(X, nullptr, nullptr,nullptr, OV, OI, n, k, CHUNK, 1, 1); return; } { size_t sh = (size_t)CHUNK * 8; dim3 grid(bpr, batch); bitonic_kernel<<>>(X, nullptr, Av, Ai, nullptr,nullptr, n, k, CHUNK, bpr, 0); } int m = bpr * k; const int RC = CHUNK2; float *curV = Av; int *curI = Ai; float *altV = Bv; int *altI = Bi; while (m > RC) { int g = (m + RC - 1) / RC; int chunk = np2((m + g - 1) / g); int bd = chunk < 1024 ? chunk : 1024; size_t sh = (size_t)chunk * 8; dim3 grid(g, batch); bitonic_kernel<<>>(curV, curI, altV, altI, nullptr,nullptr, m, k, chunk, g, 0); m = g * k; std::swap(curV, altV); std::swap(curI, altI); } { int chunk = np2(m); int bd = chunk < 1024 ? chunk : 1024; size_t sh = (size_t)chunk * 8; dim3 grid(1, batch); bitonic_kernel<<>>(curV, curI, nullptr,nullptr, OV, OI, m, k, chunk, 1, 1); } } ''' _CPP = ("void topk_cuda(torch::Tensor x, int k, int bpr, int BD, int BD2, int CHUNK, int CHUNK2, " "torch::Tensor outv, torch::Tensor outi, torch::Tensor sAv, torch::Tensor sAi, " "torch::Tensor sBv, torch::Tensor sBi);\n" "int topk_coop(torch::Tensor x, int k, int RC, torch::Tensor outv, torch::Tensor outi, " "torch::Tensor bufA, torch::Tensor bufB);") _mod = load_inline( name="topk_hybrid", cpp_sources=_CPP, cuda_sources=_CUDA, functions=["topk_cuda", "topk_coop"], extra_cuda_cflags=["-O3", "--use_fast_math"], verbose=False, ) def _nextpow2(v): p = 1 while p < v: p <<= 1 return p def _floorpow2(v): p = 1 while p * 2 <= v: p <<= 1 return p class Model(nn.Module): def __init__(self, batch: int, n: int, k: int): super().__init__() self.batch, self.n, self.k = batch, n, k self.register_buffer("_dummy", torch.zeros(1)) self.BD = self.BD2 = self.CHUNK = self.CHUNK2 = 0 self.use_coop = False if k == 1: self.bpr = max(1, min(round(188 / batch), (n + 1023) // 1024)) self._alloc() return if k <= 8: # register top-k + tree merge; BD=128 sweet spot BD = 128 bpr = max(1, round(n / 1536)) while bpr * k > 2048 and bpr > 1: bpr -= 1 while bpr > 1 and (n + bpr - 1) // bpr < BD: bpr -= 1 m = bpr * k self.BD = BD self.bpr = bpr self.BD2 = _floorpow2(max(32, min(1024, (m + k - 1) // k))) self._alloc() return # bitonic, multi-pass reduction. Small phase1 chunk => many # blocks (fill SMs) + less total sort work; reduce passes shrink cands. CHUNK1 = 256 bpr = max(1, (n + CHUNK1 - 1) // CHUNK1) self.bpr = bpr self.CHUNK = CHUNK1 self.BD = min(1024, CHUNK1) self.CHUNK2 = 512 # reduce-chunk target (RC) self.BD2 = 0 self._alloc() # single-row large-k: cooperative single-launch kernel was slower than # multi-pass relaunch (grid.sync barrier across many idle blocks), disabled. self.use_coop = False if self.use_coop: self.RC = 512 Gmax = (n + 255) // 256 cap = Gmax * k # ull buffer per row (batch==1) self.coopA = torch.empty(cap, dtype=torch.int64, device="cuda") self.coopB = torch.empty(cap, dtype=torch.int64, device="cuda") self._coop_ok = True def _alloc(self): # preallocate persistent scratch + output buffers (avoid per-call alloc) dev = torch.device("cuda") cap = max(self.batch * self.bpr * self.k, self.batch * self.k) self.sAv = torch.empty(cap, dtype=torch.float32, device=dev) self.sAi = torch.empty(cap, dtype=torch.int32, device=dev) self.sBv = torch.empty(cap, dtype=torch.float32, device=dev) self.sBi = torch.empty(cap, dtype=torch.int32, device=dev) self.outv = torch.empty(self.batch, self.k, dtype=torch.float32, device=dev) self.outi = torch.empty(self.batch, self.k, dtype=torch.int64, device=dev) def forward(self, x: torch.Tensor): if self.use_coop and self._coop_ok: err = _mod.topk_coop(x, self.k, self.RC, self.outv, self.outi, self.coopA, self.coopB) if err == 0: return self.outv, self.outi self._coop_ok = False # fall back permanently _mod.topk_cuda(x, self.k, self.bpr, self.BD, self.BD2, self.CHUNK, self.CHUNK2, self.outv, self.outi, self.sAv, self.sAi, self.sBv, self.sBi) return self.outv, self.outi def get_inputs(): x = torch.randn(batch, n, dtype=torch.float32) return [x] def get_init_inputs(): return [batch, n, k] batch = 64 n = 8192 k = 8