import torch
import torch.nn as nn
import triton
import triton.language as tl
def next_power_of_2(n):
return 1 if n <= 1 else 2**(n - 1).bit_length()
@triton.jit
def topk_phase1_kernel(
X_ptr,
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
N: tl.constexpr,
P: tl.constexpr,
K: tl.constexpr,
B: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
SINGLE_PASS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // BLOCKS_PER_ROW
block_col_idx = pid % BLOCKS_PER_ROW
N_per_block = (N + BLOCKS_PER_ROW - 1) // BLOCKS_PER_ROW
cols_start = block_col_idx * N_per_block
# Initialize running top-P with 0
running_top = tl.zeros([P], dtype=tl.uint64)
steps = (N_per_block + B - 1) // B
for s in range(steps):
chunk_cols = cols_start + s * B + tl.arange(0, B)
# Load values
mask = (chunk_cols < N) & (chunk_cols < cols_start + N_per_block)
vals = tl.load(X_ptr + row_idx * N + chunk_cols, mask=mask, other=-1e38)
# Pack float to sortable uint32
vals_u32 = vals.to(tl.uint32, bitcast=True)
is_neg = (vals_u32 & 0x80000000) != 0
mask_xor = tl.where(is_neg, 0xFFFFFFFF, 0x80000000)
mapped = vals_u32 ^ mask_xor
# Combine mapped float and column index
packed = (mapped.to(tl.uint64) << 32) | chunk_cols.to(tl.uint64)
# Find top-P of this chunk
if P == 1:
chunk_top = tl.expand_dims(tl.max(packed, axis=0), 0)
else:
chunk_top = tl.topk(packed, k=P)
# Merge running_top and chunk_top
joined = tl.join(running_top, chunk_top)
combined = tl.reshape(joined, [2 * P])
if P == 1:
running_top = tl.expand_dims(tl.max(combined, axis=0), 0)
else:
running_top = tl.topk(combined, k=P)
if SINGLE_PASS:
# Unpack and write directly to output
unpacked_mapped = (running_top >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (running_top & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, P)
mask = out_cols < K
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x, mask=mask)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx, mask=mask)
else:
# Write running_top to Workspace
out_cols = tl.arange(0, P)
tl.store(Workspace_ptr + (row_idx * BLOCKS_PER_ROW + block_col_idx) * P + out_cols, running_top)
@triton.jit
def topk_phase2_kernel(
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
P: tl.constexpr,
K: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
M: tl.constexpr,
):
row_idx = tl.program_id(0)
# Each row has BLOCKS_PER_ROW * P intermediate results in Workspace
cols = tl.arange(0, M)
packed = tl.load(Workspace_ptr + row_idx * M + cols)
# Find top-K
if K == 1:
top_packed = tl.expand_dims(tl.max(packed, axis=0), 0)
else:
top_packed = tl.topk(packed, k=K)
# Unpack
unpacked_mapped = (top_packed >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (top_packed & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, K)
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx)
class Model(nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch = batch
self.n = n
self.k = k
self.register_buffer("_dummy", torch.zeros(1))
self.p = next_power_of_2(self.k)
# Decide BLOCKS_PER_ROW, B, and num_warps dynamically with optimal routing
if self.batch == 1 and self.n == 131072 and self.k == 64:
self.blocks_per_row = 64
self.b = 2048
self.w1 = 4
self.w2 = 4
elif self.batch == 64 and self.n == 8192 and self.k == 8:
self.blocks_per_row = 1
self.b = 2048
self.w1 = 8
self.w2 = 4
elif self.batch == 32 and self.n == 16384 and self.k == 32:
self.blocks_per_row = 4
self.b = 2048
self.w1 = 16
self.w2 = 4
elif self.batch == 16 and self.n == 12000 and self.k == 16:
self.blocks_per_row = 4
self.b = 2048
self.w1 = 8
self.w2 = 16
elif self.batch == 128 and self.n == 4096 and self.k == 1:
self.blocks_per_row = 1
self.b = 2048
self.w1 = 16
self.w2 = 4
else:
# General fallback logic
if self.batch >= 64:
self.blocks_per_row = 1
elif self.batch >= 32:
self.blocks_per_row = 2
elif self.batch >= 16:
self.blocks_per_row = 4
else:
self.blocks_per_row = 64
n_per_block = (self.n + self.blocks_per_row - 1) // self.blocks_per_row
if n_per_block >= 2048:
self.b = 2048
else:
self.b = next_power_of_2(n_per_block)
if self.b < 16:
self.b = 16
self.w1 = 4
self.w2 = 4
def forward(self, x: torch.Tensor):
# Outputs
out_vals = torch.empty(self.batch, self.k, dtype=torch.float32, device=x.device)
out_idxs = torch.empty(self.batch, self.k, dtype=torch.int64, device=x.device)
if self.blocks_per_row == 1:
# Single pass is sufficient
grid = (self.batch,)
topk_phase1_kernel[grid](
x,
None,
out_vals,
out_idxs,
N=self.n,
P=self.p,
K=self.k,
B=self.b,
BLOCKS_PER_ROW=1,
SINGLE_PASS=True,
num_warps=self.w1,
)
else:
# Two-pass reduction
# Workspace of uint64 shape (batch, BLOCKS_PER_ROW, p)
workspace = torch.empty(self.batch, self.blocks_per_row, self.p, dtype=torch.int64, device=x.device)
grid1 = (self.batch * self.blocks_per_row,)
topk_phase1_kernel[grid1](
x,
workspace,
None,
None,
N=self.n,
P=self.p,
K=self.k,
B=self.b,
BLOCKS_PER_ROW=self.blocks_per_row,
SINGLE_PASS=False,
num_warps=self.w1,
)
grid2 = (self.batch,)
topk_phase2_kernel[grid2](
workspace,
out_vals,
out_idxs,
P=self.p,
K=self.k,
BLOCKS_PER_ROW=self.blocks_per_row,
M=self.blocks_per_row * self.p,
num_warps=self.w2,
)
return out_vals, out_idxs
# Standard shims required by check/benchmark
batch = 64
n = 8192
k = 8
def get_inputs():
x = torch.randn(batch, n, dtype=torch.float32)
return [x]
def get_init_inputs():
return [batch, n, k]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_112732_gemini_gemini-3.5-flash_05_topk_bitonic/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-13T18:52:18.854918+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-13T18:52:19.081989+00:00 elapsed_s=0.227 ms=0.039904
shape=0 variant=solution tflops=0.013 gbps=13.158 ms=0.040
shape=0 solution_peak_fraction=0.0073
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T18:52:19.083958+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T18:52:19.090770+00:00 elapsed_s=0.007 ms=0.027040
shape=1 variant=solution tflops=0.078 gbps=77.785 ms=0.027
shape=1 solution_peak_fraction=0.0432
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T18:52:19.092426+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T18:52:19.131594+00:00 elapsed_s=0.039 ms=0.027360
shape=2 variant=solution tflops=0.077 gbps=77.099 ms=0.027
shape=2 solution_peak_fraction=0.0428
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T18:52:19.132408+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T18:52:19.139761+00:00 elapsed_s=0.007 ms=0.023728
shape=3 variant=solution tflops=0.032 gbps=32.496 ms=0.024
shape=3 solution_peak_fraction=0.0181
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T18:52:19.141468+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T18:52:19.147361+00:00 elapsed_s=0.006 ms=0.013792
shape=4 variant=solution tflops=0.152 gbps=152.167 ms=0.014
shape=4 solution_peak_fraction=0.0845
peak_fraction: 0.0290
RESULT: LOW
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_112732_gemini_gemini-3.5-flash_05_topk_bitonic/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_112732_gemini_gemini-3.5-flash_05_topk_bitonic/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_112732_gemini_gemini-3.5-flash_05_topk_bitonic/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_112732_gemini_gemini-3.5-flash_05_topk_bitonic/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_112732_gemini_gemini-3.5-flash_05_topk_bitonic/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_112732_gemini_gemini-3.5-flash_05_topk_bitonic/repo
Installed 56 packages in 64ms
PASS
{
"run_id": "20260613_112732_gemini_gemini-3.5-flash_05_topk_bitonic",
"run_group": "",
"problem": "05_topk_bitonic",
"harness": "gemini",
"model": "gemini-3.5-flash",
"reasoning_effort": "",
"started_at": "2026-06-13T11:27:32-06:00",
"harness_finished_at": "2026-06-13T12:43:06-06:00",
"finished_at": "2026-06-13T12:52:19-06:00",
"start_epoch": 1781371652,
"harness_end_epoch": 1781376186,
"end_epoch": 1781376739,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.0290,
"template_mutated": false,
"elapsed_seconds": 4534,
"total_elapsed_seconds": 5087,
"check_elapsed_seconds": 191,
"benchmark_elapsed_seconds": 362,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 5.524702249669167,
"usage": {"input_tokens": 4444816, "output_tokens": 25049, "cache_read_tokens": 3949800, "cache_creation_tokens": null, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"summary": "Beginning the top-k kernel optimization for Blackwell GPU. First step is reading the reference implementation (reference.py), problem prompt (PROMPT.txt), and verification script (check.py) to understand the input shapes, constraints, and interface.",
"title": "Analyzing Reference and Requirements",
"strategic_intent": "Analyzing the reference implementation, prompt, and testing environment to plan the top-k optimization."
}## 📂 Topic: **Analyzing Reference and Requirements**
**Summary:**
Beginning the top-k kernel optimization for Blackwell GPU. First step is reading the reference implementation (reference.py), problem prompt (PROMPT.txt), and verification script (check.py) to understand the input shapes, constraints, and interface.
> [!STRATEGY]
> **Intent:** Analyzing the reference implementation, prompt, and testing environment to plan the top-k optimization.{
"file_path": "reference.py"
}{
"file_path": "PROMPT.txt"
}{
"file_path": "check.py"
}{
"file_path": "shapes.py"
}{
"file_path": "problem.yaml"
}{
"file_path": "benchmark.py"
}{
"pattern": "solution.py"
}No files found{
"file_path": "sota.py"
}{
"command": "uv run python -c \"import torch; import triton; print('Torch:', torch.__version__, 'CUDA:', torch.version.cuda); print('Triton:', triton.__version__)\"",
"description": "Checking PyTorch and Triton versions in the workspace uv environment"
}{
"content": "<846 chars \u2014 see below>",
"file_path": "test_triton.py"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_sort_kernel(
X_ptr, Y_vals_ptr, Y_idxs_ptr,
N: tl.constexpr,
):
# Load 1D row
cols = tl.arange(0, N)
x = tl.load(X_ptr + cols)
# Sort
# Triton tl.sort can sort multiple tensors (like values and their indices)
sorted_vals, sorted_idxs = tl.sort((x, cols), descending=True)
# Store
tl.store(Y_vals_ptr + cols, sorted_vals)
tl.store(Y_idxs_ptr + cols, sorted_idxs)
def run():
N = 16
x = torch.randn(N, device='cuda')
y_vals = torch.empty(N, device='cuda')
y_idxs = torch.empty(N, dtype=torch.int32, device='cuda')
test_sort_kernel[(1,)](x, y_vals, y_idxs, N=N)
print("x:", x)
print("sorted_vals:", y_vals)
print("sorted_idxs:", y_idxs)
if __name__ == "__main__":
run()
{
"description": "Running test_triton.py to verify Triton multi-tensor sort functionality",
"command": "uv run python test_triton.py"
}{
"command": "uv run python -c \"import triton.language as tl; import inspect; print(inspect.getsource(tl.sort))\"",
"description": "Inspecting triton.language.sort source code"
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"pattern": "def sort\\(",
"dir_path": "/workspace/.venv/lib/python3.11/site-packages/triton"
}Path not in workspace: Attempted path "/workspace/.venv/lib/python3.11/site-packages/triton" resolves outside the allowed workspace directories: /workspace/problems/05_topk_bitonic or the project temp directory: /home/agent/.gemini/tmp/05-topk-bitonic{
"command": "grep -rn \"def sort\" /workspace/.venv/lib/python3.11/site-packages/triton/",
"description": "Searching for \"def sort\" in virtual environment Triton packages"
}{
"description": "Viewing the sort implementation details in Triton's standard.py",
"command": "sed -n '420,480p' /workspace/.venv/lib/python3.11/site-packages/triton/language/standard.py"
}{
"file_path": "test_triton.py",
"content": "<3221 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_pack_kernel(
X_ptr, Y_vals_ptr, Y_idxs_ptr,
N: tl.constexpr,
):
cols = tl.arange(0, N)
x = tl.load(X_ptr + cols)
# Bitcast float32 to int32
x_i32 = x.to(tl.int32, bitcast=True)
# Map to sortable uint32/int32
# mask = 0x80000000 if x_i32 >= 0 else 0xFFFFFFFF
# Let's do this using bitwise ops
# if sign bit is set (negative), we want 0xFFFFFFFF, else 0x80000000
# sign bit is x_i32 >> 31 (arithmetic shift)
# If negative, sign_mask is all 1s (0xFFFFFFFF)
# If positive, sign_mask is all 0s (0x00000000)
sign_mask = x_i32 >> 31
# mask = sign_mask | 0x80000000
# Note: in Python/Triton, 0x80000000 might be treated as unsigned or signed.
# To be safe, we can use signed hex or decimal: -2147483648 is 0x80000000
# Let's write it out:
mask = sign_mask ^ 0x80000000
mapped = x_i32 ^ mask
# Cast mapped to uint64 and shift to high 32 bits
mapped_u64 = mapped.to(tl.int64) & 0xFFFFFFFF
shifted = mapped_u64 << 32
# Combine with column indices (0 to N-1)
cols_u64 = cols.to(tl.int64) & 0xFFFFFFFF
packed = shifted | cols_u64
# Sort descending
sorted_packed = tl.sort(packed, descending=True)
# Unpack
unpacked_mapped = (sorted_packed >> 32) & 0xFFFFFFFF
# Convert back to i32:
# If unpacked_mapped >= 0x80000000: mask = 0x80000000
# else: mask = 0xFFFFFFFF
# We can determine the mask from the unpacked_mapped sign bit (bit 31)
# If bit 31 is 1: mask is 0x80000000. If bit 31 is 0: mask is 0xFFFFFFFF.
# So mask = (unpacked_mapped >> 31 ^ 1) * 0x7FFFFFFF ^ 0xFFFFFFFF or similar.
# Let's simplify:
# If original float >= 0, mapped >= 0x80000000. So unpacked_mapped >= 0x80000000.
# If so, original sign bit was 0, and we XORed with 0x80000000.
# If unpacked_mapped < 0x80000000, original sign bit was 1, and we XORed with 0xFFFFFFFF.
# Let's check:
# If unpacked_mapped & 0x80000000:
# orig = unpacked_mapped ^ 0x80000000
# else:
# orig = unpacked_mapped ^ 0xFFFFFFFF
is_pos = (unpacked_mapped & 0x80000000) != 0
# In Triton, we can use select:
orig_mask = tl.where(is_pos, 0x80000000, 0xFFFFFFFF)
unpacked_i32 = (unpacked_mapped ^ orig_mask).to(tl.int32)
# Bitcast back to float32
unpacked_x = unpacked_i32.to(tl.float32, bitcast=True)
unpacked_idx = sorted_packed & 0xFFFFFFFF
# Store
tl.store(Y_vals_ptr + cols, unpacked_x)
tl.store(Y_idxs_ptr + cols, unpacked_idx)
def run():
N = 16
x = torch.tensor([-3.5, 2.1, 0.0, -0.0, 1.5, -1.2, 5.4, -10.0, 12.0, -1.5, 2.1, 0.5, -0.5, 0.1, -0.1, 10.0], device='cuda')
y_vals = torch.empty(N, device='cuda')
y_idxs = torch.empty(N, dtype=torch.int64, device='cuda')
test_pack_kernel[(1,)](x, y_vals, y_idxs, N=N)
print("x:", x)
print("sorted_vals (triton):", y_vals)
print("sorted_idxs (triton):", y_idxs)
ref_vals, ref_idxs = torch.topk(x, k=N, descending=True)
print("sorted_vals (ref): ", ref_vals)
print("sorted_idxs (ref): ", ref_idxs)
if __name__ == "__main__":
run()
{
"description": "Running test_triton.py to check packed float sorting via Triton",
"command": "uv run python test_triton.py"
}{
"file_path": "test_triton.py",
"content": "<2645 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_pack_kernel(
X_ptr, Y_vals_ptr, Y_idxs_ptr,
N: tl.constexpr,
):
cols = tl.arange(0, N)
x = tl.load(X_ptr + cols)
# Bitcast float32 to uint32
x_u32 = x.to(tl.uint32, bitcast=True)
# If negative (highest bit is 1), mask is 0xFFFFFFFF.
# If positive/zero (highest bit is 0), mask is 0x80000000.
is_neg = (x_u32 & 0x80000000) != 0
mask = tl.where(is_neg, 0xFFFFFFFF, 0x80000000)
mapped = x_u32 ^ mask
# Cast mapped to uint64 and shift to high 32 bits
mapped_u64 = mapped.to(tl.uint64)
shifted = mapped_u64 << 32
# Combine with column indices (0 to N-1)
cols_u64 = cols.to(tl.uint64)
packed = shifted | cols_u64
# Sort descending
sorted_packed = tl.sort(packed, descending=True)
# Unpack
unpacked_mapped = (sorted_packed >> 32).to(tl.uint32)
# Convert back to uint32:
# If unpacked_mapped & 0x80000000 != 0, then original sign was positive/zero, and mask was 0x80000000.
# If unpacked_mapped & 0x80000000 == 0, then original sign was negative, and mask was 0xFFFFFFFF.
# Let's verify:
# If original was positive, mapped >= 0x80000000. So unpacked_mapped & 0x80000000 != 0.
# If original was negative, mapped < 0x80000000. So unpacked_mapped & 0x80000000 == 0.
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
# Bitcast back to float32
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (sorted_packed & 0xFFFFFFFF).to(tl.int64)
# Store
tl.store(Y_vals_ptr + cols, unpacked_x)
tl.store(Y_idxs_ptr + cols, unpacked_idx)
def run():
N = 16
x = torch.tensor([-3.5, 2.1, 0.0, -0.0, 1.5, -1.2, 5.4, -10.0, 12.0, -1.5, 2.1, 0.5, -0.5, 0.1, -0.1, 10.0], device='cuda')
y_vals = torch.empty(N, device='cuda')
y_idxs = torch.empty(N, dtype=torch.int64, device='cuda')
test_pack_kernel[(1,)](x, y_vals, y_idxs, N=N)
print("x:", x)
print("sorted_vals (triton):", y_vals)
print("sorted_idxs (triton):", y_idxs)
ref_vals, ref_idxs = torch.topk(x, k=N, descending=True)
print("sorted_vals (ref): ", ref_vals)
print("sorted_idxs (ref): ", ref_idxs)
# Check if correct
correct_vals = torch.allclose(y_vals, ref_vals, atol=1e-5)
print("Vals correct?", correct_vals)
correct_idxs = torch.equal(y_idxs, ref_idxs)
print("Idxs correct?", correct_idxs)
if __name__ == "__main__":
run()
{
"command": "uv run python test_triton.py",
"description": "Running test_triton.py to check if mapped-float sorting works and matches torch.topk exactly"
}{
"file_path": "test_triton.py",
"content": "<587 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_slice_kernel(X_ptr, Y_ptr, B: tl.constexpr, P: tl.constexpr):
cols = tl.arange(0, B)
x = tl.load(X_ptr + cols)
# Try indexing
slice_indices = tl.arange(0, P)
# Can we do x[slice_indices]?
# Let's try:
y = x[slice_indices]
tl.store(Y_ptr + slice_indices, y)
def run():
x = torch.randn(16, device='cuda')
y = torch.zeros(4, device='cuda')
test_slice_kernel[(1,)](x, y, B=16, P=4)
print("x:", x)
print("y:", y)
if __name__ == "__main__":
run()
{
"command": "uv run python test_triton.py",
"description": "Running test_triton.py to check register tensor slicing in Triton"
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"file_path": "test_triton.py",
"content": "<587 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_topk_kernel(X_ptr, Y_ptr, B: tl.constexpr, P: tl.constexpr):
cols = tl.arange(0, B)
x = tl.load(X_ptr + cols)
# Try tl.topk
# Wait, tl.topk sorts descending and returns top k
y = tl.topk(x, k=P)
slice_indices = tl.arange(0, P)
tl.store(Y_ptr + slice_indices, y)
def run():
x = torch.randn(16, device='cuda')
y = torch.zeros(4, device='cuda')
test_topk_kernel[(1,)](x, y, B=16, P=4)
print("x:", x)
print("y:", y)
if __name__ == "__main__":
run()
{
"description": "Running test_triton.py to check tl.topk functionality in Triton",
"command": "uv run python test_triton.py"
}{
"content": "<667 chars \u2014 see below>",
"file_path": "test_triton.py"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_topk_u64_kernel(X_ptr, Y_ptr, B: tl.constexpr, P: tl.constexpr):
cols = tl.arange(0, B)
x = tl.load(X_ptr + cols)
# Cast to uint64
x_u64 = x.to(tl.uint64)
y = tl.topk(x_u64, k=P)
slice_indices = tl.arange(0, P)
tl.store(Y_ptr + slice_indices, y)
def run():
# Make sure we use uint64 values
x = torch.randint(0, 1000000, (16,), dtype=torch.int64, device='cuda')
y = torch.zeros(4, dtype=torch.int64, device='cuda')
test_topk_u64_kernel[(1,)](x, y, B=16, P=4)
print("x:", x)
print("y:", y)
if __name__ == "__main__":
run()
{
"description": "Running test_triton.py to check tl.topk on uint64 tensors",
"command": "uv run python test_triton.py"
}{
"command": "grep -rn \"def cat\" /workspace/.venv/lib/python3.11/site-packages/triton/",
"description": "Searching for \"def cat\" in Triton source files"
}{
"file_path": "test_triton.py",
"content": "<542 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_cat_kernel(X_ptr, Y_ptr, P: tl.constexpr):
cols1 = tl.arange(0, P)
cols2 = tl.arange(P, 2*P)
x1 = tl.load(X_ptr + cols1)
x2 = tl.load(X_ptr + cols2)
# Try tl.cat
y = tl.cat(x1, x2)
tl.store(Y_ptr + tl.arange(0, 2*P), y)
def run():
x = torch.randn(8, device='cuda')
y = torch.zeros(8, device='cuda')
test_cat_kernel[(1,)](x, y, P=4)
print("x:", x)
print("y:", y)
if __name__ == "__main__":
run()
{
"command": "uv run python test_triton.py",
"description": "Running test_triton.py to check register tensor concatenation using tl.cat"
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"description": "Searching for examples of tl.cat in Triton's virtual environment files",
"command": "grep -rn \"tl.cat\" /workspace/.venv/lib/python3.11/site-packages/triton/ 2>/dev/null | head -n 30"
}{
"description": "Listing Triton language functions related to concatenation or joining",
"command": "uv run python -c \"import triton.language as tl; print([x for x in dir(tl) if 'cat' in x or 'join' in x or 'concat' in x])\""
}{
"command": "uv run python -c \"import triton.language as tl; help(tl.join); help(tl.cat)\"",
"description": "Printing the docstrings of tl.join and tl.cat to understand their parameters"
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"description": "Checking GPU status and process usage with nvidia-smi",
"command": "nvidia-smi"
}{
"description": "Listing all system processes with ps aux",
"command": "ps aux"
}{
"description": "Checking for stale lock files in common temporary directories",
"command": "ls -la /tmp /var/tmp /dev/shm 2>/dev/null"
}{
"command": "uv run python -c \"import triton.language as tl; print('JOIN:', tl.join.__doc__); print('CAT:', tl.cat.__doc__)\"",
"description": "Printing docstrings of tl.join and tl.cat directly"
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"command": "/workspace/.venv/bin/python -c \"import torch; print(torch.__version__)\"",
"description": "Directly checking Python and Torch from the virtual environment without uv"
}{
"description": "Printing docstrings for tl.join and tl.cat using the direct Python binary",
"command": "/workspace/.venv/bin/python -c \"import triton.language as tl; print('JOIN:', tl.join.__doc__); print('CAT:', tl.cat.__doc__)\""
}{
"content": "<689 chars \u2014 see below>",
"file_path": "test_triton.py"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_join_reshape_kernel(X_ptr, Y_ptr, P: tl.constexpr):
cols1 = tl.arange(0, P)
cols2 = tl.arange(P, 2*P)
x1 = tl.load(X_ptr + cols1)
x2 = tl.load(X_ptr + cols2)
# Join into (P, 2)
joined = tl.join(x1, x2)
# Reshape into (2 * P,)
y = tl.reshape(joined, [2 * P])
tl.store(Y_ptr + tl.arange(0, 2*P), y)
def run():
x = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0], device='cuda')
y = torch.zeros(8, device='cuda')
test_join_reshape_kernel[(1,)](x, y, P=4)
print("x:", x)
print("y:", y)
if __name__ == "__main__":
run()
{
"description": "Running test_triton.py to check the join-reshape pattern",
"command": "/workspace/.venv/bin/python test_triton.py"
}{
"file_path": "solution.py",
"content": "<6324 chars \u2014 see below>"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
def next_power_of_2(n):
return 1 if n <= 1 else 2**(n - 1).bit_length()
@triton.jit
def topk_phase1_kernel(
X_ptr,
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
N: tl.constexpr,
P: tl.constexpr,
K: tl.constexpr,
B: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
SINGLE_PASS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // BLOCKS_PER_ROW
block_col_idx = pid % BLOCKS_PER_ROW
N_per_block = (N + BLOCKS_PER_ROW - 1) // BLOCKS_PER_ROW
cols_start = block_col_idx * N_per_block
# Initialize running top-P with 0
running_top = tl.zeros([P], dtype=tl.uint64)
steps = (N_per_block + B - 1) // B
for s in range(steps):
chunk_cols = cols_start + s * B + tl.arange(0, B)
# Load values
mask = (chunk_cols < N) & (chunk_cols < cols_start + N_per_block)
vals = tl.load(X_ptr + row_idx * N + chunk_cols, mask=mask, other=-1e38)
# Pack float to sortable uint32
vals_u32 = vals.to(tl.uint32, bitcast=True)
is_neg = (vals_u32 & 0x80000000) != 0
mask_xor = tl.where(is_neg, 0xFFFFFFFF, 0x80000000)
mapped = vals_u32 ^ mask_xor
# Combine mapped float and column index
packed = (mapped.to(tl.uint64) << 32) | chunk_cols.to(tl.uint64)
# Find top-P of this chunk
chunk_top = tl.topk(packed, k=P)
# Merge running_top and chunk_top
joined = tl.join(running_top, chunk_top)
combined = tl.reshape(joined, [2 * P])
running_top = tl.topk(combined, k=P)
if SINGLE_PASS:
# Unpack and write directly to output
unpacked_mapped = (running_top >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (running_top & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, P)
mask = out_cols < K
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x, mask=mask)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx, mask=mask)
else:
# Write running_top to Workspace
out_cols = tl.arange(0, P)
tl.store(Workspace_ptr + (row_idx * BLOCKS_PER_ROW + block_col_idx) * P + out_cols, running_top)
@triton.jit
def topk_phase2_kernel(
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
P: tl.constexpr,
K: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
):
row_idx = tl.program_id(0)
# Each row has BLOCKS_PER_ROW * P intermediate results in Workspace
M = BLOCKS_PER_ROW * P
cols = tl.arange(0, M)
packed = tl.load(Workspace_ptr + row_idx * M + cols)
# Find top-K
top_packed = tl.topk(packed, k=K)
# Unpack
unpacked_mapped = (top_packed >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (top_packed & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, K)
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx)
class Model(nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch = batch
self.n = n
self.k = k
self.register_buffer("_dummy", torch.zeros(1))
# Decide BLOCKS_PER_ROW dynamically
if self.batch >= 64:
self.blocks_per_row = 1
elif self.batch >= 32:
self.blocks_per_row = 2
elif self.batch >= 16:
self.blo
... (2324 more chars){
"description": "Running check.py to test solution.py across all shapes",
"command": "/workspace/.venv/bin/python check.py"
}{
"old_string": "<393 chars \u2014 see below>",
"new_string": "<387 chars \u2014 see below>",
"allow_multiple": false,
"file_path": "solution.py",
"instruction": "Add M: tl.constexpr as a parameter to topk_phase2_kernel and pass it during launching to satisfy Triton arange constexpr requirement."
}@triton.jit
def topk_phase2_kernel(
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
P: tl.constexpr,
K: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
):
row_idx = tl.program_id(0)
# Each row has BLOCKS_PER_ROW * P intermediate results in Workspace
M = BLOCKS_PER_ROW * P
cols = tl.arange(0, M)
packed = tl.load(Workspace_ptr + row_idx * M + cols)@triton.jit
def topk_phase2_kernel(
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
P: tl.constexpr,
K: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
M: tl.constexpr,
):
row_idx = tl.program_id(0)
# Each row has BLOCKS_PER_ROW * P intermediate results in Workspace
cols = tl.arange(0, M)
packed = tl.load(Workspace_ptr + row_idx * M + cols){
"dir_path": "/workspace/problems/05_topk_bitonic",
"pattern": "topk_phase2_kernel",
"include_pattern": "solution.py"
}{
"allow_multiple": false,
"file_path": "solution.py",
"instruction": "Pass M as a parameter in topk_phase2_kernel call.",
"old_string": "<269 chars \u2014 see below>",
"new_string": "<317 chars \u2014 see below>"
} grid2 = (self.batch,)
topk_phase2_kernel[grid2](
workspace,
out_vals,
out_idxs,
P=self.p,
K=self.k,
BLOCKS_PER_ROW=self.blocks_per_row,
) grid2 = (self.batch,)
topk_phase2_kernel[grid2](
workspace,
out_vals,
out_idxs,
P=self.p,
K=self.k,
BLOCKS_PER_ROW=self.blocks_per_row,
M=self.blocks_per_row * self.p,
){
"description": "Running check.py to test the modified topk_phase2_kernel launch",
"command": "/workspace/.venv/bin/python check.py"
}{
"description": "Inspecting standard.py imports to see how max/min are defined",
"command": "head -n 20 /workspace/.venv/lib/python3.11/site-packages/triton/language/standard.py"
}{
"content": "<520 chars \u2014 see below>",
"file_path": "test_triton.py"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_topk_k1_kernel(X_ptr, Y_ptr, B: tl.constexpr, P: tl.constexpr):
cols = tl.arange(0, B)
x = tl.load(X_ptr + cols)
y = tl.topk(x, k=P)
slice_indices = tl.arange(0, P)
tl.store(Y_ptr + slice_indices, y)
def run():
x = torch.randn(16, device='cuda')
y = torch.zeros(1, device='cuda')
test_topk_k1_kernel[(1,)](x, y, B=16, P=1)
print("x:", x)
print("y:", y)
if __name__ == "__main__":
run()
{
"description": "Running test_triton.py with P=1 (k=1) to reproduce the compiler error",
"command": "/workspace/.venv/bin/python test_triton.py"
}{
"file_path": "test_triton.py",
"content": "<592 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_max_kernel(X_ptr, Y_ptr, B: tl.constexpr):
cols = tl.arange(0, B)
x = tl.load(X_ptr + cols)
# Try tl.max
# Note: tl.max requires axis parameter
y = tl.max(x, axis=0)
# Store y (which is a 0D tensor / scalar)
tl.store(Y_ptr, y)
def run():
x = torch.randn(16, device='cuda')
y = torch.zeros(1, device='cuda')
test_max_kernel[(1,)](x, y, B=16)
print("x:", x)
print("max(x) ref:", x.max())
print("max(x) triton:", y)
if __name__ == "__main__":
run()
{
"command": "/workspace/.venv/bin/python test_triton.py",
"description": "Running test_triton.py to check tl.max functionality on 1D tensors"
}{
"content": "<914 chars \u2014 see below>",
"file_path": "test_triton.py"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_p1_kernel(X_ptr, Y_ptr, B: tl.constexpr):
cols1 = tl.arange(0, B)
cols2 = tl.arange(B, 2*B)
x1 = tl.load(X_ptr + cols1)
x2 = tl.load(X_ptr + cols2)
# Each is shape [16]
# Find max of each
max1 = tl.expand_dims(tl.max(x1, axis=0), 0) # shape [1]
max2 = tl.expand_dims(tl.max(x2, axis=0), 0) # shape [1]
# Join and reshape to [2]
joined = tl.join(max1, max2)
combined = tl.reshape(joined, [2])
# Find global max
global_max = tl.expand_dims(tl.max(combined, axis=0), 0) # shape [1]
tl.store(Y_ptr, global_max)
def run():
x = torch.randn(32, device='cuda')
y = torch.zeros(1, device='cuda')
test_p1_kernel[(1,)](x, y, B=16)
print("x:", x)
print("max(x) ref:", x.max())
print("max(x) triton:", y)
if __name__ == "__main__":
run()
{
"description": "Running test_triton.py with P=1 workaround to verify correctness",
"command": "/workspace/.venv/bin/python test_triton.py"
}{
"content": "<932 chars \u2014 see below>",
"file_path": "test_triton.py"
}import torch
import triton
import triton.language as tl
@triton.jit
def test_p1_kernel(X_ptr, Y_ptr, B: tl.constexpr):
cols1 = tl.arange(0, B)
cols2 = tl.arange(B, 2*B)
x1 = tl.load(X_ptr + cols1)
x2 = tl.load(X_ptr + cols2)
# Each is shape [16]
# Find max of each
max1 = tl.expand_dims(tl.max(x1, axis=0), 0) # shape [1]
max2 = tl.expand_dims(tl.max(x2, axis=0), 0) # shape [1]
# Join and reshape to [2]
joined = tl.join(max1, max2)
combined = tl.reshape(joined, [2])
# Find global max
global_max = tl.expand_dims(tl.max(combined, axis=0), 0) # shape [1]
tl.store(Y_ptr + tl.arange(0, 1), global_max)
def run():
x = torch.randn(32, device='cuda')
y = torch.zeros(1, device='cuda')
test_p1_kernel[(1,)](x, y, B=16)
print("x:", x)
print("max(x) ref:", x.max())
print("max(x) triton:", y)
if __name__ == "__main__":
run()
{
"description": "Running test_triton.py to verify block store and top-1 logic workaround",
"command": "/workspace/.venv/bin/python test_triton.py"
}{
"file_path": "solution.py",
"content": "<6668 chars \u2014 see below>"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
def next_power_of_2(n):
return 1 if n <= 1 else 2**(n - 1).bit_length()
@triton.jit
def topk_phase1_kernel(
X_ptr,
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
N: tl.constexpr,
P: tl.constexpr,
K: tl.constexpr,
B: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
SINGLE_PASS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // BLOCKS_PER_ROW
block_col_idx = pid % BLOCKS_PER_ROW
N_per_block = (N + BLOCKS_PER_ROW - 1) // BLOCKS_PER_ROW
cols_start = block_col_idx * N_per_block
# Initialize running top-P with 0
running_top = tl.zeros([P], dtype=tl.uint64)
steps = (N_per_block + B - 1) // B
for s in range(steps):
chunk_cols = cols_start + s * B + tl.arange(0, B)
# Load values
mask = (chunk_cols < N) & (chunk_cols < cols_start + N_per_block)
vals = tl.load(X_ptr + row_idx * N + chunk_cols, mask=mask, other=-1e38)
# Pack float to sortable uint32
vals_u32 = vals.to(tl.uint32, bitcast=True)
is_neg = (vals_u32 & 0x80000000) != 0
mask_xor = tl.where(is_neg, 0xFFFFFFFF, 0x80000000)
mapped = vals_u32 ^ mask_xor
# Combine mapped float and column index
packed = (mapped.to(tl.uint64) << 32) | chunk_cols.to(tl.uint64)
# Find top-P of this chunk
if P == 1:
chunk_top = tl.expand_dims(tl.max(packed, axis=0), 0)
else:
chunk_top = tl.topk(packed, k=P)
# Merge running_top and chunk_top
joined = tl.join(running_top, chunk_top)
combined = tl.reshape(joined, [2 * P])
if P == 1:
running_top = tl.expand_dims(tl.max(combined, axis=0), 0)
else:
running_top = tl.topk(combined, k=P)
if SINGLE_PASS:
# Unpack and write directly to output
unpacked_mapped = (running_top >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (running_top & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, P)
mask = out_cols < K
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x, mask=mask)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx, mask=mask)
else:
# Write running_top to Workspace
out_cols = tl.arange(0, P)
tl.store(Workspace_ptr + (row_idx * BLOCKS_PER_ROW + block_col_idx) * P + out_cols, running_top)
@triton.jit
def topk_phase2_kernel(
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
P: tl.constexpr,
K: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
M: tl.constexpr,
):
row_idx = tl.program_id(0)
# Each row has BLOCKS_PER_ROW * P intermediate results in Workspace
cols = tl.arange(0, M)
packed = tl.load(Workspace_ptr + row_idx * M + cols)
# Find top-K
if K == 1:
top_packed = tl.expand_dims(tl.max(packed, axis=0), 0)
else:
top_packed = tl.topk(packed, k=K)
# Unpack
unpacked_mapped = (top_packed >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (top_packed & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, K)
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx)
class Model(nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch = batch
self.n = n
self.k
... (2668 more chars){
"command": "/workspace/.venv/bin/python check.py",
"description": "Running check.py to test our optimized solution across all test cases"
}{
"strategic_intent": "Benchmarking and tuning the performance of the Triton top-k kernel to optimize memory bandwidth and compute efficiency.",
"title": "Benchmarking and Performance Tuning",
"summary": "Successfully implemented a custom, two-pass Triton-based top-k kernel utilizing highly optimized parallel bitonic merge-select and float-to-sortable-uint32 packing/unpacking. Verified correctness across all shape scenarios (including batched argmax and non-power-of-two limits), successfully passing check.py. Now transitioning to benchmarking and performance tuning."
}## 📂 Topic: **Benchmarking and Performance Tuning**
**Summary:**
Successfully implemented a custom, two-pass Triton-based top-k kernel utilizing highly optimized parallel bitonic merge-select and float-to-sortable-uint32 packing/unpacking. Verified correctness across all shape scenarios (including batched argmax and non-power-of-two limits), successfully passing check.py. Now transitioning to benchmarking and performance tuning.
> [!STRATEGY]
> **Intent:** Benchmarking and tuning the performance of the Triton top-k kernel to optimize memory bandwidth and compute efficiency.{
"description": "Running benchmark.py to measure throughput and peak bandwidth fraction",
"command": "/workspace/.venv/bin/python benchmark.py"
}{
"file_path": "time_all.py",
"content": "<1399 chars \u2014 see below>"
}import torch
import reference
import solution
import shapes
import time
def main():
device = torch.device("cuda:0")
for i, shape in enumerate(shapes.SHAPES):
batch, n, k = shape["batch"], shape["n"], shape["k"]
print(f"Shape {i}: batch={batch}, n={n}, k={k}")
# Inputs
x = torch.randn(batch, n, device=device)
# Warmup reference
ref_model = reference.Model(batch, n, k).to(device)
for _ in range(10):
ref_model(x)
# Measure reference
torch.cuda.synchronize()
t0 = time.perf_counter()
iters = 100
for _ in range(iters):
ref_vals, ref_idxs = ref_model(x)
torch.cuda.synchronize()
ref_ms = (time.perf_counter() - t0) * 1000 / iters
# Warmup solution
sol_model = solution.Model(batch, n, k).to(device)
for _ in range(10):
sol_model(x)
# Measure solution
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
sol_vals, sol_idxs = sol_model(x)
torch.cuda.synchronize()
sol_ms = (time.perf_counter() - t0) * 1000 / iters
print(f" Reference: {ref_ms:.4f} ms")
print(f" Solution: {sol_ms:.4f} ms (Speedup: {ref_ms / sol_ms:.2f}x)")
if __name__ == "__main__":
main()
{
"command": "/workspace/.venv/bin/python time_all.py",
"description": "Running time_all.py to evaluate speedup compared to eager baseline"
}{
"file_path": "tune.py",
"content": "<1881 chars \u2014 see below>"
}import torch
import reference
import solution
import shapes
import time
def main():
device = torch.device("cuda:0")
for i, shape in enumerate(shapes.SHAPES):
batch, n, k = shape["batch"], shape["n"], shape["k"]
print(f"\nShape {i}: batch={batch}, n={n}, k={k}")
# Inputs
x = torch.randn(batch, n, device=device)
# Try different blocks_per_row
best_ms = 9999.0
best_cfg = None
for bpr in [1, 2, 4, 8, 16, 32, 64]:
# Make sure we don't have too many blocks or invalid configs
n_per_block = (n + bpr - 1) // bpr
if n_per_block < k:
continue
for b in [512, 1024, 2048]:
# Test configuration
model = solution.Model(batch, n, k).to(device)
model.blocks_per_row = bpr
model.b = b
# Warmup
try:
for _ in range(5):
model(x)
torch.cuda.synchronize()
t0 = time.perf_counter()
iters = 50
for _ in range(iters):
model(x)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) * 1000 / iters
if ms < best_ms:
best_ms = ms
best_cfg = (bpr, b)
print(f" blocks_per_row={bpr:2d}, B={b:4d} -> {ms:.4f} ms")
except Exception as e:
print(f" blocks_per_row={bpr:2d}, B={b:4d} -> Error: {e}")
print(f" ==> Best: blocks_per_row={best_cfg[0]}, B={best_cfg[1]} ({best_ms:.4f} ms)")
if __name__ == "__main__":
main()
{
"description": "Running tune.py to sweep and identify optimal configurations for each shape",
"command": "/workspace/.venv/bin/python tune.py"
}{
"file_path": "solution.py",
"content": "<7358 chars \u2014 see below>"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
def next_power_of_2(n):
return 1 if n <= 1 else 2**(n - 1).bit_length()
@triton.jit
def topk_phase1_kernel(
X_ptr,
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
N: tl.constexpr,
P: tl.constexpr,
K: tl.constexpr,
B: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
SINGLE_PASS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // BLOCKS_PER_ROW
block_col_idx = pid % BLOCKS_PER_ROW
N_per_block = (N + BLOCKS_PER_ROW - 1) // BLOCKS_PER_ROW
cols_start = block_col_idx * N_per_block
# Initialize running top-P with 0
running_top = tl.zeros([P], dtype=tl.uint64)
steps = (N_per_block + B - 1) // B
for s in range(steps):
chunk_cols = cols_start + s * B + tl.arange(0, B)
# Load values
mask = (chunk_cols < N) & (chunk_cols < cols_start + N_per_block)
vals = tl.load(X_ptr + row_idx * N + chunk_cols, mask=mask, other=-1e38)
# Pack float to sortable uint32
vals_u32 = vals.to(tl.uint32, bitcast=True)
is_neg = (vals_u32 & 0x80000000) != 0
mask_xor = tl.where(is_neg, 0xFFFFFFFF, 0x80000000)
mapped = vals_u32 ^ mask_xor
# Combine mapped float and column index
packed = (mapped.to(tl.uint64) << 32) | chunk_cols.to(tl.uint64)
# Find top-P of this chunk
if P == 1:
chunk_top = tl.expand_dims(tl.max(packed, axis=0), 0)
else:
chunk_top = tl.topk(packed, k=P)
# Merge running_top and chunk_top
joined = tl.join(running_top, chunk_top)
combined = tl.reshape(joined, [2 * P])
if P == 1:
running_top = tl.expand_dims(tl.max(combined, axis=0), 0)
else:
running_top = tl.topk(combined, k=P)
if SINGLE_PASS:
# Unpack and write directly to output
unpacked_mapped = (running_top >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (running_top & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, P)
mask = out_cols < K
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x, mask=mask)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx, mask=mask)
else:
# Write running_top to Workspace
out_cols = tl.arange(0, P)
tl.store(Workspace_ptr + (row_idx * BLOCKS_PER_ROW + block_col_idx) * P + out_cols, running_top)
@triton.jit
def topk_phase2_kernel(
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
P: tl.constexpr,
K: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
M: tl.constexpr,
):
row_idx = tl.program_id(0)
# Each row has BLOCKS_PER_ROW * P intermediate results in Workspace
cols = tl.arange(0, M)
packed = tl.load(Workspace_ptr + row_idx * M + cols)
# Find top-K
if K == 1:
top_packed = tl.expand_dims(tl.max(packed, axis=0), 0)
else:
top_packed = tl.topk(packed, k=K)
# Unpack
unpacked_mapped = (top_packed >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (top_packed & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, K)
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx)
class Model(nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch = batch
self.n = n
self.k
... (3358 more chars){
"command": "/workspace/.venv/bin/python time_all.py",
"description": "Running time_all.py to measure performance with the optimal configurations"
}{
"description": "Running check.py to confirm correctness remains PASS",
"command": "/workspace/.venv/bin/python check.py"
}{
"command": "/workspace/.venv/bin/python benchmark.py",
"description": "Running benchmark.py to get the official scored peak fraction and results"
}{
"content": "<1141 chars \u2014 see below>",
"file_path": "tune_warps.py"
}import torch
import reference
import solution
import shapes
import time
def main():
device = torch.device("cuda:0")
for i, shape in enumerate(shapes.SHAPES):
batch, n, k = shape["batch"], shape["n"], shape["k"]
print(f"\nShape {i}: batch={batch}, n={n}, k={k}")
x = torch.randn(batch, n, device=device)
best_ms = 9999.0
best_warps = None
# We can temporarily modify the launch calls in solution.py or write a custom test
# Let's write a custom test inside solution.py's forward or just sweep here.
# Wait, since solution.py is imported, we can patch the forward method or edit it.
# Let's test different warps by defining a custom launch inside a copy of the model.
for w1 in [4, 8, 16]:
for w2 in [4, 8, 16]:
# Patch launch or run
# To keep it simple, we can just edit solution.py to use different warps and run time_all.py.
# But we can also test it here if we use a modified version of forward.
pass
if __name__ == "__main__":
main()
{
"file_path": "tune_warps.py",
"content": "<3149 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
import solution
import shapes
import time
def next_power_of_2(n):
return 1 if n <= 1 else 2**(n - 1).bit_length()
def run_custom_forward(x, model, w1, w2):
batch, n, k = model.batch, model.n, model.k
p, b, blocks_per_row = model.p, model.b, model.blocks_per_row
out_vals = torch.empty(batch, k, dtype=torch.float32, device=x.device)
out_idxs = torch.empty(batch, k, dtype=torch.int64, device=x.device)
if blocks_per_row == 1:
grid = (batch,)
solution.topk_phase1_kernel[grid](
x,
None,
out_vals,
out_idxs,
N=n,
P=p,
K=k,
B=b,
BLOCKS_PER_ROW=1,
SINGLE_PASS=True,
num_warps=w1,
)
else:
workspace = torch.empty(batch, blocks_per_row, p, dtype=torch.int64, device=x.device)
grid1 = (batch * blocks_per_row,)
solution.topk_phase1_kernel[grid1](
x,
workspace,
None,
None,
N=n,
P=p,
K=k,
B=b,
BLOCKS_PER_ROW=blocks_per_row,
SINGLE_PASS=False,
num_warps=w1,
)
grid2 = (batch,)
solution.topk_phase2_kernel[grid2](
workspace,
out_vals,
out_idxs,
P=p,
K=k,
BLOCKS_PER_ROW=blocks_per_row,
M=blocks_per_row * p,
num_warps=w2,
)
return out_vals, out_idxs
def main():
device = torch.device("cuda:0")
for i, shape in enumerate(shapes.SHAPES):
batch, n, k = shape["batch"], shape["n"], shape["k"]
print(f"\nShape {i}: batch={batch}, n={n}, k={k}")
x = torch.randn(batch, n, device=device)
model = solution.Model(batch, n, k).to(device)
# Determine the set of w1 and w2 to try
w1_options = [2, 4, 8, 16]
w2_options = [2, 4, 8, 16] if model.blocks_per_row > 1 else [4]
best_ms = 9999.0
best_w = None
for w1 in w1_options:
for w2 in w2_options:
try:
# Warmup
for _ in range(5):
run_custom_forward(x, model, w1, w2)
torch.cuda.synchronize()
t0 = time.perf_counter()
iters = 100
for _ in range(iters):
run_custom_forward(x, model, w1, w2)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) * 1000 / iters
print(f" w1={w1:2d}, w2={w2:2d} -> {ms:.4f} ms")
if ms < best_ms:
best_ms = ms
best_w = (w1, w2)
except Exception as e:
pass
print(f" ==> Best: w1={best_w[0]}, w2={best_w[1]} ({best_ms:.4f} ms)")
if __name__ == "__main__":
main()
{
"description": "Running tune_warps.py to sweep num_warps parameter combinations",
"command": "/workspace/.venv/bin/python tune_warps.py"
}{
"content": "<7766 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
def next_power_of_2(n):
return 1 if n <= 1 else 2**(n - 1).bit_length()
@triton.jit
def topk_phase1_kernel(
X_ptr,
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
N: tl.constexpr,
P: tl.constexpr,
K: tl.constexpr,
B: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
SINGLE_PASS: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = pid // BLOCKS_PER_ROW
block_col_idx = pid % BLOCKS_PER_ROW
N_per_block = (N + BLOCKS_PER_ROW - 1) // BLOCKS_PER_ROW
cols_start = block_col_idx * N_per_block
# Initialize running top-P with 0
running_top = tl.zeros([P], dtype=tl.uint64)
steps = (N_per_block + B - 1) // B
for s in range(steps):
chunk_cols = cols_start + s * B + tl.arange(0, B)
# Load values
mask = (chunk_cols < N) & (chunk_cols < cols_start + N_per_block)
vals = tl.load(X_ptr + row_idx * N + chunk_cols, mask=mask, other=-1e38)
# Pack float to sortable uint32
vals_u32 = vals.to(tl.uint32, bitcast=True)
is_neg = (vals_u32 & 0x80000000) != 0
mask_xor = tl.where(is_neg, 0xFFFFFFFF, 0x80000000)
mapped = vals_u32 ^ mask_xor
# Combine mapped float and column index
packed = (mapped.to(tl.uint64) << 32) | chunk_cols.to(tl.uint64)
# Find top-P of this chunk
if P == 1:
chunk_top = tl.expand_dims(tl.max(packed, axis=0), 0)
else:
chunk_top = tl.topk(packed, k=P)
# Merge running_top and chunk_top
joined = tl.join(running_top, chunk_top)
combined = tl.reshape(joined, [2 * P])
if P == 1:
running_top = tl.expand_dims(tl.max(combined, axis=0), 0)
else:
running_top = tl.topk(combined, k=P)
if SINGLE_PASS:
# Unpack and write directly to output
unpacked_mapped = (running_top >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (running_top & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, P)
mask = out_cols < K
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x, mask=mask)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx, mask=mask)
else:
# Write running_top to Workspace
out_cols = tl.arange(0, P)
tl.store(Workspace_ptr + (row_idx * BLOCKS_PER_ROW + block_col_idx) * P + out_cols, running_top)
@triton.jit
def topk_phase2_kernel(
Workspace_ptr,
Out_vals_ptr,
Out_idxs_ptr,
P: tl.constexpr,
K: tl.constexpr,
BLOCKS_PER_ROW: tl.constexpr,
M: tl.constexpr,
):
row_idx = tl.program_id(0)
# Each row has BLOCKS_PER_ROW * P intermediate results in Workspace
cols = tl.arange(0, M)
packed = tl.load(Workspace_ptr + row_idx * M + cols)
# Find top-K
if K == 1:
top_packed = tl.expand_dims(tl.max(packed, axis=0), 0)
else:
top_packed = tl.topk(packed, k=K)
# Unpack
unpacked_mapped = (top_packed >> 32).to(tl.uint32)
is_pos_or_zero = (unpacked_mapped & 0x80000000) != 0
orig_mask = tl.where(is_pos_or_zero, 0x80000000, 0xFFFFFFFF)
unpacked_u32 = unpacked_mapped ^ orig_mask
unpacked_x = unpacked_u32.to(tl.float32, bitcast=True)
unpacked_idx = (top_packed & 0xFFFFFFFF).to(tl.int64)
out_cols = tl.arange(0, K)
tl.store(Out_vals_ptr + row_idx * K + out_cols, unpacked_x)
tl.store(Out_idxs_ptr + row_idx * K + out_cols, unpacked_idx)
class Model(nn.Module):
def __init__(self, batch: int, n: int, k: int):
super().__init__()
self.batch = batch
self.n = n
self.k
... (3766 more chars){
"description": "Running check.py to confirm the correctness of our final solution",
"command": "/workspace/.venv/bin/python check.py"
}{
"description": "Running benchmark.py to get the official scored peak fraction",
"command": "/workspace/.venv/bin/python benchmark.py"
}{
"file_path": "framework.txt"
}{
"description": "Cleaning up temporary scratch files",
"command": "rm test_triton.py tune.py tune_warps.py time_all.py"
}{
"title": "Final Verification and Summary",
"summary": "Successfully designed, implemented, and optimized a custom top-k selection kernel in Triton for the RTX PRO 6000 (SM120 Blackwell). Utilized a mathematically rigorous float-to-sortable-uint32 bitwise mapping to pack floating-point values and their original column indices into 64-bit unsigned integers. Performed a two-pass parallel reduction utilizing Triton's highly-optimized bitonic top-k select primitives. Conducted automated sweeps to fine-tune `BLOCKS_PER_ROW`, chunk sizes, and warp counts (`num_warps`) for each target shape. Ach
... (298 more chars)## 📂 Topic: **Final Verification and Summary**
**Summary:**
Successfully designed, implemented, and optimized a custom top-k selection kernel in Triton for the RTX PRO 6000 (SM120 Blackwell). Utilized a mathematically rigorous float-to-sortable-uint32 bitwise mapping to pack floating-point values and their original column indices into 64-bit unsigned integers. Performed a two-pass parallel reduction utilizing Triton's highly-optimized bitonic top-k select primitives. Conducted automated sweeps to fine-tune `BLOCKS_PER_ROW`, chunk sizes, and warp counts (`num_warps`) for each target shape. Achieved a significant performance speedup of up to 2.60x over PyTorch's native CUB-backed implementation while maintaining 100% correctness across all target scenarios and numeric stress cases.
> [!STRATEGY]
> **Intent:** Optimizations completed and verified. Finalizing and summarizing the solution.