"""Custom CUDA top-k via streaming selection + block merge (SM120).""" from __future__ import annotations import torch import torch.nn as nn from torch.utils.cpp_extension import load_inline CUDA_SRC = r""" #include #include #include #include #include #define WARP_SIZE 32 __device__ __forceinline__ void insert_desc(int k, float val, int64_t idx, float* vals, int64_t* idxs) { if (val <= vals[k - 1]) return; int pos = k - 1; while (pos > 0 && val > vals[pos - 1]) { vals[pos] = vals[pos - 1]; idxs[pos] = idxs[pos - 1]; --pos; } vals[pos] = val; idxs[pos] = idx; } template __device__ __forceinline__ void insert_desc(float val, int64_t idx, float* vals, int64_t* idxs) { insert_desc(K, val, idx, vals, idxs); } template __device__ void merge_topk(const float* a_val, const int64_t* a_idx, const float* b_val, const int64_t* b_idx, float* out_val, int64_t* out_idx) { int ai = 0, bi = 0, oi = 0; while (oi < K) { if (ai < K && (bi >= K || a_val[ai] > b_val[bi])) { out_val[oi] = a_val[ai]; out_idx[oi] = a_idx[ai]; ++ai; } else { out_val[oi] = b_val[bi]; out_idx[oi] = b_idx[bi]; ++bi; } ++oi; } } template __host__ __device__ constexpr int threads_for_k() { if (K <= 16) return 128; return 64; } template __device__ void block_reduce_topk(float* local_val, int64_t* local_idx, float* s_val, int64_t* s_idx, int nthreads) { int tid = threadIdx.x; #pragma unroll for (int j = 0; j < K; ++j) { s_val[tid * K + j] = local_val[j]; s_idx[tid * K + j] = local_idx[j]; } __syncthreads(); for (int stride = nthreads / 2; stride >= 1; stride >>= 1) { if (tid < stride) { float tmp[K]; int64_t tidx[K]; merge_topk(s_val + tid * K, s_idx + tid * K, s_val + (tid + stride) * K, s_idx + (tid + stride) * K, tmp, tidx); #pragma unroll for (int j = 0; j < K; ++j) { s_val[tid * K + j] = tmp[j]; s_idx[tid * K + j] = tidx[j]; } } __syncthreads(); } } __device__ __forceinline__ void warp_reduce_max(float val, int64_t idx, float& out_val, int64_t& out_idx) { for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { float ov = __shfl_down_sync(0xffffffff, val, offset); int64_t oi = __shfl_down_sync(0xffffffff, idx, offset); if (ov > val) { val = ov; idx = oi; } } out_val = val; out_idx = idx; } __global__ void topk1_kernel(const float* __restrict__ in, float* __restrict__ out_vals, int64_t* __restrict__ out_idxs, int batch, int n) { int row = blockIdx.x; if (row >= batch) return; const float* row_in = in + (int64_t)row * n; float best = -FLT_MAX; int64_t best_i = 0; int i = threadIdx.x * 4; for (; i + 3 < n; i += blockDim.x * 4) { float4 v = *reinterpret_cast(row_in + i); if (v.x > best) { best = v.x; best_i = i; } if (v.y > best) { best = v.y; best_i = i + 1; } if (v.z > best) { best = v.z; best_i = i + 2; } if (v.w > best) { best = v.w; best_i = i + 3; } } for (; i < n; i += blockDim.x) { float v = row_in[i]; if (v > best) { best = v; best_i = i; } } __shared__ float sv[32]; __shared__ int64_t si[32]; int lane = threadIdx.x & 31, wid = threadIdx.x >> 5; float wv = best; int64_t wi = best_i; warp_reduce_max(wv, wi, wv, wi); if (lane == 0) { sv[wid] = wv; si[wid] = wi; } __syncthreads(); if (wid == 0) { wv = (lane < (blockDim.x + 31) / 32) ? sv[lane] : -FLT_MAX; wi = (lane < (blockDim.x + 31) / 32) ? si[lane] : 0; warp_reduce_max(wv, wi, wv, wi); if (lane == 0) { out_vals[row] = wv; out_idxs[row] = wi; } } } template __device__ void scan_row_topk(const float* row_in, int n, float* local_val, int64_t* local_idx) { #pragma unroll for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; } int i = threadIdx.x * 4; for (; i + 3 < n; i += blockDim.x * 4) { float4 v = *reinterpret_cast(row_in + i); insert_desc(v.x, (int64_t)i, local_val, local_idx); insert_desc(v.y, (int64_t)(i + 1), local_val, local_idx); insert_desc(v.z, (int64_t)(i + 2), local_val, local_idx); insert_desc(v.w, (int64_t)(i + 3), local_val, local_idx); } for (; i < n; i += blockDim.x) { insert_desc(row_in[i], (int64_t)i, local_val, local_idx); } } template __global__ void topk_kernel(const float* __restrict__ in, float* __restrict__ out_vals, int64_t* __restrict__ out_idxs, int batch, int n) { int row = blockIdx.x; if (row >= batch) return; const float* row_in = in + (int64_t)row * n; float local_val[K]; int64_t local_idx[K]; scan_row_topk(row_in, n, local_val, local_idx); extern __shared__ char smem[]; float* s_val = reinterpret_cast(smem); int64_t* s_idx = reinterpret_cast(s_val + blockDim.x * K); block_reduce_topk(local_val, local_idx, s_val, s_idx, blockDim.x); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < K; ++j) { out_vals[(int64_t)row * K + j] = s_val[j]; out_idxs[(int64_t)row * K + j] = s_idx[j]; } } } template __global__ void topk_slice_kernel(const float* __restrict__ in, float* __restrict__ partial_vals, int64_t* __restrict__ partial_idxs, int row, int n, int num_blocks) { int bid = blockIdx.x; if (bid >= num_blocks) return; int slice = (n + num_blocks - 1) / num_blocks; int start = bid * slice; int end = min(start + slice, n); const float* row_in = in + (int64_t)row * n; float local_val[K]; int64_t local_idx[K]; #pragma unroll for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; } int i = start + threadIdx.x * 4; for (; i + 3 < end; i += blockDim.x * 4) { float4 v = *reinterpret_cast(row_in + i); insert_desc(v.x, (int64_t)i, local_val, local_idx); insert_desc(v.y, (int64_t)(i + 1), local_val, local_idx); insert_desc(v.z, (int64_t)(i + 2), local_val, local_idx); insert_desc(v.w, (int64_t)(i + 3), local_val, local_idx); } for (; i < end; i += blockDim.x) { insert_desc(row_in[i], (int64_t)i, local_val, local_idx); } extern __shared__ char smem[]; float* s_val = reinterpret_cast(smem); int64_t* s_idx = reinterpret_cast(s_val + blockDim.x * K); block_reduce_topk(local_val, local_idx, s_val, s_idx, blockDim.x); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < K; ++j) { partial_vals[(int64_t)bid * K + j] = s_val[j]; partial_idxs[(int64_t)bid * K + j] = s_idx[j]; } } } template __global__ void topk_merge_kernel(const float* __restrict__ partial_vals, const int64_t* __restrict__ partial_idxs, float* __restrict__ out_vals, int64_t* __restrict__ out_idxs, int row, int num_partials) { float local_val[K]; int64_t local_idx[K]; #pragma unroll for (int j = 0; j < K; ++j) { local_val[j] = -FLT_MAX; local_idx[j] = 0; } for (int p = threadIdx.x; p < num_partials; p += blockDim.x) { #pragma unroll for (int j = 0; j < K; ++j) { insert_desc(partial_vals[(int64_t)p * K + j], partial_idxs[(int64_t)p * K + j], local_val, local_idx); } } extern __shared__ char smem[]; float* s_val = reinterpret_cast(smem); int64_t* s_idx = reinterpret_cast(s_val + blockDim.x * K); block_reduce_topk(local_val, local_idx, s_val, s_idx, blockDim.x); if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < K; ++j) { out_vals[(int64_t)row * K + j] = s_val[j]; out_idxs[(int64_t)row * K + j] = s_idx[j]; } } } template void launch_topk(const float* in, float* out_vals, int64_t* out_idxs, int batch, int n, cudaStream_t stream, float* partial_vals, int64_t* partial_idxs, int partial_cap) { constexpr int T = threads_for_k(); size_t smem = (size_t)T * K * (sizeof(float) + sizeof(int64_t)); if (batch == 1 && n >= 65536 && partial_vals != nullptr) { int num_blocks = min(partial_cap, (n + T - 1) / T); topk_slice_kernel<<>>( in, partial_vals, partial_idxs, 0, n, num_blocks); topk_merge_kernel<<<1, T, smem, stream>>>( partial_vals, partial_idxs, out_vals, out_idxs, 0, num_blocks); } else { topk_kernel<<>>(in, out_vals, out_idxs, batch, n); } } void dispatch_topk(const torch::Tensor& input, torch::Tensor& out_vals, torch::Tensor& out_idxs, int k, torch::Tensor& partial_vals, torch::Tensor& partial_idxs) { const float* in_ptr = input.data_ptr(); float* val_ptr = out_vals.data_ptr(); int64_t* idx_ptr = out_idxs.data_ptr(); int batch = (int)input.size(0); int n = (int)input.size(1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); float* pval = partial_vals.defined() && partial_vals.numel() > 0 ? partial_vals.data_ptr() : nullptr; int64_t* pidx = partial_idxs.defined() && partial_idxs.numel() > 0 ? partial_idxs.data_ptr() : nullptr; int partial_cap = partial_vals.defined() ? (int)partial_vals.size(0) : 0; if (k == 1) { topk1_kernel<<>>(in_ptr, val_ptr, idx_ptr, batch, n); } else if (k == 8) { launch_topk<8>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap); } else if (k == 16) { launch_topk<16>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap); } else if (k == 32) { launch_topk<32>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap); } else if (k == 64) { launch_topk<64>(in_ptr, val_ptr, idx_ptr, batch, n, stream, pval, pidx, partial_cap); } else { TORCH_CHECK(false, "unsupported k=", k); } } std::vector topk_cuda(torch::Tensor input, int k, torch::Tensor partial_vals, torch::Tensor partial_idxs) { TORCH_CHECK(input.is_cuda() && input.dtype() == torch::kFloat32); TORCH_CHECK(input.dim() == 2); auto opts_f = torch::TensorOptions().dtype(torch::kFloat32).device(input.device()); auto opts_i = torch::TensorOptions().dtype(torch::kInt64).device(input.device()); auto values = torch::empty({(int)input.size(0), k}, opts_f); auto indices = torch::empty({(int)input.size(0), k}, opts_i); dispatch_topk(input.contiguous(), values, indices, k, partial_vals, partial_idxs); return {values, indices}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("topk_cuda", &topk_cuda, "Custom top-k (values, indices)"); } """ CUDA_FLAGS = [ "-O3", "--use_fast_math", "-lineinfo", "-gencode=arch=compute_120,code=sm_120", ] _topk_ext = None _partial_cache: dict[tuple[int, int, str], tuple[torch.Tensor, torch.Tensor]] = {} def _get_ext(): global _topk_ext if _topk_ext is None: _topk_ext = load_inline( name="topk_cuda_ext", cpp_sources="", cuda_sources=CUDA_SRC, functions=None, extra_cuda_cflags=CUDA_FLAGS, verbose=False, ) return _topk_ext def _get_partial_workspace(k: int, cap: int, device: torch.device): key = (k, cap, str(device)) if key not in _partial_cache: _partial_cache[key] = ( torch.empty(cap, k, dtype=torch.float32, device=device), torch.empty(cap, k, dtype=torch.int64, device=device), ) return _partial_cache[key] class Model(nn.Module): """Top-k over the last dim of a 2D tensor.""" def __init__(self, batch: int, n: int, k: int): super().__init__() self.batch, self.n, self.k = batch, n, k self.register_buffer("_dummy", torch.zeros(1)) self._use_slice = batch == 1 and n >= 65536 and k > 1 self._partial_cap = min(1024, (n + 63) // 64) if self._use_slice else 0 def forward(self, x: torch.Tensor): ext = _get_ext() if self._use_slice: pv, pi = _get_partial_workspace(self.k, self._partial_cap, x.device) else: pv = torch.empty(0, dtype=torch.float32, device=x.device) pi = torch.empty(0, dtype=torch.int64, device=x.device) return ext.topk_cuda(x.contiguous(), self.k, pv, pi) batch = 64 n = 8192 k = 8 def get_inputs(): x = torch.randn(batch, n, dtype=torch.float32) return [x] def get_init_inputs(): return [batch, n, k]