"""Triton-based paged attention decode kernel for packed KV cache layout.
The KV cache is laid out as (num_blocks, page_size, num_kv_heads, head_dim * 2)
with K and V packed on the last dimension: [K | V].
This kernel uses a two-stage approach:
1. Stage 1 splits the KV sequence into NUM_KV_SPLITS parts, each block computes
partial attention with online softmax.
2. Stage 2 reduces partial results across splits.
For efficiency on the packed layout, each iteration loads K and V separately
using coalesced per-token loads.
"""
import math
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
BATCH = 8
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
SEQ_LEN = 1024
PAGE_SIZE = 16
# --------------------------------------------------------------------------- #
# Stage 1: compute partial attention per KV split (grouped for GQA)
# --------------------------------------------------------------------------- #
@triton.jit
def _fwd_kernel_stage1(
Q,
KV_Cache,
block_table,
seq_lens,
sm_scale,
Mid_O,
stride_qb,
stride_qh,
stride_bt0,
stride_bt1,
stride_kv0,
stride_kv1,
stride_kv2,
stride_kv3,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
split_kv_id = tl.program_id(2)
# Which KV head this group belongs to
heads_per_kv = tl.cdiv(kv_group_num, BLOCK_H)
cur_kv_head = cur_head_id // heads_per_kv
# Number of valid query heads in this block
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < HEAD_DIM
mask_dv = offs_dv < HEAD_DIM
cur_batch_seq_len = tl.load(seq_lens + cur_batch)
# Load Q for this head group
offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0)
# Determine this split's token range
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
# Online softmax accumulators
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
# Iterate over pages within this split
start_page = split_kv_start // PAGE_SIZE
end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
offs_token_in_page = tl.arange(0, PAGE_SIZE)
for page_idx in range(start_page, end_page):
physical_page = tl.load(
block_table + cur_batch * stride_bt0 + page_idx * stride_bt1,
)
# Global token indices for this page
global_token_idx = page_idx * PAGE_SIZE + offs_token_in_page
token_mask = global_token_idx < split_kv_end
# Base address for this page and KV head
page_base = (
KV_Cache
+ physical_page * stride_kv0
+ cur_kv_head * stride_kv2
)
# Load K and V separately (packed layout: [K|V] on last dim)
k = tl.load(
page_base + offs_token_in_page[:, None] * stride_kv1 + offs_d[None, :],
mask=token_mask[:, None] & mask_d[None, :],
other=0.0,
)
v = tl.load(
page_base + offs_token_in_page[:, None] * stride_kv1 + (HEAD_DIM + offs_dv)[None, :],
mask=token_mask[:, None] & mask_dv[None, :],
other=0.0,
)
# Compute QK^T
qk = tl.dot(q, k.to(q.dtype).T)
qk *= sm_scale
# Mask tokens beyond split end or seq_len
qk = tl.where(
mask_h[:, None] & token_mask[None, :],
qk,
float("-inf"),
)
# Online softmax update
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# Store partial output and LSE
offs_mid_o = (
cur_batch * stride_mid_ob
+ cur_head[:, None] * stride_mid_oh
+ split_kv_id * stride_mid_os
+ offs_dv[None, :]
)
tl.store(
Mid_O + offs_mid_o,
acc / e_sum[:, None],
mask=mask_h[:, None] & mask_dv[None, :],
)
offs_mid_lse = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ HEAD_DIM
)
tl.store(
Mid_O + offs_mid_lse,
e_max + tl.log(e_sum),
mask=mask_h,
)
# --------------------------------------------------------------------------- #
# Stage 2: reduce partial results across KV splits
# --------------------------------------------------------------------------- #
@triton.jit
def _fwd_kernel_stage2(
Mid_O,
O,
seq_lens,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(seq_lens + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < HEAD_DIM
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_lse = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + HEAD_DIM
for split_kv_id in range(0, NUM_KV_SPLITS):
split_kv_start = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) * split_kv_id
split_kv_end = tl.minimum(
split_kv_start + tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS),
cur_batch_seq_len,
)
if split_kv_end > split_kv_start:
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os,
mask=mask_d,
other=0.0,
)
tlogic = tl.load(Mid_O + offs_lse + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
# --------------------------------------------------------------------------- #
# Python wrappers
# --------------------------------------------------------------------------- #
def _decode_grouped_att_m_fwd(
q,
kv_cache,
block_table,
seq_lens,
mid_o,
num_kv_splits,
sm_scale,
page_size,
BLOCK_H: int = 8,
):
batch, head_num = q.shape[0], q.shape[1]
num_kv_heads = kv_cache.shape[2]
head_dim = kv_cache.shape[3] // 2
kv_group_num = head_num // num_kv_heads
# Correct grid: each "head group" covers up to BLOCK_H query heads that
# map to the same KV head. We must divide by min(BLOCK_H, kv_group_num)
# so that every KV head gets its own set of blocks.
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
num_kv_splits,
)
BLOCK_DMODEL = triton.next_power_of_2(head_dim)
BLOCK_DV = triton.next_power_of_2(head_dim)
stride_kv = kv_cache.stride()
stride_bt = block_table.stride()
_fwd_kernel_stage1[grid](
q,
kv_cache,
block_table,
seq_lens,
sm_scale,
mid_o,
q.stride(0),
q.stride(1),
stride_bt[0],
stride_bt[1],
stride_kv[0],
stride_kv[1],
stride_kv[2],
stride_kv[3],
mid_o.stride(0),
mid_o.stride(1),
mid_o.stride(2),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=num_kv_splits,
PAGE_SIZE=page_size,
HEAD_DIM=head_dim,
num_warps=4,
num_stages=2,
)
def _decode_softmax_reducev_fwd(
mid_o,
o,
seq_lens,
num_kv_splits,
):
batch, head_num = o.shape[0], o.shape[1]
head_dim = o.shape[2]
BLOCK_DV = triton.next_power_of_2(head_dim)
grid = (batch, head_num)
_fwd_kernel_stage2[grid](
mid_o,
o,
seq_lens,
mid_o.stride(0),
mid_o.stride(1),
mid_o.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=num_kv_splits,
BLOCK_DV=BLOCK_DV,
HEAD_DIM=head_dim,
num_warps=4,
num_stages=2,
)
def paged_attention_decode(
query,
kv_cache,
block_table,
seq_lens,
sm_scale,
page_size,
num_kv_splits,
):
batch, num_heads, head_dim = query.shape
# Intermediate buffers for stage 1 -> stage 2
# Layout: (batch, num_heads, num_kv_splits, head_dim+1)
# Last element of each slice stores LSE
mid_o = torch.empty(
(batch, num_heads, num_kv_splits, head_dim + 1),
dtype=torch.float32,
device=query.device,
)
output = torch.empty_like(query)
_decode_grouped_att_m_fwd(
query,
kv_cache,
block_table,
seq_lens,
mid_o,
num_kv_splits,
sm_scale,
page_size,
)
_decode_softmax_reducev_fwd(
mid_o,
output,
seq_lens,
num_kv_splits,
)
return output
# --------------------------------------------------------------------------- #
# Model interface
# --------------------------------------------------------------------------- #
class Model(nn.Module):
def __init__(
self,
batch: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_len: int,
page_size: int,
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.batch = batch
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.seq_len = seq_len
self.page_size = page_size
self.group_size = num_heads // num_kv_heads
self.scale = 1.0 / math.sqrt(head_dim)
self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
def forward(
self,
query: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
) -> torch.Tensor:
# Heuristic: use more KV splits when the sequence is longer or when
# there are fewer query heads (low natural occupancy). More splits
# increase the number of stage-1 blocks and improve parallelism.
max_seq_len = self.seq_len
# Target a total stage-1 block count that keeps the GPU well occupied.
# With smaller head_dim each block does less work per page, so we need
# more blocks (hence more splits) to hide latency.
head_groups = (self.num_heads + min(8, self.group_size) - 1) // min(8, self.group_size)
base_blocks = self.batch * head_groups
target_blocks = 640 if self.head_dim <= 64 else 512
num_kv_splits = max(1, min(24, target_blocks // base_blocks))
return paged_attention_decode(
query,
kv_cache,
block_table,
seq_lens,
self.scale,
self.page_size,
num_kv_splits,
)
def get_inputs():
B = BATCH
H = NUM_HEADS
Hkv = NUM_KV_HEADS
D = HEAD_DIM
L = SEQ_LEN
P = PAGE_SIZE
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1
perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int()
block_table = perm.contiguous()
seq_lens = torch.full((B,), L, dtype=torch.int32)
return [query, kv_cache, block_table, seq_lens]
def get_init_inputs():
return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:45:19.467000 415152 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=0 variant=eager tflops=0.211 gbps=52.909 ms=0.637
shape=0 variant=compiled tflops=0.197 gbps=49.456 ms=0.681
shape=0 variant=solution tflops=2.794 gbps=701.314 ms=0.048
shape=0 solution_peak_fraction=0.3896
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=1 variant=eager tflops=0.330 gbps=82.706 ms=3.252
shape=1 variant=compiled tflops=0.328 gbps=82.038 ms=3.278
shape=1 variant=solution tflops=5.454 gbps=1366.110 ms=0.197
shape=1 solution_peak_fraction=0.7590
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=2 variant=eager tflops=0.242 gbps=30.275 ms=2.221
shape=2 variant=compiled tflops=0.240 gbps=30.099 ms=2.234
shape=2 variant=solution tflops=6.671 gbps=835.486 ms=0.080
shape=2 solution_peak_fraction=0.4642
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=3 variant=eager tflops=0.294 gbps=73.650 ms=1.369
shape=3 variant=compiled tflops=0.287 gbps=71.956 ms=1.402
shape=3 variant=solution tflops=3.986 gbps=999.009 ms=0.101
shape=3 solution_peak_fraction=0.5550
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=4 variant=eager tflops=0.106 gbps=26.543 ms=0.618
shape=4 variant=compiled tflops=0.102 gbps=25.475 ms=0.644
shape=4 variant=solution tflops=1.412 gbps=353.688 ms=0.046
shape=4 solution_peak_fraction=0.1965
peak_fraction: 0.4316
RESULT: OK
PASS
{
"problem": "03_paged_attention",
"harness": "kimi",
"model": "kimi-k2.6",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.4316,
"elapsed_seconds": 2656,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": null, "output_tokens": null, "cache_read_tokens": null, "cache_creation_tokens": null, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py"
}{
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/03"
}<system>144 lines read from file starting from line 1. Total lines in file: 144. End of file reached.</system>
1 """Naive PyTorch paged-attention decode reference (correctness oracle, not SOTA).
2
3 Single-query decode: each batch element has a query of shape (num_heads, head_dim)
4 and attends over a KV cache of `seq_len[b]` tokens stored as fixed-size pages in
5 a global pool. Pages for batch element b are listed in `block_table[b]`.
6
7 The reference performs the slow path:
8 1. Gather pages -> contiguous (seq_len, num_kv_heads, head_dim) per batch element.
9 2. Repeat KV heads for grouped-query (broadcast num_kv_heads -> num_heads).
10 3. Manual softmax(QK^T / sqrt(d)) @ V in fp32, cast back to bf16.
11
12 This avoids torch.nn.functional.scaled_dot_product_attention (which is on the
13 forbidden list) so the agent cannot dispatch through SDPA either.
14 """
15 import math
16
17 import torch
18 import torch.nn as nn
19
20 OP_TYPE = "attention"
21 SUPPORTED_PRECISIONS = ["bf16"]
22 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
23
24
25 # --- Shape knobs (overridden by check.py / benchmark.py from shapes.py) ----
26 BATCH = 8
27 NUM_HEADS = 32
28 NUM_KV_HEADS = 8
29 HEAD_DIM = 128
30 SEQ_LEN = 1024
31 PAGE_SIZE = 16
32
33
34 class Model(nn.Module):
35 """Single-query paged attention decode.
36
... (4770 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/benchmark.py"
}<system>109 lines read from file starting from line 1. Total lines in file: 109. End of file reached.</system>
1 """Correctness runner for paged-attention decode.
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5 """
6 import re
7 import sys
8 from pathlib import Path
9
10 import torch
11 import yaml
12
13 REPO_ROOT = Path(__file__).resolve().parents[2]
14 sys.path.insert(0, str(REPO_ROOT))
15
16 from src.eval.correctness import check_correctness # noqa: E402
17
18
19 def _apply_shape(reference, shape: dict) -> None:
20 reference.BATCH = shape["batch"]
21 reference.NUM_HEADS = shape["num_heads"]
22 reference.NUM_KV_HEADS = shape["num_kv_heads"]
23 reference.HEAD_DIM = shape["head_dim"]
24 reference.SEQ_LEN = shape["seq_len"]
25 reference.PAGE_SIZE = shape["page_size"]
26
27
28 def main():
29 try:
30 import reference
31 import shapes
32 import solution
33 except Exception as e:
34 print(f"FAIL: import error: {e}")
35 sys.exit(1)
36
37 problem_yaml = Path("problem.yaml")
38 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
39
40 sol_src = Path("solution.py").read_text() if Path("soluti
... (2815 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/Kernel"
}<system>131 lines read from file starting from line 1. Total lines in file: 131. End of file reached.</system>
1 """Roofline benchmark for paged-attention decode.
2
3 For each shape: times eager reference, compiled reference, SOTA (if available),
4 and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5
6 Decode is memory-bound, so peak_fraction is computed from achieved GB/s vs
7 the GPU's peak DRAM bandwidth.
8 """
9 import sys
10 from math import exp, log
11 from pathlib import Path
12
13 import torch
14 import yaml
15
16 REPO_ROOT = Path(__file__).resolve().parents[2]
17 sys.path.insert(0, str(REPO_ROOT))
18
19 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
20 from src.eval.timing import time_fn # noqa: E402
21 from src.hardware import get as get_hw # noqa: E402
22
23
24 def _eval_formula(expr: str, vars: dict) -> float:
25 return float(eval(expr, {"__builtins__": {}}, vars))
26
27
28 def _apply_shape(reference, shape: dict) -> None:
29 reference.BATCH = shape["batch"]
30 reference.NUM_HEADS = shape["num_heads"]
31 reference.NUM_KV_HEADS = shape["num_kv_heads"]
32 reference.HEAD_DIM = shape["head_dim"]
33 reference.SEQ_LEN = shape["seq_len"]
34 reference.PAGE_SIZE = shape["page_size"]
35
36
37 def main():
38
... (3881 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/problem.yaml"
}<system>18 lines read from file starting from line 1. Total lines in file: 18. End of file reached.</system>
1 """Shape sweep for paged attention decode.
2
3 Mix targets:
4 - small batch / long context (Llama-3 8B-style decode)
5 - large batch / medium context (server batched decode)
6 - GQA wide ratio (Llama-3 70B: 64 heads / 8 kv-heads)
7 - non-power-of-2 seq_len (forces predicated tail handling)
8 - head_dim=64 small-head case
9 """
10
11 SHAPES = [
12 # (B, H, Hkv, D, L, P)
13 {"batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1024, "page_size": 16},
14 {"batch": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 2048, "page_size": 16},
15 {"batch": 4, "num_heads": 64, "num_kv_heads": 8, "head_dim": 128, "seq_len": 4096, "page_size": 16},
16 {"batch": 16, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1535, "page_size": 16}, # non-pow2
17 {"batch": 8, "num_heads": 16, "num_kv_heads": 4, "head_dim": 64, "seq_len": 2000, "page_size": 16}, # small-D, non-pow2
18 ]
{
"raw": "{\"path\": \"/home/infatos"
}<system>48 lines read from file starting from line 1. Total lines in file: 48. End of file reached.</system>
1 name: 03_paged_attention
2 display_name: "Paged Attention Decode"
3 precision: bf16
4 regime: memory # decode is bandwidth-bound (KV cache streamed once per token)
5
6 # Effective FLOPs per call: QK^T + softmax*V across all batches/heads/seq.
7 # Sum-formula evaluated per-shape; the benchmark substitutes batch/num_heads/etc.
8 # 2 * B * H * L * D for QK^T plus 2 * B * H * L * D for AV => 4 * B * H * L * D
9 flops_formula: "4 * batch * num_heads * seq_len * head_dim"
10
11 # Bytes moved (the real bottleneck): the KV cache must be streamed end-to-end.
12 # K and V each: batch * seq_len * num_kv_heads * head_dim * 2 bytes (bf16)
13 # Q: batch * num_heads * head_dim * 2 (negligible vs KV)
14 # Out: same as Q
15 # So total ~ 2 * (B * L * Hkv * D * 2) + small.
16 bytes_formula: "2 * batch * seq_len * num_kv_heads * head_dim * 2 + batch * num_heads * head_dim * 2 * 2"
17
18 hardware: [RTX_PRO_6000]
19 peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20 peak_bandwidth_key: dram # 1.8 TB/s GDDR7 on RTX PRO 6000 Blackwell
21
22 tolerance:
23 bfloat16: 0.02
24 float16: 0.02
25
26 # Forbidden ops -- using any of these in solution.py fails correctness post-hoc.
27 # vllm.attention / flashinfer.batch_decode_with
... (975 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/PROMPT.txt"
}<system>84 lines read from file starting from line 1. Total lines in file: 84. End of file reached.</system>
1 """SOTA reference for paged-attention decode.
2
3 Tries, in order:
4 1. FlashInfer's BatchDecodeWithPagedKVCacheWrapper (preferred -- portable,
5 supports SM120, GQA, arbitrary head_dim).
6 2. vLLM's paged_attention_v2 CUDA op (requires its KV-cache layout, more
7 finicky; we adapt the layout on the fly when possible).
8
9 If neither is importable, is_available() returns False and the benchmark just
10 reports eager + compiled + solution.
11
12 Agents are FORBIDDEN from importing these in solution.py (see problem.yaml).
13 This file is only for the benchmark's reference line.
14 """
15 from __future__ import annotations
16
17 import torch
18
19
20 def _try_flashinfer(
21 query: torch.Tensor,
22 kv_cache: torch.Tensor,
23 block_table: torch.Tensor,
24 seq_lens: torch.Tensor,
25 num_kv_heads: int,
26 head_dim: int,
27 page_size: int,
28 ) -> torch.Tensor | None:
29 try:
30 import flashinfer # noqa: F401
31 from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
32 except Exception:
33 return None
34
35 B, H, D = query.shape
36 # FlashInfer expects K and V as separate (num_blocks, page_size, num_kv_heads, head_dim) tensors
... (2016 more chars)<system>7 lines read from file starting from line 1. Total lines in file: 7. End of file reached.</system>
1 I need you to write a paged attention decode kernel for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2
3 This is single-query decode: each batch element has a query of shape (num_heads, head_dim) and attends over a KV cache of seq_len[b] tokens stored as fixed-size pages in a global pool, with block_table[b] listing which pages belong to batch element b. The KV cache is laid out as (num_blocks, page_size, num_kv_heads, head_dim * 2) where the last dim packs [K | V] so a single gather pulls both. You return the bf16 output of shape (batch, num_heads, head_dim). Correctness tolerance is 0.02 abs/rel on bf16/fp16. The shapes you have to handle are batch=8 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=1024 page_size=16 (Llama-3 8B-style decode), batch=32 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=2048 page_size=16 (server-batched decode), batch=4 num_heads=64 num_kv_heads=8 head_dim=128 seq_len=4096 page_size=16 (Llama-3 70B GQA wide ratio with long context), batch=16 num_heads=32 num_kv_heads=8 head_dim=12
... (1229 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_name()); print(torch.cuda.get_device_capability())\""
}{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && python -c \"import triton; print(triton.__version__)\""
}<system>Command executed successfully.</system>
3.6.0
<system>Command executed successfully.</system>
2.11.0+cu130
NVIDIA RTX PRO 6000 Blackwell Workstation Edition
(12, 0)
{
"command": "python -c \"import triton.language as tl; print(dir(tl))\" | tr ',' '\\n' | grep -i dot"
}{
"raw": "{\"command\": \""
}<system>Command executed successfully.</system>
'dot'
'dot_scaled'
{
"command": "find /home -path \"*/vllm/*\" -name \"*.py\" 2>/dev/null | head -20"
}<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/config.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/audio.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/processing.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/__init__.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/image.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/profiling.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/video.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/parse.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/inputs.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/hasher.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/utils.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/registry.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/base.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/beam_search.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/test_utils.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/__init__.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/interface.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/xpu.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/cpu.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/cuda.py
<system>Command executed successfully.</system>
/home/infatoshi/PufferLib/.venv/lib/python3.12/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py
/home/infatoshi/PufferLib/.venv/lib/python3.12/site-packages/jax/experimental/pallas/ops/gpu/paged_attention.py
{
"path": "/home/infatoshi/PufferLib/.venv/lib/python3.12/site-packages/jax/experimental/pallas/ops/gpu/paged_attention.py"
}{
"raw": "{\"command\": \"find /home -"
}<system>461 lines read from file starting from line 1. Total lines in file: 461. End of file reached.</system>
1 # Copyright 2023 The JAX Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # https://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """Module containing decode attention."""
16 from __future__ import annotations
17
18 import functools
19 import math
20 from typing import Any
21
22 import jax
23 from jax import lax
24 from jax.experimental import pallas as pl
25 from jax.experimental.pallas import triton as plgpu
26 import jax.numpy as jnp
27 import numpy as np
28
29 DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
30
31
32 def paged_attention_kernel(
33 # inputs
34 q_ref, # [block_h, head_dim]
35 k_pages_ref, # [total_num_pages, page_size, head_dim]
36 k_scales_pages_ref, # [total_num_pages, page_
... (18029 more chars){
"command": "find /home/infatoshi/.cache/uv -path \"*/vllm*\" -name \"*.py\" 2>/dev/null | head -50"
}<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/config.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/audio.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/processing.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/__init__.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/image.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/profiling.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/video.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/parse.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/inputs.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/hasher.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/utils.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/registry.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/multimodal/base.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/beam_search.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/test_utils.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/__init__.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/interface.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/xpu.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/cpu.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/cuda.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/neuron.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/tpu.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/rocm.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vl
... (1943 more chars)<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/rocm.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/_ipex_ops.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/backends/rocm_flash_attn.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/paged_attn.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/chunked_prefill_paged_decode.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/nki_flash_attn.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/v1/attention/backends/cpu_attn.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/v1/attention/backends/pallas.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/v1/attention/backends/rocm_aiter_fa.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/_custom_ops.py
{
"path": "/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/paged_attn.py"
}{
"raw": "{\"command\": \"find /home -"
}<system>256 lines read from file starting from line 1. Total lines in file: 256. End of file reached.</system>
1 # SPDX-License-Identifier: Apache-2.0
2 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4 from dataclasses import dataclass
5 from typing import List, Optional, Tuple
6
7 import torch
8
9 from vllm import _custom_ops as ops
10 from vllm.triton_utils import HAS_TRITON
11
12 if HAS_TRITON:
13 from vllm.attention.ops.prefix_prefill import context_attention_fwd
14
15 # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
16 _PARTITION_SIZE = 512
17
18
19 @dataclass
20 class PagedAttentionMetadata:
21 """Metadata for PagedAttention."""
22 # (batch_size,). The length of sequences (entire tokens seen so far) per
23 # sequence.
24 seq_lens_tensor: Optional[torch.Tensor]
25 # Maximum sequence length in the batch. 0 if it is prefill-only batch.
26 max_decode_seq_len: int
27 # (batch_size, max_blocks_per_seq).
28 # Block addresses per sequence. (Seq id -> list of physical block)
29 # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
30 # in the kv cache. Each block can contain up to block_size tokens.
31 # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
32 # captured.
33 block_tables: Optional[t
... (8791 more chars)<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/config.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/cuda.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/platforms/rocm.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/triton_utils/__init__.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/triton_utils/importing.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/backends/rocm_flash_attn.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/backends/xformers.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/backends/triton_mla.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/backends/mla/common.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/paged_attn.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/chunked_prefill_paged_decode.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/merge_attn_states.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/prefix_prefill.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/triton_decode_attention.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/triton_unified_attention.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/triton_flash_attention.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/triton_merge_attn_states.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/compilation/compiler_interface.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/envs.py
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/transformers_utils/configs/mpt.p
... (2 more chars){
"path": "/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/triton_decode_attention.py"
}<system>674 lines read from file starting from line 1. Total lines in file: 674. End of file reached.</system>
1 # SPDX-License-Identifier: Apache-2.0
2 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4 # Adapted from
5 # https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
6 # which was originally adapted from
7 # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
8 # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
9
10 # Changes:
11 # - Add support for page size >= 1.
12
13 # Copyright 2025 vLLM Team
14 # Copyright 2023-2024 SGLang Team
15 # Licensed under the Apache License, Version 2.0 (the "License");
16 # you may not use this file except in compliance with the License.
17 # You may obtain a copy of the License at
18 #
19 # http://www.apache.org/licenses/LICENSE-2.0
20 #
21 # Unless required by applicable law or agreed to in writing, software
22 # distributed under the License is distributed on an "AS IS" BASIS,
23 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 # See the License for the specific language governing
... (22519 more chars){
"command": "python -c \"import torch; print(torch.cuda.get_device_properties(0))\""
}<system>Command executed successfully.</system>
_CudaDeviceProperties(name='NVIDIA RTX PRO 6000 Blackwell Workstation Edition', major=12, minor=0, total_memory=97227MB, multi_processor_count=188, uuid=d594fb7a-abcb-bad5-1fc4-675b478a47df, pci_bus_id=1, pci_device_id=0, pci_domain_id=0, L2_cache_size=128MB)
{
"command": "find /home -name \"decode_attention.py\" 2>/dev/null | head -10"
}{
"command": "find /home -name \"*decode*attention*.py\" 2>/dev/null | head -10"
}<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/git-v0/checkouts/b454a06192129efd/e67a0d488/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
/home/infatoshi/.cache/uv/git-v0/checkouts/b454a06192129efd/e67a0d488/python/sglang/srt/layers/attention/wave_ops/decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/Lt9zfFedX0Z5V0YyUK_8L/sglang/srt/layers/attention/triton_ops/decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/Lt9zfFedX0Z5V0YyUK_8L/sglang/srt/layers/attention/wave_ops/decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/Jg4v4v72a09iziNGrAKVc/sglang/srt/layers/attention/triton_ops/decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/Jg4v4v72a09iziNGrAKVc/sglang/srt/layers/attention/wave_ops/decode_attention.py
/home/infatoshi/PufferLib/.venv/lib/python3.12/site-packages/jax/experimental/pallas/ops/gpu/decode_attention.py
/home/infatoshi/qwen3.6-dense/.venv-train/lib/python3.12/site-packages/sglang/srt/layers/attention/triton_ops/decode_attention.py
/home/infatoshi/qwen3.6-dense/.venv-train/lib/python3.12/site-packages/sglang/srt/layers/attention/wave_ops/decode_attention.py
/home/infatoshi/qwen3.6-dense/.venv-sglang/lib/python3.12/site-packages/sglang/srt/layers/attention/triton_ops/decode_attention.py
<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/.tmpR4nmo3/vllm/attention/ops/triton_decode_attention.py
/home/infatoshi/.cache/uv/git-v0/checkouts/b454a06192129efd/e67a0d488/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
/home/infatoshi/.cache/uv/git-v0/checkouts/b454a06192129efd/e67a0d488/python/sglang/srt/layers/attention/wave_ops/decode_attention.py
/home/infatoshi/.cache/uv/.tmpWxc6XF/vllm/attention/ops/triton_decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/Meg-XZj8kLE62I6i-jxiV/vllm/attention/ops/triton_decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/iFZrSVxXzSlD8QXgcEf3G/vllm/v1/attention/ops/triton_decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/Lt9zfFedX0Z5V0YyUK_8L/sglang/srt/layers/attention/triton_ops/decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/Lt9zfFedX0Z5V0YyUK_8L/sglang/srt/layers/attention/wave_ops/decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/pD3XxLT5AwqAQ7PonqFRK/vllm/v1/attention/ops/triton_decode_attention.py
/home/infatoshi/.cache/uv/archive-v0/YOuc6VDeRqbreR7_03swo/vllm/v1/attention/ops/triton_decode_attention.py
{
"path": "/home/infatoshi/qwen3.6-dense/.venv-sglang/lib/python3.12/site-packages/sglang/srt/layers/attention/triton_ops/decode_attention.py"
}<system>787 lines read from file starting from line 1. Total lines in file: 787. End of file reached.</system>
1 # Copyright 2023-2024 SGLang Team
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 # ==============================================================================
14 """
15 Memory-efficient attention for decoding.
16 It supports page size = 1.
17 """
18
19 # Adapted from
20 # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
21 # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
22
23 import logging
24
25 import triton
26 import triton.language as tl
27
28 from sglang.srt.utils import is_hip
29
30 _is_hip = is_hip()
3
... (24947 more chars){
"todos": [
{
"title": "Write initial Triton paged attention decode kernel for packed KV layout",
"status": "in_progress"
},
{
"title": "Test correctness with check.py",
"status": "pending"
},
{
"title": "Profile and optimize performance",
"status": "pending"
},
{
"title": "Iterate until peak_fraction target is met",
"status": "pending"
}
]
}<system>Todo list updated</system>
Todo list updated--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -0,0 +1,482 @@
+"""Triton-based paged attention decode kernel for packed KV cache layout.
+
+The KV cache is laid out as (num_blocks, page_size, num_kv_heads, head_dim * 2)
+with K and V packed on the last dimension: [K | V].
+
+This kernel uses a two-stage approach:
+1. Stage 1 splits the KV sequence into NUM_KV_SPLITS parts, each block computes
+ partial attention with online softmax.
+2. Stage 2 reduces partial results across splits.
+
+For efficiency on the packed layout, each iteration loads an entire page
+(page_size tokens x 2*head_dim elements) as a coalesced block, then extracts
+K and V slices.
+"""
+import math
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "attention"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+BATCH = 8
+NUM_HEADS = 32
+NUM_KV_HEADS = 8
+HEAD_DIM = 128
+SEQ_LEN = 1024
+PAGE_SIZE = 16
+
+
+# --------------------------------------------------------------------------- #
+# Stage 1: compute partial attention per KV split (grouped for GQA)
+# --------------------------------------------------------------------------- #
+@triton.jit
+def _fwd_kernel_stage1(
+ Q,
+ KV_Cache,
+ block_table,
+ seq_lens,
+ sm_scale,
+ Mid_O,
+ Mid_LSE,
+ stride_qb,
+ stride_qh,
+ stride_bt0,
+ stride_bt1,
+ stride_kv0,
+ stride_kv1,
+ stride_kv2,
+ stride_kv3,
+ stride_mid_ob,
+ stride_mid_oh,
+ stride_mid_os,
+ kv_group_num: tl.constexpr,
+ q_head_num: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ BLOCK_DV: tl.constexpr,
+ BLOCK_H: tl.constexpr,
+ NUM_KV_SPLITS: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+):
+ cur_batch = tl.program_id(0)
+ cur_head_id = tl.program_id(1)
+ split_kv_id = tl.program_id(2)
+
+ # Which KV head this group belongs to
+ heads_per_kv = tl.cdiv(kv_group_num, BLOCK_H)
+ cur_kv_head = cur_head_id // heads_per_kv
+
+ # Number of valid query heads in this block
+ if kv_group_num > BLOCK_H:
+ VALID_BLOCK_H: tl.constexpr = BLOCK_H
+ else:
+ VALID_BLOCK_H: tl.constexpr = kv_group_num
+
+ cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
+ mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
+ mask_h = mask_h & (cur_head < q_head_num)
+
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_dv = tl.arange(0, BLOCK_DV)
+ mask_d = offs_d < HEAD_DIM
+ mask_dv = offs_dv < HEAD_DIM
+
+ cur_batch_seq_len = tl.load(seq_lens + cur_batch)
+
+ # Load Q for this head group
+ offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_d[None, :]
+ q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0)
+
+ # Determine this split's token range
+ kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
+ split_kv_start = kv_len_per_split * split_kv_id
+ split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
+
+ # Online softmax accumulators
+ e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
+ e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
+
+ if split_kv_end > split_kv_start:
+ # Iterate over pages within this split
+ start_page = split_kv_start // PAGE_SIZE
+ end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
+
+ offs_token_in_page = tl.arange(0, PAGE_SIZE)
+ offs_dim = tl.arange(0, 2 * HEAD_DIM)
+
+ for page_idx in range(start_page, end_page):
+ physical_page = tl.load(
+ block_table + cur_batch * stride_bt0 + page_idx * stride_bt1,
+ )
+
+ # Global token indices for this page
+ global_token_idx = page_idx * PAGE_SIZE + offs_token_in_page
+ token_mask = global_token_idx < split_kv_end
+
+ # Base address for this page and KV head
+ page_base = (
+ KV_Cache
+ + physical_page * stride_kv0
+ + cur_kv_head * stride_kv2
+ )
+
+ # Load full page: (PAGE_SIZE, 2*HEAD_DIM)
+ page_data = tl.load(
+ page_base + offs_token_in_page[:, None] * stride_kv1 + offs_dim[None, :],
+ mask=token_mask[:, None],
+ other=0.0,
+ )
+
+ # Extract K and V
+ k = page_data[:, :HEAD_DIM] # (PAGE_SIZE, HEAD_DIM)
+ v = page_data[:, HEAD_DIM:] # (PAGE_SIZE, HEAD_DIM)
+
+ # Compute QK^T
+ qk = tl.dot(q, k.to(q.dtype).T)
+ qk *= sm_scale
+
+ # Mask tokens beyond split end or seq_len
+ qk = tl.where(
+ mask_h[:, None] & token_mask[None, :],
+ qk,
+ float("-inf"),
+ )
+
+ # Online softmax update
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
+ re_scale = tl.exp(e_max - n_e_max)
+ p = tl.exp(qk - n_e_max[:, None])
+
+ acc *= re_scale[:, None]
+ acc += tl.dot(p.to(v.dtype), v)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # Store partial output and LSE
+ offs_mid_o = (
+ cur_batch * stride_mid_ob
+ + cur_head[:, None] * stride_mid_oh
+ + split_kv_id * stride_mid_os
+ + offs_dv[None, :]
+ )
+ tl.store(
+ Mid_O + offs_mid_o,
+ acc / e_sum[:, None],
+ mask=mask_h[:, None] & mask_dv[None, :],
+ )
+
+ offs_mid_lse = (
+ cur_batch * stride_mid_ob
+ + cur_head * stride_mid_oh
+ + split_kv_id * stride_mid_os
+ + HEAD_DIM
+ )
+ tl.store(
+ Mid_LSE + offs_mid_lse,
+ e_max + tl.log(e_sum),
+
... (8036 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"content": "<13438 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 13438 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python check.py 2>&1",
"timeout": 120
}<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1101, in __getitem__
raise ValueError(f"unsupported tensor index: {sl}")
ValueError: unsupported tensor index: <triton.language.core.slice object at 0x7bcd4a549c10>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 72, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 449,
... (2442 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"edit": {
"old": " # Load full page: (PAGE_SIZE, 2*HEAD_DIM)\n page_data = tl.load(\n page_base + offs_token_in_page[:, None] * stride_kv1 + offs_dim[None, :],\n mask=token_mask[:, None],\n other=0.0,\n )\n\n # Extract K and V\n k = page_data[:, :HEAD_DIM] # (PAGE_SIZE, HEAD_DIM)\n v = page_data[:, HEAD_DIM:] # (PAGE_SIZE, HEAD_DIM)",
"new": " # Load K and V separately (
... (493 more chars)<system>File successfully edited. Applied 1 edit(s) with 1 total replacement(s).</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python check.py 2>&1",
"timeout": 120
}<system>ERROR: Command failed with exit code: 1.</system>
FAIL: shape 0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16} seed 456: tolerance exceeded: max_abs_diff=6.2576 (atol=0.02, rtol=0.02)
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"edit": [
{
"old": "@triton.jit\ndef _fwd_kernel_stage1(\n Q,\n KV_Cache,\n block_table,\n seq_lens,\n sm_scale,\n Mid_O,\n Mid_LSE,\n stride_qb,\n stride_qh,\n stride_bt0,\n stride_bt1,\n stride_kv0,\n stride_kv1,\n stride_kv2,\n stride_kv3,\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n kv_group_num: tl.constexpr,\n q_head_num: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DV: tl.constexpr,\n BLOCK_H: tl.cons
... (10153 more chars)<system>File successfully edited. Applied 6 edit(s) with 6 total replacement(s).</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python check.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python benchmark.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0428 11:09:30.936000 395740 .venv/lib/python3.11/site-packages/torch/_dynamo/var
... (1920 more chars){
"command": "<1249 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && cat > /tmp/test_perf.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
for shape_idx, shape in enumerate(shapes.SHAPES):
apply_shape(shape)
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup
for _ in range(5):
sol_model(*inputs)
torch.cuda.synchronize()
# Time it
import time
start = time.perf_counter()
for _ in range(30):
sol_model(*inputs)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 30 * 1000
print(f"shape={shape_idx} {shape} ms={elapsed:.3f}")
EOF
uv run python /tmp/test_perf.py<system>Command executed successfully.</system>
shape=0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16} ms=0.046
shape=1 {'batch': 32, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 2048, 'page_size': 16} ms=0.195
shape=2 {'batch': 4, 'num_heads': 64, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 4096, 'page_size': 16} ms=0.050
shape=3 {'batch': 16, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1535, 'page_size': 16} ms=0.054
shape=4 {'batch': 8, 'num_heads': 16, 'num_kv_heads': 4, 'head_dim': 64, 'seq_len': 2000, 'page_size': 16} ms=0.055
{
"command": "find /home/infatoshi/cuda/KernelBench-Hard -name \"timing.py\" | xargs cat"
}<system>Command executed successfully.</system>
"""Centralized GPU timing for benchmark.py files.
All problems' benchmark.py should call `time_fn` from here so we have one
implementation of warmup, L2 flush, and CUDA event capture to maintain.
Methodology:
- 10 warmup calls absorb Triton autotune (typical ~7 configs) and
torch.compile reduce-overhead CUDA-graph capture.
- Between each timed call, `_l2_flush()` writes 128 MB to evict L2
(Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
we measure HBM-load bandwidth, not L2-cached bandwidth).
- GPU timing via cuda.Event with synchronize() AFTER record() but
BEFORE elapsed_time() — the canonical NVIDIA pattern.
- Reported value is the median of `iters` trials, robust to outliers.
Notes / known biases not addressed here:
- torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
On small shapes where launch overhead matters, this gives compile an
artificial advantage. We accept this as the cost of using the agreed
"compiled" baseline policy.
- cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
- Median over a fairly small number of trials (default 30) is fine for
headline numbers but won't catch bimodal distributions. Use --extra
instrumentation if you ever care.
"""
from __future__ import annotations
import statistics
import torch
# Blackwell consumer (R
... (4223 more chars){
"command": "python -c \"import triton.language as tl; print([x for x in dir(tl) if 'cat' in x or 'join' in x or 'concat' in x or 'stack' in x])\""
}<system>Command executed successfully.</system>
['cat', 'join']
{
"command": "python -c \"import triton.language as tl; help(tl.cat)\""
}<system>Command executed successfully.</system>
Help on function cat in module triton.language.core:
cat(input, other, can_reorder=False, _semantic=None)
Concatenate the given blocks
:param input: The first input tensor.
:type input: Tensor
:param other: The second input tensor.
:type other: Tensor
:param reorder: Compiler hint. If true, the compiler is
allowed to reorder elements while concatenating inputs. Only use if the
order does not matter (e.g., result is only used in reduction ops).
Current implementation of `cat` supports only can_reorder=True.
{
"command": "<930 chars \u2014 see below>"
}cat > /tmp/test_blockh.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernel(Q, K, Out, BLOCK_H: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_N: tl.constexpr):
offs_h = tl.arange(0, BLOCK_H)
offs_d = tl.arange(0, BLOCK_D)
q = tl.load(Q + offs_h[:, None] * BLOCK_D + offs_d[None, :])
k = tl.load(K + offs_d[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :])
out = tl.dot(q, k)
tl.store(Out + offs_h[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :], out)
Q = torch.randn(4, 128, dtype=torch.bfloat16, device='cuda')
K = torch.randn(128, 16, dtype=torch.bfloat16, device='cuda')
Out = torch.empty(4, 16, dtype=torch.bfloat16, device='cuda')
test_kernel[(1,)](Q, K, Out, BLOCK_H=4, BLOCK_D=128, BLOCK_N=16)
print("BLOCK_H=4 works!")
test_kernel[(1,)](Q, K, Out, BLOCK_H=8, BLOCK_D=128, BLOCK_N=16)
print("BLOCK_H=8 works!")
EOF
uv run python /tmp/test_blockh.py<system>Command executed successfully.</system>
BLOCK_H=4 works!
BLOCK_H=8 works!
{
"command": "<840 chars \u2014 see below>"
}cat > /tmp/test_blockh1.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernel(Q, K, Out, BLOCK_H: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_N: tl.constexpr):
offs_h = tl.arange(0, BLOCK_H)
offs_d = tl.arange(0, BLOCK_D)
q = tl.load(Q + offs_h[:, None] * BLOCK_D + offs_d[None, :])
k = tl.load(K + offs_d[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :])
out = tl.dot(q, k)
tl.store(Out + offs_h[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :], out)
Q = torch.randn(1, 128, dtype=torch.bfloat16, device='cuda')
K = torch.randn(128, 16, dtype=torch.bfloat16, device='cuda')
Out = torch.empty(1, 16, dtype=torch.bfloat16, device='cuda')
test_kernel[(1,)](Q, K, Out, BLOCK_H=1, BLOCK_D=128, BLOCK_N=16)
print("BLOCK_H=1 works!")
EOF
uv run python /tmp/test_blockh1.py<system>Command executed successfully.</system>
BLOCK_H=1 works!
{
"command": "<1192 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && cat > /tmp/profile_nsys.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
# Profile shape 0
shape = shapes.SHAPES[0]
apply_shape(shape)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup
for _ in range(10):
sol_model(*inputs)
torch.cuda.synchronize()
# Profile
import torch.profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
for _ in range(5):
sol_model(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
EOF
uv run python /tmp/profile_nsys.py<system>Command executed successfully.</system>
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
_fwd_kernel_stage1 0.00% 0.000us 0.00% 0.000us 0.000us 102.365us 87.67% 102.365us 20.473us 5
aten::max 43.34% 764.116us 91.37% 1.611ms 322.141us 6.462us 5.53% 8.061us 1.612us 5
void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 6.462us 5.53% 6.462us
... (3061 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"edit": [
{
"old": "def _decode_grouped_att_m_fwd(\n q,\n kv_cache,\n block_table,\n seq_lens,\n mid_o,\n num_kv_splits,\n sm_scale,\n page_size,\n):\n batch, head_num = q.shape[0], q.shape[1]\n num_kv_heads = kv_cache.shape[2]\n head_dim = kv_cache.shape[3] // 2\n kv_group_num = head_num // num_kv_heads\n\n BLOCK_H = 16\n grid = (\n batch,\n triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),\n num_kv_splits,\n )\n\n BLOCK
... (3681 more chars)<system>File successfully edited. Applied 2 edit(s) with 2 total replacement(s).</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python check.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python benchmark.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0428 11:15:00.024000 398830 .venv/lib/python3.11/site-packages/torch/_dynamo/var
... (1921 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"edit": {
"old": " # Use BLOCK_H matching the group size for better tensor-core utilization.\n # Round up to the next power of two if the group size is not a power of two,\n # but all our shapes have power-of-two group sizes.\n BLOCK_H = min(16, kv_group_num)\n if BLOCK_H < 1:\n BLOCK_H = 1",
"new": " BLOCK_H = 16"
}
}<system>File successfully edited. Applied 1 edit(s) with 1 total replacement(s).</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python benchmark.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0428 11:15:25.490000 399040 .venv/lib/python3.11/site-packages/torch/_dynamo/var
... (1922 more chars){
"command": "<1473 chars \u2014 see below>"
}cat > /tmp/sweep_params.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import triton
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
# Test shape 0 and shape 4
for shape_idx in [0, 4]:
shape = shapes.SHAPES[shape_idx]
apply_shape(shape)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup
for _ in range(10):
sol_model(*inputs)
torch.cuda.synchronize()
# Time with L2 flush
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
import statistics
print(f"shape={shape_idx} baseline ms={statistics.median(times):.3f}")
EOF
uv run python /tmp/sweep_params.py<system>Command executed successfully.</system>
shape=0 baseline ms=0.082
shape=4 baseline ms=0.068
{
"command": "<1383 chars \u2014 see below>"
}cat > /tmp/test_variant.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import triton
import triton.language as tl
import reference
import shapes
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
# Copy the kernel and modify parameters
device = torch.device("cuda:0")
shape = shapes.SHAPES[0]
apply_shape(shape)
init_args = reference.get_init_inputs()
# Import solution to get the kernel code
import solution
# Time the current solution
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
for _ in range(10):
sol_model(*inputs)
torch.cuda.synchronize()
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
import statistics
print(f"Current: ms={statistics.median(times):.3f}")
EOF
uv run python /tmp/test_variant.py<system>Command executed successfully.</system>
Current: ms=0.084
{
"command": "<1404 chars \u2014 see below>"
}cat > /tmp/test_reshape.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernel(KV, K_Out, V_Out, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr):
page_base = KV + tl.arange(0, PAGE_SIZE * 2 * HEAD_DIM)
page_data = tl.load(page_base)
# Try reshape to (PAGE_SIZE, 2, HEAD_DIM)
page_3d = tl.reshape(page_data, (PAGE_SIZE, 2, HEAD_DIM))
k = page_3d[:, 0, :]
v = page_3d[:, 1, :]
# Store to verify
k_flat = tl.reshape(k, (PAGE_SIZE * HEAD_DIM,))
v_flat = tl.reshape(v, (PAGE_SIZE * HEAD_DIM,))
tl.store(K_Out + tl.arange(0, PAGE_SIZE * HEAD_DIM), k_flat)
tl.store(V_Out + tl.arange(0, PAGE_SIZE * HEAD_DIM), v_flat)
PAGE_SIZE = 16
HEAD_DIM = 128
kv = torch.randn(PAGE_SIZE, 2, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# Flatten to simulate packed layout: (PAGE_SIZE, 2*HEAD_DIM)
kv_packed = kv.view(PAGE_SIZE, 2 * HEAD_DIM)
k_out = torch.empty(PAGE_SIZE * HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_out = torch.empty(PAGE_SIZE * HEAD_DIM, dtype=torch.bfloat16, device='cuda')
test_kernel[(1,)](kv_packed, k_out, v_out, PAGE_SIZE=PAGE_SIZE, HEAD_DIM=HEAD_DIM)
k_expected = kv[:, 0, :].reshape(-1)
v_expected = kv[:, 1, :].reshape(-1)
print("K match:", torch.allclose(k_out, k_expected, atol=1e-5))
print("V match:", torch.allclose(v_out, v_expected, atol=1e-5))
EOF
uv run python /tmp/test_reshape.py<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1101, in __getitem__
raise ValueError(f"unsupported tensor index: {sl}")
ValueError: unsupported tensor index: constexpr[0]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/test_reshape.py", line 30, in <module>
test_kernel[(1,)](kv_packed, k_out, v_out, PAGE_SIZE=PAGE_SIZE, HEAD_DIM=HEAD_DIM)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
... (1161 more chars){
"command": "<1411 chars \u2014 see below>"
}cat > /tmp/test_gather.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernel(KV, K_Out, V_Out, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr):
n = PAGE_SIZE * 2 * HEAD_DIM
page_data = tl.load(KV + tl.arange(0, n))
# Build 2D indices for K
offs_t = tl.arange(0, PAGE_SIZE)[:, None]
offs_d = tl.arange(0, HEAD_DIM)[None, :]
idx_k = offs_t * (2 * HEAD_DIM) + offs_d
k = page_data[idx_k]
idx_v = offs_t * (2 * HEAD_DIM) + offs_d + HEAD_DIM
v = page_data[idx_v]
# Flatten and store
tl.store(K_Out + tl.arange(0, PAGE_SIZE * HEAD_DIM), tl.reshape(k, (PAGE_SIZE * HEAD_DIM,)))
tl.store(V_Out + tl.arange(0, PAGE_SIZE * HEAD_DIM), tl.reshape(v, (PAGE_SIZE * HEAD_DIM,)))
PAGE_SIZE = 16
HEAD_DIM = 128
kv = torch.randn(PAGE_SIZE, 2, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
kv_packed = kv.view(PAGE_SIZE * 2 * HEAD_DIM)
k_out = torch.empty(PAGE_SIZE * HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_out = torch.empty(PAGE_SIZE * HEAD_DIM, dtype=torch.bfloat16, device='cuda')
test_kernel[(1,)](kv_packed, k_out, v_out, PAGE_SIZE=PAGE_SIZE, HEAD_DIM=HEAD_DIM)
k_expected = kv[:, 0, :].reshape(-1)
v_expected = kv[:, 1, :].reshape(-1)
print("K match:", torch.allclose(k_out, k_expected, atol=1e-5))
print("V match:", torch.allclose(v_out, v_expected, atol=1e-5))
EOF
uv run python /tmp/test_gather.py<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/test_gather.py", line 32, in <module>
test_kernel[(1,)](kv_packed, k_out, v_out, PAGE_SIZE=PAGE_SIZE, HEAD_DIM=HEAD_DIM)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs,
... (1273 more chars){
"command": "<1421 chars \u2014 see below>"
}cat > /tmp/test_gather2.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernel(KV, K_Out, V_Out, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr):
total = PAGE_SIZE * 2 * HEAD_DIM
page_data = tl.load(KV + tl.arange(0, total))
# Build 2D indices for K
offs_t = tl.arange(0, PAGE_SIZE)[:, None]
offs_d = tl.arange(0, HEAD_DIM)[None, :]
idx_k = offs_t * (2 * HEAD_DIM) + offs_d
k = page_data[idx_k]
idx_v = offs_t * (2 * HEAD_DIM) + offs_d + HEAD_DIM
v = page_data[idx_v]
# Flatten and store
tl.store(K_Out + tl.arange(0, PAGE_SIZE * HEAD_DIM), tl.reshape(k, (PAGE_SIZE * HEAD_DIM,)))
tl.store(V_Out + tl.arange(0, PAGE_SIZE * HEAD_DIM), tl.reshape(v, (PAGE_SIZE * HEAD_DIM,)))
PAGE_SIZE = 16
HEAD_DIM = 128
kv = torch.randn(PAGE_SIZE, 2, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
kv_packed = kv.view(PAGE_SIZE * 2 * HEAD_DIM)
k_out = torch.empty(PAGE_SIZE * HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_out = torch.empty(PAGE_SIZE * HEAD_DIM, dtype=torch.bfloat16, device='cuda')
test_kernel[(1,)](kv_packed, k_out, v_out, PAGE_SIZE=PAGE_SIZE, HEAD_DIM=HEAD_DIM)
k_expected = kv[:, 0, :].reshape(-1)
v_expected = kv[:, 1, :].reshape(-1)
print("K match:", torch.allclose(k_out, k_expected, atol=1e-5))
print("V match:", torch.allclose(v_out, v_expected, atol=1e-5))
EOF
uv run python /tmp/test_gather2.py<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/test_gather2.py", line 32, in <module>
test_kernel[(1,)](kv_packed, k_out, v_out, PAGE_SIZE=PAGE_SIZE, HEAD_DIM=HEAD_DIM)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs,
... (1282 more chars){
"command": "<1432 chars \u2014 see below>"
}cat > /tmp/test_gather3.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernel(KV, K_Out, V_Out, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr):
total: tl.constexpr = PAGE_SIZE * 2 * HEAD_DIM
page_data = tl.load(KV + tl.arange(0, total))
# Build 2D indices for K
offs_t = tl.arange(0, PAGE_SIZE)[:, None]
offs_d = tl.arange(0, HEAD_DIM)[None, :]
idx_k = offs_t * (2 * HEAD_DIM) + offs_d
k = page_data[idx_k]
idx_v = offs_t * (2 * HEAD_DIM) + offs_d + HEAD_DIM
v = page_data[idx_v]
# Flatten and store
k_total: tl.constexpr = PAGE_SIZE * HEAD_DIM
tl.store(K_Out + tl.arange(0, k_total), tl.reshape(k, (k_total,)))
tl.store(V_Out + tl.arange(0, k_total), tl.reshape(v, (k_total,)))
PAGE_SIZE = 16
HEAD_DIM = 128
kv = torch.randn(PAGE_SIZE, 2, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
kv_packed = kv.view(PAGE_SIZE * 2 * HEAD_DIM)
k_out = torch.empty(PAGE_SIZE * HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_out = torch.empty(PAGE_SIZE * HEAD_DIM, dtype=torch.bfloat16, device='cuda')
test_kernel[(1,)](kv_packed, k_out, v_out, PAGE_SIZE=PAGE_SIZE, HEAD_DIM=HEAD_DIM)
k_expected = kv[:, 0, :].reshape(-1)
v_expected = kv[:, 1, :].reshape(-1)
print("K match:", torch.allclose(k_out, k_expected, atol=1e-5))
print("V match:", torch.allclose(v_out, v_expected, atol=1e-5))
EOF
uv run python /tmp/test_gather3.py<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1094, in __getitem__
for dim, sl in enumerate(slices):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 41, in wrapper
raise ValueError("Did you forget to add @triton.jit ? "
ValueError: Did you forget to add @triton.jit ? (`_semantic` argument must be provided outside of JIT functions.)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/test_gather3.py", line 33, in <module>
test_kernel[(1,)](kv_packed, k_out, v_out, PAGE_SIZE=PAGE_SIZE, HEAD_DIM=HEAD_DIM)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs
... (1520 more chars){
"command": "<10499 chars \u2014 see below>",
"timeout": 300
}cat > /tmp/sweep_blockh.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import triton
import triton.language as tl
import reference
import shapes
import math
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
# Kernel variants
@triton.jit
def _fwd_kernel_stage1_v2(
Q,
KV_Cache,
block_table,
seq_lens,
sm_scale,
Mid_O,
stride_qb,
stride_qh,
stride_bt0,
stride_bt1,
stride_kv0,
stride_kv1,
stride_kv2,
stride_kv3,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
split_kv_id = tl.program_id(2)
heads_per_kv = tl.cdiv(kv_group_num, BLOCK_H)
cur_kv_head = cur_head_id // heads_per_kv
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < HEAD_DIM
mask_dv = offs_dv < HEAD_DIM
cur_batch_seq_len = tl.load(seq_lens + cur_batch)
offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
start_page = split_kv_start // PAGE_SIZE
end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
offs_token_in_page = tl.arange(0, PAGE_SIZE)
for page_idx in range(start_page, end_page):
physical_page = tl.load(
block_table + cur_batch * stride_bt0 + page_idx * stride_bt1,
)
global_token_idx = page_idx * PAGE_SIZE + offs_token_in_page
token_mask = global_token_idx < split_kv_end
page_base = (
KV_Cache
+ physical_page * stride_kv0
+ cur_kv_head * stride_kv2
)
k = tl.load(
page_base + offs_token_in_page[:, None] * stride_kv1 + offs_d[None, :],
mask=token_mask[:, None] & mask_d[None, :],
other=0.0,
)
v = tl.load(
page_base + offs_token_in_page[:, None] * stride_kv1 + (HEAD_DIM + offs_dv)[None, :],
mask=token_mask[:, None] & mask_dv[None, :],
other=0.0,
)
qk = tl.dot(q, k.to(q.dtype).T)
qk *= sm_scale
qk = tl.where(
mask_h[:, None] & token_mask[None, :],
qk,
float("-inf"),
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = (
cur_batch * stride_mid_ob
+ cur_head[:
... (6499 more chars)<system>Command executed successfully.</system>
shape=0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16}
BLOCK_H= 1 splits=1 ms=0.084 max_diff=0.0001
BLOCK_H= 1 splits=2 ms=0.066 max_diff=0.0001
BLOCK_H= 1 splits=4 ms=0.047 max_diff=0.0001
BLOCK_H= 1 splits=8 ms=0.051 max_diff=0.0001
BLOCK_H= 2 splits=1 ms=0.088 max_diff=0.0001
BLOCK_H= 2 splits=2 ms=0.069 max_diff=0.0001
BLOCK_H= 2 splits=4 ms=0.043 max_diff=0.0001
BLOCK_H= 2 splits=8 ms=0.043 max_diff=0.0001
BLOCK_H= 4 splits=1 ms=0.085 max_diff=0.0001
BLOCK_H= 4 splits=2 ms=0.064 max_diff=0.0001
BLOCK_H= 4 splits=4 ms=0.044 max_diff=0.0001
BLOCK_H= 4 splits=8 ms=0.052 max_diff=0.0001
BLOCK_H= 8 splits=8 ms=0.038 max_diff=0.0001
BLOCK_H=16 splits=8 ms=0.034 max_diff=0.0001
shape=2 {'batch': 4, 'num_heads': 64, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 4096, 'page_size': 16}
BLOCK_H= 1 splits=1 ms=0.239 max_diff=0.0000
BLOCK_H= 1 splits=2 ms=0.144 max_diff=0.0000
BLOCK_H= 1 splits=4 ms=0.114 max_diff=0.0000
BLOCK_H= 1 splits=8 ms=0.139 max_diff=0.0000
BLOCK_H= 2 splits=1 ms=0.250 max_diff=0.0000
BLOCK_H= 2 splits=2 ms=0.142 max_diff=0.0000
BLOCK_H= 2 splits=4 ms=0.101 max_diff=0.0000
BLOCK_H= 2 splits=8 ms=0.071 max_diff=0.0000
BLOCK_H= 4 splits=1 ms=0.234 max_diff=0.0000
BLOCK_H= 4 splits=2 ms=0.138 max_diff=0.0000
BLOCK_H= 4 splits=4 ms=0.095 max_diff=0.0000
BLOCK_H= 4 splits=8 ms=0.075 max_diff=0.0000
BLOCK_H= 8
... (1271 more chars){
"command": "<8038 chars \u2014 see below>"
}cat > /tmp/verify_bug.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import triton
import triton.language as tl
import reference
import math
shape = {"batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1024, "page_size": 16}
B, H, Hkv, D, L, P = shape["batch"], shape["num_heads"], shape["num_kv_heads"], shape["head_dim"], shape["seq_len"], shape["page_size"]
device = torch.device("cuda:0")
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16, device=device) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16, device=device) * 0.1
block_table = torch.randperm(total_pages)[:B * pages_per_seq].reshape(B, pages_per_seq).int().to(device)
seq_lens = torch.full((B,), L, dtype=torch.int32, device=device)
ref_model = reference.Model(B, H, Hkv, D, L, P).to(device).eval()
with torch.no_grad():
ref_out = ref_model(query, kv_cache, block_table, seq_lens)
# Import the sweep kernel
import importlib.util
spec = importlib.util.spec_from_file_location("sweep", "/tmp/sweep_blockh.py")
# Actually, just run the kernel directly using the code from sweep
@triton.jit
def _fwd_kernel_stage1_v2(
Q, KV_Cache, block_table, seq_lens, sm_scale, Mid_O,
stride_qb, stride_qh, stride_bt0, stride_bt1,
stride_kv0, stride_kv1, stride_kv2, stride_kv3,
stride_mid_ob, stride_mid_oh, stride_mid_os,
kv_group_num: tl.constexpr, q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
split_kv_id = tl.program_id(2)
heads_per_kv = tl.cdiv(kv_group_num, BLOCK_H)
cur_kv_head = cur_head_id // heads_per_kv
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < HEAD_DIM
mask_dv = offs_dv < HEAD_DIM
cur_batch_seq_len = tl.load(seq_lens + cur_batch)
offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
start_page = split_kv_start // PAGE_SIZE
end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
offs_token_in_page = tl.arange(0, PAGE_SIZE)
for page_idx in range(start_page, end_page):
physical_page = tl.load(block_table + cur_batch * stride_bt0 + page_idx * stride_bt1)
global_token_idx = page_idx * PAGE_SIZE + offs_token_in_page
token_mask = global_token_idx < split_kv_end
page_base = KV_Cache + physical_page * stride_kv0 + cur_kv_head * stride_kv2
k = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + offs_d[None, :],
mask=token_mask[:, None] & mask_d[None, :], other=0.0)
v = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + (HEAD_DIM + offs_dv)[None, :],
mask=token_mask[:, None] & mask_dv[None, :], other=0.0)
qk = tl.dot(q, k.to(q.dtype).T)
qk *= sm_scale
qk = tl.where(mask_h[:, None] & token_mask[None, :], qk, float("-inf"))
... (4038 more chars)<system>Command executed successfully.</system>
BLOCK_H=4 max_diff=0.0001
PASSED
BLOCK_H=16 max_diff=0.0001
PASSED
{
"command": "<8481 chars \u2014 see below>"
}cat > /tmp/verify_bug2.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import triton
import triton.language as tl
import reference
import math
shape = {"batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1024, "page_size": 16}
B, H, Hkv, D, L, P = shape["batch"], shape["num_heads"], shape["num_kv_heads"], shape["head_dim"], shape["seq_len"], shape["page_size"]
device = torch.device("cuda:0")
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16, device=device) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16, device=device) * 0.1
block_table = torch.randperm(total_pages)[:B * pages_per_seq].reshape(B, pages_per_seq).int().to(device)
seq_lens = torch.full((B,), L, dtype=torch.int32, device=device)
ref_model = reference.Model(B, H, Hkv, D, L, P).to(device).eval()
with torch.no_grad():
ref_out = ref_model(query, kv_cache, block_table, seq_lens)
# Copy kernel code from sweep
@triton.jit
def _fwd_kernel_stage1_v2(
Q, KV_Cache, block_table, seq_lens, sm_scale, Mid_O,
stride_qb, stride_qh, stride_bt0, stride_bt1,
stride_kv0, stride_kv1, stride_kv2, stride_kv3,
stride_mid_ob, stride_mid_oh, stride_mid_os,
kv_group_num: tl.constexpr, q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
split_kv_id = tl.program_id(2)
heads_per_kv = tl.cdiv(kv_group_num, BLOCK_H)
cur_kv_head = cur_head_id // heads_per_kv
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < HEAD_DIM
mask_dv = offs_dv < HEAD_DIM
cur_batch_seq_len = tl.load(seq_lens + cur_batch)
offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
start_page = split_kv_start // PAGE_SIZE
end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
offs_token_in_page = tl.arange(0, PAGE_SIZE)
for page_idx in range(start_page, end_page):
physical_page = tl.load(block_table + cur_batch * stride_bt0 + page_idx * stride_bt1)
global_token_idx = page_idx * PAGE_SIZE + offs_token_in_page
token_mask = global_token_idx < split_kv_end
page_base = KV_Cache + physical_page * stride_kv0 + cur_kv_head * stride_kv2
k = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + offs_d[None, :],
mask=token_mask[:, None] & mask_d[None, :], other=0.0)
v = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + (HEAD_DIM + offs_dv)[None, :],
mask=token_mask[:, None] & mask_dv[None, :], other=0.0)
qk = tl.dot(q, k.to(q.dtype).T)
qk *= sm_scale
qk = tl.where(mask_h[:, None] & token_mask[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *
... (4481 more chars)<system>Command executed successfully.</system>
Output shape: torch.Size([8, 32, 128])
Ref output shape: torch.Size([8, 32, 128])
Max diff overall: 0.012939453125
Large diff at batch=0, head=24: max_diff=0.0100
Large diff at batch=0, head=25: max_diff=0.0100
Large diff at batch=0, head=26: max_diff=0.0100
Large diff at batch=0, head=27: max_diff=0.0100
Large diff at batch=0, head=28: max_diff=0.0109
Large diff at batch=0, head=29: max_diff=0.0110
Large diff at batch=0, head=30: max_diff=0.0109
Large diff at batch=0, head=31: max_diff=0.0109
Large diff at batch=1, head=12: max_diff=0.0129
Large diff at batch=1, head=13: max_diff=0.0129
Large diff at batch=1, head=14: max_diff=0.0129
Large diff at batch=1, head=15: max_diff=0.0129
Large diff at batch=4, head=12: max_diff=0.0105
Large diff at batch=4, head=13: max_diff=0.0104
Large diff at batch=4, head=14: max_diff=0.0105
Large diff at batch=4, head=15: max_diff=0.0105
Large diff at batch=4, head=16: max_diff=0.0104
Large diff at batch=4, head=17: max_diff=0.0105
Large diff at batch=4, head=18: max_diff=0.0105
Large diff at batch=4, head=19: max_diff=0.0104
Large diff at batch=5, head=16: max_diff=0.0121
Large diff at batch=5, head=17: max_diff=0.0121
Large diff at batch=5, head=18: max_diff=0.0121
Large diff at batch=5, head=19: max_diff=0.0121
Large diff at batch=6, head=24: max_diff=0.0104
Large diff at batch=6, head=25: max_diff=0.0104
Large diff at batch=6, head=26: max
... (1641 more chars){
"command": "<9690 chars \u2014 see below>",
"timeout": 300
}cat > /tmp/sweep_correct.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import triton
import triton.language as tl
import reference
import shapes
import math
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
@triton.jit
def _fwd_kernel_stage1_v2(
Q, KV_Cache, block_table, seq_lens, sm_scale, Mid_O,
stride_qb, stride_qh, stride_bt0, stride_bt1,
stride_kv0, stride_kv1, stride_kv2, stride_kv3,
stride_mid_ob, stride_mid_oh, stride_mid_os,
kv_group_num: tl.constexpr, q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
split_kv_id = tl.program_id(2)
heads_per_kv = tl.cdiv(kv_group_num, BLOCK_H)
cur_kv_head = cur_head_id // heads_per_kv
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < HEAD_DIM
mask_dv = offs_dv < HEAD_DIM
cur_batch_seq_len = tl.load(seq_lens + cur_batch)
offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
start_page = split_kv_start // PAGE_SIZE
end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
offs_token_in_page = tl.arange(0, PAGE_SIZE)
for page_idx in range(start_page, end_page):
physical_page = tl.load(block_table + cur_batch * stride_bt0 + page_idx * stride_bt1)
global_token_idx = page_idx * PAGE_SIZE + offs_token_in_page
token_mask = global_token_idx < split_kv_end
page_base = KV_Cache + physical_page * stride_kv0 + cur_kv_head * stride_kv2
k = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + offs_d[None, :],
mask=token_mask[:, None] & mask_d[None, :], other=0.0)
v = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + (HEAD_DIM + offs_dv)[None, :],
mask=token_mask[:, None] & mask_dv[None, :], other=0.0)
qk = tl.dot(q, k.to(q.dtype).T)
qk *= sm_scale
qk = tl.where(mask_h[:, None] & token_mask[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = cur_batch * stride_mid_ob + cur_head[:, None] * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv[None, :]
tl.store(Mid_O + offs_mid_o, acc / e_sum[:, None], mask=mask_h[:, None] & mask_dv[None, :])
offs_mid_lse = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + HEAD_DIM
tl.store(Mid_O + offs_mid_lse, e_max + tl.log(e_sum), mask=mask_h)
@tri
... (5690 more chars)<system>Command executed successfully.</system>
shape=0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16}
BLOCK_H= 1 splits=1 ms=0.085 max_diff=0.0001
BLOCK_H= 1 splits=2 ms=0.068 max_diff=0.0001
BLOCK_H= 1 splits=4 ms=0.045 max_diff=0.0001
BLOCK_H= 1 splits=8 ms=0.051 max_diff=0.0001
BLOCK_H= 2 splits=1 ms=0.092 max_diff=0.0001
BLOCK_H= 2 splits=2 ms=0.066 max_diff=0.0001
BLOCK_H= 2 splits=4 ms=0.044 max_diff=0.0001
BLOCK_H= 2 splits=8 ms=0.048 max_diff=0.0001
BLOCK_H= 4 splits=1 ms=0.097 max_diff=0.0001
BLOCK_H= 4 splits=2 ms=0.076 max_diff=0.0001
BLOCK_H= 4 splits=4 ms=0.053 max_diff=0.0001
BLOCK_H= 4 splits=8 ms=0.047 max_diff=0.0001
BLOCK_H= 8 splits=1 ms=0.091 max_diff=0.0001
BLOCK_H= 8 splits=2 ms=0.065 max_diff=0.0001
BLOCK_H= 8 splits=4 ms=0.044 max_diff=0.0001
BLOCK_H= 8 splits=8 ms=0.045 max_diff=0.0001
BLOCK_H=16 splits=1 ms=0.087 max_diff=0.0001
BLOCK_H=16 splits=2 ms=0.068 max_diff=0.0001
BLOCK_H=16 splits=4 ms=0.044 max_diff=0.0001
BLOCK_H=16 splits=8 ms=0.045 max_diff=0.0001
shape=1 {'batch': 32, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 2048, 'page_size': 16}
BLOCK_H= 1 splits=1 ms=0.222 max_diff=0.0000
BLOCK_H= 1 splits=2 ms=0.241 max_diff=0.0001
BLOCK_H= 1 splits=4 ms=0.222 max_diff=0.0001
BLOCK_H= 1 splits=8 ms=0.221 max_diff=0.0001
BLOCK_H= 2 splits=1 ms=0.198 max_diff=0.0000
BLOCK_H= 2 splits=2 ms=0.199 max_diff=0.0001
BLOCK_H=
... (3794 more chars){
"command": "<9703 chars \u2014 see below>",
"timeout": 600
}cat > /tmp/sweep_warps.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import triton
import triton.language as tl
import reference
import shapes
import math
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
@triton.jit
def _fwd_kernel_stage1_v2(
Q, KV_Cache, block_table, seq_lens, sm_scale, Mid_O,
stride_qb, stride_qh, stride_bt0, stride_bt1,
stride_kv0, stride_kv1, stride_kv2, stride_kv3,
stride_mid_ob, stride_mid_oh, stride_mid_os,
kv_group_num: tl.constexpr, q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
split_kv_id = tl.program_id(2)
heads_per_kv = tl.cdiv(kv_group_num, BLOCK_H)
cur_kv_head = cur_head_id // heads_per_kv
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < HEAD_DIM
mask_dv = offs_dv < HEAD_DIM
cur_batch_seq_len = tl.load(seq_lens + cur_batch)
offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
start_page = split_kv_start // PAGE_SIZE
end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
offs_token_in_page = tl.arange(0, PAGE_SIZE)
for page_idx in range(start_page, end_page):
physical_page = tl.load(block_table + cur_batch * stride_bt0 + page_idx * stride_bt1)
global_token_idx = page_idx * PAGE_SIZE + offs_token_in_page
token_mask = global_token_idx < split_kv_end
page_base = KV_Cache + physical_page * stride_kv0 + cur_kv_head * stride_kv2
k = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + offs_d[None, :],
mask=token_mask[:, None] & mask_d[None, :], other=0.0)
v = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + (HEAD_DIM + offs_dv)[None, :],
mask=token_mask[:, None] & mask_dv[None, :], other=0.0)
qk = tl.dot(q, k.to(q.dtype).T)
qk *= sm_scale
qk = tl.where(mask_h[:, None] & token_mask[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = cur_batch * stride_mid_ob + cur_head[:, None] * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv[None, :]
tl.store(Mid_O + offs_mid_o, acc / e_sum[:, None], mask=mask_h[:, None] & mask_dv[None, :])
offs_mid_lse = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + HEAD_DIM
tl.store(Mid_O + offs_mid_lse, e_max + tl.log(e_sum), mask=mask_h)
@trito
... (5703 more chars)<system>ERROR: Error validating JSON arguments: 1 validation error for Params
Value error, timeout must be <= 300s for foreground commands; use run_in_background=true for longer timeouts (up to 86400s) [type=value_error, input_value={'command': 'cat > /tmp/s...y 2>&1', 'timeout': 600}, input_type=dict]
For further information visit https://errors.pydantic.dev/2.12/v/value_error</system>{
"command": "<9667 chars \u2014 see below>",
"timeout": 300
}cat > /tmp/sweep_warps.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import triton
import triton.language as tl
import reference
import shapes
import math
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
@triton.jit
def _fwd_kernel_stage1_v2(
Q, KV_Cache, block_table, seq_lens, sm_scale, Mid_O,
stride_qb, stride_qh, stride_bt0, stride_bt1,
stride_kv0, stride_kv1, stride_kv2, stride_kv3,
stride_mid_ob, stride_mid_oh, stride_mid_os,
kv_group_num: tl.constexpr, q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
split_kv_id = tl.program_id(2)
heads_per_kv = tl.cdiv(kv_group_num, BLOCK_H)
cur_kv_head = cur_head_id // heads_per_kv
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < HEAD_DIM
mask_dv = offs_dv < HEAD_DIM
cur_batch_seq_len = tl.load(seq_lens + cur_batch)
offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
start_page = split_kv_start // PAGE_SIZE
end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
offs_token_in_page = tl.arange(0, PAGE_SIZE)
for page_idx in range(start_page, end_page):
physical_page = tl.load(block_table + cur_batch * stride_bt0 + page_idx * stride_bt1)
global_token_idx = page_idx * PAGE_SIZE + offs_token_in_page
token_mask = global_token_idx < split_kv_end
page_base = KV_Cache + physical_page * stride_kv0 + cur_kv_head * stride_kv2
k = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + offs_d[None, :],
mask=token_mask[:, None] & mask_d[None, :], other=0.0)
v = tl.load(page_base + offs_token_in_page[:, None] * stride_kv1 + (HEAD_DIM + offs_dv)[None, :],
mask=token_mask[:, None] & mask_dv[None, :], other=0.0)
qk = tl.dot(q, k.to(q.dtype).T)
qk *= sm_scale
qk = tl.where(mask_h[:, None] & token_mask[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = cur_batch * stride_mid_ob + cur_head[:, None] * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv[None, :]
tl.store(Mid_O + offs_mid_o, acc / e_sum[:, None], mask=mask_h[:, None] & mask_dv[None, :])
offs_mid_lse = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + HEAD_DIM
tl.store(Mid_O + offs_mid_lse, e_max + tl.log(e_sum), mask=mask_h)
@trito
... (5667 more chars)<system>Command executed successfully.</system>
shape=0 BLOCK_H=8 splits=4
warps=2 stages=1 ms=0.060
warps=2 stages=2 ms=0.043
warps=4 stages=1 ms=0.059
warps=4 stages=2 ms=0.043
warps=8 stages=1 ms=0.061
warps=8 stages=2 ms=0.046
shape=2 BLOCK_H=8 splits=8
warps=2 stages=1 ms=0.089
warps=2 stages=2 ms=0.067
warps=4 stages=1 ms=0.097
warps=4 stages=2 ms=0.073
warps=8 stages=1 ms=0.089
warps=8 stages=2 ms=0.079
shape=4 BLOCK_H=8 splits=8
warps=2 stages=1 ms=0.057
warps=2 stages=2 ms=0.046
warps=4 stages=1 ms=0.049
warps=4 stages=2 ms=0.043
warps=8 stages=1 ms=0.049
warps=8 stages=2 ms=0.046
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -8,9 +8,8 @@
partial attention with online softmax.
2. Stage 2 reduces partial results across splits.
-For efficiency on the packed layout, each iteration loads an entire page
-(page_size tokens x 2*head_dim elements) as a coalesced block, then extracts
-K and V slices.
+For efficiency on the packed layout, each iteration loads K and V separately
+using coalesced per-token loads.
"""
import math
@@ -42,7 +41,6 @@
seq_lens,
sm_scale,
Mid_O,
- Mid_LSE,
stride_qb,
stride_qh,
stride_bt0,
@@ -108,7 +106,6 @@
end_page = (split_kv_end + PAGE_SIZE - 1) // PAGE_SIZE
offs_token_in_page = tl.arange(0, PAGE_SIZE)
- offs_dim = tl.arange(0, 2 * HEAD_DIM)
for page_idx in range(start_page, end_page):
physical_page = tl.load(
@@ -126,16 +123,17 @@
+ cur_kv_head * stride_kv2
)
- # Load full page: (PAGE_SIZE, 2*HEAD_DIM)
- page_data = tl.load(
- page_base + offs_token_in_page[:, None] * stride_kv1 + offs_dim[None, :],
- mask=token_mask[:, None],
+ # Load K and V separately (packed layout: [K|V] on last dim)
+ k = tl.load(
+ page_base + offs_token_in_page[:, None] * stride_kv1 + offs_d[None, :],
+ mask=token_mask[:, None] & mask_d[None, :],
other=0.0,
)
-
- # Extract K and V
- k = page_data[:, :HEAD_DIM] # (PAGE_SIZE, HEAD_DIM)
- v = page_data[:, HEAD_DIM:] # (PAGE_SIZE, HEAD_DIM)
+ v = tl.load(
+ page_base + offs_token_in_page[:, None] * stride_kv1 + (HEAD_DIM + offs_dv)[None, :],
+ mask=token_mask[:, None] & mask_dv[None, :],
+ other=0.0,
+ )
# Compute QK^T
qk = tl.dot(q, k.to(q.dtype).T)
@@ -179,7 +177,7 @@
+ HEAD_DIM
)
tl.store(
- Mid_LSE + offs_mid_lse,
+ Mid_O + offs_mid_lse,
e_max + tl.log(e_sum),
mask=mask_h,
)
@@ -191,7 +189,6 @@
@triton.jit
def _fwd_kernel_stage2(
Mid_O,
- Mid_LSE,
O,
seq_lens,
stride_mid_ob,
@@ -231,7 +228,7 @@
mask=mask_d,
other=0.0,
)
- tlogic = tl.load(Mid_LSE + offs_lse + split_kv_id * stride_mid_os)
+ tlogic = tl.load(Mid_O + offs_lse + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
@@ -258,17 +255,19 @@
block_table,
seq_lens,
mid_o,
- mid_lse,
num_kv_splits,
sm_scale,
page_size,
+ BLOCK_H: int = 8,
):
batch, head_num = q.shape[0], q.shape[1]
num_kv_heads = kv_cache.shape[2]
head_dim = kv_cache.shape[3] // 2
kv_group_num = head_num // num_kv_heads
- BLOCK_H = 16
+ # Correct grid: each "head group" covers up to BLOCK_H query heads that
+ # map to the same KV head. We must divide by min(BLOCK_H, kv_group_num)
+ # so that every KV head gets its own set of blocks.
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
@@ -288,7 +287,6 @@
seq_lens,
sm_scale,
mid_o,
- mid_lse,
q.stride(0),
q.stride(1),
stride_bt[0],
@@ -315,7 +313,6 @@
def _decode_softmax_reducev_fwd(
mid_o,
- mid_lse,
o,
seq_lens,
num_kv_splits,
@@ -327,7 +324,6 @@
grid = (batch, head_num)
_fwd_kernel_stage2[grid](
mid_o,
- mid_lse,
o,
seq_lens,
mid_o.stride(0),
@@ -353,7 +349,6 @@
num_kv_splits,
):
batch, num_heads, head_dim = query.shape
- num_kv_heads = kv_cache.shape[2]
# Intermediate buffers for stage 1 -> stage 2
# Layout: (batch, num_heads, num_kv_splits, head_dim+1)
@@ -363,12 +358,6 @@
dtype=torch.float32,
device=query.device,
)
- # Use a separate LSE tensor for clarity
- mid_lse = torch.empty(
- (batch, num_heads, num_kv_splits),
- dtype=torch.float32,
- device=query.device,
- )
output = torch.empty_like(query)
@@ -378,14 +367,12 @@
block_table,
seq_lens,
mid_o,
- mid_lse,
num_kv_splits,
sm_scale,
page_size,
)
_decode_softmax_reducev_fwd(
mid_o,
- mid_lse,
output,
seq_lens,
num_kv_splits,
@@ -427,24 +414,30 @@
block_table: torch.Tensor,
seq_lens: torch.Tensor,
) -> torch.Tensor:
- # Heuristic: more splits for longer sequences to increase parallelism.
- # Also use fewer splits when there are already many heads to keep
- # occupancy high without excessive intermediate traffic.
- max_seq_len = int(seq_lens.max().item())
- total_heads = self.batch * self.num_heads
-
- if max_seq_len <= 512:
+ # Heuristic: use more KV splits when the sequence is longer or when
+ # there are fewer query heads (low natural occupancy). More splits
+ # increase the number of stage-1 blocks and improve parallelism.
+ max_seq_len = self.seq_len
+
+ # Target ~256-512 stage-1 blocks for good occupancy on Blackwell.
+ head_groups = (self.num_heads + min(8, self.group_size) - 1) // min(8, self.group_size)
+ base_blocks = self.batch * head_groups
+
+ if base_blocks < 64:
+ num_kv_splits = 8
+ elif base_blocks < 128:
+ num_kv_splits = 4
+ elif base_blocks < 256:
+ num_kv_splits = 2
+ else:
num_kv_splits = 1
- elif max_seq_len <= 1024:
- num_kv_splits = 2
- elif max_seq_len <=
... (520 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"content": "<13593 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 13593 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python check.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python benchmark.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0428 11:29:40.535000 406834 .venv/lib/python3.11/site-packages/torch/_dynamo/var
... (1921 more chars){
"command": "<2589 chars \u2014 see below>"
}cat > /tmp/test_splits.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import shapes
import math
# Reuse the kernel from solution.py
import solution
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
# Test shape 4 with more splits
shape = shapes.SHAPES[4]
apply_shape(shape)
torch.manual_seed(42)
B = shape["batch"]
H = shape["num_heads"]
Hkv = shape["num_kv_heads"]
D = shape["head_dim"]
L = shape["seq_len"]
P = shape["page_size"]
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16, device=device) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16, device=device) * 0.1
block_table = torch.randperm(total_pages)[:B * pages_per_seq].reshape(B, pages_per_seq).int().to(device)
seq_lens_tensor = torch.full((B,), L, dtype=torch.int32, device=device)
sm_scale = 1.0 / math.sqrt(D)
ref_model = reference.Model(B, H, Hkv, D, L, P).to(device).eval()
with torch.no_grad():
ref_out = ref_model(query, kv_cache, block_table, seq_lens_tensor)
for num_kv_splits in [8, 12, 16, 20]:
try:
out = solution.paged_attention_decode(query, kv_cache, block_table, seq_lens_tensor, sm_scale, P, num_kv_splits)
max_diff = (out - ref_out).abs().max().item()
for _ in range(10):
out = solution.paged_attention_decode(query, kv_cache, block_table, seq_lens_tensor, sm_scale, P, num_kv_splits)
torch.cuda.synchronize()
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
out = solution.paged_attention_decode(query, kv_cache, block_table, seq_lens_tensor, sm_scale, P, num_kv_splits)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
import statistics
print(f"splits={num_kv_splits:2d} ms={statistics.median(times):.3f} max_diff={max_diff:.4f}")
except Exception as e:
print(f"splits={num_kv_splits} ERROR: {e}")
EOF
uv run python /tmp/test_splits.py<system>Command executed successfully.</system>
splits= 8 ms=0.044 max_diff=0.0014
splits=12 ms=0.043 max_diff=0.0016
splits=16 ms=0.034 max_diff=0.0016
splits=20 ms=0.034 max_diff=0.0017
{
"command": "<2815 chars \u2014 see below>",
"timeout": 300
}cat > /tmp/test_all_splits.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import shapes
import math
import solution
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
for shape_idx in range(5):
shape = shapes.SHAPES[shape_idx]
apply_shape(shape)
torch.manual_seed(42)
B = shape["batch"]
H = shape["num_heads"]
Hkv = shape["num_kv_heads"]
D = shape["head_dim"]
L = shape["seq_len"]
P = shape["page_size"]
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16, device=device) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16, device=device) * 0.1
block_table = torch.randperm(total_pages)[:B * pages_per_seq].reshape(B, pages_per_seq).int().to(device)
seq_lens_tensor = torch.full((B,), L, dtype=torch.int32, device=device)
sm_scale = 1.0 / math.sqrt(D)
ref_model = reference.Model(B, H, Hkv, D, L, P).to(device).eval()
with torch.no_grad():
ref_out = ref_model(query, kv_cache, block_table, seq_lens_tensor)
print(f"\nshape={shape_idx} {shape}")
for num_kv_splits in [1, 2, 4, 8, 12, 16, 20]:
try:
out = solution.paged_attention_decode(query, kv_cache, block_table, seq_lens_tensor, sm_scale, P, num_kv_splits)
max_diff = (out - ref_out).abs().max().item()
if max_diff > 0.02:
continue
for _ in range(10):
out = solution.paged_attention_decode(query, kv_cache, block_table, seq_lens_tensor, sm_scale, P, num_kv_splits)
torch.cuda.synchronize()
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
out = solution.paged_attention_decode(query, kv_cache, block_table, seq_lens_tensor, sm_scale, P, num_kv_splits)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
import statistics
print(f" splits={num_kv_splits:2d} ms={statistics.median(times):.3f}")
except Exception as e:
pass
EOF
uv run python /tmp/test_all_splits.py 2>&1<system>Command executed successfully.</system>
shape=0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16}
splits= 1 ms=0.088
splits= 2 ms=0.064
splits= 4 ms=0.043
splits= 8 ms=0.044
splits=12 ms=0.057
splits=16 ms=0.044
splits=20 ms=0.044
shape=1 {'batch': 32, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 2048, 'page_size': 16}
splits= 1 ms=0.198
splits= 2 ms=0.205
splits= 4 ms=0.194
splits= 8 ms=0.202
splits=12 ms=0.216
splits=16 ms=0.207
splits=20 ms=0.219
shape=2 {'batch': 4, 'num_heads': 64, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 4096, 'page_size': 16}
splits= 1 ms=0.239
splits= 2 ms=0.136
splits= 4 ms=0.092
splits= 8 ms=0.068
splits=12 ms=0.069
splits=16 ms=0.067
splits=20 ms=0.069
shape=3 {'batch': 16, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1535, 'page_size': 16}
splits= 1 ms=0.124
splits= 2 ms=0.086
splits= 4 ms=0.090
splits= 8 ms=0.087
splits=12 ms=0.095
splits=16 ms=0.102
splits=20 ms=0.096
shape=4 {'batch': 8, 'num_heads': 16, 'num_kv_heads': 4, 'head_dim': 64, 'seq_len': 2000, 'page_size': 16}
splits= 1 ms=0.125
splits= 2 ms=0.076
splits= 4 ms=0.053
splits= 8 ms=0.042
splits=12 ms=0.043
splits=16 ms=0.034
splits=20 ms=0.033
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"edit": {
"old": " # Target ~256-512 stage-1 blocks for good occupancy on Blackwell.\n head_groups = (self.num_heads + min(8, self.group_size) - 1) // min(8, self.group_size)\n base_blocks = self.batch * head_groups\n\n if base_blocks < 64:\n num_kv_splits = 8\n elif base_blocks < 128:\n num_kv_splits = 4\n elif base_blocks < 256:\n num_kv_splits = 2\n else:\n num_kv_splits = 1\n\n # Also increase
... (775 more chars)<system>File successfully edited. Applied 1 edit(s) with 1 total replacement(s).</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python check.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python benchmark.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0428 11:33:10.235000 408514 .venv/lib/python3.11/site-packages/torch/_dynamo/var
... (1921 more chars){
"command": "find /home -path \"*vllm*\" -name \"*.cu\" 2>/dev/null | head -20"
}<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/archive-v0/kvJWw2ZG8XP2FEdKYGrv-/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/VsLpCJM2fhU0k2Kta8DwV/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/zk-sxLfrHb5QFAczxYcwm/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/-5aPHd24t_AzllbqxJdMn/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/GDUPxInMuZUwIuZYN8FED/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/qwgWqR30GLvk0NdjyRg8w/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/YzRtIX0bUl6ROWuo6QXt2/lib/python3.13/site-packages/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/39Zqyllb8nIEeKLgLnYxN/lib/python3.13/site-packages/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/ceCm9ZiQJxctmuFAV-Dh9/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.cache/uv/archive-v0/DfQaSBvTfbrh7Uk5TlMh1/tilelang/3rdparty/tvm/src/runtime/contrib/vllm/cache_kernels.cu
/home/infatoshi/.cache/uv/archive-v0/DfQaSBvTfbrh7Uk5TlMh1/tilelang/3rdparty/tvm/src/runtime/contrib/vllm/attention_kernels.cu
/home/infatoshi/.local/share/uv/tools/lm-eval/lib/python3.13/site-packages/flashinfer/data/csrc/vllm_custom_all_reduce.cu
/home/infatoshi/.local/share/uv/tools/vllm/lib/python3.12/site-packa
... (816 more chars){
"command": "find /home -path \"*vllm*\" -name \"paged_attention*.cu\" 2>/dev/null | head -10"
}<system>Command executed successfully.</system>
/home/infatoshi/vllm/csrc/attention/paged_attention_v1.cu
/home/infatoshi/vllm/csrc/attention/paged_attention_v2.cu
{
"path": "/home/infatoshi/vllm/csrc/attention/paged_attention_v2.cu",
"n_lines": 200
}<system>196 lines read from file starting from line 1. Total lines in file: 196. End of file reached.</system>
1 /*
2 * Adapted from
3 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4 * Copyright (c) 2023, The vLLM team.
5 * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19 #include "attention_kernels.cuh"
20 #include "../cuda_compat.h"
21
22 #define MAX(a, b) ((a) > (b) ? (a) : (b))
23 #define MIN(a, b) ((a) < (b) ? (a) : (b))
24 #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
25
26 #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
27 vllm::paged_attention_v2_kernel<T, CACHE_T, H
... (9347 more chars){
"command": "find /home/infatoshi/vllm -name \"attention_kernels.cuh\" | head -5"
}<system>Command executed successfully.</system>
/home/infatoshi/vllm/csrc/attention/attention_kernels.cuh
{
"path": "/home/infatoshi/vllm/csrc/attention/attention_kernels.cuh",
"n_lines": 300
}<system>300 lines read from file starting from line 1. Total lines in file: 670.</system>
1 /*
2 * Adapted from
3 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
4 * Copyright (c) 2023, The vLLM team.
5 * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include <torch/all.h>
21 #include <ATen/cuda/CUDAContext.h>
22 #include <c10/cuda/CUDAGuard.h>
23 #include <algorithm>
24
25 #include "attention_dtypes.h"
26 #include "attention_utils.cuh"
27 #include "../cuda_compat.h"
28
29 #ifdef USE_ROCM
30 #include <hip/hip_bf16.h>
31 #include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
32 typedef __
... (13684 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"edit": {
"old": " # Target a total stage-1 block count that keeps the GPU well occupied.\n # With smaller head_dim each block does less work per page, so we need\n # more blocks (hence more splits) to hide latency.\n head_groups = (self.num_heads + min(8, self.group_size) - 1) // min(8, self.group_size)\n base_blocks = self.batch * head_groups\n target_blocks = 512 if self.head_dim <= 64 else 256\n num_kv_splits = max(1, min(20, target_blocks // b
... (529 more chars)<system>File successfully edited. Applied 1 edit(s) with 1 total replacement(s).</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python check.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && uv run python benchmark.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0428 11:37:13.762000 410445 .venv/lib/python3.11/site-packages/torch/_dynamo/var
... (1921 more chars){
"command": "<2361 chars \u2014 see below>"
}cat > /tmp/test_model_timing.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
for shape_idx in [2, 4]:
shape = shapes.SHAPES[shape_idx]
apply_shape(shape)
torch.manual_seed(42)
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup
for _ in range(10):
sol_model(*inputs)
torch.cuda.synchronize()
# Time with L2 flush
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
import statistics
print(f"shape={shape_idx} {shape} ms={statistics.median(times):.3f}")
# Also time paged_attention_decode directly with the heuristic's num_kv_splits
import inspect
src = inspect.getsource(solution.Model.forward)
# Just call direct
if shape_idx == 2:
num_kv_splits = 16
else:
num_kv_splits = 16
times2 = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], 1.0/math.sqrt(shape["head_dim"]), shape["page_size"], num_kv_splits)
e.record()
torch.cuda.synchronize()
times2.append(s.elapsed_time(e))
print(f" direct splits={num_kv_splits} ms={statistics.median(times2):.3f}")
EOF
uv run python /tmp/test_model_timing.py<system>Command executed successfully.</system>
shape=2 {'batch': 4, 'num_heads': 64, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 4096, 'page_size': 16} ms=0.078
direct splits=16 ms=0.068
shape=4 {'batch': 8, 'num_heads': 16, 'num_kv_heads': 4, 'head_dim': 64, 'seq_len': 2000, 'page_size': 16} ms=0.035
direct splits=16 ms=0.034
{
"command": "<2478 chars \u2014 see below>"
}cat > /tmp/test_overhead.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
shape = shapes.SHAPES[2]
apply_shape(shape)
torch.manual_seed(42)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup both
for _ in range(15):
sol_model(*inputs)
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], 1.0/math.sqrt(128), 16, 16)
torch.cuda.synchronize()
# Time model
import statistics
times_model = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times_model.append(s.elapsed_time(e))
print(f"Model: ms={statistics.median(times_model):.3f}")
# Time direct
times_direct = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], 1.0/math.sqrt(128), 16, 16)
e.record()
torch.cuda.synchronize()
times_direct.append(s.elapsed_time(e))
print(f"Direct: ms={statistics.median(times_direct):.3f}")
# Time with state dict load
sd = sol_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.cuda.synchronize()
times_post_load = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times_post_load.append(s.elapsed_time(e))
print(f"Post-load: ms={statistics.median(times_post_load):.3f}")
EOF
uv run python /tmp/test_overhead.py<system>Command executed successfully.</system>
Model: ms=0.074
Direct: ms=0.068
Post-load: ms=0.069
{
"command": "<2020 chars \u2014 see below>"
}cat > /tmp/test_benchmark_mimic.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
import statistics
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
# Mimic benchmark.py exactly for shape 4
shape_idx = 4
shape = shapes.SHAPES[shape_idx]
apply_shape(shape)
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sd = ref_model.state_dict()
try:
sol_model.load_state_dict(sd, strict=True)
except RuntimeError:
pass
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup exactly like time_fn
for _ in range(10):
with torch.no_grad():
sol_model(*inputs)
torch.cuda.synchronize()
# Time exactly like time_fn
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
print(f"shape={shape_idx} ms={statistics.median(times):.3f}")
# Also print num_kv_splits
head_groups = (shape["num_heads"] + min(8, shape["num_heads"]//shape["num_kv_heads"]) - 1) // min(8, shape["num_heads"]//shape["num_kv_heads"])
base_blocks = shape["batch"] * head_groups
target_blocks = 640 if shape["head_dim"] <= 64 else 512
num_kv_splits = max(1, min(24, target_blocks // base_blocks))
print(f" base_blocks={base_blocks} num_kv_splits={num_kv_splits}")
EOF
uv run python /tmp/test_benchmark_mimic.py<system>Command executed successfully.</system>
shape=4 ms=0.051
base_blocks=32 num_kv_splits=20
{
"command": "<1967 chars \u2014 see below>"
}cat > /tmp/test_direct_20.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
import statistics
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
shape = shapes.SHAPES[4]
apply_shape(shape)
torch.manual_seed(2026)
B = shape["batch"]
H = shape["num_heads"]
Hkv = shape["num_kv_heads"]
D = shape["head_dim"]
L = shape["seq_len"]
P = shape["page_size"]
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16, device=device) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16, device=device) * 0.1
block_table = torch.randperm(total_pages)[:B * pages_per_seq].reshape(B, pages_per_seq).int().to(device)
seq_lens = torch.full((B,), L, dtype=torch.int32, device=device)
sm_scale = 1.0 / math.sqrt(D)
for num_kv_splits in [8, 16, 20]:
# Warmup
for _ in range(15):
solution.paged_attention_decode(query, kv_cache, block_table, seq_lens, sm_scale, P, num_kv_splits)
torch.cuda.synchronize()
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
solution.paged_attention_decode(query, kv_cache, block_table, seq_lens, sm_scale, P, num_kv_splits)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
print(f"splits={num_kv_splits} ms={statistics.median(times):.3f}")
EOF
uv run python /tmp/test_direct_20.py<system>Command executed successfully.</system>
splits=8 ms=0.043
splits=16 ms=0.034
splits=20 ms=0.037
{
"command": "<2823 chars \u2014 see below>"
}cat > /tmp/test_model_simple.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
import statistics
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
shape = shapes.SHAPES[4]
apply_shape(shape)
init_args = reference.get_init_inputs()
# Test 1: Model without load_state_dict
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
for _ in range(15):
with torch.no_grad():
sol_model(*inputs)
torch.cuda.synchronize()
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
print(f"Model (no load_state_dict): ms={statistics.median(times):.3f}")
# Test 2: Model with load_state_dict
ref_model = reference.Model(*init_args).to(device).eval()
sol_model2 = solution.Model(*init_args).to(device).eval()
sd = ref_model.state_dict()
sol_model2.load_state_dict(sd, strict=True)
torch.cuda.synchronize()
for _ in range(15):
with torch.no_grad():
sol_model2(*inputs)
torch.cuda.synchronize()
times2 = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model2(*inputs)
e.record()
torch.cuda.synchronize()
times2.append(s.elapsed_time(e))
print(f"Model (with load_state_dict): ms={statistics.median(times2):.3f}")
# Test 3: Direct call
sm_scale = 1.0 / math.sqrt(128)
num_kv_splits = 16
for _ in range(15):
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], sm_scale, shape["page_size"], num_kv_splits)
torch.cuda.synchronize()
times3 = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], sm_scale, shape["page_size"], num_kv_splits)
e.record()
torch.cuda.synchronize()
times3.append(s.elapsed_time(e))
print(f"Direct: ms={statistics.median(times3):.3f}")
EOF
uv run python /tmp/test_model_simple.py<system>Command executed successfully.</system>
Model (no load_state_dict): ms=0.043
Model (with load_state_dict): ms=0.040
Direct: ms=0.034
{
"command": "<1488 chars \u2014 see below>"
}cat > /tmp/test_no_buffer.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
# Temporarily modify solution.Model to not use register_buffer
class TestModel(solution.Model):
def __init__(self, *args, **kwargs):
torch.nn.Module.__init__(self)
assert args[1] % args[2] == 0
self.batch = args[0]
self.num_heads = args[1]
self.num_kv_heads = args[2]
self.head_dim = args[3]
self.seq_len = args[4]
self.page_size = args[5]
self.group_size = args[1] // args[2]
import math
self.scale = 1.0 / math.sqrt(args[3])
# No buffer!
for shape in shapes.SHAPES:
apply_shape(shape)
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
test_model = TestModel(*init_args).to(device).eval()
sd = ref_model.state_dict()
try:
test_model.load_state_dict(sd, strict=True)
print(f"load_state_dict passed for {shape}")
except RuntimeError as e:
print(f"load_state_dict FAILED for {shape}: {e}")
EOF
uv run python /tmp/test_no_buffer.py<system>Command executed successfully.</system>
load_state_dict passed for {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16}
load_state_dict passed for {'batch': 32, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 2048, 'page_size': 16}
load_state_dict passed for {'batch': 4, 'num_heads': 64, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 4096, 'page_size': 16}
load_state_dict passed for {'batch': 16, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1535, 'page_size': 16}
load_state_dict passed for {'batch': 8, 'num_heads': 16, 'num_kv_heads': 4, 'head_dim': 64, 'seq_len': 2000, 'page_size': 16}
{
"command": "python -c \"import torch; print(torch.__version__); m = torch.nn.Linear(2,2); m.register_buffer('b', torch.zeros(1)); sd = m.state_dict(); m2 = torch.nn.Linear(2,2); m2.load_state_dict(sd, strict=True); print('loaded')\""
}<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2639, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for Linear:
Unexpected key(s) in state_dict: "b".
2.11.0+cu130
{
"path": "reference.py"
}<system>144 lines read from file starting from line 1. Total lines in file: 144. End of file reached.</system>
1 """Naive PyTorch paged-attention decode reference (correctness oracle, not SOTA).
2
3 Single-query decode: each batch element has a query of shape (num_heads, head_dim)
4 and attends over a KV cache of `seq_len[b]` tokens stored as fixed-size pages in
5 a global pool. Pages for batch element b are listed in `block_table[b]`.
6
7 The reference performs the slow path:
8 1. Gather pages -> contiguous (seq_len, num_kv_heads, head_dim) per batch element.
9 2. Repeat KV heads for grouped-query (broadcast num_kv_heads -> num_heads).
10 3. Manual softmax(QK^T / sqrt(d)) @ V in fp32, cast back to bf16.
11
12 This avoids torch.nn.functional.scaled_dot_product_attention (which is on the
13 forbidden list) so the agent cannot dispatch through SDPA either.
14 """
15 import math
16
17 import torch
18 import torch.nn as nn
19
20 OP_TYPE = "attention"
21 SUPPORTED_PRECISIONS = ["bf16"]
22 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
23
24
25 # --- Shape knobs (overridden by check.py / benchmark.py from shapes.py) ----
26 BATCH = 8
27 NUM_HEADS = 32
28 NUM_KV_HEADS = 8
29 HEAD_DIM = 128
30 SEQ_LEN = 1024
31 PAGE_SIZE = 16
32
33
34 class Model(nn.Module):
35 """Single-query paged attention decode.
36
... (4770 more chars){
"command": "<415 chars \u2014 see below>"
}python -c "
import sys
sys.path.insert(0, '.')
import reference
import shapes
reference.BATCH = 8
reference.NUM_HEADS = 32
reference.NUM_KV_HEADS = 8
reference.HEAD_DIM = 128
reference.SEQ_LEN = 1024
reference.PAGE_SIZE = 16
ref_model = reference.Model(*reference.get_init_inputs())
print('ref state_dict keys:', list(ref_model.state_dict().keys()))
print('ref buffers:', dict(ref_model.named_buffers()).keys())
"<system>Command executed successfully.</system>
ref state_dict keys: []
ref buffers: dict_keys(['_dummy'])
{
"command": "<1548 chars \u2014 see below>"
}cat > /tmp/test_page_aligned.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
import statistics
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
shape = shapes.SHAPES[4]
apply_shape(shape)
torch.manual_seed(2026)
init_args = reference.get_init_inputs()
inputs = [t.to(device) for t in reference.get_inputs()]
sm_scale = 1.0 / math.sqrt(64)
# Test with aligned splits
for num_kv_splits in [4, 8, 16, 20]:
for _ in range(10):
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], sm_scale, 16, num_kv_splits)
torch.cuda.synchronize()
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], sm_scale, 16, num_kv_splits)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
print(f"splits={num_kv_splits}: ms={statistics.median(times):.3f}")
EOF
uv run python /tmp/test_page_aligned.py<system>Command executed successfully.</system>
splits=4: ms=0.053
splits=8: ms=0.042
splits=16: ms=0.043
splits=20: ms=0.033
{
"command": "<2201 chars \u2014 see below>"
}cat > /tmp/test_verify_splits.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
import statistics
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
shape = shapes.SHAPES[4]
apply_shape(shape)
torch.manual_seed(2026)
init_args = reference.get_init_inputs()
inputs = [t.to(device) for t in reference.get_inputs()]
# Create model and check num_kv_splits
sol_model = solution.Model(*init_args).to(device).eval()
print(f"Model num_kv_splits heuristic:")
print(f" head_dim={sol_model.head_dim}")
print(f" seq_len={sol_model.seq_len}")
print(f" batch={sol_model.batch}")
print(f" num_heads={sol_model.num_heads}")
base_blocks = sol_model.num_kv_splits if hasattr(sol_model, 'num_kv_splits') else None
print(f" base_blocks (num_kv_splits attr): {base_blocks}")
# Compute manually
kv_group_num = sol_model.num_heads // sol_model.num_kv_heads
head_groups = sol_model.num_heads // sol_model.num_kv_heads
base_blocks = (sol_model.batch * sol_model.num_kv_heads * head_groups * sol_model.seq_len) // sol_model.page_size
target_blocks = 640 if sol_model.head_dim <= 64 else 512
num_kv_splits = max(1, min(24, target_blocks // base_blocks))
print(f" Manual: base_blocks={base_blocks}, target_blocks={target_blocks}, num_kv_splits={num_kv_splits}")
# Time it
for _ in range(15):
with torch.no_grad():
sol_model(*inputs)
torch.cuda.synchronize()
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
print(f"Model time: ms={statistics.median(times):.3f}")
EOF
uv run python /tmp/test_verify_splits.py<system>Command executed successfully.</system>
Model num_kv_splits heuristic:
head_dim=64
seq_len=2000
batch=8
num_heads=16
base_blocks (num_kv_splits attr): None
Manual: base_blocks=16000, target_blocks=640, num_kv_splits=1
Model time: ms=0.052
{
"path": "solution.py"
}<system>463 lines read from file starting from line 1. Total lines in file: 463. End of file reached.</system>
1 """Triton-based paged attention decode kernel for packed KV cache layout.
2
3 The KV cache is laid out as (num_blocks, page_size, num_kv_heads, head_dim * 2)
4 with K and V packed on the last dimension: [K | V].
5
6 This kernel uses a two-stage approach:
7 1. Stage 1 splits the KV sequence into NUM_KV_SPLITS parts, each block computes
8 partial attention with online softmax.
9 2. Stage 2 reduces partial results across splits.
10
11 For efficiency on the packed layout, each iteration loads K and V separately
12 using coalesced per-token loads.
13 """
14 import math
15
16 import torch
17 import torch.nn as nn
18 import triton
19 import triton.language as tl
20
21 OP_TYPE = "attention"
22 SUPPORTED_PRECISIONS = ["bf16"]
23 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
24
25 BATCH = 8
26 NUM_HEADS = 32
27 NUM_KV_HEADS = 8
28 HEAD_DIM = 128
29 SEQ_LEN = 1024
30 PAGE_SIZE = 16
31
32
33 # --------------------------------------------------------------------------- #
34 # Stage 1: compute partial attention per KV split (grouped for GQA)
35 # --------------------------------------------------------------------------- #
36 @triton.jit
37 def _fwd_kernel_stage1(
38 Q,
39 KV_Cache,
40
... (15211 more chars){
"command": "<3375 chars \u2014 see below>"
}cat > /tmp/test_debug_overhead.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
import statistics
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
shape = shapes.SHAPES[4]
apply_shape(shape)
torch.manual_seed(2026)
init_args = reference.get_init_inputs()
inputs = [t.to(device) for t in reference.get_inputs()]
# Test 1: Direct call
sm_scale = 1.0 / math.sqrt(64)
num_kv_splits = 20
for _ in range(15):
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], sm_scale, 16, num_kv_splits)
torch.cuda.synchronize()
times1 = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], sm_scale, 16, num_kv_splits)
e.record()
torch.cuda.synchronize()
times1.append(s.elapsed_time(e))
print(f"Direct: ms={statistics.median(times1):.3f}")
# Test 2: Model.call with identical args
sol_model = solution.Model(*init_args).to(device).eval()
print(f"Model.scale={sol_model.scale}, direct_scale={sm_scale}, match={sol_model.scale==sm_scale}")
for _ in range(15):
with torch.no_grad():
sol_model.forward(inputs[0], inputs[1], inputs[2], inputs[3])
torch.cuda.synchronize()
times2 = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model.forward(inputs[0], inputs[1], inputs[2], inputs[3])
e.record()
torch.cuda.synchronize()
times2.append(s.elapsed_time(e))
print(f"Model.forward: ms={statistics.median(times2):.3f}")
# Test 3: Model.__call__ with identical args
for _ in range(15):
with torch.no_grad():
sol_model(inputs[0], inputs[1], inputs[2], inputs[3])
torch.cuda.synchronize()
times3 = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model(inputs[0], inputs[1], inputs[2], inputs[3])
e.record()
torch.cuda.synchronize()
times3.append(s.elapsed_time(e))
print(f"Model.__call__: ms={statistics.median(times3):.3f}")
# Test 4: Model.__call__ without torch.no_grad
for _ in range(15):
sol_model(inputs[0], inputs[1], inputs[2], inputs[3])
torch.cuda.synchronize()
times4 = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
sol_model(inputs[0], inputs[1], inputs[2], inputs[3])
e.record()
torch.cuda.synchronize()
times4.append(s.elapsed_time(e))
print(f"Model.__call__ (no no_grad): ms={statistics.median(times4):.3f}")
EOF
uv run python /tmp/test_debug_overhead.py<system>Command executed successfully.</system>
Direct: ms=0.042
Model.scale=0.125, direct_scale=0.125, match=True
Model.forward: ms=0.034
Model.__call__: ms=0.035
Model.__call__ (no no_grad): ms=0.034
{
"command": "uv run python /tmp/test_debug_overhead.py"
}<system>Command executed successfully.</system>
Direct: ms=0.034
Model.scale=0.125, direct_scale=0.125, match=True
Model.forward: ms=0.051
Model.__call__: ms=0.035
Model.__call__ (no no_grad): ms=0.034
{
"command": "<2080 chars \u2014 see below>"
}cat > /tmp/test_random_order.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
import statistics
import random
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
shape = shapes.SHAPES[4]
apply_shape(shape)
torch.manual_seed(2026)
init_args = reference.get_init_inputs()
inputs = [t.to(device) for t in reference.get_inputs()]
sm_scale = 1.0 / math.sqrt(64)
num_kv_splits = 20
def run_test(name, fn, warmup=15, iters=30):
for _ in range(warmup):
with torch.no_grad():
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
fn()
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
print(f"{name}: ms={statistics.median(times):.3f}")
sol_model = solution.Model(*init_args).to(device).eval()
# Run tests in random order multiple times
for i in range(5):
order = ['direct', 'forward', 'call']
random.shuffle(order)
print(f"\nRun {i+1}, order: {order}")
for name in order:
if name == 'direct':
run_test('direct', lambda: solution.paged_attention_decode(inputs[0], inputs[1], inputs[2], inputs[3], sm_scale, 16, num_kv_splits))
elif name == 'forward':
run_test('forward', lambda: sol_model.forward(inputs[0], inputs[1], inputs[2], inputs[3]))
else:
run_test('call', lambda: sol_model(inputs[0], inputs[1], inputs[2], inputs[3]))
EOF
uv run python /tmp/test_random_order.py<system>Command executed successfully.</system>
Run 1, order: ['direct', 'call', 'forward']
direct: ms=0.044
call: ms=0.035
forward: ms=0.034
Run 2, order: ['forward', 'call', 'direct']
forward: ms=0.034
call: ms=0.043
direct: ms=0.033
Run 3, order: ['forward', 'direct', 'call']
forward: ms=0.034
direct: ms=0.033
call: ms=0.047
Run 4, order: ['forward', 'call', 'direct']
forward: ms=0.035
call: ms=0.035
direct: ms=0.033
Run 5, order: ['forward', 'direct', 'call']
forward: ms=0.041
direct: ms=0.035
call: ms=0.034
{
"command": "nvidia-smi -q -d CLOCK | head -40"
}<system>Command executed successfully.</system>
==============NVSMI LOG==============
Timestamp : Tue Apr 28 11:44:06 2026
Driver Version : 595.58.03
CUDA Version : 13.2
Attached GPUs : 1
GPU 00000000:01:00.0
Clocks
Graphics : 187 MHz
SM : 187 MHz
Memory : 405 MHz
Video : 607 MHz
Applications Clocks
Graphics : Requested functionality has been deprecated
Memory : Requested functionality has been deprecated
Default Applications Clocks
Graphics : Requested functionality has been deprecated
Memory : Requested functionality has been deprecated
Deferred Clocks
Memory : N/A
Max Clocks
Graphics : 3090 MHz
SM : 3090 MHz
Memory : 14001 MHz
Video : 3090 MHz
Max Customer Boost Clock
... (658 more chars){
"command": "<1562 chars \u2014 see below>"
}cat > /tmp/test_shape0.py << 'EOF'
import sys
sys.path.insert(0, '.')
import torch
import reference
import solution
import shapes
import math
import statistics
def apply_shape(shape):
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
device = torch.device("cuda:0")
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
shape = shapes.SHAPES[0]
apply_shape(shape)
torch.manual_seed(2026)
init_args = reference.get_init_inputs()
inputs = [t.to(device) for t in reference.get_inputs()]
sol_model = solution.Model(*init_args).to(device).eval()
for _ in range(15):
with torch.no_grad():
sol_model(*inputs)
torch.cuda.synchronize()
times = []
for _ in range(30):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
ms = statistics.median(times)
B, H, D = inputs[0].shape
Hkv = sol_model.num_kv_heads
L = sol_model.seq_len
bytes_moved = 2 * B * L * Hkv * D * 2 + B * H * D * 2 + B * H * D * 2 # KV + Q + Out in bytes
bw = bytes_moved / (ms / 1000) / 1e9
print(f"Shape 0: ms={ms:.3f}, BW={bw:.1f} GB/s, fraction={bw/1800:.3f}")
EOF
uv run python /tmp/test_shape0.py<system>Command executed successfully.</system>
Shape 0: ms=0.046, BW=729.2 GB/s, fraction=0.405
{
"command": "uv run python check.py"
}<system>Command executed successfully.</system>
PASS
{
"command": "uv run python benchmark.py"
}<system>Command executed successfully.</system>
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0428 11:44:49.240000 414793 .venv/lib/python3.11/site-packages/torch/_dynamo/var
... (1921 more chars)