"""Paged attention decode — Triton kernel, RTX PRO 6000 (sm_120).
Page-based approach: each thread block (batch, head) loads whole pages of
K/V from the paged cache, computes Q·K in a single tl.dot call per page,
and accumulates attention output with online softmax.
This reduces the sequential loop depth by PAGE_SIZE× compared with
position-by-position iteration.
"""
from __future__ import annotations
import math
from typing import List
import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _paged_attention_decode_kernel(
Q, KV_CACHE, BLOCK_TABLE, SEQ_LEN_PTR, OUT,
stride_qb: tl.constexpr, stride_qh: tl.constexpr, stride_qd: tl.constexpr,
stride_kvb, stride_kvps, stride_kvkh, stride_kvd,
stride_bt,
HEAD_DIM: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
NUM_PAGES: tl.constexpr,
):
b = tl.program_id(0)
h = tl.program_id(1)
num_q_heads = tl.num_programs(1)
kv_head = h // (num_q_heads // NUM_KV_HEADS)
off_d = tl.arange(0, HEAD_DIM)
q = tl.load(Q + b * stride_qb + h * stride_qh + off_d * stride_qd)
q_row = q.reshape((1, HEAD_DIM))
att_max = tl.full([1], float("-inf"), dtype=tl.bfloat16)
att_sum = tl.full([1], 1.0, dtype=tl.bfloat16)
out = tl.zeros([PAGE_SIZE], dtype=tl.bfloat16)
for p in tl.static_range(NUM_PAGES):
block_idx = tl.load(BLOCK_TABLE + b * stride_bt + p)
k_block = tl.load(
KV_CACHE + block_idx * stride_kvb
+ tl.arange(0, PAGE_SIZE) * stride_kvps
+ kv_head * stride_kvkh
+ off_d * stride_kvd
)
sim = tl.dot(q_row, tl.trans(k_block))
sim = (sim * (1.0 / math.sqrt(HEAD_DIM))).to(tl.bfloat16)
seq_len = tl.load(SEQ_LEN_PTR + b)
pos_offsets = tl.arange(0, PAGE_SIZE) + p * PAGE_SIZE
valid = pos_offsets < seq_len
sim = tl.where(valid, sim, tl.cast(float("-inf"), tl.bfloat16))
new_max = tl.max(sim)
att_probs = (tl.exp(sim - new_max) * 0.5).to(tl.bfloat16)
alpha = (tl.exp(att_max - new_max) * 0.5).to(tl.bfloat16)
out = out * alpha + tl.where(
tl.static_range(PAGE_SIZE) < 1,
tl.zeros([PAGE_SIZE], dtype=tl.bfloat16),
tl.zeros([1], dtype=tl.bfloat16)
)
out = (out * alpha).to(tl.bfloat16)
out += att_probs.reshape((PAGE_SIZE,)) * tl.load(
KV_CACHE + block_idx * stride_kvb
+ tl.arange(0, PAGE_SIZE) * stride_kvps
+ kv_head * stride_kvkh
+ (HEAD_DIM + off_d) * stride_kvd
)
out = tl.where(valid.reshape((PAGE_SIZE,)), out, tl.zeros([1], dtype=tl.bfloat16))
att_max = new_max
att_sum = att_sum * alpha + tl.sum(att_probs)
out = out / att_sum
tl.store(OUT + b * stride_qb + h * stride_qh + off_d * stride_qd,
out.to(tl.bfloat16))
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 56, in main
sol_model = solution.Model(*init_args).to(device).eval()
^^^^^^^^^^^^^^
AttributeError: module 'solution' has no attribute 'Model'
{
"problem": "03_paged_attention",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-27b",
"reasoning_effort": "",
"has_solution": true,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 594724, "output_tokens": 8135, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 16565, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, problem.yaml, shapes.py to understand interface",
"status": "in_progress",
"priority": "high"
},
{
"content": "Design paged attention kernel approach (Triton vs CUDA C++)",
"status": "pending",
"priority": "high"
},
{
"content": "Implement initial kernel in solution.py",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py for correctness verification",
"status": "pending",
"priority": "high"
},
{
... (258 more chars)content: 'Read reference.py, check.py, benchmark.py, problem.yaml, shapes.py to understand interface'
status: 'in_progress'
priority: 'high'
content: 'Design paged attention kernel approach (Triton vs CUDA C++)'
status: 'pending'
priority: 'high'
content: 'Implement initial kernel in solution.py'
status: 'pending'
priority: 'high'
content: 'Run check.py for correctness verification'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'medium'
content: 'Final verification with check.py and benchmark.py'
status: 'pending'
priority: 'high'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py</path>
<type>file</type>
<content>
1: """Naive PyTorch paged-attention decode reference (correctness oracle, not SOTA).
2:
3: Single-query decode: each batch element has a query of shape (num_heads, head_dim)
4: and attends over a KV cache of `seq_len[b]` tokens stored as fixed-size pages in
5: a global pool. Pages for batch element b are listed in `block_table[b]`.
6:
7: The reference performs the slow path:
8: 1. Gather pages -> contiguous (seq_len, num_kv_heads, head_dim) per batch element.
9: 2. Repeat KV heads for grouped-query (broadcast num_kv_heads -> num_heads).
10: 3. Manual softmax(QK^T / sqrt(d)) @ V in fp32, cast back to bf16.
11:
12: This avoids torch.nn.functional.scaled_dot_product_attention (which is on the
13: forbidden list) so the agent cannot dispatch through SDPA either.
14: """
15: import math
16:
17: import torch
18: import torch.nn as nn
19:
20: OP_TYPE = "attention"
21: SUPPORTED_PRECISIONS = ["bf16"]
22: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
23:
24:
25: # --- Shape knobs (overridden by check.py / benchmark.py from shapes.py) ----
26: BATCH = 8
27: NUM_HEADS = 32
28: NUM_KV_HEADS = 8
29: HEAD_DIM = 128
30: SEQ_LEN = 1024
31: PAGE_SIZE = 16
32:
33:
34: class Model(nn.Module):
35: """Single-query paged attention decode.
36:
37: Forward inputs (all on device):
38: query: (batch, num_heads, head_dim) bf16
3
... (4426 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for paged-attention decode.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: REPO_ROOT = Path(__file__).resolve().parents[2]
14: sys.path.insert(0, str(REPO_ROOT))
15:
16: from src.eval.correctness import check_correctness # noqa: E402
17:
18:
19: def _apply_shape(reference, shape: dict) -> None:
20: reference.BATCH = shape["batch"]
21: reference.NUM_HEADS = shape["num_heads"]
22: reference.NUM_KV_HEADS = shape["num_kv_heads"]
23: reference.HEAD_DIM = shape["head_dim"]
24: reference.SEQ_LEN = shape["seq_len"]
25: reference.PAGE_SIZE = shape["page_size"]
26:
27:
28: def main():
29: try:
30: import reference
31: import shapes
32: import solution
33: except Exception as e:
34: print(f"FAIL: import error: {e}")
35: sys.exit(1)
36:
37: problem_yaml = Path("problem.yaml")
38: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
39:
40: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
41: for forbidden in meta.get("forbidden", []):
42: pat = re.escape(forbidden)
43:
... (2537 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for paged-attention decode.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Decode is memory-bound, so peak_fraction is computed from achieved GB/s vs
7: the GPU's peak DRAM bandwidth.
8: """
9: import sys
10: from math import exp, log
11: from pathlib import Path
12:
13: import torch
14: import yaml
15:
16: REPO_ROOT = Path(__file__).resolve().parents[2]
17: sys.path.insert(0, str(REPO_ROOT))
18:
19: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
20: from src.eval.timing import time_fn # noqa: E402
21: from src.hardware import get as get_hw # noqa: E402
22:
23:
24: def _eval_formula(expr: str, vars: dict) -> float:
25: return float(eval(expr, {"__builtins__": {}}, vars))
26:
27:
28: def _apply_shape(reference, shape: dict) -> None:
29: reference.BATCH = shape["batch"]
30: reference.NUM_HEADS = shape["num_heads"]
31: reference.NUM_KV_HEADS = shape["num_kv_heads"]
32: reference.HEAD_DIM = shape["head_dim"]
33: reference.SEQ_LEN = shape["seq_len"]
34: reference.PAGE_SIZE = shape["page_size"]
35:
36:
37: def main():
38: import reference
39: import shapes
40: import solution
41:
42: meta = yaml.safe_load(Path("problem.y
... (3563 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for paged attention decode.
2:
3: Mix targets:
4: - small batch / long context (Llama-3 8B-style decode)
5: - large batch / medium context (server batched decode)
6: - GQA wide ratio (Llama-3 70B: 64 heads / 8 kv-heads)
7: - non-power-of-2 seq_len (forces predicated tail handling)
8: - head_dim=64 small-head case
9: """
10:
11: SHAPES = [
12: # (B, H, Hkv, D, L, P)
13: {"batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1024, "page_size": 16},
14: {"batch": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 2048, "page_size": 16},
15: {"batch": 4, "num_heads": 64, "num_kv_heads": 8, "head_dim": 128, "seq_len": 4096, "page_size": 16},
16: {"batch": 16, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1535, "page_size": 16}, # non-pow2
17: {"batch": 8, "num_heads": 16, "num_kv_heads": 4, "head_dim": 64, "seq_len": 2000, "page_size": 16}, # small-D, non-pow2
18: ]
(End of file - total 18 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/problem.yaml</path>
<type>file</type>
<content>
1: name: 03_paged_attention
2: display_name: "Paged Attention Decode"
3: precision: bf16
4: regime: memory # decode is bandwidth-bound (KV cache streamed once per token)
5:
6: # Effective FLOPs per call: QK^T + softmax*V across all batches/heads/seq.
7: # Sum-formula evaluated per-shape; the benchmark substitutes batch/num_heads/etc.
8: # 2 * B * H * L * D for QK^T plus 2 * B * H * L * D for AV => 4 * B * H * L * D
9: flops_formula: "4 * batch * num_heads * seq_len * head_dim"
10:
11: # Bytes moved (the real bottleneck): the KV cache must be streamed end-to-end.
12: # K and V each: batch * seq_len * num_kv_heads * head_dim * 2 bytes (bf16)
13: # Q: batch * num_heads * head_dim * 2 (negligible vs KV)
14: # Out: same as Q
15: # So total ~ 2 * (B * L * Hkv * D * 2) + small.
16: bytes_formula: "2 * batch * seq_len * num_kv_heads * head_dim * 2 + batch * num_heads * head_dim * 2 * 2"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram # 1.8 TB/s GDDR7 on RTX PRO 6000 Blackwell
21:
22: tolerance:
23: bfloat16: 0.02
24: float16: 0.02
25:
26: # Forbidden ops -- using any of these in solution.py fails correctness post-hoc.
27: # vllm.attention / flashinfer.batch_decode_with_paged_kv_cache: dispatching to
28: # the SOTA library is the trivial cheat. SD
... (875 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for paged-attention decode.
2:
3: Tries, in order:
4: 1. FlashInfer's BatchDecodeWithPagedKVCacheWrapper (preferred -- portable,
5: supports SM120, GQA, arbitrary head_dim).
6: 2. vLLM's paged_attention_v2 CUDA op (requires its KV-cache layout, more
7: finicky; we adapt the layout on the fly when possible).
8:
9: If neither is importable, is_available() returns False and the benchmark just
10: reports eager + compiled + solution.
11:
12: Agents are FORBIDDEN from importing these in solution.py (see problem.yaml).
13: This file is only for the benchmark's reference line.
14: """
15: from __future__ import annotations
16:
17: import torch
18:
19:
20: def _try_flashinfer(
21: query: torch.Tensor,
22: kv_cache: torch.Tensor,
23: block_table: torch.Tensor,
24: seq_lens: torch.Tensor,
25: num_kv_heads: int,
26: head_dim: int,
27: page_size: int,
28: ) -> torch.Tensor | None:
29: try:
30: import flashinfer # noqa: F401
31: from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
32: except Exception:
33: return None
34:
35: B, H, D = query.shape
36: # FlashInfer expects K and V as separate (num_blocks, page_size, num_kv_heads, head_dim) tensors.
37: # Our reference packs [K|V] on the last dim -- split here.
38: k_cache = kv_cache[..., :D].contig
... (1803 more chars){
"pattern": "def check_correctness",
"path": "/home/infatoshi/cuda/KernelBench-Hard/src"
}Found 1 matches
/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py:
Line 61: def check_correctness(
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"command": "uv run python -c \"import triton; print(triton.__version__)\" 2>&1",
"description": "Check triton availability"
}3.6.0
{
"command": "nvcc --version 2>&1 | head -3",
"description": "Check nvcc version"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>&1",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "uv run python -c \"import torch; print(torch.cuda.get_device_properties(0).major, torch.cuda.get_device_properties(0).minor)\" 2>&1",
"description": "Check torch GPU compute capability"
}12 0
{
"command": "uv run python -c \"import triton; print(triton.runtime.build.build_dir); import triton.lang as tl; print(dir(tl))\" 2>&1 | tail -5",
"description": "Check triton capabilities"
}Traceback (most recent call last):
File "<string>", line 1, in <module>
AttributeError: module 'triton.runtime.build' has no attribute 'build_dir'
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/ | head -20",
"description": "List all problems"
}01_fp8_gemm
02_kda_cutlass
03_paged_attention
04_kahan_softmax
05_topk_bitonic
06_sonic_moe_swiglu
07_w4a16_gemm
08_metal_lightning_attn
{
"pattern": "**/solution.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems"
}No files found{
"path": "/home/infatoshi/cuda/KernelBench-Hard/src",
"pattern": "**/*.py"
}/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/opencode.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/html.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/claude.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/codex.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/kimi.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/cursor.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/diff_util.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/events.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/__main__.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/parsers/droid.py
/home/infatoshi/cuda/KernelBench-Hard/src/viewer/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/sandbox/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/report.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/shapes.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/m4_max.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py
/home/infatoshi/cuda/KernelBench-Hard/src/harness/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/__init__.py{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, problem.yaml, shapes.py to understand interface",
"status": "completed",
"priority": "high"
},
{
"content": "Check Triton SM120 support",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement initial Triton kernel in solution.py",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py for correctness verification",
"status": "pending",
"priority": "high"
},
{
"content": "Profile
... (234 more chars)content: 'Read reference.py, check.py, benchmark.py, problem.yaml, shapes.py to understand interface'
status: 'completed'
priority: 'high'
content: 'Check Triton SM120 support'
status: 'in_progress'
priority: 'high'
content: 'Implement initial Triton kernel in solution.py'
status: 'pending'
priority: 'high'
content: 'Run check.py for correctness verification'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'medium'
content: 'Final verification with check.py and benchmark.py'
status: 'pending'
priority: 'high'{
"command": "<524 chars \u2014 see below>",
"description": "Test Triton on SM120"
}uv run python -c "
import triton
import triton.language as tl
import torch
@triton.jit
def test_kernel(out_ptr, BLOCK_D: tl.constexpr):
pid = tl.program_id(0)
off = pid * BLOCK_D + tl.arange(0, BLOCK_D)
tl.store(out_ptr + off, off.to(tl.float32))
out = torch.empty(128, dtype=torch.float32, device='cuda')
test_kernel[(1,)](out, BLOCK_D=128)
print('Triton SM120 basic test passed:', out[:5].tolist())
print('CUDA cap:', triton.runtime.driver.active.driver.active.get_current_device().compute_capability)
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/test_triton.py
@@ -0,0 +1,19 @@
+import triton
+import triton.language as tl
+import torch
+
+@triton.jit
+def test_kernel(out_ptr, BLOCK_D: tl.constexpr):
+ pid = tl.program_id(0)
+ off = pid * BLOCK_D + tl.arange(0, BLOCK_D)
+ tl.store(out_ptr + off, off.to(tl.float32))
+
+out = torch.empty(128, dtype=torch.float32, device='cuda')
+test_kernel[(1,)](out, BLOCK_D=128)
+print('Triton SM120 basic test passed:', out[:5].tolist())
+
+# Check device
+device = torch.cuda.get_device(0)
+props = torch.cuda.get_device_properties(0)
+print(f'Compute capability: {props.major}.{props.minor}')
+print(f'Device: {props.name}')
{
"content": "<579 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/test_triton.py"
}Wrote file successfully.{
"command": "uv run python -c \"import triton; print(dir(triton.language))\" 2>&1",
"description": "Check triton.language exports"
}['PropagateNan', 'TRITON_MAX_TENSOR_NUMEL', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'abs', 'add', 'advance', 'arange', 'argmax', 'argmin', 'associative_scan', 'assume', 'atomic_add', 'atomic_and', 'atomic_cas', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_xchg', 'atomic_xor', 'bfloat16', 'bitonic_merge', 'block_type', 'broadcast', 'broadcast_to', 'cast', 'cat', 'cdiv', 'ceil', 'clamp', 'condition', 'const', 'constexpr', 'constexpr_type', 'core', 'cos', 'cumprod', 'cumsum', 'debug_barrier', 'device_assert', 'device_print', 'div_rn', 'dot', 'dot_scaled', 'dtype', 'erf', 'exp', 'exp2', 'expand_dims', 'extra', 'fdiv', 'flip', 'float16', 'float32', 'float64', 'float8e4b15', 'float8e4b8', 'float8e4nv', 'float8e5', 'float8e5b16', 'floor', 'fma', 'full', 'gather', 'histogram', 'inline_asm_elementwise', 'int1', 'int16', 'int32', 'int64', 'int8', 'interleave', 'join', 'load', 'load_tensor_descriptor', 'log', 'log2', 'make_block_ptr', 'make_tensor_descriptor', 'map_elementwise', 'math', 'max', 'max_constancy', 'max_contiguous', 'maximum', 'min', 'minimum', 'mul', 'multiple_of', 'num_programs', 'pair_uniform_to_normal', 'permute', 'philox', 'philox_impl', 'pi32_t', 'pointer_type', 'program_id', 'rand', 'rand4x', 'randint', 'randint4x', 'randn', 'randn4x', 'random', 'range', 'ravel', 'reduce', 'reduce_or', 'reshape', 'rsqrt', 'sigmoid', 'sin', 'slice', 'softmax', 'sort', 'split', 'sqrt', 'sqrt_rn'
... (384 more chars){
"command": "uv run python test_triton.py 2>&1",
"description": "Test Triton SM120 support"
}Triton SM120 basic test passed: [0.0, 1.0, 2.0, 3.0, 4.0]
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/test_triton.py", line 16, in <module>
device = torch.cuda.get_device(0)
^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'torch.cuda' has no attribute 'get_device'. Did you mean: '_get_device'?
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, problem.yaml, shapes.py to understand interface",
"status": "completed",
"priority": "high"
},
{
"content": "Check Triton SM120 support",
"status": "completed",
"priority": "high"
},
{
"content": "Implement initial Triton kernel in solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness verification",
"status": "pending",
"priority": "high"
},
{
"content": "Profil
... (236 more chars)content: 'Read reference.py, check.py, benchmark.py, problem.yaml, shapes.py to understand interface'
status: 'completed'
priority: 'high'
content: 'Check Triton SM120 support'
status: 'completed'
priority: 'high'
content: 'Implement initial Triton kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness verification'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'medium'
content: 'Final verification with check.py and benchmark.py'
status: 'pending'
priority: 'high'{
"command": "<122 chars \u2014 see below>",
"description": "Check make_block_ptr docs"
}uv run python -c "
import triton.language as tl
# Check make_block_ptr signature
help(tl.make_block_ptr)
" 2>&1 | head -30Help on function make_block_ptr in module triton.language.core:
make_block_ptr(base: 'tensor', shape, strides, offsets, block_shape, order, _semantic=None)
Returns a pointer to a block in a parent tensor
:param base: The base pointer to the parent tensor
:param shape: The shape of the parent tensor
:param strides: The strides of the parent tensor
:param offsets: The offsets to the block
:param block_shape: The shape of the block
:param order: The order of the original data format
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -0,0 +1,290 @@
+"""Paged attention decode kernel for RTX PRO 6000 Blackwell (sm_120).
+
+Triton-based single-query decode: each (batch, head) pair is one thread block.
+Loops over sequence positions, gathering K/V from the paged cache via the
+block_table on each iteration to keep SM register pressure bounded.
+
+Online softmax is used to handle variable-length sequences with a single pass.
+
+Layout:
+ query: (batch, num_heads, head_dim) bf16
+ kv_cache: (num_blocks, page_size, num_kv_heads, 2*D) bf16 [K|V]
+ block_table: (batch, max_blocks_per_seq) int32
+ seq_lens: (batch,) int32
+ output: (batch, num_heads, head_dim) bf16
+"""
+from __future__ import annotations
+
+import math
+from typing import List
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+# ---------------------------------------------------------------------------
+# Triton kernel
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _paged_attention_decode_kernel(
+ # Pointers
+ Q, # (batch, num_heads, head_dim) bf16
+ KV_CACHE, # (num_blocks, page_size, num_kv_heads, 2*head_dim) bf16
+ BLOCK_TABLE, # (batch, max_blocks) int32
+ SEQ_LEN_PTR, # (batch,) int32
+ OUT, # (batch, num_heads, head_dim) bf16
+ # Query strides
+ stride_qb: tl.constexpr,
+ stride_qh: tl.constexpr,
+ stride_qd: tl.constexpr,
+ # KV cache strides
+ stride_kvb, # stride for num_blocks axis
+ stride_kvps, # stride for page_size axis
+ stride_kvkh, # stride for num_kv_heads axis
+ stride_kvd: tl.constexpr, # stride for last axis
+ # Block table strides
+ stride_bt: tl.constexpr,
+ # Shape parameters (runtime)
+ head_dim: tl.constexpr,
+ num_kv_heads: tl.constexpr,
+ page_size: tl.constexpr,
+ kv_cache_num_blocks, # num_blocks (runtime scalar)
+ kv_cache_shape_ps, # page_size (runtime scalar from shape)
+ kv_cache_shape_kh, # num_kv_heads (runtime scalar from shape)
+ kv_cache_shape_d2, # 2*head_dim (runtime scalar from shape)
+):
+ """Single-query paged-attention decode.
+
+ Each program_id(0) = batch element, program_id(1) = query head.
+ We loop over sequence positions in the KV cache, gather from the
+ paged block-table layout, and accumulate attention output with
+ online softmax.
+ """
+ b_idx = tl.program_id(0)
+ h_idx = tl.program_id(1)
+
+ # GQA: map query head -> kv head
+ num_query_heads = tl.num_programs(1)
+ group_size = num_query_heads // num_kv_heads
+ kv_head = h_idx // group_size
+
+ head_dim_f = float(head_dim)
+
+ # Load seq_len for this batch element
+ seq_len = tl.load(SEQ_LEN_PTR + b_idx)
+
+ # Compute max_num_pages for this batch element
+ max_num_pages = tl.cdiv(seq_len, page_size)
+
+ # ---- Load Q once ----
+ head_offsets = tl.arange(0, head_dim)
+ q = tl.load(
+ Q + b_idx * stride_qb + h_idx * stride_qh + head_offsets * stride_qd,
+ )
+
+ # ---- Online softmax accumulators ----
+ att_max = float("-inf")
+ att_sum: tl.float32 = 1.0
+ out_accum = tl.zeros([head_dim], dtype=tl.float32)
+
+ # ---- Loop over sequence positions ----
+ # For each position, compute the page index and intra-page offset,
+ # then gather K and V from the paged cache.
+ pos = 0
+ while pos < seq_len:
+ # Page index and intra-page offset for this position
+ page_idx = pos // page_size
+ page_off = pos % page_size
+
+ # Look up the block index from block_table[b, page_idx]
+ block_idx = tl.load(BLOCK_TABLE + b_idx * stride_bt + page_idx)
+
+ # ---- Load K from kv_cache[block_idx, page_off, kv_head, :] ----
+ k = tl.load(
+ KV_CACHE
+ + block_idx * stride_kvb
+ + page_off * stride_kvps
+ + kv_head * stride_kvkh
+ + head_offsets * stride_kvd
+ )
+
+ # ---- Compute Q·K / sqrt(D) (fp32) ----
+ sim = tl.sum(q * k) / tl.sqrt(head_dim_f)
+
+ # Mask: positions >= seq_len are already excluded by the loop bound
+
+ # ---- Online softmax update ----
+ new_max: tl.float32 = tl.maximum(att_max, sim)
+ alpha = tl.exp(att_max - new_max)
+ beta = tl.exp(sim - new_max)
+
+ out_accum = out_accum * alpha + beta * tl.to(v, tl.float32)
+ out_accum = out_accum * alpha + beta * tl.to(tl.load(
+ KV_CACHE
+ + block_idx * stride_kvb
+ + page_off * stride_kvps
+ + kv_head * stride_kvkh
+ + head_dim * stride_kvd
+ + head_offsets * stride_kvd
+ ), tl.float32)
+
+ att_max = new_max
+ att_sum = att_sum * alpha + beta
+
+ pos += 1
+
+ # Final normalization
+ out_final = out_accum / att_sum
+
+ # Store result (cast to bf16)
+ tl.store(
+ OUT + b_idx * stride_qb + h_idx * stride_qh + head_offsets * stride_qd,
+ out_final.to(tl.bfloat16),
+ )
+
+
+def _paged_attention_decode(
+ query: torch.Tensor,
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ seq_lens: torch.Tensor,
+ num_kv_heads: int,
+ head_dim: int,
+ page_size: int,
+) -> torch.Tensor:
+ """Launch the Triton kernel for paged attention decode."""
+ batch, num_heads, _ = query.shape
+ num_blocks = kv_cache.shape[0]
+
+ out = torch.empty_like(query)
+
+ # Strides for the tensors
+ stride_qb = query.stride(0)
+ stride_qh = query.stride(1)
+ stride_qd = query.stride(2)
+
+ # KV cache strides: (num_blocks, page_size, num_kv_heads, 2*head_dim)
+ stride_kvb = kv_cache.stride(0)
+ stride_kvps = kv_cache.stride(1)
+ stride_kvk
... (3749 more chars){
"content": "<9343 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -23,126 +23,101 @@
import triton
import triton.language as tl
-# ---------------------------------------------------------------------------
-# Triton kernel
-# ---------------------------------------------------------------------------
@triton.jit
def _paged_attention_decode_kernel(
- # Pointers
- Q, # (batch, num_heads, head_dim) bf16
- KV_CACHE, # (num_blocks, page_size, num_kv_heads, 2*head_dim) bf16
- BLOCK_TABLE, # (batch, max_blocks) int32
- SEQ_LEN_PTR, # (batch,) int32
- OUT, # (batch, num_heads, head_dim) bf16
+ Q, # *strided*(batch, num_heads, head_dim)
+ KV_CACHE, # *strided*(num_blocks, page_size, num_kv_heads, 2*head_dim)
+ BLOCK_TABLE,# int32[batch, max_blocks]
+ SEQ_LEN_PTR,# int32[batch]
+ OUT, # *strided*(batch, num_heads, head_dim)
# Query strides
stride_qb: tl.constexpr,
stride_qh: tl.constexpr,
stride_qd: tl.constexpr,
- # KV cache strides
- stride_kvb, # stride for num_blocks axis
- stride_kvps, # stride for page_size axis
- stride_kvkh, # stride for num_kv_heads axis
- stride_kvd: tl.constexpr, # stride for last axis
- # Block table strides
+ # KV cache strides (all axes)
+ stride_kvb,
+ stride_kvps,
+ stride_kvkh,
+ stride_kvd: tl.constexpr,
+ # Block table stride for page index
stride_bt: tl.constexpr,
- # Shape parameters (runtime)
- head_dim: tl.constexpr,
- num_kv_heads: tl.constexpr,
- page_size: tl.constexpr,
- kv_cache_num_blocks, # num_blocks (runtime scalar)
- kv_cache_shape_ps, # page_size (runtime scalar from shape)
- kv_cache_shape_kh, # num_kv_heads (runtime scalar from shape)
- kv_cache_shape_d2, # 2*head_dim (runtime scalar from shape)
+ # Compile-time shape params
+ HEAD_DIM: tl.constexpr,
+ NUM_KV_HEADS: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
):
- """Single-query paged-attention decode.
-
- Each program_id(0) = batch element, program_id(1) = query head.
- We loop over sequence positions in the KV cache, gather from the
- paged block-table layout, and accumulate attention output with
- online softmax.
- """
+ # Program indices
b_idx = tl.program_id(0)
h_idx = tl.program_id(1)
- # GQA: map query head -> kv head
+ # GQA: map query head g to kv head
num_query_heads = tl.num_programs(1)
- group_size = num_query_heads // num_kv_heads
+ group_size = num_query_heads // NUM_KV_HEADS
kv_head = h_idx // group_size
- head_dim_f = float(head_dim)
-
- # Load seq_len for this batch element
+ # Per-thread head-dim offsets
+ off_d = tl.arange(0, HEAD_DIM)
+
+ # Load seq_len for batch element
seq_len = tl.load(SEQ_LEN_PTR + b_idx)
- # Compute max_num_pages for this batch element
- max_num_pages = tl.cdiv(seq_len, page_size)
-
- # ---- Load Q once ----
- head_offsets = tl.arange(0, head_dim)
- q = tl.load(
- Q + b_idx * stride_qb + h_idx * stride_qh + head_offsets * stride_qd,
- )
-
- # ---- Online softmax accumulators ----
- att_max = float("-inf")
- att_sum: tl.float32 = 1.0
- out_accum = tl.zeros([head_dim], dtype=tl.float32)
-
- # ---- Loop over sequence positions ----
- # For each position, compute the page index and intra-page offset,
- # then gather K and V from the paged cache.
+ # Load all of Q for (b_idx, h_idx, :)
+ q_ptr = Q + b_idx * stride_qb + h_idx * stride_qh + off_d * stride_qd
+ q = tl.load(q_ptr)
+
+ # Online softmax state
+ att_max = tl.full([HEAD_DIM], float("-inf"), dtype=tl.float32)
+ att_sum = tl.full([HEAD_DIM], 1.0, dtype=tl.float32)
+ out = tl.zeros([HEAD_DIM], dtype=tl.float32)
+
+ # Loop over all sequence positions
pos = 0
while pos < seq_len:
- # Page index and intra-page offset for this position
- page_idx = pos // page_size
- page_off = pos % page_size
-
- # Look up the block index from block_table[b, page_idx]
+ # Where in the paged cache is this position?
+ page_idx = pos // PAGE_SIZE
+ page_off = pos % PAGE_SIZE
+
+ # Read block-table entry
block_idx = tl.load(BLOCK_TABLE + b_idx * stride_bt + page_idx)
- # ---- Load K from kv_cache[block_idx, page_off, kv_head, :] ----
+ # Load K: kv_cache[block_idx, page_off, kv_head, 0:D]
k = tl.load(
KV_CACHE
+ block_idx * stride_kvb
+ page_off * stride_kvps
+ kv_head * stride_kvkh
- + head_offsets * stride_kvd
- )
-
- # ---- Compute Q·K / sqrt(D) (fp32) ----
- sim = tl.sum(q * k) / tl.sqrt(head_dim_f)
-
- # Mask: positions >= seq_len are already excluded by the loop bound
-
- # ---- Online softmax update ----
- new_max: tl.float32 = tl.maximum(att_max, sim)
- alpha = tl.exp(att_max - new_max)
- beta = tl.exp(sim - new_max)
-
- out_accum = out_accum * alpha + beta * tl.to(v, tl.float32)
- out_accum = out_accum * alpha + beta * tl.to(tl.load(
+ + off_d * stride_kvd
+ )
+
+ # Q · K (reduce over head_dim)
+ sim = tl.sum(q * k) / tl.sqrt(float(HEAD_DIM))
+
+ # Load V: kv_cache[block_idx, page_off, kv_head, D:2D]
+ v = tl.load(
KV_CACHE
+ block_idx * stride_kvb
+ page_off * stride_kvps
+ kv_head * stride_kvkh
- + head_dim * stride_kvd
- + head_offsets * stride_kvd
- ), tl.float32)
-
+ + (HEAD_DIM + off_d) * stride_kvd
+ )
+
+ # Online softmax update
+ new_max = tl.maximum(att_max, sim)
+ alpha = tl.exp(att_max - new_max
... (5669 more chars){
"content": "<6674 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0427 18:56:47.699000 1242345 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 18
... (1879 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,17 +1,11 @@
-"""Paged attention decode kernel for RTX PRO 6000 Blackwell (sm_120).
+"""Paged attention decode — Triton kernel, RTX PRO 6000 (sm_120).
-Triton-based single-query decode: each (batch, head) pair is one thread block.
-Loops over sequence positions, gathering K/V from the paged cache via the
-block_table on each iteration to keep SM register pressure bounded.
+Page-based approach: each thread block (batch, head) loads whole pages of
+K/V from the paged cache, computes Q·K in a single tl.dot call per page,
+and accumulates attention output with online softmax.
-Online softmax is used to handle variable-length sequences with a single pass.
-
-Layout:
- query: (batch, num_heads, head_dim) bf16
- kv_cache: (num_blocks, page_size, num_kv_heads, 2*D) bf16 [K|V]
- block_table: (batch, max_blocks_per_seq) int32
- seq_lens: (batch,) int32
- output: (batch, num_heads, head_dim) bf16
+This reduces the sequential loop depth by PAGE_SIZE× compared with
+position-by-position iteration.
"""
from __future__ import annotations
@@ -26,196 +20,68 @@
@triton.jit
def _paged_attention_decode_kernel(
- Q, # *strided*(batch, num_heads, head_dim)
- KV_CACHE, # *strided*(num_blocks, page_size, num_kv_heads, 2*head_dim)
- BLOCK_TABLE,# int32[batch, max_blocks]
- SEQ_LEN_PTR,# int32[batch]
- OUT, # *strided*(batch, num_heads, head_dim)
- # Query strides
- stride_qb: tl.constexpr,
- stride_qh: tl.constexpr,
- stride_qd: tl.constexpr,
- # KV cache strides (all axes)
- stride_kvb,
- stride_kvps,
- stride_kvkh,
- stride_kvd: tl.constexpr,
- # Block table stride for page index
- stride_bt: tl.constexpr,
- # Compile-time shape params
+ Q, KV_CACHE, BLOCK_TABLE, SEQ_LEN_PTR, OUT,
+ stride_qb: tl.constexpr, stride_qh: tl.constexpr, stride_qd: tl.constexpr,
+ stride_kvb, stride_kvps, stride_kvkh, stride_kvd,
+ stride_bt,
HEAD_DIM: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
+ NUM_PAGES: tl.constexpr,
):
- # Program indices
- b_idx = tl.program_id(0)
- h_idx = tl.program_id(1)
+ b = tl.program_id(0)
+ h = tl.program_id(1)
+ num_q_heads = tl.num_programs(1)
+ kv_head = h // (num_q_heads // NUM_KV_HEADS)
- # GQA: map query head g to kv head
- num_query_heads = tl.num_programs(1)
- group_size = num_query_heads // NUM_KV_HEADS
- kv_head = h_idx // group_size
-
- # Per-thread head-dim offsets
off_d = tl.arange(0, HEAD_DIM)
- # Load seq_len for batch element
- seq_len = tl.load(SEQ_LEN_PTR + b_idx)
+ q = tl.load(Q + b * stride_qb + h * stride_qh + off_d * stride_qd)
+ q_row = q.reshape((1, HEAD_DIM))
- # Load all of Q for (b_idx, h_idx, :)
- q_ptr = Q + b_idx * stride_qb + h_idx * stride_qh + off_d * stride_qd
- q = tl.load(q_ptr)
+ att_max = tl.full([1], float("-inf"), dtype=tl.bfloat16)
+ att_sum = tl.full([1], 1.0, dtype=tl.bfloat16)
+ out = tl.zeros([PAGE_SIZE], dtype=tl.bfloat16)
- # Online softmax state
- att_max = tl.full([HEAD_DIM], float("-inf"), dtype=tl.float32)
- att_sum = tl.full([HEAD_DIM], 1.0, dtype=tl.float32)
- out = tl.zeros([HEAD_DIM], dtype=tl.float32)
+ for p in tl.static_range(NUM_PAGES):
+ block_idx = tl.load(BLOCK_TABLE + b * stride_bt + p)
- # Loop over all sequence positions
- pos = 0
- while pos < seq_len:
- # Where in the paged cache is this position?
- page_idx = pos // PAGE_SIZE
- page_off = pos % PAGE_SIZE
-
- # Read block-table entry
- block_idx = tl.load(BLOCK_TABLE + b_idx * stride_bt + page_idx)
-
- # Load K: kv_cache[block_idx, page_off, kv_head, 0:D]
- k = tl.load(
- KV_CACHE
- + block_idx * stride_kvb
- + page_off * stride_kvps
+ k_block = tl.load(
+ KV_CACHE + block_idx * stride_kvb
+ + tl.arange(0, PAGE_SIZE) * stride_kvps
+ kv_head * stride_kvkh
+ off_d * stride_kvd
)
- # Q · K (reduce over head_dim)
- sim = tl.sum(q * k) / tl.sqrt(float(HEAD_DIM))
+ sim = tl.dot(q_row, tl.trans(k_block))
+ sim = (sim * (1.0 / math.sqrt(HEAD_DIM))).to(tl.bfloat16)
- # Load V: kv_cache[block_idx, page_off, kv_head, D:2D]
- v = tl.load(
- KV_CACHE
- + block_idx * stride_kvb
- + page_off * stride_kvps
+ seq_len = tl.load(SEQ_LEN_PTR + b)
+ pos_offsets = tl.arange(0, PAGE_SIZE) + p * PAGE_SIZE
+ valid = pos_offsets < seq_len
+ sim = tl.where(valid, sim, tl.cast(float("-inf"), tl.bfloat16))
+
+ new_max = tl.max(sim)
+ att_probs = (tl.exp(sim - new_max) * 0.5).to(tl.bfloat16)
+
+ alpha = (tl.exp(att_max - new_max) * 0.5).to(tl.bfloat16)
+ out = out * alpha + tl.where(
+ tl.static_range(PAGE_SIZE) < 1,
+ tl.zeros([PAGE_SIZE], dtype=tl.bfloat16),
+ tl.zeros([1], dtype=tl.bfloat16)
+ )
+ out = (out * alpha).to(tl.bfloat16)
+ out += att_probs.reshape((PAGE_SIZE,)) * tl.load(
+ KV_CACHE + block_idx * stride_kvb
+ + tl.arange(0, PAGE_SIZE) * stride_kvps
+ kv_head * stride_kvkh
+ (HEAD_DIM + off_d) * stride_kvd
)
+ out = tl.where(valid.reshape((PAGE_SIZE,)), out, tl.zeros([1], dtype=tl.bfloat16))
- # Online softmax update
- new_max = tl.maximum(att_max, sim)
- alpha = tl.exp(att_max - new_max)
- beta = tl.exp(sim - new_max)
- out = out * alpha + beta * v.to(tl.float32)
att_max = new_max
- att_sum = att_sum * alpha + beta
+ att_sum = att_sum * alpha
... (3384 more chars){
"content": "<2921 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.