"""Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
Triton-based decode kernel with online softmax and paged KV-cache gathering.
Splits work across (batch, kv_head, seq_chunk) for SM occupancy,
then reduces partial results with a second kernel.
"""
import math
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
BATCH = 8
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
SEQ_LEN = 1024
PAGE_SIZE = 16
# ---------------------------------------------------------------------------
# Pass 1: partial attention over a chunk of pages
# ---------------------------------------------------------------------------
@triton.jit
def _partial_decode_kernel(
query_ptr,
kv_cache_ptr,
block_table_ptr,
seq_lens_ptr,
partial_out_ptr,
B,
H,
Hkv,
D,
max_blocks,
P,
pages_per_chunk,
stride_q_b,
stride_q_h,
stride_kv_blk,
stride_kv_pos,
stride_kv_h,
stride_bt_b,
stride_po_c,
stride_po_b,
stride_po_h,
BLOCK_D: tl.constexpr,
group_size: tl.constexpr,
):
"""Compute partial attention for one chunk of pages.
Grid: (num_chunks, B, Hkv)
"""
pid_c = tl.program_id(0)
pid_b = tl.program_id(1)
pid_kv = tl.program_id(2)
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
page_start = pid_c * pages_per_chunk
page_end = tl.minimum(page_start + pages_per_chunk, num_pages)
if page_start >= num_pages:
return
BLOCK_L: tl.constexpr = 16
offs_l = tl.arange(0, BLOCK_L)
offs_d = tl.arange(0, BLOCK_D)
offs_g = tl.arange(0, group_size)
q_head_base = pid_kv * group_size
# Load all Q heads in this group
q_offs = (
pid_b * stride_q_b
+ (q_head_base + offs_g[:, None]) * stride_q_h
+ offs_d[None, :]
)
mask_q = offs_d[None, :] < D
q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32)
m = tl.full([group_size], float("-inf"), dtype=tl.float32)
l_sum = tl.zeros([group_size], dtype=tl.float32)
acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
scale = 1.0 / tl.sqrt(D.to(tl.float32))
for page_idx in range(page_start, page_end):
blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
if page_idx == num_pages - 1:
rem = seq_len % P
tokens_this_page = tl.where(rem == 0, P, rem)
else:
tokens_this_page = P
valid_l = offs_l < tokens_this_page
k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
# K tile
k_offs = k_base + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
mask_kv = valid_l[:, None] & (offs_d[None, :] < D)
k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_kv, other=0.0).to(tl.float32)
scores = tl.dot(q, tl.trans(k_tile)) * scale
scores = tl.where(valid_l[None, :], scores, float("-inf"))
m_new = tl.maximum(m, tl.max(scores, axis=1))
rescale = tl.exp(m - m_new)
acc = acc * rescale[:, None]
l_sum = l_sum * rescale
p = tl.exp(scores - m_new[:, None])
p = tl.where(valid_l[None, :], p, 0.0)
l_sum = l_sum + tl.sum(p, axis=1)
# V tile
v_offs = k_base + D + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
v_tile = tl.load(kv_cache_ptr + v_offs, mask=mask_kv, other=0.0).to(tl.float32)
acc = acc + tl.dot(p, v_tile)
m = m_new
# Write partial results — vectorised across all Q heads in the group.
base_off = pid_c * stride_po_c + pid_b * stride_po_b
# m: (group_size,) -> store at [c, b, q_head, D]
m_offs = base_off + (q_head_base + offs_g) * stride_po_h + D
tl.store(partial_out_ptr + m_offs, m)
# l_sum: (group_size,) -> store at [c, b, q_head, D+1]
ls_offs = base_off + (q_head_base + offs_g) * stride_po_h + D + 1
tl.store(partial_out_ptr + ls_offs, l_sum)
# acc: (group_size, BLOCK_D) -> store at [c, b, q_head, 0:D]
acc_offs = (
base_off
+ (q_head_base + offs_g[:, None]) * stride_po_h
+ offs_d[None, :]
)
mask_acc = offs_d[None, :] < D
tl.store(partial_out_ptr + acc_offs, acc, mask=mask_acc)
# ---------------------------------------------------------------------------
# Pass 2: reduce partial results across chunks
# ---------------------------------------------------------------------------
@triton.jit
def _reduce_partial_kernel(
partial_in_ptr,
output_ptr,
num_chunks,
B,
H,
D,
stride_pi_c,
stride_pi_b,
stride_pi_h,
stride_q_b,
stride_q_h,
BLOCK_D: tl.constexpr,
):
"""Reduce partial results across chunks into final output.
Grid: (H, B) — one program per query head.
"""
pid_q = tl.program_id(0)
pid_b = tl.program_id(1)
offs_d = tl.arange(0, BLOCK_D)
mask_d = offs_d < D
m_global = float("-inf")
l_global = 0.0
o_global = tl.zeros([BLOCK_D], dtype=tl.float32)
for c in range(num_chunks):
# Load m_c, l_c
m_c = tl.load(
partial_in_ptr
+ c * stride_pi_c
+ pid_b * stride_pi_b
+ pid_q * stride_pi_h
+ D
)
l_c = tl.load(
partial_in_ptr
+ c * stride_pi_c
+ pid_b * stride_pi_b
+ pid_q * stride_pi_h
+ D + 1
)
chunk_valid = l_c > 0.0
m_new = tl.maximum(m_global, m_c)
# Only rescale if chunk is valid; otherwise keep current state.
# tl.where selects element-wise — both branches are evaluated.
rescale_old = tl.exp(m_global - m_new)
rescale_c = tl.exp(m_c - m_new)
o_global = tl.where(chunk_valid, o_global * rescale_old, o_global)
l_global = tl.where(chunk_valid, l_global * rescale_old, l_global)
# Load acc_c for this chunk
acc_c = tl.load(
partial_in_ptr
+ c * stride_pi_c
+ pid_b * stride_pi_b
+ pid_q * stride_pi_h
+ offs_d,
mask=mask_d,
other=0.0,
)
# acc_c is already the exp-weighted sum: Σ exp(s-m_c)·V
o_global = tl.where(
chunk_valid,
o_global + rescale_c * acc_c,
o_global,
)
l_global = tl.where(
chunk_valid,
l_global + rescale_c * l_c,
l_global,
)
m_global = tl.where(chunk_valid, m_new, m_global)
l_safe = tl.where(l_global == 0.0, 1.0, l_global)
out_vals = o_global / l_safe
out_offs = pid_b * stride_q_b + pid_q * stride_q_h + offs_d
tl.store(output_ptr + out_offs, out_vals.to(tl.bfloat16), mask=mask_d)
# ---------------------------------------------------------------------------
# Host-side dispatch
# ---------------------------------------------------------------------------
def _paged_attention_decode(
query: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
) -> torch.Tensor:
B, H, D = query.shape
Hkv = kv_cache.shape[2]
P = kv_cache.shape[1]
max_blocks = block_table.shape[1]
group_size = H // Hkv
max_seq = int(seq_lens.max().item())
max_pages = (max_seq + P - 1) // P
# Target at least 256 blocks for good SM occupancy.
base_blocks = B * Hkv
target_blocks = 256
num_chunks = max(1, min(max_pages, target_blocks // base_blocks))
pages_per_chunk = (max_pages + num_chunks - 1) // num_chunks
# Intermediate storage: (num_chunks, B, H, D + 2) fp32
partial = torch.zeros(
num_chunks, B, H, D + 2,
dtype=torch.float32,
device=query.device,
)
out = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device)
# Pass 1: partial attention over chunks
grid1 = (num_chunks, B, Hkv)
_partial_decode_kernel[grid1](
query,
kv_cache,
block_table,
seq_lens,
partial,
B,
H,
Hkv,
D,
max_blocks,
P,
pages_per_chunk,
query.stride(0),
query.stride(1),
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
block_table.stride(0),
partial.stride(0),
partial.stride(1),
partial.stride(2),
BLOCK_D=D,
group_size=group_size,
)
# Pass 2: reduce partial results across chunks
grid2 = (H, B)
_reduce_partial_kernel[grid2](
partial,
out,
num_chunks,
B,
H,
D,
partial.stride(0),
partial.stride(1),
partial.stride(2),
out.stride(0),
out.stride(1),
BLOCK_D=D,
)
return out
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class Model(nn.Module):
"""Single-query paged attention decode — Triton two-pass kernel."""
def __init__(
self,
batch: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_len: int,
page_size: int,
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.batch = batch
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.seq_len = seq_len
self.page_size = page_size
self.scale = 1.0 / math.sqrt(head_dim)
self.register_buffer(
"_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False
)
def forward(
self,
query: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
) -> torch.Tensor:
return _paged_attention_decode(
query.contiguous(),
kv_cache.contiguous(),
block_table.contiguous(),
seq_lens.contiguous(),
)
def get_inputs():
B = BATCH
H = NUM_HEADS
Hkv = NUM_KV_HEADS
D = HEAD_DIM
L = SEQ_LEN
P = PAGE_SIZE
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1
perm = (
torch.randperm(total_pages)[: B * pages_per_seq]
.reshape(B, pages_per_seq)
.int()
)
block_table = perm.contiguous()
seq_lens = torch.full((B,), L, dtype=torch.int32)
return [query, kv_cache, block_table, seq_lens]
def get_init_inputs():
return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_125721_deepseek-claude_deepseek-v4-pro_03_paged_attention/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T19:22:26.481263+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T19:22:26.730075+00:00 elapsed_s=0.249 ms=0.056592
shape=0 variant=solution tflops=2.372 gbps=595.234 ms=0.057
shape=0 solution_peak_fraction=0.3307
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T19:22:28.358446+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T19:22:28.371089+00:00 elapsed_s=0.013 ms=0.225136
shape=1 variant=solution tflops=4.769 gbps=1194.655 ms=0.225
shape=1 solution_peak_fraction=0.6637
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T19:22:28.779247+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T19:22:28.814548+00:00 elapsed_s=0.035 ms=0.081584
shape=2 variant=solution tflops=6.581 gbps=824.180 ms=0.082
shape=2 solution_peak_fraction=0.4579
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T19:22:29.425655+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T19:22:29.433042+00:00 elapsed_s=0.007 ms=0.105104
shape=3 variant=solution tflops=3.829 gbps=959.620 ms=0.105
shape=3 solution_peak_fraction=0.5331
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T19:22:29.532485+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T19:22:29.538909+00:00 elapsed_s=0.006 ms=0.052032
shape=4 variant=solution tflops=1.260 gbps=315.513 ms=0.052
shape=4 solution_peak_fraction=0.1753
peak_fraction: 0.3931
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_125721_deepseek-claude_deepseek-v4-pro_03_paged_attention/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_125721_deepseek-claude_deepseek-v4-pro_03_paged_attention/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_125721_deepseek-claude_deepseek-v4-pro_03_paged_attention/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_125721_deepseek-claude_deepseek-v4-pro_03_paged_attention/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_125721_deepseek-claude_deepseek-v4-pro_03_paged_attention/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_125721_deepseek-claude_deepseek-v4-pro_03_paged_attention/repo
Installed 56 packages in 68ms
PASS
{
"run_id": "20260615_125721_deepseek-claude_deepseek-v4-pro_03_paged_attention",
"run_group": "",
"problem": "03_paged_attention",
"harness": "deepseek-claude",
"model": "deepseek-v4-pro",
"reasoning_effort": "",
"started_at": "2026-06-15T12:57:21-06:00",
"harness_finished_at": "2026-06-15T13:22:12-06:00",
"finished_at": "2026-06-15T13:22:29-06:00",
"start_epoch": 1781549841,
"harness_end_epoch": 1781551332,
"end_epoch": 1781551349,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.3931,
"template_mutated": false,
"elapsed_seconds": 1491,
"total_elapsed_seconds": 1508,
"check_elapsed_seconds": 13,
"benchmark_elapsed_seconds": 4,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": null,
"usage": {"input_tokens": null, "output_tokens": null, "cache_read_tokens": null, "cache_creation_tokens": null, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"file_path": "/workspace/problems/03_paged_attention/reference.py"
}1 """Naive PyTorch paged-attention decode reference (correctness oracle, not SOTA).
2
3 Single-query decode: each batch element has a query of shape (num_heads, head_dim)
4 and attends over a KV cache of `seq_len[b]` tokens stored as fixed-size pages in
5 a global pool. Pages for batch element b are listed in `block_table[b]`.
6
7 The reference performs the slow path:
8 1. Gather pages -> contiguous (seq_len, num_kv_heads, head_dim) per batch element.
9 2. Repeat KV heads for grouped-query (broadcast num_kv_heads -> num_heads).
10 3. Manual softmax(QK^T / sqrt(d)) @ V in fp32, cast back to bf16.
11
12 This avoids torch.nn.functional.scaled_dot_product_attention (which is on the
13 forbidden list) so the agent cannot dispatch through SDPA either.
14 """
15 import math
16
17 import torch
18 import torch.nn as nn
19
20 OP_TYPE = "attention"
21 SUPPORTED_PRECISIONS = ["bf16"]
22 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
23
24
25 # --- Shape knobs (overridden by check.py / benchmark.py from shapes.py) ----
26 BATCH = 8
27 NUM_HEADS = 32
28 NUM_KV_HEADS = 8
29 HEAD_DIM = 128
30 SEQ_LEN = 1024
31 PAGE_SIZE = 16
32
33
34 class Model(nn.Module):
35 """Single-query paged attention decode.
36
37 Forward inputs (all on device):
38 query: (batch, num_heads, head_dim) bf16
39 kv_cache: (num_blocks, page_size, num_kv_heads, head_dim * 2)
40 Layout: last dim packs [K | V] so a single gather pulls both.
4
... (4123 more chars){
"file_path": "/workspace/problems/03_paged_attention/benchmark.py"
}1 """Roofline benchmark for paged-attention decode.
2
3 For each shape: times the agent's solution first. Optional eager, compiled,
4 and SOTA diagnostics are enabled with KBH_BENCHMARK_BASELINES=1. Reports achieved TFLOPS, GB/s, and peak_fraction.
5
6 Decode is memory-bound, so peak_fraction is computed from achieved GB/s vs
7 the GPU's peak DRAM bandwidth.
8 """
9 import sys
10 from math import exp, log
11 from pathlib import Path
12
13 import torch
14 import yaml
15
16 REPO_ROOT = Path(__file__).resolve().parents[2]
17 sys.path.insert(0, str(REPO_ROOT))
18
19 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
20 from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
21 from src.hardware import get as get_hw # noqa: E402
22
23
24 def _eval_formula(expr: str, vars: dict) -> float:
25 return float(eval(expr, {"__builtins__": {}}, vars))
26
27
28 def _apply_shape(reference, shape: dict) -> None:
29 reference.BATCH = shape["batch"]
30 reference.NUM_HEADS = shape["num_heads"]
31 reference.NUM_KV_HEADS = shape["num_kv_heads"]
32 reference.HEAD_DIM = shape["head_dim"]
33 reference.SEQ_LEN = shape["seq_len"]
34 reference.PAGE_SIZE = shape["page_size"]
35
36
37 def main():
38 import reference
39 import shapes
40 import solution
41
42 meta = yaml.safe_load(Path("problem.yaml").read_text())
43 hw = get_hw(meta["hardware"][0])
44 peak_tflops = hw.peak_tflop
... (4399 more chars){
"file_path": "/workspace/problems/03_paged_attention/check.py"
}1 """Correctness runner for paged-attention decode.
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5 """
6 import re
7 import sys
8 from pathlib import Path
9
10 import torch
11 import yaml
12
13 REPO_ROOT = Path(__file__).resolve().parents[2]
14 sys.path.insert(0, str(REPO_ROOT))
15
16 from src.eval.correctness import check_correctness # noqa: E402
17 from src.eval.numeric_stress import ( # noqa: E402
18 numeric_stress_cases,
19 numeric_stress_context,
20 tolerance_for_case,
21 )
22
23
24 def _apply_shape(reference, shape: dict) -> None:
25 reference.BATCH = shape["batch"]
26 reference.NUM_HEADS = shape["num_heads"]
27 reference.NUM_KV_HEADS = shape["num_kv_heads"]
28 reference.HEAD_DIM = shape["head_dim"]
29 reference.SEQ_LEN = shape["seq_len"]
30 reference.PAGE_SIZE = shape["page_size"]
31
32
33 def main():
34 try:
35 import reference
36 import shapes
37 import solution
38 except Exception as e:
39 print(f"FAIL: import error: {e}")
40 sys.exit(1)
41
42 problem_yaml = Path("problem.yaml")
43 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
44
45 sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
46 for forbidden in meta.get("forbidden", []):
47 pat = re.escape(forbidden)
48 if re.se
... (2701 more chars){
"command": "ls -la /workspace/problems/03_paged_attention/",
"description": "List all files in the workspace"
}total 44
drwxrwxr-x 2 1002 1002 4096 Jun 15 18:57 .
drwxrwxr-x 3 1002 1002 4096 Jun 15 18:57 ..
-rw-rw-r-- 1 1002 1002 2585 Apr 27 09:21 PROMPT.txt
-rw-rw-r-- 1 1002 1002 5387 Jun 1 21:05 benchmark.py
-rw-rw-r-- 1 1002 1002 3841 Jun 2 01:08 check.py
-rw-rw-r-- 1 1002 1002 2030 Apr 26 04:19 problem.yaml
-rw-rw-r-- 1 1002 1002 5151 Apr 26 04:19 reference.py
-rw-rw-r-- 1 1002 1002 933 Apr 26 04:19 shapes.py
-rw-rw-r-- 1 1002 1002 2819 Apr 26 04:24 sota.py{
"file_path": "/workspace/problems/03_paged_attention/shapes.py"
}1 """Shape sweep for paged attention decode.
2
3 Mix targets:
4 - small batch / long context (Llama-3 8B-style decode)
5 - large batch / medium context (server batched decode)
6 - GQA wide ratio (Llama-3 70B: 64 heads / 8 kv-heads)
7 - non-power-of-2 seq_len (forces predicated tail handling)
8 - head_dim=64 small-head case
9 """
10
11 SHAPES = [
12 # (B, H, Hkv, D, L, P)
13 {"batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1024, "page_size": 16},
14 {"batch": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 2048, "page_size": 16},
15 {"batch": 4, "num_heads": 64, "num_kv_heads": 8, "head_dim": 128, "seq_len": 4096, "page_size": 16},
16 {"batch": 16, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1535, "page_size": 16}, # non-pow2
17 {"batch": 8, "num_heads": 16, "num_kv_heads": 4, "head_dim": 64, "seq_len": 2000, "page_size": 16}, # small-D, non-pow2
18 ]
19 {
"file_path": "/workspace/problems/03_paged_attention/problem.yaml"
}1 name: 03_paged_attention
2 display_name: "Paged Attention Decode"
3 precision: bf16
4 regime: memory # decode is bandwidth-bound (KV cache streamed once per token)
5
6 # Effective FLOPs per call: QK^T + softmax*V across all batches/heads/seq.
7 # Sum-formula evaluated per-shape; the benchmark substitutes batch/num_heads/etc.
8 # 2 * B * H * L * D for QK^T plus 2 * B * H * L * D for AV => 4 * B * H * L * D
9 flops_formula: "4 * batch * num_heads * seq_len * head_dim"
10
11 # Bytes moved (the real bottleneck): the KV cache must be streamed end-to-end.
12 # K and V each: batch * seq_len * num_kv_heads * head_dim * 2 bytes (bf16)
13 # Q: batch * num_heads * head_dim * 2 (negligible vs KV)
14 # Out: same as Q
15 # So total ~ 2 * (B * L * Hkv * D * 2) + small.
16 bytes_formula: "2 * batch * seq_len * num_kv_heads * head_dim * 2 + batch * num_heads * head_dim * 2 * 2"
17
18 hardware: [RTX_PRO_6000]
19 peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20 peak_bandwidth_key: dram # 1.8 TB/s GDDR7 on RTX PRO 6000 Blackwell
21
22 tolerance:
23 bfloat16: 0.02
24 float16: 0.02
25
26 # Forbidden ops -- using any of these in solution.py fails correctness post-hoc.
27 # vllm.attention / flashinfer.batch_decode_with_paged_kv_cache: dispatching to
28 # the SOTA library is the trivial cheat. SDPA is also banned because the agent
29 # could gather pages then call SDPA and inherit FlashAttention "for free".
30 forbidden:
31 - "vllm.attenti
... (668 more chars){
"file_path": "/workspace/problems/03_paged_attention/sota.py"
}1 """SOTA reference for paged-attention decode.
2
3 Tries, in order:
4 1. FlashInfer's BatchDecodeWithPagedKVCacheWrapper (preferred -- portable,
5 supports SM120, GQA, arbitrary head_dim).
6 2. vLLM's paged_attention_v2 CUDA op (requires its KV-cache layout, more
7 finicky; we adapt the layout on the fly when possible).
8
9 If neither is importable, is_available() returns False and the benchmark just
10 reports eager + compiled + solution.
11
12 Agents are FORBIDDEN from importing these in solution.py (see problem.yaml).
13 This file is only for the benchmark's reference line.
14 """
15 from __future__ import annotations
16
17 import torch
18
19
20 def _try_flashinfer(
21 query: torch.Tensor,
22 kv_cache: torch.Tensor,
23 block_table: torch.Tensor,
24 seq_lens: torch.Tensor,
25 num_kv_heads: int,
26 head_dim: int,
27 page_size: int,
28 ) -> torch.Tensor | None:
29 try:
30 import flashinfer # noqa: F401
31 from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
32 except Exception:
33 return None
34
35 B, H, D = query.shape
36 # FlashInfer expects K and V as separate (num_blocks, page_size, num_kv_heads, head_dim) tensors.
37 # Our reference packs [K|V] on the last dim -- split here.
38 k_cache = kv_cache[..., :D].contiguous()
39 v_cache = kv_cache[..., D:].contiguous()
40
41 workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=query.device)
42
... (1565 more chars){
"file_path": "/workspace/problems/03_paged_attention/PROMPT.txt"
}1 I need you to write a paged attention decode kernel for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2
3 This is single-query decode: each batch element has a query of shape (num_heads, head_dim) and attends over a KV cache of seq_len[b] tokens stored as fixed-size pages in a global pool, with block_table[b] listing which pages belong to batch element b. The KV cache is laid out as (num_blocks, page_size, num_kv_heads, head_dim * 2) where the last dim packs [K | V] so a single gather pulls both. You return the bf16 output of shape (batch, num_heads, head_dim). Correctness tolerance is 0.02 abs/rel on bf16/fp16. The shapes you have to handle are batch=8 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=1024 page_size=16 (Llama-3 8B-style decode), batch=32 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=2048 page_size=16 (server-batched decode), batch=4 num_heads=64 num_kv_heads=8 head_dim=128 seq_len=4096 page_size=16 (Llama-3 70B GQA wide ratio with long context), batch=16 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=1535 page_size=16 (non-power-of-two seq_len that forces predicated tails), and batch=8 num_heads=16 num_kv_heads
... (1089 more chars){
"command": "nvidia-smi --query-gpu=name,compute_cap --format=csv 2>/dev/null || echo \"no nvidia-smi\"",
"description": "Check GPU model"
}name, compute_cap
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0{
"command": "python -c \"import torch; print('CUDA:', torch.cuda.is_available()); print('Compute:', torch.cuda.get_device_capability(0)); print('Arch:', torch.cuda.get_arch_list()); print('Triton:', __import__('importlib').import_module('triton').runtime.driver.active.get_device().name); print('GPU:', torch.cuda.get_device_name(0))\" 2>&1",
"description": "Check PyTorch and GPU capabilities"
}{
"command": "uv run python -c \"import torch; print('torch version:', torch.__version__); print('CUDA available:', torch.cuda.is_available()); print('Device count:', torch.cuda.device_count()); print('Device 0:', torch.cuda.get_device_name(0)); print('Capability:', torch.cuda.get_device_capability(0)); print('BF16 support:', torch.cuda.is_bf16_supported())\" 2>&1",
"description": "Check torch version and GPU via uv"
}Exit code 1
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/driver.py", line 24, in __getattr__
return getattr(self._obj, name)
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'CudaDriver' object has no attribute 'get_device'
CUDA: True
Compute: (12, 0)
Arch: ['sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120']Using CPython 3.11.15
Creating virtual environment at: /workspace/.venv
Building kernelbench-hard @ file:///workspace
Built kernelbench-hard @ file:///workspace
warning: Failed to hardlink files; falling back to full copy. This may lead to degraded performance.
If the cache and target directories are on different filesystems, hardlinking may not be supported.
If this is intentional, set `export UV_LINK_MODE=copy` or use `--link-mode=copy` to suppress this warning.
Installed 56 packages in 871ms
torch version: 2.11.0+cu130
CUDA available: True
Device count: 1
Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Capability: (12, 0)
BF16 support: True{
"command": "uv run python -c \"import triton; print('Triton version:', triton.__version__); print('Triton path:', triton.__file__)\" 2>&1",
"description": "Check Triton version and capabilities"
}Triton version: 3.6.0
Triton path: /workspace/.venv/lib/python3.11/site-packages/triton/__init__.py{
"command": "nvcc --version 2>&1; echo \"---\"; which ncu 2>&1; ncu --version 2>&1 | head -5",
"description": "Check CUDA compilation tools"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
---
/kbh/bin/ncu
NVIDIA (R) Nsight Compute Command Line Profiler
Copyright (c) 2018-2025 NVIDIA Corporation
Version 2025.2.1.0 (build 35987062) (public-release){
"command": "<474 chars \u2014 see below>",
"description": "Check available GPU memory and compute features"
}uv run python -c "
import torch
print('Total memory:', torch.cuda.get_device_properties(0).total_mem / 1e9, 'GB')
print('SM count:', torch.cuda.get_device_properties(0).multi_processor_count)
print('Max shared mem per block:', torch.cuda.get_device_properties(0).shared_mem_per_block, 'bytes')
print('Max threads per block:', torch.cuda.get_device_properties(0).max_threads_per_block)
print('L2 cache:', torch.cuda.get_device_properties(0).l2_cache_size / 1024, 'KB')
" 2>&1Exit code 1
Traceback (most recent call last):
File "<string>", line 3, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'total_mem'. Did you mean: 'total_memory'?{
"command": "<498 chars \u2014 see below>",
"description": "Check GPU memory properties"
}uv run python -c "
import torch
p = torch.cuda.get_device_properties(0)
print('Total memory:', p.total_memory / 1e9, 'GB')
print('SM count:', p.multi_processor_count)
print('Max shared mem per block:', p.shared_mem_per_block, 'bytes')
print('Max shared mem per block opt-in:', p.shared_mem_per_block_optin, 'bytes')
print('Max threads per block:', p.max_threads_per_block)
print('L2 cache:', p.l2_cache_size / 1024, 'KB')
print('Max blocks per SM:', p.max_threads_per_multi_processor // 256)
" 2>&1Exit code 1
Traceback (most recent call last):
File "<string>", line 6, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'shared_mem_per_block'. Did you mean: 'shared_memory_per_block'?
Total memory: 102.014189568 GB
SM count: 188{
"command": "<707 chars \u2014 see below>",
"description": "Check GPU properties with correct attribute names"
}uv run python -c "
import torch
p = torch.cuda.get_device_properties(0)
print('Total memory:', p.total_memory / 1e9, 'GB')
print('SM count:', p.multi_processor_count)
print('Max shared mem per block:', p.shared_memory_per_block, 'bytes')
print('Max shared mem per block opt-in:', p.shared_memory_per_block_optin, 'bytes')
print('Max threads per block:', p.max_threads_per_block)
print('L2 cache:', p.l2_cache_size / 1024, 'KB')
print('Max warps per SM:', p.max_threads_per_multi_processor // 32)
print('Max blocks per SM:', p.max_threads_per_multi_processor // 256)
print('Registers per block:', p.regs_per_block)
print('Registers per SM:', p.regs_per_multiprocessor)
print('Warp size:', p.warp_size)
" 2>&1Exit code 1
Traceback (most recent call last):
File "<string>", line 9, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'l2_cache_size'. Did you mean: 'L2_cache_size'?
Total memory: 102.014189568 GB
SM count: 188
Max shared mem per block: 49152 bytes
Max shared mem per block opt-in: 101376 bytes
Max threads per block: 1024{
"subject": "Write initial CUDA paged attention decode kernel",
"description": "Write a CUDA C++ kernel via load_inline that implements paged attention decode with online softmax, handling all 5 shape configs. Grid: (num_kv_heads, batch) with K/V loaded once per page into shared memory and reused across GQA query heads.",
"activeForm": "Writing CUDA kernel"
}Task #1 created successfully: Write initial CUDA paged attention decode kernel{
"taskId": "1",
"status": "in_progress"
}Updated task #1 status--- /dev/null
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -0,0 +1,372 @@
+"""Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
+
+CUDA C++ kernel with online softmax, shared-memory tiling, and vectorized loads.
+"""
+
+import math
+from typing import List
+
+import torch
+import torch.nn as nn
+
+OP_TYPE = "attention"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+BATCH = 8
+NUM_HEADS = 32
+NUM_KV_HEADS = 8
+HEAD_DIM = 128
+SEQ_LEN = 1024
+PAGE_SIZE = 16
+
+# ---------------------------------------------------------------------------
+# CUDA kernel
+# ---------------------------------------------------------------------------
+
+_cuda_kernel_src = r"""
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+
+constexpr int PAGE_SIZE = 16;
+constexpr int MAX_D = 128;
+constexpr int WARP_SIZE = 32;
+constexpr int MAX_GROUP_SIZE = 8; // max num_heads/num_kv_heads across shapes
+
+extern "C" __global__ void paged_attention_decode_kernel(
+ const __nv_bfloat16* __restrict__ query, // (B, H, D)
+ const __nv_bfloat16* __restrict__ kv_cache, // (num_blocks, P, Hkv, 2*D)
+ const int* __restrict__ block_table, // (B, max_blocks)
+ const int* __restrict__ seq_lens, // (B,)
+ __nv_bfloat16* __restrict__ output, // (B, H, D)
+ int B,
+ int H,
+ int Hkv,
+ int D,
+ int max_blocks,
+ int kv_cache_stride_block, // stride of block dim in bf16 elements
+ int kv_cache_stride_pos, // stride of position dim
+ int kv_cache_stride_head // stride of kv_head dim
+) {
+ int b = blockIdx.y;
+ int kv_head = blockIdx.x;
+ int group_size = H / Hkv;
+ int seq_len = seq_lens[b];
+ if (seq_len <= 0) return;
+
+ int num_pages = (seq_len + PAGE_SIZE - 1) / PAGE_SIZE;
+
+ __shared__ __nv_bfloat16 K_smem[PAGE_SIZE * MAX_D];
+ __shared__ __nv_bfloat16 V_smem[PAGE_SIZE * MAX_D];
+
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ int tid = threadIdx.x;
+ int num_threads = blockDim.x;
+
+ // Map each warp to a Q head in the group.
+ // With up to 8 warps (256 threads) and max group_size=8, each Q head
+ // gets one dedicated warp. If group_size < 8, extra warps idle.
+ int q_head_in_group = warp_id;
+ bool active = (q_head_in_group < group_size);
+ int q_head = kv_head * group_size + q_head_in_group;
+
+ // --- Pre-load Q for this warp's head into registers ---
+ float q_reg[MAX_D / WARP_SIZE];
+ if (active) {
+ #pragma unroll
+ for (int i = 0; i < D / WARP_SIZE; i++) {
+ int d = lane_id + i * WARP_SIZE;
+ if (d < D) {
+ q_reg[i] = __bfloat162float(query[b * H * D + q_head * D + d]);
+ } else {
+ q_reg[i] = 0.0f;
+ }
+ }
+ }
+
+ // --- Online-softmax state (fp32, per thread's stripe of D) ---
+ float m_global = -1e30f;
+ float l_global = 0.0f;
+ float acc[MAX_D / WARP_SIZE];
+ if (active) {
+ #pragma unroll
+ for (int i = 0; i < D / WARP_SIZE; i++) {
+ acc[i] = 0.0f;
+ }
+ }
+
+ float scale = rsqrtf(static_cast<float>(D));
+
+ // --- Page loop ---
+ for (int page_idx = 0; page_idx < num_pages; page_idx++) {
+ int block_idx = block_table[b * max_blocks + page_idx];
+ int tokens_this_page;
+ if (page_idx == num_pages - 1) {
+ int rem = seq_len % PAGE_SIZE;
+ tokens_this_page = (rem == 0) ? PAGE_SIZE : rem;
+ } else {
+ tokens_this_page = PAGE_SIZE;
+ }
+
+ // ----- Cooperative load of K and V into shared memory -----
+ // All threads (including idle warps) participate so we fill smem
+ // quickly with coalesced accesses.
+ int total_k_elems = PAGE_SIZE * D;
+ int elems_per_thread = (total_k_elems + num_threads - 1) / num_threads;
+
+ // Load K
+ for (int i = 0; i < elems_per_thread; i++) {
+ int idx = tid + i * num_threads;
+ if (idx < total_k_elems) {
+ int pos = idx / D;
+ int d = idx % D;
+ int gmem_offset = block_idx * kv_cache_stride_block
+ + pos * kv_cache_stride_pos
+ + kv_head * kv_cache_stride_head
+ + d;
+ K_smem[idx] = kv_cache[gmem_offset];
+ }
+ }
+
+ // Load V
+ for (int i = 0; i < elems_per_thread; i++) {
+ int idx = tid + i * num_threads;
+ if (idx < total_k_elems) {
+ int pos = idx / D;
+ int d = idx % D;
+ int gmem_offset = block_idx * kv_cache_stride_block
+ + pos * kv_cache_stride_pos
+ + kv_head * kv_cache_stride_head
+ + D + d;
+ V_smem[idx] = kv_cache[gmem_offset];
+ }
+ }
+
+ __syncthreads();
+
+ // ----- Per-warp attention computation -----
+ if (active) {
+ // Compute scores[t] = dot(Q, K[t]) * scale for every valid
+ // position in this tile, using warp shuffles for the reduction.
+ float scores[PAGE_SIZE];
+
+ float m_local = m_global;
+ float max_score = m_global;
+
+ #pragma unroll
+ for (int t = 0; t < PAGE_SIZE; t++) {
+ float dot = 0.0f;
+ if (t < tokens_this_page) {
+ #pragma unroll
+ for (int i = 0; i < D / WARP_SIZE; i++) {
+ int d = lane_id + i * WARP_SIZE;
+ if (d < D) {
+ dot += q_reg[i] * __bfloat162float(K_smem[t * D + d]);
+ }
+ }
+ }
+ // Full-warp reduction for
... (6117 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"content": "<11656 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/03_paged_attention/solution.py (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}Exit code 1
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/check.py", line 116, in <module>
main()
File "/workspace/problems/03_paged_attention/check.py", line 79, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 317, in forward
mod = _get_cuda_module()
^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 242, in _get_cuda_module
mod = cpp_ext.load_inline(
^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site
... (4269 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -33,7 +33,7 @@
constexpr int WARP_SIZE = 32;
constexpr int MAX_GROUP_SIZE = 8; // max num_heads/num_kv_heads across shapes
-extern "C" __global__ void paged_attention_decode_kernel(
+__global__ void paged_attention_decode_kernel(
const __nv_bfloat16* __restrict__ query, // (B, H, D)
const __nv_bfloat16* __restrict__ kv_cache, // (num_blocks, P, Hkv, 2*D)
const int* __restrict__ block_table, // (B, max_blocks)
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "extern \"C\" __global__ void paged_attention_decode_kernel(",
"new_string": "__global__ void paged_attention_decode_kernel("
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check after fixing linkage",
"timeout": 180000
}Exit code 1
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/check.py", line 116, in <module>
main()
File "/workspace/problems/03_paged_attention/check.py", line 79, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 317, in forward
mod = _get_cuda_module()
^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 242, in _get_cuda_module
mod = cpp_ext.load_inline(
^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site
... (4351 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py"
}1 """Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
2
3 CUDA C++ kernel with online softmax, shared-memory tiling, and vectorized loads.
4 """
5
6 import math
7 from typing import List
8
9 import torch
10 import torch.nn as nn
11
12 OP_TYPE = "attention"
13 SUPPORTED_PRECISIONS = ["bf16"]
14 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
15
16 BATCH = 8
17 NUM_HEADS = 32
18 NUM_KV_HEADS = 8
19 HEAD_DIM = 128
20 SEQ_LEN = 1024
21 PAGE_SIZE = 16
22
23 # ---------------------------------------------------------------------------
24 # CUDA kernel
25 # ---------------------------------------------------------------------------
26
27 _cuda_kernel_src = r"""
28 #include <cuda_fp16.h>
29 #include <cuda_bf16.h>
30
31 constexpr int PAGE_SIZE = 16;
32 constexpr int MAX_D = 128;
33 constexpr int WARP_SIZE = 32;
34 constexpr int MAX_GROUP_SIZE = 8; // max num_heads/num_kv_heads across shapes
35
36 __global__ void paged_attention_decode_kernel(
37 const __nv_bfloat16* __restrict__ query, // (B, H, D)
38 const __nv_bfloat16* __restrict__ kv_cache, // (num_blocks, P, Hkv, 2*D)
39 const int* __restrict__ block_table, // (B, max_blocks)
40 const int* __restrict__ seq_lens, // (B,)
41 __nv_bfloat16* __restrict__ output, // (B, H, D)
42 int B,
43 int H,
44 int Hkv,
45 int D,
46 int max_blocks,
47 int kv_cache_stride_block, // stride of block dim in bf16 elements
48 int
... (11529 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,13 +1,14 @@
"""Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
-CUDA C++ kernel with online softmax, shared-memory tiling, and vectorized loads.
+Triton-based decode kernel with online softmax and paged KV-cache gathering.
"""
import math
-from typing import List
import torch
import torch.nn as nn
+import triton
+import triton.language as tl
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
@@ -21,240 +22,194 @@
PAGE_SIZE = 16
# ---------------------------------------------------------------------------
-# CUDA kernel
-# ---------------------------------------------------------------------------
-
-_cuda_kernel_src = r"""
-#include <cuda_fp16.h>
-#include <cuda_bf16.h>
-
-constexpr int PAGE_SIZE = 16;
-constexpr int MAX_D = 128;
-constexpr int WARP_SIZE = 32;
-constexpr int MAX_GROUP_SIZE = 8; // max num_heads/num_kv_heads across shapes
-
-__global__ void paged_attention_decode_kernel(
- const __nv_bfloat16* __restrict__ query, // (B, H, D)
- const __nv_bfloat16* __restrict__ kv_cache, // (num_blocks, P, Hkv, 2*D)
- const int* __restrict__ block_table, // (B, max_blocks)
- const int* __restrict__ seq_lens, // (B,)
- __nv_bfloat16* __restrict__ output, // (B, H, D)
- int B,
- int H,
- int Hkv,
- int D,
- int max_blocks,
- int kv_cache_stride_block, // stride of block dim in bf16 elements
- int kv_cache_stride_pos, // stride of position dim
- int kv_cache_stride_head // stride of kv_head dim
-) {
- int b = blockIdx.y;
- int kv_head = blockIdx.x;
- int group_size = H / Hkv;
- int seq_len = seq_lens[b];
- if (seq_len <= 0) return;
-
- int num_pages = (seq_len + PAGE_SIZE - 1) / PAGE_SIZE;
-
- __shared__ __nv_bfloat16 K_smem[PAGE_SIZE * MAX_D];
- __shared__ __nv_bfloat16 V_smem[PAGE_SIZE * MAX_D];
-
- int warp_id = threadIdx.x / WARP_SIZE;
- int lane_id = threadIdx.x % WARP_SIZE;
- int tid = threadIdx.x;
- int num_threads = blockDim.x;
-
- // Map each warp to a Q head in the group.
- // With up to 8 warps (256 threads) and max group_size=8, each Q head
- // gets one dedicated warp. If group_size < 8, extra warps idle.
- int q_head_in_group = warp_id;
- bool active = (q_head_in_group < group_size);
- int q_head = kv_head * group_size + q_head_in_group;
-
- // --- Pre-load Q for this warp's head into registers ---
- float q_reg[MAX_D / WARP_SIZE];
- if (active) {
- #pragma unroll
- for (int i = 0; i < D / WARP_SIZE; i++) {
- int d = lane_id + i * WARP_SIZE;
- if (d < D) {
- q_reg[i] = __bfloat162float(query[b * H * D + q_head * D + d]);
- } else {
- q_reg[i] = 0.0f;
- }
- }
- }
-
- // --- Online-softmax state (fp32, per thread's stripe of D) ---
- float m_global = -1e30f;
- float l_global = 0.0f;
- float acc[MAX_D / WARP_SIZE];
- if (active) {
- #pragma unroll
- for (int i = 0; i < D / WARP_SIZE; i++) {
- acc[i] = 0.0f;
- }
- }
-
- float scale = rsqrtf(static_cast<float>(D));
-
- // --- Page loop ---
- for (int page_idx = 0; page_idx < num_pages; page_idx++) {
- int block_idx = block_table[b * max_blocks + page_idx];
- int tokens_this_page;
- if (page_idx == num_pages - 1) {
- int rem = seq_len % PAGE_SIZE;
- tokens_this_page = (rem == 0) ? PAGE_SIZE : rem;
- } else {
- tokens_this_page = PAGE_SIZE;
- }
-
- // ----- Cooperative load of K and V into shared memory -----
- // All threads (including idle warps) participate so we fill smem
- // quickly with coalesced accesses.
- int total_k_elems = PAGE_SIZE * D;
- int elems_per_thread = (total_k_elems + num_threads - 1) / num_threads;
-
- // Load K
- for (int i = 0; i < elems_per_thread; i++) {
- int idx = tid + i * num_threads;
- if (idx < total_k_elems) {
- int pos = idx / D;
- int d = idx % D;
- int gmem_offset = block_idx * kv_cache_stride_block
- + pos * kv_cache_stride_pos
- + kv_head * kv_cache_stride_head
- + d;
- K_smem[idx] = kv_cache[gmem_offset];
- }
- }
-
- // Load V
- for (int i = 0; i < elems_per_thread; i++) {
- int idx = tid + i * num_threads;
- if (idx < total_k_elems) {
- int pos = idx / D;
- int d = idx % D;
- int gmem_offset = block_idx * kv_cache_stride_block
- + pos * kv_cache_stride_pos
- + kv_head * kv_cache_stride_head
- + D + d;
- V_smem[idx] = kv_cache[gmem_offset];
- }
- }
-
- __syncthreads();
-
- // ----- Per-warp attention computation -----
- if (active) {
- // Compute scores[t] = dot(Q, K[t]) * scale for every valid
- // position in this tile, using warp shuffles for the reduction.
- float scores[PAGE_SIZE];
-
- float m_local = m_global;
- float max_score = m_global;
-
- #pragma unroll
- for (int t = 0; t < PAGE_SIZE; t++) {
- float dot = 0.0f;
- if (t < tokens_this_page) {
- #pragma unroll
- for (int i = 0; i < D / WARP_SIZE; i++) {
- int d = lane_id + i * WARP_SIZE;
- if (d < D) {
- dot += q_reg[i] * __bfloat162float(K_smem[t * D + d]);
- }
- }
- }
... (12836 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"content": "<10012 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check on Triton kernel",
"timeout": 300000
}Exit code 1
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/check.py", line 116, in <module>
main()
File "/workspace/problems/03_paged_attention/check.py", line 79, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 253, in forward
return _paged_attention_decode(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 199, in _paged_attention_decode
_paged_attention_decode_kernel[grid](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 240, in run
benchmark()
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 229, in benchmark
timings = {config: self._bench(*a
... (2618 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -25,8 +25,7 @@
# Triton kernel
# ---------------------------------------------------------------------------
-# Autotune over tile sizes along the sequence-length dimension.
-# Decode is memory-bound so the primary goal is saturating DRAM bandwidth.
+
@triton.autotune(
configs=[
triton.Config({"BLOCK_L": 16, "BLOCK_D": 64}, num_warps=4),
@@ -54,129 +53,105 @@
group_size: tl.constexpr,
):
"""
- Each program instance handles one (batch_element, kv_head) pair and
- iterates over the sequence in BLOCK_L-sized tiles.
-
- Within each tile, Q @ K^T is computed as a dot-product between the query
- rows of the group and the K tile, using online softmax rescaling, then
- the V tile is accumulated.
+ Each program handles one (batch, kv_head) pair.
+
+ Loads all Q heads in the GQA group together in one shot, then iterates
+ over KV-cache pages in BLOCK_L-sized tiles. Online softmax (fp32)
+ rescaling keeps accumulator state per Q head.
"""
- pid_b = tl.program_id(0) # batch index
- pid_kv = tl.program_id(1) # KV head index
-
- # Number of programs along the sequence-length dimension
+ pid_b = tl.program_id(0)
+ pid_kv = tl.program_id(1)
+
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
- # Pre-compute offsets that are constant across the tile loop
- offs_d = tl.arange(0, BLOCK_D) # dimension offsets (may exceed D)
- offs_l = tl.arange(0, BLOCK_L) # length offsets within a tile
-
- # Online softmax state — one per Q head in the group (group_size rows)
- # We store them as a [group_size, BLOCK_D] accumulator and per-row m/l.
- # Use fp32 for accumulation.
+ offs_d = tl.arange(0, BLOCK_D)
+ offs_l = tl.arange(0, BLOCK_L)
+ offs_g = tl.arange(0, group_size) # Q-heads within the GQA group
+
+ # --- Load all Q heads of this group at once ---
+ q_head_base = pid_kv * group_size
+ q_offs = (
+ pid_b * stride_q_b
+ + (q_head_base + offs_g[:, None]) * stride_q_h
+ + offs_d[None, :]
+ )
+ mask_q = offs_d[None, :] < D
+ q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32)
+ # q: (group_size, BLOCK_D)
+
+ # --- Online-softmax state (fp32) ---
m = tl.full([group_size], float("-inf"), dtype=tl.float32)
l_sum = tl.zeros([group_size], dtype=tl.float32)
-
- # Accumulator: (group_size, BLOCK_D) — padded to BLOCK_D
acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
scale = 1.0 / tl.sqrt(D.to(tl.float32))
- # Load Q for all heads in this group, shape (group_size, D)
- # q_head = kv_head * group_size + g
- kv_head = pid_kv
- q_heads_base = kv_head * group_size * D # offset within a batch's Q
-
- # Load Q rows: (group_size, D) padded to BLOCK_D
- q = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
- for g in range(group_size):
- q_head_idx = kv_head * group_size + g
- q_offs = pid_b * stride_q_b + q_head_idx * stride_q_h + offs_d
- mask_d = offs_d < D
- q_g = tl.load(query_ptr + q_offs, mask=mask_d, other=0.0)
- q[g, :] = q_g.to(tl.float32)
-
- # Page loop
+ # --- Page loop ---
for page_idx in range(num_pages):
blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
- # Tokens in this page: PAGE_SIZE except possibly the last page
if page_idx == num_pages - 1:
rem = seq_len % P
tokens_this_page = tl.where(rem == 0, P, rem)
else:
tokens_this_page = P
- # Tile loop within a page (a page may be processed as one tile
- # when BLOCK_L == P, but the autotuner may pick other sizes)
+ # Tile loop within a page
for l_start in range(0, P, BLOCK_L):
l_offs = l_start + offs_l
- # Which positions in the page are valid?
- # valid_l_mask: positions within this tile AND within the page token budget
valid_l = l_offs < tokens_this_page
- # --- Load K tile: (BLOCK_L, BLOCK_D) from this page, this KV head ---
- # Global offset for K:
- # kv_cache[blk_idx, pos, kv_head, d] for d in 0..D-1
- # = blk_idx * stride_kv_blk + pos * stride_kv_pos + kv_head * stride_kv_h + d
+ # K / V base offset for this block + kv_head
k_base = (
blk_idx.to(tl.int64) * stride_kv_blk
- + kv_head * stride_kv_h
+ + pid_kv * stride_kv_h
)
- # We need K[l_offs[pos], d] — l_offs[pos] is the in-page position, d is dimension
- # Reshape to 2D indexing: use l_offs for the pos dimension
- k_offs = k_base + (l_offs[:, None] * stride_kv_pos) + offs_d[None, :]
+
+ # --- Load K tile ---
+ k_offs = k_base + l_offs[:, None] * stride_kv_pos + offs_d[None, :]
mask_2d = valid_l[:, None] & (offs_d[None, :] < D)
k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_2d, other=0.0)
- k_tile_fp32 = k_tile.to(tl.float32) # (BLOCK_L, BLOCK_D)
-
- # --- Compute scores: (group_size, BLOCK_L) = Q @ K^T ---
- # Q: (group_size, BLOCK_D), K: (BLOCK_L, BLOCK_D)
- # scores[g, l] = sum_d Q[g, d] * K[l, d]
- scores = tl.dot(q, tl.trans(k_tile_fp32)) # (group_size, BLOCK_L)
- scores = scores * scale # apply scaling
-
- # Mask invalid positions to -inf for the softmax max-finding
+ k_f32 = k_tile.to(tl.float32) # (BLOCK_L, BLOCK_D)
+
+ # --- QK^T ---
+ scores = tl.dot(q, tl.trans(k_f32)) # (group_size, BLOCK_L)
+ scores = scores * scale
+ # Mask invalid positions
scores = tl.where(valid_l[None, :], scores, float("-inf"))
# --- Online
... (2833 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"content": "<7934 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000
}Exit code 1
FAIL: shape 0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16} seed 42 case small_q_kv: tolerance exceeded: max_abs_diff=0.012085 max_rel_diff=7053.51 bad=14384/32768 worst_idx=(4, 13, 90) (atol=0.0005, rtol=0.05){
"command": "<216 chars \u2014 see below>",
"description": "Check what stress cases exist and their tolerances"
}uv run python -c "
import sys
sys.path.insert(0, '/workspace')
from src.eval.numeric_stress import numeric_stress_cases
for case in numeric_stress_cases('03_paged_attention'):
print(f'{case.name}: {case}')
" 2>&1nominal: NumericStressCase(name='nominal', input_scales={}, state_scales={}, tolerance=None)
small_q_kv: NumericStressCase(name='small_q_kv', input_scales={0: 0.01, 1: 0.01}, state_scales={}, tolerance={'bfloat16': {'atol': 0.0005, 'rtol': 0.05}})
large_q_kv: NumericStressCase(name='large_q_kv', input_scales={0: 8.0, 1: 8.0}, state_scales={}, tolerance={'bfloat16': {'atol': 0.05, 'rtol': 0.05}}){
"command": "<1902 chars \u2014 see below>",
"description": "Debug correctness issue with small_q_kv case",
"timeout": 120000
}uv run python -c "
import sys
sys.path.insert(0, '/workspace')
import torch
import reference
import solution
from src.eval.numeric_stress import numeric_stress_cases, numeric_stress_context
import shapes
# Test shape 0, seed 42, small_q_kv case
shape = shapes.SHAPES[0]
reference.BATCH = shape['batch']
reference.NUM_HEADS = shape['num_heads']
reference.NUM_KV_HEADS = shape['num_kv_heads']
reference.HEAD_DIM = shape['head_dim']
reference.SEQ_LEN = shape['seq_len']
reference.PAGE_SIZE = shape['page_size']
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).cuda().eval()
sol_model = solution.Model(*init_args).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
base_inputs = [t.cuda() for t in reference.get_inputs()]
case = [c for c in numeric_stress_cases('03_paged_attention') if c.name == 'small_q_kv'][0]
with numeric_stress_context(ref_model, sol_model, base_inputs, case) as inputs:
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
# Just print basic stats
print('Ref out shape:', ref_out.shape)
print('Sol out shape:', sol_out.shape)
print('Ref out mean:', ref_out.float().mean().item())
print('Sol out mean:', sol_out.float().mean().item())
print('Max abs diff:', (ref_out.float() - sol_out.float()).abs().max().item())
diff = (ref_out.float() - sol_out.float()).abs()
print('Diff > 0.01 count:', (diff > 0.01).sum().item(), '/', diff.numel())
# Check which batch/head elements are bad
bad_mask = diff > 0.005
bad_indices = torch.nonzero(bad_mask)
if len(bad_indices) > 0:
print('First few bad indices (b, h, d):')
for idx in bad_indices[:10]:
b, h, d = idx.tolist()
print(f' [{b},{h},{d}] ref={ref_out[b,h,d].item():.6f} sol={sol_out[b,h,d].item():.6f} diff={diff[b,h,d].item():.6f}')
" 2>&1Ref out shape: torch.Size([8, 32, 128])
Sol out shape: torch.Size([8, 32, 128])
Ref out mean: -6.886963888064201e-08
Sol out mean: -6.972371124902565e-08
Max abs diff: 4.76837158203125e-07
Diff > 0.01 count: 0 / 32768{
"command": "uv run python check.py 2>&1",
"description": "Re-run correctness check",
"timeout": 300000
}Exit code 1
FAIL: shape 0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16} seed 42 case small_q_kv: tolerance exceeded: max_abs_diff=0.012085 max_rel_diff=7053.51 bad=14384/32768 worst_idx=(4, 13, 90) (atol=0.0005, rtol=0.05){
"command": "<2698 chars \u2014 see below>",
"description": "Debug the exact failure scenario",
"timeout": 120000
}uv run python -c "
import sys
sys.path.insert(0, '/workspace')
import torch
import reference
import solution
from src.eval.numeric_stress import numeric_stress_cases, numeric_stress_context, tolerance_for_case
from src.eval.correctness import check_correctness
import shapes
import yaml
from pathlib import Path
meta = yaml.safe_load(Path('problem.yaml').read_text())
device = torch.device('cuda:0')
shape = shapes.SHAPES[0]
reference.BATCH = shape['batch']
reference.NUM_HEADS = shape['num_heads']
reference.NUM_KV_HEADS = shape['num_kv_heads']
reference.HEAD_DIM = shape['head_dim']
reference.SEQ_LEN = shape['seq_len']
reference.PAGE_SIZE = shape['page_size']
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
for seed in (42,):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
base_inputs = [t.to(device) for t in reference.get_inputs()]
for case in numeric_stress_cases(meta.get('name', '')):
print(f'Case: {case.name}')
print(f' input_scales: {case.input_scales}')
with numeric_stress_context(ref_model, sol_model, base_inputs, case) as inputs:
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
print(f' ref_out stats: min={ref_out.float().min().item():.6f} max={ref_out.float().max().item():.6f} mean={ref_out.float().mean().item():.6f}')
print(f' sol_out stats: min={sol_out.float().min().item():.6f} max={sol_out.float().max().item():.6f} mean={sol_out.float().mean().item():.6f}')
tol_override = meta.get('tolerance') or None
tol = tolerance_for_case(tol_override, case)
print(f' tolerance: {tol}')
ok, msg = check_correctness(ref_out, sol_out, dtype=ref_out.dtype, override=tol)
if not ok:
print(f' FAIL: {msg}')
# Find where diffs are large
diff = (ref_out.float() - sol_out.float()).abs()
bad = diff > 0.001
print(f' bad elements: {bad.sum().item()}/{diff.numel()}')
if bad.sum() > 0:
worst = diff.argmax().item()
w_flat = worst
w_b = w_flat // (32*128)
w_h = (w_flat % (32*128)) // 128
w_d = w_flat % 128
print(f' worst element: [{w_b},{w_h},{w_d}] ref={ref_out.flatten()[worst].item():.8f} sol={sol_out.flatten()[worst].item():.8f} diff={diff.flatten()[worst].item():.8f}')
else:
print(f' OK')
" 2>&1Case: nominal
input_scales: {}
ref_out stats: min=-0.012207 max=0.011414 mean=-0.000007
sol_out stats: min=-0.012207 max=0.011414 mean=-0.000006
tolerance: {'bfloat16': 0.02, 'float16': 0.02}
OK
Case: small_q_kv
input_scales: {0: 0.01, 1: 0.01}
ref_out stats: min=-0.000122 max=0.000113 mean=-0.000000
sol_out stats: min=-0.012207 max=0.010437 mean=0.000005
tolerance: {'bfloat16': {'atol': 0.0005, 'rtol': 0.05}, 'float16': 0.02}
FAIL: tolerance exceeded: max_abs_diff=0.012085 max_rel_diff=7053.51 bad=14384/32768 worst_idx=(4, 13, 90) (atol=0.0005, rtol=0.05)
bad elements: 12165/32768
worst element: [4,13,90] ref=-0.00012207 sol=-0.01220703 diff=0.01208496
Case: large_q_kv
input_scales: {0: 8.0, 1: 8.0}
ref_out stats: min=-0.126953 max=0.121094 mean=-0.000008
sol_out stats: min=nan max=nan mean=nan
tolerance: {'bfloat16': {'atol': 0.05, 'rtol': 0.05}, 'float16': 0.02}
FAIL: solution contains NaN
bad elements: 31249/32768
worst element: [0,1,76] ref=0.00163269 sol=nan diff=nan{
"command": "<211 chars \u2014 see below>",
"description": "Check numeric_stress_context implementation"
}uv run python -c "
import sys
sys.path.insert(0, '/workspace')
from src.eval.numeric_stress import numeric_stress_context, NumericStressCase
import inspect
print(inspect.getsource(numeric_stress_context))
" 2>&1@contextmanager
def numeric_stress_context(
ref_model: torch.nn.Module,
sol_model: torch.nn.Module,
inputs: Sequence[object],
case: NumericStressCase,
) -> Iterator[list[object]]:
backups = _scale_states([ref_model, sol_model], case.state_scales)
try:
yield _scale_inputs(inputs, case.input_scales)
finally:
if torch.cuda.is_available():
torch.cuda.synchronize()
with torch.no_grad():
for tensor, original in backups:
tensor.copy_(original){
"command": "<174 chars \u2014 see below>",
"description": "Check _scale_inputs implementation"
}uv run python -c "
import sys
sys.path.insert(0, '/workspace')
from src.eval.numeric_stress import _scale_inputs
import inspect
print(inspect.getsource(_scale_inputs))
" 2>&1def _scale_inputs(inputs: Sequence[object], scales: Mapping[int, float]) -> list[object]:
out = list(inputs)
for idx, scale in scales.items():
if idx >= len(out):
raise IndexError(f"input scale index {idx} out of range for {len(out)} inputs")
value = out[idx]
if not isinstance(value, torch.Tensor):
raise TypeError(f"input {idx} is not a tensor and cannot be scaled")
out[idx] = _scale_tensor(value, scale)
return out{
"command": "<174 chars \u2014 see below>",
"description": "Check _scale_tensor implementation"
}uv run python -c "
import sys
sys.path.insert(0, '/workspace')
from src.eval.numeric_stress import _scale_tensor
import inspect
print(inspect.getsource(_scale_tensor))
" 2>&1def _scale_tensor(tensor: torch.Tensor, scale: float) -> torch.Tensor:
if not torch.is_floating_point(tensor):
raise TypeError(f"cannot scale non-floating tensor dtype={tensor.dtype}")
return (tensor.float() * scale).to(dtype=tensor.dtype, device=tensor.device){
"command": "<4845 chars \u2014 see below>",
"description": "Debug with non-autotuned kernel",
"timeout": 120000
}uv run python -c "
import sys
sys.path.insert(0, '/workspace')
import torch
import reference
import triton
import triton.language as tl
from src.eval.numeric_stress import numeric_stress_cases, numeric_stress_context, tolerance_for_case
from src.eval.correctness import check_correctness
import shapes, yaml
from pathlib import Path
meta = yaml.safe_load(Path('problem.yaml').read_text())
device = torch.device('cuda:0')
shape = shapes.SHAPES[0]
reference.BATCH = shape['batch']
reference.NUM_HEADS = shape['num_heads']
reference.NUM_KV_HEADS = shape['num_kv_heads']
reference.HEAD_DIM = shape['head_dim']
reference.SEQ_LEN = shape['seq_len']
reference.PAGE_SIZE = shape['page_size']
# Non-autotuned kernel
@triton.jit
def test_kernel(
query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr, output_ptr,
B, H, Hkv, D, max_blocks, P,
stride_q_b, stride_q_h,
stride_kv_blk, stride_kv_pos, stride_kv_h,
stride_bt_b,
BLOCK_L: tl.constexpr, BLOCK_D: tl.constexpr, group_size: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_kv = tl.program_id(1)
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
offs_d = tl.arange(0, BLOCK_D)
offs_l = tl.arange(0, BLOCK_L)
offs_g = tl.arange(0, group_size)
q_head_base = pid_kv * group_size
q_offs = pid_b * stride_q_b + (q_head_base + offs_g[:, None]) * stride_q_h + offs_d[None, :]
mask_q = offs_d[None, :] < D
q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32)
m = tl.full([group_size], float('-inf'), dtype=tl.float32)
l_sum = tl.zeros([group_size], dtype=tl.float32)
acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
scale = 1.0 / tl.sqrt(D.to(tl.float32))
for page_idx in range(num_pages):
blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
if page_idx == num_pages - 1:
rem = seq_len % P
tokens_this_page = tl.where(rem == 0, P, rem)
else:
tokens_this_page = P
for l_start in range(0, P, BLOCK_L):
l_offs = l_start + offs_l
valid_l = l_offs < tokens_this_page
k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
k_offs = k_base + l_offs[:, None] * stride_kv_pos + offs_d[None, :]
mask_2d = valid_l[:, None] & (offs_d[None, :] < D)
k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_2d, other=0.0).to(tl.float32)
scores = tl.dot(q, tl.trans(k_tile)) * scale
scores = tl.where(valid_l[None, :], scores, float('-inf'))
m_new = tl.maximum(m, tl.max(scores, axis=1))
rescale = tl.exp(m - m_new)
acc = acc * rescale[:, None]
l_sum = l_sum * rescale
p = tl.exp(scores - m_new[:, None])
p = tl.where(valid_l[None, :], p, 0.0)
l_sum = l_sum + tl.sum(p, axis=1)
v_offs = k_base + D + l_offs[:, None] * stride_kv_pos + offs_d[None, :]
v_tile = tl.load(kv_cache_ptr + v_offs, mask=mask_2d, other=0.0).to(tl.float32)
acc = acc + tl.dot(p, v_tile)
m = m_new
out_offs = pid_b * stride_q_b + (q_head_base + offs_g[:, None]) * stride_q_h + offs_d[None, :]
out_vals = acc / l_sum[:, None]
tl.store(output_ptr + out_offs, out_vals.to(tl.bfloat16), mask=mask_q)
# Run manual test
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
for case in numeric_stress_cases(meta.get('name', '')):
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
base_inputs = [t.to(device) for t in reference.get_inputs()]
with numeric_stress_context(ref_model, ref_model, base_inputs, case) as inputs:
q, kvc, bt, sl = inputs
B, H, D = q.shape
Hkv = kvc.shape[2]
P = kvc.shape[1]
group_size = H // Hkv
... (845 more chars)Exit code 1
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 25, in <module>
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file--- /dev/null
+++ b//workspace/problems/03_paged_attention/debug_kernel.py
@@ -0,0 +1,162 @@
+"""Debug script to compare ref vs sol for specific cases."""
+import sys
+sys.path.insert(0, '/workspace')
+import torch
+import reference
+import triton
+import triton.language as tl
+from src.eval.numeric_stress import numeric_stress_cases, numeric_stress_context, tolerance_for_case
+from src.eval.correctness import check_correctness
+import shapes, yaml
+from pathlib import Path
+
+meta = yaml.safe_load(Path('problem.yaml').read_text())
+device = torch.device('cuda:0')
+
+# Non-autotuned kernel with proper D handling
+@triton.jit
+def debug_kernel(
+ query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr, output_ptr,
+ B, H, Hkv, D, max_blocks, P,
+ stride_q_b, stride_q_h,
+ stride_kv_blk, stride_kv_pos, stride_kv_h,
+ stride_bt_b,
+ BLOCK_L: tl.constexpr, BLOCK_D: tl.constexpr, group_size: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_kv = tl.program_id(1)
+ seq_len = tl.load(seq_lens_ptr + pid_b)
+ num_pages = (seq_len + P - 1) // P
+ offs_l = tl.arange(0, BLOCK_L)
+ offs_g = tl.arange(0, group_size)
+
+ q_head_base = pid_kv * group_size
+
+ m = tl.full([group_size], float('-inf'), dtype=tl.float32)
+ l_sum = tl.zeros([group_size], dtype=tl.float32)
+ acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
+ scale = 1.0 / tl.sqrt(D.to(tl.float32))
+
+ for page_idx in range(num_pages):
+ blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
+ if page_idx == num_pages - 1:
+ rem = seq_len % P
+ tokens_this_page = tl.where(rem == 0, P, rem)
+ else:
+ tokens_this_page = P
+
+ for l_start in range(0, P, BLOCK_L):
+ l_offs = l_start + offs_l
+ valid_l = l_offs < tokens_this_page
+ k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
+
+ # Process D in BLOCK_D chunks, accumulating the full dot product
+ scores_sum = tl.zeros([group_size, BLOCK_L], dtype=tl.float32)
+ acc_update = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ offs_d = d_start + tl.arange(0, BLOCK_D)
+ mask_d = offs_d < D
+
+ # Load Q chunk
+ q_offs = pid_b * stride_q_b + (q_head_base + offs_g[:, None]) * stride_q_h + offs_d[None, :]
+ q_chunk = tl.load(query_ptr + q_offs, mask=mask_d[None, :], other=0.0).to(tl.float32)
+
+ # Load K chunk
+ k_offs = k_base + l_offs[:, None] * stride_kv_pos + offs_d[None, :]
+ mask_2d = valid_l[:, None] & mask_d[None, :]
+ k_chunk = tl.load(kv_cache_ptr + k_offs, mask=mask_2d, other=0.0).to(tl.float32)
+
+ # Partial dot product
+ scores_sum += tl.dot(q_chunk, tl.trans(k_chunk))
+
+ scores = scores_sum * scale
+ scores = tl.where(valid_l[None, :], scores, float('-inf'))
+
+ m_new = tl.maximum(m, tl.max(scores, axis=1))
+ rescale = tl.exp(m - m_new)
+ acc = acc * rescale[:, None]
+ l_sum = l_sum * rescale
+
+ p = tl.exp(scores - m_new[:, None])
+ p = tl.where(valid_l[None, :], p, 0.0)
+ l_sum = l_sum + tl.sum(p, axis=1)
+
+ # Accumulate V (also in D chunks)
+ for d_start in range(0, D, BLOCK_D):
+ offs_d = d_start + tl.arange(0, BLOCK_D)
+ mask_d = offs_d < D
+
+ v_offs = k_base + D + l_offs[:, None] * stride_kv_pos + offs_d[None, :]
+ mask_2d = valid_l[:, None] & mask_d[None, :]
+ v_chunk = tl.load(kv_cache_ptr + v_offs, mask=mask_2d, other=0.0).to(tl.float32)
+
+ acc_update_d = tl.dot(p, v_chunk) # (group_size, BLOCK_D)
+ # Scatter into acc — only valid D positions
+ acc_d = tl.where(mask_d[None, :], acc_update_d, 0.0)
+ acc += acc_d
+
+ m = m_new
+
+ # Write output (in D chunks)
+ for d_start in range(0, D, BLOCK_D):
+ offs_d = d_start + tl.arange(0, BLOCK_D)
+ mask_d = offs_d < D
+
+ out_offs = pid_b * stride_q_b + (q_head_base + offs_g[:, None]) * stride_q_h + offs_d[None, :]
+
+ # Extract the right portion of acc
+ acc_chunk = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
+ # Hmm, this is tricky — acc has shape (group_size, BLOCK_D) but we need
+ # to extract the d_start..d_start+BLOCK_D portion.
+ # Actually acc stores the full D-dim accumulation, but padded to BLOCK_D.
+ # Since BLOCK_D divides D (or we have multiple chunks), this gets complex.
+
+ out_vals = acc / l_sum[:, None]
+ tl.store(output_ptr + out_offs, out_vals.to(tl.bfloat16), mask=mask_d[None, :])
+
+
+# Test nominal case
+shape = shapes.SHAPES[0]
+reference.BATCH = shape['batch']
+reference.NUM_HEADS = shape['num_heads']
+reference.NUM_KV_HEADS = shape['num_kv_heads']
+reference.HEAD_DIM = shape['head_dim']
+reference.SEQ_LEN = shape['seq_len']
+reference.PAGE_SIZE = shape['page_size']
+
+init_args = reference.get_init_inputs()
+ref_model = reference.Model(*init_args).to(device).eval()
+
+for case in numeric_stress_cases(meta.get('name', '')):
+ torch.manual_seed(42)
+ torch.cuda.manual_seed_all(42)
+ base_inputs = [t.to(device) for t in reference.get_inputs()]
+
+ with numeric_stress_context(ref_model, ref_model, base_inputs, case) as inputs:
+ q, kvc, bt, sl = inputs
+ B, H, D = q.shape
+ Hkv = kvc.shape[2]
+ P = kvc.shape[1]
+ group_size = H // Hkv
+ out = torch.zeros(B, H, D, dtype=t
... (930 more chars){
"file_path": "/workspace/problems/03_paged_attention/debug_kernel.py",
"content": "<6675 chars \u2014 see diff>",
"description": "Debug kernel with proper D chunking",
"timeout": 120000
}<tool_use_error>InputValidationError: Write failed due to the following issues:
An unexpected parameter `description` was provided
An unexpected parameter `timeout` was provided</tool_use_error>{
"command": "<5282 chars \u2014 see below>",
"description": "Write and run debug kernel",
"timeout": 120000
}cat > /workspace/problems/03_paged_attention/debug_kernel.py << 'PYEOF'
"""Debug script to compare ref vs sol for specific cases."""
import sys
sys.path.insert(0, '/workspace')
import torch
import reference
import triton
import triton.language as tl
from src.eval.numeric_stress import numeric_stress_cases, numeric_stress_context, tolerance_for_case
from src.eval.correctness import check_correctness
import shapes, yaml
from pathlib import Path
meta = yaml.safe_load(Path('problem.yaml').read_text())
device = torch.device('cuda:0')
# Simple kernel: BLOCK_D exactly equals D, no chunking needed
@triton.jit
def debug_kernel(
query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr, output_ptr,
B, H, Hkv, D, max_blocks, P,
stride_q_b, stride_q_h,
stride_kv_blk, stride_kv_pos, stride_kv_h,
stride_bt_b,
BLOCK_D: tl.constexpr, group_size: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_kv = tl.program_id(1)
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
offs_l = tl.arange(0, P) # exactly PAGE_SIZE positions per tile
offs_g = tl.arange(0, group_size)
offs_d = tl.arange(0, BLOCK_D)
q_head_base = pid_kv * group_size
m = tl.full([group_size], float('-inf'), dtype=tl.float32)
l_sum = tl.zeros([group_size], dtype=tl.float32)
acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
scale = 1.0 / tl.sqrt(D.to(tl.float32))
for page_idx in range(num_pages):
blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
if page_idx == num_pages - 1:
rem = seq_len % P
tokens_this_page = tl.where(rem == 0, P, rem)
else:
tokens_this_page = P
l_offs = offs_l
valid_l = l_offs < tokens_this_page
k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
# Load Q
mask_q = offs_d[None, :] < D
q_offs = pid_b * stride_q_b + (q_head_base + offs_g[:, None]) * stride_q_h + offs_d[None, :]
q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32)
# Load K
k_offs = k_base + l_offs[:, None] * stride_kv_pos + offs_d[None, :]
mask_k = valid_l[:, None] & (offs_d[None, :] < D)
k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_k, other=0.0).to(tl.float32)
scores = tl.dot(q, tl.trans(k_tile)) * scale
scores = tl.where(valid_l[None, :], scores, float('-inf'))
m_new = tl.maximum(m, tl.max(scores, axis=1))
rescale = tl.exp(m - m_new)
acc = acc * rescale[:, None]
l_sum = l_sum * rescale
p = tl.exp(scores - m_new[:, None])
p = tl.where(valid_l[None, :], p, 0.0)
l_sum = l_sum + tl.sum(p, axis=1)
# Load V
v_offs = k_base + D + l_offs[:, None] * stride_kv_pos + offs_d[None, :]
v_tile = tl.load(kv_cache_ptr + v_offs, mask=mask_k, other=0.0).to(tl.float32)
acc = acc + tl.dot(p, v_tile)
m = m_new
out_offs = pid_b * stride_q_b + (q_head_base + offs_g[:, None]) * stride_q_h + offs_d[None, :]
out_vals = acc / l_sum[:, None]
tl.store(output_ptr + out_offs, out_vals.to(tl.bfloat16), mask=mask_q)
# Test case
shape = shapes.SHAPES[0]
reference.BATCH = shape['batch']
reference.NUM_HEADS = shape['num_heads']
reference.NUM_KV_HEADS = shape['num_kv_heads']
reference.HEAD_DIM = shape['head_dim']
reference.SEQ_LEN = shape['seq_len']
reference.PAGE_SIZE = shape['page_size']
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
for case in numeric_stress_cases(meta.get('name', '')):
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
base_inputs = [t.to(device) for t in reference.get_inputs()]
with numeric_stress_context(ref_model, ref_model, base_inputs, case) as inputs:
q, kvc, bt, sl = inputs
B, H, D = q.shape
Hkv
... (1282 more chars)Command running in background with ID: bq06fx2zv. Output is being written to: /tmp/claude-1002/-workspace-problems-03-paged-attention/0ad3b801-0ba8-4dab-9164-459bde2e792c/tasks/bq06fx2zv.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bq06fx2zv",
"block": true,
"timeout": 120000
}<retrieval_status>success</retrieval_status>
<task_id>bq06fx2zv</task_id>
<task_type>local_bash</task_type>
<status>failed</status>
<exit_code>1</exit_code>
<output>
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/debug_kernel.py", line 112, in <module>
debug_kernel[(B, Hkv)](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^
... (1442 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,6 +1,7 @@
"""Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
Triton-based decode kernel with online softmax and paged KV-cache gathering.
+Uses one block per (batch, kv_head) pair, processes all GQA query heads together.
"""
import math
@@ -24,40 +25,56 @@
# ---------------------------------------------------------------------------
# Triton kernel
# ---------------------------------------------------------------------------
+#
+# Design:
+# - Grid: (batch, num_kv_heads)
+# - Each program handles one GQA group (group_size query heads sharing
+# one KV head) for one batch element.
+# - Q heads are loaded together and all dot products computed vectorised.
+# - KV pages are iterated one at a time (PAGE_SIZE positions each).
+# - Online softmax (fp32) prevents materialising the full attention matrix.
+# - BLOCK_D is set to D at launch time so no dimension chunking is needed.
+# - BLOCK_L is exactly PAGE_SIZE (16) — a single page fits in one tile.
@triton.autotune(
configs=[
- triton.Config({"BLOCK_L": 16, "BLOCK_D": 64}, num_warps=4),
- triton.Config({"BLOCK_L": 16, "BLOCK_D": 64}, num_warps=8),
- triton.Config({"BLOCK_L": 16, "BLOCK_D": 128}, num_warps=4),
- triton.Config({"BLOCK_L": 16, "BLOCK_D": 128}, num_warps=8),
- triton.Config({"BLOCK_L": 32, "BLOCK_D": 64}, num_warps=4),
- triton.Config({"BLOCK_L": 32, "BLOCK_D": 64}, num_warps=8),
+ triton.Config({"BLOCK_D": 64}, num_warps=4),
+ triton.Config({"BLOCK_D": 64}, num_warps=8),
+ triton.Config({"BLOCK_D": 128}, num_warps=4),
+ triton.Config({"BLOCK_D": 128}, num_warps=8),
],
key=["D", "group_size"],
)
@triton.jit
def _paged_attention_decode_kernel(
- query_ptr, # (B, H, D) bf16
- kv_cache_ptr, # (num_blocks, P, Hkv, 2*D) bf16
- block_table_ptr, # (B, max_blocks) int32
- seq_lens_ptr, # (B,) int32
- output_ptr, # (B, H, D) bf16
- B, H, Hkv, D, max_blocks, P,
- stride_q_b, stride_q_h,
- stride_kv_blk, stride_kv_pos, stride_kv_h,
+ query_ptr,
+ kv_cache_ptr,
+ block_table_ptr,
+ seq_lens_ptr,
+ output_ptr,
+ B,
+ H,
+ Hkv,
+ D,
+ max_blocks,
+ P,
+ stride_q_b,
+ stride_q_h,
+ stride_kv_blk,
+ stride_kv_pos,
+ stride_kv_h,
stride_bt_b,
- BLOCK_L: tl.constexpr,
BLOCK_D: tl.constexpr,
group_size: tl.constexpr,
):
"""
- Each program handles one (batch, kv_head) pair.
-
- Loads all Q heads in the GQA group together in one shot, then iterates
- over KV-cache pages in BLOCK_L-sized tiles. Online softmax (fp32)
- rescaling keeps accumulator state per Q head.
+ Paged attention decode — one (batch, kv_head) per program instance.
+
+ BLOCK_D must be >= D (the kernel relies on loading all D elements at once).
+ The autotuner is restricted to BLOCK_D values that cover the actual D:
+ - D=64 → BLOCK_D ∈ {64, 128}
+ - D=128 → BLOCK_D ∈ {128}
"""
pid_b = tl.program_id(0)
pid_kv = tl.program_id(1)
@@ -65,11 +82,12 @@
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
- offs_d = tl.arange(0, BLOCK_D)
- offs_l = tl.arange(0, BLOCK_L)
- offs_g = tl.arange(0, group_size) # Q-heads within the GQA group
-
- # --- Load all Q heads of this group at once ---
+ # Dimension / position offset vectors (compile-time shapes)
+ offs_d = tl.arange(0, BLOCK_D) # [0 .. BLOCK_D-1]
+ offs_l = tl.arange(0, P) # [0 .. 15] (PAGE_SIZE is always 16)
+ offs_g = tl.arange(0, group_size) # [0 .. group_size-1]
+
+ # --- Load all Q heads of this group in one shot ---
q_head_base = pid_kv * group_size
q_offs = (
pid_b * stride_q_b
@@ -78,7 +96,7 @@
)
mask_q = offs_d[None, :] < D
q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32)
- # q: (group_size, BLOCK_D)
+ # q: (group_size, BLOCK_D), zero-padded beyond D
# --- Online-softmax state (fp32) ---
m = tl.full([group_size], float("-inf"), dtype=tl.float32)
@@ -91,67 +109,65 @@
for page_idx in range(num_pages):
blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
+ # Tokens in this page: PAGE_SIZE for all but the last page
if page_idx == num_pages - 1:
rem = seq_len % P
tokens_this_page = tl.where(rem == 0, P, rem)
else:
tokens_this_page = P
- # Tile loop within a page
- for l_start in range(0, P, BLOCK_L):
- l_offs = l_start + offs_l
- valid_l = l_offs < tokens_this_page
-
- # K / V base offset for this block + kv_head
- k_base = (
- blk_idx.to(tl.int64) * stride_kv_blk
- + pid_kv * stride_kv_h
- )
-
- # --- Load K tile ---
- k_offs = k_base + l_offs[:, None] * stride_kv_pos + offs_d[None, :]
- mask_2d = valid_l[:, None] & (offs_d[None, :] < D)
- k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_2d, other=0.0)
- k_f32 = k_tile.to(tl.float32) # (BLOCK_L, BLOCK_D)
-
- # --- QK^T ---
- scores = tl.dot(q, tl.trans(k_f32)) # (group_size, BLOCK_L)
- scores = scores * scale
- # Mask invalid positions
- scores = tl.where(valid_l[None, :], scores, float("-inf"))
-
- # --- Online softmax update ---
- m_new = tl.maximum(m, tl.max(scores, axis=1)) # (group_size,)
-
- rescale = tl.exp(m - m_new) # (group_size,)
- acc = acc * rescale[:, None]
- l_sum = l_sum * rescale
-
- p = tl.exp(scores - m_new[:, None]) # (group_size, BLOCK_L)
-
... (3205 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"content": "<8423 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -26,15 +26,14 @@
# Triton kernel
# ---------------------------------------------------------------------------
#
-# Design:
-# - Grid: (batch, num_kv_heads)
-# - Each program handles one GQA group (group_size query heads sharing
-# one KV head) for one batch element.
-# - Q heads are loaded together and all dot products computed vectorised.
-# - KV pages are iterated one at a time (PAGE_SIZE positions each).
-# - Online softmax (fp32) prevents materialising the full attention matrix.
-# - BLOCK_D is set to D at launch time so no dimension chunking is needed.
-# - BLOCK_L is exactly PAGE_SIZE (16) — a single page fits in one tile.
+# Key design decisions:
+# - Grid: (batch, num_kv_heads) — one program per GQA group per batch element.
+# - BLOCK_L = PAGE_SIZE = 16 (compile-time constant, matches the page size).
+# - BLOCK_D is autotuned over {64, 128, 256} but the kernel handles the
+# case BLOCK_D < D by accumulating partial dot products across D-chunks.
+# - All Q heads in a GQA group are loaded together and participate in
+# vectorised tl.dot operations.
+# - Online softmax in fp32 prevents materialising the L×H attention matrix.
@triton.autotune(
@@ -69,12 +68,10 @@
group_size: tl.constexpr,
):
"""
- Paged attention decode — one (batch, kv_head) per program instance.
-
- BLOCK_D must be >= D (the kernel relies on loading all D elements at once).
- The autotuner is restricted to BLOCK_D values that cover the actual D:
- - D=64 → BLOCK_D ∈ {64, 128}
- - D=128 → BLOCK_D ∈ {128}
+ Paged attention decode kernel.
+
+ One program per (batch_element, kv_head). Iterates over KV-cache pages
+ in order, loading one full page (16 positions) at a time.
"""
pid_b = tl.program_id(0)
pid_kv = tl.program_id(1)
@@ -82,13 +79,175 @@
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
- # Dimension / position offset vectors (compile-time shapes)
- offs_d = tl.arange(0, BLOCK_D) # [0 .. BLOCK_D-1]
- offs_l = tl.arange(0, P) # [0 .. 15] (PAGE_SIZE is always 16)
+ # The page size is always 16 for all shapes — bake it as the tile length.
+ BLOCK_L: tl.constexpr = 16
+
+ offs_l = tl.arange(0, BLOCK_L) # [0 .. 15]
offs_g = tl.arange(0, group_size) # [0 .. group_size-1]
- # --- Load all Q heads of this group in one shot ---
q_head_base = pid_kv * group_size
+
+ # --- Online-softmax state (fp32) ---
+ m = tl.full([group_size], float("-inf"), dtype=tl.float32)
+ l_sum = tl.zeros([group_size], dtype=tl.float32)
+ acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
+
+ scale = 1.0 / tl.sqrt(D.to(tl.float32))
+
+ # --- Page loop ---
+ for page_idx in range(num_pages):
+ blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
+
+ if page_idx == num_pages - 1:
+ rem = seq_len % P
+ tokens_this_page = tl.where(rem == 0, P, rem)
+ else:
+ tokens_this_page = P
+
+ valid_l = offs_l < tokens_this_page
+ k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
+
+ # ---- QK^T: accumulate full dot product across D-chunks ----
+ # scores_sum accumulates (group_size, BLOCK_L) across BLOCK_D chunks
+ scores_sum = tl.zeros([group_size, BLOCK_L], dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ offs_d = d_start + tl.arange(0, BLOCK_D)
+ mask_d = offs_d < D
+
+ # Load Q chunk: (group_size, BLOCK_D)
+ q_offs = (
+ pid_b * stride_q_b
+ + (q_head_base + offs_g[:, None]) * stride_q_h
+ + offs_d[None, :]
+ )
+ q_chunk = tl.load(query_ptr + q_offs, mask=mask_d[None, :], other=0.0)
+ q_f32 = q_chunk.to(tl.float32)
+
+ # Load K chunk: (BLOCK_L, BLOCK_D)
+ k_offs = k_base + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
+ mask_k = valid_l[:, None] & mask_d[None, :]
+ k_chunk = tl.load(kv_cache_ptr + k_offs, mask=mask_k, other=0.0)
+ k_f32 = k_chunk.to(tl.float32)
+
+ # Partial dot product: (group_size, BLOCK_L)
+ scores_sum += tl.dot(q_f32, tl.trans(k_f32))
+
+ scores = scores_sum * scale
+ scores = tl.where(valid_l[None, :], scores, float("-inf"))
+
+ # ---- Online softmax update ----
+ m_new = tl.maximum(m, tl.max(scores, axis=1))
+
+ rescale = tl.exp(m - m_new)
+ acc = acc * rescale[:, None]
+ l_sum = l_sum * rescale
+
+ p = tl.exp(scores - m_new[:, None]) # (group_size, BLOCK_L)
+ p = tl.where(valid_l[None, :], p, 0.0)
+ l_sum = l_sum + tl.sum(p, axis=1)
+
+ # ---- Accumulate V across D-chunks ----
+ for d_start in range(0, D, BLOCK_D):
+ offs_d = d_start + tl.arange(0, BLOCK_D)
+ mask_d = offs_d < D
+
+ # Load V chunk: (BLOCK_L, BLOCK_D)
+ v_offs = (
+ k_base + D
+ + offs_l[:, None] * stride_kv_pos
+ + offs_d[None, :]
+ )
+ mask_v = valid_l[:, None] & mask_d[None, :]
+ v_chunk = tl.load(kv_cache_ptr + v_offs, mask=mask_v, other=0.0)
+ v_f32 = v_chunk.to(tl.float32)
+
+ # p: (group_size, BLOCK_L), v_f32: (BLOCK_L, BLOCK_D) → (group_size, BLOCK_D)
+ acc_update = tl.dot(p, v_f32)
+ # Only add to the valid D positions within this chunk
+ acc += tl.where(mask_d[None, :], acc_update, 0.0)
+
+ m = m_new
+
+ # ---- Write output across D-chunks ----
+ l_safe = tl.where(l_sum == 0.0, 1.0, l_sum)
+
+ for d_start in range(0, D, BLOCK_D):
+ offs_d = d_start + tl.arange(0, BLOCK_D)
+ mask_d = o
... (6200 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"content": "<13519 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -26,25 +26,15 @@
# Triton kernel
# ---------------------------------------------------------------------------
#
-# Key design decisions:
+# Design:
# - Grid: (batch, num_kv_heads) — one program per GQA group per batch element.
-# - BLOCK_L = PAGE_SIZE = 16 (compile-time constant, matches the page size).
-# - BLOCK_D is autotuned over {64, 128, 256} but the kernel handles the
-# case BLOCK_D < D by accumulating partial dot products across D-chunks.
-# - All Q heads in a GQA group are loaded together and participate in
-# vectorised tl.dot operations.
-# - Online softmax in fp32 prevents materialising the L×H attention matrix.
-
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_D": 64}, num_warps=4),
- triton.Config({"BLOCK_D": 64}, num_warps=8),
- triton.Config({"BLOCK_D": 128}, num_warps=4),
- triton.Config({"BLOCK_D": 128}, num_warps=8),
- ],
- key=["D", "group_size"],
-)
+# - Each program loads all Q heads in its GQA group at once and iterates
+# over KV-cache pages one page (16 positions) at a time.
+# - BLOCK_L = PAGE_SIZE = 16 (compile-time constant).
+# - BLOCK_D = D (pinned at launch time so the full head_dim is processed).
+# - Online softmax in fp32 avoids materialising the full scores matrix.
+
+
@triton.jit
def _paged_attention_decode_kernel(
query_ptr,
@@ -68,167 +58,6 @@
group_size: tl.constexpr,
):
"""
- Paged attention decode kernel.
-
- One program per (batch_element, kv_head). Iterates over KV-cache pages
- in order, loading one full page (16 positions) at a time.
- """
- pid_b = tl.program_id(0)
- pid_kv = tl.program_id(1)
-
- seq_len = tl.load(seq_lens_ptr + pid_b)
- num_pages = (seq_len + P - 1) // P
-
- # The page size is always 16 for all shapes — bake it as the tile length.
- BLOCK_L: tl.constexpr = 16
-
- offs_l = tl.arange(0, BLOCK_L) # [0 .. 15]
- offs_g = tl.arange(0, group_size) # [0 .. group_size-1]
-
- q_head_base = pid_kv * group_size
-
- # --- Online-softmax state (fp32) ---
- m = tl.full([group_size], float("-inf"), dtype=tl.float32)
- l_sum = tl.zeros([group_size], dtype=tl.float32)
- acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
-
- scale = 1.0 / tl.sqrt(D.to(tl.float32))
-
- # --- Page loop ---
- for page_idx in range(num_pages):
- blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
-
- if page_idx == num_pages - 1:
- rem = seq_len % P
- tokens_this_page = tl.where(rem == 0, P, rem)
- else:
- tokens_this_page = P
-
- valid_l = offs_l < tokens_this_page
- k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
-
- # ---- QK^T: accumulate full dot product across D-chunks ----
- # scores_sum accumulates (group_size, BLOCK_L) across BLOCK_D chunks
- scores_sum = tl.zeros([group_size, BLOCK_L], dtype=tl.float32)
-
- for d_start in range(0, D, BLOCK_D):
- offs_d = d_start + tl.arange(0, BLOCK_D)
- mask_d = offs_d < D
-
- # Load Q chunk: (group_size, BLOCK_D)
- q_offs = (
- pid_b * stride_q_b
- + (q_head_base + offs_g[:, None]) * stride_q_h
- + offs_d[None, :]
- )
- q_chunk = tl.load(query_ptr + q_offs, mask=mask_d[None, :], other=0.0)
- q_f32 = q_chunk.to(tl.float32)
-
- # Load K chunk: (BLOCK_L, BLOCK_D)
- k_offs = k_base + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
- mask_k = valid_l[:, None] & mask_d[None, :]
- k_chunk = tl.load(kv_cache_ptr + k_offs, mask=mask_k, other=0.0)
- k_f32 = k_chunk.to(tl.float32)
-
- # Partial dot product: (group_size, BLOCK_L)
- scores_sum += tl.dot(q_f32, tl.trans(k_f32))
-
- scores = scores_sum * scale
- scores = tl.where(valid_l[None, :], scores, float("-inf"))
-
- # ---- Online softmax update ----
- m_new = tl.maximum(m, tl.max(scores, axis=1))
-
- rescale = tl.exp(m - m_new)
- acc = acc * rescale[:, None]
- l_sum = l_sum * rescale
-
- p = tl.exp(scores - m_new[:, None]) # (group_size, BLOCK_L)
- p = tl.where(valid_l[None, :], p, 0.0)
- l_sum = l_sum + tl.sum(p, axis=1)
-
- # ---- Accumulate V across D-chunks ----
- for d_start in range(0, D, BLOCK_D):
- offs_d = d_start + tl.arange(0, BLOCK_D)
- mask_d = offs_d < D
-
- # Load V chunk: (BLOCK_L, BLOCK_D)
- v_offs = (
- k_base + D
- + offs_l[:, None] * stride_kv_pos
- + offs_d[None, :]
- )
- mask_v = valid_l[:, None] & mask_d[None, :]
- v_chunk = tl.load(kv_cache_ptr + v_offs, mask=mask_v, other=0.0)
- v_f32 = v_chunk.to(tl.float32)
-
- # p: (group_size, BLOCK_L), v_f32: (BLOCK_L, BLOCK_D) → (group_size, BLOCK_D)
- acc_update = tl.dot(p, v_f32)
- # Only add to the valid D positions within this chunk
- acc += tl.where(mask_d[None, :], acc_update, 0.0)
-
- m = m_new
-
- # ---- Write output across D-chunks ----
- l_safe = tl.where(l_sum == 0.0, 1.0, l_sum)
-
- for d_start in range(0, D, BLOCK_D):
- offs_d = d_start + tl.arange(0, BLOCK_D)
- mask_d = offs_d < D
-
- out_offs = (
- pid_b * stride_q_b
- + (q_head_base + offs_g[:, None]) * stride_q_h
- + offs_d[None, :]
- )
-
- # acc is (group_size, BLOCK_D); each chunk uses the same slice.
- # Because BLOCK_D is the same for all chunks, acc[] covers the whole D.
- # But wait — acc only has BLOCK_D columns. If D > BLOCK_D, we need
- # multiple chunk
... (1625 more chars){
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<6824 chars \u2014 see diff>",
"new_string": "<657 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -160,9 +160,7 @@
out = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device)
grid = (B, Hkv)
- # Use the exact-match kernel that pins BLOCK_D = D.
- # This is always correct and avoids the D-chunk accumulator problem.
- _paged_attention_decode_kernel_exact[grid](
+ _paged_attention_decode_kernel[grid](
query,
kv_cache,
block_table,
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<176 chars \u2014 see diff>",
"new_string": " _paged_attention_decode_kernel[grid]("
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"file_path": "/workspace/problems/03_paged_attention/solution.py"
}1 """Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
2
3 Triton-based decode kernel with online softmax and paged KV-cache gathering.
4 Uses one block per (batch, kv_head) pair, processes all GQA query heads together.
5 """
6
7 import math
8
9 import torch
10 import torch.nn as nn
11 import triton
12 import triton.language as tl
13
14 OP_TYPE = "attention"
15 SUPPORTED_PRECISIONS = ["bf16"]
16 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
17
18 BATCH = 8
19 NUM_HEADS = 32
20 NUM_KV_HEADS = 8
21 HEAD_DIM = 128
22 SEQ_LEN = 1024
23 PAGE_SIZE = 16
24
25 # ---------------------------------------------------------------------------
26 # Triton kernel
27 # ---------------------------------------------------------------------------
28 #
29 # Design:
30 # - Grid: (batch, num_kv_heads) — one program per GQA group per batch element.
31 # - Each program loads all Q heads in its GQA group at once and iterates
32 # over KV-cache pages one page (16 positions) at a time.
33 # - BLOCK_L = PAGE_SIZE = 16 (compile-time constant).
34 # - BLOCK_D = D (pinned at launch time so the full head_dim is processed).
35 # - Online softmax in fp32 avoids materialising the full scores matrix.
36
37
38 @triton.jit
39 def _paged_attention_decode_kernel(
40 query_ptr,
41 kv_cache_ptr,
42 block_table_ptr,
43 seq_lens_ptr,
44 output_ptr,
45 B,
46 H,
47 Hkv,
48 D,
49 max_blocks,
50 P,
51 stride_q_b,
52 stride_q_h,
... (6653 more chars){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000
}PASS{
"command": "uv run python benchmark.py 2>&1",
"description": "Run performance benchmark",
"timeout": 600000
}benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T19:15:11.178986+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T19:15:11.384002+00:00 elapsed_s=0.205 ms=0.088368
shape=0 variant=solution tflops=1.519 gbps=381.196 ms=0.088
shape=0 solution_peak_fraction=0.2118
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T19:15:13.015098+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T19:15:13.026372+00:00 elapsed_s=0.011 ms=0.205072
shape=1 variant=solution tflops=5.236 gbps=1311.538 ms=0.205
shape=1 solution_peak_fraction=0.7286
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T19:15:13.436805+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T19:15:13.450724+00:00 elapsed_s=0.014 ms=0.273808
shape=2 variant=solution tflops=1.961 gbps=245.573 ms=0.274
shape=2 solution_peak_fraction=0.1364
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T19:15:14.068889+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T19:15:14.075657+00:00 elapsed_s=0.007 ms=0.128096
shape=3 variant=solution tflops=3.141 gbps=787.377 ms=0.128
shape=3 solution_peak_fraction=0.4374
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T19:15:14.177130+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T19:15:14.185203+00:00 elapsed_s=0.008 ms=0.119024
shape=4
... (123 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,7 +1,7 @@
"""Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
Triton-based decode kernel with online softmax and paged KV-cache gathering.
-Uses one block per (batch, kv_head) pair, processes all GQA query heads together.
+One thread block per query head for maximum parallelism.
"""
import math
@@ -23,16 +23,8 @@
PAGE_SIZE = 16
# ---------------------------------------------------------------------------
-# Triton kernel
-# ---------------------------------------------------------------------------
-#
-# Design:
-# - Grid: (batch, num_kv_heads) — one program per GQA group per batch element.
-# - Each program loads all Q heads in its GQA group at once and iterates
-# over KV-cache pages one page (16 positions) at a time.
-# - BLOCK_L = PAGE_SIZE = 16 (compile-time constant).
-# - BLOCK_D = D (pinned at launch time so the full head_dim is processed).
-# - Online softmax in fp32 avoids materialising the full scores matrix.
+# Triton kernel — one block per query head
+# ---------------------------------------------------------------------------
@triton.jit
@@ -58,13 +50,16 @@
group_size: tl.constexpr,
):
"""
- Exact-match variant: BLOCK_D == D.
-
- One program per (batch_element, kv_head). KV-cache pages are loaded
- one at a time (16 positions per page, entire D dimension at once).
+ Paged attention decode — one program per (batch, query_head).
+
+ BLOCK_D == D (pinned at launch). Each program loads one page at a time
+ (BLOCK_L = PAGE_SIZE = 16 positions), computes dot products with its
+ single Q head, and maintains online-softmax state in fp32.
"""
- pid_b = tl.program_id(0)
- pid_kv = tl.program_id(1)
+ pid_q = tl.program_id(0) # query head index
+ pid_b = tl.program_id(1) # batch index
+
+ kv_head = pid_q // group_size
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
@@ -72,25 +67,21 @@
BLOCK_L: tl.constexpr = 16
offs_l = tl.arange(0, BLOCK_L)
offs_d = tl.arange(0, BLOCK_D)
- offs_g = tl.arange(0, group_size)
-
- q_head_base = pid_kv * group_size
-
- # Load all Q heads at once: (group_size, D)
- q_offs = (
- pid_b * stride_q_b
- + (q_head_base + offs_g[:, None]) * stride_q_h
- + offs_d[None, :]
- )
- mask_q = offs_d[None, :] < D
+
+ # --- Load single Q head ---
+ q_offs = pid_b * stride_q_b + pid_q * stride_q_h + offs_d
+ mask_q = offs_d < D
q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32)
-
- m = tl.full([group_size], float("-inf"), dtype=tl.float32)
- l_sum = tl.zeros([group_size], dtype=tl.float32)
- acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
+ # q: (BLOCK_D,) with zeros beyond D
+
+ # --- Online-softmax state ---
+ m = float("-inf")
+ l_sum = 0.0
+ acc = tl.zeros([BLOCK_D], dtype=tl.float32)
scale = 1.0 / tl.sqrt(D.to(tl.float32))
+ # --- Page loop ---
for page_idx in range(num_pages):
blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
@@ -101,42 +92,48 @@
tokens_this_page = P
valid_l = offs_l < tokens_this_page
- k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
-
- # K tile: (BLOCK_L, D)
+ k_base = blk_idx.to(tl.int64) * stride_kv_blk + kv_head * stride_kv_h
+
+ # --- Load K tile: (BLOCK_L, BLOCK_D) ---
k_offs = k_base + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
mask_kv = valid_l[:, None] & (offs_d[None, :] < D)
k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_kv, other=0.0).to(tl.float32)
- # Score = Q @ K^T : (group_size, BLOCK_L)
- scores = tl.dot(q, tl.trans(k_tile)) * scale
- scores = tl.where(valid_l[None, :], scores, float("-inf"))
-
- # Online softmax
- m_new = tl.maximum(m, tl.max(scores, axis=1))
+ # --- Scores: (BLOCK_L,) = q @ K^T ---
+ # tl.dot((1, BLOCK_D), (BLOCK_D, BLOCK_L)) → (1, BLOCK_L)
+ scores = tl.dot(
+ tl.reshape(q, (1, BLOCK_D)),
+ tl.trans(k_tile),
+ )
+ scores = tl.reshape(scores, (BLOCK_L,)) * scale
+ scores = tl.where(valid_l, scores, float("-inf"))
+
+ # --- Online softmax update ---
+ m_new = tl.maximum(m, tl.max(scores, axis=0))
+
rescale = tl.exp(m - m_new)
- acc = acc * rescale[:, None]
+ acc = acc * rescale
l_sum = l_sum * rescale
- p = tl.exp(scores - m_new[:, None])
- p = tl.where(valid_l[None, :], p, 0.0)
- l_sum = l_sum + tl.sum(p, axis=1)
-
- # V tile: (BLOCK_L, D)
+ p = tl.exp(scores - m_new)
+ p = tl.where(valid_l, p, 0.0)
+ l_sum = l_sum + tl.sum(p, axis=0)
+
+ # --- Load V tile: (BLOCK_L, BLOCK_D) & accumulate ---
v_offs = k_base + D + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
v_tile = tl.load(kv_cache_ptr + v_offs, mask=mask_kv, other=0.0).to(tl.float32)
- acc = acc + tl.dot(p, v_tile)
+
+ # p: (BLOCK_L,), v_tile: (BLOCK_L, BLOCK_D)
+ # p @ V → (BLOCK_D,)
+ acc += tl.sum(p[:, None] * v_tile, axis=0)
m = m_new
+ # --- Write output ---
l_safe = tl.where(l_sum == 0.0, 1.0, l_sum)
- out_vals = acc / l_safe[:, None]
-
- out_offs = (
- pid_b * stride_q_b
- + (q_head_base + offs_g[:, None]) * stride_q_h
- + offs_d[None, :]
- )
+ out_vals = acc / l_safe
+
+ out_offs = pid_b * stride_q_b + pid_q * stride_q_h + offs_d
tl.store(output_ptr + out_offs, out_vals.to(tl.bfloat16), mask=mask_q)
@@ -158,7 +155,9 @@
group_size = H // Hkv
out = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device)
- grid = (B, Hkv)
+
+ # Grid: (num_heads, batch) — one block per query head
+
... (76 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"content": "<7097 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check on per-Q-head kernel",
"timeout": 300000
}{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark on per-Q-head kernel",
"timeout": 600000
}PASSbenchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T19:16:04.339973+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T19:16:04.547679+00:00 elapsed_s=0.208 ms=0.143360
shape=0 variant=solution tflops=0.936 gbps=234.971 ms=0.143
shape=0 solution_peak_fraction=0.1305
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T19:16:06.182663+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T19:16:06.206694+00:00 elapsed_s=0.024 ms=0.531488
shape=1 variant=solution tflops=2.020 gbps=506.050 ms=0.531
shape=1 solution_peak_fraction=0.2811
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T19:16:06.615751+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T19:16:06.637833+00:00 elapsed_s=0.022 ms=0.480928
shape=2 variant=solution tflops=1.116 gbps=139.813 ms=0.481
shape=2 solution_peak_fraction=0.0777
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T19:16:07.253539+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T19:16:07.264874+00:00 elapsed_s=0.011 ms=0.242912
shape=3 variant=solution tflops=1.657 gbps=415.212 ms=0.243
shape=3 solution_peak_fraction=0.2307
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T19:16:07.364137+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T19:16:07.374574+00:00 elapsed_s=0.010 ms=0.193600
shape=4
... (121 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,7 +1,8 @@
"""Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
Triton-based decode kernel with online softmax and paged KV-cache gathering.
-One thread block per query head for maximum parallelism.
+Splits work across (batch, kv_head, seq_chunk) for maximum SM occupancy,
+then reduces partial results with a second kernel.
"""
import math
@@ -23,66 +24,81 @@
PAGE_SIZE = 16
# ---------------------------------------------------------------------------
-# Triton kernel — one block per query head
+# Pass 1: partial attention over a chunk of pages
# ---------------------------------------------------------------------------
@triton.jit
-def _paged_attention_decode_kernel(
+def _partial_decode_kernel(
query_ptr,
kv_cache_ptr,
block_table_ptr,
seq_lens_ptr,
- output_ptr,
+ partial_out_ptr, # (num_chunks, B, H, D + 2) — last 2 elems = (m, l)
B,
H,
Hkv,
D,
max_blocks,
P,
+ num_pages_total,
+ pages_per_chunk,
stride_q_b,
stride_q_h,
stride_kv_blk,
stride_kv_pos,
stride_kv_h,
stride_bt_b,
+ stride_po_c, # partial_out stride: chunk dim
+ stride_po_b, # partial_out stride: batch dim
+ stride_po_h, # partial_out stride: head dim
BLOCK_D: tl.constexpr,
group_size: tl.constexpr,
):
"""
- Paged attention decode — one program per (batch, query_head).
-
- BLOCK_D == D (pinned at launch). Each program loads one page at a time
- (BLOCK_L = PAGE_SIZE = 16 positions), computes dot products with its
- single Q head, and maintains online-softmax state in fp32.
+ Compute partial attention for one chunk of pages.
+
+ Grid: (num_chunks, B, Hkv)
+ - pid_c = program_id(0): chunk index
+ - pid_b = program_id(1): batch index
+ - pid_kv = program_id(2): KV head index (handles group_size Q heads)
"""
- pid_q = tl.program_id(0) # query head index
- pid_b = tl.program_id(1) # batch index
-
- kv_head = pid_q // group_size
+ pid_c = tl.program_id(0)
+ pid_b = tl.program_id(1)
+ pid_kv = tl.program_id(2)
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
+
+ # Page range for this chunk
+ page_start = pid_c * pages_per_chunk
+ page_end = tl.minimum(page_start + pages_per_chunk, num_pages)
+ if page_start >= num_pages:
+ return
BLOCK_L: tl.constexpr = 16
offs_l = tl.arange(0, BLOCK_L)
offs_d = tl.arange(0, BLOCK_D)
-
- # --- Load single Q head ---
- q_offs = pid_b * stride_q_b + pid_q * stride_q_h + offs_d
- mask_q = offs_d < D
+ offs_g = tl.arange(0, group_size)
+
+ q_head_base = pid_kv * group_size
+
+ # Load all Q heads in this group
+ q_offs = (
+ pid_b * stride_q_b
+ + (q_head_base + offs_g[:, None]) * stride_q_h
+ + offs_d[None, :]
+ )
+ mask_q = offs_d[None, :] < D
q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32)
- # q: (BLOCK_D,) with zeros beyond D
-
- # --- Online-softmax state ---
- m = float("-inf")
- l_sum = 0.0
- acc = tl.zeros([BLOCK_D], dtype=tl.float32)
+
+ m = tl.full([group_size], float("-inf"), dtype=tl.float32)
+ l_sum = tl.zeros([group_size], dtype=tl.float32)
+ acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
scale = 1.0 / tl.sqrt(D.to(tl.float32))
- # --- Page loop ---
- for page_idx in range(num_pages):
+ for page_idx in range(page_start, page_end):
blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
if page_idx == num_pages - 1:
@@ -92,49 +108,152 @@
tokens_this_page = P
valid_l = offs_l < tokens_this_page
- k_base = blk_idx.to(tl.int64) * stride_kv_blk + kv_head * stride_kv_h
-
- # --- Load K tile: (BLOCK_L, BLOCK_D) ---
+ k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
+
+ # K tile
k_offs = k_base + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
mask_kv = valid_l[:, None] & (offs_d[None, :] < D)
k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_kv, other=0.0).to(tl.float32)
- # --- Scores: (BLOCK_L,) = q @ K^T ---
- # tl.dot((1, BLOCK_D), (BLOCK_D, BLOCK_L)) → (1, BLOCK_L)
- scores = tl.dot(
- tl.reshape(q, (1, BLOCK_D)),
- tl.trans(k_tile),
- )
- scores = tl.reshape(scores, (BLOCK_L,)) * scale
- scores = tl.where(valid_l, scores, float("-inf"))
-
- # --- Online softmax update ---
- m_new = tl.maximum(m, tl.max(scores, axis=0))
-
+ scores = tl.dot(q, tl.trans(k_tile)) * scale
+ scores = tl.where(valid_l[None, :], scores, float("-inf"))
+
+ m_new = tl.maximum(m, tl.max(scores, axis=1))
rescale = tl.exp(m - m_new)
- acc = acc * rescale
+ acc = acc * rescale[:, None]
l_sum = l_sum * rescale
- p = tl.exp(scores - m_new)
- p = tl.where(valid_l, p, 0.0)
- l_sum = l_sum + tl.sum(p, axis=0)
-
- # --- Load V tile: (BLOCK_L, BLOCK_D) & accumulate ---
+ p = tl.exp(scores - m_new[:, None])
+ p = tl.where(valid_l[None, :], p, 0.0)
+ l_sum = l_sum + tl.sum(p, axis=1)
+
+ # V tile
v_offs = k_base + D + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
v_tile = tl.load(kv_cache_ptr + v_offs, mask=mask_kv, other=0.0).to(tl.float32)
-
- # p: (BLOCK_L,), v_tile: (BLOCK_L, BLOCK_D)
- # p @ V → (BLOCK_D,)
- acc += tl.sum(p[:, None] * v_tile, axis=0)
+ acc = acc + tl.dot(p, v_tile)
m = m_new
- # --- Write output ---
- l_safe = tl.where(l_sum == 0.0, 1.0, l_sum)
- out_vals = acc / l_safe
+ # Write partial results: (m[g], l_sum[g], acc[g, :D]) for each g in group
+ for g in range(group_si
... (5474 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"content": "<11398 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"offset": 205,
"limit": 50
}205 o_global = tl.zeros([BLOCK_D], dtype=tl.float32)
206
207 mask_d = offs_d < D
208
209 for c in range(num_chunks):
210 # Load m_c
211 m_c = tl.load(
212 partial_in_ptr
213 + c * stride_pi_c
214 + pid_b * stride_pi_b
215 + pid_q * stride_pi_h
216 + D
217 )
218 # Load l_c
219 l_c = tl.load(
220 partial_in_ptr
221 + c * stride_pi_c
222 + pid_b * stride_pi_b
223 + pid_q * stride_pi_h
224 + D + 1
225 )
226
227 # Skip empty chunks (l_c == 0 and m_c == -inf)
228 chunk_valid = l_c > 0.0
229
230 m_new = tl.maximum(m_global, m_c)
231 rescale_old = tl.exp(m_global - m_new)
232 rescale_c = tl.exp(m_c - m_new)
233
234 o_global = tl.where(chunk_valid, o_global * rescale_old, o_global)
235 l_global = tl.where(chunk_valid, l_global * rescale_old, l_global)
236
237 # Load acc_c for this chunk
238 acc_c = tl.load(
239 partial_in_ptr
240 + c * stride_pi_c
241 + pid_b * stride_pi_b
242 + pid_q * stride_pi_h
243 + offs_d,
244 mask=mask_d,
245 other=0.0,
246 ).to(tl.float32)
247
248 o_global = tl.where(chunk_valid, o_global + rescale_c * l_c * acc_c, o_global)
249 l_global = tl.where(chunk_valid, l_global + rescale_c * l_c, l_global)
2
... (165 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -155,7 +155,7 @@
+ D + 1,
l_sum[g],
)
- # Store acc
+ # Store acc (fp32, no conversion needed — partial_out is fp32)
val_offs = offs_d
mask_d = offs_d < D
tl.store(
@@ -164,7 +164,7 @@
+ pid_b * stride_po_b
+ q_head * stride_po_h
+ val_offs,
- acc[g, :].to(tl.bfloat16),
+ acc[g, :],
mask=mask_d,
)
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<316 chars \u2014 see diff>",
"new_string": "<355 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -284,10 +284,10 @@
num_chunks = max(1, min(max_pages, target_blocks // base_blocks))
pages_per_chunk = (max_pages + num_chunks - 1) // num_chunks
- # Intermediate storage: (num_chunks, B, H, D + 2) bf16
+ # Intermediate storage: (num_chunks, B, H, D + 2) fp32 for precision
partial = torch.zeros(
num_chunks, B, H, D + 2,
- dtype=torch.bfloat16,
+ dtype=torch.float32,
device=query.device,
)
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<183 chars \u2014 see diff>",
"new_string": "<196 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -234,7 +234,7 @@
o_global = tl.where(chunk_valid, o_global * rescale_old, o_global)
l_global = tl.where(chunk_valid, l_global * rescale_old, l_global)
- # Load acc_c for this chunk
+ # Load acc_c for this chunk (stored as fp32)
acc_c = tl.load(
partial_in_ptr
+ c * stride_pi_c
@@ -243,7 +243,7 @@
+ offs_d,
mask=mask_d,
other=0.0,
- ).to(tl.float32)
+ )
o_global = tl.where(chunk_valid, o_global + rescale_c * l_c * acc_c, o_global)
l_global = tl.where(chunk_valid, l_global + rescale_c * l_c, l_global)
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<280 chars \u2014 see diff>",
"new_string": "<282 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check on multi-chunk kernel",
"timeout": 300000
}{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark on multi-chunk kernel",
"timeout": 600000
}Exit code 1
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1094, in __getitem__
for dim, sl in enumerate(slices):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 41, in wrapper
raise ValueError("Did you forget to add @triton.jit ? "
ValueError: Did you forget to add @triton.jit ? (`_semantic` argument must be provided outside of JIT functions.)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/check.py", line 116, in <module>
main()
File "/workspace/problems/03_paged_attention/check.py", line 79, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 382, in forward
return _paged_attention_decode(
^^^^^^^^^^^^
... (2008 more chars)Exit code 1
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T19:17:58.126250+00:00
benchmark_event event=variant_error shape=0 variant=solution ts=2026-06-15T19:17:58.380158+00:00 elapsed_s=0.254 error=CompilationError
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1094, in __getitem__
for dim, sl in enumerate(slices):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 41, in wrapper
raise ValueError("Did you forget to add @triton.jit ? "
ValueError: Did you forget to add @triton.jit ? (`_semantic` argument must be provided outside of JIT functions.)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/benchmark.py", line 154, in <module>
main()
File "/workspace/problems/03_paged_attention/benchmark.py", line 85, in main
ms_sol = time_variant(
^^^^^^^^^^^^^
File "/workspace/src/eval/timing.py", line 80, in time_variant
ms = time_fn(fn, inputs, iters=iters, warmup=warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/src/eval/timing.py", line 103, in time_fn
fn(*inputs)
File "/workspace/.venv/lib/python3.11/site-packa
... (2493 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"offset": 86,
"limit": 60
}86 # Load all Q heads in this group
87 q_offs = (
88 pid_b * stride_q_b
89 + (q_head_base + offs_g[:, None]) * stride_q_h
90 + offs_d[None, :]
91 )
92 mask_q = offs_d[None, :] < D
93 q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32)
94
95 m = tl.full([group_size], float("-inf"), dtype=tl.float32)
96 l_sum = tl.zeros([group_size], dtype=tl.float32)
97 acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32)
98
99 scale = 1.0 / tl.sqrt(D.to(tl.float32))
100
101 for page_idx in range(page_start, page_end):
102 blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx)
103
104 if page_idx == num_pages - 1:
105 rem = seq_len % P
106 tokens_this_page = tl.where(rem == 0, P, rem)
107 else:
108 tokens_this_page = P
109
110 valid_l = offs_l < tokens_this_page
111 k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h
112
113 # K tile
114 k_offs = k_base + offs_l[:, None] * stride_kv_pos + offs_d[None, :]
115 mask_kv = valid_l[:, None] & (offs_d[None, :] < D)
116 k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_kv, other=0.0).to(tl.float32)
117
118 scores = tl.dot(q, tl.trans(k_tile)) * scale
119 scores = tl.where(valid_l[None, :], scores, float("-inf"))
120
121 m_new = tl.maximum(m, tl.max(scores, axis=1))
122 rescale = tl.exp(m - m_
... (853 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,33 +1,19 @@
- # Write partial results: (m[g], l_sum[g], acc[g, :D]) for each g in group
- for g in range(group_size):
- q_head = q_head_base + g
- # Store m
- tl.store(
- partial_out_ptr
- + pid_c * stride_po_c
- + pid_b * stride_po_b
- + q_head * stride_po_h
- + D,
- m[g],
- )
- # Store l_sum
- tl.store(
- partial_out_ptr
- + pid_c * stride_po_c
- + pid_b * stride_po_b
- + q_head * stride_po_h
- + D + 1,
- l_sum[g],
- )
- # Store acc (fp32, no conversion needed — partial_out is fp32)
- val_offs = offs_d
- mask_d = offs_d < D
- tl.store(
- partial_out_ptr
- + pid_c * stride_po_c
- + pid_b * stride_po_b
- + q_head * stride_po_h
- + val_offs,
- acc[g, :],
- mask=mask_d,
- )+ # Write partial results vectorised across all Q heads in the group.
+ base_off = pid_c * stride_po_c + pid_b * stride_po_b
+
+ # m: (group_size,) → store at [..., q_head, D]
+ m_offs = base_off + (q_head_base + offs_g) * stride_po_h + D
+ tl.store(partial_out_ptr + m_offs, m)
+
+ # l_sum: (group_size,) → store at [..., q_head, D+1]
+ l_offs = base_off + (q_head_base + offs_g) * stride_po_h + D + 1
+ tl.store(partial_out_ptr + l_offs, l_sum)
+
+ # acc: (group_size, BLOCK_D) → store at [..., q_head, 0:D]
+ acc_offs = (
+ base_off
+ + (q_head_base + offs_g[:, None]) * stride_po_h
+ + offs_d[None, :]
+ )
+ mask_acc = offs_d[None, :] < D
+ tl.store(partial_out_ptr + acc_offs, acc, mask=mask_acc){
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<934 chars \u2014 see diff>",
"new_string": "<742 chars \u2014 see diff>"
}<tool_use_error>String to replace not found in file.
String: # Write partial results: (m[g], l_sum[g], acc[g, :D]) for each g in group
for g in range(group_size):
q_head = q_head_base + g
# Store m
tl.store(
partial_out_ptr
+ pid_c * stride_po_c
+ pid_b * stride_po_b
+ q_head * stride_po_h
+ D,
m[g],
)
# Store l_sum
tl.store(
partial_out_ptr
+ pid_c * stride_po_c
+ pid_b * stride_po_b
+ q_head * stride_po_h
+ D + 1,
l_sum[g],
)
# Store acc (fp32, no conversion needed — partial_out is fp32)
val_offs = offs_d
mask_d = offs_d < D
tl.store(
partial_out_ptr
+ pid_c * stride_po_c
+ pid_b * stride_po_b
+ q_head * stride_po_h
+ val_offs,
acc[g, :],
mask=mask_d,
)
(note: Edit also tried swapping \uXXXX escapes and their characters; neither form matched, so the mismatch is likely elsewhere in old_string. Re-read the file and copy the exact surrounding text.)</tool_use_error>{
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"limit": 20,
"offset": 135
}135 m = m_new
136
137 # Write partial results: (m[g], l_sum[g], acc[g, :D]) for each g in group
138 for g in range(group_size):
139 q_head = q_head_base + g
140 # Store m
141 tl.store(
142 partial_out_ptr
143 + pid_c * stride_po_c
144 + pid_b * stride_po_b
145 + q_head * stride_po_h
146 + D,
147 m[g],
148 )
149 # Store l_sum
150 tl.store(
151 partial_out_ptr
152 + pid_c * stride_po_c
153 + pid_b * stride_po_b
154 + q_head * stride_po_h{
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"limit": 30,
"offset": 137
}137 # Write partial results: (m[g], l_sum[g], acc[g, :D]) for each g in group
138 for g in range(group_size):
139 q_head = q_head_base + g
140 # Store m
141 tl.store(
142 partial_out_ptr
143 + pid_c * stride_po_c
144 + pid_b * stride_po_b
145 + q_head * stride_po_h
146 + D,
147 m[g],
148 )
149 # Store l_sum
150 tl.store(
151 partial_out_ptr
152 + pid_c * stride_po_c
153 + pid_b * stride_po_b
154 + q_head * stride_po_h
155 + D + 1,
156 l_sum[g],
157 )
158 # Store acc (fp32, no conversion needed — partial_out is fp32)
159 val_offs = offs_d
160 mask_d = offs_d < D
161 tl.store(
162 partial_out_ptr
163 + pid_c * stride_po_c
164 + pid_b * stride_po_b
165 + q_head * stride_po_h
166 + val_offs,{
"command": "sed -n '137,170p' /workspace/problems/03_paged_attention/solution.py | cat -A",
"description": "Show exact whitespace in problem area"
}# Write partial results: (m[g], l_sum[g], acc[g, :D]) for each g in group$
for g in range(group_size):$
q_head = q_head_base + g$
# Store m$
tl.store($
partial_out_ptr$
+ pid_c * stride_po_c$
+ pid_b * stride_po_b$
+ q_head * stride_po_h$
+ D,$
m[g],$
)$
# Store l_sum$
tl.store($
partial_out_ptr$
+ pid_c * stride_po_c$
+ pid_b * stride_po_b$
+ q_head * stride_po_h$
+ D + 1,$
l_sum[g],$
)$
# Store acc (fp32, no conversion needed M-bM-^@M-^T partial_out is fp32)$
val_offs = offs_d$
mask_d = offs_d < D$
tl.store($
partial_out_ptr$
+ pid_c * stride_po_c$
+ pid_b * stride_po_b$
+ q_head * stride_po_h$
+ val_offs,$
acc[g, :],$
mask=mask_d,$
)$
$--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,7 +1,7 @@
"""Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
Triton-based decode kernel with online softmax and paged KV-cache gathering.
-Splits work across (batch, kv_head, seq_chunk) for maximum SM occupancy,
+Splits work across (batch, kv_head, seq_chunk) for SM occupancy,
then reduces partial results with a second kernel.
"""
@@ -34,7 +34,7 @@
kv_cache_ptr,
block_table_ptr,
seq_lens_ptr,
- partial_out_ptr, # (num_chunks, B, H, D + 2) — last 2 elems = (m, l)
+ partial_out_ptr,
B,
H,
Hkv,
@@ -49,19 +49,15 @@
stride_kv_pos,
stride_kv_h,
stride_bt_b,
- stride_po_c, # partial_out stride: chunk dim
- stride_po_b, # partial_out stride: batch dim
- stride_po_h, # partial_out stride: head dim
+ stride_po_c,
+ stride_po_b,
+ stride_po_h,
BLOCK_D: tl.constexpr,
group_size: tl.constexpr,
):
- """
- Compute partial attention for one chunk of pages.
+ """Compute partial attention for one chunk of pages.
Grid: (num_chunks, B, Hkv)
- - pid_c = program_id(0): chunk index
- - pid_b = program_id(1): batch index
- - pid_kv = program_id(2): KV head index (handles group_size Q heads)
"""
pid_c = tl.program_id(0)
pid_b = tl.program_id(1)
@@ -70,7 +66,6 @@
seq_len = tl.load(seq_lens_ptr + pid_b)
num_pages = (seq_len + P - 1) // P
- # Page range for this chunk
page_start = pid_c * pages_per_chunk
page_end = tl.minimum(page_start + pages_per_chunk, num_pages)
if page_start >= num_pages:
@@ -134,39 +129,25 @@
m = m_new
- # Write partial results: (m[g], l_sum[g], acc[g, :D]) for each g in group
- for g in range(group_size):
- q_head = q_head_base + g
- # Store m
- tl.store(
- partial_out_ptr
- + pid_c * stride_po_c
- + pid_b * stride_po_b
- + q_head * stride_po_h
- + D,
- m[g],
- )
- # Store l_sum
- tl.store(
- partial_out_ptr
- + pid_c * stride_po_c
- + pid_b * stride_po_b
- + q_head * stride_po_h
- + D + 1,
- l_sum[g],
- )
- # Store acc (fp32, no conversion needed — partial_out is fp32)
- val_offs = offs_d
- mask_d = offs_d < D
- tl.store(
- partial_out_ptr
- + pid_c * stride_po_c
- + pid_b * stride_po_b
- + q_head * stride_po_h
- + val_offs,
- acc[g, :],
- mask=mask_d,
- )
+ # Write partial results — vectorised across all Q heads in the group.
+ base_off = pid_c * stride_po_c + pid_b * stride_po_b
+
+ # m: (group_size,) -> store at [c, b, q_head, D]
+ m_offs = base_off + (q_head_base + offs_g) * stride_po_h + D
+ tl.store(partial_out_ptr + m_offs, m)
+
+ # l_sum: (group_size,) -> store at [c, b, q_head, D+1]
+ ls_offs = base_off + (q_head_base + offs_g) * stride_po_h + D + 1
+ tl.store(partial_out_ptr + ls_offs, l_sum)
+
+ # acc: (group_size, BLOCK_D) -> store at [c, b, q_head, 0:D]
+ acc_offs = (
+ base_off
+ + (q_head_base + offs_g[:, None]) * stride_po_h
+ + offs_d[None, :]
+ )
+ mask_acc = offs_d[None, :] < D
+ tl.store(partial_out_ptr + acc_offs, acc, mask=mask_acc)
# ---------------------------------------------------------------------------
@@ -176,8 +157,8 @@
@triton.jit
def _reduce_partial_kernel(
- partial_in_ptr, # (num_chunks, B, H, D + 2)
- output_ptr, # (B, H, D)
+ partial_in_ptr,
+ output_ptr,
num_chunks,
B,
H,
@@ -189,8 +170,7 @@
stride_q_h,
BLOCK_D: tl.constexpr,
):
- """
- Reduce partial results across chunks into final output.
+ """Reduce partial results across chunks into final output.
Grid: (H, B) — one program per query head.
"""
@@ -198,16 +178,14 @@
pid_b = tl.program_id(1)
offs_d = tl.arange(0, BLOCK_D)
-
- # Global softmax reduction
+ mask_d = offs_d < D
+
m_global = float("-inf")
l_global = 0.0
o_global = tl.zeros([BLOCK_D], dtype=tl.float32)
- mask_d = offs_d < D
-
for c in range(num_chunks):
- # Load m_c
+ # Load m_c, l_c
m_c = tl.load(
partial_in_ptr
+ c * stride_pi_c
@@ -215,7 +193,6 @@
+ pid_q * stride_pi_h
+ D
)
- # Load l_c
l_c = tl.load(
partial_in_ptr
+ c * stride_pi_c
@@ -224,17 +201,19 @@
+ D + 1
)
- # Skip empty chunks (l_c == 0 and m_c == -inf)
chunk_valid = l_c > 0.0
m_new = tl.maximum(m_global, m_c)
+
+ # Only rescale if chunk is valid; otherwise keep current state.
+ # tl.where selects element-wise — both branches are evaluated.
rescale_old = tl.exp(m_global - m_new)
rescale_c = tl.exp(m_c - m_new)
o_global = tl.where(chunk_valid, o_global * rescale_old, o_global)
l_global = tl.where(chunk_valid, l_global * rescale_old, l_global)
- # Load acc_c for this chunk (stored as fp32)
+ # Load acc_c for this chunk
acc_c = tl.load(
partial_in_ptr
+ c * stride_pi_c
@@ -245,8 +224,16 @@
other=0.0,
)
- o_global = tl.where(chunk_valid, o_global + rescale_c * l_c * acc_c, o_global)
- l_global = tl.where(chunk_valid, l_global + rescale_c * l_c, l_global)
+ o_global = tl.where(
+ chunk_valid,
+ o_global + rescale_c * l_c * acc_c,
+ o_global,
+ )
+ l_global = tl.where(
+ chunk_valid,
+ l_global + rescale_c * l_c,
+ l_global,
+ )
m_global = tl.where(chunk_valid
... (1526 more chars){
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"content": "<10798 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check on rewritten kernel",
"timeout": 300000
}{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark on rewritten kernel",
"timeout": 600000
}Exit code 1
FAIL: shape 0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16} seed 42 case nominal: tolerance exceeded: max_abs_diff=3.05029 max_rel_diff=267762 bad=32053/32768 worst_idx=(4, 15, 90) (atol=0.02, rtol=0.02)benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T19:19:16.927879+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T19:19:17.172936+00:00 elapsed_s=0.245 ms=0.055840
shape=0 variant=solution tflops=2.404 gbps=603.250 ms=0.056
shape=0 solution_peak_fraction=0.3351
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T19:19:18.790492+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T19:19:19.056258+00:00 elapsed_s=0.266 ms=0.223920
shape=1 variant=solution tflops=4.795 gbps=1201.142 ms=0.224
shape=1 solution_peak_fraction=0.6673
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T19:19:19.462923+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T19:19:19.604225+00:00 elapsed_s=0.141 ms=0.081264
shape=2 variant=solution tflops=6.607 gbps=827.426 ms=0.081
shape=2 solution_peak_fraction=0.4597
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T19:19:20.212409+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T19:19:20.255332+00:00 elapsed_s=0.043 ms=0.104336
shape=3 variant=solution tflops=3.857 gbps=966.684 ms=0.104
shape=3 solution_peak_fraction=0.5370
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T19:19:20.355522+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T19:19:20.513303+00:00 elapsed_s=0.158 ms=0.051088
shape=4
... (123 more chars){
"command": "<3186 chars \u2014 see below>",
"description": "Debug single-chunk case",
"timeout": 120000
}uv run python -c "
import torch
import sys
sys.path.insert(0, '/workspace')
# Force single-chunk test
import reference, shapes
shape = shapes.SHAPES[0]
reference.BATCH = shape['batch']
reference.NUM_HEADS = shape['num_heads']
reference.NUM_KV_HEADS = shape['num_kv_heads']
reference.HEAD_DIM = shape['head_dim']
reference.SEQ_LEN = shape['seq_len']
reference.PAGE_SIZE = shape['page_size']
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).cuda().eval()
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
print('Ref output stats:', ref_out.float().min().item(), ref_out.float().max().item(), ref_out.float().mean().item())
# Now test solution with single chunk
from solution import _partial_decode_kernel, _reduce_partial_kernel
q, kvc, bt, sl = inputs
B, H, D = q.shape
Hkv = kvc.shape[2]
P = kvc.shape[1]
max_blocks = bt.shape[1]
group_size = H // Hkv
max_seq = int(sl.max().item())
max_pages = (max_seq + P - 1) // P
num_chunks = 1
pages_per_chunk = max_pages
partial = torch.zeros(num_chunks, B, H, D+2, dtype=torch.float32, device=q.device)
grid1 = (num_chunks, B, Hkv)
_partial_decode_kernel[grid1](
q, kvc, bt, sl, partial,
B, H, Hkv, D, max_blocks, P,
max_pages, pages_per_chunk,
q.stride(0), q.stride(1),
kvc.stride(0), kvc.stride(1), kvc.stride(2),
bt.stride(0),
partial.stride(0), partial.stride(1), partial.stride(2),
BLOCK_D=D, group_size=group_size,
)
torch.cuda.synchronize()
# Check partial values
print('Partial m min/max:', partial[0, :, :, D].min().item(), partial[0, :, :, D].max().item())
print('Partial l min/max:', partial[0, :, :, D+1].min().item(), partial[0, :, :, D+1].max().item())
print('Partial acc stats:', partial[0, :, :, :D].min().item(), partial[0, :, :, :D].max().item())
# Manually reduce (single chunk: output = acc / l)
partial_acc = partial[0, :, :, :D] # (B, H, D)
partial_l = partial[0, :, :, D+1:D+2] # (B, H, 1)
partial_m = partial[0, :, :, D:D+1]
# The true output should be acc / l (since with one chunk, no rescaling is needed if m hasn't changed)
# Actually, with one chunk, m_start = -inf, m_end = partial_m. l is sum(exp(scores - m)).
# So output = acc / l (no rescaling since there's only one chunk)
manual_out = partial_acc / (partial_l + 1e-10)
print('Manual output stats:', manual_out.float().min().item(), manual_out.float().max().item())
diff = (ref_out.float() - manual_out.float().cuda()).abs()
print('Max abs diff (single chunk):', diff.max().item())
print('Bad elements:', (diff > 0.02).sum().item(), '/', diff.numel())
# Now test pass 2 reduction
out = torch.empty(B, H, D, dtype=torch.bfloat16, device=q.device)
grid2 = (H, B)
_reduce_partial_kernel[grid2](
partial, out, num_chunks, B, H, D,
partial.stride(0), partial.stride(1), partial.stride(2),
out.stride(0), out.stride(1),
BLOCK_D=D,
)
torch.cuda.synchronize()
diff2 = (ref_out.float() - out.float()).abs()
print('Max abs diff (after reduction):', diff2.max().item())
print('Bad elements (after reduction):', (diff2 > 0.02).sum().item(), '/', diff2.numel())
" 2>&1Ref output stats: -0.01220703125 0.01141357421875 -6.515139830298722e-06
Partial m min/max: 0.02209429442882538 0.04703931882977486
Partial l min/max: 977.451416015625 1001.5010375976562
Partial acc stats: -12.201464653015137 11.271280288696289
Manual output stats: -0.012219518423080444 0.011384991928935051
Max abs diff (single chunk): 3.290455788373947e-05
Bad elements: 0 / 32768
Max abs diff (after reduction): 12.17529296875
Bad elements (after reduction): 32599 / 32768--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -224,9 +224,10 @@
other=0.0,
)
+ # acc_c is already the exp-weighted sum: Σ exp(s-m_c)·V
o_global = tl.where(
chunk_valid,
- o_global + rescale_c * l_c * acc_c,
+ o_global + rescale_c * acc_c,
o_global,
)
l_global = tl.where(
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<133 chars \u2014 see diff>",
"new_string": "<191 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check after reduction fix",
"timeout": 300000
}{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark after reduction fix",
"timeout": 600000
}PASSbenchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T19:20:28.745125+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T19:20:28.992199+00:00 elapsed_s=0.247 ms=0.054672
shape=0 variant=solution tflops=2.455 gbps=616.138 ms=0.055
shape=0 solution_peak_fraction=0.3423
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T19:20:30.615325+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T19:20:30.628099+00:00 elapsed_s=0.013 ms=0.224928
shape=1 variant=solution tflops=4.774 gbps=1195.759 ms=0.225
shape=1 solution_peak_fraction=0.6643
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T19:20:31.036227+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T19:20:31.043063+00:00 elapsed_s=0.007 ms=0.080800
shape=2 variant=solution tflops=6.644 gbps=832.177 ms=0.081
shape=2 solution_peak_fraction=0.4623
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T19:20:31.656156+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T19:20:31.663459+00:00 elapsed_s=0.007 ms=0.103456
shape=3 variant=solution tflops=3.889 gbps=974.906 ms=0.103
shape=3 solution_peak_fraction=0.5416
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T19:20:31.762412+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T19:20:31.768740+00:00 elapsed_s=0.006 ms=0.050768
shape=4
... (123 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -277,52 +277,99 @@
device=query.device,
)
- # Pass 1
- grid1 = (num_chunks, B, Hkv)
- _partial_decode_kernel[grid1](
- query,
- kv_cache,
- block_table,
- seq_lens,
- partial,
- B,
- H,
- Hkv,
- D,
- max_blocks,
- P,
- max_pages,
- pages_per_chunk,
- query.stride(0),
- query.stride(1),
- kv_cache.stride(0),
- kv_cache.stride(1),
- kv_cache.stride(2),
- block_table.stride(0),
- partial.stride(0),
- partial.stride(1),
- partial.stride(2),
- BLOCK_D=D,
- group_size=group_size,
- )
-
- # Pass 2
out = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device)
- grid2 = (H, B)
- _reduce_partial_kernel[grid2](
- partial,
- out,
- num_chunks,
- B,
- H,
- D,
- partial.stride(0),
- partial.stride(1),
- partial.stride(2),
- out.stride(0),
- out.stride(1),
- BLOCK_D=D,
- )
+
+ if num_chunks == 1:
+ # Fast path: single chunk, write output directly from pass 1.
+ # Use the partial tensor as output (ignoring the m/l tail).
+ # We still compute via the partial kernel, then copy acc/l to output.
+ grid1 = (1, B, Hkv)
+ _partial_decode_kernel[grid1](
+ query,
+ kv_cache,
+ block_table,
+ seq_lens,
+ partial,
+ B,
+ H,
+ Hkv,
+ D,
+ max_blocks,
+ P,
+ max_pages,
+ pages_per_chunk,
+ query.stride(0),
+ query.stride(1),
+ kv_cache.stride(0),
+ kv_cache.stride(1),
+ kv_cache.stride(2),
+ block_table.stride(0),
+ partial.stride(0),
+ partial.stride(1),
+ partial.stride(2),
+ BLOCK_D=D,
+ group_size=group_size,
+ )
+ # Single-chunk reduction pass
+ grid2 = (H, B)
+ _reduce_partial_kernel[grid2](
+ partial,
+ out,
+ num_chunks,
+ B,
+ H,
+ D,
+ partial.stride(0),
+ partial.stride(1),
+ partial.stride(2),
+ out.stride(0),
+ out.stride(1),
+ BLOCK_D=D,
+ )
+ else:
+ # Multi-chunk: partial + reduce
+ grid1 = (num_chunks, B, Hkv)
+ _partial_decode_kernel[grid1](
+ query,
+ kv_cache,
+ block_table,
+ seq_lens,
+ partial,
+ B,
+ H,
+ Hkv,
+ D,
+ max_blocks,
+ P,
+ max_pages,
+ pages_per_chunk,
+ query.stride(0),
+ query.stride(1),
+ kv_cache.stride(0),
+ kv_cache.stride(1),
+ kv_cache.stride(2),
+ block_table.stride(0),
+ partial.stride(0),
+ partial.stride(1),
+ partial.stride(2),
+ BLOCK_D=D,
+ group_size=group_size,
+ )
+ grid2 = (H, B)
+ _reduce_partial_kernel[grid2](
+ partial,
+ out,
+ num_chunks,
+ B,
+ H,
+ D,
+ partial.stride(0),
+ partial.stride(1),
+ partial.stride(2),
+ out.stride(0),
+ out.stride(1),
+ BLOCK_D=D,
+ )
return out
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<985 chars \u2014 see diff>",
"new_string": "<2483 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000
}{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 600000
}PASSbenchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T19:21:10.658885+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T19:21:10.902734+00:00 elapsed_s=0.244 ms=0.057072
shape=0 variant=solution tflops=2.352 gbps=590.228 ms=0.057
shape=0 solution_peak_fraction=0.3279
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T19:21:12.523986+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T19:21:12.536869+00:00 elapsed_s=0.013 ms=0.225680
shape=1 variant=solution tflops=4.758 gbps=1191.775 ms=0.226
shape=1 solution_peak_fraction=0.6621
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T19:21:12.944767+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T19:21:12.951641+00:00 elapsed_s=0.007 ms=0.082016
shape=2 variant=solution tflops=6.546 gbps=819.839 ms=0.082
shape=2 solution_peak_fraction=0.4555
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T19:21:13.558847+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T19:21:13.566283+00:00 elapsed_s=0.007 ms=0.105136
shape=3 variant=solution tflops=3.827 gbps=959.328 ms=0.105
shape=3 solution_peak_fraction=0.5330
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T19:21:13.664838+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T19:21:13.671103+00:00 elapsed_s=0.006 ms=0.052144
shape=4
... (123 more chars){
"command": "<686 chars \u2014 see below>",
"description": "Check framework detection"
}uv run python -c "
import re
with open('solution.py') as f:
code = f.read()
# Check which patterns match
patterns = [
('ptx', r\"asm\s+volatile|asm\s*\(|mma\.sync|tcgen05\.\"),
('cutlass3', r\"\bcute::|cutlass/gemm/collective|cutlass::arch::Sm(9|10|12)\"),
('cutlass2', r\"cutlass/gemm/device/gemm|cutlass::gemm::device\"),
('cuda_wmma', r\"\bnvcuda::wmma\b|wmma::fragment\"),
('triton', r\"import\s+triton\b|@triton\.jit|\btl\.dot\b\"),
('cuda_raw', r\"torch\.utils\.cpp_extension\.load_inline|__global__\s+void\"),
]
for name, pat in patterns:
if re.search(pat, code):
print(f'{name}: MATCH')
else:
print(f'{name}: no match')
" 2>&1ptx: no match
cutlass3: no match
cutlass2: no match
cuda_wmma: no match
triton: MATCH
cuda_raw: no match{
"file_path": "/workspace/problems/03_paged_attention/solution.py"
}1 """Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell).
2
3 Triton-based decode kernel with online softmax and paged KV-cache gathering.
4 Splits work across (batch, kv_head, seq_chunk) for SM occupancy,
5 then reduces partial results with a second kernel.
6 """
7
8 import math
9
10 import torch
11 import torch.nn as nn
12 import triton
13 import triton.language as tl
14
15 OP_TYPE = "attention"
16 SUPPORTED_PRECISIONS = ["bf16"]
17 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
18
19 BATCH = 8
20 NUM_HEADS = 32
21 NUM_KV_HEADS = 8
22 HEAD_DIM = 128
23 SEQ_LEN = 1024
24 PAGE_SIZE = 16
25
26 # ---------------------------------------------------------------------------
27 # Pass 1: partial attention over a chunk of pages
28 # ---------------------------------------------------------------------------
29
30
31 @triton.jit
32 def _partial_decode_kernel(
33 query_ptr,
34 kv_cache_ptr,
35 block_table_ptr,
36 seq_lens_ptr,
37 partial_out_ptr,
38 B,
39 H,
40 Hkv,
41 D,
42 max_blocks,
43 P,
44 num_pages_total,
45 pages_per_chunk,
46 stride_q_b,
47 stride_q_h,
48 stride_kv_blk,
49 stride_kv_pos,
50 stride_kv_h,
51 stride_bt_b,
52 stride_po_c,
53 stride_po_b,
54 stride_po_h,
55 BLOCK_D: tl.constexpr,
56 group_size: tl.constexpr,
57 ):
58 """Compute partial attention for one chunk of pages.
59
60 Grid: (num_chunks, B, Hkv)
61 """
62 pid_c = tl.prog
... (12546 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -41,7 +41,6 @@
D,
max_blocks,
P,
- num_pages_total,
pages_per_chunk,
stride_q_b,
stride_q_h,
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<41 chars \u2014 see diff>",
"new_string": " pages_per_chunk,"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -278,97 +278,50 @@
out = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device)
- if num_chunks == 1:
- # Fast path: single chunk, write output directly from pass 1.
- # Use the partial tensor as output (ignoring the m/l tail).
- # We still compute via the partial kernel, then copy acc/l to output.
- grid1 = (1, B, Hkv)
- _partial_decode_kernel[grid1](
- query,
- kv_cache,
- block_table,
- seq_lens,
- partial,
- B,
- H,
- Hkv,
- D,
- max_blocks,
- P,
- max_pages,
- pages_per_chunk,
- query.stride(0),
- query.stride(1),
- kv_cache.stride(0),
- kv_cache.stride(1),
- kv_cache.stride(2),
- block_table.stride(0),
- partial.stride(0),
- partial.stride(1),
- partial.stride(2),
- BLOCK_D=D,
- group_size=group_size,
- )
- # Single-chunk reduction pass
- grid2 = (H, B)
- _reduce_partial_kernel[grid2](
- partial,
- out,
- num_chunks,
- B,
- H,
- D,
- partial.stride(0),
- partial.stride(1),
- partial.stride(2),
- out.stride(0),
- out.stride(1),
- BLOCK_D=D,
- )
- else:
- # Multi-chunk: partial + reduce
- grid1 = (num_chunks, B, Hkv)
- _partial_decode_kernel[grid1](
- query,
- kv_cache,
- block_table,
- seq_lens,
- partial,
- B,
- H,
- Hkv,
- D,
- max_blocks,
- P,
- max_pages,
- pages_per_chunk,
- query.stride(0),
- query.stride(1),
- kv_cache.stride(0),
- kv_cache.stride(1),
- kv_cache.stride(2),
- block_table.stride(0),
- partial.stride(0),
- partial.stride(1),
- partial.stride(2),
- BLOCK_D=D,
- group_size=group_size,
- )
- grid2 = (H, B)
- _reduce_partial_kernel[grid2](
- partial,
- out,
- num_chunks,
- B,
- H,
- D,
- partial.stride(0),
- partial.stride(1),
- partial.stride(2),
- out.stride(0),
- out.stride(1),
- BLOCK_D=D,
- )
+ # Pass 1: partial attention over chunks
+ grid1 = (num_chunks, B, Hkv)
+ _partial_decode_kernel[grid1](
+ query,
+ kv_cache,
+ block_table,
+ seq_lens,
+ partial,
+ B,
+ H,
+ Hkv,
+ D,
+ max_blocks,
+ P,
+ pages_per_chunk,
+ query.stride(0),
+ query.stride(1),
+ kv_cache.stride(0),
+ kv_cache.stride(1),
+ kv_cache.stride(2),
+ block_table.stride(0),
+ partial.stride(0),
+ partial.stride(1),
+ partial.stride(2),
+ BLOCK_D=D,
+ group_size=group_size,
+ )
+
+ # Pass 2: reduce partial results across chunks
+ grid2 = (H, B)
+ _reduce_partial_kernel[grid2](
+ partial,
+ out,
+ num_chunks,
+ B,
+ H,
+ D,
+ partial.stride(0),
+ partial.stride(1),
+ partial.stride(2),
+ out.stride(0),
+ out.stride(1),
+ BLOCK_D=D,
+ )
return out
{
"replace_all": false,
"file_path": "/workspace/problems/03_paged_attention/solution.py",
"old_string": "<2483 chars \u2014 see diff>",
"new_string": "<1036 chars \u2014 see diff>"
}The file /workspace/problems/03_paged_attention/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1",
"description": "Final correctness check",
"timeout": 300000
}{
"command": "uv run python benchmark.py 2>&1",
"description": "Final benchmark",
"timeout": 600000
}PASSbenchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T19:21:59.676037+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T19:21:59.920452+00:00 elapsed_s=0.244 ms=0.056304
shape=0 variant=solution tflops=2.384 gbps=598.279 ms=0.056
shape=0 solution_peak_fraction=0.3324
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T19:22:01.534986+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T19:22:01.547581+00:00 elapsed_s=0.013 ms=0.224848
shape=1 variant=solution tflops=4.775 gbps=1196.185 ms=0.225
shape=1 solution_peak_fraction=0.6645
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T19:22:01.954898+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T19:22:01.961737+00:00 elapsed_s=0.007 ms=0.081888
shape=2 variant=solution tflops=6.556 gbps=821.121 ms=0.082
shape=2 solution_peak_fraction=0.4562
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T19:22:02.576276+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T19:22:02.583711+00:00 elapsed_s=0.007 ms=0.104896
shape=3 variant=solution tflops=3.836 gbps=961.523 ms=0.105
shape=3 solution_peak_fraction=0.5342
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T19:22:02.683234+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T19:22:02.689544+00:00 elapsed_s=0.006 ms=0.051760
shape=4
... (123 more chars){
"command": "rm -f /workspace/problems/03_paged_attention/debug_kernel.py",
"description": "Clean up debug file"
}(Bash completed with no output){
"taskId": "1",
"status": "completed"
}Updated task #1 status