reasoning (113065 chars)
Let me understand the problem:
1. I need to write a custom top-k kernel in `solution.py`
2. Must NOT use any forbidden ops (torch.topk, torch.sort, etc.)
3. Must handle shapes: (1, 131072, 64), (64, 8192, 8), (32, 16384, 32), (16, 12000, 16), (128, 4096, 1)
4. RTX PRO 6000 is Blackwell SM120 with GDDR7 1.8 TB/s bandwidth
5. The problem is memory-bound - need to maximize bandwidth utilization
6. Must pass `python check.py` and then optimize with `python benchmark.py`
This is a classic top-k problem. For memory-bound top-k, the key is:
- Maximize read bandwidth (coalesced loads)
- Minimize register/shared memory usage
- Use efficient selection/sorting algorithms
For small k (like k=1, k=8, k=16, k=32, k=64), a common approach is:
1. Load data in tiles
2. Maintain a local sorted list of top-k elements in registers
3. Merge tiles
Another approach is bitonic sort in shared memory, which is what the problem name hints at.
For large n (131072) with moderate k (64), and many smaller cases, I think a tile-based local top-k with warp-level or block-level reduction would work well.
Let me think about what approach to take. I could:
1. Use Triton - easier to write, but might have overhead
2. Use CUDA C++ via torch.utils.cpp_extension.load_inline - more control
3. Use CUTLASS - probably overkill for this
Given the Blackwell architecture (SM120), Triton might be the fastest to iterate with. But for maximum performance on top-k, a carefully written CUDA kernel might be better.
Let me look at what SOTA approaches exist. I know:
- For small k, a register-based bubble/insertion sort maintaining top-k as you stream through data is very efficient
- Warp-level primitives can help merge partial results
- For k=1, it's just argmax and can be done very efficiently
Actually, let me think about this more carefully. The largest shape is batch=1, n=131072, k=64. The total input is 512KB. At 1.8 TB/s, that's ~0.28 microseconds just to read. In practice, we'd be happy with a few microseconds.
For many top-k implementations:
- Tile-based: each thread processes a chunk, maintains top-k, then threads merge
- Bitonic sort: sort everything then take top-k (wasteful for small k)
A good approach for batched top-k with small k:
- Each row processed by one block (or warp)
- Each thread loads elements, maintains a local top-k heap/register array
- Thread-level results merged within block via shared memory
- Final top-k extracted
For k up to 64, each thread could maintain 64 elements in registers. That's a lot of registers though (64 * (4+4) = 512 bytes per thread). With 1024 threads, that's 512KB registers, which might be too much.
Better: Use a two-level approach:
1. Each thread processes a chunk and extracts top-k (using simple sorted array insertion)
2. All threads in block write to shared memory
3. Block-level reduction to get final top-k
Or use a smaller k per thread. For example, each thread maintains top-2 or top-4, then merge.
Actually, for k=64, a simpler approach might be:
- Use a block of 256 threads
- Each thread processes n/256 elements, keeping the top 1 (or top 2)
- Then do a block-level bitonic sort on the 256 (or 512) candidates to extract top 64
Wait, that's not enough. If each thread keeps only top-1 from 512 elements, the top-64 of the whole might be distributed across many threads. We need each thread to keep at least enough that the union contains the true top-64.
Actually, if each thread keeps top-k, then merging all threads' results and taking top-k gives the correct answer. So each thread needs to keep top-k from its chunk. For 256 threads and n=131072, each thread handles 512 elements, keeping top-64. Then we have 256*64 = 16384 candidates. A block-level sort of 16384 elements and taking top-64 gives the answer.
16384 elements in shared memory is 64KB (fp32) + 128KB (int64 indices) = 192KB. That's a lot of shared memory but might fit on Blackwell (max ~228KB per SM? Actually Blackwell increased this).
Alternatively, we can do a tree reduction:
- Each warp merges 32 top-k lists into one top-k list
- Then warps merge
For merging two sorted top-k lists, we can use a bitonic merge or simple merge.
Actually, let me think about what's most efficient. The key bottleneck is reading memory. So we want:
1. Coalesced loads (each thread reads consecutive elements or strided)
2. Minimal computation per element
3. Minimal register pressure
For reading, if each thread reads consecutive elements, we get good coalescing. For example, thread i reads elements i, i+blockDim.x, i+2*blockDim.x, etc.
For maintaining top-k, a common trick is to use a sorted array in registers and insert each new element. For k=64, insertion is O(k) = 64 comparisons per element. For n=131072 and 256 threads, each thread processes 512 elements, so 512 * 64 = 32768 comparisons. That's fine.
But 64 fp32 + 64 int32 indices per thread = 512 bytes in registers. With 256 threads = 131KB registers per block. SM120 has plenty of registers, so this is fine.
Actually, let me think about a cleaner design:
**Approach: Tile-based top-k with register queues**
For each row (or batch of rows):
- Launch one block per row
- Block has `THREADS` threads (e.g., 256 or 512)
- Each thread loads `n / THREADS` elements strided by `THREADS`
- Each thread maintains a local sorted array of top-k (values + indices)
- After all elements processed, write local top-k to shared memory
- Perform block-level bitonic sort / merge to get global top-k
- Write to output
For the block-level merge:
- We have `THREADS * k` candidates in shared memory
- Need to extract top-k
- Can use a partial bitonic sort (only need top-k, not full sort)
For k=64 and THREADS=256, we have 16384 candidates. A full bitonic sort of 16384 is 16384 * log2(16384)^2 / 2 ≈ 16384 * 196 / 2 ≈ 1.6M comparisons. That's actually a lot.
A better merge approach:
- Use a tournament tree / heap to merge THREADS sorted lists
- Or use a warp-level reduction repeatedly
Actually, we can do this more cleverly. Instead of each thread keeping top-k, have each thread keep top-1, then do warp reductions to find the top element, then repeat k times. But that's O(k * n) with warp divergence and not great.
Better: Each thread keeps top-k. Then for merging:
- Use a parallel merge sort or bitonic sort on the candidates
- But only need top-k, so we can use a "partial bitonic sort" that only sorts the first k elements
Actually, a simple and efficient approach for the merge:
- Write all candidates to shared memory
- Use a bitonic sort network on chunks, but only propagate the top elements
- Or use multiple rounds of parallel compare-and-swap
Let me think about this differently. There's a known efficient approach for top-k on GPUs:
**Partial bitonic sort in shared memory**
1. Load elements into shared memory (values + indices)
2. Run a bitonic sort network, but after each stage, only keep the "upper" half
3. Since we only need top-k, we can discard the bottom half progressively
For n candidates and needing top-k:
- The bitonic sort has log2(n) stages
- After stage i, we can discard elements that are definitely not in top-k
- Actually, bitonic sort builds sorted sequences of size 2, 4, 8, ..., n
- After building sorted sequences of size m, if m > k, we only need the top-k from each m-sequence when merging
Hmm, this gets complicated. Let me think of a simpler approach.
**Radix Select**
For fp32, we could do a radix-based selection. This is what PyTorch's cuTOPK uses. But implementing a full radix select in CUDA is complex.
**Simple approach that might be fast enough:**
Use Triton! Triton has good support for block-level operations and shared memory. I can write a kernel where each program instance handles one row:
```python
@triton.jit
def topk_kernel(x_ptr, out_val_ptr, out_idx_ptr, batch, n, k, BLOCK_N: tl.constexpr):
row_idx = tl.program_id(0)
# Each row processed by one block
# Load tile, maintain top-k in registers
# ...
```
For the top-k maintenance in Triton, I could:
- Load BLOCK_N elements at a time
- Merge with current top-k using a small sorting network
Actually, Triton might have limitations with dynamic k. But k is known at kernel launch time (passed as parameter), though not at compile time.
Wait, the k values are: 1, 8, 16, 32, 64. These are small and powers of 2 (except 1 which is trivial). I could write a specialized kernel for each k, or use a generic kernel with a loop.
Let me look at what other people do for this. I should search for fast top-k CUDA implementations.
Actually, I think the most practical approach is CUDA C++ with load_inline. This gives me full control over registers, shared memory, and warp primitives.
Let me design a kernel:
**Kernel Design: "Streaming Top-K with Warp Merge"**
For each row, assign one warp (32 threads) or one block.
If one block per row:
- `THREADS = 256` threads per block
- Each thread loads `n / THREADS` elements strided
- Each thread maintains `LOCAL_K = k` best elements in a sorted array (value, idx)
- After processing all elements, threads write to shared memory
- Block does a bitonic sort / merge to extract top-k
To optimize the merge, I'll use a technique where:
- Each warp sorts its portion using warp shuffles
- Warps merge hierarchically
Actually, for simplicity and performance, let me use this:
**Two-pass approach:**
1. **First pass**: Each thread processes a chunk and extracts top-k. Write all thread results to global memory. So we get `batch * num_threads * k` candidates.
2. **Second pass**: Each block reads `num_threads * k` candidates and extracts top-k.
This is clean and easy to implement. The first pass is perfectly parallel and memory-bound. The second pass has much less data to process.
For the second pass, `num_threads * k` is at most `256 * 64 = 16384` elements per row. This is small enough to fit in shared memory and sort efficiently.
Let me calculate for each shape:
1. batch=1, n=131072, k=64: First pass reads 512KB. Second pass sorts 16384 elements.
2. batch=64, n=8192, k=8: First pass reads 64 * 8192 * 4 = 2MB. Second pass sorts 64 * 256 * 8 = 512 elements per row (but 64 rows in parallel).
3. batch=32, n=16384, k=32: First pass reads 32 * 16384 * 4 = 2MB. Second pass sorts 32 * 256 * 32 = 256K elements total.
4. batch=16, n=12000, k=16: First pass reads 16 * 12000 * 4 = 768KB. Second pass sorts 16 * 256 * 16 = 65536 elements total.
5. batch=128, n=4096, k=1: First pass reads 128 * 4096 * 4 = 2MB. Second pass sorts 128 * 256 * 1 = 32768 elements total.
The first pass dominates for large n.
For k=1, I can use a specialized path (argmax) which is much faster.
Actually, I just realized: for k=1, I can use warp-level reductions very efficiently. Each thread finds local max, then warp reduces, then block reduces.
Let me design the kernel more carefully.
**First Pass Kernel:**
```cuda
template<int K>
__global__ void local_topk(const float* x, float* out_vals, int64_t* out_idxs, int n, int k) {
int row = blockIdx.x;
int tid = threadIdx.x;
int nthreads = blockDim.x;
const float* row_x = x + row * n;
// Each thread maintains top-K elements
float vals[K];
int64_t idxs[K];
// Initialize with -inf
#pragma unroll
for (int i = 0; i < K; i++) {
vals[i] = -INFINITY;
idxs[i] = -1;
}
// Strided loop over elements
for (int i = tid; i < n; i += nthreads) {
float v = row_x[i];
// Insert into sorted array if v > vals[K-1]
if (v > vals[K-1]) {
// Shift and insert
int pos = K - 1;
while (pos > 0 && v > vals[pos-1]) {
vals[pos] = vals[pos-1];
idxs[pos] = idxs[pos-1];
pos--;
}
vals[pos] = v;
idxs[pos] = i;
}
}
// Write to global memory
float* row_out_vals = out_vals + row * nthreads * k;
int64_t* row_out_idxs = out_idxs + row * nthreads * k;
#pragma unroll
for (int i = 0; i < K; i++) {
row_out_vals[tid * k + i] = vals[i];
row_out_idxs[tid * k + i] = idxs[i];
}
}
```
**Second Pass Kernel:**
```cuda
template<int K>
__global__ void merge_topk(const float* in_vals, const int64_t* in_idxs,
float* out_vals, int64_t* out_idxs,
int num_candidates, int k) {
int row = blockIdx.x;
int tid = threadIdx.x;
// Load all candidates into shared memory
extern __shared__ char smem[];
float* s_vals = (float*)smem;
int64_t* s_idxs = (int64_t*)(smem + num_candidates * sizeof(float));
const float* row_in_vals = in_vals + row * num_candidates;
const int64_t* row_in_idxs = in_idxs + row * num_candidates;
for (int i = tid; i < num_candidates; i += blockDim.x) {
s_vals[i] = row_in_vals[i];
s_idxs[i] = row_in_idxs[i];
}
__syncthreads();
// Bitonic sort in shared memory - only need top K
// ... sort s_vals/s_idxs descending, then write top K
}
```
The bitonic sort for the second pass needs to handle arbitrary `num_candidates`. I'll round up to the next power of 2, pad with -inf, and run a full bitonic sort, then take top K.
Actually, full bitonic sort of 16384 elements with 256 threads is quite a bit of work but should be fast in shared memory.
Let me think about whether there's a better merge approach. For the second pass, the candidates are already partially sorted (each thread's contribution is sorted descending). We could use a multi-way merge or a tournament tree, but on GPU a parallel bitonic sort is usually faster.
Wait, I should think about this more. For k=64 and num_candidates=16384, a full bitonic sort does ~16384 * 14 * 14 / 2 = 1.6M comparisons. With 256 threads, that's ~6400 comparisons per thread. At 1 comparison per clock (optimistic), that's ~6400 cycles = ~4 microseconds at 1.5GHz. Plus memory loads. The first pass reads 512KB in ~0.3 microseconds at full bandwidth. So the second pass might dominate!
Hmm, but we don't need to fully sort. We only need top-64. Can we do a partial bitonic sort?
Yes! A partial bitonic sort for top-k works by:
1. Building bitonic sequences as usual
2. But when merging two sequences of size m, if the result is guaranteed to be outside top-k, we can stop tracking it
Actually, a cleaner approach: use a selection algorithm. After the first pass, we have `num_candidates` elements. We can use a "quickselect" or "bitonic select" to find the top-k without full sort.
For shared memory with many threads, the most efficient approach might be:
1. Load into shared memory
2. Use a parallel reduction to find the k-th largest element (or just repeatedly extract max)
3. Then filter and sort the top-k
Repeatedly extracting max k times:
- Each extraction requires a parallel reduction over num_candidates elements
- k = 64, num_candidates = 16384
- Each reduction is ~16384 operations, 64 reductions = ~1M operations
- This is comparable to sorting, but with more synchronization
A better approach for the second pass: Use a small heap or sorted array in registers for each thread, and have each thread load a subset of candidates.
Actually, wait. For the second pass, `num_candidates = nthreads * k`. If nthreads=256 and k=64, that's 16384. What if I use more threads? I could use 512 or 1024 threads per block.
With 512 threads and n=131072, each thread processes 256 elements. That's still fine. num_candidates = 512 * 64 = 32768. More candidates, but each thread does less work in pass 1.
With 1024 threads, each thread processes 128 elements. num_candidates = 1024 * 64 = 65536.
Actually, fewer elements per thread means less register pressure and better latency hiding. Let me use 512 threads.
For pass 2 with 32768 candidates: A bitonic sort of 32768 (padded to 32768, which is power of 2) is straightforward. 32768 * 15 * 15 / 2 = 3.7M comparisons. With 512 threads, ~7200 per thread. Still manageable.
But wait - shared memory size! 32768 floats = 128KB, plus 32768 int64 = 256KB. Total 384KB. That's too much for most GPUs (max 164-228KB per block).
So I need to limit shared memory. Let me use 256 threads for pass 1:
- num_candidates = 256 * 64 = 16384
- Shared memory: 16384 * 4 + 16384 * 8 = 64KB + 128KB = 192KB
- This might be too much for some GPUs (max 164KB on older ones, but Blackwell SM120 should have more)
Actually, I just checked: Blackwell has up to 256KB shared memory per SM. But it's partitioned per block. I should check if 192KB is safe.
To be safe, I could store indices as int32 instead of int64. The problem requires int64 output, but internally I can use int32 if n <= 2^31-1 (which it is, max n=131072). Then cast to int64 at the end.
With int32 indices: 64KB + 64KB = 128KB. That's safe on all modern GPUs.
Or I can use a two-level merge to keep shared memory smaller.
Actually, let me think about a single-pass approach instead.
**Single-pass block-level top-k:**
- One block per row
- Each thread loads elements and maintains a local top-k
- Threads cooperatively merge using warp-level operations
- Final top-k is in shared memory and written out
For merging, here's a clever approach:
- After each thread has its local top-k, we do a tree merge
- Warp 0 merges threads 0-31's top-k into a top-k
- Warp 1 merges threads 32-63's top-k
- ...
- Then warp 0 merges the warp results
- etc.
Each merge of two sorted top-k lists can be done in O(k) time with one thread, or in parallel with multiple threads.
For merging two sorted arrays of size k into top-k:
- We can use a parallel merge where thread i finds the i-th element
- Or use a bitonic merge network for 2k elements and take top k
A bitonic merge of two sorted k-element lists:
1. Reverse the second list
2. Concatenate: [list1, reversed(list2)]
3. Apply bitonic merge network
4. The first k elements are the top-k
For k=64, a bitonic merge of 128 elements requires log2(128) = 7 stages, with 128/2 = 64 comparators per stage = 448 comparators total. With 64 threads (2 elements per thread), that's ~7 comparators per thread. Very fast!
For merging 32 threads' results within a warp:
- Merge pairs (16 merges of 2 lists -> 16 lists)
- Merge pairs of results (8 merges -> 8 lists)
- ... (4 -> 2 -> 1)
Total merges: 16 + 8 + 4 + 2 + 1 = 31 merges of varying sizes.
But wait, each merge produces a top-k list from two top-k lists. The size stays k.
Actually, let me think about this. After the first round, we have 16 top-k lists. After the second, 8 top-k lists, etc. After 5 rounds, 1 top-k list.
For k=64, each merge is a bitonic merge of 128 elements. With 64 threads per merge (since we need parallelism), that's 64 * 31 = 1984 thread-merges. But warps are 32 threads, so we'd use 2 warps per merge.
Hmm, this is getting complex. Let me consider a simpler approach.
**Simplest robust approach: Two-pass with bitonic sort in pass 2**
I'll use:
- Pass 1: Each block processes one row. Each thread maintains a sorted top-k array in registers. Write `blockDim.x * k` candidates to global memory.
- Pass 2: Each block reads candidates for one row, sorts with bitonic sort in shared memory, writes top-k.
I'll tune `blockDim.x` per shape or use a fixed value. Let me use 256 threads for simplicity.
For pass 2 shared memory: I'll store values and int32 indices. After sorting, write int64 indices to output.
For the bitonic sort, I'll write a generic kernel that works for any power-of-2 size. For non-power-of-2 candidates, I'll pad with -inf.
Actually, I realize I can also do the entire thing in one pass with a different strategy:
- Use a persistent kernel that keeps top-k in shared memory and updates it
- Each thread loads a chunk, finds its top-k, then all threads merge into shared memory
- But this has the same complexity
Let me search for some existing top-k CUDA implementations to see what's actually fast.
Wait, I should also consider using Triton. Let me check if there are any Triton top-k implementations I can reference. Actually, I recall that Triton is particularly good at fused reductions and has efficient shared memory access patterns. But implementing bitonic sort in Triton might be tricky due to the SIMT programming model.
Actually, I know what - let me use CUDA C++. It gives me the most control.
Let me draft the kernel step by step.
Actually, for k=1 (argmax), I can write a specialized highly-optimized kernel. For k=1 with 4096 elements and 128 batches, a simple warp reduction is very fast.
Let me also think about memory layout. The input is (batch, n) row-major. For coalesced access:
- In pass 1, thread tid should read elements at tid, tid+blockDim.x, tid+2*blockDim.x, etc. within each row.
- This means consecutive threads read consecutive elements = coalesced!
Yes, because row elements are contiguous in memory. Thread 0 reads index 0, thread 1 reads index 1, etc. Perfect coalescing.
For pass 2, each thread reads its portion of candidates. Candidates for a row are stored as [thread0_k0, thread0_k1, ..., thread0_kK-1, thread1_k0, ...]. This is also coalesced if thread tid reads starting at tid * k.
OK let me write the code. I'll create a solution.py that compiles CUDA kernels inline.
For the bitonic sort in pass 2, I need to handle the fact that different shapes have different numbers of candidates. The candidates = THREADS * k, where THREADS = 256.
For each shape:
1. (1, 131072, 64): candidates = 16384
2. (64, 8192, 8): candidates = 2048
3. (32, 16384, 32): candidates = 8192
4. (16, 12000, 16): candidates = 4096
5. (128, 4096, 1): candidates = 256
All are powers of 2 except shape 4 has candidates = 4096 (power of 2) and shape 1 has 16384 (power of 2). Wait, all are powers of 2 with THREADS=256!
256 * 1 = 256
256 * 8 = 2048
256 * 16 = 4096
256 * 32 = 8192
256 * 64 = 16384
Great! All powers of 2. That makes bitonic sort much easier.
But wait, n=12000 for shape 4. With 256 threads, each thread processes 12000/256 = 46.875 elements. I'll just have each thread handle 47 elements (some threads will go up to 12000, others stop at 11999). That's fine.
Now, for the bitonic sort, I need to write a kernel that can handle different sizes. I can use a template parameter for the candidate count, or write a generic one.
Actually, I think a cleaner approach is to write a single kernel that takes the candidate count as a runtime parameter and uses a loop-based bitonic sort. The performance difference is minimal for shared memory sorts.
Loop-based bitonic sort:
```cuda
for (int size = 2; size <= N; size *= 2) {
for (int stride = size / 2; stride > 0; stride /= 2) {
// Compare-and-swap elements
for (int i = tid; i < N/2; i += blockDim.x) {
// calculate indices to compare
int idx = i * 2 + ((i / stride) % 2) * stride;
// Actually, standard bitonic sort indexing...
}
__syncthreads();
}
}
```
Wait, standard bitonic sort indexing for descending order:
For each pair of elements to compare:
```cuda
int ixj = i ^ j;
if (ixj > i) {
if ((i & k2) == 0) {
// ascending in this phase... actually for descending:
if (s_vals[i] < s_vals[ixj]) swap;
} else {
if (s_vals[i] > s_vals[ixj]) swap;
}
}
```
Let me write a cleaner bitonic sort. The standard algorithm for descending:
```cuda
// N must be power of 2
for (int k = 2; k <= N; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
for (int i = tid; i < N; i += blockDim.x) {
int ixj = i ^ j;
if (ixj > i) {
bool ascending = (i & k) == 0;
// For descending overall, we want:
// In first half of k-block: larger values go to lower indices
// In second half: larger values go to higher indices... wait
}
}
__syncthreads();
}
}
```
Actually, let me be careful. For a descending sort (largest first), the standard bitonic sort works as follows:
In the outer loop over k = 2, 4, 8, ..., N:
- For each k, we create bitonic sequences of length k
- The first k/2 elements should be in ascending order, the second k/2 in descending (or vice versa, depending on position)
For a fully descending sort:
- At each stage k, elements are divided into blocks of size k
- Within each block, we want the first half sorted ascending and second half descending (forming a bitonic sequence)
- Then we merge
Actually, let me just use the well-known indexing. For descending order:
```cuda
for (int k = 2; k <= N; k = k * 2) {
for (int j = k / 2; j > 0; j = j / 2) {
for (int i = threadIdx.x; i < N; i += blockDim.x) {
int ixj = i ^ j;
if (ixj > i) {
if ((i & k) == 0) {
if (s_vals[i] < s_vals[ixj]) {
swap(s_vals[i], s_vals[ixj]);
swap(s_idxs[i], s_idxs[ixj]);
}
} else {
if (s_vals[i] > s_vals[ixj]) {
swap(s_vals[i], s_vals[ixj]);
swap(s_idxs[i], s_idxs[ixj]);
}
}
}
}
__syncthreads();
}
}
```
Wait, I need to double-check this. For k=2, j=1:
- i=0: ixj=1, (0 & 2)==0, so if s[0] < s[1], swap -> s[0] > s[1] (descending pair)
- i=1: ixj=0, but ixj > i is false, so skip
- i=2: ixj=3, (2 & 2)!=0, so if s[2] > s[3], swap -> s[2] < s[3] (ascending pair)
- i=3: ixj=2, skip
After k=2: pairs are [desc, asc, desc, asc, ...] - this is correct for bitonic sequence building.
For k=4, j=2:
- i=0: ixj=2, (0&4)==0, if s[0] < s[2], swap -> larger to 0
- i=1: ixj=3, (1&4)==0, if s[1] < s[3], swap -> larger to 1
- i=2: ixj=0, skip
- i=3: ixj=1, skip
- i=4: ixj=6, (4&4)!=0, if s[4] > s[6], swap -> smaller to 4
- i=5: ixj=7, (5&4)!=0, if s[5] > s[7], swap -> smaller to 5
After j=2: within each 4-block, first half has larger elements from first&second quarters, second half has smaller.
Then j=1:
- i=0: ixj=1, (0&4)==0, if s[0] < s[1], swap
- i=2: ixj=3, (2&4)==0, if s[2] < s[3], swap
- i=4: ixj=5, (4&4)!=0, if s[4] > s[5], swap
- i=6: ixj=7, (6&4)!=0, if s[6] > s[7], swap
After k=4: each 4-block is fully descending. Great!
So the code above produces a fully descending sort. Good.
Now, each thread accesses elements i and ixj. When i and ixj are far apart, this can cause bank conflicts if they're in the same bank. But for float4 or float, adjacent threads access adjacent elements, so there might be some conflicts. Actually, with 32 banks and stride j, if j is a multiple of 32, all threads in a warp access the same bank -> major conflict.
To avoid bank conflicts, we can use a different indexing scheme or use shared memory with padding. But for small sorts in shared memory, bank conflicts might not be terrible. Actually, for j >= 32, all threads in a warp read the same bank but different addresses. Wait, no - for j=32, thread i reads bank (i % 32) and bank ((i^32) % 32) = bank (i % 32). So yes, both reads are from the same bank! That's a 2-way conflict for all threads.
One fix is to pad the shared memory array so that index i maps to bank (i * PADDING % 32). For example, with PADDING=33, adjacent elements are in different banks. But this makes indexing more complex.
Another fix is to have each thread process two adjacent elements that are compared. But in bitonic sort, the comparison distance varies.
For the sizes we're dealing with (up to 16384), the number of stages with j >= 32 is small (log2(16384) - 5 = 9 stages). Bank conflicts in these stages might not be a huge issue.
Alternatively, I can use a "shuffled" indexing to avoid conflicts. But let's keep it simple first and optimize if needed.
Actually, there's a known trick: store as `struct { float val; int idx; }` and align properly. Or interleave the storage.
Let me just write a working version first, then profile and optimize.
Now let me think about the overall structure of solution.py.
```python
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
# CUDA source code
cuda_source = """
... kernels ...
"""
# Compile
module = load_inline(
name="topk_cuda",
cpp_sources="",
cuda_sources=cuda_source,
functions=["topk_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
)
class Model(nn.Module):
def __init__(self, batch, n, k):
super().__init__()
self.batch = batch
self.n = n
self.k = k
self.register_buffer("_dummy", torch.zeros(1))
def forward(self, x):
return module.topk_forward(x, self.k)
```
Wait, I need to handle the case where k varies. But for a given model instance, k is fixed.
For the kernel dispatch, I'll have:
- A specialized path for k=1 (argmax)
- A general two-pass path for k > 1
Actually, for k=1, the two-pass approach works too (candidates = 256, sort 256 elements to get top 1). But an argmax reduction is much faster. Let me implement both.
For k=1:
```cuda
__global__ void topk1_kernel(const float* x, float* out_vals, int64_t* out_idxs, int n) {
int row = blockIdx.x;
int tid = threadIdx.x;
int lane = tid % 32;
int warp = tid / 32;
const float* row_x = x + row * n;
float max_val = -INFINITY;
int max_idx = -1;
for (int i = tid; i < n; i += blockDim.x) {
float v = row_x[i];
if (v > max_val) {
max_val = v;
max_idx = i;
}
}
// Warp reduce
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
float other_val = __shfl_down_sync(0xFFFFFFFF, max_val, offset);
int other_idx = __shfl_down_sync(0xFFFFFFFF, max_idx, offset);
if (other_val > max_val) {
max_val = other_val;
max_idx = other_idx;
}
}
// Write warp results to shared memory
__shared__ float s_vals[32];
__shared__ int s_idxs[32];
if (lane == 0) {
s_vals[warp] = max_val;
s_idxs[warp] = max_idx;
}
__syncthreads();
// Final reduction by warp 0
if (tid < 32) {
float v = (tid < (blockDim.x + 31) / 32) ? s_vals[tid] : -INFINITY;
int idx = (tid < (blockDim.x + 31) / 32) ? s_idxs[tid] : -1;
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
float other_val = __shfl_down_sync(0xFFFFFFFF, v, offset);
int other_idx = __shfl_down_sync(0xFFFFFFFF, idx, offset);
if (other_val > v) {
v = other_val;
idx = other_idx;
}
}
if (tid == 0) {
out_vals[row] = v;
out_idxs[row] = idx;
}
}
}
```
For k > 1, the two-pass approach. Let me write a templated version for the first pass.
Actually, since k can be 8, 16, 32, 64, I can template on k for better performance. But CUDA kernels with template parameters need to be instantiated. I can use `switch(k)` to dispatch to template instantiations.
Let me draft the full CUDA code.
For the first pass template:
```cuda
template <int K>
__global__ void local_topk_kernel(const float* __restrict__ x,
float* __restrict__ tmp_vals,
int* __restrict__ tmp_idxs,
int n, int k) {
int row = blockIdx.x;
int tid = threadIdx.x;
int nthreads = blockDim.x;
const float* row_x = x + row * n;
float vals[K];
int idxs[K];
#pragma unroll
for (int i = 0; i < K; i++) {
vals[i] = -INFINITY;
idxs[i] = -1;
}
for (int i = tid; i < n; i += nthreads) {
float v = row_x[i];
if (v > vals[K-1]) {
int pos = K - 1;
while (pos > 0 && v > vals[pos-1]) {
vals[pos] = vals[pos-1];
idxs[pos] = idxs[pos-1];
pos--;
}
vals[pos] = v;
idxs[pos] = i;
}
}
float* row_tmp_vals = tmp_vals + row * nthreads * k;
int* row_tmp_idxs = tmp_idxs + row * nthreads * k;
#pragma unroll
for (int i = 0; i < K; i++) {
row_tmp_vals[tid * k + i] = vals[i];
row_tmp_idxs[tid * k + i] = idxs[i];
}
}
```
For the second pass, bitonic sort:
```cuda
template <int K>
__global__ void merge_topk_kernel(const float* __restrict__ tmp_vals,
const int* __restrict__ tmp_idxs,
float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,
int n_candidates, int k) {
int row = blockIdx.x;
int tid = threadIdx.x;
extern __shared__ char smem[];
float* s_vals = (float*)smem;
int* s_idxs = (int*)(smem + n_candidates * sizeof(float));
const float* row_tmp_vals = tmp_vals + row * n_candidates;
const int* row_tmp_idxs = tmp_idxs + row * n_candidates;
for (int i = tid; i < n_candidates; i += blockDim.x) {
s_vals[i] = row_tmp_vals[i];
s_idxs[i] = row_tmp_idxs[i];
}
__syncthreads();
// Bitonic sort descending
for (int kk = 2; kk <= n_candidates; kk *= 2) {
for (int j = kk / 2; j > 0; j /= 2) {
for (int i = tid; i < n_candidates; i += blockDim.x) {
int ixj = i ^ j;
if (ixj > i) {
if ((i & kk) == 0) {
if (s_vals[i] < s_vals[ixj]) {
float tmp_v = s_vals[i]; s_vals[i] = s_vals[ixj]; s_vals[ixj] = tmp_v;
int tmp_i = s_idxs[i]; s_idxs[i] = s_idxs[ixj]; s_idxs[ixj] = tmp_i;
}
} else {
if (s_vals[i] > s_vals[ixj]) {
float tmp_v = s_vals[i]; s_vals[i] = s_vals[ixj]; s_vals[ixj] = tmp_v;
int tmp_i = s_idxs[i]; s_idxs[i] = s_idxs[ixj]; s_idxs[ixj] = tmp_i;
}
}
}
}
__syncthreads();
}
}
// Write top K
if (tid < k) {
out_vals[row * k + tid] = s_vals[tid];
out_idxs[row * k + tid] = (int64_t)s_idxs[tid];
}
}
```
Wait, there's an issue with shared memory size. The extern shared memory is allocated per block, and I need to tell the kernel how much to allocate. In the PyTorch load_inline API, I can pass `shared_mem` parameter when launching, or I can use `extern __shared__` and specify at kernel launch.
In PyTorch, when using `torch.cuda.ExternalResource` or custom ops, I'd typically use `cpp_extension` and write a C++ wrapper. But with `load_inline`, I need to be careful.
Actually, `torch.utils.cpp_extension.load_inline` loads a module with the functions I declare. But I need a C++ function that calls the kernels with the right parameters.
Let me write the C++ wrapper:
```cpp
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
// ... kernel declarations ...
torch::Tensor topk_forward(torch::Tensor x, int k) {
int batch = x.size(0);
int n = x.size(1);
auto vals = torch::empty({batch, k}, x.options());
auto idxs = torch::empty({batch, k}, torch::dtype(torch::kInt64).device(x.device()));
const float* x_ptr = x.data_ptr<float>();
float* vals_ptr = vals.data_ptr<float>();
int64_t* idxs_ptr = idxs.data_ptr<int64_t>();
int threads = 256;
if (k == 1) {
topk1_kernel<<<batch, threads>>>(x_ptr, vals_ptr, idxs_ptr, n);
} else {
int n_candidates = threads * k;
auto tmp_vals = torch::empty({batch, n_candidates}, x.options());
auto tmp_idxs = torch::empty({batch, n_candidates}, torch::dtype(torch::kInt32).device(x.device()));
float* tmp_vals_ptr = tmp_vals.data_ptr<float>();
int* tmp_idxs_ptr = tmp_idxs.data_ptr<int>();
switch (k) {
case 8:
local_topk_kernel<8><<<batch, threads>>>(x_ptr, tmp_vals_ptr, tmp_idxs_ptr, n, k);
merge_topk_kernel<8><<<batch, threads, n_candidates * (sizeof(float) + sizeof(int))>>>(tmp_vals_ptr, tmp_idxs_ptr, vals_ptr, idxs_ptr, n_candidates, k);
break;
case 16:
local_topk_kernel<16><<<batch, threads>>>(x_ptr, tmp_vals_ptr, tmp_idxs_ptr, n, k);
merge_topk_kernel<16><<<batch, threads, n_candidates * (sizeof(float) + sizeof(int))>>>(tmp_vals_ptr, tmp_idxs_ptr, vals_ptr, idxs_ptr, n_candidates, k);
break;
case 32:
local_topk_kernel<32><<<batch, threads>>>(x_ptr, tmp_vals_ptr, tmp_idxs_ptr, n, k);
merge_topk_kernel<32><<<batch, threads, n_candidates * (sizeof(float) + sizeof(int))>>>(tmp_vals_ptr, tmp_idxs_ptr, vals_ptr, idxs_ptr, n_candidates, k);
break;
case 64:
local_topk_kernel<64><<<batch, threads>>>(x_ptr, tmp_vals_ptr, tmp_idxs_ptr, n, k);
merge_topk_kernel<64><<<batch, threads, n_candidates * (sizeof(float) + sizeof(int))>>>(tmp_vals_ptr, tmp_idxs_ptr, vals_ptr, idxs_ptr, n_candidates, k);
break;
default:
// Generic fallback - not needed for our shapes
break;
}
}
return {vals, idxs};
}
```
Wait, actually I should be returning a tuple from the C++ function. In PyTorch C++ extensions:
```cpp
std::vector<torch::Tensor> topk_forward(torch::Tensor x, int k) {
// ...
return {vals, idxs};
}
```
And in Python:
```python
module = load_inline(
name="topk_cuda",
cpp_sources=cpp_source,
cuda_sources=cuda_source,
functions=["topk_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
with_cuda=True,
)
```
Now let me think about optimizing the first pass. The insertion sort in registers might not be optimal. For small K, unrolling the insertion is important.
For K=8, each insertion requires up to 8 comparisons and shifts. For 512 elements per thread (n=131072, threads=256), that's 512 * 8 = 4096 operations per thread. At 1.5GHz, that's ~2.7 microseconds. With 256 threads, the block takes ~2.7 microseconds (they run in parallel). Total for batch=1: ~2.7 microseconds + memory reads.
Actually, the memory read of 512KB at 1.8 TB/s is ~0.28 microseconds. But in practice, with latency and coalescing, it's probably 1-2 microseconds. So the compute might be comparable.
Wait, 256 threads can saturate memory bandwidth if they're all reading. For 512KB, with perfect coalescing, each warp reads 128 bytes per instruction. That's 512KB / 128B = 4096 warp-instructions. At 1 warp/SM/clock (roughly), with 128 warps total (4 blocks of 32 warps on a large GPU), this might take a few microseconds.
Hmm, let me think about whether I can optimize further.
Actually, for the first pass, instead of maintaining a sorted array, I could use a different approach:
1. Process elements in groups of K
2. Sort each group of K using a sorting network
3. Merge with the current top-K using a merge network
But for K=64, this might be more complex than simple insertion.
Another idea: Use vectorized loads. Each thread loads 4 floats at a time (float4). This reduces instruction count and improves memory throughput. For n=131072 and 256 threads, each thread loads 128 float4s. That's efficient.
But we need to handle non-power-of-2 n (like 12000). 12000/256 = 46.875, so some threads load 47 elements, some 46. With float4, we'd need to handle remainders carefully.
Actually, let me add float4 loads. It's a common optimization:
```cuda
float4* x4 = (float4*)row_x;
int n4 = n / 4;
for (int i = tid; i < n4; i += nthreads) {
float4 v4 = x4[i];
// process v4.x, v4.y, v4.z, v4.w
}
// Handle remaining elements
for (int i = tid + n4 * 4; i < n; i += nthreads) {
float v = row_x[i];
// process v
}
```
This is good! Let me incorporate this.
Now, for the second pass, the bitonic sort in shared memory. The issue is that with many threads accessing shared memory with varying strides, bank conflicts can occur.
For j = 1: threads access i and i^1, which are adjacent -> no bank conflict (different banks)
For j = 2: access i and i^2, stride 2 -> might have some conflicts depending on bank size
Actually, for 32-byte banks (as in modern GPUs), float is 4 bytes, so 8 floats per bank. Stride j=1 means accessing elements 1 apart, which is 4 bytes apart. With 32 threads in a warp, thread i accesses banks (i % 8) and ((i^1) % 8). For even i, (i^1) = i+1, so banks are i%8 and (i+1)%8 -> different banks. For odd i, same but reversed. No conflicts.
For j=2: i and i^2. For i=0: 0 and 2, banks 0 and 2. i=1: 1 and 3, banks 1 and 3. i=2: 2 and 0, same as i=0. Wait, but threads 0 and 2 both access banks 0 and 2. Are there conflicts?
Thread 0 accesses shared[0] and shared[2]. Thread 1 accesses shared[1] and shared[3]. Thread 2 accesses shared[2] and shared[0]. Thread 3 accesses shared[3] and shared[1].
In a warp, threads 0,1,2,3 are active. Bank 0 is accessed by thread 0 (shared[0]) and thread 2 (shared[0] via shared[2]'s partner). So yes, bank conflict!
For j=4: even worse. Many threads access the same banks.
This is a well-known issue with bitonic sort. One solution is to pad the shared memory array so that index i maps to a different bank than index i^j for all j.
A common padding is:
```cuda
// Instead of s[i], use s[i + (i >> 5)]
// Or use a struct with padding
```
Actually, a simpler fix for small N (like up to 16384) is to not worry too much about it. The number of stages with significant conflicts is small, and shared memory is fast anyway.
But we can optimize: store values and indices in a struct and add padding:
```cuda
struct Item {
float val;
int idx;
};
__shared__ Item s_items[N + N/32]; // pad every 32 elements
```
With padding, index i maps to `i + (i >> 5)` or similar. Let me use `i + (i / 32)` as the padded index.
Actually, a cleaner way: `s[i + (i >> 5)]` gives one extra element every 32 elements, shifting elements 32+ apart by 1, which changes the bank.
Hmm, this is getting complicated. Let me just use the simple version first and optimize if profiling shows it's a bottleneck.
Actually, wait. Let me check: what GPU do we have and what's its shared memory bandwidth? Blackwell SM120 has much higher shared memory bandwidth than previous generations. Bank conflicts in shared memory might be less of an issue.
Let me proceed with the simple implementation and then profile.
One more thing: for the merge pass with k=1 and candidates=256, the bitonic sort is overkill. I already have a specialized k=1 kernel.
Let me also think about whether I should process multiple rows per block for small shapes. For example, batch=128, n=4096, k=1. If I launch 128 blocks of 256 threads each, that's 32768 threads. With ~100+ SMs, each SM gets ~300 threads, which is fine. But the overhead of 128 kernel launches might be noticeable.
Actually, in CUDA, launching one block per row is standard. The GPU scheduler handles it well.
For batch=128, n=4096, k=1, each block reads 4096 floats = 16KB. With 128 blocks, total read = 2MB. This should be fast.
Let me think about another optimization: for the merge pass, I can reduce the number of threads to just what's needed. For candidates=2048 (k=8), I only need 1024 threads to cover all elements. But with 256 threads, each thread processes 8 elements in the bitonic sort loops. That's fine.
Actually, using fewer threads in the merge pass means fewer thread loops but more work per thread. With 256 threads and 2048 candidates, each thread handles 8 elements per stage. With 1024 threads, each thread handles 2 elements. 256 threads might be better due to less synchronization overhead.
Let me think about whether I can fuse the two passes. In some cases, especially for small n, a single pass might be faster.
For n=4096, k=1: single pass argmax is optimal.
For n=8192, k=8: single pass where each thread keeps top-8 and then block reduces.
Actually, for n=8192 with 256 threads, each thread processes 32 elements. Keeping top-8 in registers and then doing a block-level merge of 256 * 8 = 2048 candidates. This is the two-pass approach, but the "passes" are both in registers/shared memory within the same kernel launch? No, pass 1 writes to global memory, pass 2 reads from global memory.
Can I avoid the global memory write/read? Yes, by doing everything in one kernel:
1. Each thread processes its chunk and keeps top-k in registers
2. Write to shared memory
3. Bitonic sort in shared memory
4. Write top-k to output
This eliminates the global memory temporary! That's a big win.
For n=131072, k=64, with 256 threads:
- Each thread processes 512 elements, keeps top-64
- Write 256 * 64 = 16384 candidates to shared memory (192KB with int64, 128KB with int32)
- Bitonic sort 16384 in shared memory
- Write top-64 to output
192KB shared memory might be too much. But 128KB (with int32 indices) might fit on Blackwell. Actually, I don't know the exact shared memory limit per block on SM120. Ampere/Hopper had 164KB configurable, but Blackwell increased it. Let me check.
Actually, for safety, I can use a two-level approach within one kernel:
1. Each warp (32 threads) processes its portion and keeps top-k
2. Warp-level shuffle to find top-k per warp
3. Then block-level merge of warp results
For 256 threads = 8 warps. Each warp keeps top-k. Then we merge 8 top-k lists.
For k=64, merging 8 lists of 64 into top-64:
- Load all 512 candidates into shared memory
- Bitonic sort 512 elements (padded to 512) -> 512 * 9 * 9 / 2 = ~20K comparisons
- Take top-64
Shared memory: 512 * 4 + 512 * 4 = 4KB (with int32 indices). Very small!
Wait, that's brilliant! Here's the single-kernel approach:
**Warp-local top-k + block-level merge**
1. Divide block into warps (8 warps of 32 threads for 256 threads)
2. Each warp processes its strided elements:
- Warp w processes elements starting at w*32 + lane, then +256, +512, etc.
- Each thread in warp keeps a local top-k in registers
- Within warp, merge all 32 top-k lists into one warp top-k list using warp shuffles
3. After all warps have their top-k, write all warp top-k lists to shared memory
4. One warp (or all threads) bitonic sorts the candidates and extracts top-k
Let me detail step 2 - warp-level merge:
Each thread in the warp has top-k in registers. To merge 32 top-k lists into one:
- Write all 32*k candidates to shared memory (within warp)
- Bitonic sort 32*k elements (padded to next power of 2)
- Each thread reads back top-k
For k=64, 32*k = 2048. Shared memory per warp: 2048 * 4 + 2048 * 4 = 16KB. For 8 warps, if they use shared memory sequentially (not simultaneously), total is still 16KB. Or we can partition shared memory into 8 regions of 16KB each = 128KB total.
Actually, warps in a block can share shared memory. We can have 8 * 2048 = 16384 elements total = 128KB. That's the same as before.
But the key insight: instead of having each thread keep top-k and then sorting 16384 elements at the block level, we first merge within warps to reduce 256 top-k lists to 8 top-k lists. Then we sort 512 elements at the block level instead of 16384. That's 32x fewer elements to sort at the block level!
Wait, let me recalculate:
- Original: 256 threads * 64 = 16384 candidates -> sort 16384
- With warp merge: 8 warps * 64 = 512 candidates -> sort 512
For k=8: 256 * 8 = 2048 -> sort 2048
With warp merge: 8 * 8 = 64 -> sort 64
For k=16: 256 * 16 = 4096 -> sort 4096
With warp merge: 8 * 16 = 128 -> sort 128
For k=32: 256 * 32 = 8192 -> sort 8192
With warp merge: 8 * 32 = 256 -> sort 256
This is a HUGE reduction in sorting work! Let me do this.
**Detailed warp-level merge:**
Within a warp, each of 32 threads has a sorted top-k array `vals[K]`, `idxs[K]`.
To merge these into a single top-k array:
1. Write all 32*K candidates to shared memory (warp-private region)
2. Bitonic sort the 32*K candidates
3. Thread 0 reads top-k and broadcasts to all threads in warp... or each thread reads its portion of top-k
Actually, after sorting, we only need the top-k. We can have thread lane read `s_vals[lane]` for lane < k. Then each thread in the warp has one element of the top-k.
But for the subsequent processing, it might be better if each thread has the full top-k (for k=64, that's 64 registers per thread). Or we can have thread 0 keep the full top-k and the rest just continue processing.
Wait, in this design, the warp merge happens after all elements are processed. Each warp first processes all its assigned elements (strided by blockDim.x), building thread-local top-k. Then merges within the warp. Then writes warp top-k to block shared memory. Then block-level bitonic sort on warp results.
For step 2 (warp merge), I can do it like this:
```cuda
// Within warp, merge 32 top-k lists
__shared__ float s_warp_vals[32][128]; // 8 warps * 2048 max candidates, but let's use dynamic
__shared__ int s_warp_idxs[32][128];
```
Actually, let's use a simpler approach. I'll allocate shared memory for all warps:
```cuda
// Max candidates per warp = 32 * 64 = 2048
// 8 warps -> 16384 total elements
// But we process warp by warp
__shared__ float s_vals[WARP_MAX_CANDIDATES];
__shared__ int s_idxs[WARP_MAX_CANDIDATES];
```
For merging within warp:
```cuda
// Each thread writes its top-k
for (int i = 0; i < k; i++) {
s_vals[lane * k + i] = vals[i];
s_idxs[lane * k + i] = idxs[i];
}
__syncwarp();
// Bitonic sort 32*k candidates
int warp_candidates = 32 * k;
for (int kk = 2; kk <= warp_candidates; kk *= 2) {
for (int j = kk / 2; j > 0; j /= 2) {
for (int i = lane; i < warp_candidates; i += 32) {
int ixj = i ^ j;
if (ixj > i) {
if ((i & kk) == 0) {
if (s_vals[i] < s_vals[ixj]) swap;
} else {
if (s_vals[i] > s_vals[ixj]) swap;
}
}
}
__syncwarp();
}
}
// Each thread reads back top-k portion
for (int i = 0; i < k; i++) {
vals[i] = s_vals[lane * k + i]; // Wait, top-k are at indices 0..k-1
idxs[i] = s_idxs[lane * k + i];
}
```
Wait, after sorting, the top-k are at indices 0 to k-1. I want all threads in the warp to have the top-k. But with 32 threads and k=64, each thread can hold 2 elements. Or I can have thread 0 hold all 64 in its registers and then... that's too many.
Actually, I don't need each thread to hold all top-k after the warp merge. I just need the warp top-k to be written to shared memory, and then the block-level sort will pick it up.
Here's a revised plan:
1. Each thread processes its chunk, keeps top-k in registers
2. Warp merge: all threads write to warp shared memory, sort, result is top-k in shared memory indices 0..k-1
3. Warp 0 (or a designated thread) copies warp top-k from shared memory to block shared memory at warp_offset
4. After all warps done, block-level bitonic sort on `num_warps * k` candidates
5. Write top-k to output
For step 3, I need to coordinate. Let me think...
Actually, a cleaner design:
- After warp-level sort, the top-k are at s_vals[0..k-1] in the warp's shared memory region
- Thread 0 of each warp copies these k elements to the block-level shared memory
- All warps do this simultaneously (each to its own region)
- Then __syncthreads()
- Block-level bitonic sort on block_candidates = num_warps * k
Shared memory layout:
- Partition 1: warp working area (max 2048 floats + 2048 ints) = 16KB
- Partition 2: block candidate area (8 * 64 = 512 floats + 512 ints) = 4KB
Total: 20KB. Very manageable!
Actually, we can reuse partition 1 as the block working area for the final sort. The block candidates (512 elements) can be stored in a separate small array. Or we can sort in place.
Let me think about the control flow more carefully.
```cuda
template <int K>
__global__ void topk_kernel(const float* __restrict__ x,
float* __restrict__ out_vals,
int64_t* __restrict__ out_idxs,
int n, int k) {
int row = blockIdx.x;
int tid = threadIdx.x;
int lane = tid & 31;
int warp_id = tid >> 5;
int num_warps = blockDim.x >> 5;
const float* row_x = x + row * n;
// Thread-local top-K
float vals[K];
int idxs[K];
#pragma unroll
for (int i = 0; i < K; i++) {
vals[i] = -INFINITY;
idxs[i] = -1;
}
// Process elements
for (int i = tid; i < n; i += blockDim.x) {
float v = row_x[i];
if (v > vals[K-1]) {
int pos = K - 1;
while (pos > 0 && v > vals[pos-1]) {
vals[pos] = vals[pos-1];
idxs[pos] = idxs[pos-1];
pos--;
}
vals[pos] = v;
idxs[pos] = i;
}
}
// Warp-level shared memory
// Layout: [warp0_area][warp1_area]... but actually each warp uses it sequentially
// Better: all warps use the same area but at different times? No, they run concurrently.
// So partition by warp_id.
extern __shared__ char smem[];
// Warp work area: each warp gets 32*K elements
float* s_warp_vals = (float*)smem;
int* s_warp_idxs = (int*)(smem + num_warps * 32 * K * sizeof(float));
float* my_warp_vals = s_warp_vals + warp_id * 32 * K;
int* my_warp_idxs = s_warp_idxs + warp_id * 32 * K;
// Write thread top-k to warp shared memory
#pragma unroll
for (int i = 0; i < K; i++) {
my_warp_vals[lane * K + i] = vals[i];
my_warp_idxs[lane * K + i] = idxs[i];
}
__syncwarp();
// Bitonic sort within warp
int warp_cand = 32 * K;
for (int kk = 2; kk <= warp_cand; kk *= 2) {
for (int j = kk / 2; j > 0; j /= 2) {
for (int i = lane; i < warp_cand; i += 32) {
int ixj = i ^ j;
if (ixj > i) {
bool dir = (i & kk) == 0; // true for ascending in bitonic, but we want...
// For descending sort:
// In first half of k-block: larger at lower index
// In second half: larger at higher index
if (dir) { // i in first half
if (my_warp_vals[i] < my_warp_vals[ixj]) {
swap(my_warp_vals[i], my_warp_vals[ixj]);
swap(my_warp_idxs[i], my_warp_idxs[ixj]);
}
} else {
if (my_warp_vals[i] > my_warp_vals[ixj]) {
swap(my_warp_vals[i], my_warp_vals[ixj]);
swap(my_warp_idxs[i], my_warp_idxs[ixj]);
}
}
}
}
__syncwarp();
}
}
// Now top-K of warp are at my_warp_vals[0..K-1]
// Copy warp top-K to block candidate area
// Block candidate area is after warp work areas
float* s_block_vals = (float*)(smem + num_warps * 32 * K * (sizeof(float) + sizeof(int)));
int* s_block_idxs = (int*)(s_block_vals + num_warps * K);
if (lane < K) {
s_block_vals[warp_id * K + lane] = my_warp_vals[lane];
s_block_idxs[warp_id * K + lane] = my_warp_idxs[lane];
}
__syncthreads();
// Block-level bitonic sort on num_warps * K candidates
int block_cand = num_warps * K;
if (tid < block_cand) { // Only need block_cand threads, but all can participate
for (int kk = 2; kk <= block_cand; kk *= 2) {
for (int j = kk / 2; j > 0; j /= 2) {
for (int i = tid; i < block_cand; i += blockDim.x) {
int ixj = i ^ j;
if (ixj > i) {
bool dir = (i & kk) == 0;
if (dir) {
if (s_block_vals[i] < s_block_vals[ixj]) {
swap(s_block_vals[i], s_block_vals[ixj]);
swap(s_block_idxs[i], s_block_idxs[ixj]);
}
} else {
if (s_block_vals[i] > s_block_vals[ixj]) {
swap(s_block_vals[i], s_block_vals[ixj]);
swap(s_block_idxs[i], s_block_idxs[ixj]);
}
}
}
}
__syncthreads();
}
}
}
// Write output
if (tid < k) {
out_vals[row * k + tid] = s_block_vals[tid];
out_idxs[row * k + tid] = (int64_t)s_block_idxs[tid];
}
}
```
Wait, there's a subtle issue. In the block-level sort, I'm using `tid < block_cand` as a condition inside the loop, but the `__syncthreads()` requires all threads in the block to reach it. So I can't have an `if (tid < block_cand)` around the `__syncthreads()`.
I need all threads to participate in the block-level sort. For small block_cand (like 64), some threads will have `i >= block_cand` and do nothing, but they'll still hit `__syncthreads()`. That's fine.
Actually, let me restructure:
```cuda
for (int kk = 2; kk <= block_cand; kk *= 2) {
for (int j = kk / 2; j > 0; j /= 2) {
for (int i = tid; i < block_cand; i += blockDim.x) {
// compare and maybe swap
}
__syncthreads();
}
}
```
This works because all threads participate in the loop (even if their `i` is out of bounds, the loop just doesn't execute for them), and all hit `__syncthreads()`.
Now let me think about shared memory size:
- For K=64, num_warps=8:
- Warp work area: 8 * 32 * 64 * 4 = 64KB (vals) + 64KB (idxs) = 128KB
- Block candidate area: 8 * 64 * 4 = 2KB (vals) + 2KB (idxs) = 4KB
- Total: 132KB
This is on the high side but should fit on Blackwell (which has larger shared memory). For K=32: 64KB + 32KB + 2KB + 2KB = 100KB. For K=16: 64KB + 16KB + 1KB + 1KB = 82KB. For K=8: 64KB + 8KB + 0.5KB + 0.5KB = 73KB.
Hmm, 132KB might exceed the limit on some GPUs. Ampere has 164KB total but only 100KB is guaranteed for a single block without special config. Hopper has 228KB. Blackwell SM120... I think it has 256KB.
To be safe, I could reduce threads per block for K=64. With 128 threads (4 warps):
- Warp work area: 4 * 32 * 64 * 4 = 32KB + 32KB = 64KB
- Block candidates: 4 * 64 * 4 = 1KB + 1KB = 2KB
- Total: 66KB
Each thread then processes 2x more elements (1024 instead of 512 for n=131072). That's fine - insertion sort of 1024 elements into top-64 is still reasonable.
Alternatively, I can keep 256 threads but reduce the warp work area by having the warp merge happen in stages. For example, first merge pairs of threads (2*k), then pairs of pairs (4*k), etc. This uses less shared memory at any one time.
Actually, the pairwise merge approach:
- Step 1: 16 pairs of threads merge their top-k -> 16 lists of 2k candidates, sort, keep top-k. Shared: 2k per pair = 32k total.
- Step 2: 8 pairs of pairs merge -> 8 lists. Shared: 2k per merge = 16k? No, each merge takes 2k from previous step and produces k. We need 2k workspace.
This is essentially a tree. At each level, the workspace needed is `2 * k * (num_threads_at_level / 2)`.
Level 0 (thread pairs): 2 * k * 16 = 32k elements
Level 1 (4 threads): 2 * k * 8 = 16k elements
...
Actually, we can reuse the same workspace. Max needed at any level is 32k elements = 128KB for k=64. That's still a lot.
Let me just use fewer threads for larger k. I'll dispatch:
- k <= 16: 256 threads
- k == 32: 256 threads (100KB shared mem)
- k == 64: 128 threads (66KB shared mem)
Or better yet, I can make the block size a runtime parameter.
Actually, let me also consider another approach: instead of doing warp-level merge via shared memory bitonic sort, I can use warp shuffles to do a merge of two sorted lists.
Merge of two sorted k-element lists using warp shuffles:
- This is complex. There are papers on parallel merge, but it's not trivial.
Let me stick with the shared memory approach and just use 128 threads for k=64.
For n=131072 and 128 threads, each thread processes 1024 elements. With top-64 insertion sort, that's 1024 * 64 = 65536 comparisons per thread. At 1.5GHz, ~44 microseconds. Wait, that's a lot!
Actually, 65536 comparisons at 1 clock each = 65536 clocks = 43 microseconds per thread. But we're running 128 threads in parallel. The issue is that this is per block, and for batch=1, we have 1 block. So 43 microseconds just for comparisons. Plus memory reads.
Hmm, that's slow. With 256 threads, each thread does 512 * 64 = 32768 comparisons = 22 microseconds. Still not great.
Can I make the insertion faster?
Yes! For maintaining top-k, instead of shifting elements one by one, I can use a different data structure. But in registers, we're limited.
Actually, for k=64, if each thread processes 512 elements, we do 512 insertions into a size-64 sorted array. Each insertion is up to 64 shifts.
What if we use a tournament/tree approach within the thread?
1. Load 64 elements
2. Sort them using a bitonic sort network (fast, no branches)
3. Load next 64 elements, sort them
4. Merge the two sorted 64-element lists, keep top 64
5. Repeat
For n=512, we do 8 loads of 64 elements. Each 64-element sort: bitonic sort of 64 = 64 * 6 * 6 / 2 = 1152 comparisons. Each merge of two 64-element lists into top-64: we can use a bitonic merge of 128 elements and take top 64. Bitonic merge of 128 = 128 * 7 / 2 = 448 comparisons? Actually, a bitonic merge is just log2(128) = 7 stages with 64 comparators each = 448 comparisons.
Total: 8 sorts + 7 merges = 8 * 1152 + 7 * 448 = 9216 + 3136 = 12352 comparisons.
vs. original: 512 * 64 = 32768 comparisons.
That's 2.6x fewer comparisons! And bitonic sort networks have no branches (or minimal branches), so they execute faster.
But implementing bitonic sort for 64 elements in registers is complex. We'd need to do compare-and-swap of register pairs.
For 64 elements per thread, we need 64 registers for values and 64 for indices. That's 128 registers. Plus temp registers for swapping. A GPU has 256 registers per thread (typically), so this fits but leaves little room for other variables.
Actually, let me count:
- 64 float registers
- 64 int registers
- Loop counters, pointers: ~10 registers
Total: ~138 registers. Should fit in 255.
The bitonic sort of 64 elements in registers:
```cuda
// 64-element bitonic sort (descending)
// Using CUDA's ability to swap registers
#define CMP_SWAP(i, j) \
if (vals[i] < vals[j]) { \
float tmp_v = vals[i]; vals[i] = vals[j]; vals[j] = tmp_v; \
int tmp_i = idxs[i]; idxs[i] = idxs[j]; idxs[j] = tmp_i; \
}
// Stage 1: pairs
CMP_SWAP(0, 1); CMP_SWAP(2, 3); ... CMP_SWAP(62, 63);
// Stage 2: 4-element sequences
CMP_SWAP(0, 2); CMP_SWAP(1, 3); ...
// ...
```
This is a lot of code but can be fully unrolled. For 64 elements, log2(64)=6 stages, with 32*6 = 192 comparators per stage on average? Actually, total comparators for bitonic sort of N is (N * log2(N) * (log2(N)+1)) / 4 = 64 * 6 * 7 / 4 = 672 comparators.
Each comparator is a compare and possible swap. With registers, this is very fast.
But writing all this out is tedious. We can use a loop-based approach too:
```cuda
for (int k = 2; k <= 64; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
for (int i = 0; i < 64; i++) {
int ixj = i ^ j;
if (ixj > i) {
bool dir = (i & k) == 0;
if (dir) { if (vals[i] < vals[ixj]) swap; }
else { if (vals[i] > vals[ixj]) swap; }
}
}
}
}
```
This is much cleaner and the compiler will unroll it since 64 is a compile-time constant (if K is a template parameter and we use a constexpr). Wait, K is a template parameter, but the loop bounds use `K` which is known at compile time. The compiler should unroll this nicely.
Actually, for the inner loop `i = 0; i < K; i++`, the compiler knows K at compile time and can unroll. But there are branches inside. For register operations, branches might not be terrible but they do cause some divergence within a thread (not within a warp, since all threads in the warp do the same thing).
Wait, within a thread, the branch is deterministic based on `i` and `j`. The compiler might even optimize it away since `dir` is known at compile time for each unrolled iteration. Let me check:
For k=2, j=1:
- i=0: ixj=1, dir=(0&2)==0=true. Compare vals[0] < vals[1]
- i=1: ixj=0, ixj>i false, skip
- i=2: ixj=3, dir=(2&2)==0=false. Compare vals[2] > vals[3]
- etc.
The compiler can see all of this at compile time and generate straight-line code. So the loop-based version should compile to essentially the same code as fully unrolled CMP_SWAP macros.
Great! Let me use the loop-based bitonic sort for register arrays.
**Revised thread-level processing:**
Instead of streaming insertion sort, I'll do:
1. Process elements in chunks of K
2. For each chunk, copy to a temporary buffer, bitonic sort it
3. Merge with current top-K using a bitonic merge
4. Keep only top-K
Wait, merging two sorted K-element lists to get top-K:
- Concatenate: [current_topK, new_chunk_sorted]
- This is a bitonic sequence? No. [descending, descending] is not bitonic.
- But [descending, ascending_reversed] is bitonic.
- If I reverse the new chunk, I get a bitonic sequence of size 2K.
- Then I can bitonic merge it (log2(2K) stages) and take top K.
So:
1. Load K elements into `chunk_vals[K]`, `chunk_idxs[K]`
2. Bitonic sort `chunk` descending
3. Reverse `chunk` (or load in reverse order)
4. Concatenate with `topK`: [topK_vals, reversed_chunk_vals]
5. Bitonic merge of size 2K
6. First K elements are the new topK
Steps 3-5 can be optimized. Actually, [descending, descending] can be turned into a bitonic sequence by reversing the second half. Then bitonic merge.
For K=64, merging two sorted lists:
- Create bitonic sequence: topK[0..63] + chunk[63..0] (reversed)
- Bitonic merge of 128 elements: 7 stages
- Result: first 64 are the merged top-64
This is efficient!
But wait, we need to handle the fact that chunk might have fewer than K elements (last chunk). We can pad with -inf.
Also, for loading elements, we can process them in groups of K. For n=1024 and K=64, that's 16 chunks. Each chunk is sorted (672 comparators), then merged (448 comparators). Total: 16 * 672 + 15 * 448 = 10752 + 6720 = 17472 comparators.
vs. insertion sort: 1024 * 64 = 65536 comparisons.
So the chunk+merge approach is ~3.7x fewer comparisons. And it's branchless (or mostly branchless). This should be significantly faster!
Let me implement this.
```cuda
template <int K>
__device__ void bitonic_sort(float* vals, int* idxs, int n) {
for (int k = 2; k <= n; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
#pragma unroll
for (int i = 0; i < n; i++) {
int ixj = i ^ j;
if (ixj > i) {
bool dir = (i & k) == 0;
if (dir) {
if (vals[i] < vals[ixj]) {
float tmp_v = vals[i]; vals[i] = vals[ixj]; vals[ixj] = tmp_v;
int tmp_i = idxs[i]; idxs[i] = idxs[ixj]; idxs[j] = tmp_i;
}
} else {
if (vals[i] > vals[ixj]) {
float tmp_v = vals[i]; vals[i] = vals[ixj]; vals[ixj] = tmp_v;
int tmp_i = idxs[i]; idxs[i] = idxs[ixj]; idxs[j] = tmp_i;
}
}
}
}
}
}
}
template <int K>
__device__ void bitonic_merge_topk(float* topk_vals, int* topk_idxs,
float* chunk_vals, int* chunk_idxs) {
// chunk is already sorted descending
// Reverse chunk to create bitonic sequence: [topk_desc, chunk_asc]
float seq_vals[2*K];
int seq_idxs[2*K];
#pragma unroll
for (int i = 0; i < K; i++) {
seq_vals[i] = topk_vals[i];
seq_idxs[i] = topk_idxs[i];
seq_vals[K + i] = chunk_vals[K - 1 - i];
seq_idxs[K + i] = chunk_idxs[K - 1 - i];
}
// Bitonic merge of size 2K
for (int j = K; j > 0; j /= 2) {
#pragma unroll
for (int i = 0; i < 2*K; i++) {
int ixj = i ^ j;
if (ixj > i) {
// For descending merge: first half of each 2j block keeps larger
bool is_first_half = (i & (2*j)) == 0;
if (is_first_half) {
if (seq_vals[i] < seq_vals[ixj]) {
swap
}
} else {
if (seq_vals[i] > seq_vals[ixj]) {
swap
}
}
}
}
}
// Copy top K back
#pragma unroll
for (int i = 0; i < K; i++) {
topk_vals[i] = seq_vals[i];
topk_idxs[i] = seq_idxs[i];
}
}
```
Wait, for the bitonic merge, I'm not sure about the direction logic. Let me think carefully.
A bitonic sequence is one that first increases then decreases (or vice versa). [descending, ascending] is bitonic.
The bitonic merge algorithm for a bitonic sequence of size N:
1. Compare elements i and i + N/2 for all i < N/2
2. If the sequence is [descending, ascending] and we want descending:
- In the first half, keep the larger of each pair
- In the second half, keep the smaller of each pair
3. Recursively merge each half
Actually, the standard bitonic merge for descending:
```cuda
for (int j = N/2; j > 0; j /= 2) {
for (int i = 0; i < N; i++) {
int ixj = i ^ j;
if (ixj > i) {
if (i < (i | j)) { // or some condition
// keep larger in lower half
}
}
}
}
```
Hmm, let me think about this more carefully. In a bitonic merge:
- We compare i and i + j where j starts at N/2 and halves each iteration
- For a descending sort, we want the larger elements in the first half after each comparison stage
Actually, for the merge of a bitonic sequence into descending order:
```cuda
for (int j = N/2; j > 0; j /= 2) {
for (int i = 0; i < N; i++) {
int ixj = i ^ j;
if (ixj > i) {
// Determine direction based on position in the current sub-problem
bool up = ((i / (2*j)) % 2) == 0; // first half of 2j block
if (up) {
if (vals[i] < vals[ixj]) swap; // larger goes to first half
} else {
if (vals[i] > vals[ixj]) swap; // smaller goes to second half
}
}
}
}
```
Wait, in the first iteration j = N/2. The 2j block is the whole N. So `up = ((i / N) % 2) == 0` is always true for i < N. So we always do `if (vals[i] < vals[ixj]) swap`. This means after the first stage, the larger of each pair (i, i+N/2) goes to position i.
After this stage, we have two bitonic sequences of size N/2, each with larger elements in the first half. We recursively merge each half.
For the second iteration j = N/4:
- For i in [0, N/2): 2j block is [0, N/2) or [N/2, N).
- `up = ((i / (N/2)) % 2) == 0`. For i in [0, N/2), this is true. For i in [N/2, N), this is false.
- So for first half: keep larger. For second half: keep smaller.
This is correct!
So the merge code is:
```cuda
for (int j = K; j > 0; j /= 2) { // K is N/2 where N=2K
for (int i = 0; i < 2*K; i++) {
int ixj = i ^ j;
if (ixj > i) {
bool up = ((i / (2*j)) % 2) == 0;
if (up) {
if (seq_vals[i] < seq_vals[ixj]) swap;
} else {
if (seq_vals[i] > seq_vals[ixj]) swap;
}
}
}
}
```
Wait, for j=K, 2*j = 2*K = N. For all i in [0, N), i/N = 0, so up is always true. Good.
For j=K/2, 2*j = K. For i in [0, K), i/K = 0, up=true. For i in [K, 2K), i/K = 1, up=false. Good.
This looks correct!
Now, for the chunk processing, I need to load elements in chunks of K, sort each chunk, then merge with topk.
But what about the last chunk which might be smaller than K? I can just load the remaining elements and fill the rest with -inf.
Actually, an even simpler approach: process elements one at a time but use a more efficient insertion.
Hmm, I think the chunk+merge approach is good but might use too many registers (2K for the merge sequence). For K=64, that's 128 float registers + 128 int registers = 256 registers. That's a lot and might spill to local memory.
Let me check: for K=64, each thread needs:
- topk_vals[64], topk_idxs[64] = 128 registers
- chunk_vals[64], chunk_idxs[64] = 128 registers
- seq_vals[128], seq_idxs[128] = 256 registers
Total: 512 registers. That's way too much. GPU limit is 255 per thread. This will spill badly.
I need to be more register-efficient.
Alternative: instead of keeping chunk and seq in separate arrays, I can merge directly:
1. chunk_vals and chunk_idxs are loaded from global memory into registers
2. Sort chunk in place (using chunk_vals, chunk_idxs)
3. Merge with topk by comparing elements directly without a full seq array
A merge of two sorted descending arrays A[0..K-1] and B[0..K-1] into C[0..K-1]:
- We can do this with a parallel merge. Thread i finds the i-th largest element.
- But with one thread, a sequential merge is O(K).
Sequential merge:
```cuda
int a = 0, b = 0;
for (int i = 0; i < K; i++) {
if (a < K && (b >= K || A[a] >= B[b])) {
C[i] = A[a]; c_idx[i] = A_idx[a]; a++;
} else {
C[i] = B[b]; c_idx[i] = B_idx[b]; b++;
}
}
```
This is K comparisons and K assignments. For K=64, that's 64 operations. And it uses only A, B, C arrays (3 * 64 * 2 = 384 values, but they can overlap).
Wait, we can do this in place or with a temporary:
```cuda
float new_vals[K];
int new_idxs[K];
int a = 0, b = 0;
for (int i = 0; i < K; i++) {
if (a < K && (b >= K || topk_vals[a] >= chunk_vals[b])) {
new_vals[i] = topk_vals[a];
new_idxs[i] = topk_idxs[a];
a++;
} else {
new_vals[i] = chunk_vals[b];
new_idxs[i] = chunk_idxs[b];
b++;
}
}
// Copy back
for (int i = 0; i < K; i++) {
topk_vals[i] = new_vals[i];
topk_idxs[i] = new_idxs[i];
}
```
This uses 2K extra registers (new_vals, new_idxs). Total: 4K = 256 registers for K=64. Still a lot but might fit.
Actually, we can optimize: instead of a separate new array, we can write directly to topk if we're careful. But since we read from topk and write to topk, we'd overwrite elements before reading them.
One trick: since both arrays are sorted descending, and we want the top K, we can use a double-buffered approach with just 2 arrays that swap roles.
Actually, let me just use the sequential merge. It's simple and only uses one extra array of K. Total registers: 3K values + K values temp = 4K values. For K=64: 256 values. Each value is a float or int, so ~256 registers. This is tight but might fit.
For K=32: 128 registers. Very comfortable.
For K=16: 64 registers.
For K=8: 32 registers.
So the sequential merge approach works well for smaller K and is acceptable for K=64.
Let me also think about the chunk size. For sequential merge, the cost per chunk is:
- Sort chunk of size C: bitonic sort = C * log2(C)^2 / 2 comparators
- Merge with topk: K comparisons
If C = K, sort cost = K * log2(K)^2 / 2, merge cost = K.
Total per chunk: K * (log2(K)^2 / 2 + 1).
For n elements, number of chunks = n / K.
Total cost: n * (log2(K)^2 / 2 + 1).
For K=64: log2(64) = 6, so cost per element = 6^2/2 + 1 = 19 comparisons.
For n=1024: 1024 * 19 = 19456 comparisons.
vs. simple insertion: 1024 * 64 = 65536 comparisons.
Improvement: 3.4x.
For n=512: 512 * 19 = 9728 comparisons. At 1.5GHz, ~6.5 microseconds. Much better!
And register usage: chunk_vals[64], chunk_idxs[64], topk_vals[64], topk_idxs[64], new_vals[64], new_idxs[64] = 384 values.
384 registers is too many. But some of these can be optimized:
- chunk is loaded from memory, sorted, then merged. We can reuse chunk storage as temp during merge.
- Actually, during merge we need topk_vals, chunk_vals, and new_vals. That's 3 * 64 = 192 float registers + 192 int registers = 384 registers.
That's still a lot. But we don't need to keep chunk_vals after merge. So:
1. Load 64 elements into chunk_vals, chunk_idxs
2. Sort chunk_vals, chunk_idxs in place (using chunk arrays)
3. Merge topk_vals and chunk_vals into new_vals
4. Copy new_vals to topk_vals
After step 2, we still need chunk_vals for step 3. After step 3, we can discard chunk_vals. But registers are allocated statically, so the compiler will allocate space for all of them.
Actually, the compiler might optimize this. If we write it carefully, the compiler might reuse registers. But it's hard to guarantee.
Alternative: use a smaller chunk size, like C=32.
- Sort chunk of 32: 32 * 5 * 5 / 2 = 400 comparators
- Merge with topk (64): 64 comparisons
- Per chunk: 464 operations
- For n=1024: 32 chunks * 464 = 14848 operations.
- Registers: chunk[32], topk[64], new[64] = 160 values. Much better!
For K=64, chunk size 32:
- Number of chunks = ceil(1024 / 32) = 32
- Total ops = 32 * (400 + 64) = 14848
- Still much better than insertion (65536)
For K=32, chunk size 32:
- Sort chunk: 32 * 5 * 5 / 2 = 400
- Merge: 32 comparisons
- Per chunk: 432
- For n=512: 16 * 432 = 6912
This is great! Let me use chunk size = min(K, 32) or something. Actually, let me just use a fixed chunk size of 32 for all K >= 8. For K=8, chunk size 8.
Wait, if K=8, chunk size should be 8. Sorting 8 elements is 8 * 3 * 3 / 2 = 36 comparators. Merging with topk of 8: 8 comparisons. For n=1024: 128 chunks * 44 = 5632 operations. vs insertion: 1024 * 8 = 8192. Still better, but not as dramatic.
For K=1, we don't need any of this.
OK, I think using chunk-based processing with bitonic sort of chunks and sequential merge is a good approach. Let me formalize it.
Actually, let me step back. Maybe I'm overcomplicating this. Let me first try a simple insertion sort approach and see how it performs. If it's fast enough, I don't need the complexity.
For the largest shape (1, 131072, 64) with 256 threads:
- Each thread: 512 insertions into top-64
- 512 * 64 = 32768 comparisons
- At ~1 comparison per clock (optimistic), 32768 / 1.5e9 = 22 microseconds
- Memory read: 512KB at 1.8 TB/s = 0.28 microseconds (theoretical), probably 2-5 microseconds in practice
- So compute dominates at ~22 microseconds
Total: ~25-30 microseconds. Is that good? Let me check what PyTorch's topk does. For similar sizes, torch.topk might take 10-20 microseconds on H100. So 25-30 microseconds might be ~50-70% of torch.topk. We want to beat or match torch.topk.
If I optimize to 3x fewer comparisons with chunk+merge, I could get down to ~10-15 microseconds. That would be excellent.
Let me try the chunk+merge approach with chunk size = K (no extra register overhead for chunk arrays beyond what's needed for sorting).
Actually, I just realized: for bitonic sort of K elements in registers, I can do it in-place on the topk array if I'm processing the first chunk. Then for subsequent chunks, I need a temporary chunk array.
But there's another approach: instead of maintaining a sorted topk array and doing insertion/merge, I can:
1. Load all n/thread elements into a local buffer
2. Sort the entire buffer using a fast sorting network or quicksort
3. Take top-k
For n/thread = 512, sorting 512 elements is 512 * 9 * 9 / 2 = 20736 comparisons. Then take top-64. This is actually more comparisons than insertion sort (32768 vs 20736)! Wait, 20736 is LESS than 32768.
But sorting 512 elements in registers requires 512 float registers + 512 int registers = 1024 registers. That's way too many.
So we can't sort the entire local buffer.
What if we use a partial sort? We only need top-64 out of 512. A partial heap or tournament tree would be O(512 * log(64)) = 512 * 6 = 3072 operations. That's much better!
A tournament tree / heap of size 64:
- Insert 512 elements one by one
- Each insertion: compare with root (smallest in heap), if larger, replace root and sift down
- Sift down: at most log2(64) = 6 comparisons
Total: 512 * 6 = 3072 comparisons. vs. 32768 for insertion sort. That's 10x better!
But a heap in registers is tricky. The sift-down requires accessing elements at indices that depend on the tree structure. In a binary heap stored in an array:
- Root at 0
- Children of i at 2i+1, 2i+2
- Parent of i at (i-1)/2
For sift-down:
```cuda
int pos = 0;
while (2*pos+1 < K) {
int child = 2*pos+1;
if (child+1 < K && vals[child+1] > vals[child]) child++;
if (vals[pos] >= vals[child]) break;
swap(vals[pos], vals[child]);
swap(idxs[pos], idxs[child]);
pos = child;
}
```
This is up to log2(K) = 6 swaps, but each swap is 3 assignments, and there are branches.
For K=64 and 512 elements: 512 * 6 * (some constant) = ~3000 operations.
But wait, for heap-based top-k, we actually want a MIN-heap of the current top-k. The root is the smallest of the top-k. When we see a new element:
- If new element > root, replace root and sift-down
- If new element <= root, ignore
After processing all elements, the heap contains the top-k but not sorted. We need to extract them in sorted order. We can do this by repeatedly extracting the max... but a min-heap doesn't support efficient max extraction.
Alternatively, after building the min-heap of top-k, we can sort the heap array using a sort (64 * 6 * 6 / 2 = 1152 comparisons) to get them in descending order.
Total: 512 * 6 + 1152 = 4224 operations for heap + sort. vs. insertion: 32768. That's 7.8x better!
And register usage: just the heap array of 64 + temp variables. Very efficient!
Let me implement the heap approach.
```cuda
template <int K>
__device__ void heap_sift_down(float* vals, int* idxs, int pos) {
while (true) {
int left = 2 * pos + 1;
int right = 2 * pos + 2;
int smallest = pos;
if (left < K && vals[left] < vals[smallest]) smallest = left;
if (right < K && vals[right] < vals[smallest]) smallest = right;
if (smallest == pos) break;
float tmp_v = vals[pos]; vals[pos] = vals[smallest]; vals[smallest] = tmp_v;
int tmp_i = idxs[pos]; idxs[pos] = idxs[smallest]; idxs[smallest] = tmp_i;
pos = smallest;
}
}
template <int K>
__device__ void build_min_heap(float* vals, int* idxs) {
for (int i = K / 2 - 1; i >= 0; i--) {
heap_sift_down<K>(vals, idxs, i);
}
}
template <int K>
__device__ void heap_sort_descending(float* vals, int* idxs) {
// Convert min-heap to sorted descending
for (int i = K - 1; i > 0; i--) {
// Move max (actually, for min-heap, the max is somewhere in the leaves)
// Hmm, heap sort from min-heap gives ascending order, not descending.
// We need max-heap for descending order.
}
}
```
Wait, I want the final output sorted descending. If I use a min-heap, at the end I have the top-k elements but unsorted. To get descending order:
- Build a max-heap instead? No, for streaming top-k we need a min-heap to quickly check if a new element belongs.
- After processing, we can just sort the K elements using bitonic sort.
So:
1. Initialize heap with -inf
2. For each element: if > root, replace root and sift-down
3. After all elements, sort the K elements descending
For step 2, the heap is a min-heap (root is smallest of top-k).
For step 3, bitonic sort of K elements descending.
This is clean! Let me write it.
```cuda
template <int K>
__device__ void heap_replace_root(float* vals, int* idxs, float new_val, int new_idx) {
vals[0] = new_val;
idxs[0] = new_idx;
int pos = 0;
while (true) {
int left = 2 * pos + 1;
int right = 2 * pos + 2;
int smallest = pos;
if (left < K && vals[left] < vals[smallest]) smallest = left;
if (right < K && vals[right] < vals[smallest]) smallest = right;
if (smallest == pos) break;
float tmp_v = vals[pos]; vals[pos] = vals[smallest]; vals[smallest] = tmp_v;
int tmp_i = idxs[pos]; idxs[pos] = idxs[smallest]; idxs[smallest] = tmp_i;
pos = smallest;
}
}
template <int K>
__device__ void bitonic_sort_desc(float* vals, int* idxs, int n) {
for (int k = 2; k <= n; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
for (int i = 0; i < n; i++) {
int ixj = i ^ j;
if (ixj > i) {
bool dir = (i & k) == 0;
if (dir) {
if (vals[i] < vals[ixj]) {
float tv = vals[i]; vals[i] = vals[ixj]; vals[ixj] = tv;
int ti = idxs[i]; idxs[i] = idxs[ixj]; idxs[ixj] = ti;
}
} else {
if (vals[i] > vals[ixj]) {
float tv = vals[i]; vals[i] = vals[ixj]; vals[ixj] = tv;
int ti = idxs[i]; idxs[i] = idxs[ixj]; idxs[ixj] = ti;
}
}
}
}
}
}
}
```
Now the thread-level processing:
```cuda
float vals[K];
int idxs[K];
#pragma unroll
for (int i = 0; i < K; i++) {
vals[i] = -INFINITY;
idxs[i] = -1;
}
// Optional: build initial heap with first K elements
// But since we stream, we can just process all elements
for (int i = tid; i < n; i += blockDim.x) {
float v = row_x[i];
if (v > vals[0]) {
heap_replace_root<K>(vals, idxs, v, i);
}
}
// Sort descending
bitonic_sort_desc<K>(vals, idxs, K);
```
For K=64, 512 elements per thread:
- Heap operations: 512 sift-downs, each up to 6 levels
- Each level: 2 comparisons + 3 swaps (9 ops)
- Total: 512 * 6 * 9 = 27648 operations
- Plus sort: 1152 comparisons
- Total: ~28800 operations
Hmm, that's actually similar to insertion sort (32768). The heap has less comparison but more memory movement (swaps). In registers, swaps are fast but still take instructions.
Wait, I can optimize the heap by not doing full swaps. Instead, track the position and do a single assignment at the end:
```cuda
template <int K>
__device__ void heap_replace_root_fast(float* vals, int* idxs, float new_val, int new_idx) {
int pos = 0;
while (true) {
int left = 2 * pos + 1;
int right = 2 * pos + 2;
int smallest = pos;
if (left < K && vals[left] < vals[smallest]) smallest = left;
if (right < K && vals[right] < vals[smallest]) smallest = right;
if (smallest == pos) break;
pos = smallest;
}
// Shift elements along the path and insert at pos
// Wait, we need to shift from root to pos, but in a heap the parent is always <= children.
// Actually, if we just put new_val at root and sift down, we compare with children and move
// the smaller child up. But in our case, we want to maintain min-heap property.
// Standard sift down:
// Start at root with new_val
// While current node > smaller child:
// Move smaller child up
// Go to that child position
// Insert new_val at final position
}
```
Ah, the standard sift-down is:
```cuda
float v = new_val;
int idx = new_idx;
int pos = 0;
while (true) {
int left = 2 * pos + 1;
int right = 2 * pos + 2;
int smallest = pos;
if (left < K && vals[left] < v) smallest = left;
if (right < K && vals[right] < v) smallest = right;
if (smallest == pos) {
vals[pos] = v;
idxs[pos] = idx;
break;
}
vals[pos] = vals[smallest];
idxs[pos] = idxs[smallest];
pos = smallest;
}
```
This does at most log2(K) iterations, each with 2 comparisons and 2 assignments (moving child up). Then 2 assignments at the end. Total: 6 * (2+2) + 2 = 26 operations per sift-down.
For 512 elements: 512 * 26 = 13312 operations. Plus sort: 1152. Total: ~14500.
vs. insertion sort: 512 * 64 = 32768 comparisons + 32768 shifts = ~65536 operations.
Heap approach: ~14500 operations. 4.5x better!
And register usage: just vals[K] and idxs[K]. No extra temp arrays needed for sift-down (just a few scalar variables).
This is excellent! Let me use the heap approach.
Wait, there's a subtle issue. The heap requires `vals[0]` to be the smallest of the top-k. But initially, all vals are -inf. After replacing root with the first K elements, the heap property might not hold.
Actually, if we start with all -inf, and we replace root with element 0, then root is still -inf (if element 0 is -inf, which it isn't... wait, element 0 is some value > -inf).
Let me trace: initial vals = [-inf, -inf, ..., -inf].
Process element 0 (v=1.0): v > vals[0] (-inf), replace root with 1.0, sift down. Since children are -inf, 1.0 stays at root. vals = [1.0, -inf, ..., -inf].
Process element 1 (v=2.0): v > vals[0] (1.0), replace root with 2.0, sift down. vals = [2.0, -inf, ..., -inf].
...
Process element K-1 (v=K): vals = [K, -inf, ..., -inf].
At this point, the heap property is violated! The root is K, but many children are -inf. A min-heap requires parent <= children. Here parent > children.
But wait, the sift-down algorithm I wrote compares the new value with children and moves smaller children up. If children are -inf, they'll move up, and the new value will go down.
Let me trace more carefully:
- vals = [1.0, -inf, -inf, ...]
- Replace root with 2.0, sift down:
- pos=0, left=1 (-inf), right=2 (-inf). smallest=1 (or 2). Move vals[1] (-inf) to pos 0. pos=1.
- pos=1, left=3 (-inf), right=4 (-inf). Move vals[3] to pos 1. pos=3.
- ... continue down to a leaf.
- Insert 2.0 at leaf.
After this, the heap is almost all -inf except one leaf has 2.0. That's not right either.
Hmm, I see the problem. In a min-heap, when we replace the root (minimum), we should put the new element at the root and sift it down to maintain the heap property. But in our case, the new element is LARGER than the root (that's why we're replacing). So it will sift down until it finds a position where it's <= its children. Since many children might be -inf, it might go all the way to a leaf.
This means after processing the first few elements, the large elements end up at the leaves, and -inf values bubble up to the root. Then subsequent elements are compared with -inf at the root, so almost all of them replace the root and sift down.
This is actually correct behavior! After processing all elements, the heap will contain the K largest elements, with the smallest of those K at the root. The large elements will be scattered throughout the heap.
Wait, let me verify. After processing many elements:
- The root should be the K-th largest (smallest among top-K)
- All elements in the heap should be >= root
- This is exactly the min-heap property!
Yes, this works. When we replace root with a new element that's larger, it sifts down to its correct position. When we replace with an element that's smaller, it might not sift down at all and stay at root.
Actually, in our algorithm, we only replace if new element > root. So the new element is always > current root. It will sift down to some position. The old root is discarded.
After processing all elements, the heap contains the K largest elements seen so far, with the smallest of them at the root. Perfect!
Then we sort the heap in descending order using bitonic sort.
One more thing: the sift-down algorithm I wrote compares `v` with children, not `vals[pos]` with children. This is the standard sift-down where we start with a "hole" at the root and move it down:
```cuda
float v = new_val;
int idx = new_idx;
int pos = 0;
while (true) {
int left = 2 * pos + 1;
int right = 2 * pos + 2;
int smallest = pos;
if (left < K && vals[left] < v) smallest = left;
if (right < K && vals[right] < v) smallest = right;
if (smallest == pos) break;
vals[pos] = vals[smallest];
idxs[pos] = idxs[smallest];
pos = smallest;
}
vals[pos] = v;
idxs[pos] = idx;
```
Yes! This is correct and efficient.
Now, what about when there are duplicate values or ties? The heap handles this fine. The final sort is stable enough for our purposes (we don't care about tie-breaking for indices).
Great, I think the heap approach is the way to go. Let me now design the full kernel with warp-level and block-level merging.
Actually, wait. I was going to do warp merge + block merge. But with the heap approach, each thread processes its chunk and produces a sorted top-k. Then we still need to merge.
For warp merge, I can use the same heap approach within the warp! Each thread has sorted top-k. We want to merge 32 sorted top-k lists into one top-k list.
One approach:
1. Write all 32*k candidates to shared memory
2. One thread builds a heap of size k from all candidates
3. But that's O(32*k) operations for one thread, which is slow.
Better: use a tournament tree or parallel heap.
Actually, for warp merge, a bitonic sort of 32*k elements in shared memory is probably the most parallel and efficient. Let me stick with that.
For block-level merge of warp results, also use bitonic sort.
OK here's the final kernel design:
**Kernel: `topk_kernel<K, THREADS>`**
Parameters:
- `K`: compile-time k
- `THREADS`: threads per block (128 or 256)
Shared memory layout:
- Warp work area: `num_warps * 32 * K * (sizeof(float) + sizeof(int))`
- Block candidate area: `num_warps * K * (sizeof(float) + sizeof(int))`
Thread processing:
1. Initialize heap with -inf
2. Strided loop over elements, heap replace if larger
3. Bitonic sort thread's top-k in registers
4. Write to warp shared memory
5. Warp-level bitonic sort of 32*K candidates
6. Warp's thread 0 writes top-k to block candidate area
7. __syncthreads()
8. Block-level bitonic sort of num_warps*K candidates
9. Write top-k to output
Wait, in step 6, I said "warp's thread 0 writes top-k". But we can have all threads in the warp participate. After the warp-level sort, top-k are at indices 0..K-1 in warp shared memory. I can have all threads in the block participate in writing to the block candidate area.
Actually, since each warp sorts independently in its own shared memory region, after sort the top-k are at the start of each warp's region. Then I need to copy them to the block candidate area.
Let me use a dedicated step:
- After warp sort, if (lane < K), write `my_warp_vals[lane]` to `s_block_vals[warp_id * K + lane]`
This way, all warps write simultaneously and no thread is idle.
Then `__syncthreads()` and block sort.
For the block sort, all threads participate. After sort, threads 0..K-1 write to output.
This is clean!
Now let me think about the bitonic sort in shared memory for the warp level. With 32 threads in a warp, each thread can process multiple elements.
For warp_cand = 32 * K:
- Number of threads = 32
- Elements per thread = K
In each stage, thread `lane` processes elements `lane, lane+32, lane+64, ...` up to `warp_cand`.
For K=64, warp_cand = 2048. Each thread processes 64 elements. In each `__syncwarp()`, the threads cooperate on the sort.
For j=1: thread lane compares `lane*64 + i` and `lane*64 + i ^ 1` for various i.
Wait, the indexing is `i` and `i ^ j`, not `lane*64 + i`.
The standard loop is:
```cuda
for (int i = lane; i < warp_cand; i += 32) {
int ixj = i ^ j;
if (ixj > i) {
// compare and swap
}
}
```
For j=1: i=lane, lane+32, lane+64, ...
ixj = i ^ 1. For even i, ixj = i+1. For odd i, ixj = i-1.
So thread lane compares pairs (i, i^1) where i = lane, lane+32, etc.
For even lane: compares (lane, lane+1), (lane+32, lane+33), etc.
For odd lane: compares (lane-1, lane), (lane+31, lane+32), etc.
But the `ixj > i` check means odd lanes skip. So only even lanes do the compare-and-swap. That's fine - half the threads are active.
For j=2: i = lane, lane+32, ...
ixj = i ^ 2. For lane=0: compares (0,2), (32,34), etc. For lane=1: (1,3), (33,35), etc.
For lane=2: ixj=0, but 0 < 2, so skip. For lane=3: ixj=1, 1 < 3, skip.
So lanes 0 and 1 are active, lanes 2 and 3 skip.
This pattern continues. In each stage, half the threads are active. This is actually quite efficient.
But there's still the bank conflict issue. For large j (like j=1024), thread lane accesses i=lane and ixj=lane^1024. Both map to the same bank if 1024 is a multiple of the number of banks. With 32-byte banks (8 floats per bank), 1024 / 8 = 128, so i and i^1024 are in the same bank. But different threads access different pairs, so there might not be conflicts within a warp.
Wait, for j=1024 and lane=0: accesses 0 and 1024. Bank(0) = 0, Bank(1024) = 1024 % 8 = 0. Same bank!
For lane=1: accesses 1 and 1025. Bank(1) = 1, Bank(1025) = 1. Same bank!
For lane=2: accesses 2 and 1026. Bank(2) = 2, Bank(1026) = 2. Same bank!
So ALL threads in the warp access two addresses in the SAME bank. That's a 32-way bank conflict! But wait, modern GPUs (since Volta) have memory access multicast/broadcast. If all threads in a warp access the same bank but different addresses, it's serialized. If they access the same address, it's broadcast.
For j=1024, each thread accesses two DIFFERENT addresses in the same bank. So the warp needs 2 * 32 = 64 accesses to that bank, serialized into 64 transactions. That's terrible!
However, the bank width is 4 bytes on older GPUs and 8 bytes on newer ones. Let me check Blackwell... I think it's 8 bytes (2 floats per bank) or maybe 32 bytes.
Actually, for shared memory banks:
- Pre-Volta: 4 bytes per bank, 32 banks
- Volta+: 4 bytes per bank, 32 banks (but independent thread scheduling)
- Some sources say Ampere/Hopper have 32-byte banks for certain access patterns
Actually, I think for NVIDIA GPUs, shared memory has 32 banks, each 4 bytes wide. So consecutive 4-byte words are in consecutive banks. Two addresses in the same bank if they differ by a multiple of 128 bytes (32 * 4).
For j=1024, addresses are i*4 and (i^1024)*4 bytes. The difference is 1024*4 = 4096 bytes, which is a multiple of 128. So yes, same bank.
For j=512: diff = 512*4 = 2048, multiple of 128. Same bank.
For j=256: diff = 1024, multiple of 128. Same bank.
For j=128: diff = 512, multiple of 128. Same bank.
For j=64: diff = 256, not multiple of 128. Different banks!
For j=32: diff = 128, multiple of 128. Same bank.
Wait, 32*4 = 128. Yes, j=32 also causes same bank.
So for j = 32, 64, 128, 256, 512, 1024, most of these cause same-bank accesses! That's many stages with bank conflicts.
To fix this, we can use padded indexing. Instead of `s[i]`, use `s[i + (i >> 5)]` or similar. But then `i^j` needs to map correctly too.
A common padding is:
```cuda
#define IDX(i) ((i) + ((i) >> 5))
```
This adds one extra element every 32 elements, shifting elements 32+ apart by 1. So `IDX(0)=0, IDX(1)=1, ..., IDX(31)=31, IDX(32)=33, IDX(33)=34, ...`.
Now, for j=32: i and i^32.
- i=0: IDX(0)=0, IDX(32)=33. Banks: 0 % 32 = 0, 33 % 32 = 1. Different!
- i=1: IDX(1)=1, IDX(33)=34. Banks: 1, 2. Different!
For j=64: i=0: IDX(0)=0, IDX(64)=66. Banks: 0, 2. Different!
i=32: IDX(32)=33, IDX(96)=99. Banks: 1, 3. Different!
For j=1024: i=0: IDX(0)=0, IDX(1024)=1056. Banks: 0, 1056 % 32 = 0. Still same bank!
Hmm, j=1024 corresponds to offset of 1024 elements = 4096 bytes. Even with padding, 4096 / 128 = 32 padding slots inserted, so the offset in memory is 4096 + 32 * something.
Actually, for i=0: IDX(0)=0. For i=1024: there are 1024/32 = 32 padding slots inserted before index 1024. So IDX(1024) = 1024 + 32 = 1056. The byte offset is 1056 * 4 = 4224. 4224 / 128 = 33. So bank = 4224 % 128 / 4 = ... wait, bank = (byte_offset / 4) % 32 = 1056 % 32 = 0. Same as IDX(0) = 0 % 32 = 0.
So j=1024 still conflicts. But j=1024 is large, and there are only log2(N) - 10 stages with j >= 1024 (for N=2048, only j=1024). For N=2048, the stages with j=1024 is just one stage. The conflict is 2-way (each thread accesses 2 addresses in the same bank), so 2 transactions per thread. With 32 threads, that's 64 transactions total. It's bad but only for one stage.
For j=512: i=0, IDX(0)=0, IDX(512)=512+16=528. Bank: 528 % 32 = 16. Different from 0! Great.
For j=256: i=0, IDX(0)=0, IDX(256)=256+8=264. Bank: 264 % 32 = 8. Different!
For j=128: i=0, IDX(0)=0, IDX(128)=128+4=132. Bank: 132 % 32 = 4. Different!
For j=64: i=0, IDX(0)=0, IDX(64)=64+2=66. Bank: 66 % 32 = 2. Different!
For j=32: i=0, IDX(0)=0, IDX(32)=32+1=33. Bank: 33 % 32 = 1. Different!
So with `IDX(i) = i + (i >> 5)`, the only remaining conflict is j=1024 (and j=2048, etc., powers of 2 >= 1024). For N <= 2048, only j=1024 conflicts. For N=16384, j=8192, 16384 would also conflict.
This is a huge improvement! Let me use this padding.
The total shared memory size with padding: for N elements, we need N + N/32 elements. For 2048: 2048 + 64 = 2112. For 16384: 16384 + 512 = 16896.
Wait, for the warp-level sort, warp_cand = 32 * K. For K=64, that's 2048. With padding: 2112 floats + 2112 ints = 16.5KB + 16.5KB = 33KB per warp.
For 8 warps: 264KB. That's too much!
Actually, I miscalculated. Each warp uses its own shared memory region. The regions don't overlap in time (all warps run concurrently). So total shared memory = max per warp * num_warps = 33KB * 8 = 264KB if they're all active simultaneously.
But wait, shared memory is per block, and all warps in the block share it. So I need 264KB. That's way too much.
I can't use this much shared memory. Let me reconsider.
Options:
1. Don't use padding, accept bank conflicts
2. Use fewer warps (fewer threads)
3. Process warp merges sequentially instead of in parallel
For option 3: have one warp do all the warp merges one at a time. But then we're not using the full block parallelism.
For option 2: use 128 threads (4 warps). Shared memory: 4 * 33KB = 132KB. Still a lot but maybe fits on Blackwell.
For option 1: no padding. Let me see how bad the conflicts really are.
With j=1024, each thread accesses two addresses in the same bank. On modern GPUs, shared memory can handle 2-way conflicts at half bandwidth. For 32 threads accessing the same bank, it's serialized 64-way. Wait, each thread accesses 2 addresses in the SAME bank, but different addresses. For 32 threads, that's 64 different addresses in one bank. This requires 64 separate transactions. At 1 transaction per clock, that's 64 clocks = ~43 nanoseconds. For one stage. There are log2(N) stages total, so this is a small fraction.
Actually, for a warp-level sort of 2048 elements, there are log2(2048) * (log2(2048)+1) / 2 = 11 * 12 / 2 = 66 stages. Each stage has some threads active. The total time is dominated by the many small-j stages where all threads are active and there are no bank conflicts. The large-j stages with conflicts are few in number.
So maybe bank conflicts aren't that bad. Let me just use un-padded shared memory and see.
Actually, let me think about this from a different angle. For the warp-level sort, with 32 threads and 2048 elements:
- j=1: 1024 comparators, 32 threads active -> 32 iterations, no conflicts
- j=2: 1024 comparators, 16 threads active -> 16 iterations, some conflicts
- ...
- j=1024: 1024 comparators, 1 thread active -> 1 iteration, conflict
The total "work" in terms of comparator executions is 2048 * 11 * 12 / 4 = 67584 comparators. At 1 comparator per thread per clock, with 32 threads, that's ~2112 clocks = 1.4 microseconds. Even with some slowdown from bank conflicts, it's still very fast.
Similarly for block-level sort of 512 elements with 128 threads:
- 512 * 9 * 10 / 4 = 11520 comparators
- With 128 threads, ~90 clocks = 60 nanoseconds. Very fast!
So the sorting time is negligible compared to the heap processing time (~10+ microseconds). Bank conflicts in shared memory sort are not the bottleneck.
Let me not worry about padding and just use simple indexing.
Now, let me also consider vectorized loads for the element processing. Each thread can load float4 (4 floats at a time) from global memory. This is especially beneficial for large n.
For n=131072 and 128 threads, each thread reads 1024 elements = 256 float4 loads.
For n=8192 and 256 threads, each thread reads 32 elements = 8 float4 loads.
For n=4096 and 256 threads, each thread reads 16 elements = 4 float4 loads.
Vectorized loads can significantly improve memory bandwidth. Let me add them.
```cuda
// Process 4 elements at a time using float4
int n4 = n / 4;
float4* x4 = (float4*)(row_x);
for (int i = tid; i < n4; i += blockDim.x) {
float4 v4 = x4[i];
if (v4.x > vals[0]) heap_replace<K>(vals, idxs, v4.x, i * 4);
if (v4.y > vals[0]) heap_replace<K>(vals, idxs, v4.y, i * 4 + 1);
if (v4.z > vals[0]) heap_replace<K>(vals, idxs, v4.z, i * 4 + 2);
if (v4.w > vals[0]) heap_replace<K>(vals, idxs, v4.w, i * 4 + 3);
}
// Process remaining elements
for (int i = tid + n4 * 4; i < n; i += blockDim.x) {
float v = row_x[i];
if (v > vals[0]) heap_replace<K>(vals, idxs, v, i);
}
```
Wait, there's an issue: the float4 load requires 16-byte alignment. For a row starting at `row * n`, if `n` is not a multiple of 4... actually, `row * n` is a multiple of 4 bytes (since each float is 4 bytes), so `row_x` is always 4-byte aligned. But float4 requires 16-byte alignment!
For n=12000, row_x + 0 is aligned to 4 bytes but not necessarily 16 bytes (it is aligned to 4 bytes because row * 12000 * 4 = row * 48000, which is divisible by 4 but not necessarily by 16).
Actually, for row-major layout, the start of each row is at `batch * n` offset. If `n` is not a multiple of 4, then rows after row 0 are not 16-byte aligned. For n=12000, 12000 * 4 = 48000, which is divisible by 16? 48000 / 16 = 3000. Yes, it is!
Wait, 12000 = 16 * 750. So 12000 * 4 = 48000 = 16 * 3000. Yes, divisible by 16.
For n=131072: 131072 * 4 = 524288 = 16 * 32768. Divisible by 16.
For n=8192: 8192 * 4 = 32768 = 16 * 2048. Divisible by 16.
For n=16384: 16384 * 4 = 65536 = 16 * 4096. Divisible by 16.
For n=4096: 4096 * 4 = 16384 = 16 * 1024. Divisible by 16.
All n values are multiples of 4, and since the tensor starts at a 16-byte aligned address (PyTorch allocates aligned memory), all rows are 16-byte aligned. So float4 loads are safe!
Actually, let me double-check. PyTorch tensors are typically allocated with at least 16-byte alignment for CUDA. And all our n values are multiples of 4. So yes, float4 is safe.
For n=12000, the last row starts at 15 * 12000 * 4 = 720000, which is 16 * 45000. Still aligned.
Great, I'll use float4 loads.
Now let me also think about the k=1 case. For k=1, the heap approach degenerates to a simple max reduction. I already wrote a specialized argmax kernel. But I can also just use the heap approach with K=1:
- vals[0] starts at -inf
- For each element, if > vals[0], replace
- After all elements, sort 1 element (trivial)
- Then warp-level and block-level sorts are also trivial
For K=1, warp_cand = 32 * 1 = 32. Block_cand = 8 * 1 = 8. The sorts are tiny.
But the argmax reduction using warp shuffles is much faster. I'll keep the specialized k=1 kernel.
Actually, I can even use warp shuffles for the warp-level merge of K=1. But let me just keep the specialized kernel - it's simple and fast.
Let me now think about warp divergence in the heap replace. The condition `if (v > vals[0])` causes divergence within the warp. Some threads might replace, others might not. The `heap_replace` function has loops with dynamic iteration counts, which causes more divergence.
However, for random data, roughly half the elements will be > root (initially) and then fewer as the heap fills. Actually, for random Gaussian data, the probability that a new element is > the K-th largest of previously seen elements decreases. But with K=64 and n=131072, after seeing enough elements, only about K/n = 0.05% of new elements will be > root.
Wait, that's for the final stages. But initially, almost all elements are > -inf. So for the first 64 elements, all threads replace. Then the heap fills and replacements become rare.
The divergence is not terrible because:
1. Threads in a warp see similar data patterns (they're processing nearby elements)
2. When replacements are rare, most threads skip quickly
Also, the `if (v > vals[0])` is very cheap. The expensive part is `heap_replace` which only runs for a small fraction of elements after the heap is full.
For the first K elements per thread, we can avoid the heap entirely and just fill the array. But since each thread processes strided elements (tid, tid+THREADS, tid+2*THREADS, ...), the "first K elements" are spread out. It's simpler to just use the heap from the start.
One optimization: initialize vals with the first K elements (or as many as available), then build the heap. This avoids the initial -inf replacements.
```cuda
// Load first K elements (or as many as this thread sees)
int count = 0;
for (int i = tid; i < n && count < K; i += blockDim.x) {
vals[count] = row_x[i];
idxs[count] = i;
count++;
}
// If we have fewer than K elements, fill rest with -inf
for (int i = count; i < K; i++) {
vals[i] = -INFINITY;
idxs[i] = -1;
}
// Build min-heap
for (int i = K / 2 - 1; i >= 0; i--) {
heap_sift_down<K>(vals, idxs, i);
}
// Process remaining elements
for (int i = tid + count * blockDim.x; i < n; i += blockDim.x) {
float v = row_x[i];
if (v > vals[0]) {
heap_replace_root<K>(vals, idxs, v, i);
}
}
```
But wait, `count` is usually 1 or 2 for the first loop (since blockDim.x is 128 or 256 and n is large). So this doesn't save much. For n=4096 and 256 threads, each thread sees exactly 16 elements. For K=1, count=1. For K=8, count might be 1 if 256 threads and n=4096.
Actually, the first loop `i = tid; i < n && count < K; i += blockDim.x` gives each thread at most 1 element (since blockDim.x >> K for most cases). So this is basically just loading the first element and then building a heap of mostly -infs. Not very helpful.
I'll just stick with initializing to -inf.
Actually, there's another optimization for the case where each thread sees only a few elements. For n=4096 and 256 threads, each thread sees 16 elements. For K=64, 16 < 64, so each thread's local top-k is just its 16 elements (sorted) plus 48 -infs. In this case, the heap is overkill - we can just load all elements into an array and sort it.
I can add a branch:
```cuda
if (elements_per_thread <= K) {
// Load all elements, sort, done
} else {
// Use heap
}
```
For n=4096 and 256 threads: elements_per_thread = 16. For K=1, 8, 16, 32, 64, only K=1 and K=8 have elements_per_thread > K.
For n=8192 and 256 threads: elements_per_thread = 32. K=1, 8, 16 have elements_per_thread > K.
For n=12000 and 256 threads: elements_per_thread = 47. K=1, 8, 16, 32 have elements_per_thread > K.
For n=16384 and 256 threads: elements_per_thread = 64. K=1, 8, 16, 32 have elements_per_thread > K. K=64 has elements_per_thread == K.
For n=131072 and 128 threads: elements_per_thread = 1024. All K have elements_per_thread > K.
This branch could help for small element counts. But adding branches per thread might complicate the code. Let me keep it simple and just use the heap for all cases. The overhead of processing 16 elements with a heap of size 64 is minimal.
Actually, let me do a quick calculation. For 16 elements and K=64:
- Initialize 64 elements to -inf: trivial
- Process 16 elements: each does `if (v > vals[0])` (always true initially), then heap_replace
- After first element, vals[0] is no longer -inf. After a few elements, sift-down happens.
- Total operations: ~16 * 26 = 416 operations.
- Then bitonic sort 64 elements: 1152 comparisons.
- Total: ~1568 operations.
vs. loading 16 elements, filling rest with -inf, sorting 64: still 1152 comparisons. So the heap adds ~416 operations, which is small.
OK, I'll just use the heap universally. It's simpler.
Now let me think about another thing: the `__syncwarp()` in the warp-level bitonic sort. Modern GPUs (Volta+) have independent thread scheduling, so `__syncwarp()` is needed to ensure all threads in the warp have written to shared memory before others read.
Similarly, `__syncthreads()` in the block-level sort.
Now let me plan the actual code structure.
I'll write a single CUDA source file with:
1. `topk1_kernel` for k=1
2. `topk_kernel<K>` template for k > 1
3. C++ wrapper function `topk_forward` that dispatches based on k
And in Python, I'll compile it with `load_inline`.
Let me draft the full code. I'll start with a simpler version and iterate.
Actually, I realize I should check if there are any existing files in the directory that might help, like sota.py. Let me also check what PyTorch version is available and the CUDA architecture.
Let me also read the problem prompt to see if there are any hints.