"""Paged attention decode — Triton kernel, (B, H) grid for max parallelism.
Single-query decode: each batch element has a query (num_heads, head_dim) and
attends over a paged KV cache. Grid is (batch, num_heads). Each program handles
one query head, loads its KV data via page indirection, and computes attention
with online softmax in fp32.
"""
import math
import torch
import triton
import triton.language as tl
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
BATCH = 8
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
SEQ_LEN = 1024
PAGE_SIZE = 16
@triton.jit
def _paged_attn_kernel(
q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
stride_qb, stride_qh,
stride_kvb, stride_kvps, stride_kvkh,
stride_btb,
stride_ob, stride_oh,
scale,
HEAD_DIM: tl.constexpr,
PAGE_SIZE: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
batch = tl.program_id(0)
head = tl.program_id(1)
kv_head = head // GROUP_SIZE
offs_d = tl.arange(0, HEAD_DIM)
# Load query as bf16 once, reuse across all pages via Tensor-Core dot
q_bf16 = tl.load(
q_ptr + batch * stride_qb + head * stride_qh + offs_d
)
L = tl.load(sl_ptr + batch)
num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
offs_t = tl.arange(0, PAGE_SIZE)
# Online-softmax state across all pages
m_i = -float('inf')
d_i = 0.0
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for p in range(num_pages):
block_id = tl.load(bt_ptr + batch * stride_btb + p)
remaining = L - p * PAGE_SIZE
tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
mask = offs_t < tokens_in_page
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
# Load K and V as fp32 for elementwise dot
k_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
).to(tl.float32)
v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
# Elementwise dot, fp32 accumulate
scores = tl.sum(q_bf16.to(tl.float32)[None, :] * k_fp32, axis=1) * scale
scores = tl.where(mask, scores, float('-inf'))
m_block = tl.max(scores, 0)
m_new = tl.maximum(m_i, m_block)
alpha = tl.exp(m_i - m_new)
exp_scores = tl.exp(scores - m_new)
block_d = tl.sum(exp_scores)
acc = acc * alpha + tl.sum(exp_scores[:, None] * v_fp32, axis=0)
d_i = d_i * alpha + block_d
m_i = m_new
out = acc / d_i
tl.store(
out_ptr + batch * stride_ob + head * stride_oh + offs_d,
out.to(tl.bfloat16),
)
class Model(torch.nn.Module):
def __init__(
self,
batch: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_len: int,
page_size: int,
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.batch = batch
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.seq_len = seq_len
self.page_size = page_size
self.group_size = num_heads // num_kv_heads
self.scale = 1.0 / math.sqrt(head_dim)
self.register_buffer(
"_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False
)
def forward(
self,
query: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
) -> torch.Tensor:
B, H, D = query.shape
P = self.page_size
G = self.group_size
out = torch.empty(B, H, D, dtype=query.dtype, device=query.device)
grid = (B, H)
_paged_attn_kernel[grid](
query, kv_cache, block_table, seq_lens, out,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
block_table.stride(0),
out.stride(0), out.stride(1),
self.scale,
HEAD_DIM=D,
PAGE_SIZE=P,
GROUP_SIZE=G,
num_warps=2,
)
return out
def get_inputs():
B = BATCH
H = NUM_HEADS
Hkv = NUM_KV_HEADS
D = HEAD_DIM
L = SEQ_LEN
P = PAGE_SIZE
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1
perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(
B, pages_per_seq
).int()
block_table = perm.contiguous()
seq_lens = torch.full((B,), L, dtype=torch.int32)
return [query, kv_cache, block_table, seq_lens]
def get_init_inputs():
return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:52:14.578000 1127683 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=0 variant=eager tflops=0.230 gbps=57.689 ms=0.584
shape=0 variant=compiled tflops=0.226 gbps=56.794 ms=0.593
shape=0 variant=solution tflops=1.101 gbps=276.292 ms=0.122
shape=0 solution_peak_fraction=0.1535
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=1 variant=eager tflops=0.349 gbps=87.493 ms=3.074
shape=1 variant=compiled tflops=0.348 gbps=87.257 ms=3.082
shape=1 variant=solution tflops=3.768 gbps=943.798 ms=0.285
shape=1 solution_peak_fraction=0.5243
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=2 variant=eager tflops=0.244 gbps=30.559 ms=2.200
shape=2 variant=compiled tflops=0.243 gbps=30.393 ms=2.212
shape=2 variant=solution tflops=1.332 gbps=166.858 ms=0.403
shape=2 solution_peak_fraction=0.0927
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=3 variant=eager tflops=0.303 gbps=76.039 ms=1.326
shape=3 variant=compiled tflops=0.301 gbps=75.450 ms=1.337
shape=3 variant=solution tflops=2.134 gbps=534.896 ms=0.189
shape=3 solution_peak_fraction=0.2972
[sota unavailable] ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
shape=4 variant=eager tflops=0.113 gbps=28.338 ms=0.579
shape=4 variant=compiled tflops=0.112 gbps=27.945 ms=0.587
shape=4 variant=solution tflops=0.425 gbps=106.359 ms=0.154
shape=4 solution_peak_fraction=0.0591
peak_fraction: 0.1673
RESULT: OK
PASS
{
"problem": "03_paged_attention",
"harness": "opencode",
"model": "deepseek/deepseek-v4-flash",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1673,
"elapsed_seconds": 1676,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 82842, "output_tokens": 24493, "cache_read_tokens": 6434688, "cache_creation_tokens": 0, "reasoning_tokens": 85324, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py</path>
<type>file</type>
<content>
1: """Naive PyTorch paged-attention decode reference (correctness oracle, not SOTA).
2:
3: Single-query decode: each batch element has a query of shape (num_heads, head_dim)
4: and attends over a KV cache of `seq_len[b]` tokens stored as fixed-size pages in
5: a global pool. Pages for batch element b are listed in `block_table[b]`.
6:
7: The reference performs the slow path:
8: 1. Gather pages -> contiguous (seq_len, num_kv_heads, head_dim) per batch element.
9: 2. Repeat KV heads for grouped-query (broadcast num_kv_heads -> num_heads).
10: 3. Manual softmax(QK^T / sqrt(d)) @ V in fp32, cast back to bf16.
11:
12: This avoids torch.nn.functional.scaled_dot_product_attention (which is on the
13: forbidden list) so the agent cannot dispatch through SDPA either.
14: """
15: import math
16:
17: import torch
18: import torch.nn as nn
19:
20: OP_TYPE = "attention"
21: SUPPORTED_PRECISIONS = ["bf16"]
22: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
23:
24:
25: # --- Shape knobs (overridden by check.py / benchmark.py from shapes.py) ----
26: BATCH = 8
27: NUM_HEADS = 32
28: NUM_KV_HEADS = 8
29: HEAD_DIM = 128
30: SEQ_LEN = 1024
31: PAGE_SIZE = 16
32:
33:
34: class Model(nn.Module):
35: """Single-query paged attention decode.
36:
37: Forward inputs (all on device):
38: query: (batch, num_heads, head_dim) bf16
3
... (4426 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for paged attention decode.
2:
3: Mix targets:
4: - small batch / long context (Llama-3 8B-style decode)
5: - large batch / medium context (server batched decode)
6: - GQA wide ratio (Llama-3 70B: 64 heads / 8 kv-heads)
7: - non-power-of-2 seq_len (forces predicated tail handling)
8: - head_dim=64 small-head case
9: """
10:
11: SHAPES = [
12: # (B, H, Hkv, D, L, P)
13: {"batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1024, "page_size": 16},
14: {"batch": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 2048, "page_size": 16},
15: {"batch": 4, "num_heads": 64, "num_kv_heads": 8, "head_dim": 128, "seq_len": 4096, "page_size": 16},
16: {"batch": 16, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1535, "page_size": 16}, # non-pow2
17: {"batch": 8, "num_heads": 16, "num_kv_heads": 4, "head_dim": 64, "seq_len": 2000, "page_size": 16}, # small-D, non-pow2
18: ]
(End of file - total 18 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for paged-attention decode.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: REPO_ROOT = Path(__file__).resolve().parents[2]
14: sys.path.insert(0, str(REPO_ROOT))
15:
16: from src.eval.correctness import check_correctness # noqa: E402
17:
18:
19: def _apply_shape(reference, shape: dict) -> None:
20: reference.BATCH = shape["batch"]
21: reference.NUM_HEADS = shape["num_heads"]
22: reference.NUM_KV_HEADS = shape["num_kv_heads"]
23: reference.HEAD_DIM = shape["head_dim"]
24: reference.SEQ_LEN = shape["seq_len"]
25: reference.PAGE_SIZE = shape["page_size"]
26:
27:
28: def main():
29: try:
30: import reference
31: import shapes
32: import solution
33: except Exception as e:
34: print(f"FAIL: import error: {e}")
35: sys.exit(1)
36:
37: problem_yaml = Path("problem.yaml")
38: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
39:
40: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
41: for forbidden in meta.get("forbidden", []):
42: pat = re.escape(forbidden)
43:
... (2537 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for paged-attention decode.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Decode is memory-bound, so peak_fraction is computed from achieved GB/s vs
7: the GPU's peak DRAM bandwidth.
8: """
9: import sys
10: from math import exp, log
11: from pathlib import Path
12:
13: import torch
14: import yaml
15:
16: REPO_ROOT = Path(__file__).resolve().parents[2]
17: sys.path.insert(0, str(REPO_ROOT))
18:
19: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
20: from src.eval.timing import time_fn # noqa: E402
21: from src.hardware import get as get_hw # noqa: E402
22:
23:
24: def _eval_formula(expr: str, vars: dict) -> float:
25: return float(eval(expr, {"__builtins__": {}}, vars))
26:
27:
28: def _apply_shape(reference, shape: dict) -> None:
29: reference.BATCH = shape["batch"]
30: reference.NUM_HEADS = shape["num_heads"]
31: reference.NUM_KV_HEADS = shape["num_kv_heads"]
32: reference.HEAD_DIM = shape["head_dim"]
33: reference.SEQ_LEN = shape["seq_len"]
34: reference.PAGE_SIZE = shape["page_size"]
35:
36:
37: def main():
38: import reference
39: import shapes
40: import solution
41:
42: meta = yaml.safe_load(Path("problem.y
... (3563 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/problem.yaml</path>
<type>file</type>
<content>
1: name: 03_paged_attention
2: display_name: "Paged Attention Decode"
3: precision: bf16
4: regime: memory # decode is bandwidth-bound (KV cache streamed once per token)
5:
6: # Effective FLOPs per call: QK^T + softmax*V across all batches/heads/seq.
7: # Sum-formula evaluated per-shape; the benchmark substitutes batch/num_heads/etc.
8: # 2 * B * H * L * D for QK^T plus 2 * B * H * L * D for AV => 4 * B * H * L * D
9: flops_formula: "4 * batch * num_heads * seq_len * head_dim"
10:
11: # Bytes moved (the real bottleneck): the KV cache must be streamed end-to-end.
12: # K and V each: batch * seq_len * num_kv_heads * head_dim * 2 bytes (bf16)
13: # Q: batch * num_heads * head_dim * 2 (negligible vs KV)
14: # Out: same as Q
15: # So total ~ 2 * (B * L * Hkv * D * 2) + small.
16: bytes_formula: "2 * batch * seq_len * num_kv_heads * head_dim * 2 + batch * num_heads * head_dim * 2 * 2"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram # 1.8 TB/s GDDR7 on RTX PRO 6000 Blackwell
21:
22: tolerance:
23: bfloat16: 0.02
24: float16: 0.02
25:
26: # Forbidden ops -- using any of these in solution.py fails correctness post-hoc.
27: # vllm.attention / flashinfer.batch_decode_with_paged_kv_cache: dispatching to
28: # the SOTA library is the trivial cheat. SD
... (875 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for paged-attention decode.
2:
3: Tries, in order:
4: 1. FlashInfer's BatchDecodeWithPagedKVCacheWrapper (preferred -- portable,
5: supports SM120, GQA, arbitrary head_dim).
6: 2. vLLM's paged_attention_v2 CUDA op (requires its KV-cache layout, more
7: finicky; we adapt the layout on the fly when possible).
8:
9: If neither is importable, is_available() returns False and the benchmark just
10: reports eager + compiled + solution.
11:
12: Agents are FORBIDDEN from importing these in solution.py (see problem.yaml).
13: This file is only for the benchmark's reference line.
14: """
15: from __future__ import annotations
16:
17: import torch
18:
19:
20: def _try_flashinfer(
21: query: torch.Tensor,
22: kv_cache: torch.Tensor,
23: block_table: torch.Tensor,
24: seq_lens: torch.Tensor,
25: num_kv_heads: int,
26: head_dim: int,
27: page_size: int,
28: ) -> torch.Tensor | None:
29: try:
30: import flashinfer # noqa: F401
31: from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
32: except Exception:
33: return None
34:
35: B, H, D = query.shape
36: # FlashInfer expects K and V as separate (num_blocks, page_size, num_kv_heads, head_dim) tensors.
37: # Our reference packs [K|V] on the last dim -- split here.
38: k_cache = kv_cache[..., :D].contig
... (1803 more chars){
"command": "python3 -c \"import triton; print(triton.__version__); print(triton.__file__)\"",
"description": "Check Triton version"
}3.6.0
/home/infatoshi/.local/lib/python3.12/site-packages/triton/__init__.py
{
"command": "nvidia-smi",
"description": "Check GPU status"
}Mon Apr 27 08:27:02 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 26C P8 10W / 600W | 383MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (281 more chars){
"pattern": "src/eval/*.py"
}No files found{
"command": "python3 -c \"import torch; print(torch.cuda.get_device_capability()); print(torch.__version__)\"",
"description": "Check CUDA compute capability"
}(12, 0)
2.11.0+cu130
{
"command": "python3 -c \"import triton; print(triton.backends.nvidia.gpu_archs)\"",
"description": "Check Triton supported GPU archs"
}Traceback (most recent call last):
File "<string>", line 1, in <module>
AttributeError: module 'triton.backends.nvidia' has no attribute 'gpu_archs'
{
"description": "Check Triton SM120 support",
"command": "<352 chars \u2014 see below>"
}python3 -c "
import triton
# Check supported GPUs
from triton.backends.nvidia.driver import get_device_properties
props = get_device_properties(0)
print('SM:', props.major, props.minor)
print('Name:', props.name)
print('Cores:', props.multi_processor_count)
print('Shared mem per SM:', props.shared_memory_per_sm)
print('Warp size:', props.warp_size)
"Traceback (most recent call last):
File "<string>", line 4, in <module>
ImportError: cannot import name 'get_device_properties' from 'triton.backends.nvidia.driver' (/home/infatoshi/.local/lib/python3.12/site-packages/triton/backends/nvidia/driver.py)
{
"description": "Find hardware.py and eval modules",
"command": "find /home/infatoshi/cuda/KernelBench-Hard/src -name \"*.py\" | head -30"
}/home/infatoshi/cuda/KernelBench-Hard/src/eval/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/report.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/shapes.py
/home/infatoshi/cuda/KernelBench-Hard/src/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/m4_max.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/html.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/events.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/diff_util.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/kimi.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/opencode.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/cursor.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/claude.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/droid.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/codex.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/__main__.py
/home/infatoshi/cuda/KernelBench-Hard/src/sandbox/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/harness/__init__.py
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"todos": [
{
"content": "Write initial Triton paged attention kernel (solution.py)",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py to verify correctness across all shapes",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py to measure performance",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize kernel based on results",
"status": "pending",
"priority": "medium"
}
]
}content: 'Write initial Triton paged attention kernel (solution.py)'
status: 'in_progress'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py to measure performance'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel based on results'
status: 'pending'
priority: 'medium'{
"pattern": "problems/03_paged_attention/solution*"
}No files found{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/",
"description": "Check for existing solution.py"
}benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -0,0 +1,199 @@
+"""Paged attention decode — Triton kernel with online softmax + Tensor Cores.
+
+Single-query decode: each batch element has a query (num_heads, head_dim) and
+attends over a paged KV cache. The grid is (batch, num_kv_heads) so that KV
+data is loaded once and shared across all query heads in the same GQA group.
+"""
+
+import math
+
+import torch
+import triton
+import triton.language as tl
+
+OP_TYPE = "attention"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+BATCH = 8
+NUM_HEADS = 32
+NUM_KV_HEADS = 8
+HEAD_DIM = 128
+SEQ_LEN = 1024
+PAGE_SIZE = 16
+
+
+@triton.jit
+def _paged_attn_kernel(
+ q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
+ stride_qb, stride_qh,
+ stride_kvb, stride_kvps, stride_kvkh,
+ stride_btb,
+ stride_ob, stride_oh,
+ scale,
+ HEAD_DIM: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ batch = tl.program_id(0)
+ kv_head = tl.program_id(1)
+
+ offs_d = tl.arange(0, HEAD_DIM)
+
+ L = tl.load(sl_ptr + batch)
+ num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
+
+ offs_t = tl.arange(0, PAGE_SIZE)
+
+ # Per-query-head online-softmax state across all pages
+ m_i = tl.full([GROUP_SIZE], float('-inf'), dtype=tl.float32)
+ d_i = tl.zeros([GROUP_SIZE], dtype=tl.float32)
+ acc = tl.zeros([GROUP_SIZE, HEAD_DIM], dtype=tl.float32)
+
+ for p in range(num_pages):
+ block_id = tl.load(bt_ptr + batch * stride_btb + p)
+ remaining = L - p * PAGE_SIZE
+ tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
+ mask = offs_t < tokens_in_page
+
+ kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
+
+ # Load K (bf16 — stays bf16 for Tensor-Core dot) and V (cast to fp32)
+ k_bf16 = tl.load(
+ kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
+ mask=mask[:, None],
+ )
+ v_fp32 = tl.load(
+ kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
+ mask=mask[:, None],
+ ).to(tl.float32)
+
+ # Process each query head in the GQA group
+ for g in range(GROUP_SIZE):
+ head = kv_head * GROUP_SIZE + g
+
+ q_bf16 = tl.load(
+ q_ptr + batch * stride_qb + head * stride_qh + offs_d
+ )
+
+ # Dot product via Tensor Cores (bf16 in, fp32 out)
+ q_2d = q_bf16[None, :]
+ k_t = tl.trans(k_bf16)
+ scores_2d = tl.dot(q_2d, k_t)
+ scores = scores_2d[0, :] * scale
+ scores = tl.where(mask, scores, float('-inf'))
+
+ # Online safe softmax
+ m_block = tl.max(scores, 0)
+
+ g_mask_1d = tl.arange(0, GROUP_SIZE) == g
+ m_i_g = tl.sum(m_i * g_mask_1d.to(tl.float32))
+ d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
+
+ m_new = tl.maximum(m_i_g, m_block)
+
+ alpha = tl.exp(m_i_g - m_new)
+ exp_scores = tl.exp(scores - m_new)
+ block_d = tl.sum(exp_scores)
+
+ acc_g = tl.sum(
+ acc * g_mask_1d[:, None].to(tl.float32), axis=0
+ )
+ new_acc_g = acc_g * alpha + tl.sum(
+ exp_scores[:, None] * v_fp32, axis=0
+ )
+ new_d_g = d_i_g * alpha + block_d
+
+ # Commit updated state for this head
+ m_i = tl.where(g_mask_1d, m_new, m_i)
+ d_i = tl.where(g_mask_1d, new_d_g, d_i)
+ acc = tl.where(g_mask_1d[:, None], new_acc_g[None, :], acc)
+
+ # Write results
+ for g in range(GROUP_SIZE):
+ head = kv_head * GROUP_SIZE + g
+ g_mask_1d = tl.arange(0, GROUP_SIZE) == g
+ d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
+ acc_g = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
+ result = acc_g / d_i_g
+ tl.store(
+ out_ptr + batch * stride_ob + head * stride_oh + offs_d,
+ result.to(tl.bfloat16),
+ )
+
+
+class Model(torch.nn.Module):
+ def __init__(
+ self,
+ batch: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ seq_len: int,
+ page_size: int,
+ ):
+ super().__init__()
+ assert num_heads % num_kv_heads == 0
+ self.batch = batch
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
+ self.seq_len = seq_len
+ self.page_size = page_size
+ self.group_size = num_heads // num_kv_heads
+ self.scale = 1.0 / math.sqrt(head_dim)
+ self.register_buffer(
+ "_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False
+ )
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ seq_lens: torch.Tensor,
+ ) -> torch.Tensor:
+ B, H, D = query.shape
+ Hkv = self.num_kv_heads
+ P = self.page_size
+ G = self.group_size
+
+ out = torch.empty(B, H, D, dtype=query.dtype, device=query.device)
+
+ grid = (B, Hkv)
+ _paged_attn_kernel[grid](
+ query, kv_cache, block_table, seq_lens, out,
+ query.stride(0), query.stride(1),
+ kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
+ block_table.stride(0),
+ out.stride(0), out.stride(1),
+ self.scale,
+ HEAD_DIM=D,
+ PAGE_SIZE=P,
+ GROUP_SIZE=G,
+ )
+
+ return out
+
+
+def get_inputs():
+ B = BATCH
+ H = NUM_HEADS
+ Hkv = NUM_KV_HEADS
+ D = HEAD_DIM
+ L = SEQ_LEN
+ P = PAGE_SIZE
+ pages_per_seq = (L + P - 1) // P
+ total_pages = max(B * pages_per_seq + 8, 64)
+ query = torch.randn(B, H, D, dtype=torch.bfloat16)
... (448 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"content": "<6133 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run correctness check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 1101, in __getitem__
raise ValueError(f"unsupported tensor index: {sl}")
ValueError: unsupported tensor index: constexpr[0]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 72, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 164, in forward
_paged_attn_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs:
... (1767 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"offset": 30,
"limit": 80
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
30: stride_kvb, stride_kvps, stride_kvkh,
31: stride_btb,
32: stride_ob, stride_oh,
33: scale,
34: HEAD_DIM: tl.constexpr,
35: PAGE_SIZE: tl.constexpr,
36: GROUP_SIZE: tl.constexpr,
37: ):
38: batch = tl.program_id(0)
39: kv_head = tl.program_id(1)
40:
41: offs_d = tl.arange(0, HEAD_DIM)
42:
43: L = tl.load(sl_ptr + batch)
44: num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
45:
46: offs_t = tl.arange(0, PAGE_SIZE)
47:
48: # Per-query-head online-softmax state across all pages
49: m_i = tl.full([GROUP_SIZE], float('-inf'), dtype=tl.float32)
50: d_i = tl.zeros([GROUP_SIZE], dtype=tl.float32)
51: acc = tl.zeros([GROUP_SIZE, HEAD_DIM], dtype=tl.float32)
52:
53: for p in range(num_pages):
54: block_id = tl.load(bt_ptr + batch * stride_btb + p)
55: remaining = L - p * PAGE_SIZE
56: tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
57: mask = offs_t < tokens_in_page
58:
59: kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
60:
61: # Load K (bf16 — stays bf16 for Tensor-Core dot) and V (cast to fp32)
62: k_bf16 = tl.load(
63: kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
64: mask=mask[:, None],
65: )
66: v_fp32 = tl.load(
67: kv_base + of
... (1738 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -58,11 +58,11 @@
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
- # Load K (bf16 — stays bf16 for Tensor-Core dot) and V (cast to fp32)
- k_bf16 = tl.load(
+ # Load K and V in fp32 for elementwise dot products
+ k_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
- )
+ ).to(tl.float32)
v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
@@ -72,15 +72,12 @@
for g in range(GROUP_SIZE):
head = kv_head * GROUP_SIZE + g
- q_bf16 = tl.load(
+ q_fp32 = tl.load(
q_ptr + batch * stride_qb + head * stride_qh + offs_d
- )
+ ).to(tl.float32)
- # Dot product via Tensor Cores (bf16 in, fp32 out)
- q_2d = q_bf16[None, :]
- k_t = tl.trans(k_bf16)
- scores_2d = tl.dot(q_2d, k_t)
- scores = scores_2d[0, :] * scale
+ # Elementwise dot product, fp32 accumulate
+ scores = tl.sum(q_fp32[None, :] * k_fp32, axis=1) * scale
scores = tl.where(mask, scores, float('-inf'))
# Online safe softmax
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": "<982 chars \u2014 see below>",
"newString": "<899 chars \u2014 see below>"
} kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
# Load K (bf16 — stays bf16 for Tensor-Core dot) and V (cast to fp32)
k_bf16 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
)
v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
# Process each query head in the GQA group
for g in range(GROUP_SIZE):
head = kv_head * GROUP_SIZE + g
q_bf16 = tl.load(
q_ptr + batch * stride_qb + head * stride_qh + offs_d
)
# Dot product via Tensor Cores (bf16 in, fp32 out)
q_2d = q_bf16[None, :]
k_t = tl.trans(k_bf16)
scores_2d = tl.dot(q_2d, k_t)
scores = scores_2d[0, :] * scale
scores = tl.where(mask, scores, float('-inf')) kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
# Load K and V in fp32 for elementwise dot products
k_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
).to(tl.float32)
v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
# Process each query head in the GQA group
for g in range(GROUP_SIZE):
head = kv_head * GROUP_SIZE + g
q_fp32 = tl.load(
q_ptr + batch * stride_qb + head * stride_qh + offs_d
).to(tl.float32)
# Elementwise dot product, fp32 accumulate
scores = tl.sum(q_fp32[None, :] * k_fp32, axis=1) * scale
scores = tl.where(mask, scores, float('-inf'))Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1",
"description": "Run correctness check after fix",
"timeout": 120000
}FAIL: shape 0 {'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16} seed 42: solution contains NaN
{
"command": "<1088 chars \u2014 see below>",
"description": "Debug kernel with simple test",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
import torch
import math
import solution
import reference
torch.manual_seed(42)
# Use a single shape
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
sol = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
# Load state dict
sd = ref.state_dict()
sol.load_state_dict(sd, strict=True)
# Get inputs
reference.BATCH, reference.NUM_HEADS = B, H
reference.NUM_KV_HEADS, reference.HEAD_DIM = Hkv, D
reference.SEQ_LEN, reference.PAGE_SIZE = L, P
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref(*inputs)
print('Ref output range:', ref_out.min().item(), ref_out.max().item())
print('Ref has NaN:', torch.isnan(ref_out).any().item())
sol_out = sol(*inputs)
print('Sol output range:', sol_out.min().item(), sol_out.max().item())
print('Sol has NaN:', torch.isnan(sol_out).any().item())
max_diff = (ref_out.float() - sol_out.float()).abs().max().item()
print('Max diff:', max_diff)
" 2>&1Ref output range: -0.01220703125 0.01141357421875
Ref has NaN: False
Sol output range: nan nan
Sol has NaN: True
Max diff: nan
{
"command": "<537 chars \u2014 see below>",
"description": "Test basic Triton kernel on SM120",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
import torch
import triton
import triton.language as tl
# Minimal test: just load and store
@triton.jit
def test_kernel(x_ptr, y_ptr, n: tl.constexpr):
pid = tl.program_id(0)
offs = tl.arange(0, n)
x = tl.load(x_ptr + offs)
tl.store(y_ptr + offs, x + 1.0)
x = torch.ones(128, device='cuda', dtype=torch.float32)
y = torch.zeros(128, device='cuda', dtype=torch.float32)
test_kernel[(1,)](x, y, 128)
print('kernel works:', y[:4])
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1260, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1089, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 7, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file
{
"command": "<968 chars \u2014 see below>",
"description": "Test with GROUP_SIZE right on command",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
import sys; sys.path.insert(0, '.')
import solution
# Check the kernel works by running with GROUP_SIZE=1
import torch
B, H, Hkv, D, L, P = 2, 4, 2, 128, 64, 16
ref = __import__('reference').Model(B, H, Hkv, D, L, P).cuda().eval()
sol = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
import reference as ref_mod
ref_mod.BATCH = B
ref_mod.NUM_HEADS = H
ref_mod.NUM_KV_HEADS = Hkv
ref_mod.HEAD_DIM = D
ref_mod.SEQ_LEN = L
ref_mod.PAGE_SIZE = P
torch.manual_seed(42)
inputs = [t.cuda() for t in ref_mod.get_inputs()]
with torch.no_grad():
ref_out = ref(*inputs)
sol_out = sol(*inputs)
print('Ref has NaN:', torch.isnan(ref_out).any().item())
print('Sol has NaN:', torch.isnan(sol_out).any().item())
if not torch.isnan(sol_out).any():
print('Max diff:', (ref_out.float() - sol_out.float()).abs().max().item())
" 2>&1Ref has NaN: False
Sol has NaN: True
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py
@@ -0,0 +1,141 @@
+"""Debug script for paged attention kernel."""
+import math
+
+import torch
+import triton
+import triton.language as tl
+
+
+# Step 1: Minimal Triton kernel on SM120
+@triton.jit
+def test_kernel(x_ptr, y_ptr, N: tl.constexpr):
+ pid = tl.program_id(0)
+ offs = tl.arange(0, N)
+ x = tl.load(x_ptr + offs)
+ tl.store(y_ptr + offs, x + 1.0)
+
+
+def test_basic():
+ N = 128
+ x = torch.ones(N, device='cuda', dtype=torch.float32)
+ y = torch.zeros(N, device='cuda', dtype=torch.float32)
+ test_kernel[(1,)](x, y, N=N)
+ print(f"Basic kernel works: {y[:4].tolist()}")
+
+
+# Step 2: Simple paged attention kernel (1 head, 1 batch)
+@triton.jit
+def simple_paged_attn_kernel(
+ q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
+ stride_qb, stride_qh,
+ stride_kvb, stride_kvps, stride_kvkh,
+ stride_btb,
+ stride_ob, stride_oh,
+ scale,
+ HEAD_DIM: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+):
+ batch = tl.program_id(0)
+ head = tl.program_id(1)
+
+ offs_d = tl.arange(0, HEAD_DIM)
+ q = tl.load(q_ptr + batch * stride_qb + head * stride_qh + offs_d).to(tl.float32)
+
+ L = tl.load(sl_ptr + batch)
+ num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
+
+ m_i = -float('inf')
+ d_i = 0.0
+ acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
+
+ offs_t = tl.arange(0, PAGE_SIZE)
+
+ for p in range(num_pages):
+ block_id = tl.load(bt_ptr + batch * stride_btb + p)
+ remaining = L - p * PAGE_SIZE
+ tokens_in_page = PAGE_SIZE
+ if remaining < PAGE_SIZE:
+ tokens_in_page = remaining
+ mask = offs_t < tokens_in_page
+
+ kv_base = kv_ptr + block_id * stride_kvb + head * stride_kvkh
+
+ k = tl.load(
+ kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
+ mask=mask[:, None],
+ ).to(tl.float32)
+ v = tl.load(
+ kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
+ mask=mask[:, None],
+ ).to(tl.float32)
+
+ scores = tl.sum(q[None, :] * k, axis=1) * scale
+ scores = tl.where(mask, scores, float('-inf'))
+
+ m_block = tl.max(scores, 0)
+ m_new = tl.maximum(m_i, m_block)
+
+ alpha = tl.exp(m_i - m_new)
+ exp_scores = tl.exp(scores - m_new)
+ block_d = tl.sum(exp_scores)
+
+ acc = acc * alpha + tl.sum(exp_scores[:, None] * v, axis=0)
+ d_i = d_i * alpha + block_d
+ m_i = m_new
+
+ out = acc / d_i
+ tl.store(out_ptr + batch * stride_ob + head * stride_oh + offs_d, out.to(tl.bfloat16))
+
+
+def test_simple_paged():
+ B, H, Hkv, D, L, P = 2, 1, 1, 128, 64, 16
+ scale = 1.0 / math.sqrt(D)
+
+ pages_per_seq = (L + P - 1) // P
+ total_pages = B * pages_per_seq + 4
+
+ torch.manual_seed(42)
+ query = torch.randn(B, H, D, device='cuda', dtype=torch.bfloat16) * 0.1
+ kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, device='cuda', dtype=torch.bfloat16) * 0.1
+ block_table = torch.arange(B * pages_per_seq, dtype=torch.int32, device='cuda').reshape(B, pages_per_seq)
+ seq_lens = torch.full((B,), L, dtype=torch.int32, device='cuda')
+
+ out = torch.empty(B, H, D, device='cuda', dtype=torch.bfloat16)
+
+ grid = (B, H)
+ simple_paged_attn_kernel[grid](
+ query, kv_cache, block_table, seq_lens, out,
+ query.stride(0), query.stride(1),
+ kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
+ block_table.stride(0),
+ out.stride(0), out.stride(1),
+ scale,
+ HEAD_DIM=D, PAGE_SIZE=P,
+ )
+
+ ref_out = torch.zeros_like(out)
+ for b in range(B):
+ L_val = int(seq_lens[b].item())
+ num_p = (L_val + P - 1) // P
+ pages = block_table[b, :num_p].long()
+ kv = kv_cache.index_select(0, pages)
+ kv = kv.reshape(num_p * P, Hkv, 2 * D)[:L_val]
+ k = kv[..., :D]
+ v = kv[..., D:]
+ q = query[b]
+ qf = q.float()
+ kf = k.float()
+ vf = v.float()
+ scores = torch.einsum("hd,lhd->hl", qf, kf) * scale
+ probs = torch.softmax(scores, dim=-1)
+ o = torch.einsum("hl,lhd->hd", probs, vf)
+ ref_out[b] = o.to(query.dtype)
+
+ print(f"Simple paged has NaN: {torch.isnan(out).any().item()}")
+ max_diff = (ref_out.float() - out.float()).abs().max().item()
+ print(f"Max diff: {max_diff}")
+
+
+if __name__ == "__main__":
+ test_basic()
+ test_simple_paged()
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py",
"content": "<4318 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python debug_test.py 2>&1",
"description": "Run debug test",
"timeout": 60000
}Basic kernel works: [2.0, 2.0, 2.0, 2.0]
Simple paged has NaN: False
Max diff: 0.0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py
@@ -136,6 +136,143 @@
print(f"Max diff: {max_diff}")
+@triton.jit
+def gqa_paged_attn_kernel(
+ q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
+ stride_qb, stride_qh,
+ stride_kvb, stride_kvps, stride_kvkh,
+ stride_btb,
+ stride_ob, stride_oh,
+ scale,
+ HEAD_DIM: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ batch = tl.program_id(0)
+ kv_head = tl.program_id(1)
+
+ offs_d = tl.arange(0, HEAD_DIM)
+
+ L = tl.load(sl_ptr + batch)
+ num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
+
+ offs_t = tl.arange(0, PAGE_SIZE)
+
+ m_i = tl.full([GROUP_SIZE], float('-inf'), dtype=tl.float32)
+ d_i = tl.zeros([GROUP_SIZE], dtype=tl.float32)
+ acc = tl.zeros([GROUP_SIZE, HEAD_DIM], dtype=tl.float32)
+
+ for p in range(num_pages):
+ block_id = tl.load(bt_ptr + batch * stride_btb + p)
+ remaining = L - p * PAGE_SIZE
+ tokens_in_page = PAGE_SIZE
+ if remaining < PAGE_SIZE:
+ tokens_in_page = remaining
+ mask = offs_t < tokens_in_page
+
+ kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
+
+ k = tl.load(
+ kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
+ mask=mask[:, None],
+ ).to(tl.float32)
+ v = tl.load(
+ kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
+ mask=mask[:, None],
+ ).to(tl.float32)
+
+ for g in range(GROUP_SIZE):
+ head = kv_head * GROUP_SIZE + g
+ q = tl.load(q_ptr + batch * stride_qb + head * stride_qh + offs_d).to(tl.float32)
+
+ scores = tl.sum(q[None, :] * k, axis=1) * scale
+ scores = tl.where(mask, scores, float('-inf'))
+
+ m_block = tl.max(scores, 0)
+
+ g_mask_1d = tl.arange(0, GROUP_SIZE) == g
+ m_i_g = tl.sum(m_i * g_mask_1d.to(tl.float32))
+ d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
+
+ m_new = tl.maximum(m_i_g, m_block)
+
+ alpha = tl.exp(m_i_g - m_new)
+ exp_scores = tl.exp(scores - m_new)
+ block_d = tl.sum(exp_scores)
+
+ acc_g = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
+ new_acc_g = acc_g * alpha + tl.sum(exp_scores[:, None] * v, axis=0)
+ new_d_g = d_i_g * alpha + block_d
+
+ m_i = tl.where(g_mask_1d, m_new, m_i)
+ d_i = tl.where(g_mask_1d, new_d_g, d_i)
+ acc = tl.where(g_mask_1d[:, None], new_acc_g[None, :], acc)
+
+ for g in range(GROUP_SIZE):
+ head = kv_head * GROUP_SIZE + g
+ g_mask_1d = tl.arange(0, GROUP_SIZE) == g
+ d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
+ acc_g = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
+ result = acc_g / d_i_g
+ tl.store(out_ptr + batch * stride_ob + head * stride_oh + offs_d,
+ result.to(tl.bfloat16))
+
+
+def test_gqa_paged():
+ B, H, Hkv, D, L, P = 2, 4, 2, 128, 64, 16
+ G = H // Hkv
+ scale = 1.0 / math.sqrt(D)
+
+ pages_per_seq = (L + P - 1) // P
+ total_pages = B * pages_per_seq + 4
+
+ torch.manual_seed(42)
+ query = torch.randn(B, H, D, device='cuda', dtype=torch.bfloat16) * 0.1
+ kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, device='cuda', dtype=torch.bfloat16) * 0.1
+ block_table = torch.arange(B * pages_per_seq, dtype=torch.int32, device='cuda').reshape(B, pages_per_seq)
+ seq_lens = torch.full((B,), L, dtype=torch.int32, device='cuda')
+
+ out = torch.empty(B, H, D, device='cuda', dtype=torch.bfloat16)
+
+ grid = (B, Hkv)
+ gqa_paged_attn_kernel[grid](
+ query, kv_cache, block_table, seq_lens, out,
+ query.stride(0), query.stride(1),
+ kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
+ block_table.stride(0),
+ out.stride(0), out.stride(1),
+ scale,
+ HEAD_DIM=D, PAGE_SIZE=P, GROUP_SIZE=G,
+ )
+
+ print(f"GQA paged has NaN: {torch.isnan(out).any().item()}")
+
+ # Reference
+ ref_out = torch.zeros_like(out)
+ for b in range(B):
+ L_val = int(seq_lens[b].item())
+ num_p = (L_val + P - 1) // P
+ pages = block_table[b, :num_p].long()
+ kv = kv_cache.index_select(0, pages)
+ kv = kv.reshape(num_p * P, Hkv, 2 * D)[:L_val]
+ k = kv[..., :D]
+ v = kv[..., D:]
+ kv_head_repeated = k.repeat_interleave(G, dim=1)
+ v_head_repeated = v.repeat_interleave(G, dim=1)
+ q = query[b]
+ qf = q.float()
+ kf = kv_head_repeated.float()
+ vf = v_head_repeated.float()
+ scores = torch.einsum("hd,lhd->hl", qf, kf) * scale
+ probs = torch.softmax(scores, dim=-1)
+ o = torch.einsum("hl,lhd->hd", probs, vf)
+ ref_out[b] = o.to(query.dtype)
+
+ max_diff = (ref_out.float() - out.float()).abs().max().item()
+ print(f"GQA Max diff: {max_diff}")
+
+
if __name__ == "__main__":
test_basic()
test_simple_paged()
+ test_gqa_paged()
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py",
"oldString": "<67 chars \u2014 see below>",
"newString": "<4848 chars \u2014 see below>"
}if __name__ == "__main__":
test_basic()
test_simple_paged()@triton.jit
def gqa_paged_attn_kernel(
q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
stride_qb, stride_qh,
stride_kvb, stride_kvps, stride_kvkh,
stride_btb,
stride_ob, stride_oh,
scale,
HEAD_DIM: tl.constexpr,
PAGE_SIZE: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
batch = tl.program_id(0)
kv_head = tl.program_id(1)
offs_d = tl.arange(0, HEAD_DIM)
L = tl.load(sl_ptr + batch)
num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
offs_t = tl.arange(0, PAGE_SIZE)
m_i = tl.full([GROUP_SIZE], float('-inf'), dtype=tl.float32)
d_i = tl.zeros([GROUP_SIZE], dtype=tl.float32)
acc = tl.zeros([GROUP_SIZE, HEAD_DIM], dtype=tl.float32)
for p in range(num_pages):
block_id = tl.load(bt_ptr + batch * stride_btb + p)
remaining = L - p * PAGE_SIZE
tokens_in_page = PAGE_SIZE
if remaining < PAGE_SIZE:
tokens_in_page = remaining
mask = offs_t < tokens_in_page
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
k = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
).to(tl.float32)
v = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
for g in range(GROUP_SIZE):
head = kv_head * GROUP_SIZE + g
q = tl.load(q_ptr + batch * stride_qb + head * stride_qh + offs_d).to(tl.float32)
scores = tl.sum(q[None, :] * k, axis=1) * scale
scores = tl.where(mask, scores, float('-inf'))
m_block = tl.max(scores, 0)
g_mask_1d = tl.arange(0, GROUP_SIZE) == g
m_i_g = tl.sum(m_i * g_mask_1d.to(tl.float32))
d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
m_new = tl.maximum(m_i_g, m_block)
alpha = tl.exp(m_i_g - m_new)
exp_scores = tl.exp(scores - m_new)
block_d = tl.sum(exp_scores)
acc_g = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
new_acc_g = acc_g * alpha + tl.sum(exp_scores[:, None] * v, axis=0)
new_d_g = d_i_g * alpha + block_d
m_i = tl.where(g_mask_1d, m_new, m_i)
d_i = tl.where(g_mask_1d, new_d_g, d_i)
acc = tl.where(g_mask_1d[:, None], new_acc_g[None, :], acc)
for g in range(GROUP_SIZE):
head = kv_head * GROUP_SIZE + g
g_mask_1d = tl.arange(0, GROUP_SIZE) == g
d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
acc_g = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
result = acc_g / d_i_g
tl.store(out_ptr + batch * stride_ob + head * stride_oh + offs_d,
result.to(tl.bfloat16))
def test_gqa_paged():
B, H, Hkv, D, L, P = 2, 4, 2, 128, 64, 16
G = H // Hkv
scale = 1.0 / math.sqrt(D)
pages_per_seq = (L + P - 1) // P
total_pages = B * pages_per_seq + 4
torch.manual_seed(42)
query = torch.randn(B, H, D, device='cuda', dtype=torch.bfloat16) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, device='cuda', dtype=torch.bfloat16) * 0.1
block_table = torch.arange(B * pages_per_seq, dtype=torch.int32, device='cuda').reshape(B, pages_per_seq)
seq_lens = torch.full((B,), L, dtype=torch.int32, device='cuda')
out = torch.empty(B, H, D, device='cuda', dtype=torch.bfloat16)
grid = (B, Hkv)
gqa_paged_attn_kernel[grid](
query, kv_cache, block_table, seq_lens, out,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
block_table.stride(0),
out.stride(0), out.stride(1),
scale,
HEAD_DIM=D, PAGE_SIZE=P, GROUP_SIZE=G,
)
print(f"GQA paged has NaN: {torch.isnan(out).any().item()}")
# Reference
ref_out = torch.zeros_like(out)
for b in range(B):
L_val = int(seq_lens[b].item())
... (848 more chars)Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python debug_test.py 2>&1",
"description": "Test GQA paged kernel",
"timeout": 60000
}Basic kernel works: [2.0, 2.0, 2.0, 2.0]
Simple paged has NaN: False
Max diff: 0.0
GQA paged has NaN: True
GQA Max diff: nan
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py
@@ -272,7 +272,43 @@
print(f"GQA Max diff: {max_diff}")
+# Test the group extraction logic in isolation
+@triton.jit
+def test_group_kernel(
+ out_ptr, GROUP_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,
+):
+ m_i = tl.full([GROUP_SIZE], float('-inf'), dtype=tl.float32)
+ d_i = tl.zeros([GROUP_SIZE], dtype=tl.float32)
+ acc = tl.zeros([GROUP_SIZE, HEAD_DIM], dtype=tl.float32)
+
+ # For g=0, set m_i_g = 5.0, d_i_g = 10.0, acc_g = 3.0
+ g = 0
+ g_mask_1d = tl.arange(0, GROUP_SIZE) == g
+ m_i = tl.where(g_mask_1d, 5.0, m_i)
+ d_i = tl.where(g_mask_1d, 10.0, d_i)
+ acc_g_new = tl.full([HEAD_DIM], 3.0, dtype=tl.float32)
+ acc = tl.where(g_mask_1d[:, None], acc_g_new[None, :], acc)
+
+ # Read back g=0
+ d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
+ m_i_g = tl.sum(m_i * g_mask_1d.to(tl.float32))
+
+ # Read back g=0 acc
+ read_acc = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
+
+ tl.store(out_ptr + 0, d_i_g)
+ tl.store(out_ptr + 1, m_i_g)
+ tl.store(out_ptr + 2, tl.sum(read_acc))
+
+
+def test_group_ops():
+ out = torch.zeros(8, device='cuda', dtype=torch.float32)
+ test_group_kernel[(1,)](out, GROUP_SIZE=4, HEAD_DIM=128)
+ print(f"Group ops: d_i={out[0].item()}, m_i={out[1].item()}, acc_sum={out[2].item()}")
+
+
if __name__ == "__main__":
test_basic()
test_simple_paged()
test_gqa_paged()
+ test_group_ops()
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py",
"oldString": "<88 chars \u2014 see below>",
"newString": "<1330 chars \u2014 see below>"
}if __name__ == "__main__":
test_basic()
test_simple_paged()
test_gqa_paged()# Test the group extraction logic in isolation
@triton.jit
def test_group_kernel(
out_ptr, GROUP_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,
):
m_i = tl.full([GROUP_SIZE], float('-inf'), dtype=tl.float32)
d_i = tl.zeros([GROUP_SIZE], dtype=tl.float32)
acc = tl.zeros([GROUP_SIZE, HEAD_DIM], dtype=tl.float32)
# For g=0, set m_i_g = 5.0, d_i_g = 10.0, acc_g = 3.0
g = 0
g_mask_1d = tl.arange(0, GROUP_SIZE) == g
m_i = tl.where(g_mask_1d, 5.0, m_i)
d_i = tl.where(g_mask_1d, 10.0, d_i)
acc_g_new = tl.full([HEAD_DIM], 3.0, dtype=tl.float32)
acc = tl.where(g_mask_1d[:, None], acc_g_new[None, :], acc)
# Read back g=0
d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
m_i_g = tl.sum(m_i * g_mask_1d.to(tl.float32))
# Read back g=0 acc
read_acc = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
tl.store(out_ptr + 0, d_i_g)
tl.store(out_ptr + 1, m_i_g)
tl.store(out_ptr + 2, tl.sum(read_acc))
def test_group_ops():
out = torch.zeros(8, device='cuda', dtype=torch.float32)
test_group_kernel[(1,)](out, GROUP_SIZE=4, HEAD_DIM=128)
print(f"Group ops: d_i={out[0].item()}, m_i={out[1].item()}, acc_sum={out[2].item()}")
if __name__ == "__main__":
test_basic()
test_simple_paged()
test_gqa_paged()
test_group_ops()Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python debug_test.py 2>&1",
"description": "Test group extraction logic",
"timeout": 60000
}Basic kernel works: [2.0, 2.0, 2.0, 2.0]
Simple paged has NaN: False
Max diff: 0.0
GQA paged has NaN: True
GQA Max diff: nan
Group ops: d_i=10.0, m_i=nan, acc_sum=384.0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -84,8 +84,9 @@
m_block = tl.max(scores, 0)
g_mask_1d = tl.arange(0, GROUP_SIZE) == g
- m_i_g = tl.sum(m_i * g_mask_1d.to(tl.float32))
- d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
+ zeros = tl.zeros([GROUP_SIZE], dtype=tl.float32)
+ m_i_g = tl.sum(tl.where(g_mask_1d, m_i, zeros))
+ d_i_g = tl.sum(tl.where(g_mask_1d, d_i, zeros))
m_new = tl.maximum(m_i_g, m_block)
@@ -94,7 +95,8 @@
block_d = tl.sum(exp_scores)
acc_g = tl.sum(
- acc * g_mask_1d[:, None].to(tl.float32), axis=0
+ tl.where(g_mask_1d[:, None], acc, tl.zeros_like(acc)),
+ axis=0,
)
new_acc_g = acc_g * alpha + tl.sum(
exp_scores[:, None] * v_fp32, axis=0
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": "<843 chars \u2014 see below>",
"newString": "<937 chars \u2014 see below>"
} g_mask_1d = tl.arange(0, GROUP_SIZE) == g
m_i_g = tl.sum(m_i * g_mask_1d.to(tl.float32))
d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
m_new = tl.maximum(m_i_g, m_block)
alpha = tl.exp(m_i_g - m_new)
exp_scores = tl.exp(scores - m_new)
block_d = tl.sum(exp_scores)
acc_g = tl.sum(
acc * g_mask_1d[:, None].to(tl.float32), axis=0
)
new_acc_g = acc_g * alpha + tl.sum(
exp_scores[:, None] * v_fp32, axis=0
)
new_d_g = d_i_g * alpha + block_d
# Commit updated state for this head
m_i = tl.where(g_mask_1d, m_new, m_i)
d_i = tl.where(g_mask_1d, new_d_g, d_i)
acc = tl.where(g_mask_1d[:, None], new_acc_g[None, :], acc) g_mask_1d = tl.arange(0, GROUP_SIZE) == g
zeros = tl.zeros([GROUP_SIZE], dtype=tl.float32)
m_i_g = tl.sum(tl.where(g_mask_1d, m_i, zeros))
d_i_g = tl.sum(tl.where(g_mask_1d, d_i, zeros))
m_new = tl.maximum(m_i_g, m_block)
alpha = tl.exp(m_i_g - m_new)
exp_scores = tl.exp(scores - m_new)
block_d = tl.sum(exp_scores)
acc_g = tl.sum(
tl.where(g_mask_1d[:, None], acc, tl.zeros_like(acc)),
axis=0,
)
new_acc_g = acc_g * alpha + tl.sum(
exp_scores[:, None] * v_fp32, axis=0
)
new_d_g = d_i_g * alpha + block_d
# Commit updated state for this head
m_i = tl.where(g_mask_1d, m_new, m_i)
d_i = tl.where(g_mask_1d, new_d_g, d_i)
acc = tl.where(g_mask_1d[:, None], new_acc_g[None, :], acc)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -112,8 +112,12 @@
for g in range(GROUP_SIZE):
head = kv_head * GROUP_SIZE + g
g_mask_1d = tl.arange(0, GROUP_SIZE) == g
- d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
- acc_g = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
+ zeros = tl.zeros([GROUP_SIZE], dtype=tl.float32)
+ d_i_g = tl.sum(tl.where(g_mask_1d, d_i, zeros))
+ acc_g = tl.sum(
+ tl.where(g_mask_1d[:, None], acc, tl.zeros_like(acc)),
+ axis=0,
+ )
result = acc_g / d_i_g
tl.store(
out_ptr + batch * stride_ob + head * stride_oh + offs_d,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": "<279 chars \u2014 see below>",
"newString": "<386 chars \u2014 see below>"
} for g in range(GROUP_SIZE):
head = kv_head * GROUP_SIZE + g
g_mask_1d = tl.arange(0, GROUP_SIZE) == g
d_i_g = tl.sum(d_i * g_mask_1d.to(tl.float32))
acc_g = tl.sum(acc * g_mask_1d[:, None].to(tl.float32), axis=0)
result = acc_g / d_i_g for g in range(GROUP_SIZE):
head = kv_head * GROUP_SIZE + g
g_mask_1d = tl.arange(0, GROUP_SIZE) == g
zeros = tl.zeros([GROUP_SIZE], dtype=tl.float32)
d_i_g = tl.sum(tl.where(g_mask_1d, d_i, zeros))
acc_g = tl.sum(
tl.where(g_mask_1d[:, None], acc, tl.zeros_like(acc)),
axis=0,
)
result = acc_g / d_i_gEdit applied successfully.{
"description": "Run check.py after fixing NaN issue",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1",
"timeout": 120000
}PASS
{
"todos": [
{
"content": "Write initial Triton paged attention kernel (solution.py)",
"priority": "high",
"status": "completed"
},
{
"content": "Run check.py to verify correctness across all shapes",
"priority": "high",
"status": "completed"
},
{
"content": "Run benchmark.py to measure performance",
"priority": "high",
"status": "in_progress"
},
{
"content": "Profile and optimize kernel based on results",
"priority": "medium",
"status": "pending"
}
]
}content: 'Write initial Triton paged attention kernel (solution.py)'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py to measure performance'
status: 'in_progress'
priority: 'high'
content: 'Profile and optimize kernel based on results'
status: 'pending'
priority: 'medium'{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:37:03.958000 1120916 torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=0.227 gbps=56.941 ms=0.592
shape=0 variant=compiled tflops=0.222 gbps=55.600 ms=0.606
shape=0 variant=solution tflops=0.456 gbps=114.346 ms=0.295
shape=0 solution_peak_fraction=0.0635
shape=1 variant=eager tflops=0.348 gbps=87.093 ms=3.088
shape=1 vari
... (816 more chars){
"todos": [
{
"priority": "high",
"content": "Write initial Triton paged attention kernel (solution.py)",
"status": "completed"
},
{
"priority": "high",
"content": "Run check.py to verify correctness across all shapes",
"status": "completed"
},
{
"priority": "high",
"content": "Run benchmark.py to measure performance",
"status": "completed"
},
{
"priority": "high",
"content": "Switch to (B, H) grid for better parallelism on small shapes",
"status": "in_progress"
},
{
"priority":
... (264 more chars)content: 'Write initial Triton paged attention kernel (solution.py)'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py to measure performance'
status: 'completed'
priority: 'high'
content: 'Switch to (B, H) grid for better parallelism on small shapes'
status: 'in_progress'
priority: 'high'
content: 'Run check.py and benchmark.py on (B, H) grid version'
status: 'pending'
priority: 'medium'
content: 'Profile and further optimize based on benchmark results'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
1: """Paged attention decode — Triton kernel with online softmax + Tensor Cores.
2:
3: Single-query decode: each batch element has a query (num_heads, head_dim) and
4: attends over a paged KV cache. The grid is (batch, num_kv_heads) so that KV
5: data is loaded once and shared across all query heads in the same GQA group.
6: """
7:
8: import math
9:
10: import torch
11: import triton
12: import triton.language as tl
13:
14: OP_TYPE = "attention"
15: SUPPORTED_PRECISIONS = ["bf16"]
16: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
17:
18: BATCH = 8
19: NUM_HEADS = 32
20: NUM_KV_HEADS = 8
21: HEAD_DIM = 128
22: SEQ_LEN = 1024
23: PAGE_SIZE = 16
24:
25:
26: @triton.jit
27: def _paged_attn_kernel(
28: q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
29: stride_qb, stride_qh,
30: stride_kvb, stride_kvps, stride_kvkh,
31: stride_btb,
32: stride_ob, stride_oh,
33: scale,
34: HEAD_DIM: tl.constexpr,
35: PAGE_SIZE: tl.constexpr,
36: GROUP_SIZE: tl.constexpr,
37: ):
38: batch = tl.program_id(0)
39: kv_head = tl.program_id(1)
40:
41: offs_d = tl.arange(0, HEAD_DIM)
42:
43: L = tl.load(sl_ptr + batch)
44: num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
45:
46: offs_t = tl.arange(0, PAGE_SIZE)
47:
48: # Per-query-head online-softmax state across all pages
49: m_i = tl.full([GROUP_SIZE], float('-inf'
... (5815 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,8 +1,9 @@
-"""Paged attention decode — Triton kernel with online softmax + Tensor Cores.
+"""Paged attention decode — Triton kernel, (B, H) grid for max parallelism.
Single-query decode: each batch element has a query (num_heads, head_dim) and
-attends over a paged KV cache. The grid is (batch, num_kv_heads) so that KV
-data is loaded once and shared across all query heads in the same GQA group.
+attends over a paged KV cache. Grid is (batch, num_heads). Each program handles
+one query head, loads its KV data via page indirection, and computes attention
+with online softmax in fp32.
"""
import math
@@ -36,19 +37,25 @@
GROUP_SIZE: tl.constexpr,
):
batch = tl.program_id(0)
- kv_head = tl.program_id(1)
+ head = tl.program_id(1)
+ kv_head = head // GROUP_SIZE
offs_d = tl.arange(0, HEAD_DIM)
+
+ # Load query once and reuse across all page iterations
+ q = tl.load(
+ q_ptr + batch * stride_qb + head * stride_qh + offs_d
+ ).to(tl.float32)
L = tl.load(sl_ptr + batch)
num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
offs_t = tl.arange(0, PAGE_SIZE)
- # Per-query-head online-softmax state across all pages
- m_i = tl.full([GROUP_SIZE], float('-inf'), dtype=tl.float32)
- d_i = tl.zeros([GROUP_SIZE], dtype=tl.float32)
- acc = tl.zeros([GROUP_SIZE, HEAD_DIM], dtype=tl.float32)
+ # Online-softmax state across all pages
+ m_i = -float('inf')
+ d_i = 0.0
+ acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for p in range(num_pages):
block_id = tl.load(bt_ptr + batch * stride_btb + p)
@@ -58,71 +65,34 @@
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
- # Load K and V in fp32 for elementwise dot products
- k_fp32 = tl.load(
+ k = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
).to(tl.float32)
- v_fp32 = tl.load(
+ v = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
- # Process each query head in the GQA group
- for g in range(GROUP_SIZE):
- head = kv_head * GROUP_SIZE + g
+ scores = tl.sum(q[None, :] * k, axis=1) * scale
+ scores = tl.where(mask, scores, float('-inf'))
- q_fp32 = tl.load(
- q_ptr + batch * stride_qb + head * stride_qh + offs_d
- ).to(tl.float32)
+ m_block = tl.max(scores, 0)
+ m_new = tl.maximum(m_i, m_block)
- # Elementwise dot product, fp32 accumulate
- scores = tl.sum(q_fp32[None, :] * k_fp32, axis=1) * scale
- scores = tl.where(mask, scores, float('-inf'))
+ alpha = tl.exp(m_i - m_new)
+ exp_scores = tl.exp(scores - m_new)
+ block_d = tl.sum(exp_scores)
- # Online safe softmax
- m_block = tl.max(scores, 0)
+ acc = acc * alpha + tl.sum(exp_scores[:, None] * v, axis=0)
+ d_i = d_i * alpha + block_d
+ m_i = m_new
- g_mask_1d = tl.arange(0, GROUP_SIZE) == g
- zeros = tl.zeros([GROUP_SIZE], dtype=tl.float32)
- m_i_g = tl.sum(tl.where(g_mask_1d, m_i, zeros))
- d_i_g = tl.sum(tl.where(g_mask_1d, d_i, zeros))
-
- m_new = tl.maximum(m_i_g, m_block)
-
- alpha = tl.exp(m_i_g - m_new)
- exp_scores = tl.exp(scores - m_new)
- block_d = tl.sum(exp_scores)
-
- acc_g = tl.sum(
- tl.where(g_mask_1d[:, None], acc, tl.zeros_like(acc)),
- axis=0,
- )
- new_acc_g = acc_g * alpha + tl.sum(
- exp_scores[:, None] * v_fp32, axis=0
- )
- new_d_g = d_i_g * alpha + block_d
-
- # Commit updated state for this head
- m_i = tl.where(g_mask_1d, m_new, m_i)
- d_i = tl.where(g_mask_1d, new_d_g, d_i)
- acc = tl.where(g_mask_1d[:, None], new_acc_g[None, :], acc)
-
- # Write results
- for g in range(GROUP_SIZE):
- head = kv_head * GROUP_SIZE + g
- g_mask_1d = tl.arange(0, GROUP_SIZE) == g
- zeros = tl.zeros([GROUP_SIZE], dtype=tl.float32)
- d_i_g = tl.sum(tl.where(g_mask_1d, d_i, zeros))
- acc_g = tl.sum(
- tl.where(g_mask_1d[:, None], acc, tl.zeros_like(acc)),
- axis=0,
- )
- result = acc_g / d_i_g
- tl.store(
- out_ptr + batch * stride_ob + head * stride_oh + offs_d,
- result.to(tl.bfloat16),
- )
+ out = acc / d_i
+ tl.store(
+ out_ptr + batch * stride_ob + head * stride_oh + offs_d,
+ out.to(tl.bfloat16),
+ )
class Model(torch.nn.Module):
@@ -157,13 +127,12 @@
seq_lens: torch.Tensor,
) -> torch.Tensor:
B, H, D = query.shape
- Hkv = self.num_kv_heads
P = self.page_size
G = self.group_size
out = torch.empty(B, H, D, dtype=query.dtype, device=query.device)
- grid = (B, Hkv)
+ grid = (B, H)
_paged_attn_kernel[grid](
query, kv_cache, block_table, seq_lens, out,
query.stride(0), query.stride(1),
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"content": "<4831 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1",
"description": "Run check.py on (B,H) grid version",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python benchmark.py 2>&1",
"description": "Benchmark (B,H) grid version",
"timeout": 300000
}W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:40:02.834000 1122178 torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=0.227 gbps=56.874 ms=0.592
shape=0 variant=compiled tflops=0.221 gbps=55.563 ms=0.606
shape=0 variant=solution tflops=1.151 gbps=288.997 ms=0.117
shape=0 solution_peak_fraction=0.1606
shape=1 variant=eager tflops=0.346 gbps=86.709 ms=3.102
shape=1 vari
... (816 more chars){
"command": "<785 chars \u2014 see below>",
"description": "Test tl.reshape in Triton 3.6",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
# Quick test: does tl.reshape work in Triton 3.6?
import os
os.chdir('/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention')
# Check from the compiled kernel - run a test
import torch
import triton
import triton.language as tl
@triton.jit
def test_reshape_kernel(in_ptr, out_ptr, N: tl.constexpr):
pid = tl.program_id(0)
offs = tl.arange(0, N)
x = tl.load(in_ptr + offs)
x_2d = x[None, :] # (1, N)
# Try reshape
y = tl.reshape(x_2d, [N])
tl.store(out_ptr + offs, y)
x = torch.arange(16, device='cuda', dtype=torch.float32)
y = torch.zeros(16, device='cuda', dtype=torch.float32)
test_reshape_kernel[(1,)](x, y, N=16)
print('reshape works:', y[:4])
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1260, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1089, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 11, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py
@@ -307,8 +307,45 @@
print(f"Group ops: d_i={out[0].item()}, m_i={out[1].item()}, acc_sum={out[2].item()}")
+# Test tl.reshape
+@triton.jit
+def test_reshape_kernel(in_ptr, out_ptr, N: tl.constexpr):
+ pid = tl.program_id(0)
+ offs = tl.arange(0, N)
+ x = tl.load(in_ptr + offs)
+ x_2d = x[None, :] # (1, N)
+ y = tl.reshape(x_2d, [N])
+ tl.store(out_ptr + offs, y)
+
+
+@triton.jit
+def test_dot_kernel(q_ptr, k_ptr, out_ptr, D: tl.constexpr, N: tl.constexpr):
+ pid = tl.program_id(0)
+ offs_d = tl.arange(0, D)
+ offs_n = tl.arange(0, N)
+ q = tl.load(q_ptr + offs_d)
+ k = tl.load(k_ptr + offs_n[:, None] * D + offs_d[None, :])
+ q_2d = q[None, :].to(tl.bfloat16)
+ k_t = tl.trans(k.to(tl.bfloat16))
+ scores_2d = tl.dot(q_2d, k_t)
+ # Try to get scores as 1D
+ scores = tl.reshape(scores_2d, [N])
+ tl.store(out_ptr + offs_n, scores * 0.125)
+
+
+def test_tl_dot():
+ D, N = 128, 16
+ q = torch.randn(D, device='cuda', dtype=torch.bfloat16) * 0.1
+ k = torch.randn(N, D, device='cuda', dtype=torch.bfloat16) * 0.1
+ out = torch.zeros(N, device='cuda', dtype=torch.float32)
+ test_dot_kernel[(1,)](q, k, out, D=D, N=N)
+ ref = (q.float() @ k.float().T * 0.125).float()
+ print(f"tl.reshape + tl.dot works: {(out - ref).abs().max().item()}")
+
+
if __name__ == "__main__":
test_basic()
test_simple_paged()
test_gqa_paged()
test_group_ops()
+ test_tl_dot()
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py",
"oldString": "<109 chars \u2014 see below>",
"newString": "<1305 chars \u2014 see below>"
}if __name__ == "__main__":
test_basic()
test_simple_paged()
test_gqa_paged()
test_group_ops()# Test tl.reshape
@triton.jit
def test_reshape_kernel(in_ptr, out_ptr, N: tl.constexpr):
pid = tl.program_id(0)
offs = tl.arange(0, N)
x = tl.load(in_ptr + offs)
x_2d = x[None, :] # (1, N)
y = tl.reshape(x_2d, [N])
tl.store(out_ptr + offs, y)
@triton.jit
def test_dot_kernel(q_ptr, k_ptr, out_ptr, D: tl.constexpr, N: tl.constexpr):
pid = tl.program_id(0)
offs_d = tl.arange(0, D)
offs_n = tl.arange(0, N)
q = tl.load(q_ptr + offs_d)
k = tl.load(k_ptr + offs_n[:, None] * D + offs_d[None, :])
q_2d = q[None, :].to(tl.bfloat16)
k_t = tl.trans(k.to(tl.bfloat16))
scores_2d = tl.dot(q_2d, k_t)
# Try to get scores as 1D
scores = tl.reshape(scores_2d, [N])
tl.store(out_ptr + offs_n, scores * 0.125)
def test_tl_dot():
D, N = 128, 16
q = torch.randn(D, device='cuda', dtype=torch.bfloat16) * 0.1
k = torch.randn(N, D, device='cuda', dtype=torch.bfloat16) * 0.1
out = torch.zeros(N, device='cuda', dtype=torch.float32)
test_dot_kernel[(1,)](q, k, out, D=D, N=N)
ref = (q.float() @ k.float().T * 0.125).float()
print(f"tl.reshape + tl.dot works: {(out - ref).abs().max().item()}")
if __name__ == "__main__":
test_basic()
test_simple_paged()
test_gqa_paged()
test_group_ops()
test_tl_dot()Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python debug_test.py 2>&1",
"description": "Test tl.reshape and tl.dot",
"timeout": 60000
}Basic kernel works: [2.0, 2.0, 2.0, 2.0]
Simple paged has NaN: False
Max diff: 0.0
GQA paged has NaN: True
GQA Max diff: nan
Group ops: d_i=10.0, m_i=nan, acc_sum=384.0
tl.reshape + tl.dot works: 7.450580596923828e-09
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -42,10 +42,10 @@
offs_d = tl.arange(0, HEAD_DIM)
- # Load query once and reuse across all page iterations
- q = tl.load(
+ # Load query as bf16 once, reuse across all pages via Tensor-Core dot
+ q_bf16 = tl.load(
q_ptr + batch * stride_qb + head * stride_qh + offs_d
- ).to(tl.float32)
+ )
L = tl.load(sl_ptr + batch)
num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
@@ -65,16 +65,21 @@
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
- k = tl.load(
+ # K stays bf16 for Tensor-Core dot; V loads as fp32 for accumulation
+ k_bf16 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
- ).to(tl.float32)
- v = tl.load(
+ )
+ v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
- scores = tl.sum(q[None, :] * k, axis=1) * scale
+ # Tensor-Core dot: bf16 * bf16 -> fp32
+ q_2d = q_bf16[None, :]
+ k_t = tl.trans(k_bf16)
+ scores_2d = tl.dot(q_2d, k_t)
+ scores = tl.reshape(scores_2d, [PAGE_SIZE]) * scale
scores = tl.where(mask, scores, float('-inf'))
m_block = tl.max(scores, 0)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": "<1191 chars \u2014 see below>",
"newString": "<1419 chars \u2014 see below>"
} offs_d = tl.arange(0, HEAD_DIM)
# Load query once and reuse across all page iterations
q = tl.load(
q_ptr + batch * stride_qb + head * stride_qh + offs_d
).to(tl.float32)
L = tl.load(sl_ptr + batch)
num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
offs_t = tl.arange(0, PAGE_SIZE)
# Online-softmax state across all pages
m_i = -float('inf')
d_i = 0.0
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for p in range(num_pages):
block_id = tl.load(bt_ptr + batch * stride_btb + p)
remaining = L - p * PAGE_SIZE
tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
mask = offs_t < tokens_in_page
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
k = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
).to(tl.float32)
v = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
scores = tl.sum(q[None, :] * k, axis=1) * scale
scores = tl.where(mask, scores, float('-inf')) offs_d = tl.arange(0, HEAD_DIM)
# Load query as bf16 once, reuse across all pages via Tensor-Core dot
q_bf16 = tl.load(
q_ptr + batch * stride_qb + head * stride_qh + offs_d
)
L = tl.load(sl_ptr + batch)
num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
offs_t = tl.arange(0, PAGE_SIZE)
# Online-softmax state across all pages
m_i = -float('inf')
d_i = 0.0
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for p in range(num_pages):
block_id = tl.load(bt_ptr + batch * stride_btb + p)
remaining = L - p * PAGE_SIZE
tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
mask = offs_t < tokens_in_page
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
# K stays bf16 for Tensor-Core dot; V loads as fp32 for accumulation
k_bf16 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
)
v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
# Tensor-Core dot: bf16 * bf16 -> fp32
q_2d = q_bf16[None, :]
k_t = tl.trans(k_bf16)
scores_2d = tl.dot(q_2d, k_t)
scores = tl.reshape(scores_2d, [PAGE_SIZE]) * scale
scores = tl.where(mask, scores, float('-inf'))Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"offset": 55,
"limit": 35
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
55: # Online-softmax state across all pages
56: m_i = -float('inf')
57: d_i = 0.0
58: acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
59:
60: for p in range(num_pages):
61: block_id = tl.load(bt_ptr + batch * stride_btb + p)
62: remaining = L - p * PAGE_SIZE
63: tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
64: mask = offs_t < tokens_in_page
65:
66: kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
67:
68: # K stays bf16 for Tensor-Core dot; V loads as fp32 for accumulation
69: k_bf16 = tl.load(
70: kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
71: mask=mask[:, None],
72: )
73: v_fp32 = tl.load(
74: kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
75: mask=mask[:, None],
76: ).to(tl.float32)
77:
78: # Tensor-Core dot: bf16 * bf16 -> fp32
79: q_2d = q_bf16[None, :]
80: k_t = tl.trans(k_bf16)
81: scores_2d = tl.dot(q_2d, k_t)
82: scores = tl.reshape(scores_2d, [PAGE_SIZE]) * scale
83: scores = tl.where(mask, scores, float('-inf'))
84:
85: m_block = tl.max(scores, 0)
86: m_new = tl.maximum(m_i, m_block)
87:
88: alpha = tl.exp(m_i - m_new)
89: exp_scores = tl.exp(
... (84 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"offset": 89,
"limit": 15
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
89: exp_scores = tl.exp(scores - m_new)
90: block_d = tl.sum(exp_scores)
91:
92: acc = acc * alpha + tl.sum(exp_scores[:, None] * v, axis=0)
93: d_i = d_i * alpha + block_d
94: m_i = m_new
95:
96: out = acc / d_i
97: tl.store(
98: out_ptr + batch * stride_ob + head * stride_oh + offs_d,
99: out.to(tl.bfloat16),
100: )
101:
102:
103: class Model(torch.nn.Module):
(Showing lines 89-103 of 176. Use offset=104 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -89,7 +89,7 @@
exp_scores = tl.exp(scores - m_new)
block_d = tl.sum(exp_scores)
- acc = acc * alpha + tl.sum(exp_scores[:, None] * v, axis=0)
+ acc = acc * alpha + tl.sum(exp_scores[:, None] * v_fp32, axis=0)
d_i = d_i * alpha + block_d
m_i = m_new
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": " acc = acc * alpha + tl.sum(exp_scores[:, None] * v, axis=0)",
"newString": " acc = acc * alpha + tl.sum(exp_scores[:, None] * v_fp32, axis=0)"
}Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1",
"description": "Run check.py after Tensor Core optimization",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python benchmark.py 2>&1",
"description": "Benchmark Tensor Core version",
"timeout": 300000
}W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:41:31.562000 1123042 torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=0.224 gbps=56.306 ms=0.598
shape=0 variant=compiled tflops=0.219 gbps=54.941 ms=0.613
shape=0 variant=solution tflops=1.104 gbps=277.055 ms=0.122
shape=0 solution_peak_fraction=0.1539
shape=1 variant=eager tflops=0.345 gbps=86.486 ms=3.110
shape=1 vari
... (816 more chars){
"todos": [
{
"priority": "high",
"content": "Write initial Triton paged attention kernel (solution.py)",
"status": "completed"
},
{
"priority": "high",
"content": "Run check.py to verify correctness across all shapes",
"status": "completed"
},
{
"priority": "high",
"content": "Run benchmark.py to measure performance",
"status": "completed"
},
{
"priority": "high",
"content": "Switch to (B, H) grid for better parallelism on small shapes",
"status": "completed"
},
{
"priority": "
... (393 more chars)content: 'Write initial Triton paged attention kernel (solution.py)'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py to measure performance'
status: 'completed'
priority: 'high'
content: 'Switch to (B, H) grid for better parallelism on small shapes'
status: 'completed'
priority: 'high'
content: 'Revert to elementwise dot (faster than tl.dot for small matmuls)'
status: 'in_progress'
priority: 'high'
content: 'Run check.py and benchmark.py'
status: 'pending'
priority: 'medium'
content: 'Profile and further optimize based on benchmark results'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -65,21 +65,18 @@
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
- # K stays bf16 for Tensor-Core dot; V loads as fp32 for accumulation
- k_bf16 = tl.load(
+ # Load K and V as fp32 for elementwise dot
+ k_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
- )
+ ).to(tl.float32)
v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
- # Tensor-Core dot: bf16 * bf16 -> fp32
- q_2d = q_bf16[None, :]
- k_t = tl.trans(k_bf16)
- scores_2d = tl.dot(q_2d, k_t)
- scores = tl.reshape(scores_2d, [PAGE_SIZE]) * scale
+ # Elementwise dot, fp32 accumulate
+ scores = tl.sum(q_bf16.to(tl.float32)[None, :] * k_fp32, axis=1) * scale
scores = tl.where(mask, scores, float('-inf'))
m_block = tl.max(scores, 0)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": "<717 chars \u2014 see below>",
"newString": "<623 chars \u2014 see below>"
} kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
# K stays bf16 for Tensor-Core dot; V loads as fp32 for accumulation
k_bf16 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
)
v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
# Tensor-Core dot: bf16 * bf16 -> fp32
q_2d = q_bf16[None, :]
k_t = tl.trans(k_bf16)
scores_2d = tl.dot(q_2d, k_t)
scores = tl.reshape(scores_2d, [PAGE_SIZE]) * scale
scores = tl.where(mask, scores, float('-inf')) kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
# Load K and V as fp32 for elementwise dot
k_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None],
).to(tl.float32)
v_fp32 = tl.load(
kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None],
).to(tl.float32)
# Elementwise dot, fp32 accumulate
scores = tl.sum(q_bf16.to(tl.float32)[None, :] * k_fp32, axis=1) * scale
scores = tl.where(mask, scores, float('-inf'))Edit applied successfully.{
"command": "rm /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/debug_test.py",
"description": "Clean up debug file"
}(no output){
"description": "Run check.py after cleanup",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1",
"timeout": 120000
}PASS
{
"command": "<886 chars \u2014 see below>",
"description": "Warmup for profiling",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
import torch, sys, os
sys.path.insert(0, '.')
import solution as sol
import reference as ref_mod
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
ref_mod.BATCH = B
ref_mod.NUM_HEADS = H
ref_mod.NUM_KV_HEADS = Hkv
ref_mod.HEAD_DIM = D
ref_mod.SEQ_LEN = L
ref_mod.PAGE_SIZE = P
model = sol.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in ref_mod.get_inputs()]
# Warmup
for _ in range(10):
model(*inputs)
torch.cuda.synchronize()
# Timed
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
out = model(*inputs)
end.record()
torch.cuda.synchronize()
print(f'Time: {start.elapsed_time(end):.3f} ms')
print(f'Output range: {out.min().item():.6f} to {out.max().item():.6f}')
" > /dev/null && echo "Warmup done"Warmup done
{
"command": "<1163 chars \u2014 see below>",
"description": "Profile with ncu",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
import sys
sys.path.insert(0, '.')
import solution as sol
import reference as ref_mod
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
ref_mod.BATCH = B
ref_mod.NUM_HEADS = H
ref_mod.NUM_KV_HEADS = Hkv
ref_mod.HEAD_DIM = D
ref_mod.SEQ_LEN = L
ref_mod.PAGE_SIZE = P
model = sol.Model(B, H, Hkv, D, L, P).cuda().eval()
import torch
torch.manual_seed(42)
inputs = [t.cuda() for t in ref_mod.get_inputs()]
for _ in range(5):
model(*inputs)
torch.cuda.synchronize()
" 2>&1 | tail -5
echo "---"
# Run ncu for a quick profile
ncu --set full -k _paged_attn_kernel --launch-count 2 --launch-skip 1 -o profile_ncu python3 -c "
import sys
sys.path.insert(0, '.')
import solution as sol
import reference as ref_mod
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
ref_mod.BATCH = B
ref_mod.NUM_HEADS = H
ref_mod.NUM_KV_HEADS = Hkv
ref_mod.HEAD_DIM = D
ref_mod.SEQ_LEN = L
ref_mod.PAGE_SIZE = P
model = sol.Model(B, H, Hkv, D, L, P).cuda().eval()
import torch
torch.manual_seed(42)
inputs = [t.cuda() for t in ref_mod.get_inputs()]
model(*inputs)
torch.cuda.synchronize()
" 2>&1 | tail -30---
==PROF== Connected to process 1124285 (/usr/bin/python3.12)
==ERROR== ERR_NVGPUCTRPERM - The user does not have permission to access NVIDIA GPU Performance Counters on the target device 0. For instructions on enabling permissions and to get more information see https://developer.nvidia.com/ERR_NVGPUCTRPERM
==PROF== Disconnected from process 1124285
{
"command": "<1234 chars \u2014 see below>",
"description": "Kernel metadata and memory info",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
import sys
sys.path.insert(0, '.')
import solution as sol
import reference as ref_mod
import torch
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
ref_mod.BATCH = B
ref_mod.NUM_HEADS = H
ref_mod.NUM_KV_HEADS = Hkv
ref_mod.HEAD_DIM = D
ref_mod.SEQ_LEN = L
ref_mod.PAGE_SIZE = P
model = sol.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in ref_mod.get_inputs()]
# Get the kernel info
ktn = sol._paged_attn_kernel
print('Kernel function:', ktn)
# Run once
out = model(*inputs)
# Print Triton kernel metadata if available
if hasattr(ktn, 'cache_key'):
print('cache_key:', ktn.cache_key)
if hasattr(ktn, 'cached_values'):
print('cached values count:', len(ktn.cached_values))
# Warm up then do detailed timing
for _ in range(10):
model(*inputs)
torch.cuda.synchronize()
# Check memory usage
print('query:', inputs[0].shape, inputs[0].element_size() * inputs[0].numel() / 1024, 'KB')
print('kv_cache:', inputs[1].shape, inputs[1].element_size() * inputs[1].numel() / 1024 / 1024, 'MB')
print('block_table:', inputs[2].shape, inputs[2].element_size() * inputs[2].numel() / 1024, 'KB')
" 2>&1Kernel function: JITFunction(solution:_paged_attn_kernel)
cache_key: 7447de8c95cb0c70611be48ee50461845581a4b987d527af465afcc8544d731f
query: torch.Size([8, 32, 128]) 64.0 KB
kv_cache: torch.Size([520, 16, 8, 256]) 32.5 MB
block_table: torch.Size([8, 64]) 2.0 KB
{
"description": "Find flashinfer source",
"command": "python3 -c \"import flashinfer; print(flashinfer.__file__)\" 2>&1 && find /home/infatoshi/.local/lib/python3.12/site-packages/flashinfer -name \"*.cu\" -o -name \"*.cuh\" | head -10",
"timeout": 10000
}Traceback (most recent call last):
File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'flashinfer'
{
"command": "<2676 chars \u2014 see below>",
"description": "Check Triton debug and num_warps",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
import sys
sys.path.insert(0, '.')
import solution as sol
import reference as ref_mod
import torch
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
ref_mod.BATCH = B; ref_mod.NUM_HEADS = H
ref_mod.NUM_KV_HEADS = Hkv; ref_mod.HEAD_DIM = D
ref_mod.SEQ_LEN = L; ref_mod.PAGE_SIZE = P
model = sol.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in ref_mod.get_inputs()]
# Trigger compilation
out = model(*inputs)
# Try different num_warps
print('Testing different num_warps...')
@triton.jit
def _paged_attn_kernel_nw(
q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
stride_qb, stride_qh,
stride_kvb, stride_kvps, stride_kvkh,
stride_btb,
stride_ob, stride_oh,
scale,
HEAD_DIM: tl.constexpr,
PAGE_SIZE: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
batch = tl.program_id(0)
head = tl.program_id(1)
kv_head = head // GROUP_SIZE
offs_d = tl.arange(0, HEAD_DIM)
q_bf16 = tl.load(q_ptr + batch * stride_qb + head * stride_qh + offs_d)
L = tl.load(sl_ptr + batch)
num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
offs_t = tl.arange(0, PAGE_SIZE)
m_i = -float('inf')
d_i = 0.0
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for p in range(num_pages):
block_id = tl.load(bt_ptr + batch * stride_btb + p)
remaining = L - p * PAGE_SIZE
tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
mask = offs_t < tokens_in_page
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
k_fp32 = tl.load(kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None]).to(tl.float32)
v_fp32 = tl.load(kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None]).to(tl.float32)
scores = tl.sum(q_bf16.to(tl.float32)[None, :] * k_fp32, axis=1) * scale
scores = tl.where(mask, scores, float('-inf'))
m_block = tl.max(scores, 0)
m_new = tl.maximum(m_i, m_block)
alpha = tl.exp(m_i - m_new)
exp_scores = tl.exp(scores - m_new)
block_d = tl.sum(exp_scores)
acc = acc * alpha + tl.sum(exp_scores[:, None] * v_fp32, axis=0)
d_i = d_i * alpha + block_d
m_i = m_new
out = acc / d_i
tl.store(out_ptr + batch * stride_ob + head * stride_oh + offs_d, out.to(tl.bfloat16))
import triton
for nw in [2, 4, 8, 16]:
try:
k = _paged_attn_kernel_nw[grid](...,
num_warps=nw)
except:
pass
print(f' num_warps={nw}')
" 2>&1 | head -20Traceback (most recent call last):
File "<string>", line 23, in <module>
NameError: name 'triton' is not defined
Testing different num_warps...
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1 && echo \"---\" && python benchmark.py 2>&1 | tail -25",
"description": "Final check and benchmark",
"timeout": 300000
}PASS
---
W0427 08:47:52.722000 1125823 torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0427 08:47:52.722000 1125823 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:47:52.722000 1125823 torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=0.226 gbps=56.763 ms=0.593
shape=0 variant=compiled tflops=0.222 gbps=55.600 ms=0.606
shape=0 variant=solution tflops=1.152 gbps=289.037 ms=0.117
shape=0 solution_peak_fraction=0.1606
shape=1 variant=eager tflops=0.347 gbps=86.888 ms=3.095
shape=1 variant=compiled tflops=0.346 gbps=86.695 ms=3.102
shape=1 variant=solution tflops=2.462 gbps=616.767 ms=0.436
shape=1 solution_peak_fraction=0.3426
shape=2 variant=eager tflops=0.243 gbps=30.442 ms=2.209
shape=2 variant=compiled tflops=0.242 gbps=30.350 ms=2.216
shape=2 variant=solution tflops=1.367 gbps=171.146 ms=0.393
shape=2 solution_peak_fraction=0.0951
shape=3 variant=eager tflops=0.301 gbps=75.426 ms=1.337
shape=3 variant=compiled tflops=0.298 gbps=74.712 ms=1.350
shape=3 variant=solution tflops=2.057 gbps=515.517 ms=0.196
shape=3 solution_peak_fraction=0.2864
shape=4 variant=eager tflops=0.110 gbps=27.659 ms=0.594
shape=4 variant=compiled tflops=0.108 gbps=27.002 ms=0.608
shape=4 variant=solution tflops=0.379 gbps=94.838 ms=0.173
shape=4 solution_peak_fraction=0.0527
peak_fraction: 0.1512
RESULT: OK
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
1: """Paged attention decode — Triton kernel, (B, H) grid for max parallelism.
2:
3: Single-query decode: each batch element has a query (num_heads, head_dim) and
4: attends over a paged KV cache. Grid is (batch, num_heads). Each program handles
5: one query head, loads its KV data via page indirection, and computes attention
6: with online softmax in fp32.
7: """
8:
9: import math
10:
11: import torch
12: import triton
13: import triton.language as tl
14:
15: OP_TYPE = "attention"
16: SUPPORTED_PRECISIONS = ["bf16"]
17: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
18:
19: BATCH = 8
20: NUM_HEADS = 32
21: NUM_KV_HEADS = 8
22: HEAD_DIM = 128
23: SEQ_LEN = 1024
24: PAGE_SIZE = 16
25:
26:
27: @triton.jit
28: def _paged_attn_kernel(
29: q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
30: stride_qb, stride_qh,
31: stride_kvb, stride_kvps, stride_kvkh,
32: stride_btb,
33: stride_ob, stride_oh,
34: scale,
35: HEAD_DIM: tl.constexpr,
36: PAGE_SIZE: tl.constexpr,
37: GROUP_SIZE: tl.constexpr,
38: ):
39: batch = tl.program_id(0)
40: head = tl.program_id(1)
41: kv_head = head // GROUP_SIZE
42:
43: offs_d = tl.arange(0, HEAD_DIM)
44:
45: # Load query as bf16 once, reuse across all pages via Tensor-Core dot
46: q_bf16 = tl.load(
47: q_ptr + batch * stride_qb + head * stride_qh + offs_d
48: )
... (4389 more chars){
"command": "<4320 chars \u2014 see below>",
"description": "Test different num_warps for shape 4",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python3 -c "
import sys
sys.path.insert(0, '.')
import solution as sol
import reference as ref_mod
import torch
import triton
import triton.language as tl
# Test different num_warps on shape 4 (worst performer)
B, H, Hkv, D, L, P = 8, 16, 4, 64, 2000, 16
ref_mod.BATCH = B; ref_mod.NUM_HEADS = H
ref_mod.NUM_KV_HEADS = Hkv; ref_mod.HEAD_DIM = D
ref_mod.SEQ_LEN = L; ref_mod.PAGE_SIZE = P
G = H // Hkv
scale = 1.0 / (D ** 0.5)
model = sol.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in ref_mod.get_inputs()]
# Define manual kernel with different num_warps
@triton.jit
def _paged_kernel(
q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
stride_qb, stride_qh,
stride_kvb, stride_kvps, stride_kvkh,
stride_btb,
stride_ob, stride_oh,
scale,
HEAD_DIM: tl.constexpr,
PAGE_SIZE: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
batch = tl.program_id(0)
head = tl.program_id(1)
kv_head = head // GROUP_SIZE
offs_d = tl.arange(0, HEAD_DIM)
q_bf16 = tl.load(q_ptr + batch * stride_qb + head * stride_qh + offs_d)
L = tl.load(sl_ptr + batch)
num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
offs_t = tl.arange(0, PAGE_SIZE)
m_i = -float('inf')
d_i = 0.0
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for p in range(num_pages):
block_id = tl.load(bt_ptr + batch * stride_btb + p)
remaining = L - p * PAGE_SIZE
tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
mask = offs_t < tokens_in_page
kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
k_fp32 = tl.load(kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
mask=mask[:, None]).to(tl.float32)
v_fp32 = tl.load(kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
mask=mask[:, None]).to(tl.float32)
scores = tl.sum(q_bf16.to(tl.float32)[None, :] * k_fp32, axis=1) * scale
scores = tl.where(mask, scores, float('-inf'))
m_block = tl.max(scores, 0)
m_new = tl.maximum(m_i, m_block)
alpha = tl.exp(m_i - m_new)
exp_scores = tl.exp(scores - m_new)
block_d = tl.sum(exp_scores)
acc = acc * alpha + tl.sum(exp_scores[:, None] * v_fp32, axis=0)
d_i = d_i * alpha + block_d
m_i = m_new
out = acc / d_i
tl.store(out_ptr + batch * stride_ob + head * stride_oh + offs_d, out.to(tl.bfloat16))
query, kv_cache, block_table, seq_lens = inputs
out = torch.empty_like(query)
grid = (B, H)
# Warmup original
for _ in range(5):
_paged_kernel[grid](
query, kv_cache, block_table, seq_lens, out,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
block_table.stride(0),
out.stride(0), out.stride(1),
scale,
HEAD_DIM=D, PAGE_SIZE=P, GROUP_SIZE=G,
)
torch.cuda.synchronize()
# Try various num_warps
for nw in [2, 4, 8, 16]:
torch.cuda.empty_cache()
# Warmup
for _ in range(3):
_paged_kernel[grid](
query, kv_cache, block_table, seq_lens, out,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
block_table.stride(0),
out.stride(0), out.stride(1),
scale,
HEAD_DIM=D, PAGE_SIZE=P, GROUP_SIZE=G,
num_warps=nw,
)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(20):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
_paged_kernel[grid](
query, kv_cache, block_table, seq_lens, out,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
block_table.stride(0),
out.stride(0)
... (320 more chars)Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1260, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1089, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 24, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/tune_warps.py
@@ -0,0 +1,138 @@
+"""Tune num_warps for each shape."""
+import sys
+sys.path.insert(0, '.')
+import torch
+import triton
+import triton.language as tl
+import solution as sol
+import reference as ref_mod
+from shapes import SHAPES
+
+@triton.jit
+def _paged_kernel(
+ q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
+ stride_qb, stride_qh,
+ stride_kvb, stride_kvps, stride_kvkh,
+ stride_btb,
+ stride_ob, stride_oh,
+ scale,
+ HEAD_DIM: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ batch = tl.program_id(0)
+ head = tl.program_id(1)
+ kv_head = head // GROUP_SIZE
+ offs_d = tl.arange(0, HEAD_DIM)
+ q_bf16 = tl.load(q_ptr + batch * stride_qb + head * stride_qh + offs_d)
+ L = tl.load(sl_ptr + batch)
+ num_pages = (L + PAGE_SIZE - 1) // PAGE_SIZE
+ offs_t = tl.arange(0, PAGE_SIZE)
+ m_i = -float('inf')
+ d_i = 0.0
+ acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
+ for p in range(num_pages):
+ block_id = tl.load(bt_ptr + batch * stride_btb + p)
+ remaining = L - p * PAGE_SIZE
+ tokens_in_page = tl.where(remaining > PAGE_SIZE, PAGE_SIZE, remaining)
+ mask = offs_t < tokens_in_page
+ kv_base = kv_ptr + block_id * stride_kvb + kv_head * stride_kvkh
+ k_fp32 = tl.load(kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :],
+ mask=mask[:, None]).to(tl.float32)
+ v_fp32 = tl.load(kv_base + offs_t[:, None] * stride_kvps + offs_d[None, :] + HEAD_DIM,
+ mask=mask[:, None]).to(tl.float32)
+ scores = tl.sum(q_bf16.to(tl.float32)[None, :] * k_fp32, axis=1) * scale
+ scores = tl.where(mask, scores, float('-inf'))
+ m_block = tl.max(scores, 0)
+ m_new = tl.maximum(m_i, m_block)
+ alpha = tl.exp(m_i - m_new)
+ exp_scores = tl.exp(scores - m_new)
+ block_d = tl.sum(exp_scores)
+ acc = acc * alpha + tl.sum(exp_scores[:, None] * v_fp32, axis=0)
+ d_i = d_i * alpha + block_d
+ m_i = m_new
+ out = acc / d_i
+ tl.store(out_ptr + batch * stride_ob + head * stride_oh + offs_d, out.to(tl.bfloat16))
+
+
+def tune_shape(shape):
+ B = shape["batch"]
+ H = shape["num_heads"]
+ Hkv = shape["num_kv_heads"]
+ D = shape["head_dim"]
+ L = shape["seq_len"]
+ P = shape["page_size"]
+ G = H // Hkv
+ scale_val = 1.0 / (D ** 0.5)
+
+ ref_mod.BATCH = B; ref_mod.NUM_HEADS = H
+ ref_mod.NUM_KV_HEADS = Hkv; ref_mod.HEAD_DIM = D
+ ref_mod.SEQ_LEN = L; ref_mod.PAGE_SIZE = P
+
+ torch.manual_seed(42)
+ inputs = [t.cuda() for t in ref_mod.get_inputs()]
+ query, kv_cache, block_table, seq_lens = inputs
+ out = torch.empty_like(query)
+ grid = (B, H)
+
+ best_ms = float('inf')
+ best_nw = 4
+
+ for nw in [2, 4, 8, 16]:
+ torch.cuda.empty_cache()
+ try:
+ for _ in range(5):
+ _paged_kernel[grid](
+ query, kv_cache, block_table, seq_lens, out,
+ query.stride(0), query.stride(1),
+ kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
+ block_table.stride(0),
+ out.stride(0), out.stride(1),
+ scale_val,
+ HEAD_DIM=D, PAGE_SIZE=P, GROUP_SIZE=G,
+ num_warps=nw,
+ )
+ torch.cuda.synchronize()
+ except Exception as e:
+ print(f" num_warps={nw} FAIL: {e}")
+ continue
+
+ times = []
+ for _ in range(30):
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ _paged_kernel[grid](
+ query, kv_cache, block_table, seq_lens, out,
+ query.stride(0), query.stride(1),
+ kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
+ block_table.stride(0),
+ out.stride(0), out.stride(1),
+ scale_val,
+ HEAD_DIM=D, PAGE_SIZE=P, GROUP_SIZE=G,
+ num_warps=nw,
+ )
+ end.record()
+ torch.cuda.synchronize()
+ times.append(start.elapsed_time(end))
+ med = sorted(times)[len(times)//2]
+ print(f" shape B={B} H={H} D={D} L={L}: nw={nw} -> {med:.4f} ms")
+ if med < best_ms:
+ best_ms = med
+ best_nw = nw
+
+ print(f" BEST for shape B={B} H={H} D={D}: num_warps={best_nw} ({best_ms:.4f} ms)")
+ return best_nw
+
+
+if __name__ == "__main__":
+ best_nws = {}
+ for i, shape in enumerate(SHAPES):
+ print(f"\n=== Shape {i} ===")
+ try:
+ best_nws[i] = tune_shape(shape)
+ except Exception as e:
+ print(f"Shape {i} ERROR: {e}")
+ import traceback
+ traceback.print_exc()
+ print(f"\nBest num_warps: {best_nws}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/tune_warps.py",
"content": "<4799 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python tune_warps.py 2>&1",
"description": "Tune num_warps for all shapes",
"timeout": 300000
}=== Shape 0 ===
shape B=8 H=32 D=128 L=1024: nw=2 -> 0.0687 ms
shape B=8 H=32 D=128 L=1024: nw=4 -> 0.0908 ms
shape B=8 H=32 D=128 L=1024: nw=8 -> 0.1486 ms
shape B=8 H=32 D=128 L=1024: nw=16 -> 0.2704 ms
BEST for shape B=8 H=32 D=128: num_warps=2 (0.0687 ms)
=== Shape 1 ===
shape B=32 H=32 D=128 L=2048: nw=2 -> 0.2353 ms
shape B=32 H=32 D=128 L=2048: nw=4 -> 0.4340 ms
shape B=32 H=32 D=128 L=2048: nw=8 -> 0.7944 ms
shape B=32 H=32 D=128 L=2048: nw=16 -> 1.5252 ms
BEST for shape B=32 H=32 D=128: num_warps=2 (0.2353 ms)
=== Shape 2 ===
shape B=4 H=64 D=128 L=4096: nw=2 -> 0.2465 ms
shape B=4 H=64 D=128 L=4096: nw=4 -> 0.3289 ms
shape B=4 H=64 D=128 L=4096: nw=8 -> 0.5623 ms
shape B=4 H=64 D=128 L=4096: nw=16 -> 1.0488 ms
BEST for shape B=4 H=64 D=128: num_warps=2 (0.2465 ms)
=== Shape 3 ===
shape B=16 H=32 D=128 L=1535: nw=2 -> 0.1066 ms
shape B=16 H=32 D=128 L=1535: nw=4 -> 0.1731 ms
shape B=16 H=32 D=128 L=1535: nw=8 -> 0.3125 ms
shape B=16 H=32 D=128 L=1535: nw=16 -> 0.5855 ms
BEST for shape B=16 H=32 D=128: num_warps=2 (0.1066 ms)
=== Shape 4 ===
shape B=8 H=16 D=64 L=2000: nw=2 -> 0.0783 ms
shape B=8 H=16 D=64 L=2000: nw=4 -> 0.1004 ms
shape B=8 H=16 D=64 L=2000: nw=8 -> 0.1189 ms
shape B=8 H=16 D=64 L=2000: nw=16 -> 0.1762 ms
BEST for shape B=8 H=16 D=64: num_warps=2 (0.0783 ms)
Best num_warps: {0: 2, 1: 2, 2: 2, 3: 2, 4: 2}
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -145,6 +145,7 @@
HEAD_DIM=D,
PAGE_SIZE=P,
GROUP_SIZE=G,
+ num_warps=2,
)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": "<416 chars \u2014 see below>",
"newString": "<441 chars \u2014 see below>"
} grid = (B, H)
_paged_attn_kernel[grid](
query, kv_cache, block_table, seq_lens, out,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
block_table.stride(0),
out.stride(0), out.stride(1),
self.scale,
HEAD_DIM=D,
PAGE_SIZE=P,
GROUP_SIZE=G,
) grid = (B, H)
_paged_attn_kernel[grid](
query, kv_cache, block_table, seq_lens, out,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2),
block_table.stride(0),
out.stride(0), out.stride(1),
self.scale,
HEAD_DIM=D,
PAGE_SIZE=P,
GROUP_SIZE=G,
num_warps=2,
)Edit applied successfully.{
"todos": [
{
"content": "Write initial Triton paged attention kernel (solution.py)",
"priority": "high",
"status": "completed"
},
{
"content": "Run check.py to verify correctness across all shapes",
"priority": "high",
"status": "completed"
},
{
"content": "Run benchmark.py to measure performance",
"priority": "high",
"status": "completed"
},
{
"content": "Switch to (B, H) grid for better parallelism",
"priority": "high",
"status": "completed"
},
{
"content": "Revert to element
... (464 more chars)content: 'Write initial Triton paged attention kernel (solution.py)'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py to measure performance'
status: 'completed'
priority: 'high'
content: 'Switch to (B, H) grid for better parallelism'
status: 'completed'
priority: 'high'
content: 'Revert to elementwise dot (faster than tl.dot for small matmuls)'
status: 'completed'
priority: 'high'
content: 'Tune num_warps (2 is fastest)'
status: 'completed'
priority: 'high'
content: 'Run final check.py and benchmark.py'
status: 'in_progress'
priority: 'high'
content: 'Clean up temp files'
status: 'pending'
priority: 'medium'{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1",
"description": "Run check.py with num_warps=2",
"timeout": 120000
}PASS
{
"description": "Run benchmark with num_warps=2",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python benchmark.py 2>&1",
"timeout": 300000
}W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 08:50:41.134000 1127153 torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=0.225 gbps=56.347 ms=0.598
shape=0 variant=compiled tflops=0.217 gbps=54.473 ms=0.618
shape=0 variant=solution tflops=1.098 gbps=275.569 ms=0.122
shape=0 solution_peak_fraction=0.1531
shape=1 variant=eager tflops=0.343 gbps=85.805 ms=3.135
shape=1 vari
... (817 more chars){
"command": "rm -f /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/tune_warps.py /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/framework.txt",
"description": "Clean up temp files"
}(no output){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
1: """Paged attention decode — Triton kernel, (B, H) grid for max parallelism.
2:
3: Single-query decode: each batch element has a query (num_heads, head_dim) and
4: attends over a paged KV cache. Grid is (batch, num_heads). Each program handles
5: one query head, loads its KV data via page indirection, and computes attention
6: with online softmax in fp32.
7: """
8:
9: import math
10:
11: import torch
12: import triton
13: import triton.language as tl
14:
15: OP_TYPE = "attention"
16: SUPPORTED_PRECISIONS = ["bf16"]
17: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
18:
19: BATCH = 8
20: NUM_HEADS = 32
21: NUM_KV_HEADS = 8
22: HEAD_DIM = 128
23: SEQ_LEN = 1024
24: PAGE_SIZE = 16
25:
26:
27: @triton.jit
28: def _paged_attn_kernel(
29: q_ptr, kv_ptr, bt_ptr, sl_ptr, out_ptr,
30: stride_qb, stride_qh,
31: stride_kvb, stride_kvps, stride_kvkh,
32: stride_btb,
33: stride_ob, stride_oh,
34: scale,
35: HEAD_DIM: tl.constexpr,
36: PAGE_SIZE: tl.constexpr,
37: GROUP_SIZE: tl.constexpr,
38: ):
39: batch = tl.program_id(0)
40: head = tl.program_id(1)
41: kv_head = head // GROUP_SIZE
42:
43: offs_d = tl.arange(0, HEAD_DIM)
44:
45: # Load query as bf16 once, reuse across all pages via Tensor-Core dot
46: q_bf16 = tl.load(
47: q_ptr + batch * stride_qb + head * stride_qh + offs_d
48: )
... (4419 more chars){
"description": "Run ruff check on solution",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/03_paged_attention/solution.py --fix 2>&1",
"timeout": 30000
}All checks passed!
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/",
"description": "Verify clean directory"
}benchmark.py
check.py
problem.yaml
PROMPT.txt
__pycache__
reference.py
shapes.py
solution.py
sota.py
{
"description": "Final correctness + benchmark run",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention && python check.py 2>&1 && echo \"===\" && python benchmark.py 2>&1 | tail -10",
"timeout": 300000
}PASS
===
shape=3 variant=eager tflops=0.300 gbps=75.271 ms=1.340
shape=3 variant=compiled tflops=0.298 gbps=74.791 ms=1.349
shape=3 variant=solution tflops=2.144 gbps=537.312 ms=0.188
shape=3 solution_peak_fraction=0.2985
shape=4 variant=eager tflops=0.113 gbps=28.219 ms=0.582
shape=4 variant=compiled tflops=0.110 gbps=27.492 ms=0.597
shape=4 variant=solution tflops=0.424 gbps=106.095 ms=0.155
shape=4 solution_peak_fraction=0.0589
peak_fraction: 0.1671
RESULT: OK
{
"todos": [
{
"content": "Write initial Triton paged attention kernel (solution.py)",
"priority": "high",
"status": "completed"
},
{
"content": "Run check.py to verify correctness across all shapes",
"priority": "high",
"status": "completed"
},
{
"content": "Run benchmark.py to measure performance",
"priority": "high",
"status": "completed"
},
{
"content": "Switch to (B, H) grid for better parallelism",
"priority": "high",
"status": "completed"
},
{
"content": "Revert to element
... (464 more chars)content: 'Write initial Triton paged attention kernel (solution.py)'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py to measure performance'
status: 'completed'
priority: 'high'
content: 'Switch to (B, H) grid for better parallelism'
status: 'completed'
priority: 'high'
content: 'Revert to elementwise dot (faster than tl.dot for small matmuls)'
status: 'completed'
priority: 'high'
content: 'Tune num_warps (2 is fastest)'
status: 'completed'
priority: 'high'
content: 'Run final check.py and benchmark.py'
status: 'completed'
priority: 'high'
content: 'Clean up temp files'
status: 'completed'
priority: 'medium'