"""Custom top-k via Triton - values-only topk with index lookup. Approach: Stage 1 (per-chunk): each chunk does tl.topk on fp32 values to get top-K values, then finds the index of each value by scanning the chunk. Stage 2 (per-row merge): all chunks' top-K candidates are sorted to get the row's final top-K. This uses less shared memory than the packed-int64 approach since we don't need to keep indices in registers during the bitonic topk. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "topk" SUPPORTED_PRECISIONS = ["fp32"] @triton.jit def _topk_packed_kernel( x_ptr, v_ptr, i_ptr, N: tl.constexpr, K: tl.constexpr, N_REAL: tl.constexpr, ): """Per-row top-K with packed (value, index). Each program handles one row. Input: x of shape (B, N_REAL) row-major Output: v[i, k] and i_ptr[i, k] for the top-K """ pid = tl.program_id(0) row_start = pid * N_REAL offs = tl.arange(0, N) indices = tl.arange(0, N).to(tl.int64) mask = offs < N_REAL x = tl.load(x_ptr + row_start + offs, mask=mask, other=-float('inf')) # Convert fp32 to sortable uint32 u32 = x.to(tl.uint32, bitcast=True) sign = (u32 >> 31) & 1 # Standard "sortable" representation (works under UNSIGNED comparison): out_nonneg = u32 | 0x80000000 out_neg = u32 ^ 0xFFFFFFFF s = tl.where(sign == 0, out_nonneg, out_neg) # Flip the sign bit so signed int64 ordering matches float ordering s_xor = s ^ 0x80000000 # Pack: high 32 bits = sortable, low 32 bits = index packed = (s_xor.to(tl.int64) << 32) | indices # Sort descending; first K are top-K sorted_packed = tl.sort(packed, descending=True) sorted_idx = (sorted_packed & 0xFFFFFFFF).to(tl.int32) sorted_v = tl.load(x_ptr + row_start + sorted_idx.to(tl.int64)) out_offs = tl.arange(0, N) out_mask = out_offs < K tl.store(v_ptr + pid * K + out_offs, sorted_v, mask=out_mask) tl.store(i_ptr + pid * K + out_offs, sorted_idx.to(tl.int64), mask=out_mask) @triton.jit def _topk_chunk_values_kernel( x_ptr, v_ptr, i_ptr, N: tl.constexpr, K: tl.constexpr, CHUNK_SIZE: tl.constexpr, NUM_CHUNKS: tl.constexpr, ): """Per-chunk top-K using values-only topk + index lookup. Each program handles one (row, chunk) pair. Output: v[pid_b, pid_c, k] and i[pid_b, pid_c, k] """ pid_b = tl.program_id(0) pid_c = tl.program_id(1) is_last = pid_c == NUM_CHUNKS - 1 # The actual chunk size (smaller for last chunk) # We need to handle the last chunk specially, but for simplicity # assume all chunks are CHUNK_SIZE (caller pads as needed) row_start = pid_b * N chunk_start = pid_c * CHUNK_SIZE offs = tl.arange(0, CHUNK_SIZE) mask = (chunk_start + offs) < N x = tl.load(x_ptr + row_start + chunk_start + offs, mask=mask, other=-float('inf')) # Get top-K values topk_v = tl.topk(x, K) # shape (K,) # For each top-K value, find the first position in x where it matches # match[i, k] = (x[i] == topk_v[k]) # We want the smallest i where match[i, k] is True # argmin of (1 - match) gives the smallest i with match[i, k] = True (i.e., 1 - match = 0) match = (x[:, None] == topk_v[None, :]) # (CHUNK_SIZE, K) bool # For each k, find smallest i with match[i, k] = True # Convert to int32 for argmin inv_match = 1 - match.to(tl.int32) # argmin over axis 0 gives (K,) first_idx = tl.argmin(inv_match, axis=0).to(tl.int64) + chunk_start # Hmm wait, this is argmin over CHUNK_SIZE elements, output shape (K,) # But we need to be careful: argmin finds the smallest value, which is 0 for match=True # The index of the first 0 is the first True position out_offs = tl.arange(0, K) out_offset = pid_b * NUM_CHUNKS * K + pid_c * K tl.store(v_ptr + out_offset + out_offs, topk_v) tl.store(i_ptr + out_offset + out_offs, first_idx) @triton.jit def _topk_merge_kernel( cand_v_ptr, cand_i_ptr, v_ptr, i_ptr, N: tl.constexpr, K: tl.constexpr, N_REAL: tl.constexpr, ): """Per-row top-K from pre-computed candidates (value, index) pairs. Input: cand_v and cand_i of shape (B, N_REAL) row-major Output: v[i, k] and i_ptr[i, k] for the row's top-K """ pid = tl.program_id(0) row_start = pid * N_REAL offs = tl.arange(0, N) mask = offs < N_REAL cand_v = tl.load(cand_v_ptr + row_start + offs, mask=mask, other=-float('inf')) cand_i = tl.load(cand_i_ptr + row_start + offs, mask=mask, other=0) # Pack with positional index (so we can look up both cand_v and cand_i after sort) u32 = cand_v.to(tl.uint32, bitcast=True) sign = (u32 >> 31) & 1 out_nonneg = u32 | 0x80000000 out_neg = u32 ^ 0xFFFFFFFF s = tl.where(sign == 0, out_nonneg, out_neg) s_xor = s ^ 0x80000000 pos_idx = offs.to(tl.int64) # positional index in candidates array packed = (s_xor.to(tl.int64) << 32) | pos_idx sorted_packed = tl.sort(packed, descending=True) sorted_pos = (sorted_packed & 0xFFFFFFFF).to(tl.int64) sorted_v = tl.load(cand_v_ptr + row_start + sorted_pos, mask=mask, other=-float('inf')) sorted_i = tl.load(cand_i_ptr + row_start + sorted_pos, mask=mask, other=0) out_offs = tl.arange(0, N) out_mask = out_offs < K tl.store(v_ptr + pid * K + out_offs, sorted_v, mask=out_mask) tl.store(i_ptr + pid * K + out_offs, sorted_i, mask=out_mask) def _next_pow2(n: int) -> int: p = 1 while p < n: p *= 2 return p # Heuristic for max N that fits in single-block packed topk # 99KB shared mem / 8 bytes per element = ~12K elements # Plus intermediate storage for bitonic topk. Conservative: 4096. SIMPLE_MAX_N = 4096 class Model(nn.Module): def __init__(self, batch: int, n: int, k: int): super().__init__() self.batch, self.n, self.k = batch, n, k self.register_buffer("_dummy", torch.zeros(1)) if n <= 8192: self._strategy = "simple" self._n_padded = _next_pow2(n) else: # Chunk size must be power of 2 for tl.topk # The values-only topk uses less shared mem than the packed topk # so we can use larger chunks # For 8192: 8192 * 4 (fp32) + intermediate = ~32KB. Fits. # For K up to 64, the bitonic topk needs log_k * 2^log_k = 64 * 64 = 4KB chunk_size = 8192 num_chunks = (n + chunk_size - 1) // chunk_size while num_chunks * k > 4096: chunk_size *= 2 num_chunks = (n + chunk_size - 1) // chunk_size chunk_size = min(chunk_size, 8192) # cap at 8192 self._strategy = "chunked" self._chunk_size = chunk_size self._num_chunks = num_chunks # Total candidates per row (padded to power of 2 for the merge) self._total_cands = num_chunks * k self._total_padded = _next_pow2(self._total_cands) def forward(self, x: torch.Tensor): v = torch.empty(self.batch, self.k, device=x.device, dtype=torch.float32) i = torch.empty(self.batch, self.k, device=x.device, dtype=torch.int64) if self._strategy == "simple": _topk_packed_kernel[(self.batch,)]( x, v, i, N=self._n_padded, K=self.k, N_REAL=self.n, ) else: # Stage 1: per-chunk top-K using values-only topk cand_v = torch.empty(self.batch, self._num_chunks, self.k, device=x.device, dtype=torch.float32) cand_i = torch.empty(self.batch, self._num_chunks, self.k, device=x.device, dtype=torch.int64) _topk_chunk_values_kernel[(self.batch, self._num_chunks)]( x, cand_v, cand_i, N=self.n, K=self.k, CHUNK_SIZE=self._chunk_size, NUM_CHUNKS=self._num_chunks, ) # Stage 2: merge per row cand_v_flat = cand_v.view(self.batch, self._num_chunks * self.k) cand_i_flat = cand_i.view(self.batch, self._num_chunks * self.k) _topk_merge_kernel[(self.batch,)]( cand_v_flat, cand_i_flat, v, i, N=self._total_padded, K=self.k, N_REAL=self._num_chunks * self.k, ) return v, i