import torch import torch.nn as nn import triton import triton.language as tl def next_power_of_2(n): return 1 if n <= 1 else 2**(n - 1).bit_length() @triton.jit def topk_phase1_kernel( X_ptr, Workspace_ptr, Out_vals_ptr, Out_idxs_ptr, N: tl.constexpr, P: tl.constexpr, K: tl.constexpr, B: tl.constexpr, BLOCKS_PER_ROW: tl.constexpr, SINGLE_PASS: tl.constexpr, ): pid = tl.program_id(0) row_idx = pid // BLOCKS_PER_ROW block_col_idx = pid % BLOCKS_PER_ROW N_per_block = (N + BLOCKS_PER_ROW - 1) // BLOCKS_PER_ROW cols_start = block_col_idx * N_per_block # Initialize running top-P with 0 running_top = tl.zeros([P], dtype=tl.uint64) steps = (N_per_block + B - 1) // B for s in range(steps): chunk_cols = cols_start + s * B + tl.arange(0, B) # Load values mask = (chunk_cols < N) & (chunk_cols < cols_start + N_per_block) vals = tl.load(X_ptr + row_idx * N + chunk_cols, mask=mask, other=-1e38) # Pack float to sortable uint32 vals_u32 = vals.to(tl.uint32, bitcast=True) is_neg = (vals_u32 & 0x80000000) != 0 mask_xor = tl.where(is_neg, 0xFFFFFFFF, 0x80000000) mapped = vals_u32 ^ mask_xor # Combine mapped float and column index packed = (mapped.to(tl.uint64) << 32) | chunk_cols.to(tl.uint64) # Find top-P of this chunk if P == 1: chunk_top = tl.expand_dims(tl.max(packed, axis=0), 0) else: chunk_top = tl.topk(packed, k=P) # Merge running_top and chunk_top joined = tl.join(running_top, chunk_top) combined = tl.reshape(joined, [2 * P]) if P == 1: running_top = tl.expand_dims(tl.max(combined, axis=0), 0) else: running_top = tl.topk(combined, k=P) if SINGLE_PASS: # Unpack and write directly to output unpacked_mapped = (running_top >> 32).to(tl.uint32) is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0 orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF) unpacked_u32 = unpacked_mapped ^ orig_mask unpacked_x = unpacked_u32.to(tl.float32, bitcast=True) unpacked_idx = (running_top & 0xFFFFFFFF).to(tl.int64) out_cols = tl.arange(0, P) mask = out_cols < K tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x, mask=mask) tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx, mask=mask) else: # Write running_top to Workspace out_cols = tl.arange(0, P) tl.store(Workspace_ptr + (row_idx * BLOCKS_PER_ROW + block_col_idx) * P + out_cols, running_top) @triton.jit def topk_phase2_kernel( Workspace_ptr, Out_vals_ptr, Out_idxs_ptr, P: tl.constexpr, K: tl.constexpr, BLOCKS_PER_ROW: tl.constexpr, M: tl.constexpr, ): row_idx = tl.program_id(0) # Each row has BLOCKS_PER_ROW * P intermediate results in Workspace cols = tl.arange(0, M) packed = tl.load(Workspace_ptr + row_idx * M + cols) # Find top-K if K == 1: top_packed = tl.expand_dims(tl.max(packed, axis=0), 0) else: top_packed = tl.topk(packed, k=K) # Unpack unpacked_mapped = (top_packed >> 32).to(tl.uint32) is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0 orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF) unpacked_u32 = unpacked_mapped ^ orig_mask unpacked_x = unpacked_u32.to(tl.float32, bitcast=True) unpacked_idx = (top_packed & 0xFFFFFFFF).to(tl.int64) out_cols = tl.arange(0, K) tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x) tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx) class Model(nn.Module): def __init__(self, batch: int, n: int, k: int): super().__init__() self.batch = batch self.n = n self.k = k self.register_buffer("_dummy", torch.zeros(1)) self.p = next_power_of_2(self.k) # Decide BLOCKS_PER_ROW, B, and num_warps dynamically with optimal routing if self.batch == 1 and self.n == 131072 and self.k == 64: self.blocks_per_row = 64 self.b = 2048 self.w1 = 4 self.w2 = 4 elif self.batch == 64 and self.n == 8192 and self.k == 8: self.blocks_per_row = 1 self.b = 2048 self.w1 = 8 self.w2 = 4 elif self.batch == 32 and self.n == 16384 and self.k == 32: self.blocks_per_row = 4 self.b = 2048 self.w1 = 16 self.w2 = 4 elif self.batch == 16 and self.n == 12000 and self.k == 16: self.blocks_per_row = 4 self.b = 2048 self.w1 = 8 self.w2 = 16 elif self.batch == 128 and self.n == 4096 and self.k == 1: self.blocks_per_row = 1 self.b = 2048 self.w1 = 16 self.w2 = 4 else: # General fallback logic if self.batch >= 64: self.blocks_per_row = 1 elif self.batch >= 32: self.blocks_per_row = 2 elif self.batch >= 16: self.blocks_per_row = 4 else: self.blocks_per_row = 64 n_per_block = (self.n + self.blocks_per_row - 1) // self.blocks_per_row if n_per_block >= 2048: self.b = 2048 else: self.b = next_power_of_2(n_per_block) if self.b < 16: self.b = 16 self.w1 = 4 self.w2 = 4 def forward(self, x: torch.Tensor): # Outputs out_vals = torch.empty(self.batch, self.k, dtype=torch.float32, device=x.device) out_idxs = torch.empty(self.batch, self.k, dtype=torch.int64, device=x.device) if self.blocks_per_row == 1: # Single pass is sufficient grid = (self.batch,) topk_phase1_kernel[grid]( x, None, out_vals, out_idxs, N=self.n, P=self.p, K=self.k, B=self.b, BLOCKS_PER_ROW=1, SINGLE_PASS=True, num_warps=self.w1, ) else: # Two-pass reduction # Workspace of uint64 shape (batch, BLOCKS_PER_ROW, p) workspace = torch.empty(self.batch, self.blocks_per_row, self.p, dtype=torch.int64, device=x.device) grid1 = (self.batch * self.blocks_per_row,) topk_phase1_kernel[grid1]( x, workspace, None, None, N=self.n, P=self.p, K=self.k, B=self.b, BLOCKS_PER_ROW=self.blocks_per_row, SINGLE_PASS=False, num_warps=self.w1, ) grid2 = (self.batch,) topk_phase2_kernel[grid2]( workspace, out_vals, out_idxs, P=self.p, K=self.k, BLOCKS_PER_ROW=self.blocks_per_row, M=self.blocks_per_row * self.p, num_warps=self.w2, ) return out_vals, out_idxs # Standard shims required by check/benchmark batch = 64 n = 8192 k = 8 def get_inputs(): x = torch.randn(batch, n, dtype=torch.float32) return [x] def get_init_inputs(): return [batch, n, k]