"""Per-head CUDA paged attention decode kernel.
Architecture: one thread block per (batch, query_head), 128 threads.
Each thread handles one element of D, reduces qk across threads.
"""
import math
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
def _make_scale(head_dim: int) -> float:
return 1.0 / math.sqrt(head_dim)
_module = None
def _compile():
# Put ALL the code in the CUDA source. No C++ wrapper needed.
# We define the kernel AND the C++ interface function in the .cu file.
CUDA_SRC = r"""
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <math.h>
#include <torch/extension.h>
using bfloat = __nv_bfloat16;
__global__ void __launch_bounds__(128) paged_decode_kernel(
bfloat* __restrict__ Out,
const bfloat* __restrict__ Q,
const bfloat* __restrict__ KV,
const int* __restrict__ BlockTable,
const int* __restrict__ SeqLens,
int B, int H, int Hkv, int D, int P, float scale,
int stride_qb, int stride_qo,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
int stride_ob, int stride_oo,
int stride_bt_row
) {
int pid_b = blockIdx.x;
int pid_h = blockIdx.y;
int tid = threadIdx.x;
int pid_hkv = pid_h / (H / Hkv);
int seq_len = SeqLens[pid_b];
int num_pages = (seq_len + P - 1) / P;
// Early exit
if (tid >= D) return;
// Load Q element
float q_val = __bfloat162float(Q[pid_b * stride_qb + pid_h * stride_qo + tid]);
// Accumulators
float acc_num_val = 0.0f;
float acc_denom = 1e-30f;
float acc_max = -1e30f;
// Warp count
int num_warps = (D + 31) / 32;
int my_warp = tid / 32;
for (int pg = 0; pg < num_pages; pg++) {
int page_id = BlockTable[pid_b * stride_bt_row + pg];
int tok_end = min((pg + 1) * P, seq_len);
int tok_start = pg * P;
for (int t = tok_start; t < tok_end; t++) {
int local_t = t - tok_start;
// Load K[tid]
int k_base = page_id * stride_kvb + local_t * stride_kvs + pid_hkv * stride_kvh;
float k_val = __bfloat162float(KV[k_base + tid * stride_kvo]);
// Partial dot product
float partial = q_val * k_val;
// Intra-warp reduction
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
partial += __shfl_xor_sync(0xffffffff, partial, offset);
// Each warp leader writes its sum to shared memory
__shared__ float warp_sum[4];
if (tid % 32 == 0)
warp_sum[tid / 32] = partial;
__syncthreads();
// Compute full qk reduction
float qk = 0.0f;
for (int w = 0; w < num_warps; w++)
qk += warp_sum[w];
qk *= scale;
// Load V[tid]
float v_val = __bfloat162float(KV[k_base + D + tid * stride_kvo]);
// Online softmax (consistent across all threads)
float old_max = acc_max;
float new_max = fmaxf(old_max, qk);
float exp_v = __expf(qk - new_max);
float rescale = __expf(old_max - new_max);
acc_num_val = acc_num_val * rescale + exp_v * v_val;
acc_denom = acc_denom * rescale + exp_v;
acc_max = new_max;
}
}
// Write output
Out[pid_b * stride_ob + pid_h * stride_oo + tid] = __float2bfloat16(acc_num_val / acc_denom);
}
// PyTorch binding
void paged_attention_forward(
torch::Tensor out,
torch::Tensor query,
torch::Tensor kv_cache,
torch::Tensor block_table,
torch::Tensor seq_lens,
float scale)
{
int B = query.size(0);
int H = query.size(1);
int D = query.size(2);
int Hkv = kv_cache.size(2);
int P = kv_cache.size(1);
int stride_bt_row = block_table.stride(0);
dim3 grid(B, H);
dim3 block(128, 1, 1);
paged_decode_kernel<<<grid, block>>>(
(bfloat*)out.data_ptr(),
(const bfloat*)query.data_ptr(),
(const bfloat*)kv_cache.data_ptr(),
block_table.data_ptr<int>(),
seq_lens.data_ptr<int>(),
B, H, Hkv, D, P, scale,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
out.stride(0), out.stride(1),
stride_bt_row);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("paged_attention_forward", &paged_attention_forward, "paged_attention_forward");
}
"""
return load_inline(
name="paged_attn_cuda_v5",
cpp_sources='',
cuda_sources=CUDA_SRC,
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
def _get_module():
global _module
if _module is None:
_module = _compile()
return _module
class Model(nn.Module):
def __init__(self, batch, num_heads, num_kv_heads, head_dim, seq_len, page_size):
super().__init__()
assert num_heads % num_kv_heads == 0
self.batch = batch
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.seq_len = seq_len
self.page_size = page_size
self.scale = 1.0 / math.sqrt(head_dim)
self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
def forward(self, query, kv_cache, block_table, seq_lens):
B, H, D = self.batch, self.num_heads, self.head_dim
Hkv = self.num_kv_heads
out = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device)
_get_module().paged_attention_forward(out, query, kv_cache, block_table, seq_lens, self.scale)
return out
def get_inputs():
import reference
return reference.get_inputs()
def get_init_inputs():
import reference
return reference.get_init_inputs()
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 07:55:56.052000 1012951 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=0.198 gbps=49.772 ms=0.677
shape=0 variant=compiled tflops=0.215 gbps=53.874 ms=0.625
shape=0 variant=solution tflops=0.153 gbps=38.507 ms=0.875
shape=0 solution_peak_fraction=0.0214
shape=1 variant=eager tflops=0.337 gbps=84.403 ms=3.187
shape=1 variant=compiled tflops=0.336 gbps=84.054 ms=3.200
shape=1 variant=solution tflops=0.634 gbps=158.714 ms=1.695
shape=1 solution_peak_fraction=0.0882
shape=2 variant=eager tflops=0.240 gbps=30.118 ms=2.233
shape=2 variant=compiled tflops=0.241 gbps=30.188 ms=2.227
shape=2 variant=solution tflops=0.157 gbps=19.616 ms=3.428
shape=2 solution_peak_fraction=0.0109
shape=3 variant=eager tflops=0.292 gbps=73.197 ms=1.378
shape=3 variant=compiled tflops=0.284 gbps=71.239 ms=1.416
shape=3 variant=solution tflops=0.314 gbps=78.680 ms=1.282
shape=3 solution_peak_fraction=0.0437
shape=4 variant=eager tflops=0.103 gbps=25.856 ms=0.635
shape=4 variant=compiled tflops=0.102 gbps=25.604 ms=0.641
shape=4 variant=solution tflops=0.040 gbps=9.955 ms=1.649
shape=4 solution_peak_fraction=0.0055
peak_fraction: 0.0218
RESULT: LOW
PASS
{
"problem": "03_paged_attention",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-plus",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0218,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 6512307, "output_tokens": 69830, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 38794, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py</path>
<type>file</type>
<content>
1: """Naive PyTorch paged-attention decode reference (correctness oracle, not SOTA).
2:
3: Single-query decode: each batch element has a query of shape (num_heads, head_dim)
4: and attends over a KV cache of `seq_len[b]` tokens stored as fixed-size pages in
5: a global pool. Pages for batch element b are listed in `block_table[b]`.
6:
7: The reference performs the slow path:
8: 1. Gather pages -> contiguous (seq_len, num_kv_heads, head_dim) per batch element.
9: 2. Repeat KV heads for grouped-query (broadcast num_kv_heads -> num_heads).
10: 3. Manual softmax(QK^T / sqrt(d)) @ V in fp32, cast back to bf16.
11:
12: This avoids torch.nn.functional.scaled_dot_product_attention (which is on the
13: forbidden list) so the agent cannot dispatch through SDPA either.
14: """
15: import math
16:
17: import torch
18: import torch.nn as nn
19:
20: OP_TYPE = "attention"
21: SUPPORTED_PRECISIONS = ["bf16"]
22: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
23:
24:
25: # --- Shape knobs (overridden by check.py / benchmark.py from shapes.py) ----
26: BATCH = 8
27: NUM_HEADS = 32
28: NUM_KV_HEADS = 8
29: HEAD_DIM = 128
30: SEQ_LEN = 1024
31: PAGE_SIZE = 16
32:
33:
34: class Model(nn.Module):
35: """Single-query paged attention decode.
36:
37: Forward inputs (all on device):
38: query: (batch, num_heads, head_dim) bf16
3
... (4426 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for paged-attention decode.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Decode is memory-bound, so peak_fraction is computed from achieved GB/s vs
7: the GPU's peak DRAM bandwidth.
8: """
9: import sys
10: from math import exp, log
11: from pathlib import Path
12:
13: import torch
14: import yaml
15:
16: REPO_ROOT = Path(__file__).resolve().parents[2]
17: sys.path.insert(0, str(REPO_ROOT))
18:
19: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
20: from src.eval.timing import time_fn # noqa: E402
21: from src.hardware import get as get_hw # noqa: E402
22:
23:
24: def _eval_formula(expr: str, vars: dict) -> float:
25: return float(eval(expr, {"__builtins__": {}}, vars))
26:
27:
28: def _apply_shape(reference, shape: dict) -> None:
29: reference.BATCH = shape["batch"]
30: reference.NUM_HEADS = shape["num_heads"]
31: reference.NUM_KV_HEADS = shape["num_kv_heads"]
32: reference.HEAD_DIM = shape["head_dim"]
33: reference.SEQ_LEN = shape["seq_len"]
34: reference.PAGE_SIZE = shape["page_size"]
35:
36:
37: def main():
38: import reference
39: import shapes
40: import solution
41:
42: meta = yaml.safe_load(Path("problem.y
... (3563 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for paged-attention decode.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: REPO_ROOT = Path(__file__).resolve().parents[2]
14: sys.path.insert(0, str(REPO_ROOT))
15:
16: from src.eval.correctness import check_correctness # noqa: E402
17:
18:
19: def _apply_shape(reference, shape: dict) -> None:
20: reference.BATCH = shape["batch"]
21: reference.NUM_HEADS = shape["num_heads"]
22: reference.NUM_KV_HEADS = shape["num_kv_heads"]
23: reference.HEAD_DIM = shape["head_dim"]
24: reference.SEQ_LEN = shape["seq_len"]
25: reference.PAGE_SIZE = shape["page_size"]
26:
27:
28: def main():
29: try:
30: import reference
31: import shapes
32: import solution
33: except Exception as e:
34: print(f"FAIL: import error: {e}")
35: sys.exit(1)
36:
37: problem_yaml = Path("problem.yaml")
38: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
39:
40: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
41: for forbidden in meta.get("forbidden", []):
42: pat = re.escape(forbidden)
43:
... (2537 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/problem.yaml</path>
<type>file</type>
<content>
1: name: 03_paged_attention
2: display_name: "Paged Attention Decode"
3: precision: bf16
4: regime: memory # decode is bandwidth-bound (KV cache streamed once per token)
5:
6: # Effective FLOPs per call: QK^T + softmax*V across all batches/heads/seq.
7: # Sum-formula evaluated per-shape; the benchmark substitutes batch/num_heads/etc.
8: # 2 * B * H * L * D for QK^T plus 2 * B * H * L * D for AV => 4 * B * H * L * D
9: flops_formula: "4 * batch * num_heads * seq_len * head_dim"
10:
11: # Bytes moved (the real bottleneck): the KV cache must be streamed end-to-end.
12: # K and V each: batch * seq_len * num_kv_heads * head_dim * 2 bytes (bf16)
13: # Q: batch * num_heads * head_dim * 2 (negligible vs KV)
14: # Out: same as Q
15: # So total ~ 2 * (B * L * Hkv * D * 2) + small.
16: bytes_formula: "2 * batch * seq_len * num_kv_heads * head_dim * 2 + batch * num_heads * head_dim * 2 * 2"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram # 1.8 TB/s GDDR7 on RTX PRO 6000 Blackwell
21:
22: tolerance:
23: bfloat16: 0.02
24: float16: 0.02
25:
26: # Forbidden ops -- using any of these in solution.py fails correctness post-hoc.
27: # vllm.attention / flashinfer.batch_decode_with_paged_kv_cache: dispatching to
28: # the SOTA library is the trivial cheat. SD
... (875 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for paged attention decode.
2:
3: Mix targets:
4: - small batch / long context (Llama-3 8B-style decode)
5: - large batch / medium context (server batched decode)
6: - GQA wide ratio (Llama-3 70B: 64 heads / 8 kv-heads)
7: - non-power-of-2 seq_len (forces predicated tail handling)
8: - head_dim=64 small-head case
9: """
10:
11: SHAPES = [
12: # (B, H, Hkv, D, L, P)
13: {"batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1024, "page_size": 16},
14: {"batch": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 2048, "page_size": 16},
15: {"batch": 4, "num_heads": 64, "num_kv_heads": 8, "head_dim": 128, "seq_len": 4096, "page_size": 16},
16: {"batch": 16, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1535, "page_size": 16}, # non-pow2
17: {"batch": 8, "num_heads": 16, "num_kv_heads": 4, "head_dim": 64, "seq_len": 2000, "page_size": 16}, # small-D, non-pow2
18: ]
(End of file - total 18 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to write a paged attention decode kernel for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: This is single-query decode: each batch element has a query of shape (num_heads, head_dim) and attends over a KV cache of seq_len[b] tokens stored as fixed-size pages in a global pool, with block_table[b] listing which pages belong to batch element b. The KV cache is laid out as (num_blocks, page_size, num_kv_heads, head_dim * 2) where the last dim packs [K | V] so a single gather pulls both. You return the bf16 output of shape (batch, num_heads, head_dim). Correctness tolerance is 0.02 abs/rel on bf16/fp16. The shapes you have to handle are batch=8 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=1024 page_size=16 (Llama-3 8B-style decode), batch=32 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=2048 page_size=16 (server-batched decode), batch=4 num_heads=64 num_kv_heads=8 head_dim=128 seq_len=4096 page_size=16 (Llama-3 70B GQA wide ratio with long context), batch=16 num_heads=32 num_kv_heads=8 head_dim=128
... (1253 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}{
"todos": [
{
"content": "Understand problem requirements and reference implementation",
"status": "completed",
"priority": "high"
},
{
"content": "Write initial Triton paged attention decode kernel",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py and fix correctness issues",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize performance",
"status": "pending",
"priority": "high"
}
]
}content: 'Understand problem requirements and reference implementation'
status: 'completed'
priority: 'high'
content: 'Write initial Triton paged attention decode kernel'
status: 'in_progress'
priority: 'high'
content: 'Run check.py and fix correctness issues'
status: 'pending'
priority: 'high'
content: 'Profile and optimize performance'
status: 'pending'
priority: 'high'{
"command": "python -c \"import triton; print(triton.__version__)\"",
"description": "Check Triton availability and version"
}3.6.0
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware.py"
}{
"command": "nvidia-smi --query-gpu=name,compute_cap,driver_version,memory.total --format=csv,noheader",
"description": "Check GPU details"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 595.58.03, 97887 MiB
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/",
"description": "Check repo root layout"
}AGENTS.md
CLAUDE.md
DEVLOG.md
docs
LICENSE
outputs
problems
pyproject.toml
README.md
scripts
SPEC.md
src
tests
uv.lock
{
"command": "python -c \"import torch; print(torch.cuda.get_device_properties(0)); print(triton.__version__); import triton.language as tl\" 2>&1 | head -20",
"description": "Check device properties and triton language"
}Traceback (most recent call last):
File "<string>", line 1, in <module>
NameError: name 'triton' is not defined
_CudaDeviceProperties(name='NVIDIA RTX PRO 6000 Blackwell Workstation Edition', major=12, minor=0, total_memory=97227MB, multi_processor_count=188, uuid=d594fb7a-abcb-bad5-1fc4-675b478a47df, pci_bus_id=1, pci_device_id=0, pci_domain_id=0, L2_cache_size=128MB)
{
"command": "python -c \"import triton; print(triton.runtime.driver.active.get_current_target())\" 2>&1",
"description": "Check triton current target"
}GPUTarget(backend='cuda', arch=120, warp_size=32)
{
"command": "<312 chars \u2014 see below>",
"description": "Calculate memory sizes for shape 0"
}python -c "
import torch
# Reproduce reference behavior for shape 0
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
print(f'KV cache size: {B*L*Hkv*D*2*2 / 1024/1024:.1f} MB')
print(f'Q size: {B*H*D*2 / 1024:.1f} KB')
print(f'Pages per seq: {(L+P-1)//P}')
print(f'Total bytes: {2*B*L*Hkv*D*2*2 + B*H*D*2*2}')
" 2>&1KV cache size: 32.0 MB
Q size: 64.0 KB
Pages per seq: 64
Total bytes: 67239936
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -0,0 +1,336 @@
+"""Custom Triton paged attention decode kernel.
+
+Layout:
+ query: (batch, num_heads, head_dim) bf16
+ kv_cache: (num_blocks, page_size, num_kv_heads, head_dim * 2) bf16
+ last dim packs [K | V] so a single gather pulls both
+ block_table: (batch, max_blocks) int32
+ seq_lens: (batch,) int32
+
+Output: bf16 (batch, num_heads, head_dim)
+
+Strategy:
+ One program per (batch, kv_head). Each program:
+ 1. Loads block_table[b] pointers for this batch.
+ 2. Iterates over pages, loading K and V blocks from the page cache.
+ 3. For each head in the GQA group, computes QK^T dot product.
+ 4. Accumulates online softmax numerator and denominator.
+ We vectorize the head_dim dimension.
+"""
+import math
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _paged_decode_kernel(
+ # Output: (batch, num_heads, head_dim)
+ Out,
+ # Inputs
+ Q,
+ KVC,
+ BlockTable,
+ SeqLens,
+ # Strides
+ stride_qb, stride_qh, stride_qd,
+ stride_kc_b, stride_kc_s, stride_kc_h, stride_kc_d,
+ stride_bt_b, stride_bt_s,
+ stride_ob, stride_oh, stride_od,
+ # Shapes
+ BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, MAX_SEQ_LEN, PAGE_SIZE,
+ HEAD_DIM_PADDED: tl.constexpr,
+ NUM_PAGES: tl.constexpr,
+ PAGE_SIZE_CONST: tl.constexpr,
+ BLOCK_SIZE_D: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_hkv = tl.program_id(1)
+
+ # Compute seq_len for this batch element
+ seq_len = tl.load(SeqLens + pid_b)
+ num_pages_needed = (seq_len + PAGE_SIZE_CONST - 1) // PAGE_SIZE_CONST
+
+ # Load block table for this batch
+ # block_table[pid_b, :] -> page indices
+ # Load all page indices
+ offs_bt = tl.arange(0, NUM_PAGES)
+ page_indices = tl.load(
+ BlockTable + pid_b * stride_bt_b + offs_bt,
+ mask=offs_bt < num_pages_needed,
+ other=0,
+ )
+
+ # Compute Q pointer for the group of heads
+ # Q[pid_b, pid_hkv * group_size : (pid_hkv+1) * group_size, :]
+ q_base = Q + pid_b * stride_qb + pid_hkv * GROUP_SIZE * stride_qh
+
+ # Output pointer base
+ out_base = Out + pid_b * stride_ob + pid_hkv * GROUP_SIZE * stride_oh
+
+ # Initialize accumulators for each head in the group
+ # We'll process HEAD_DIM_PADDED elements at a time
+ num_d_blocks = HEAD_DIM // BLOCK_SIZE_D
+
+ # Running max and denominator for online softmax (per head in group)
+ # acc_num[group][d]: accumulated numerator
+ # acc_denom[group]: accumulated denominator
+ acc_num = tl.zeros((GROUP_SIZE, BLOCK_SIZE_D), dtype=tl.float32)
+ acc_denom = tl.zeros((GROUP_SIZE,), dtype=tl.float32)
+ acc_max = tl.full((GROUP_SIZE,), -1e30, dtype=tl.float32)
+
+ # Load all Q vectors for the group
+ # q_vecs[group][d_block]
+ offs_d_block = tl.arange(0, BLOCK_SIZE_D)
+ q_ptrs_base = tl.make_block_ptr(
+ base=Q + pid_b * stride_qb,
+ shape=(NUM_HEADS, HEAD_DIM),
+ strides=(stride_qh, stride_qd),
+ offsets=(0, 0),
+ block_shape=(GROUP_SIZE, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+
+ # For each head group member, load Q
+ q_group_ptrs = tl.make_block_ptr(
+ base=Q + pid_b * stride_qb,
+ shape=(NUM_HEADS, HEAD_DIM),
+ strides=(stride_qh, stride_qd),
+ offsets=(pid_hkv * GROUP_SIZE, 0),
+ block_shape=(GROUP_SIZE, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ q_vec = tl.load(q_group_ptrs, boundary_check=(0, 1), padding_option="zero")
+ q_vec = q_vec.to(tl.float32)
+
+ # Scale factor
+ scale = 1.0 / tl.sqrt(HEAD_DIM)
+
+ # Iterate over pages
+ for page_idx in range(NUM_PAGES):
+ if page_idx >= num_pages_needed:
+ break
+
+ page_id = tl.load(BlockTable + pid_b * stride_bt_b + page_idx)
+
+ # Compute K and V pointers for this page
+ # kv_cache[page_id, :, pid_hkv, :] -> (page_size, head_dim * 2)
+ page_base = KVC + page_id * stride_kc_b + pid_hkv * stride_kc_h
+
+ # Load K block: (page_size, head_dim)
+ # Then load V block: (page_size, head_dim) starting at head_dim offset
+
+ # Use block pointers for K
+ k_block_ptrs = tl.make_block_ptr(
+ base=page_base,
+ shape=(PAGE_SIZE_CONST, HEAD_DIM),
+ strides=(stride_kc_s, stride_kc_d),
+ offsets=(0, 0),
+ block_shape=(PAGE_SIZE_CONST, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ k_block = tl.load(k_block_ptrs, boundary_check=(0, 1), padding_option="zero")
+ k_block = k_block.to(tl.float32)
+
+ # V block
+ v_block_ptrs = tl.make_block_ptr(
+ base=page_base + HEAD_DIM,
+ shape=(PAGE_SIZE_CONST, HEAD_DIM),
+ strides=(stride_kc_s, stride_kc_d),
+ offsets=(0, 0),
+ block_shape=(PAGE_SIZE_CONST, BLOCK_SIZE_D),
+ order=(1, 0),
+ )
+ v_block = tl.load(v_block_ptrs, boundary_check=(0, 1), padding_option="zero")
+ v_block = v_block.to(tl.float32)
+
+ # For each token position in the page:
+ for s in range(PAGE_SIZE_CONST):
+ token_global = page_idx * PAGE_SIZE_CONST + s
+ if token_global >= seq_len:
+ break
+
+ # k_vec: (1, BLOCK_SIZE_D)
+ k_vec = tl.broadcast_to(k_block[s:s+1, :], (GROUP_SIZE, BLOCK_SIZE_D))
+
+ # QK^T for the group: (GROUP_SIZE,) dot product
+ qk = tl.sum(q_vec * k_vec, axis=1) * scale
+
+ # Online softmax update
+ old_max = acc_max
+ new_max = tl.maximum(old_max, qk)
+ exp_diff = tl.exp(qk - new_max)
+ rescale = tl.exp(old_max - new_max)
+
+ # V contribution
+ v_vec = tl.broadc
... (5541 more chars){
"content": "<11089 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "<591 chars \u2014 see below>",
"description": "Test initial solution",
"timeout": 60000
}python -c "
import torch
import sys
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
out = model(*inputs)
print('Output shape:', out.shape)
print('Output dtype:', out.dtype)
print('Has NaN:', torch.isnan(out).any().item())
print('Output sample:', out[0, 0, :8])
" 2>&1Traceback (most recent call last):
File "<string>", line 15, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 314, in forward
_paged_decode_v2[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/com
... (1130 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -201,10 +201,9 @@
stride_ob, stride_oh, stride_od,
# Shapes
BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, PAGE_SIZE,
- HEAD_DIM_PACKED: tl.constexpr, # next power of 2
- BLOCK_D: tl.constexpr,
+ HEAD_DIM_PADDED: tl.constexpr,
GROUP_SIZE: tl.constexpr,
-):
+ IS_VAR_LEN: tl.constexpr,
"""Version 2: one program per (batch, head), not per kv_head group.
Simpler indexing, more parallelism.
"""
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": "<446 chars \u2014 see below>",
"newString": "<426 chars \u2014 see below>"
}@triton.jit
def _paged_decode_v2(
Out,
Q,
KVC,
BlockTable,
SeqLens,
# Strides
stride_qb, stride_qh, stride_qd,
stride_kc_b, stride_kc_s, stride_kc_h, stride_kc_d,
stride_bt_b, stride_bt_s,
stride_ob, stride_oh, stride_od,
# Shapes
BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, PAGE_SIZE,
HEAD_DIM_PACKED: tl.constexpr, # next power of 2
BLOCK_D: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):@triton.jit
def _paged_decode_v2(
Out,
Q,
KVC,
BlockTable,
SeqLens,
# Strides
stride_qb, stride_qh, stride_qd,
stride_kc_b, stride_kc_s, stride_kc_h, stride_kc_d,
stride_bt_b, stride_bt_s,
stride_ob, stride_oh, stride_od,
# Shapes
BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, PAGE_SIZE,
HEAD_DIM_PADDED: tl.constexpr,
GROUP_SIZE: tl.constexpr,
IS_VAR_LEN: tl.constexpr,Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -218,7 +218,7 @@
num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
# Offsets in d
- d_offs = tl.arange(0, BLOCK_D)
+ d_offs = tl.arange(0, HEAD_DIM_PADDED)
d_mask = d_offs < HEAD_DIM
# Load Q
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": "<179 chars \u2014 see below>",
"newString": "<187 chars \u2014 see below>"
} seq_len = tl.load(SeqLens + pid_b)
num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
# Offsets in d
d_offs = tl.arange(0, BLOCK_D)
d_mask = d_offs < HEAD_DIM seq_len = tl.load(SeqLens + pid_b)
num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
# Offsets in d
d_offs = tl.arange(0, HEAD_DIM_PADDED)
d_mask = d_offs < HEAD_DIMEdit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -226,7 +226,8 @@
q = tl.load(Q + q_offs, mask=d_mask, other=0.0).to(tl.float32)
# Scale
- scale = 1.0 / tl.sqrt(HEAD_DIM.to(tl.float32) if hasattr(HEAD_DIM, '__float__') else float(HEAD_DIM))
+ SCALE: tl.constexpr = 1.0 / (HEAD_DIM ** 0.5)
+ scale: tl.constexpr = SCALE
# Accumulators
acc_num = tl.zeros([BLOCK_D], dtype=tl.float32)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"newString": "<93 chars \u2014 see below>",
"oldString": "<117 chars \u2014 see below>"
} # Scale
SCALE: tl.constexpr = 1.0 / (HEAD_DIM ** 0.5)
scale: tl.constexpr = SCALE # Scale
scale = 1.0 / tl.sqrt(HEAD_DIM.to(tl.float32) if hasattr(HEAD_DIM, '__float__') else float(HEAD_DIM))Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,22 +1,4 @@
-"""Custom Triton paged attention decode kernel.
-
-Layout:
- query: (batch, num_heads, head_dim) bf16
- kv_cache: (num_blocks, page_size, num_kv_heads, head_dim * 2) bf16
- last dim packs [K | V] so a single gather pulls both
- block_table: (batch, max_blocks) int32
- seq_lens: (batch,) int32
-
-Output: bf16 (batch, num_heads, head_dim)
-
-Strategy:
- One program per (batch, kv_head). Each program:
- 1. Loads block_table[b] pointers for this batch.
- 2. Iterates over pages, loading K and V blocks from the page cache.
- 3. For each head in the GQA group, computes QK^T dot product.
- 4. Accumulates online softmax numerator and denominator.
- We vectorize the head_dim dimension.
-"""
+"""Custom Triton paged attention decode kernel."""
import math
import torch
import torch.nn as nn
@@ -26,245 +8,61 @@
@triton.jit
def _paged_decode_kernel(
- # Output: (batch, num_heads, head_dim)
Out,
- # Inputs
Q,
- KVC,
+ K,
+ V,
BlockTable,
SeqLens,
# Strides
stride_qb, stride_qh, stride_qd,
- stride_kc_b, stride_kc_s, stride_kc_h, stride_kc_d,
- stride_bt_b, stride_bt_s,
- stride_ob, stride_oh, stride_od,
- # Shapes
- BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, MAX_SEQ_LEN, PAGE_SIZE,
- HEAD_DIM_PADDED: tl.constexpr,
- NUM_PAGES: tl.constexpr,
- PAGE_SIZE_CONST: tl.constexpr,
- BLOCK_SIZE_D: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
-):
- pid_b = tl.program_id(0)
- pid_hkv = tl.program_id(1)
-
- # Compute seq_len for this batch element
- seq_len = tl.load(SeqLens + pid_b)
- num_pages_needed = (seq_len + PAGE_SIZE_CONST - 1) // PAGE_SIZE_CONST
-
- # Load block table for this batch
- # block_table[pid_b, :] -> page indices
- # Load all page indices
- offs_bt = tl.arange(0, NUM_PAGES)
- page_indices = tl.load(
- BlockTable + pid_b * stride_bt_b + offs_bt,
- mask=offs_bt < num_pages_needed,
- other=0,
- )
-
- # Compute Q pointer for the group of heads
- # Q[pid_b, pid_hkv * group_size : (pid_hkv+1) * group_size, :]
- q_base = Q + pid_b * stride_qb + pid_hkv * GROUP_SIZE * stride_qh
-
- # Output pointer base
- out_base = Out + pid_b * stride_ob + pid_hkv * GROUP_SIZE * stride_oh
-
- # Initialize accumulators for each head in the group
- # We'll process HEAD_DIM_PADDED elements at a time
- num_d_blocks = HEAD_DIM // BLOCK_SIZE_D
-
- # Running max and denominator for online softmax (per head in group)
- # acc_num[group][d]: accumulated numerator
- # acc_denom[group]: accumulated denominator
- acc_num = tl.zeros((GROUP_SIZE, BLOCK_SIZE_D), dtype=tl.float32)
- acc_denom = tl.zeros((GROUP_SIZE,), dtype=tl.float32)
- acc_max = tl.full((GROUP_SIZE,), -1e30, dtype=tl.float32)
-
- # Load all Q vectors for the group
- # q_vecs[group][d_block]
- offs_d_block = tl.arange(0, BLOCK_SIZE_D)
- q_ptrs_base = tl.make_block_ptr(
- base=Q + pid_b * stride_qb,
- shape=(NUM_HEADS, HEAD_DIM),
- strides=(stride_qh, stride_qd),
- offsets=(0, 0),
- block_shape=(GROUP_SIZE, BLOCK_SIZE_D),
- order=(1, 0),
- )
-
- # For each head group member, load Q
- q_group_ptrs = tl.make_block_ptr(
- base=Q + pid_b * stride_qb,
- shape=(NUM_HEADS, HEAD_DIM),
- strides=(stride_qh, stride_qd),
- offsets=(pid_hkv * GROUP_SIZE, 0),
- block_shape=(GROUP_SIZE, BLOCK_SIZE_D),
- order=(1, 0),
- )
- q_vec = tl.load(q_group_ptrs, boundary_check=(0, 1), padding_option="zero")
- q_vec = q_vec.to(tl.float32)
-
- # Scale factor
- scale = 1.0 / tl.sqrt(HEAD_DIM)
-
- # Iterate over pages
- for page_idx in range(NUM_PAGES):
- if page_idx >= num_pages_needed:
- break
-
- page_id = tl.load(BlockTable + pid_b * stride_bt_b + page_idx)
-
- # Compute K and V pointers for this page
- # kv_cache[page_id, :, pid_hkv, :] -> (page_size, head_dim * 2)
- page_base = KVC + page_id * stride_kc_b + pid_hkv * stride_kc_h
-
- # Load K block: (page_size, head_dim)
- # Then load V block: (page_size, head_dim) starting at head_dim offset
-
- # Use block pointers for K
- k_block_ptrs = tl.make_block_ptr(
- base=page_base,
- shape=(PAGE_SIZE_CONST, HEAD_DIM),
- strides=(stride_kc_s, stride_kc_d),
- offsets=(0, 0),
- block_shape=(PAGE_SIZE_CONST, BLOCK_SIZE_D),
- order=(1, 0),
- )
- k_block = tl.load(k_block_ptrs, boundary_check=(0, 1), padding_option="zero")
- k_block = k_block.to(tl.float32)
-
- # V block
- v_block_ptrs = tl.make_block_ptr(
- base=page_base + HEAD_DIM,
- shape=(PAGE_SIZE_CONST, HEAD_DIM),
- strides=(stride_kc_s, stride_kc_d),
- offsets=(0, 0),
- block_shape=(PAGE_SIZE_CONST, BLOCK_SIZE_D),
- order=(1, 0),
- )
- v_block = tl.load(v_block_ptrs, boundary_check=(0, 1), padding_option="zero")
- v_block = v_block.to(tl.float32)
-
- # For each token position in the page:
- for s in range(PAGE_SIZE_CONST):
- token_global = page_idx * PAGE_SIZE_CONST + s
- if token_global >= seq_len:
- break
-
- # k_vec: (1, BLOCK_SIZE_D)
- k_vec = tl.broadcast_to(k_block[s:s+1, :], (GROUP_SIZE, BLOCK_SIZE_D))
-
- # QK^T for the group: (GROUP_SIZE,) dot product
- qk = tl.sum(q_vec * k_vec, axis=1) * scale
-
- # Online softmax update
- old_max = acc_max
- new_max = tl.maximum(old_max, qk)
- exp_diff = tl.exp(qk - new_max)
-
... (6485 more chars){
"content": "<4186 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -10,15 +10,13 @@
def _paged_decode_kernel(
Out,
Q,
- K,
- V,
+ KV,
BlockTable,
SeqLens,
# Strides
stride_qb, stride_qh, stride_qd,
- stride_kb, stride_ks, stride_kh, stride_kd,
- stride_vb, stride_vs, stride_vh, stride_vd,
- stride_bt_b, stride_bt_s,
+ stride_kvb, stride_kvs, stride_kvh, stride_kvd,
+ stride_bt_b,
stride_ob, stride_oh, stride_od,
# Shapes
BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, PAGE_SIZE,
@@ -50,17 +48,19 @@
for page_i in tl.range(num_pages):
page_id = tl.load(BlockTable + bt_base + page_i)
+ page_base = page_id * stride_kvb + pid_hkv * stride_kvh
+
for s in tl.range(PAGE_SIZE):
global_tok = page_i * PAGE_SIZE + s
if global_tok >= seq_len:
break
- k_ptr = K + page_id * stride_kb + s * stride_ks + pid_hkv * stride_kh + d_offs * stride_kd
+ k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
k = tl.load(k_ptr, mask=d_mask, other=0.0).to(tl.float32)
- qk = tl.sum(q * k)
+ qk = tl.sum(q * k) * scale
- v_ptr = V + page_id * stride_vb + s * stride_vs + pid_hkv * stride_vh + d_offs * stride_vd
+ v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
v = tl.load(v_ptr, mask=d_mask, other=0.0).to(tl.float32)
old_max = acc_max
@@ -81,7 +81,7 @@
class Model(nn.Module):
def __init__(self, batch, num_heads, num_kv_heads, head_dim, seq_len, page_size):
super().__init__()
- assert num_heads % num_kv_heads == 0, "num_heads must be a multiple of num_kv_heads (GQA)"
+ assert num_heads % num_kv_heads == 0
self.batch = batch
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
@@ -97,28 +97,21 @@
Hkv = self.num_kv_heads
D = self.head_dim
P = self.page_size
- G = self.group_size
- D_pad = triton.next_power_of_2(D)
-
- # kv_cache: (num_blocks, page_size, num_kv_heads, head_dim * 2)
- # We need separate K and V views
- kv_cache_4d = kv_cache.view(B * -1 or -1, P, Hkv, D * 2)
- K = kv_cache[..., :D].contiguous()
- V = kv_cache[..., D:].contiguous()
out = torch.empty(B, H, D, dtype=query.dtype, device=query.device)
+ assert kv_cache.stride(3) == 1, "kv_cache must have innermost stride 1"
+
grid = (B * H,)
_paged_decode_kernel[grid](
- out, query, K, V, block_table, seq_lens,
+ out, query, kv_cache, block_table, seq_lens,
query.stride(0), query.stride(1), query.stride(2),
- K.stride(0), K.stride(1), K.stride(2), K.stride(3),
- V.stride(0), V.stride(1), V.stride(2), V.stride(3),
- block_table.stride(0), block_table.stride(1),
+ kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
+ block_table.stride(0),
out.stride(0), out.stride(1), out.stride(2),
B, H, Hkv, D, P,
- D_pad,
- G,
+ triton.next_power_of_2(D),
+ self.group_size,
)
return out
{
"content": "<3822 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "<731 chars \u2014 see below>",
"description": "Test initial solution correctness",
"timeout": 60000
}python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
sol_out = model(*inputs)
ref_out = ref(*inputs)
print('sol shape:', sol_out.shape, ref_out.shape)
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
print('atol 0.02:', torch.allclose(sol_out.float(), ref_out.float(), atol=0.02))
" 2>&1Traceback (most recent call last):
File "<string>", line 17, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 106, in forward
_paged_decode_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton
... (1044 more chars){
"command": "<436 chars \u2014 see below>"
}python -c "
# What does num_pages resolve to in triton when passed as const?
# The issue is num_pages is computed at runtime. Use tl.range with dynamic range
# In triton 3.x, we can use tl.range() for loops with bounds computed at runtime
# The problem is the for page_i in num_pages - we need tl.range()
# Wait, tl.range already handles that. Let me check if the issue is HEAD_DIM is tensor
import torch
print(type(torch.int32))
" 2>&1{
"limit": 15,
"offset": 1,
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
1: """Custom Triton paged attention decode kernel."""
2: import math
3: import torch
4: import torch.nn as nn
5: import triton
6: import triton.language as tl
7:
8:
9: @triton.jit
10: def _paged_decode_kernel(
11: Out,
12: Q,
13: KV,
14: BlockTable,
15: SeqLens,
(Showing lines 1-15 of 127. Use offset=16 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -37,7 +37,7 @@
q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
q = tl.load(q_ptr, mask=d_mask, other=0.0).to(tl.float32)
- scale = 1.0 / tl.sqrt(float(HEAD_DIM))
+ scale = 1.0 / (HEAD_DIM ** 0.5)
acc_num = tl.zeros([HEAD_DIM_PACKED], dtype=tl.float32)
acc_denom = 0.0
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"oldString": " scale = 1.0 / tl.sqrt(float(HEAD_DIM))",
"newString": " scale = 1.0 / (HEAD_DIM ** 0.5)"
}Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -4,6 +4,10 @@
import torch.nn as nn
import triton
import triton.language as tl
+
+
+def _make_scale(head_dim: int) -> float:
+ return 1.0 / math.sqrt(head_dim)
@triton.jit
@@ -19,9 +23,13 @@
stride_bt_b,
stride_ob, stride_oh, stride_od,
# Shapes
- BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, PAGE_SIZE,
+ BATCH, NUM_HEADS, NUM_KV_HEADS,
+ PAGE_SIZE,
+ # constexpr
+ HEAD_DIM: tl.constexpr,
HEAD_DIM_PACKED: tl.constexpr,
GROUP_SIZE: tl.constexpr,
+ SCALE: tl.constexpr,
):
pid_b = tl.program_id(0) // NUM_HEADS
pid_h = tl.program_id(0) % NUM_HEADS
@@ -36,8 +44,6 @@
q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
q = tl.load(q_ptr, mask=d_mask, other=0.0).to(tl.float32)
-
- scale = 1.0 / (HEAD_DIM ** 0.5)
acc_num = tl.zeros([HEAD_DIM_PACKED], dtype=tl.float32)
acc_denom = 0.0
@@ -58,7 +64,7 @@
k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
k = tl.load(k_ptr, mask=d_mask, other=0.0).to(tl.float32)
- qk = tl.sum(q * k) * scale
+ qk = tl.sum(q * k) * SCALE
v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
v = tl.load(v_ptr, mask=d_mask, other=0.0).to(tl.float32)
@@ -97,6 +103,8 @@
Hkv = self.num_kv_heads
D = self.head_dim
P = self.page_size
+ D_pad = triton.next_power_of_2(D)
+ scale = _make_scale(D)
out = torch.empty(B, H, D, dtype=query.dtype, device=query.device)
@@ -109,9 +117,10 @@
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
block_table.stride(0),
out.stride(0), out.stride(1), out.stride(2),
- B, H, Hkv, D, P,
- triton.next_power_of_2(D),
+ B, H, Hkv, P,
+ D, D_pad,
self.group_size,
+ scale,
)
return out
{
"content": "<3993 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "<738 chars \u2014 see below>",
"description": "Test v2 solution correctness",
"timeout": 120000
}uv run python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
sol_out = model(*inputs)
ref_out = ref(*inputs)
print('sol shape:', sol_out.shape, ref_out.shape)
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
print('atol 0.02:', torch.allclose(sol_out.float(), ref_out.float(), atol=0.02))
" 2>&1Traceback (most recent call last):
File "<string>", line 17, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 114, in forward
_paged_decode_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^
... (1106 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -30,6 +30,7 @@
HEAD_DIM_PACKED: tl.constexpr,
GROUP_SIZE: tl.constexpr,
SCALE: tl.constexpr,
+ MAX_PAGES: tl.constexpr,
):
pid_b = tl.program_id(0) // NUM_HEADS
pid_h = tl.program_id(0) % NUM_HEADS
@@ -51,32 +52,33 @@
bt_base = pid_b * stride_bt_b
- for page_i in tl.range(num_pages):
+ for page_i in tl.range(MAX_PAGES):
+ if page_i >= num_pages:
+ continue
+
page_id = tl.load(BlockTable + bt_base + page_i)
page_base = page_id * stride_kvb + pid_hkv * stride_kvh
for s in tl.range(PAGE_SIZE):
global_tok = page_i * PAGE_SIZE + s
- if global_tok >= seq_len:
- break
+ if global_tok < seq_len:
+ k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
+ k = tl.load(k_ptr, mask=d_mask, other=0.0).to(tl.float32)
- k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
- k = tl.load(k_ptr, mask=d_mask, other=0.0).to(tl.float32)
+ qk = tl.sum(q * k) * SCALE
- qk = tl.sum(q * k) * SCALE
+ v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=d_mask, other=0.0).to(tl.float32)
- v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
- v = tl.load(v_ptr, mask=d_mask, other=0.0).to(tl.float32)
+ old_max = acc_max
+ new_max = tl.maximum(old_max, qk)
+ exp_val = tl.exp(qk - new_max)
+ rescale_old = tl.exp(old_max - new_max)
- old_max = acc_max
- new_max = tl.maximum(old_max, qk)
- exp_val = tl.exp(qk - new_max)
- rescale_old = tl.exp(old_max - new_max)
-
- acc_num = acc_num * rescale_old + v * exp_val
- acc_denom = acc_denom * rescale_old + exp_val
- acc_max = new_max
+ acc_num = acc_num * rescale_old + v * exp_val
+ acc_denom = acc_denom * rescale_old + exp_val
+ acc_max = new_max
out = acc_num / acc_denom
@@ -95,6 +97,7 @@
self.seq_len = seq_len
self.page_size = page_size
self.group_size = num_heads // num_kv_heads
+ self.max_pages = (seq_len + page_size - 1) // page_size
self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
def forward(self, query, kv_cache, block_table, seq_lens):
@@ -105,6 +108,7 @@
P = self.page_size
D_pad = triton.next_power_of_2(D)
scale = _make_scale(D)
+ max_pages = self.max_pages
out = torch.empty(B, H, D, dtype=query.dtype, device=query.device)
@@ -121,6 +125,7 @@
D, D_pad,
self.group_size,
scale,
+ max_pages,
)
return out
{
"content": "<4222 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "<738 chars \u2014 see below>",
"description": "Test v3 solution correctness",
"timeout": 120000
}uv run python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
sol_out = model(*inputs)
ref_out = ref(*inputs)
print('sol shape:', sol_out.shape, ref_out.shape)
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
print('atol 0.02:', torch.allclose(sol_out.float(), ref_out.float(), atol=0.02))
" 2>&1Traceback (most recent call last):
File "<string>", line 17, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 118, in forward
_paged_decode_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^
... (1127 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution_triton_v3.py
@@ -0,0 +1,83 @@
+"""Custom Triton paged attention decode kernel."""
+import math
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+def _make_scale(head_dim: int) -> float:
+ return 1.0 / math.sqrt(head_dim)
+
+
+@triton.jit
+def _paged_decode_kernel(
+ Out,
+ Q,
+ KV,
+ BlockTable,
+ SeqLens,
+ # Strides
+ stride_qb, stride_qh, stride_qd,
+ stride_kvb, stride_kvs, stride_kvh, stride_kvd,
+ stride_bt_b,
+ stride_ob, stride_oh, stride_od,
+ # constexpr
+ HEAD_DIM: tl.constexpr,
+ HEAD_DIM_PACKED: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
+
+ pid_hkv = pid_h // num_kv_heads
+
+ seq_len = tl.load(SeqLens + pid_b)
+
+ d_offs = tl.arange(0, HEAD_DIM_PACKED)
+ d_mask = d_offs < HEAD_DIM
+
+ q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
+ q = tl.load(q_ptr, mask=d_mask, other=0.0).to(tl.float32)
+
+ acc_num = tl.zeros([HEAD_DIM_PACKED], dtype=tl.float32)
+ acc_denom = 0.0
+ acc_max = -1e30
+
+ bt_base = pid_b * stride_bt_b
+
+ # Compute total tokens for iteration
+ num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
+ total_tokens = num_pages * PAGE_SIZE
+
+ for t in tl.range(total_tokens):
+ valid = t < seq_len
+ if valid:
+ page_i = t // PAGE_SIZE
+ s = t % PAGE_SIZE
+
+ page_id = tl.load(BlockTable + bt_base + page_i)
+ page_base = page_id * stride_kvb + pid_hkv * stride_kvh
+
+ k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
+ k = tl.load(k_ptr, mask=d_mask, other=0.0).to(tl.float32)
+
+ qk = tl.sum(q * k) * SCALE
+
+ v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=d_mask, other=0.0).to(tl.float32)
+
+ old_max = acc_max
+ new_max = tl.maximum(old_max, qk)
+ exp_val = tl.exp(qk - new_max)
+ rescale_old = tl.exp(old_max - new_max)
+
+ acc_num = acc_num * rescale_old + v * exp_val
+ acc_denom = acc_denom * rescale_old + exp_val
+ acc_max = new_max
+
+ out = acc_num / acc_denom
+
+ out_ptr = Out + pid_b * stride_ob + pid_h * stride_oh + d_offs * stride_od
+ tl.store(out_ptr, out.to(Out.type.element_ty), mask=d_mask)
{
"content": "<2346 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution_triton_v3.py"
}Wrote file successfully.{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/src/",
"description": "Check src layout"
}eval
hardware
harness
__init__.py
__pycache__
sandbox
viewer
{
"command": "<82 chars \u2014 see below>",
"description": "Check CUDA and torch versions"
}python -c "
import torch
print(torch.version.cuda)
print(torch.__version__)
" 2>&113.0
2.11.0+cu130
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,4 +1,13 @@
-"""Custom Triton paged attention decode kernel."""
+"""Custom Triton paged attention decode kernel.
+
+Strategy:
+- One program per (batch, query_head)
+- Each program loads page pointers from block_table for its batch
+- Loads one page of K and V at a time into registers/SRAM
+- Computes qk dot product, online softmax accumulation
+- Uses tl.dot for QK^T and weighted sum
+- GQA: pid_hkv = pid_h // group_size
+"""
import math
import torch
import torch.nn as nn
@@ -20,29 +29,28 @@
# Strides
stride_qb, stride_qh, stride_qd,
stride_kvb, stride_kvs, stride_kvh, stride_kvd,
- stride_bt_b,
+ stride_bt_b, stride_bt_s,
stride_ob, stride_oh, stride_od,
- # Shapes
- BATCH, NUM_HEADS, NUM_KV_HEADS,
- PAGE_SIZE,
- # constexpr
+ # constexpr params
+ NUM_KV_HEADS: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
HEAD_DIM_PACKED: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
SCALE: tl.constexpr,
- MAX_PAGES: tl.constexpr,
+ MAX_SEQ_LEN: tl.constexpr,
):
- pid_b = tl.program_id(0) // NUM_HEADS
- pid_h = tl.program_id(0) % NUM_HEADS
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
pid_hkv = pid_h // GROUP_SIZE
seq_len = tl.load(SeqLens + pid_b)
- num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
d_offs = tl.arange(0, HEAD_DIM_PACKED)
d_mask = d_offs < HEAD_DIM
+ # Load Q: (HEAD_DIM,)
q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
q = tl.load(q_ptr, mask=d_mask, other=0.0).to(tl.float32)
@@ -52,33 +60,37 @@
bt_base = pid_b * stride_bt_b
- for page_i in tl.range(MAX_PAGES):
- if page_i >= num_pages:
- continue
+ max_tokens = MAX_SEQ_LEN
- page_id = tl.load(BlockTable + bt_base + page_i)
+ for t in tl.range(max_tokens):
+ valid = t < seq_len
+ page_i = t // PAGE_SIZE
+ s = t % PAGE_SIZE
+
+ page_id = tl.load(BlockTable + bt_base + page_i * stride_bt_s)
page_base = page_id * stride_kvb + pid_hkv * stride_kvh
- for s in tl.range(PAGE_SIZE):
- global_tok = page_i * PAGE_SIZE + s
- if global_tok < seq_len:
- k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
- k = tl.load(k_ptr, mask=d_mask, other=0.0).to(tl.float32)
+ k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
+ k = tl.load(k_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
- qk = tl.sum(q * k) * SCALE
+ qk = tl.sum(q * k) * SCALE
- v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
- v = tl.load(v_ptr, mask=d_mask, other=0.0).to(tl.float32)
+ v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
- old_max = acc_max
- new_max = tl.maximum(old_max, qk)
- exp_val = tl.exp(qk - new_max)
- rescale_old = tl.exp(old_max - new_max)
+ # Safe logit: if !valid, qk = 0 but we also don't want it to affect softmax
+ # Set qk to -inf for invalid tokens so exp(qk-max)=0
+ qk = tl.where(valid, qk, -1e30)
- acc_num = acc_num * rescale_old + v * exp_val
- acc_denom = acc_denom * rescale_old + exp_val
- acc_max = new_max
+ old_max = acc_max
+ new_max = tl.maximum(old_max, qk)
+ exp_val = tl.exp(qk - new_max)
+ rescale_old = tl.exp(old_max - new_max)
+
+ acc_num = acc_num * rescale_old + v * exp_val
+ acc_denom = acc_denom * rescale_old + exp_val
+ acc_max = new_max
out = acc_num / acc_denom
@@ -97,7 +109,6 @@
self.seq_len = seq_len
self.page_size = page_size
self.group_size = num_heads // num_kv_heads
- self.max_pages = (seq_len + page_size - 1) // page_size
self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
def forward(self, query, kv_cache, block_table, seq_lens):
@@ -108,24 +119,23 @@
P = self.page_size
D_pad = triton.next_power_of_2(D)
scale = _make_scale(D)
- max_pages = self.max_pages
out = torch.empty(B, H, D, dtype=query.dtype, device=query.device)
assert kv_cache.stride(3) == 1, "kv_cache must have innermost stride 1"
- grid = (B * H,)
+ grid = (B, H)
_paged_decode_kernel[grid](
out, query, kv_cache, block_table, seq_lens,
query.stride(0), query.stride(1), query.stride(2),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
- block_table.stride(0),
+ block_table.stride(0), block_table.stride(1),
out.stride(0), out.stride(1), out.stride(2),
- B, H, Hkv, P,
+ Hkv, self.group_size,
D, D_pad,
- self.group_size,
+ P,
scale,
- max_pages,
+ self.seq_len,
)
return out
{
"content": "<4468 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "<738 chars \u2014 see below>",
"description": "Test v4 solution correctness",
"timeout": 180000
}uv run python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
sol_out = model(*inputs)
ref_out = ref(*inputs)
print('sol shape:', sol_out.shape, ref_out.shape)
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
print('atol 0.02:', torch.allclose(sol_out.float(), ref_out.float(), atol=0.02))
" 2>&1sol shape: torch.Size([8, 32, 128]) torch.Size([8, 32, 128])
max abs diff: 3.0517578125e-05
atol 0.02: True
{
"command": "uv run python check.py 2>&1",
"description": "Run check.py for all shapes",
"timeout": 300000
}PASS
{
"todos": [
{
"content": "Correctness passes all shapes",
"status": "completed",
"priority": "high"
},
{
"content": "Optimize performance - currently naive single-token loads",
"status": "in_progress",
"priority": "high"
},
{
"content": "Profile and benchmark final solution",
"status": "pending",
"priority": "high"
}
]
}content: 'Correctness passes all shapes'
status: 'completed'
priority: 'high'
content: 'Optimize performance - currently naive single-token loads'
status: 'in_progress'
priority: 'high'
content: 'Profile and benchmark final solution'
status: 'pending'
priority: 'high'{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark current solution",
"timeout": 300000
}W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0429 07:26:47.252000 997077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 07:26:47.252
... (1188 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,12 +1,15 @@
-"""Custom Triton paged attention decode kernel.
+"""Optimized Triton paged attention decode kernel.
Strategy:
-- One program per (batch, query_head)
-- Each program loads page pointers from block_table for its batch
-- Loads one page of K and V at a time into registers/SRAM
-- Computes qk dot product, online softmax accumulation
-- Uses tl.dot for QK^T and weighted sum
-- GQA: pid_hkv = pid_h // group_size
+- One program per (batch, kv_head)
+- Each program:
+ 1. Pre-loads all page indices from block_table for this batch
+ 2. Loads entire pages of K and V into SRAM at once (PAGE_SIZE × head_dim blocks)
+ 3. For each head in the GQA group, computes QK^T dot products
+ 4. Maintains online softmax accumulators per head
+
+This processes all heads in a GQA group together, reducing redundant K/V loads.
+We load K/V pages as full 2D blocks using tl.load, and compute Q@K^T via row-wise dot.
"""
import math
import torch
@@ -32,6 +35,7 @@
stride_bt_b, stride_bt_s,
stride_ob, stride_oh, stride_od,
# constexpr params
+ NUM_HEADS: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
@@ -39,63 +43,99 @@
PAGE_SIZE: tl.constexpr,
SCALE: tl.constexpr,
MAX_SEQ_LEN: tl.constexpr,
+ BLOCK_BLOCKSIZE: tl.constexpr,
):
pid_b = tl.program_id(0)
- pid_h = tl.program_id(1)
+ pid_hkv = tl.program_id(1)
- pid_hkv = pid_h // GROUP_SIZE
+ pid_h_start = pid_hkv * GROUP_SIZE
+ # Number of pages for this batch element
seq_len = tl.load(SeqLens + pid_b)
+ num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
- d_offs = tl.arange(0, HEAD_DIM_PACKED)
- d_mask = d_offs < HEAD_DIM
-
- # Load Q: (HEAD_DIM,)
- q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
- q = tl.load(q_ptr, mask=d_mask, other=0.0).to(tl.float32)
-
- acc_num = tl.zeros([HEAD_DIM_PACKED], dtype=tl.float32)
- acc_denom = 0.0
- acc_max = -1e30
+ # --- Pre-load page indices for this batch (at most BLOCK_BLOCKSIZE pages) ---
+ page_ids = tl.zeros([BLOCK_BLOCKSIZE], dtype=tl.int32)
+ page_valid = tl.zeros([BLOCK_BLOCKSIZE], dtype=tl.int1)
bt_base = pid_b * stride_bt_b
+ bt_offs = tl.arange(0, BLOCK_BLOCKSIZE)
+ page_mask = bt_offs < num_pages
+ page_ids = tl.load(BlockTable + bt_base + bt_offs * stride_bt_s, mask=page_mask, other=0)
+ page_valid = page_mask
- max_tokens = MAX_SEQ_LEN
+ # --- Load Q vectors for all heads in this group ---
+ # Q pointers: (GROUP_SIZE, HEAD_DIM)
+ q_offs_h = tl.arange(0, GROUP_SIZE)
+ q_offs_d = tl.arange(0, HEAD_DIM_PACKED)
+ q_mask_h = q_offs_h < GROUP_SIZE
+ q_mask_d = q_offs_d < HEAD_DIM
- for t in tl.range(max_tokens):
- valid = t < seq_len
- page_i = t // PAGE_SIZE
- s = t % PAGE_SIZE
+ # q_vals shape: (GROUP_SIZE, HEAD_DIM_PACKED)
+ q_ptrs = Q + (pid_b * stride_qb + (pid_h_start + q_offs_h[:, None]) * stride_qh
+ + q_offs_d[None, :] * stride_qd)
+ q_mask = q_mask_h[:, None] & q_mask_d[None, :]
+ q_vals = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float32)
- page_id = tl.load(BlockTable + bt_base + page_i * stride_bt_s)
+ # --- Accumulators per head in group ---
+ acc_num = tl.zeros([GROUP_SIZE, HEAD_DIM_PACKED], dtype=tl.float32)
+ acc_denom = tl.zeros([GROUP_SIZE], dtype=tl.float32)
+ acc_max = tl.full([GROUP_SIZE], -1e30, dtype=tl.float32)
+ # --- Process each page ---
+ max_pages = (MAX_SEQ_LEN + PAGE_SIZE - 1) // PAGE_SIZE
+
+ for p_idx in tl.range(max_pages):
+ p_valid = p_idx < num_pages
+
+ page_id = tl.load(BlockTable + bt_base + p_idx * stride_bt_s)
page_base = page_id * stride_kvb + pid_hkv * stride_kvh
- k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
- k = tl.load(k_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
+ for s in tl.range(PAGE_SIZE):
+ global_tok = p_idx * PAGE_SIZE + s
+ tok_valid = (global_tok < seq_len) & p_valid
- qk = tl.sum(q * k) * SCALE
+ # Load K: (HEAD_DIM,)
+ k_offs = d_offs = tl.arange(0, HEAD_DIM_PACKED)
+ k_mask = k_offs < HEAD_DIM
- v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
- v = tl.load(v_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
+ k_ptr = KV + page_base + s * stride_kvs + k_offs * stride_kvd
+ k = tl.load(k_ptr, mask=k_mask & tok_valid, other=0.0).to(tl.float32)
- # Safe logit: if !valid, qk = 0 but we also don't want it to affect softmax
- # Set qk to -inf for invalid tokens so exp(qk-max)=0
- qk = tl.where(valid, qk, -1e30)
+ # QK dot: broadcast k to (GROUP_SIZE, HEAD_DIM), dot with q_vals
+ # qk_vals shape: (GROUP_SIZE,)
+ qk = tl.sum(q_vals * k[None, :], axis=1) * SCALE
- old_max = acc_max
- new_max = tl.maximum(old_max, qk)
- exp_val = tl.exp(qk - new_max)
- rescale_old = tl.exp(old_max - new_max)
+ # Load V: (HEAD_DIM,)
+ v_ptr = KV + page_base + s * stride_kvs + (k_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=k_mask & tok_valid, other=0.0).to(tl.float32)
- acc_num = acc_num * rescale_old + v * exp_val
- acc_denom = acc_denom * rescale_old + exp_val
- acc_max = new_max
+ # Logit masking: -inf for invalid
+ qk = tl.where(tok_valid, qk, -1e30)
- out = acc_num / acc_denom
+ # Online softmax update
+ old_max = acc_max
+ new_max = tl.maximum(old_max, qk)
+ exp_val = tl.exp(qk - new_max)
+ rescale = tl.exp(old_max - new_max)
- out_ptr = Out + pid_b
... (1665 more chars){
"content": "<6621 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "<738 chars \u2014 see below>",
"description": "Test v5 solution correctness",
"timeout": 180000
}uv run python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
sol_out = model(*inputs)
ref_out = ref(*inputs)
print('sol shape:', sol_out.shape, ref_out.shape)
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
print('atol 0.02:', torch.allclose(sol_out.float(), ref_out.float(), atol=0.02))
" 2>&1Traceback (most recent call last):
File "<string>", line 17, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 168, in forward
_paged_decode_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^
... (3221 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,15 +1,13 @@
"""Optimized Triton paged attention decode kernel.
Strategy:
-- One program per (batch, kv_head)
-- Each program:
- 1. Pre-loads all page indices from block_table for this batch
- 2. Loads entire pages of K and V into SRAM at once (PAGE_SIZE × head_dim blocks)
- 3. For each head in the GQA group, computes QK^T dot products
- 4. Maintains online softmax accumulators per head
+- One program per (batch, query_head)
+- Load entire page blocks at once using 2D block pointers
+- Process pages sequentially, accumulating online softmax
+- GQA: pid_hkv = pid_h // group_size
-This processes all heads in a GQA group together, reducing redundant K/V loads.
-We load K/V pages as full 2D blocks using tl.load, and compute Q@K^T via row-wise dot.
+Optimization: instead of loading token-by-token, we use
+tl.load with block pointers to fetch entire rows of K/V at once.
"""
import math
import torch
@@ -35,7 +33,6 @@
stride_bt_b, stride_bt_s,
stride_ob, stride_oh, stride_od,
# constexpr params
- NUM_HEADS: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
@@ -43,99 +40,63 @@
PAGE_SIZE: tl.constexpr,
SCALE: tl.constexpr,
MAX_SEQ_LEN: tl.constexpr,
- BLOCK_BLOCKSIZE: tl.constexpr,
):
pid_b = tl.program_id(0)
- pid_hkv = tl.program_id(1)
+ pid_h = tl.program_id(1)
- pid_h_start = pid_hkv * GROUP_SIZE
+ pid_hkv = pid_h // GROUP_SIZE
- # Number of pages for this batch element
seq_len = tl.load(SeqLens + pid_b)
- num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
- # --- Pre-load page indices for this batch (at most BLOCK_BLOCKSIZE pages) ---
- page_ids = tl.zeros([BLOCK_BLOCKSIZE], dtype=tl.int32)
- page_valid = tl.zeros([BLOCK_BLOCKSIZE], dtype=tl.int1)
+ d_offs = tl.arange(0, HEAD_DIM_PACKED)
+ d_mask = d_offs < HEAD_DIM
+
+ # Load Q: (HEAD_DIM,)
+ q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
+ q = tl.load(q_ptr, mask=d_mask, other=0.0).to(tl.float32)
+
+ acc_num = tl.zeros([HEAD_DIM_PACKED], dtype=tl.float32)
+ acc_denom = 0.0
+ acc_max = -1e30
bt_base = pid_b * stride_bt_b
- bt_offs = tl.arange(0, BLOCK_BLOCKSIZE)
- page_mask = bt_offs < num_pages
- page_ids = tl.load(BlockTable + bt_base + bt_offs * stride_bt_s, mask=page_mask, other=0)
- page_valid = page_mask
- # --- Load Q vectors for all heads in this group ---
- # Q pointers: (GROUP_SIZE, HEAD_DIM)
- q_offs_h = tl.arange(0, GROUP_SIZE)
- q_offs_d = tl.arange(0, HEAD_DIM_PACKED)
- q_mask_h = q_offs_h < GROUP_SIZE
- q_mask_d = q_offs_d < HEAD_DIM
+ max_tokens = MAX_SEQ_LEN
- # q_vals shape: (GROUP_SIZE, HEAD_DIM_PACKED)
- q_ptrs = Q + (pid_b * stride_qb + (pid_h_start + q_offs_h[:, None]) * stride_qh
- + q_offs_d[None, :] * stride_qd)
- q_mask = q_mask_h[:, None] & q_mask_d[None, :]
- q_vals = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float32)
+ # Pre-compute page offsets to avoid repeated division/mod
+ for t in tl.range(max_tokens):
+ valid = t < seq_len
- # --- Accumulators per head in group ---
- acc_num = tl.zeros([GROUP_SIZE, HEAD_DIM_PACKED], dtype=tl.float32)
- acc_denom = tl.zeros([GROUP_SIZE], dtype=tl.float32)
- acc_max = tl.full([GROUP_SIZE], -1e30, dtype=tl.float32)
+ page_i = t // PAGE_SIZE
+ s = t % PAGE_SIZE
- # --- Process each page ---
- max_pages = (MAX_SEQ_LEN + PAGE_SIZE - 1) // PAGE_SIZE
+ page_id = tl.load(BlockTable + bt_base + page_i * stride_bt_s)
- for p_idx in tl.range(max_pages):
- p_valid = p_idx < num_pages
-
- page_id = tl.load(BlockTable + bt_base + p_idx * stride_bt_s)
page_base = page_id * stride_kvb + pid_hkv * stride_kvh
- for s in tl.range(PAGE_SIZE):
- global_tok = p_idx * PAGE_SIZE + s
- tok_valid = (global_tok < seq_len) & p_valid
+ k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
+ k = tl.load(k_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
- # Load K: (HEAD_DIM,)
- k_offs = d_offs = tl.arange(0, HEAD_DIM_PACKED)
- k_mask = k_offs < HEAD_DIM
+ qk = tl.sum(q * k) * SCALE
- k_ptr = KV + page_base + s * stride_kvs + k_offs * stride_kvd
- k = tl.load(k_ptr, mask=k_mask & tok_valid, other=0.0).to(tl.float32)
+ v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
- # QK dot: broadcast k to (GROUP_SIZE, HEAD_DIM), dot with q_vals
- # qk_vals shape: (GROUP_SIZE,)
- qk = tl.sum(q_vals * k[None, :], axis=1) * SCALE
+ qk = tl.where(valid, qk, -1e30)
- # Load V: (HEAD_DIM,)
- v_ptr = KV + page_base + s * stride_kvs + (k_offs + HEAD_DIM) * stride_kvd
- v = tl.load(v_ptr, mask=k_mask & tok_valid, other=0.0).to(tl.float32)
+ old_max = acc_max
+ new_max = tl.maximum(old_max, qk)
+ exp_val = tl.exp(qk - new_max)
+ rescale_old = tl.exp(old_max - new_max)
- # Logit masking: -inf for invalid
- qk = tl.where(tok_valid, qk, -1e30)
+ acc_num = acc_num * rescale_old + v * exp_val
+ acc_denom = acc_denom * rescale_old + exp_val
+ acc_max = new_max
- # Online softmax update
- old_max = acc_max
- new_max = tl.maximum(old_max, qk)
- exp_val = tl.exp(qk - new_max)
- rescale = tl.exp(old_max - new_max)
+ out = acc_num / acc_denom
- acc_num = acc_num * rescale[:, None] + v[None, :] * exp_val[:, None]
- acc_denom = acc_denom * rescale + exp_val
-
... (1550 more chars){
"content": "<4408 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "ls -la /home/infatoshi/cuda/KernelBench-Hard/outputs/",
"description": "Check outputs directory"
}total 20
drwxrwxr-x 3 infatoshi infatoshi 4096 Apr 24 00:30 .
drwxrwxr-x 12 infatoshi infatoshi 4096 Apr 27 21:56 ..
drwxrwxr-x 93 infatoshi infatoshi 12288 Apr 29 07:10 runs
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/sol_cuda.py
@@ -0,0 +1,242 @@
+"""CUDA C++ paged attention decode kernel using torch.compile/inline PTX.
+
+Strategy:
+- Use torch.compile with custom Triton kernel that loads full page tiles
+- Or use cuda kernel via load_inline
+- One thread block per (batch, kv_head)
+- Each thread handles subset of heads in GQA group + subset of head_dim
+- Load full page K/V into shared memory, process page at a time
+"""
+import math
+import torch
+import torch.nn as nn
+
+
+# We'll write a CUDA kernel and use load_inline
+_kernel_source = """
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+#include <cuda_runtime.h>
+#include <math.h>
+#include <float.h>
+
+#define WARP_SIZE 32
+
+// Paged attention decode kernel.
+// Configuration: (batch, kv_head, head_dim/VEC_SIZE)
+// One warp per (batch, kv_head), threads in warp share K/V loads and compute per-head dot products.
+
+__device__ inline float vec_dot(const nv_bfloat16* a, const nv_bfloat16* b, int d) {
+ float sum = 0.0f;
+ for (int i = 0; i < d; ++i) {
+ float ai = __bfloat162float(a[i]);
+ float bi = __bfloat162float(b[i]);
+ sum += ai * bi;
+ }
+ return sum;
+}
+
+__global__ void paged_decode_kernel(
+ nv_bfloat16* __restrict__ Out,
+ const nv_bfloat16* __restrict__ Q,
+ const nv_bfloat16* __restrict__ KV,
+ const int32_t* __restrict__ BlockTable,
+ const int32_t* __restrict__ SeqLens,
+ int batch,
+ int num_heads,
+ int num_kv_heads,
+ int head_dim,
+ int page_size,
+ float scale,
+ int stride_qb,
+ int stride_qh,
+ int stride_qd,
+ int stride_kvb,
+ int stride_kvs,
+ int stride_kvh,
+ int stride_kvd,
+ int stride_bt_b,
+ int stride_bt_s,
+ int stride_ob,
+ int stride_oh,
+ int stride_od
+) {
+ int pid_b = blockIdx.x % batch;
+ int pid_hkv = blockIdx.y;
+ int group_start = threadIdx.x;
+ int group_size = num_heads / num_kv_heads;
+
+ int block_idx_b = blockIdx.x;
+ pid_b = block_idx_b % batch;
+ pid_hkv = blockIdx.y;
+ group_start = threadIdx.x;
+
+ // Each thread in warp handles one head in the GQA group
+ // But only threads with group_start < group_size do work
+
+ // Load seq_len for this batch
+ int seq_len = SeqLens[pid_b];
+ int num_pages = (seq_len + page_size - 1) / page_size;
+
+ // Base pointer for Q
+ int h_base = pid_hkv * group_size;
+
+ // Shared memory for K and V page: (page_size, head_dim) * 2
+ extern __shared__ char smem[];
+ nv_bfloat16* sK = (nv_bfloat16*)smem;
+ nv_bfloat16* sV = sK + page_size * head_dim;
+
+ // Each thread gets its Q vector if within group
+ float q_vec[256]; // max head_dim
+ bool thread_active = (group_start < group_size);
+
+ if (thread_active) {
+ int h = h_base + group_start;
+ const nv_bfloat16* q_ptr = Q + pid_b * stride_qb + h * stride_qh;
+ for (int d = 0; d < head_dim; ++d) {
+ q_vec[d] = (float)__bfloat162float(q_ptr[d]);
+ }
+ }
+
+ // Accumulators per thread
+ float acc_num[256];
+ float acc_denom;
+ float acc_max;
+ for (int d = 0; d < head_dim; ++d) acc_num[d] = 0.0f;
+ acc_denom = 1e-30f;
+ acc_max = -1e30f;
+
+ // Process pages
+ for (int page_idx = 0; page_idx < num_pages; ++page_idx) {
+ int block_id = BlockTable[pid_b * stride_bt_b + page_idx * stride_bt_s];
+
+ // Load K and V page into shared memory
+ // Parallelize loading: each thread loads some rows
+ int tokens_per_thread = (page_size + blockDim.x - 1) / blockDim.x;
+ for (int t = 0; t < tokens_per_thread && (threadIdx.x + t) < page_size; ++t) {
+ int tok = threadIdx.x * tokens_per_thread + t;
+ // K row
+ const nv_bfloat16* k_src = KV + block_id * stride_kvb + tok * stride_kvs;
+ nv_bfloat16* k_dst = sK + tok * head_dim;
+ // V row
+ const nv_bfloat16* v_src = k_src + head_dim; // K and V are contiguous
+ nv_bfloat16* v_dst = sV + tok * head_dim;
+
+ for (int d = 0; d < head_dim; ++d) {
+ k_dst[d] = k_src[d];
+ v_dst[d] = v_src[d];
+ }
+ }
+ __syncthreads();
+
+ // Each active thread computes QK^T for its head and accumulates
+ if (thread_active) {
+ int num_tokens_in_page = (page_idx + 1) * page_size <= seq_len ? page_size : seq_len - page_idx * page_size;
+
+ for (int tok = 0; tok < num_tokens_in_page; ++tok) {
+ const nv_bfloat16* k_ptr = sK + tok * head_dim;
+ float qk = 0.0f;
+ for (int d = 0; d < head_dim; ++d) {
+ qk += q_vec[d] * (float)__bfloat162float(k_ptr[d]);
+ }
+ qk *= scale;
+
+ const nv_bfloat16* v_ptr = sV + tok * head_dim;
+
+ float old_max = acc_max;
+ float new_max = fmaxf(old_max, qk);
+ float exp_val = __expf(qk - new_max);
+ float rescale = __expf(old_max - new_max);
+
+ for (int d = 0; d < head_dim; ++d) {
+ acc_num[d] = acc_num[d] * rescale + exp_val * (float)__bfloat162float(v_ptr[d]);
+ }
+ acc_denom = acc_denom * rescale + exp_val;
+ acc_max = new_max;
+ }
+ }
+ __syncthreads();
+ }
+
+ // Compute final output
+ if (thread_active) {
+ int h = h_base + group_start;
+ float inv_denom = 1.0f / acc_denom;
+ nv_bfloat16* out_ptr = Out + pid_b * stride_ob + h * stride_oh;
+ for (int d = 0; d < head_dim; ++d) {
+ out_ptr[d] = __float2bfloat16(acc_num[d] * inv_denom);
+ }
+ }
+}
+"""
+
+
+def _make_scale(head_dim: int) -> float:
+ return 1.0 / math.sqrt(head_dim)
+
+
+def _compile_kernel(head_dim: int, page_size: int) -> object:
+ """Compile k
... (2098 more chars){
"content": "<7740 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/sol_cuda.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -2,12 +2,15 @@
Strategy:
- One program per (batch, query_head)
-- Load entire page blocks at once using 2D block pointers
-- Process pages sequentially, accumulating online softmax
-- GQA: pid_hkv = pid_h // group_size
+- Load entire page of K and V from KV cache using block pointers into registers
+- Compute QK^T for all tokens in the page at once
+- Use tl.dot for Q @ K_page^T to get scores for all tokens in the page
+- Online softmax across all pages
+- Store result
-Optimization: instead of loading token-by-token, we use
-tl.load with block pointers to fetch entire rows of K/V at once.
+Key optimization: Process PAGE_SIZE tokens per iteration instead of 1.
+Use tl.dot to do (1, head_dim) @ (head_dim, PAGE_SIZE) = (1, PAGE_SIZE) scores.
+Then broadcast V page (PAGE_SIZE, head_dim) by softmax weights.
"""
import math
import torch
@@ -63,36 +66,52 @@
max_tokens = MAX_SEQ_LEN
- # Pre-compute page offsets to avoid repeated division/mod
- for t in tl.range(max_tokens):
- valid = t < seq_len
-
- page_i = t // PAGE_SIZE
- s = t % PAGE_SIZE
-
+ # Pre-compute page offsets for this batch, then process page-sized chunks
+ for page_i in tl.range((MAX_SEQ_LEN + PAGE_SIZE - 1) // PAGE_SIZE):
page_id = tl.load(BlockTable + bt_base + page_i * stride_bt_s)
page_base = page_id * stride_kvb + pid_hkv * stride_kvh
- k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
- k = tl.load(k_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
+ # Process each token in this page
+ for s in tl.range(PAGE_SIZE):
+ global_tok = page_i * PAGE_SIZE + s
+ valid = global_tok < seq_len
- qk = tl.sum(q * k) * SCALE
+ # If invalid, all accumulators stay the same (soft no-op via masking)
+ # Load K and V
+ k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
+ k = tl.load(k_ptr, mask=d_mask, other=0.0).to(tl.float32)
- v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
- v = tl.load(v_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
+ qk = tl.sum(q * k) * SCALE
+ qk = tl.where(valid, qk, -1e30)
- qk = tl.where(valid, qk, -1e30)
+ v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=d_mask, other=0.0).to(tl.float32)
- old_max = acc_max
- new_max = tl.maximum(old_max, qk)
- exp_val = tl.exp(qk - new_max)
- rescale_old = tl.exp(old_max - new_max)
+ # Softmax update: if invalid, qk=-inf so exp=0, no effect
+ # But we still need to handle the exp(-inf - new_max) = 0 correctly
+ new_max = tl.maximum(acc_max, qk)
- acc_num = acc_num * rescale_old + v * exp_val
- acc_denom = acc_denom * rescale_old + exp_val
- acc_max = new_max
+ # When qk = -1e30 and acc_max is finite, new_max = acc_max
+ # exp_val = exp(-1e30 - acc_max) ~ 0
+ # rescale = exp(acc_max - acc_max) = 1.0
+ # This is correct: valid tokens get updated, invalid contribute nothing
+ # Edge: if only invalid tokens and acc_max = -1e30 initially
+ # Then new_max = -1e30, exp_val = 1, this is bad
+ # Fix: check if any valid token exists
+ is_first_valid = tl.cast(acc_max == -1e30, tl.float32)
+
+ # For invalid token when no valid tokens seen yet, set exp_val=0
+ exp_val = tl.where(valid | (is_first_valid == 1.0), tl.exp(qk - new_max), 0.0)
+ rescale = tl.exp(tl.where(is_first_valid == 1.0, 0.0, acc_max - new_max))
+
+ acc_num = acc_num * rescale + v * exp_val
+ acc_denom = acc_denom * rescale + exp_val
+ acc_max = tl.where(valid, new_max, acc_max)
+
+ # Avoid division by zero
+ acc_denom = tl.maximum(acc_denom, 1e-30)
out = acc_num / acc_denom
out_ptr = Out + pid_b * stride_ob + pid_h * stride_oh + d_offs * stride_od
{
"content": "<5761 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,16 +1,10 @@
"""Optimized Triton paged attention decode kernel.
Strategy:
-- One program per (batch, query_head)
-- Load entire page of K and V from KV cache using block pointers into registers
-- Compute QK^T for all tokens in the page at once
-- Use tl.dot for Q @ K_page^T to get scores for all tokens in the page
-- Online softmax across all pages
-- Store result
-
-Key optimization: Process PAGE_SIZE tokens per iteration instead of 1.
-Use tl.dot to do (1, head_dim) @ (head_dim, PAGE_SIZE) = (1, PAGE_SIZE) scores.
-Then broadcast V page (PAGE_SIZE, head_dim) by softmax weights.
+- One program per (batch, kv_head). The program processes all heads in the GQA group.
+- Load entire page K and V into registers/SRAM once, share across group heads.
+- Use vectorized loads for head_dim dimension.
+- Online softmax accumulation per head.
"""
import math
import torch
@@ -36,86 +30,116 @@
stride_bt_b, stride_bt_s,
stride_ob, stride_oh, stride_od,
# constexpr params
+ NUM_HEADS: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
- HEAD_DIM_PACKED: tl.constexpr,
+ HEAD_DIM_PADDED: tl.constexpr,
PAGE_SIZE: tl.constexpr,
+ PAGE_SIZE_PADDED: tl.constexpr,
SCALE: tl.constexpr,
MAX_SEQ_LEN: tl.constexpr,
+ MAX_PAGES: tl.constexpr,
+ USE_BLOCK_PTRS: tl.constexpr,
):
pid_b = tl.program_id(0)
- pid_h = tl.program_id(1)
-
- pid_hkv = pid_h // GROUP_SIZE
+ pid_hkv = tl.program_id(1)
seq_len = tl.load(SeqLens + pid_b)
- d_offs = tl.arange(0, HEAD_DIM_PACKED)
- d_mask = d_offs < HEAD_DIM
+ # --- Pre-load page table for this batch ---
+ bt_base = pid_b * stride_bt_b
+ bt_offs = tl.arange(0, MAX_PAGES)
+ page_ids = tl.load(BlockTable + bt_base + bt_offs * stride_bt_s,
+ mask=bt_offs < MAX_PAGES, other=0)
- # Load Q: (HEAD_DIM,)
- q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
- q = tl.load(q_ptr, mask=d_mask, other=0.0).to(tl.float32)
+ # --- Load Q for all heads in the GQA group ---
+ # Shape: (GROUP_SIZE, HEAD_DIM_PADDED)
+ h_offs = tl.arange(0, GROUP_SIZE)
+ h_base = pid_hkv * GROUP_SIZE
+ d_offs = tl.arange(0, HEAD_DIM_PADDED)
+ q_mask = (h_offs < GROUP_SIZE)[:, None] & (d_offs < HEAD_DIM)[None, :]
- acc_num = tl.zeros([HEAD_DIM_PACKED], dtype=tl.float32)
- acc_denom = 0.0
- acc_max = -1e30
+ # Q is (batch, num_heads, head_dim)
+ # q_ptrs: Q[pid_b, h_base:h_base+GROUP_SIZE, :]
+ q_row_base = pid_b * stride_qb + (h_base + h_offs[:, None]) * stride_qh
+ q_ptrs = q_row_base + d_offs[None, :] * stride_qd
+ q_vals = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float32)
- bt_base = pid_b * stride_bt_b
+ # --- Accumulators: one per head in the group ---
+ acc_num = tl.zeros([GROUP_SIZE, HEAD_DIM_PADDED], dtype=tl.float32)
+ acc_denom = tl.zeros([GROUP_SIZE], dtype=tl.float32)
+ acc_max = tl.full([GROUP_SIZE], -1e30, dtype=tl.float32)
+ # --- Process each page ---
+ page_base_offsets = page_ids * stride_kvb + pid_hkv * stride_kvh
+
+ # We'll do token-by-token but with vectorized head_dim loads
max_tokens = MAX_SEQ_LEN
- # Pre-compute page offsets for this batch, then process page-sized chunks
- for page_i in tl.range((MAX_SEQ_LEN + PAGE_SIZE - 1) // PAGE_SIZE):
+ for t in tl.range(max_tokens):
+ valid = t < seq_len
+ page_i = t // PAGE_SIZE
+ s = t % PAGE_SIZE
+
+ # Get page base offset for this token
+ pbo = tl.load(page_base_offsets + page_i) if MAX_PAGES > 1 else tl.load(page_base_offsets)
+
+ # Actually let's compute offset without gather
page_id = tl.load(BlockTable + bt_base + page_i * stride_bt_s)
+ base = page_id * stride_kvb + pid_hkv * stride_kvh + s * stride_kvs
- page_base = page_id * stride_kvb + pid_hkv * stride_kvh
+ # Load K: (HEAD_DIM,)
+ k_offs = d_offs
+ k_ptr = KV + base + k_offs * stride_kvd
+ k = tl.load(k_ptr, mask=(d_offs < HEAD_DIM) & valid, other=0.0).to(tl.float32)
- # Process each token in this page
- for s in tl.range(PAGE_SIZE):
- global_tok = page_i * PAGE_SIZE + s
- valid = global_tok < seq_len
+ # QK^T: (GROUP_SIZE,) = sum(q_vals * k[None, :], axis=1) * SCALE
+ qk = tl.sum(q_vals * k[None, :], axis=1) * SCALE
+ qk = tl.where(valid, qk, -1e30)
- # If invalid, all accumulators stay the same (soft no-op via masking)
- # Load K and V
- k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
- k = tl.load(k_ptr, mask=d_mask, other=0.0).to(tl.float32)
+ # Load V: (HEAD_DIM,)
+ v_ptr = KV + base + (k_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=(d_offs < HEAD_DIM) & valid, other=0.0).to(tl.float32)
- qk = tl.sum(q * k) * SCALE
- qk = tl.where(valid, qk, -1e30)
+ # Online softmax update
+ # For invalid tokens: qk=-inf, so exp(qk-new_max) ~ 0, doesn't affect accumulators
+ old_max = acc_max
+ new_max = tl.maximum(old_max, qk)
- v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
- v = tl.load(v_ptr, mask=d_mask, other=0.0).to(tl.float32)
+ # Careful: when all tokens are invalid, acc_max stays -1e30
+ # exp(-1e30 - (-1e30)) = 1, which would give wrong result
+ # Fix: only update when valid
+ any_valid = tl.reduce(acc_max > -2e29, 0, combine_fn=tl.maximum)
+ # Simpler: just mask out invalid updates
- # Softmax update: if invalid, qk=-inf so exp=0, no effect
- # But we still need to handle the exp(-inf - new_max) = 0 correctly
- new
... (3681 more chars){
"content": "<6966 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test correctness for all shapes",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 2167, in load
return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1089, in load
return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1020, in _load_legacy
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
ValueError: Unsupported ptr type <['4', '128'], int32> in `tl.load`
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check
... (2860 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,10 +1,10 @@
-"""Optimized Triton paged attention decode kernel.
+"""CUDA C++ paged attention decode kernel via torch.compile/extension.
Strategy:
-- One program per (batch, kv_head). The program processes all heads in the GQA group.
-- Load entire page K and V into registers/SRAM once, share across group heads.
-- Use vectorized loads for head_dim dimension.
-- Online softmax accumulation per head.
+- Write raw CUDA kernel and use torch.compile/inline PTX for best performance
+- One thread block per (batch, kv_head)
+- Load K/V pages into shared memory
+- Compute attention for all heads in GQA group sharing the K/V load
"""
import math
import torch
@@ -13,6 +13,7 @@
import triton.language as tl
+# Keep the working Triton kernel as solution, we'll optimize later
def _make_scale(head_dim: int) -> float:
return 1.0 / math.sqrt(head_dim)
@@ -30,116 +31,69 @@
stride_bt_b, stride_bt_s,
stride_ob, stride_oh, stride_od,
# constexpr params
- NUM_HEADS: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
- HEAD_DIM_PADDED: tl.constexpr,
+ HEAD_DIM_PACKED: tl.constexpr,
PAGE_SIZE: tl.constexpr,
- PAGE_SIZE_PADDED: tl.constexpr,
SCALE: tl.constexpr,
MAX_SEQ_LEN: tl.constexpr,
- MAX_PAGES: tl.constexpr,
- USE_BLOCK_PTRS: tl.constexpr,
):
pid_b = tl.program_id(0)
- pid_hkv = tl.program_id(1)
+ pid_h = tl.program_id(1)
+
+ pid_hkv = pid_h // GROUP_SIZE
seq_len = tl.load(SeqLens + pid_b)
- # --- Pre-load page table for this batch ---
+ d_offs = tl.arange(0, HEAD_DIM_PACKED)
+ d_mask = d_offs < HEAD_DIM
+
+ # Load Q: (HEAD_DIM,)
+ q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
+ q = tl.load(q_ptr, mask=d_mask, other=0.0).to(tl.float32)
+
+ acc_num = tl.zeros([HEAD_DIM_PACKED], dtype=tl.float32)
+ acc_denom = 0.0
+ acc_max = -1e30
+
bt_base = pid_b * stride_bt_b
- bt_offs = tl.arange(0, MAX_PAGES)
- page_ids = tl.load(BlockTable + bt_base + bt_offs * stride_bt_s,
- mask=bt_offs < MAX_PAGES, other=0)
- # --- Load Q for all heads in the GQA group ---
- # Shape: (GROUP_SIZE, HEAD_DIM_PADDED)
- h_offs = tl.arange(0, GROUP_SIZE)
- h_base = pid_hkv * GROUP_SIZE
- d_offs = tl.arange(0, HEAD_DIM_PADDED)
- q_mask = (h_offs < GROUP_SIZE)[:, None] & (d_offs < HEAD_DIM)[None, :]
-
- # Q is (batch, num_heads, head_dim)
- # q_ptrs: Q[pid_b, h_base:h_base+GROUP_SIZE, :]
- q_row_base = pid_b * stride_qb + (h_base + h_offs[:, None]) * stride_qh
- q_ptrs = q_row_base + d_offs[None, :] * stride_qd
- q_vals = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float32)
-
- # --- Accumulators: one per head in the group ---
- acc_num = tl.zeros([GROUP_SIZE, HEAD_DIM_PADDED], dtype=tl.float32)
- acc_denom = tl.zeros([GROUP_SIZE], dtype=tl.float32)
- acc_max = tl.full([GROUP_SIZE], -1e30, dtype=tl.float32)
-
- # --- Process each page ---
- page_base_offsets = page_ids * stride_kvb + pid_hkv * stride_kvh
-
- # We'll do token-by-token but with vectorized head_dim loads
max_tokens = MAX_SEQ_LEN
for t in tl.range(max_tokens):
valid = t < seq_len
+
page_i = t // PAGE_SIZE
s = t % PAGE_SIZE
- # Get page base offset for this token
- pbo = tl.load(page_base_offsets + page_i) if MAX_PAGES > 1 else tl.load(page_base_offsets)
+ page_id = tl.load(BlockTable + bt_base + page_i * stride_bt_s)
- # Actually let's compute offset without gather
- page_id = tl.load(BlockTable + bt_base + page_i * stride_bt_s)
- base = page_id * stride_kvb + pid_hkv * stride_kvh + s * stride_kvs
+ page_base = page_id * stride_kvb + pid_hkv * stride_kvh
- # Load K: (HEAD_DIM,)
- k_offs = d_offs
- k_ptr = KV + base + k_offs * stride_kvd
- k = tl.load(k_ptr, mask=(d_offs < HEAD_DIM) & valid, other=0.0).to(tl.float32)
+ k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
+ k = tl.load(k_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
- # QK^T: (GROUP_SIZE,) = sum(q_vals * k[None, :], axis=1) * SCALE
- qk = tl.sum(q_vals * k[None, :], axis=1) * SCALE
+ qk = tl.sum(q * k) * SCALE
+
+ v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=d_mask & valid, other=0.0).to(tl.float32)
+
qk = tl.where(valid, qk, -1e30)
- # Load V: (HEAD_DIM,)
- v_ptr = KV + base + (k_offs + HEAD_DIM) * stride_kvd
- v = tl.load(v_ptr, mask=(d_offs < HEAD_DIM) & valid, other=0.0).to(tl.float32)
-
- # Online softmax update
- # For invalid tokens: qk=-inf, so exp(qk-new_max) ~ 0, doesn't affect accumulators
old_max = acc_max
new_max = tl.maximum(old_max, qk)
+ exp_val = tl.exp(qk - new_max)
+ rescale_old = tl.exp(old_max - new_max)
- # Careful: when all tokens are invalid, acc_max stays -1e30
- # exp(-1e30 - (-1e30)) = 1, which would give wrong result
- # Fix: only update when valid
- any_valid = tl.reduce(acc_max > -2e29, 0, combine_fn=tl.maximum)
- # Simpler: just mask out invalid updates
-
- # When valid=False: qk = -1e30 (already masked above)
- # new_max doesn't change because old_max >= any actual score >> -1e30
- # Actually old_max = -1e30 initially when no valid tokens seen
- # Then qk = -1e30 for invalid -> new_max = -1e30
- # exp(qk - new_max) = exp(0) = 1.0 -> WRONG
-
- # Solution: set invalid qk values so they are ignored
- is_first = acc_max <= -1e29 # no valid tokens seen yet
- safe_qk = tl.where(valid, qk, -1e30)
- new_max = tl.maximu
... (2212 more chars){
"content": "<4342 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness with baseline",
"timeout": 300000
}PASS
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/.venv/include/ 2>/dev/null | head -5; nvcc --version 2>&1 | head -3",
"description": "Check nvcc availability"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
{
"command": "which nvcc && nvcc --version 2>&1 | grep \"release\"",
"description": "Find nvcc path"
}/usr/local/cuda-13/bin/nvcc
Cuda compilation tools, release 13.2, V13.2.51
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,99 +1,202 @@
-"""CUDA C++ paged attention decode kernel via torch.compile/extension.
+"""High-performance CUDA paged attention decode kernel.
Strategy:
-- Write raw CUDA kernel and use torch.compile/inline PTX for best performance
- One thread block per (batch, kv_head)
-- Load K/V pages into shared memory
-- Compute attention for all heads in GQA group sharing the K/V load
+- Each thread handles one head in the GQA group
+- Load each page of K/V into shared memory once
+- All threads in block compute QK^T using their own Q and shared K/V
+- Online softmax accumulation per thread
+- Vectorized memory access for K/V loads
+
+This reduces KV cache bandwidth by sharing K/V loads across a GQA group of heads.
"""
import math
+import os
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
-
-
-# Keep the working Triton kernel as solution, we'll optimize later
+
+
+_CUDA_SOURCE = """
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+#include <cuda_runtime.h>
+#include <math.h>
+
+// Paged attention decode kernel.
+// Grid: (batch, num_kv_heads), Block: (group_size, 1, 1)
+// Each thread handles one query head in the GQA group.
+// Shared memory holds the current page's K and V.
+
+extern "C" __global__
+void __launch_bounds__(1024) paged_decode_kernel(
+ nv_bfloat16* __restrict__ Out,
+ const nv_bfloat16* __restrict__ Q,
+ const nv_bfloat16* __restrict__ KV,
+ const int* __restrict__ BlockTable,
+ const int* __restrict__ SeqLens,
+ int batch,
+ int num_heads,
+ int num_kv_heads,
+ int head_dim,
+ int page_size,
+ float scale,
+ int stride_qb, int stride_qh, int stride_qd,
+ int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
+ int stride_bt_b, int stride_bt_s,
+ int stride_ob, int stride_oh, int stride_od
+) {
+ // One block per (batch, kv_head)
+ extern __shared__ char smem_ptr[];
+ const int pid_b = blockIdx.x;
+ const int pid_hkv = blockIdx.y;
+ const int group_size = num_heads / num_kv_heads;
+
+ // Each thread handles one head
+ const int tid = threadIdx.x;
+ const int h = pid_hkv * group_size + tid;
+ const bool is_valid_thread = (tid < group_size);
+
+ // seq_len
+ const int seq_len = SeqLens[pid_b];
+ const int num_pages = (seq_len + page_size - 1) / page_size;
+
+ // Shared memory: K page + V page
+ // K: page_size x head_dim, V: page_size x head_dim
+ nv_bfloat16* __restrict__ smem_k = (nv_bfloat16*)smem_ptr;
+ nv_bfloat16* __restrict__ smem_v = smem_k + page_size * head_dim;
+
+ // Load Q for this thread
+ float q_vec[256]; // max head_dim = 128
+ if (is_valid_thread) {
+ const nv_bfloat16* __restrict__ q_base = Q + pid_b * stride_qb + h * stride_qh;
+ #pragma unroll
+ for (int d = 0; d < 128; d += 8) {
+ #pragma unroll
+ for (int i = 0; i < 8; ++i) {
+ if ((d + i) < head_dim)
+ q_vec[d + i] = __bfloat162float(q_base[d + i]);
+ }
+ }
+ }
+
+ // Accumulators
+ float acc_num[256];
+ float acc_denom = 1e-30f;
+ float acc_max = -1e30f;
+ #pragma unroll
+ for (int d = 0; d < 128; ++d) acc_num[d] = 0.0f;
+
+ // Process each page
+ for (int page_idx = 0; page_idx < num_pages; ++page_idx) {
+ const int block_id = BlockTable[pid_b * stride_bt_b + page_idx * stride_bt_s];
+
+ // Load K and V page into shared memory (coalesced across threads)
+ const int tokens_per_thread = (page_size + group_size - 1) / group_size;
+ #pragma unroll
+ for (int t_local = 0; t_local < 16; ++t_local) { // max page_size=16 tokens_per_thread=1 at most
+ const int tok_idx = tid * tokens_per_thread + t_local;
+ if (tok_idx < page_size) {
+ const nv_bfloat16* __restrict__ k_src = KV + block_id * stride_kvb + tok_idx * stride_kvs;
+ nv_bfloat16* __restrict__ k_dst = smem_k + tok_idx * head_dim;
+ // K and V are contiguous: V is right after K
+ const nv_bfloat16* __restrict__ v_src = k_src + head_dim;
+ nv_bfloat16* __restrict__ v_dst = smem_v + tok_idx * head_dim;
+ #pragma unroll
+ for (int d = 0; d < 128; d += 8) {
+ #pragma unroll
+ for (int i = 0; i < 8; ++i) {
+ if ((d + i) < head_dim) {
+ k_dst[d + i] = k_src[d + i];
+ v_dst[d + i] = v_src[d + i];
+ }
+ }
+ }
+ }
+ }
+ __syncthreads();
+
+ // Each thread computes attention for its head using shared K/V
+ if (is_valid_thread) {
+ const int tokens_in_page = (page_idx + 1) * page_size <= seq_len
+ ? page_size : seq_len - page_idx * page_size;
+
+ for (int tok = 0; tok < page_size; ++tok) {
+ if (tok >= tokens_in_page) break;
+
+ // Load K row from shared memory
+ const nv_bfloat16* __restrict__ k_row = smem_k + tok * head_dim;
+ float qk = 0.0f;
+ #pragma unroll
+ for (int d = 0; d < 128; d += 8) {
+ #pragma unroll
+ for (int i = 0; i < 8; ++i) {
+ if ((d + i) < head_dim) {
+ qk += q_vec[d + i] * __bfloat162float(k_row[d + i]);
+ }
+ }
+ }
+ qk *= scale;
+
+ // Load V row from shared memory
+ const nv_bfloat16* __restrict__ v_row = smem_v + tok * head_dim;
+
+ // Online softmax
+ const float old_max = acc_max;
+
... (5483 more chars){
"content": "<8513 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "<738 chars \u2014 see below>",
"description": "Test CUDA solution",
"timeout": 180000
}uv run python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref(*inputs)
sol_out = model(*inputs)
print('sol shape:', sol_out.shape, ref_out.shape)
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
print('atol 0.02:', torch.allclose(sol_out.float(), ref_out.float(), atol=0.02))
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 18, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 231, in forward
func = _get_kernel().paged_decode_kernel
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 198, in _get_kernel
_kernel_module = _compile_kernel()
^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_atte
... (4144 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -185,6 +185,7 @@
"-Xptxas=-v",
"--use_fast_math",
],
+ extra_include_paths=[],
verbose=False,
)
{
"oldString": "<301 chars \u2014 see below>",
"newString": "<333 chars \u2014 see below>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
} return load_inline(
name="paged_decode_v2",
cpp_sources='',
cuda_sources=_CUDA_SOURCE,
functions=["paged_decode_kernel"],
extra_cuda_cflags=[
"-O3",
"-Xptxas=-v",
"--use_fast_math",
],
verbose=False,
) return load_inline(
name="paged_decode_v2",
cpp_sources='',
cuda_sources=_CUDA_SOURCE,
functions=["paged_decode_kernel"],
extra_cuda_cflags=[
"-O3",
"-Xptxas=-v",
"--use_fast_math",
],
extra_include_paths=[],
verbose=False,
)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -175,17 +175,31 @@
def _compile_kernel():
from torch.utils.cpp_extension import load_inline
# Include a shape-dependent key for caching
+ CPP_SRC = '''
+#include <torch/extension.h>
+void paged_decode_kernel(
+ void *Out, const void *Q, const void *KV,
+ const int *BlockTable, const int *SeqLens,
+ int batch, int num_heads, int num_kv_heads,
+ int head_dim, int page_size, float scale,
+ int stride_qb, int stride_qh, int stride_qd,
+ int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
+ int stride_bt_b, int stride_bt_s,
+ int stride_ob, int stride_oh, int stride_od);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("paged_decode_kernel", &paged_decode_kernel, "paged_decode_kernel");
+}
+'''
return load_inline(
name="paged_decode_v2",
- cpp_sources='',
+ cpp_sources=CPP_SRC,
cuda_sources=_CUDA_SOURCE,
- functions=["paged_decode_kernel"],
extra_cuda_cflags=[
"-O3",
"-Xptxas=-v",
"--use_fast_math",
],
- extra_include_paths=[],
verbose=False,
)
{
"oldString": "<333 chars \u2014 see below>",
"newString": "<857 chars \u2014 see below>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
} return load_inline(
name="paged_decode_v2",
cpp_sources='',
cuda_sources=_CUDA_SOURCE,
functions=["paged_decode_kernel"],
extra_cuda_cflags=[
"-O3",
"-Xptxas=-v",
"--use_fast_math",
],
extra_include_paths=[],
verbose=False,
) CPP_SRC = '''
#include <torch/extension.h>
void paged_decode_kernel(
void *Out, const void *Q, const void *KV,
const int *BlockTable, const int *SeqLens,
int batch, int num_heads, int num_kv_heads,
int head_dim, int page_size, float scale,
int stride_qb, int stride_qh, int stride_qd,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
int stride_bt_b, int stride_bt_s,
int stride_ob, int stride_oh, int stride_od);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("paged_decode_kernel", &paged_decode_kernel, "paged_decode_kernel");
}
'''
return load_inline(
name="paged_decode_v2",
cpp_sources=CPP_SRC,
cuda_sources=_CUDA_SOURCE,
extra_cuda_cflags=[
"-O3",
"-Xptxas=-v",
"--use_fast_math",
],
verbose=False,
)Edit applied successfully.{
"command": "rm -rf /home/infatoshi/.cache/torch_extensions/py311_cu130/paged_decode_v2",
"description": "Clear old torch extension cache"
}--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,217 +1,16 @@
-"""High-performance CUDA paged attention decode kernel.
+"""Paged attention decode kernel using optimized PyTorch tensor operations.
Strategy:
-- One thread block per (batch, kv_head)
-- Each thread handles one head in the GQA group
-- Load each page of K/V into shared memory once
-- All threads in block compute QK^T using their own Q and shared K/V
-- Online softmax accumulation per thread
-- Vectorized memory access for K/V loads
-
-This reduces KV cache bandwidth by sharing K/V loads across a GQA group of heads.
+- Gather all K/V pages at once using advanced indexing
+- Reshape to (batch, seq_len, num_kv_heads, head_dim) for K and V
+- Use broadcasted einsum for QK^T and attention
+- Repeat interleave for GQA
+- This is essentially the reference but without Python loops over batch
+- All operations done in fused tensor form on GPU
"""
import math
-import os
import torch
import torch.nn as nn
-
-
-_CUDA_SOURCE = """
-#include <cuda_fp16.h>
-#include <cuda_bf16.h>
-#include <cuda_runtime.h>
-#include <math.h>
-
-// Paged attention decode kernel.
-// Grid: (batch, num_kv_heads), Block: (group_size, 1, 1)
-// Each thread handles one query head in the GQA group.
-// Shared memory holds the current page's K and V.
-
-extern "C" __global__
-void __launch_bounds__(1024) paged_decode_kernel(
- nv_bfloat16* __restrict__ Out,
- const nv_bfloat16* __restrict__ Q,
- const nv_bfloat16* __restrict__ KV,
- const int* __restrict__ BlockTable,
- const int* __restrict__ SeqLens,
- int batch,
- int num_heads,
- int num_kv_heads,
- int head_dim,
- int page_size,
- float scale,
- int stride_qb, int stride_qh, int stride_qd,
- int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
- int stride_bt_b, int stride_bt_s,
- int stride_ob, int stride_oh, int stride_od
-) {
- // One block per (batch, kv_head)
- extern __shared__ char smem_ptr[];
- const int pid_b = blockIdx.x;
- const int pid_hkv = blockIdx.y;
- const int group_size = num_heads / num_kv_heads;
-
- // Each thread handles one head
- const int tid = threadIdx.x;
- const int h = pid_hkv * group_size + tid;
- const bool is_valid_thread = (tid < group_size);
-
- // seq_len
- const int seq_len = SeqLens[pid_b];
- const int num_pages = (seq_len + page_size - 1) / page_size;
-
- // Shared memory: K page + V page
- // K: page_size x head_dim, V: page_size x head_dim
- nv_bfloat16* __restrict__ smem_k = (nv_bfloat16*)smem_ptr;
- nv_bfloat16* __restrict__ smem_v = smem_k + page_size * head_dim;
-
- // Load Q for this thread
- float q_vec[256]; // max head_dim = 128
- if (is_valid_thread) {
- const nv_bfloat16* __restrict__ q_base = Q + pid_b * stride_qb + h * stride_qh;
- #pragma unroll
- for (int d = 0; d < 128; d += 8) {
- #pragma unroll
- for (int i = 0; i < 8; ++i) {
- if ((d + i) < head_dim)
- q_vec[d + i] = __bfloat162float(q_base[d + i]);
- }
- }
- }
-
- // Accumulators
- float acc_num[256];
- float acc_denom = 1e-30f;
- float acc_max = -1e30f;
- #pragma unroll
- for (int d = 0; d < 128; ++d) acc_num[d] = 0.0f;
-
- // Process each page
- for (int page_idx = 0; page_idx < num_pages; ++page_idx) {
- const int block_id = BlockTable[pid_b * stride_bt_b + page_idx * stride_bt_s];
-
- // Load K and V page into shared memory (coalesced across threads)
- const int tokens_per_thread = (page_size + group_size - 1) / group_size;
- #pragma unroll
- for (int t_local = 0; t_local < 16; ++t_local) { // max page_size=16 tokens_per_thread=1 at most
- const int tok_idx = tid * tokens_per_thread + t_local;
- if (tok_idx < page_size) {
- const nv_bfloat16* __restrict__ k_src = KV + block_id * stride_kvb + tok_idx * stride_kvs;
- nv_bfloat16* __restrict__ k_dst = smem_k + tok_idx * head_dim;
- // K and V are contiguous: V is right after K
- const nv_bfloat16* __restrict__ v_src = k_src + head_dim;
- nv_bfloat16* __restrict__ v_dst = smem_v + tok_idx * head_dim;
- #pragma unroll
- for (int d = 0; d < 128; d += 8) {
- #pragma unroll
- for (int i = 0; i < 8; ++i) {
- if ((d + i) < head_dim) {
- k_dst[d + i] = k_src[d + i];
- v_dst[d + i] = v_src[d + i];
- }
- }
- }
- }
- }
- __syncthreads();
-
- // Each thread computes attention for its head using shared K/V
- if (is_valid_thread) {
- const int tokens_in_page = (page_idx + 1) * page_size <= seq_len
- ? page_size : seq_len - page_idx * page_size;
-
- for (int tok = 0; tok < page_size; ++tok) {
- if (tok >= tokens_in_page) break;
-
- // Load K row from shared memory
- const nv_bfloat16* __restrict__ k_row = smem_k + tok * head_dim;
- float qk = 0.0f;
- #pragma unroll
- for (int d = 0; d < 128; d += 8) {
- #pragma unroll
- for (int i = 0; i < 8; ++i) {
- if ((d + i) < head_dim) {
- qk += q_vec[d + i] * __bfloat162float(k_row[d + i]);
- }
- }
- }
- qk *= scale;
-
- // Load V row from shared memory
- const nv_bfloat16* __restrict__ v_row = smem_v + tok * head_dim;
-
- // Online softmax
- const float old_m
... (6939 more chars){
"content": "<4850 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness of tensor solution",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark tensor solution",
"timeout": 300000
}W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0429 07:33:53.110000 1000639 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 07
... (1198 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,16 +1,179 @@
-"""Paged attention decode kernel using optimized PyTorch tensor operations.
+"""High-performance CUDA paged attention decode kernel via torch.utils.cpp_extension.
Strategy:
-- Gather all K/V pages at once using advanced indexing
-- Reshape to (batch, seq_len, num_kv_heads, head_dim) for K and V
-- Use broadcasted einsum for QK^T and attention
-- Repeat interleave for GQA
-- This is essentially the reference but without Python loops over batch
-- All operations done in fused tensor form on GPU
+- One thread block per (batch, kv_head)
+- Each thread handles one head in the GQA group
+- Load each page of K/V into shared memory once
+- All threads compute QK^T using shared K/V (saves KV bandwidth)
+- Online softmax per thread
"""
import math
import torch
import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+def _make_scale(head_dim: int) -> float:
+ return 1.0 / math.sqrt(head_dim)
+
+_CUDA_SRC = r"""
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+#include <cuda_runtime.h>
+
+__global__ void __launch_bounds__(512) paged_decode_kernel(
+ nv_bfloat16* __restrict__ Out,
+ const nv_bfloat16* __restrict__ Q,
+ const nv_bfloat16* __restrict__ KV,
+ const int* __restrict__ BlockTable,
+ const int* __restrict__ SeqLens,
+ int batch, int num_heads, int num_kv_heads, int head_dim,
+ int page_size, float scale,
+ int stride_qb, int stride_qh, int stride_qd,
+ int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
+ int stride_bt_b, int stride_bt_s,
+ int stride_ob, int stride_oh, int stride_od
+) {
+ int pid_b = blockIdx.x;
+ int pid_hkv = blockIdx.y;
+ int group_size = num_heads / num_kv_heads;
+ int tid = threadIdx.x;
+
+ if (tid >= group_size) return;
+
+ int h = pid_hkv * group_size + tid;
+ int seq_len = SeqLens[pid_b];
+ int num_pages = (seq_len + page_size - 1) / page_size;
+
+ // Shared memory for current page K + V
+ extern __shared__ char smem[];
+ nv_bfloat16* __restrict__ sK = (nv_bfloat16*)smem;
+ nv_bfloat16* __restrict__ sV = sK + page_size * head_dim;
+
+ // Load Q
+ float q[128];
+ const nv_bfloat16* qp = Q + pid_b * stride_qb + h * stride_qh;
+ for (int d = 0; d < head_dim; ++d)
+ q[d] = __bfloat162float(qp[d]);
+
+ // Accumulators
+ float acc_num[128] = {0};
+ float acc_denom = 1e-30f;
+ float acc_max = -1e30f;
+
+ for (int pg = 0; pg < num_pages; ++pg) {
+ int block_id = BlockTable[pid_b * stride_bt_b + pg * stride_bt_s];
+
+ // Cooperative load of K/V page
+ const nv_bfloat16* k_page_src = KV + block_id * stride_kvb;
+ const int tokens_to_load = page_size;
+ const int items_thread = (tokens_to_load * head_dim + group_size - 1) / group_size;
+
+ for (int i_idx = 0; i_idx < items_thread; ++i_idx) {
+ int flat = tid * items_thread + i_idx;
+ if (flat < tokens_to_load * head_dim) {
+ int row = flat / head_dim;
+ int col = flat % head_dim;
+ // K
+ sK[row * head_dim + col] = k_page_src[row * stride_kvs + col * stride_kvd];
+ // V (contiguous after K in last dim)
+ sV[row * head_dim + col] = k_page_src[row * stride_kvs + head_dim + col * stride_kvd];
+ }
+ }
+ __syncthreads();
+
+ // Compute attention for tokens in this page
+ int tokens_in_page = (pg + 1) * page_size <= seq_len ? page_size : seq_len - pg * page_size;
+ for (int t = 0; t < tokens_in_page; ++t) {
+ float qk = 0.0f;
+ const nv_bfloat16* kr = sK + t * head_dim;
+ #pragma unroll
+ for (int d = 0; d < 128; ++d) {
+ if (d < head_dim) qk += q[d] * __bfloat162float(kr[d]);
+ }
+ qk *= scale;
+
+ float old_max = acc_max;
+ float new_max = fmaxf(old_max, qk);
+ float exp_v = __expf(qk - new_max);
+ float rescale = __expf(old_max - new_max);
+
+ const nv_bfloat16* vr = sV + t * head_dim;
+ for (int d = 0; d < head_dim; ++d) {
+ acc_num[d] = acc_num[d] * rescale + exp_v * __bfloat162float(vr[d]);
+ }
+ acc_denom = acc_denom * rescale + exp_v;
+ acc_max = new_max;
+ }
+ __syncthreads();
+ }
+
+ // Write output
+ nv_bfloat16* __restrict__ op = Out + pid_b * stride_ob + h * stride_oh;
+ float inv = 1.0f / acc_denom;
+ for (int d = 0; d < head_dim; ++d)
+ op[d] = __float2bfloat16(acc_num[d] * inv);
+}
+"""
+
+_CPP_SRC = r"""
+#include <torch/extension.h>
+#include <cuda_runtime.h>
+
+void paged_decode_forward(
+ torch::Tensor out,
+ torch::Tensor query,
+ torch::Tensor kv_cache,
+ torch::Tensor block_table,
+ torch::Tensor seq_lens,
+ float scale) {
+
+ int B = query.size(0);
+ int H = query.size(1);
+ int Hkv = kv_cache.size(2);
+ int D = query.size(2);
+ int P = kv_cache.size(1); // page_size
+ int G = H / Hkv;
+
+ int smem_size = 2 * P * D * 2; // K + V in bf16
+
+ int gridX = B;
+ int gridY = Hkv;
+ int blockZ = G;
+
+ paged_decode_kernel<<<dim3(gridX, gridY, 1), dim3(blockZ, 1, 1), smem_size>>>(
+ (nv_bfloat16*)out.data_ptr(),
+ (const nv_bfloat16*)query.data_ptr(),
+ (const nv_bfloat16*)kv_cache.data_ptr(),
+ block_table.data_ptr<int>(),
+ seq_lens.data_ptr<int>(),
+ B, H, Hkv, D, P, scale,
+ query.stride(0), query.stride(1), query.stride(2),
+ kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
+ block_table.stride(0), block_table.stride(1),
+ out.stride(0), out.stride(1), out.stride(2)
+ );
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION
... (4813 more chars){
"content": "<6816 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness of CUDA solution",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 72, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 199, in forward
_get_module().paged_decode_forward(out, query, kv_cache, block_table, seq_lens,
... (5271 more chars){
"description": "Write optimal Triton paged attention",
"prompt": "<4259 chars \u2014 see below>",
"subagent_type": "general"
}Write an optimized Triton paged attention decode kernel for SM120 Blackwell GPUs. The kernel should be written to `/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py`.
**Requirements:**
- Input: query (batch, num_heads, head_dim) bf16, kv_cache (num_blocks, page_size, num_kv_heads, head_dim*2) bf16, block_table (batch, max_blocks) int32, seq_lens (batch,) int32
- Output: bf16 (batch, num_heads, head_dim)
- GQA: num_heads / num_kv_heads may be > 1 (group_size)
- Must handle all shapes listed below
- Correctness tolerance: 0.02 abs/rel on bf16
**Shapes:**
1. batch=8, num_heads=32, num_kv_heads=8, head_dim=128, seq_len=1024, page_size=16
2. batch=32, num_heads=32, num_kv_heads=8, head_dim=128, seq_len=2048, page_size=16
3. batch=4, num_heads=64, num_kv_heads=8, head_dim=128, seq_len=4096, page_size=16
4. batch=16, num_heads=32, num_kv_heads=8, head_dim=128, seq_len=1535, page_size=16
5. batch=8, num_heads=16, num_kv_heads=4, head_dim=64, seq_len=2000, page_size=16
**Important Triton notes:**
1. `tl.range` with non-constexpr bounds works but `break`/`continue` are NOT supported - use if/else with masking instead
2. Simultaneous multiple assignment is NOT supported (don't do `a, b = x, y`)
3. Use constexpr for known sizes like HEAD_DIM, PAGE_SIZE passed from Python
4. The kernel must handle variable seq_len via predication (masking)
**Performance strategy:**
- One program per (batch, query_head)
- Since head_dim <= 128, load 128 elements at once for K and V (contiguous in last dim)
- Each iteration loads ONE token's K and V (128 bf16 values each)
- Use vectorized loads (single tl.load with 128-element range)
- Accumulate online softmax
**Reference reference implementation:**
```python
# For each batch b:
# Gather pages -> contiguous (seq_len, num_kv_heads, head_dim) per batch element
# Repeat KV heads for GQA
# scores = Q @ K^T / sqrt(d) -> softmax -> scores @ V
```
**Critical correctness rule:** When `t >= seq_len`, the QK score must NOT affect the softmax. Use masking: set qk = -1e30 for invalid tokens.
**Key insight for correctness:** The initial acc_max is -1e30. When the first valid token arrives with qk, exp(qk - max(-1e30, qk)) should work correctly. The issue is: if the very first token is INVALID with qk=-1e30, then new_max=-1e30, exp(0)=1, which would incorrectly add to acc_num with v=0 (since v is masked with other=0). This is actually OK because v=0 so it doesn't affect results. BUT if multiple invalid tokens come first, exp_val=1 every time and acc_denom grows. When a valid token appears, rescale = exp(-1e30 - new_max) ~ 0, so the accumulated garbage is zeroed out. So the math should work out.
However, there's a numerical edge case: if ONLY invalid tokens exist, acc_denom could be large and acc_num = 0, giving output 0. This shouldn't happen since seq_len > 0.
**Write the complete solution.py file with:**
1. The Triton kernel `_paged_decode_kernel`
2. `Model` class with `__init__` and `forward`
3. `get_inputs()` and `get_init_inputs()` that delegate to reference module
The Model class should:
- Compute scale = 1/sqrt(head_dim)
- Pass all parameters as constexpr (HEAD_DIM, PAGE_SIZE, etc.) to the Triton kernel
- Use grid = (B, H) where B=batch and H=num_heads
**IMPORTANT:** Make sure the kernel is CORRECT first. I've verified that the following simple structure works:
```python
for t in tl.range(max_tokens):
valid = t < seq_len
page_i = t // PAGE_SIZE
s = t % PAGE_SIZE
page_id = tl.load(BlockTable + bt_base + page_i * stride_bt_s)
page_base = page_id * stride_kvb + pid_hkv * stride_kvh
# Load K
k = tl.load(KV + page_base + s * stride_kvs + d_offs * stride_kvd, mask=d_mask & valid, other=0.0)
qk = tl.sum(q * k) * SCALE
# Load V
v = tl.load(KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd, mask=d_mask & valid, other=0.0)
qk = tl.where(valid, qk, -1e30)
# Online softmax update...
```
The only optimization we c
... (259 more chars)task_id: ses_2268c4a1fffeMG0AJw7P3lJQ1H (for resuming to continue this task if needed)
<task_result>
All shapes pass correctness with trivial Triton kernel. The max diff across all shapes is only ~0.00006 (well under 0.02 tolerance).
</task_result>{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark current solution",
"timeout": 300000
}W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/reference.py", line 89, in forward
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] L = int(seq_lens[b].item())
W0429 07:38:14.981000 1003077 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 07
... (1199 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
1: """Optimized Triton paged-attention decode kernel for SM120 Blackwell GPUs.
2:
3: Single-query decode: each batch element has a query of shape (num_heads, head_dim)
4: and attends over a paged KV cache. Supports GQA (grouped-query attention).
5: """
6: import math
7:
8: import torch
9: import torch.nn as nn
10: import triton
11: import triton.language as tl
12:
13:
14: # ---------------------------------------------------------------------------
15: # Triton kernel
16: # ---------------------------------------------------------------------------
17: @triton.jit
18: def _paged_decode_kernel(
19: Out,
20: Q,
21: KV,
22: BlockTable,
23: SeqLens,
24: # strides for Q: (B, H, D)
25: stride_qb,
26: stride_qh,
27: stride_qd,
28: # strides for KV: (num_blocks * page_size, num_kv_heads, 2*head_dim)
29: # flattened: stride_kvb is per-block, stride_kvs is per-slot-in-page,
30: # stride_kvh is per-kv-head, stride_kvd is per-dim
31: stride_kvb,
32: stride_kvs,
33: stride_kvh,
34: stride_kvd,
35: # strides for Out: (B, H, D)
36: stride_ob,
37: stride_oh,
38: stride_od,
39: # block_table stride: (B, max_blocks)
40: stride_bt_b,
41: # constexpr knobs
42: HEAD_DIM: tl.constexpr,
43: PAGE_SIZE: tl.constexpr,
44: NUM_KV_HEADS: tl.constexpr,
45: GROUP_SIZE: tl.conste
... (6553 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
1: """Optimized Triton paged-attention decode kernel for SM120 Blackwell GPUs.
2:
3: Single-query decode: each batch element has a query of shape (num_heads, head_dim)
4: and attends over a paged KV cache. Supports GQA (grouped-query attention).
5: """
6: import math
7:
8: import torch
9: import torch.nn as nn
10: import triton
11: import triton.language as tl
12:
13:
14: # ---------------------------------------------------------------------------
15: # Triton kernel
16: # ---------------------------------------------------------------------------
17: @triton.jit
18: def _paged_decode_kernel(
19: Out,
20: Q,
21: KV,
22: BlockTable,
23: SeqLens,
24: # strides for Q: (B, H, D)
25: stride_qb,
26: stride_qh,
27: stride_qd,
28: # strides for KV: (num_blocks * page_size, num_kv_heads, 2*head_dim)
29: # flattened: stride_kvb is per-block, stride_kvs is per-slot-in-page,
30: # stride_kvh is per-kv-head, stride_kvd is per-dim
31: stride_kvb,
32: stride_kvs,
33: stride_kvh,
34: stride_kvd,
35: # strides for Out: (B, H, D)
36: stride_ob,
37: stride_oh,
38: stride_od,
39: # block_table stride: (B, max_blocks)
40: stride_bt_b,
41: # constexpr knobs
42: HEAD_DIM: tl.constexpr,
43: PAGE_SIZE: tl.constexpr,
44: NUM_KV_HEADS: tl.constexpr,
45: GROUP_SIZE: tl.conste
... (6553 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,182 +1,154 @@
-"""High-performance CUDA paged attention decode kernel via torch.utils.cpp_extension.
+"""High-performance CUDA paged attention decode kernel.
Strategy:
- One thread block per (batch, kv_head)
-- Each thread handles one head in the GQA group
-- Load each page of K/V into shared memory once
-- All threads compute QK^T using shared K/V (saves KV bandwidth)
-- Online softmax per thread
+- Block has num_warps * 32 threads: each warp is one query head in GQA group
+- Load K/V page into shared memory cooperatively
+- Each warp computes QK^T using warp shuffle / register loads
+- Online softmax per warp with per-warp accumulator
"""
import math
import torch
import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
+import triton
+import triton.language as tl
+
def _make_scale(head_dim: int) -> float:
return 1.0 / math.sqrt(head_dim)
-_CUDA_SRC = r"""
-#include <cuda_fp16.h>
-#include <cuda_bf16.h>
-#include <cuda_runtime.h>
-__global__ void __launch_bounds__(512) paged_decode_kernel(
- nv_bfloat16* __restrict__ Out,
- const nv_bfloat16* __restrict__ Q,
- const nv_bfloat16* __restrict__ KV,
- const int* __restrict__ BlockTable,
- const int* __restrict__ SeqLens,
- int batch, int num_heads, int num_kv_heads, int head_dim,
- int page_size, float scale,
- int stride_qb, int stride_qh, int stride_qd,
- int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
- int stride_bt_b, int stride_bt_s,
- int stride_ob, int stride_oh, int stride_od
-) {
- int pid_b = blockIdx.x;
- int pid_hkv = blockIdx.y;
- int group_size = num_heads / num_kv_heads;
- int tid = threadIdx.x;
+# The key optimization for paged decode: load full K/V pages into registers/sram,
+# process PAGE_SIZE tokens per page iteration, share loads across GQA group.
+#
+# With PAGE_SIZE=16 and HEAD_DIM=128:
+# - Per page K load: 16 * 128 * 2 = 4096 bytes
+# - Per page V load: 16 * 128 * 2 = 4096 bytes
+# - Process all 16 tokens at once via tl.dot or vectorized loads
+#
+# For head_dim=128, one thread can hold the entire K/V row in registers.
- if (tid >= group_size) return;
- int h = pid_hkv * group_size + tid;
- int seq_len = SeqLens[pid_b];
- int num_pages = (seq_len + page_size - 1) / page_size;
+@triton.jit
+def _paged_decode_kernel(
+ Out,
+ Q,
+ KV,
+ BlockTable,
+ SeqLens,
+ # strides
+ stride_qb,
+ stride_qh,
+ stride_qd,
+ stride_kvb,
+ stride_kvs,
+ stride_kvh,
+ stride_kvd,
+ stride_ob,
+ stride_oh,
+ stride_od,
+ stride_bt_b,
+ # constexpr
+ HEAD_DIM: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+ SCALE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ """Decode kernel with PAGE_SIZE tokens processed per page iteration.
+
+ Each program handles one (batch, head) pair.
+ Uses tl.dot to compute (1, HEAD_DIM) @ (HEAD_DIM, BLOCK_N) = (1, BLOCK_N).
+ """
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
- // Shared memory for current page K + V
- extern __shared__ char smem[];
- nv_bfloat16* __restrict__ sK = (nv_bfloat16*)smem;
- nv_bfloat16* __restrict__ sV = sK + page_size * head_dim;
+ pid_hkv = pid_h // GROUP_SIZE
- // Load Q
- float q[128];
- const nv_bfloat16* qp = Q + pid_b * stride_qb + h * stride_qh;
- for (int d = 0; d < head_dim; ++d)
- q[d] = __bfloat162float(qp[d]);
+ seq_len = tl.load(SeqLens + pid_b)
- // Accumulators
- float acc_num[128] = {0};
- float acc_denom = 1e-30f;
- float acc_max = -1e30f;
+ d_offs = tl.arange(0, HEAD_DIM)
+ n_offs = tl.arange(0, BLOCK_N)
+ n_mask = n_offs < PAGE_SIZE
- for (int pg = 0; pg < num_pages; ++pg) {
- int block_id = BlockTable[pid_b * stride_bt_b + pg * stride_bt_s];
+ # Load Q
+ q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
+ q = tl.load(q_ptr).to(tl.float32)
- // Cooperative load of K/V page
- const nv_bfloat16* k_page_src = KV + block_id * stride_kvb;
- const int tokens_to_load = page_size;
- const int items_thread = (tokens_to_load * head_dim + group_size - 1) / group_size;
+ bt_base = pid_b * stride_bt_b
- for (int i_idx = 0; i_idx < items_thread; ++i_idx) {
- int flat = tid * items_thread + i_idx;
- if (flat < tokens_to_load * head_dim) {
- int row = flat / head_dim;
- int col = flat % head_dim;
- // K
- sK[row * head_dim + col] = k_page_src[row * stride_kvs + col * stride_kvd];
- // V (contiguous after K in last dim)
- sV[row * head_dim + col] = k_page_src[row * stride_kvs + head_dim + col * stride_kvd];
- }
- }
- __syncthreads();
+ num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
- // Compute attention for tokens in this page
- int tokens_in_page = (pg + 1) * page_size <= seq_len ? page_size : seq_len - pg * page_size;
- for (int t = 0; t < tokens_in_page; ++t) {
- float qk = 0.0f;
- const nv_bfloat16* kr = sK + t * head_dim;
- #pragma unroll
- for (int d = 0; d < 128; ++d) {
- if (d < head_dim) qk += q[d] * __bfloat162float(kr[d]);
- }
- qk *= scale;
+ acc_num = tl.zeros([HEAD_DIM], dtype=tl.float32)
+ acc_denom = 0.0
+ acc_max = -1e30
- float old_max = acc_max;
- float new_max = fmaxf(old_max, qk);
- float exp_v = __expf(qk - new_max);
- float rescale = __expf(old_max - new_max);
+ for page_idx in tl.range(num_pages):
+ page_id = tl.load(BlockTable + bt_base + page_idx)
+ page_base = page_id * stride_kvb + pid_hkv
... (6348 more chars){
"content": "<6376 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "<738 chars \u2014 see below>",
"description": "Test v6 solution",
"timeout": 180000
}uv run python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref(*inputs)
sol_out = model(*inputs)
print('sol shape:', sol_out.shape, ref_out.shape)
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
print('atol 0.02:', torch.allclose(sol_out.float(), ref_out.float(), atol=0.02))
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1094, in __getitem__
for dim, sl in enumerate(slices):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 41, in wrapper
raise ValueError("Did you forget to add @triton.jit ? "
ValueError: Did you forget to add @triton.jit ? (`_semantic` argument must be provided outside of JIT functions.)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 18, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 176, in forward
_paged_decode_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/si
... (2156 more chars){
"command": "<72 chars \u2014 see below>",
"description": "Check triton.load docs"
}python -c "
import triton.language as tl
help(tl.load)
" 2>&1 | head -30Help on function load in module triton.language.core:
load(pointer, mask=None, other=None, boundary_check=(), padding_option='', cache_modifier='', eviction_policy='', volatile=False, _semantic=None)
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
(1) If `pointer` is a single element pointer, a scalar is be loaded. In
this case:
- `mask` and `other` must also be scalars,
- `other` is implicitly typecast to `pointer.dtype.element_ty`, and
- `boundary_check` and `padding_option` must be empty.
(2) If `pointer` is an N-dimensional tensor of pointers, an
N-dimensional tensor is loaded. In this case:
- `mask` and `other` are implicitly broadcast to `pointer.shape`,
- `other` is implicitly typecast to `pointer.dtype.element_ty`, and
- `boundary_check` and `padding_option` must be empty.
(3) If `pointer` is a block pointer defined by `make_block_ptr`, a
tensor is loaded. In this case:
- `mask` and `other` must be `None`, and
- `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access.
:param pointer: Pointer to the data to be loaded
:type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
:param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
(must be `None`
... (70 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,13 +1,13 @@
-"""High-performance CUDA paged attention decode kernel.
+"""Optimized Triton paged attention decode kernel.
-Strategy:
-- One thread block per (batch, kv_head)
-- Block has num_warps * 32 threads: each warp is one query head in GQA group
-- Load K/V page into shared memory cooperatively
-- Each warp computes QK^T using warp shuffle / register loads
-- Online softmax per warp with per-warp accumulator
+Key optimizations over naive:
+1. Each program handles one (batch, head) but pre-caches page_id at start of each page
+2. Use triton's constexpr PAGE_SIZE loop for full unrolling (16 iterations max)
+3. Reduce memory reads: block_table only readed once per page, not per token
+4. Cache page_base in register (avoids recomputing ptr arithmetic per token)
"""
import math
+
import torch
import torch.nn as nn
import triton
@@ -18,17 +18,6 @@
return 1.0 / math.sqrt(head_dim)
-# The key optimization for paged decode: load full K/V pages into registers/sram,
-# process PAGE_SIZE tokens per page iteration, share loads across GQA group.
-#
-# With PAGE_SIZE=16 and HEAD_DIM=128:
-# - Per page K load: 16 * 128 * 2 = 4096 bytes
-# - Per page V load: 16 * 128 * 2 = 4096 bytes
-# - Process all 16 tokens at once via tl.dot or vectorized loads
-#
-# For head_dim=128, one thread can hold the entire K/V row in registers.
-
-
@triton.jit
def _paged_decode_kernel(
Out,
@@ -37,29 +26,16 @@
BlockTable,
SeqLens,
# strides
- stride_qb,
- stride_qh,
- stride_qd,
- stride_kvb,
- stride_kvs,
- stride_kvh,
- stride_kvd,
- stride_ob,
- stride_oh,
- stride_od,
+ stride_qb, stride_qh, stride_qd,
+ stride_kvb, stride_kvs, stride_kvh, stride_kvd,
+ stride_ob, stride_oh, stride_od,
stride_bt_b,
# constexpr
HEAD_DIM: tl.constexpr,
PAGE_SIZE: tl.constexpr,
GROUP_SIZE: tl.constexpr,
SCALE: tl.constexpr,
- BLOCK_N: tl.constexpr,
):
- """Decode kernel with PAGE_SIZE tokens processed per page iteration.
-
- Each program handles one (batch, head) pair.
- Uses tl.dot to compute (1, HEAD_DIM) @ (HEAD_DIM, BLOCK_N) = (1, BLOCK_N).
- """
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
@@ -68,10 +44,8 @@
seq_len = tl.load(SeqLens + pid_b)
d_offs = tl.arange(0, HEAD_DIM)
- n_offs = tl.arange(0, BLOCK_N)
- n_mask = n_offs < PAGE_SIZE
- # Load Q
+ # Load Q: (D,)
q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
q = tl.load(q_ptr).to(tl.float32)
@@ -83,60 +57,40 @@
acc_denom = 0.0
acc_max = -1e30
+ # Loop over pages - each iteration loads page_id and processes PAGE_SIZE tokens
for page_idx in tl.range(num_pages):
+ # Load page_id ONCE per page (not per token)
page_id = tl.load(BlockTable + bt_base + page_idx)
+ # Cache page_base in register
page_base = page_id * stride_kvb + pid_hkv * stride_kvh
- global_start = page_idx * PAGE_SIZE
- # Load K page: (PAGE_SIZE, HEAD_DIM)
- # k_ptrs: (BLOCK_N, HEAD_DIM)
- k_base = KV + page_base
- k_ptrs = k_base + (n_offs[:, None] * stride_kvs + d_offs[None, :] * stride_kvd)
- k_mask = n_mask[:, None]
- k_page = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.float32)
+ # Unroll PAGE_SIZE tokens
+ for s in tl.range(PAGE_SIZE):
+ global_tok = page_idx * PAGE_SIZE + s
+ valid = global_tok < seq_len
- # Load V page: (PAGE_SIZE, HEAD_DIM)
- v_base = KV + page_base + HEAD_DIM * stride_kvd
- v_ptrs = v_base + (n_offs[:, None] * stride_kvs + d_offs[None, :] * stride_kvd)
- v_page = tl.load(v_ptrs, mask=k_mask, other=0.0).to(tl.float32)
+ # K pointer
+ k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
+ k = tl.load(k_ptr, mask=valid, other=0.0).to(tl.float32)
- # Compute QK^T: (1, BLOCK_N) = q[1, D] @ K[D, BLOCK_N].T
- # Actually q is [D,] and k_page is [BLOCK_N, D]
- # We need q @ K^T = sum(q[d] * K[n, d], d) for each n
- scores = tl.sum(q[None, :] * k_page, axis=1) * SCALE # (BLOCK_N,)
+ # QK dot
+ qk = tl.sum(q * k) * SCALE
- # Mask for tokens beyond seq_len
- valid_mask = (global_start + n_offs) < seq_len
- scores = tl.where(valid_mask, scores, -1e30)
+ # V pointer (K and V contiguous in last dim)
+ v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
+ v = tl.load(v_ptr, mask=valid, other=0.0).to(tl.float32)
- # Online softmax across the BLOCK_N tokens
- # Use tl.max and tl.exp for parallel reduction
- old_max = acc_max
- new_max = tl.maximum(old_max, scores) # (BLOCK_N,) per-element max
-
- # Online softmax must be done sequentially or carefully
- # We need the running max across all BLOCK_N elements
- # This is tricky: we can't just do elementwise online softmax
-
- # Instead: compute max, then compute exp, then sum for this block
- # Update online softmax:
- # For the block, we process BLOCK_N tokens. The max across all of them,
- # then update acc with exp-weighted sum.
-
- # Simplest correct approach: for small BLOCK_N, do sequential
- for s in tl.range(PAGE_SIZE):
- qk = scores[s]
- if s >= PAGE_SIZE:
- tl.static_assert(True) # should not happen, but triton needs this
+ # Online softmax: qk=0 for invalid tokens, exp(0-max) handled
+ # When valid=False: qk=0 from masked load, but we still have nonzero result
+ # Need to set qk=-inf for invalid tokens to avoid affecting softmax
+ qk = tl.where(valid, qk, -1e30)
... (1666 more chars){
"content": "<4576 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && echo \"---\" && uv run python benchmark.py 2>&1 | grep solution",
"description": "Check correctness and benchmark",
"timeout": 300000
}PASS
---
shape=0 variant=solution tflops=0.158 gbps=39.546 ms=0.852
shape=0 solution_peak_fraction=0.0220
shape=1 variant=solution tflops=0.653 gbps=163.655 ms=1.643
shape=1 solution_peak_fraction=0.0909
shape=2 variant=solution tflops=0.162 gbps=20.264 ms=3.318
shape=2 solution_peak_fraction=0.0113
shape=3 variant=solution tflops=0.324 gbps=81.129 ms=1.243
shape=3 solution_peak_fraction=0.0451
shape=4 variant=solution tflops=0.042 gbps=10.460 ms=1.569
shape=4 solution_peak_fraction=0.0058
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,104 +1,395 @@
-"""Optimized Triton paged attention decode kernel.
-
-Key optimizations over naive:
-1. Each program handles one (batch, head) but pre-caches page_id at start of each page
-2. Use triton's constexpr PAGE_SIZE loop for full unrolling (16 iterations max)
-3. Reduce memory reads: block_table only readed once per page, not per token
-4. Cache page_base in register (avoids recomputing ptr arithmetic per token)
-"""
+"""High-performance CUDA C++ paged attention decode kernel."""
import math
-
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
-
def _make_scale(head_dim: int) -> float:
return 1.0 / math.sqrt(head_dim)
-
-@triton.jit
-def _paged_decode_kernel(
- Out,
- Q,
- KV,
- BlockTable,
- SeqLens,
- # strides
- stride_qb, stride_qh, stride_qd,
- stride_kvb, stride_kvs, stride_kvh, stride_kvd,
- stride_ob, stride_oh, stride_od,
- stride_bt_b,
- # constexpr
- HEAD_DIM: tl.constexpr,
- PAGE_SIZE: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
- SCALE: tl.constexpr,
-):
- pid_b = tl.program_id(0)
- pid_h = tl.program_id(1)
-
- pid_hkv = pid_h // GROUP_SIZE
-
- seq_len = tl.load(SeqLens + pid_b)
-
- d_offs = tl.arange(0, HEAD_DIM)
-
- # Load Q: (D,)
- q_ptr = Q + pid_b * stride_qb + pid_h * stride_qh + d_offs * stride_qd
- q = tl.load(q_ptr).to(tl.float32)
-
- bt_base = pid_b * stride_bt_b
-
- num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE
-
- acc_num = tl.zeros([HEAD_DIM], dtype=tl.float32)
- acc_denom = 0.0
- acc_max = -1e30
-
- # Loop over pages - each iteration loads page_id and processes PAGE_SIZE tokens
- for page_idx in tl.range(num_pages):
- # Load page_id ONCE per page (not per token)
- page_id = tl.load(BlockTable + bt_base + page_idx)
- # Cache page_base in register
- page_base = page_id * stride_kvb + pid_hkv * stride_kvh
-
- # Unroll PAGE_SIZE tokens
- for s in tl.range(PAGE_SIZE):
- global_tok = page_idx * PAGE_SIZE + s
- valid = global_tok < seq_len
-
- # K pointer
- k_ptr = KV + page_base + s * stride_kvs + d_offs * stride_kvd
- k = tl.load(k_ptr, mask=valid, other=0.0).to(tl.float32)
-
- # QK dot
- qk = tl.sum(q * k) * SCALE
-
- # V pointer (K and V contiguous in last dim)
- v_ptr = KV + page_base + s * stride_kvs + (d_offs + HEAD_DIM) * stride_kvd
- v = tl.load(v_ptr, mask=valid, other=0.0).to(tl.float32)
-
- # Online softmax: qk=0 for invalid tokens, exp(0-max) handled
- # When valid=False: qk=0 from masked load, but we still have nonzero result
- # Need to set qk=-inf for invalid tokens to avoid affecting softmax
- qk = tl.where(valid, qk, -1e30)
-
- old_max = acc_max
- new_max = tl.maximum(old_max, qk)
- exp_val = tl.exp(qk - new_max)
- rescale = tl.exp(old_max - new_max)
-
- acc_num = acc_num * rescale + v * exp_val
- acc_denom = acc_denom * rescale + exp_val
- acc_max = new_max
-
- out = acc_num / acc_denom
-
- out_ptr = Out + pid_b * stride_ob + pid_h * stride_oh + d_offs * stride_od
- tl.store(out_ptr, out.to(Out.type.element_ty))
-
+# The CUDA source. One thread block per (batch, kv_head). Each thread handles
+# one query head in the GQA group. Shared memory holds K/V page for current iteration.
+
+_CUDA_SRC = r"""
+#include <cuda_runtime.h>
+#include <cuda_bf16.h>
+#include <math.h>
+#include <stdio.h>
+
+using bfloat = __nv_bfloat16;
+
+__forceinline__ __device__ float bf16_to_float(bfloat v) {
+ return __bfloat162float(v);
+}
+
+__forceinline__ __device__ bfloat float_to_bfloat(float v) {
+ return __float2bfloat16(v);
+}
+
+__global__ void paged_decode_kernel(
+ bfloat* __restrict__ Out,
+ const bfloat* __restrict__ Q,
+ const bfloat* __restrict__ KV,
+ const int* __restrict__ BlockTable,
+ const int* __restrict__ SeqLens,
+ int B, int H, int Hkv, int D, int P, float scale,
+ int stride_qb, int stride_qh, int stride_qd,
+ int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
+ int stride_bt_b, int stride_bt_s,
+ int stride_ob, int stride_oh, int stride_od
+) {
+ // grid: (B, Hkv), block: (group_size)
+ extern __shared__ char smem_raw[];
+ bfloat* __restrict__ smem = (bfloat*)smem_raw;
+
+ int pid_b = blockIdx.x;
+ int pid_hkv = blockIdx.y;
+ int group_size = H / Hkv;
+
+ int h_in_group = threadIdx.x;
+ if (h_in_group >= group_size) return;
+
+ int seq_len = SeqLens[pid_b];
+ int num_pages = (seq_len + P - 1) / P;
+
+ // Shared memory layout:
+ // [0: P*D) = K page
+ // [P*D: 2*P*D) = V page
+ int KV_stride = P * D;
+ bfloat* __restrict__ sK = smem;
+ bfloat* __restrict__ sV = smem + KV_stride;
+
+ // Each thread gets its Q vector
+ float q[128];
+ const bfloat* q_base = Q + pid_b * stride_qb + (pid_hkv * group_size + h_in_group) * stride_qh;
+ for (int d = 0; d < D; d++) {
+ q[d] = bf16_to_float(q_base[d]);
+ }
+
+ // Accumulators per thread
+ float acc_num[128] = {0};
+ float acc_denom = 1e-30f;
+ float acc_max = -1e30f;
+
+ // Process each page
+ for (int pg = 0; pg < num_pages; pg++) {
+ int block_id = BlockTable[pid_b * stride_bt_b + pg * stride_bt_s];
+
+ // Load K page into shared memory cooperatively
+ // Total elements: P * D. Divide among group_size threads.
+ int total_k_items = P * D;
+ int items_per_thread = (total_k_items + group_size - 1) / group_size;
+
+ for (int i = 0; i < items_per_thread; i++) {
+ int idx = h_in_group * items_per_thread + i;
+ if (idx >= total_k_i
... (11741 more chars){
"content": "<13781 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA solution correctness",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 72, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 414, in forward
_get_module().paged_attention_forward(out, query, kv_cache, block_table, seq_le
... (4056 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -312,7 +312,7 @@
int stride_ob, int stride_oh, int stride_od
) {
int group_size = H / Hkv;
- int smem_size = 2 * P * D * sizeoof(bfloat);
+ int smem_size = 2 * P * D * sizeof(bfloat);
dim3 grid(B, Hkv, 1);
dim3 block(group_size, 1, 1);
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"newString": " int smem_size = 2 * P * D * sizeof(bfloat);",
"oldString": " int smem_size = 2 * P * D * sizeoof(bfloat);"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test fixed CUDA solution",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1 | grep -E \"solution|RESULT\"",
"description": "Benchmark CUDA solution",
"timeout": 300000
}shape=0 variant=solution tflops=0.014 gbps=3.504 ms=9.614
shape=0 solution_peak_fraction=0.0019
shape=1 variant=solution tflops=0.056 gbps=13.902 ms=19.346
shape=1 solution_peak_fraction=0.0077
shape=2 variant=solution tflops=0.017 gbps=2.170 ms=30.986
shape=2 solution_peak_fraction=0.0012
shape=3 variant=solution tflops=0.028 gbps=6.988 ms=14.434
shape=3 solution_peak_fraction=0.0039
shape=4 variant=solution tflops=0.010 gbps=2.482 ms=6.614
shape=4 solution_peak_fraction=0.0014
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -2,30 +2,39 @@
import math
import torch
import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
def _make_scale(head_dim: int) -> float:
return 1.0 / math.sqrt(head_dim)
-# The CUDA source. One thread block per (batch, kv_head). Each thread handles
-# one query head in the GQA group. Shared memory holds K/V page for current iteration.
-
-_CUDA_SRC = r"""
+_module = None
+
+def _compile():
+ CUDA_SRC = r"""
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <math.h>
-#include <stdio.h>
using bfloat = __nv_bfloat16;
-__forceinline__ __device__ float bf16_to_float(bfloat v) {
- return __bfloat162float(v);
-}
-
-__forceinline__ __device__ bfloat float_to_bfloat(float v) {
- return __float2bfloat16(v);
-}
-
-__global__ void paged_decode_kernel(
+/*
+ * Design: one thread block per (batch, head).
+ * - 128 threads per block (4 warps, each warp handles 32 elements of head_dim)
+ * - Load one K/V page tile into shared memory cooperatively
+ * - Process seq_len tokens, accumulate online softmax
+ * - GQA: all heads sharing a KV head will compete for shared memory bandwidth,
+ * but each computes independently.
+ *
+ * Shared memory layout:
+ * smem[0:P*D] = K page (PAGE_SIZE x HEAD_DIM)
+ * smem[P*D:2*P*D] = V page (PAGE_SIZE x HEAD_DIM)
+ *
+ * Each thread handles D/128 elements of the head_dim vector.
+ */
+#define MAX_HEAD_DIM 128
+#define BLOCK_THREADS 128
+
+__global__ void __launch_bounds__(BLOCK_THREADS) paged_decode_kernel(
bfloat* __restrict__ Out,
const bfloat* __restrict__ Q,
const bfloat* __restrict__ KV,
@@ -37,269 +46,125 @@
int stride_bt_b, int stride_bt_s,
int stride_ob, int stride_oh, int stride_od
) {
- // grid: (B, Hkv), block: (group_size)
+ int pid_b = blockIdx.x;
+ int pid_h = blockIdx.y;
+ int tid = threadIdx.x;
+
+ int seq_len = SeqLens[pid_b];
+ int num_pages = (seq_len + P - 1) / P;
+
+ // Shared memory
extern __shared__ char smem_raw[];
bfloat* __restrict__ smem = (bfloat*)smem_raw;
-
- int pid_b = blockIdx.x;
- int pid_hkv = blockIdx.y;
- int group_size = H / Hkv;
-
- int h_in_group = threadIdx.x;
- if (h_in_group >= group_size) return;
-
- int seq_len = SeqLens[pid_b];
- int num_pages = (seq_len + P - 1) / P;
-
- // Shared memory layout:
- // [0: P*D) = K page
- // [P*D: 2*P*D) = V page
- int KV_stride = P * D;
bfloat* __restrict__ sK = smem;
- bfloat* __restrict__ sV = smem + KV_stride;
-
- // Each thread gets its Q vector
- float q[128];
- const bfloat* q_base = Q + pid_b * stride_qb + (pid_hkv * group_size + h_in_group) * stride_qh;
- for (int d = 0; d < D; d++) {
- q[d] = bf16_to_float(q_base[d]);
+ bfloat* __restrict__ sV = smem + P * D;
+
+ // Each thread owns D/BLOCK_THREADS elements
+ int items_per_thread = (D + BLOCK_THREADS - 1) / BLOCK_THREADS;
+ int start_d = tid * items_per_thread;
+ int end_d = min(start_d + items_per_thread, D);
+ int my_items = end_d - start_d;
+
+ // Load Q elements for this thread's range
+ float q_local[8]; // MAX_HEAD_DIM / BLOCK_THREADS = 1 when 128 threads and D=128
+ if (tid < D) {
+ const bfloat* q_base = Q + pid_b * stride_qb + pid_h * stride_qh;
+ for (int j = 0; j < my_items; j++)
+ q_local[j] = __bfloat162float(q_base[start_d + j]);
}
- // Accumulators per thread
- float acc_num[128] = {0};
+ // Accumulators for this thread's range
+ float acc_num_local[8] = {0};
float acc_denom = 1e-30f;
float acc_max = -1e30f;
// Process each page
- for (int pg = 0; pg < num_pages; pg++) {
- int block_id = BlockTable[pid_b * stride_bt_b + pg * stride_bt_s];
-
- // Load K page into shared memory cooperatively
- // Total elements: P * D. Divide among group_size threads.
- int total_k_items = P * D;
- int items_per_thread = (total_k_items + group_size - 1) / group_size;
-
- for (int i = 0; i < items_per_thread; i++) {
- int idx = h_in_group * items_per_thread + i;
- if (idx >= total_k_items) break;
-
+ for (int pg_idx = 0; pg_idx < num_pages; pg_idx++) {
+ int page_id = BlockTable[pid_b * stride_bt_b + pg_idx * stride_bt_s];
+
+ // Load K page into smem cooperatively
+ // Each thread loads P*D/BLOCK_THREADS elements
+ int total_kv = P * D;
+ int load_per_thread = (total_kv + BLOCK_THREADS - 1) / BLOCK_THREADS;
+ for (int i = 0; i < load_per_thread; i++) {
+ int idx = tid * load_per_thread + i;
+ if (idx >= total_kv) break;
int row = idx / D;
int col = idx % D;
-
- // K is at KV[block_id, row, pid_hkv, col]
- // offset = block_id * stride_kvb + row * stride_kvs + pid_hkv * stride_kvh + col * stride_kvd
- int k_off = block_id * stride_kvb + row * stride_kvs + pid_hkv * stride_kvh + col * stride_kvd;
- // V is at offset + D (contiguous in last dim)
- int v_off = k_off + D;
-
- sK[row * D + col] = KV[k_off];
- sV[row * D + col] = KV[v_off];
+ int base_off = page_id * stride_kvb + row * stride_kvs + (KV + 0) == 0 ? 0 : 0;
+ int k_off = page_id * stride_kvb + row * stride_kvs + (KV + 0 - KV) * stride_kvd + col * stride_kvd;
+ // Fix: proper offset computation
+ int kv_row_offset = page_id * stride_kvb + row * stride_kvs;
+ int k_idx = kv_row_offset + col * stride_kvd;
+ int v_idx = k_idx + D * stride_kvd;
+ sK[row * D + col] = KV[k_idx];
+ sV[row * D + col] = KV[v_idx];
}
__syncthreads();
- // Each thread computes attention for its head
- // Process ea
... (10907 more chars){
"content": "<9704 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA v2 solution",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1 | grep -E \"solution|RESULT\"",
"description": "Benchmark CUDA v2",
"timeout": 300000
}shape=0 variant=solution tflops=0.204 gbps=51.097 ms=0.659
shape=0 solution_peak_fraction=0.0284
shape=1 variant=solution tflops=0.561 gbps=140.485 ms=1.915
shape=1 solution_peak_fraction=0.0780
shape=2 variant=solution tflops=0.205 gbps=25.731 ms=2.613
shape=2 solution_peak_fraction=0.0143
shape=3 variant=solution tflops=0.349 gbps=87.526 ms=1.152
shape=3 solution_peak_fraction=0.0486
shape=4 variant=solution tflops=0.060 gbps=15.028 ms=1.092
shape=4 solution_peak_fraction=0.0083
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,4 +1,12 @@
-"""High-performance CUDA C++ paged attention decode kernel."""
+"""Optimized CUDA C++ paged attention decode kernel.
+
+Design:
+- One thread block per (batch, kv_head)
+- 128 threads cooperate: load K/V page into shared memory, then each thread
+ computes its query head's attention output
+- Shared memory holds one K/V page (PAGE_SIZE x HEAD_DIM each)
+- All threads share the K/V load, reducing bandwidth
+"""
import math
import torch
import torch.nn as nn
@@ -18,173 +26,121 @@
using bfloat = __nv_bfloat16;
/*
- * Design: one thread block per (batch, head).
- * - 128 threads per block (4 warps, each warp handles 32 elements of head_dim)
- * - Load one K/V page tile into shared memory cooperatively
- * - Process seq_len tokens, accumulate online softmax
- * - GQA: all heads sharing a KV head will compete for shared memory bandwidth,
- * but each computes independently.
+ * Block layout: 128 threads per (batch, kv_head).
+ * First GROUP_SIZE threads each handle one query head's output.
+ * All 128 threads cooperate to load K/V into shared memory.
*
- * Shared memory layout:
- * smem[0:P*D] = K page (PAGE_SIZE x HEAD_DIM)
- * smem[P*D:2*P*D] = V page (PAGE_SIZE x HEAD_DIM)
- *
- * Each thread handles D/128 elements of the head_dim vector.
+ * Shared memory: sK[PAGE_SIZE x HEAD_DIM], sV[PAGE_SIZE x HEAD_DIM]
+ * Each thread loads a portion of the page.
*/
-#define MAX_HEAD_DIM 128
-#define BLOCK_THREADS 128
-
-__global__ void __launch_bounds__(BLOCK_THREADS) paged_decode_kernel(
+__global__ void __launch_bounds__(128) paged_decode_kernel(
bfloat* __restrict__ Out,
const bfloat* __restrict__ Q,
const bfloat* __restrict__ KV,
const int* __restrict__ BlockTable,
const int* __restrict__ SeqLens,
- int B, int H, int Hkv, int D, int P, float scale,
- int stride_qb, int stride_qh, int stride_qd,
- int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
- int stride_bt_b, int stride_bt_s,
- int stride_ob, int stride_oh, int stride_od
+ int B, int H, int Hkv, int D, int P, float scale
) {
+ int group_size = H / Hkv;
+
int pid_b = blockIdx.x;
- int pid_h = blockIdx.y;
- int tid = threadIdx.x;
+ int pid_hkv = blockIdx.y;
+
+ int tid = threadIdx.x; // 0..127
+ int num_threads = 128;
+
+ // Each thread loads a portion of K/V page
+ // Total elements per page: P * D (K) + P * D (V) = 2 * P * D
+ int total_kv = 2 * P * D;
+ int load_per_thread = (total_kv + num_threads - 1) / num_threads;
+
+ // Shared memory
+ extern __shared__ char smem_raw[];
+ bfloat* __restrict__ sK = (bfloat*)smem_raw;
+ bfloat* __restrict__ sV = sK + P * D;
int seq_len = SeqLens[pid_b];
int num_pages = (seq_len + P - 1) / P;
- // Shared memory
- extern __shared__ char smem_raw[];
- bfloat* __restrict__ smem = (bfloat*)smem_raw;
- bfloat* __restrict__ sK = smem;
- bfloat* __restrict__ sV = smem + P * D;
-
- // Each thread owns D/BLOCK_THREADS elements
- int items_per_thread = (D + BLOCK_THREADS - 1) / BLOCK_THREADS;
- int start_d = tid * items_per_thread;
- int end_d = min(start_d + items_per_thread, D);
- int my_items = end_d - start_d;
-
- // Load Q elements for this thread's range
- float q_local[8]; // MAX_HEAD_DIM / BLOCK_THREADS = 1 when 128 threads and D=128
- if (tid < D) {
- const bfloat* q_base = Q + pid_b * stride_qb + pid_h * stride_qh;
- for (int j = 0; j < my_items; j++)
- q_local[j] = __bfloat162float(q_base[start_d + j]);
+ // Thread tid < group_size handles query head pid_hkv*group_size+tid
+ bool is_worker = (tid < group_size);
+
+ // Load Q for this thread's head
+ float q_local[128];
+ if (is_worker) {
+ int h = pid_hkv * group_size + tid;
+ const bfloat* q_base = Q + pid_b * H * D + h * D; // contiguous layout
+ for (int j = 0; j < D; j++)
+ q_local[j] = __bfloat162float(q_base[j]);
}
- // Accumulators for this thread's range
- float acc_num_local[8] = {0};
+ // Each thread accumulates for its own head's range
+ float acc_num[128];
+ for (int j = 0; j < 128; j++) acc_num[j] = 0.0f;
float acc_denom = 1e-30f;
float acc_max = -1e30f;
- // Process each page
- for (int pg_idx = 0; pg_idx < num_pages; pg_idx++) {
- int page_id = BlockTable[pid_b * stride_bt_b + pg_idx * stride_bt_s];
-
- // Load K page into smem cooperatively
- // Each thread loads P*D/BLOCK_THREADS elements
- int total_kv = P * D;
- int load_per_thread = (total_kv + BLOCK_THREADS - 1) / BLOCK_THREADS;
+ // Process pages
+ for (int pg = 0; pg < num_pages; pg++) {
+ int page_id = BlockTable[pid_b * group_size + pg]; // max_blocks per batch
+
+ // Cooperative load of K/V page into shared memory
for (int i = 0; i < load_per_thread; i++) {
int idx = tid * load_per_thread + i;
if (idx >= total_kv) break;
- int row = idx / D;
- int col = idx % D;
- int base_off = page_id * stride_kvb + row * stride_kvs + (KV + 0) == 0 ? 0 : 0;
- int k_off = page_id * stride_kvb + row * stride_kvs + (KV + 0 - KV) * stride_kvd + col * stride_kvd;
- // Fix: proper offset computation
- int kv_row_offset = page_id * stride_kvb + row * stride_kvs;
- int k_idx = kv_row_offset + col * stride_kvd;
- int v_idx = k_idx + D * stride_kvd;
- sK[row * D + col] = KV[k_idx];
- sV[row * D + col] = KV[v_idx];
+
+ int slot = idx % D;
+ int tmp = idx / D;
+ int row = tmp % P;
+ int kv_part = tmp / P; // 0=K, 1=V
+
+ int base_off = page_id * P * Hkv * 2 * D + row * H
... (6484 more chars){
"content": "<7441 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA v3",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 72, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 227, in forward
_get_module().paged_attention_forward(out, query, kv_cache, block_table, seq_lens, self.scale)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 205, in _get_module
_module = _compile()
^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 194, in _compile
return load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^
... (969 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,12 +1,4 @@
-"""Optimized CUDA C++ paged attention decode kernel.
-
-Design:
-- One thread block per (batch, kv_head)
-- 128 threads cooperate: load K/V page into shared memory, then each thread
- computes its query head's attention output
-- Shared memory holds one K/V page (PAGE_SIZE x HEAD_DIM each)
-- All threads share the K/V load, reducing bandwidth
-"""
+"""Optimized CUDA C++ paged attention decode kernel."""
import math
import torch
import torch.nn as nn
@@ -26,96 +18,96 @@
using bfloat = __nv_bfloat16;
/*
- * Block layout: 128 threads per (batch, kv_head).
- * First GROUP_SIZE threads each handle one query head's output.
- * All 128 threads cooperate to load K/V into shared memory.
- *
- * Shared memory: sK[PAGE_SIZE x HEAD_DIM], sV[PAGE_SIZE x HEAD_DIM]
- * Each thread loads a portion of the page.
+ * One thread block per (batch, kv_head).
+ * 128 threads: all load K/V page cooperatively, first GROUP_SIZE threads
+ * each compute one query head's output.
+ *
+ * KV cache: (num_blocks, page_size, num_kv_heads, 2*head_dim) bf16
+ * contiguous in last dim: [K_0,...,K_{D-1}, V_0,...,V_{D-1}]
*/
-__global__ void __launch_bounds__(128) paged_decode_kernel(
+
+__global__ void paged_decode_kernel(
bfloat* __restrict__ Out,
const bfloat* __restrict__ Q,
const bfloat* __restrict__ KV,
const int* __restrict__ BlockTable,
const int* __restrict__ SeqLens,
- int B, int H, int Hkv, int D, int P, float scale
+ int B, int H, int Hkv, int D, int P, float scale,
+ int max_blocks_per_batch,
+ bfloat* __restrict__ smem_KV // externally allocated shared memory
) {
int group_size = H / Hkv;
-
int pid_b = blockIdx.x;
int pid_hkv = blockIdx.y;
-
- int tid = threadIdx.x; // 0..127
- int num_threads = 128;
-
- // Each thread loads a portion of K/V page
- // Total elements per page: P * D (K) + P * D (V) = 2 * P * D
+ int tid = threadIdx.x;
+ int num_threads = blockDim.x;
+
+ int seq_len = SeqLens[pid_b];
+ int num_pages = (seq_len + P - 1) / P;
+
+ // Shared memory pointers
+ bfloat* __restrict__ sK = smem_KV;
+ bfloat* __restrict__ sV = smem_KV + P * D;
+
+ // Each thread loads portion of K/V page (combined K+V = 2*P*D elements)
int total_kv = 2 * P * D;
int load_per_thread = (total_kv + num_threads - 1) / num_threads;
- // Shared memory
- extern __shared__ char smem_raw[];
- bfloat* __restrict__ sK = (bfloat*)smem_raw;
- bfloat* __restrict__ sV = sK + P * D;
-
- int seq_len = SeqLens[pid_b];
- int num_pages = (seq_len + P - 1) / P;
-
- // Thread tid < group_size handles query head pid_hkv*group_size+tid
+ // Worker threads compute attention (one per query head in GQA group)
bool is_worker = (tid < group_size);
- // Load Q for this thread's head
- float q_local[128];
+ // Load Q for worker threads
+ float q_vec[128];
if (is_worker) {
int h = pid_hkv * group_size + tid;
- const bfloat* q_base = Q + pid_b * H * D + h * D; // contiguous layout
+ const bfloat* q_base = Q + pid_b * stride_qb + h * stride_qh;
for (int j = 0; j < D; j++)
- q_local[j] = __bfloat162float(q_base[j]);
- }
-
- // Each thread accumulates for its own head's range
+ q_vec[j] = __bfloat162float(q_base[j * stride_qd]);
+ }
+
+ // Local accumulators
float acc_num[128];
- for (int j = 0; j < 128; j++) acc_num[j] = 0.0f;
+ for (int j = 0; j < D; j++) acc_num[j] = 0.0f;
float acc_denom = 1e-30f;
float acc_max = -1e30f;
// Process pages
for (int pg = 0; pg < num_pages; pg++) {
- int page_id = BlockTable[pid_b * group_size + pg]; // max_blocks per batch
-
- // Cooperative load of K/V page into shared memory
+ int page_id = BlockTable[pid_b * max_blocks_per_batch + pg];
+
+ // Load K/V page cooperatively into shared memory
for (int i = 0; i < load_per_thread; i++) {
int idx = tid * load_per_thread + i;
if (idx >= total_kv) break;
- int slot = idx % D;
+ int col = idx % D;
int tmp = idx / D;
int row = tmp % P;
- int kv_part = tmp / P; // 0=K, 1=V
-
- int base_off = page_id * P * Hkv * 2 * D + row * Hkv * 2 * D + pid_hkv * 2 * D;
- // KV cache: [page_id, row, pid_hkv, slot] for K, [page_id, row, pid_hkv, slot] for V
-
- if (kv_part == 0) {
- sK[row * D + slot] = KV[base_off + slot];
- } else {
- sV[row * D + slot] = KV[base_off + D + slot];
- }
- }
- __syncthreads();
-
- // Each processing thread computes QK^T for its head using shared K/V
+ int kv_half = tmp / P; // 0=K, 1=V
+
+ // KV offset: page_id * stride_kvb + row * stride_kvs + pid_hkv * stride_kvh + col * stride_kvd
+ // K is at col, V is at col + D
+ int base = page_id * stride_kvb + row * stride_kvs + pid_hkv * stride_kvh;
+ int k_idx = base + col * stride_kvd;
+ int v_idx = base + (D + col) * stride_kvd;
+
+ // Store into shared memory: row-major, D columns
+ if (kv_half == 0)
+ sK[row * D + col] = KV[k_idx];
+ else
+ sV[row * D + col] = KV[v_idx];
+ }
+ __syncthreads();
+
if (is_worker) {
- int tokens_in_page = min((pg + 1) * P, seq_len) - pg * P;
- if (tokens_in_page < 0) tokens_in_page = 0;
-
- for (int t = 0; t < tokens_in_page; t++) {
- // QK^T
+ int tokens_in = min((pg + 1) * P, seq_len) - pg * P;
+ if (tokens_in < 0) tokens_in = 0;
+
+ for (int t = 0; t < tokens_in; t++) {
flo
... (12611 more chars){
"content": "<16212 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA v3 solution",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1 | grep -E \"solution|RESULT\"",
"description": "Benchmark CUDA v3",
"timeout": 300000
}shape=0 variant=solution tflops=0.024 gbps=5.988 ms=5.625
shape=0 solution_peak_fraction=0.0033
shape=1 variant=solution tflops=0.089 gbps=22.286 ms=12.069
shape=1 solution_peak_fraction=0.0124
shape=2 variant=solution tflops=0.024 gbps=2.994 ms=22.458
shape=2 solution_peak_fraction=0.0017
shape=3 variant=solution tflops=0.048 gbps=11.932 ms=8.453
shape=3 solution_peak_fraction=0.0066
shape=4 variant=solution tflops=0.022 gbps=5.459 ms=3.007
shape=4 solution_peak_fraction=0.0030
RESULT: LOW
{
"command": "<836 chars \u2014 see below>",
"description": "Debug CUDA v3 values",
"timeout": 120000
}uv run python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref(*inputs)
sol_out = model(*inputs)
print('sol shape:', sol_out.shape, ref_out.shape)
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
print('atol 0.02:', torch.allclose(sol_out.float(), ref_out.float(), atol=0.02))
# Print some values
print('ref[0,0,:5]:', ref_out[0,0,:5])
print('sol[0,0,:5]:', sol_out[0,0,:5])
" 2>&1sol shape: torch.Size([8, 32, 128]) torch.Size([8, 32, 128])
max abs diff: 0.01715087890625
atol 0.02: True
ref[0,0,:5]: tensor([ 0.0003, -0.0004, -0.0005, 0.0011, 0.0036], device='cuda:0',
dtype=torch.bfloat16)
sol[0,0,:5]: tensor([ 0.0003, -0.0004, -0.0005, 0.0011, 0.0036], device='cuda:0',
dtype=torch.bfloat16)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,4 +1,15 @@
-"""Optimized CUDA C++ paged attention decode kernel."""
+"""Optimized CUDA C++ paged attention decode kernel.
+
+Architecture:
+- One block per (batch, kv_head), 128 threads
+- All 128 threads cooperatively load K/V page into shared memory
+- First GROUP_SIZE threads compute attention for their query heads
+- Correct batch indexing for BlockTable (stride=1 within each batch's rows)
+
+The key performance insight: decode is memory-bound since KV cache must be streamed.
+By loading each page once into shared memory and sharing across the GQA group,
+we avoid redundant loads that would happen with per-head programs.
+"""
import math
import torch
import torch.nn as nn
@@ -17,274 +28,6 @@
using bfloat = __nv_bfloat16;
-/*
- * One thread block per (batch, kv_head).
- * 128 threads: all load K/V page cooperatively, first GROUP_SIZE threads
- * each compute one query head's output.
- *
- * KV cache: (num_blocks, page_size, num_kv_heads, 2*head_dim) bf16
- * contiguous in last dim: [K_0,...,K_{D-1}, V_0,...,V_{D-1}]
- */
-
-__global__ void paged_decode_kernel(
- bfloat* __restrict__ Out,
- const bfloat* __restrict__ Q,
- const bfloat* __restrict__ KV,
- const int* __restrict__ BlockTable,
- const int* __restrict__ SeqLens,
- int B, int H, int Hkv, int D, int P, float scale,
- int max_blocks_per_batch,
- bfloat* __restrict__ smem_KV // externally allocated shared memory
-) {
- int group_size = H / Hkv;
- int pid_b = blockIdx.x;
- int pid_hkv = blockIdx.y;
- int tid = threadIdx.x;
- int num_threads = blockDim.x;
-
- int seq_len = SeqLens[pid_b];
- int num_pages = (seq_len + P - 1) / P;
-
- // Shared memory pointers
- bfloat* __restrict__ sK = smem_KV;
- bfloat* __restrict__ sV = smem_KV + P * D;
-
- // Each thread loads portion of K/V page (combined K+V = 2*P*D elements)
- int total_kv = 2 * P * D;
- int load_per_thread = (total_kv + num_threads - 1) / num_threads;
-
- // Worker threads compute attention (one per query head in GQA group)
- bool is_worker = (tid < group_size);
-
- // Load Q for worker threads
- float q_vec[128];
- if (is_worker) {
- int h = pid_hkv * group_size + tid;
- const bfloat* q_base = Q + pid_b * stride_qb + h * stride_qh;
- for (int j = 0; j < D; j++)
- q_vec[j] = __bfloat162float(q_base[j * stride_qd]);
- }
-
- // Local accumulators
- float acc_num[128];
- for (int j = 0; j < D; j++) acc_num[j] = 0.0f;
- float acc_denom = 1e-30f;
- float acc_max = -1e30f;
-
- // Process pages
- for (int pg = 0; pg < num_pages; pg++) {
- int page_id = BlockTable[pid_b * max_blocks_per_batch + pg];
-
- // Load K/V page cooperatively into shared memory
- for (int i = 0; i < load_per_thread; i++) {
- int idx = tid * load_per_thread + i;
- if (idx >= total_kv) break;
-
- int col = idx % D;
- int tmp = idx / D;
- int row = tmp % P;
- int kv_half = tmp / P; // 0=K, 1=V
-
- // KV offset: page_id * stride_kvb + row * stride_kvs + pid_hkv * stride_kvh + col * stride_kvd
- // K is at col, V is at col + D
- int base = page_id * stride_kvb + row * stride_kvs + pid_hkv * stride_kvh;
- int k_idx = base + col * stride_kvd;
- int v_idx = base + (D + col) * stride_kvd;
-
- // Store into shared memory: row-major, D columns
- if (kv_half == 0)
- sK[row * D + col] = KV[k_idx];
- else
- sV[row * D + col] = KV[v_idx];
- }
- __syncthreads();
-
- if (is_worker) {
- int tokens_in = min((pg + 1) * P, seq_len) - pg * P;
- if (tokens_in < 0) tokens_in = 0;
-
- for (int t = 0; t < tokens_in; t++) {
- float qk = 0.0f;
- const bfloat* kr = sK + t * D;
- for (int j = 0; j < D; j++)
- qk += q_vec[j] * __bfloat162float(kr[j]);
- qk *= scale;
-
- float old_max = acc_max;
- float new_max = fmaxf(old_max, qk);
- float exp_v = __expf(qk - new_max);
- float rescale = __expf(old_max - new_max);
-
- const bfloat* vr = sV + t * D;
- for (int j = 0; j < D; j++)
- acc_num[j] = acc_num[j] * rescale + exp_v * __bfloat162float(vr[j]);
- acc_denom = acc_denom * rescale + exp_v;
- acc_max = new_max;
- }
- }
- __syncthreads();
- }
-
- // Write output
- if (is_worker) {
- int h = pid_hkv * group_size + tid;
- bfloat* out_ptr = Out + pid_b * stride_ob + h * stride_oh;
- float inv = 1.0f / acc_denom;
- for (int j = 0; j < D; j++)
- out_ptr[j * stride_od] = __float2bfloat16(acc_num[j] * inv);
- }
-}
-
-// We need to pass strides from Python, so we wrap the kernel launch
-extern "C" void launch_paged_decode(
- bfloat* Out,
- const bfloat* Q,
- const bfloat* KV,
- const int* BlockTable,
- const int* SeqLens,
- int B, int H, int Hkv, int D, int P, float scale,
- int max_blocks_per_batch,
- int smem_size_bytes,
- int stride_qb, int stride_qh, int stride_qd,
- int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvd,
- int stride_ob, int stride_oh, int stride_od
-) {
- // We can't pass shared memory size dynamically from a C function
- // without modifying the kernel. Let me just use a kernel template
- // or pass smem as a kernel parameter.
- // Actually, the kernel above already uses a separate smem_KV pointer.
- // We need to allocate it via cudaMalloc or use dynamic shared memory.
-}
-"""
- # The smem approach requires k
... (13157 more chars){
"content": "<4628 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"offset": 87,
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
87: // (caller does: BlockTable + pid_b * stride_bt_b)
88: // Wait, we can't do that easily. Let me rework: use max_blocks from blockTable shape
89: // passed via kernel params.
90:
91: // Simpler: just pass max_blocks and compute index
92: // But we don't have max_blocks here. Let me pass the page_ids as a flattened array
93: // at the correct offset. Actually, let me just pass the batch's row:
94: // BlockTable_batch = BlockTable + pid_b * max_blocks_per_batch
95:
96: // I'll compute max_blocks from the block_table tensor shape in the launcher
97: // and pass (BlockTable_batch, page_id = BlockTable_batch[pg]).
98: // But we can't do this at the kernel level easily.
99: //
100: // EASIEST FIX: Pass max_blocks_per_batch as a kernel parameter
101: // and compute BlockTable_batch = BlockTable + blockIdx.x * max_blocks
102:
103: // But that changes the kernel signature. Let me keep it simple:
104: // The wrapper computes BlockTable + pid_b * max_blks and passes a pointer
105: // that's already offset. Since the wrapper is in C++, this is feasible.
106:
107: // For the kernel, assume BlockTable points to THIS batch's entry.
108: // The launcher wraps: launches B*Hkv blocks, each passing BlockTable + b*max_blks
109: // But that requires dynamic grid with different par
... (570 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,15 +1,4 @@
-"""Optimized CUDA C++ paged attention decode kernel.
-
-Architecture:
-- One block per (batch, kv_head), 128 threads
-- All 128 threads cooperatively load K/V page into shared memory
-- First GROUP_SIZE threads compute attention for their query heads
-- Correct batch indexing for BlockTable (stride=1 within each batch's rows)
-
-The key performance insight: decode is memory-bound since KV cache must be streamed.
-By loading each page once into shared memory and sharing across the GQA group,
-we avoid redundant loads that would happen with per-head programs.
-"""
+"""Optimized CUDA C++ paged attention decode kernel."""
import math
import torch
import torch.nn as nn
@@ -37,8 +26,10 @@
int B, int H, int Hkv, int D, int P, float scale,
int stride_qb, int stride_qo,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
- int stride_ob, int stride_oo
+ int stride_ob, int stride_oo,
+ int stride_bt_row
) {
+ // Grid: (B, Hkv), Block: (128 threads)
int BLOCK_THREADS = blockDim.x;
int GROUP_SIZE = H / Hkv;
@@ -55,7 +46,7 @@
bool is_worker = (tid < GROUP_SIZE);
- // Load Q for this thread's head
+ // Load Q
float q[256];
if (is_worker) {
int h = pid_hkv * GROUP_SIZE + tid;
@@ -70,49 +61,191 @@
float acc_denom = 1e-30f;
float acc_max = -1e30f;
- // Cooperative K/V load
+ // K/V page load distribution
int total_items = P * D;
int per_thread = (total_items + BLOCK_THREADS - 1) / BLOCK_THREADS;
- // BlockTable is (B, max_blocks), row-major. Entry [b, pg] is at:
- // BlockTable + b * max_blocks + pg
- // But we need max_blocks stride from Python.
- // Since each row is contiguous, stride_bt is 1 within a row.
- // We need the batch stride. The kernel receives BlockTable already offset to this batch
- // via the caller passing (BlockTable + pid_b * stride_bt_b).
- // Actually, let's compute it: BlockTable is (B, max_blocks), so stride_bt_b = max_blocks.
- // Let's just pass max_blocks from the kernel launch wrapper.
+ for (int pg = 0; pg < num_pages; pg++) {
+ int page_id = BlockTable[pid_b * stride_bt_row + pg];
+
+ // Cooperative load of K/V page into shared memory
+ for (int i = 0; i < per_thread; i++) {
+ int idx = tid * per_thread + i;
+ if (idx >= total_items) break;
+ int row = idx / D;
+ int col = idx % D;
+
+ int base = page_id * stride_kvb + row * stride_kvs + col * stride_kvo;
+ sK[row * D + col] = KV[base];
+ sV[row * D + col] = KV[base + D];
+ }
+ __syncthreads();
+
+ // Compute attention
+ if (is_worker) {
+ int tok_end = min((pg + 1) * P, seq_len);
+ int tok_start = pg * P;
+ for (int t = tok_start; t < tok_end; t++) {
+ int local_t = t - tok_start;
+ float qk = 0.0f;
+ const bfloat* kr = sK + local_t * D;
+ for (int j = 0; j < D; j++)
+ qk += q[j] * __bfloat162float(kr[j]);
+ qk *= scale;
+
+ float old_max = acc_max;
+ float new_max = fmaxf(old_max, qk);
+ float exp_v = __expf(qk - new_max);
+ float rescale = __expf(old_max - new_max);
+
+ const bfloat* vr = sV + local_t * D;
+ for (int j = 0; j < D; j++)
+ acc_num[j] = acc_num[j] * rescale + exp_v * __bfloat162float(vr[j]);
+ acc_denom = acc_denom * rescale + exp_v;
+ acc_max = new_max;
+ }
+ }
+ __syncthreads();
+ }
+
+ // Write output
+ if (is_worker) {
+ int h = pid_hkv * GROUP_SIZE + tid;
+ bfloat* ob = Out + pid_b * stride_ob + h * stride_oo;
+ float inv = 1.0f / acc_denom;
+ for (int j = 0; j < D; j++)
+ ob[j] = __float2bfloat16(acc_num[j] * inv);
+ }
+}
+
+// Wrapper: called from PyTorch
+void launch_paged_decode(
+ bfloat* Out, const bfloat* Q, const bfloat* KV,
+ const int* BlockTable, const int* SeqLens,
+ int B, int H, int Hkv, int D, int P, float scale,
+ int stride_bt_row,
+ int stride_qb, int stride_qo,
+ int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
+ int stride_ob, int stride_oo
+) {
+ int smem_bytes = 2 * P * D * sizeof(bfloat);
+ dim3 grid(B, Hkv);
+ dim3 block(128, 1, 1);
+ cudaStream_t stream = 0;
+
+ // Use the non-extern-shared version by passing through the wrapper
+ // We need to call the kernel with dynamic smem
+ // In CUDA: kernel<<<grid, block, smem_bytes, stream>>>
- // For now, assume BlockTable passed is already offset to this batch
- // (caller does: BlockTable + pid_b * stride_bt_b)
- // Wait, we can't do that easily. Let me rework: use max_blocks from blockTable shape
- // passed via kernel params.
+ // Since the kernel uses extern __shared__ and we launch from here,
+ // we can pass the smem size:
+ paged_decode_kernel<<<grid, block, smem_bytes, stream>>>(
+ Out, Q, KV, BlockTable, SeqLens,
+ B, H, Hkv, D, P, scale,
+ stride_qb, stride_qo,
+ stride_kvb, stride_kvs, stride_kvh, stride_kvo,
+ stride_ob, stride_oo, stride_bt_row);
+}
+"""
+
+ CPP_SRC = r"""
+#include <torch/extension.h>
+#include <cuda_bf16.h>
+
+void launch_paged_decode(
+ __nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
+ const int* BlockTable, const int* SeqLens,
+ int B, int H, int Hkv, int D, int P, float scale,
+ int stride_bt_row,
+ int stride_qb, int stride_qo,
+ int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
+ int stride_ob, int stride_oo);
+
+void paged_attention_
... (4398 more chars){
"content": "<7883 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA v4",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1 | grep -E \"solution|RESULT\"",
"description": "Benchmark CUDA v4",
"timeout": 300000
}shape=0 variant=solution tflops=0.023 gbps=5.867 ms=5.741
shape=0 solution_peak_fraction=0.0033
shape=1 variant=solution tflops=0.087 gbps=21.776 ms=12.351
shape=1 solution_peak_fraction=0.0121
shape=2 variant=solution tflops=0.023 gbps=2.937 ms=22.895
shape=2 solution_peak_fraction=0.0016
shape=3 variant=solution tflops=0.046 gbps=11.651 ms=8.657
shape=3 solution_peak_fraction=0.0065
shape=4 variant=solution tflops=0.021 gbps=5.196 ms=3.160
shape=4 solution_peak_fraction=0.0029
RESULT: LOW
{
"command": "<825 chars \u2014 see below>",
"description": "Check parallelism per architecture"
}uv run python -c "
shapes = [
{'batch': 8, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1024, 'page_size': 16},
{'batch': 32, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 2048, 'page_size': 16},
{'batch': 4, 'num_heads': 64, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 4096, 'page_size': 16},
{'batch': 16, 'num_heads': 32, 'num_kv_heads': 8, 'head_dim': 128, 'seq_len': 1535, 'page_size': 16},
{'batch': 8, 'num_heads': 16, 'num_kv_heads': 4, 'head_dim': 64, 'seq_len': 2000, 'page_size': 16},
]
for i, s in enumerate(shapes):
grid_per_head = s['batch'] * s['num_heads']
grid_per_kv = s['batch'] * s['num_kv_heads']
print(f'Shape {i}: grid(b,h)={grid_per_head}, grid(b,kv)={grid_per_kv}, group={s[\"num_heads\"]//s[\"num_kv_heads\"]}')
" 2>&1Shape 0: grid(b,h)=256, grid(b,kv)=64, group=4
Shape 1: grid(b,h)=1024, grid(b,kv)=256, group=4
Shape 2: grid(b,h)=256, grid(b,kv)=32, group=8
Shape 3: grid(b,h)=512, grid(b,kv)=128, group=4
Shape 4: grid(b,h)=128, grid(b,kv)=32, group=4
{
"command": "<996 chars \u2014 see below>",
"description": "Time solution vs reference",
"timeout": 60000
}uv run python -c "
import torch
import reference
import solution
B, H, Hkv, D, L, P = 8, 32, 8, 128, 1024, 16
reference.BATCH, reference.NUM_HEADS, reference.NUM_KV_HEADS = B, H, Hkv
reference.HEAD_DIM, reference.SEQ_LEN, reference.PAGE_SIZE = D, L, P
model = solution.Model(B, H, Hkv, D, L, P).cuda().eval()
ref = reference.Model(B, H, Hkv, D, L, P).cuda().eval()
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
# Time each call
import time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
with torch.no_grad():
sol_out = model(*inputs)
end.record()
torch.cuda.synchronize()
print(f'Solution time: {start.elapsed_time(end):.3f} ms')
start.record()
with torch.no_grad():
ref_out = ref(*inputs)
end.record()
torch.cuda.synchronize()
print(f'Reference time: {start.elapsed_time(end):.3f} ms')
print('max abs diff:', (sol_out.float() - ref_out.float()).abs().max().item())
" 2>&1Solution time: 20.789 ms
Reference time: 66.584 ms
max abs diff: 0.0164794921875
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,4 +1,11 @@
-"""Optimized CUDA C++ paged attention decode kernel."""
+"""Paged attention decode kernel - per-head design optimized for bandwidth.
+
+Architecture: one block per (batch, query_head), 128 threads.
+- Each thread loads and computes a portion of the attention (D / BLOCK_DIM elements)
+- No shared memory - each thread loads its own K/V from global memory
+- Use __ldg (read-only data cache) for K/V loads
+- Use vectorized loads (float4 = 8 bf16) where possible
+"""
import math
import torch
import torch.nn as nn
@@ -17,7 +24,17 @@
using bfloat = __nv_bfloat16;
-__global__ void paged_decode_kernel(
+/*
+ * Design: one thread block per (batch, head), 128 threads.
+ * Each thread handles D/128 elements (1 element when D=128).
+ * No shared memory: each thread loads from global and accumulates.
+ *
+ * For D=128, each thread handles exactly 1 element of the 128-dim vector.
+ * Thread d loads KV[d], computes partial QK, and accumulates acc_num[d].
+ * We need a reduction across threads for qk.
+ */
+
+__global__ void __launch_bounds__(128) paged_decode_kernel(
bfloat* __restrict__ Out,
const bfloat* __restrict__ Q,
const bfloat* __restrict__ KV,
@@ -29,122 +46,85 @@
int stride_ob, int stride_oo,
int stride_bt_row
) {
- // Grid: (B, Hkv), Block: (128 threads)
- int BLOCK_THREADS = blockDim.x;
- int GROUP_SIZE = H / Hkv;
-
+ int BLOCK_DIM = 128;
int pid_b = blockIdx.x;
- int pid_hkv = blockIdx.y;
+ int pid_h = blockIdx.y;
int tid = threadIdx.x;
+ int pid_hkv = pid_h / (H / Hkv);
int seq_len = SeqLens[pid_b];
int num_pages = (seq_len + P - 1) / P;
- extern __shared__ bfloat smem[];
- bfloat* __restrict__ sK = smem;
- bfloat* __restrict__ sV = smem + P * D;
-
- bool is_worker = (tid < GROUP_SIZE);
-
- // Load Q
- float q[256];
- if (is_worker) {
- int h = pid_hkv * GROUP_SIZE + tid;
- const bfloat* qb = Q + pid_b * stride_qb + h * stride_qo;
- for (int j = 0; j < D; j++)
- q[j] = __bfloat162float(qb[j]);
- }
+ // Each thread works on element tid of the D-dimensional vector
+ // For D < 128, some threads are idle
+ if (tid >= D) return;
+
+ // Load this thread's Q element
+ float q_val = __bfloat162float(Q[pid_b * stride_qb + pid_h * stride_qo + tid]);
// Accumulators
- float acc_num[256];
- for (int j = 0; j < D; j++) acc_num[j] = 0.0f;
+ float acc_num_val = 0.0f;
float acc_denom = 1e-30f;
float acc_max = -1e30f;
- // K/V page load distribution
- int total_items = P * D;
- int per_thread = (total_items + BLOCK_THREADS - 1) / BLOCK_THREADS;
-
+ // Process pages and tokens
for (int pg = 0; pg < num_pages; pg++) {
int page_id = BlockTable[pid_b * stride_bt_row + pg];
- // Cooperative load of K/V page into shared memory
- for (int i = 0; i < per_thread; i++) {
- int idx = tid * per_thread + i;
- if (idx >= total_items) break;
- int row = idx / D;
- int col = idx % D;
-
- int base = page_id * stride_kvb + row * stride_kvs + col * stride_kvo;
- sK[row * D + col] = KV[base];
- sV[row * D + col] = KV[base + D];
+ int tok_end = min((pg + 1) * P, seq_len);
+ int tok_start = pg * P;
+
+ for (int t = tok_start; t < tok_end; t++) {
+ int local_t = t - tok_start;
+
+ // Load K[tid] for this token
+ int k_base = page_id * stride_kvb + local_t * stride_kvs + pid_hkv * stride_kvh;
+ float k_val = __bfloat162float(KV[k_base + tid * stride_kvo]);
+
+ // Each thread computes its partial Q * K contribution
+ float partial = q_val * k_val;
+
+ // Warp-wide reduction to compute full qk dot product
+ // Use __shfl_xor for butterfly reduction across 128 threads
+ // 128 threads = 4 warps (32 each)
+
+ // First: reduce within warp
+ #pragma unroll
+ for (int offset = 16; offset > 0; offset >>= 1)
+ partial += __shfl_xor_sync(0xffffffff, partial, offset);
+
+ // Now each warp has partial sum of its 32 elements
+ // Warp 0 holds sum of elements [0..31], warp 1: [32..63], etc.
+ // Need to get all partial sums to all threads
+
+ __shared__ float warp_sum[4]; // 128/32 = 4 warps max
+ if (tid % 32 == 0)
+ warp_sum[tid / 32] = partial;
+ __syncthreads();
+
+ float qk = 0.0f;
+ for (int w = 0; w < (D + 31) / 32 && w * 32 < D; w++)
+ qk += warp_sum[w];
+ qk *= scale;
+
+ // Now each thread loads V[tid] and does its softmax update
+ float v_val = __bfloat162float(KV[k_base + D + tid * stride_kvo]);
+
+ // All threads need consistent qk, max, exp, rescale
+ float old_max = acc_max;
+ float new_max = fmaxf(old_max, qk);
+ float exp_v = __expf(qk - new_max);
+ float rescale = __expf(old_max - new_max);
+
+ // Update: acc_num[tid] += exp_v * v_val, acc_denom += exp_v
+ acc_num_val = acc_num_val * rescale + exp_v * v_val;
+ acc_denom = acc_denom * rescale + exp_v;
+ acc_max = new_max;
}
- __syncthreads();
-
- // Compute attention
- if (is_worker) {
- int tok_end = min((pg + 1) * P, seq_len);
- int tok_start = pg * P;
- for (int t = tok_start; t < tok_end; t++) {
- int local_t = t - tok_start;
- float qk = 0.0f;
- const bfloat* kr = sK + local_t * D;
- for (int j = 0; j < D; j++)
- qk += q[j] * __bfloat162float(kr[
... (3838 more chars){
"content": "<7267 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA v5 correctness",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 72, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 209, in forward
_get_module().paged_attention_forward(out, query, kv_cache, block_table, seq_lens, self.scale)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 188, in _get_module
_module = _compile()
^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 177, in _compile
return load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^
... (978 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -133,13 +133,23 @@
#include <cuda_bf16.h>
void launch_paged_decode(
- __nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
+ bfloat* Out, const bfloat* Q, const bfloat* KV,
const int* BlockTable, const int* SeqLens,
int B, int H, int Hkv, int D, int P, float scale,
int stride_bt_row,
int stride_qb, int stride_qo,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
- int stride_ob, int stride_oo);
+ int stride_ob, int stride_oo
+) {
+ dim3 grid(B, H);
+ dim3 block(128, 1, 1);
+ paged_decode_kernel<<<grid, block>>>(
+ Out, Q, KV, BlockTable, SeqLens,
+ B, H, Hkv, D, P, scale,
+ stride_qb, stride_qo,
+ stride_kvb, stride_kvs, stride_kvh, stride_kvo,
+ stride_ob, stride_oo, stride_bt_row);
+}
void paged_attention_forward(
torch::Tensor out,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"newString": "<637 chars \u2014 see below>",
"oldString": "<359 chars \u2014 see below>"
}void launch_paged_decode(
bfloat* Out, const bfloat* Q, const bfloat* KV,
const int* BlockTable, const int* SeqLens,
int B, int H, int Hkv, int D, int P, float scale,
int stride_bt_row,
int stride_qb, int stride_qo,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
int stride_ob, int stride_oo
) {
dim3 grid(B, H);
dim3 block(128, 1, 1);
paged_decode_kernel<<<grid, block>>>(
Out, Q, KV, BlockTable, SeqLens,
B, H, Hkv, D, P, scale,
stride_qb, stride_qo,
stride_kvb, stride_kvs, stride_kvh, stride_kvo,
stride_ob, stride_oo, stride_bt_row);
}void launch_paged_decode(
__nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
const int* BlockTable, const int* SeqLens,
int B, int H, int Hkv, int D, int P, float scale,
int stride_bt_row,
int stride_qb, int stride_qo,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
int stride_ob, int stride_oo);Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA v5 correctness",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 72, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 219, in forward
_get_module().paged_attention_forward(out, query, kv_cache, block_table, seq_le
... (7691 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py</path>
<type>file</type>
<content>
1: """Paged attention decode kernel - per-head design optimized for bandwidth.
2:
3: Architecture: one block per (batch, query_head), 128 threads.
4: - Each thread loads and computes a portion of the attention (D / BLOCK_DIM elements)
5: - No shared memory - each thread loads its own K/V from global memory
6: - Use __ldg (read-only data cache) for K/V loads
7: - Use vectorized loads (float4 = 8 bf16) where possible
8: """
9: import math
10: import torch
11: import torch.nn as nn
12: from torch.utils.cpp_extension import load_inline
13:
14: def _make_scale(head_dim: int) -> float:
15: return 1.0 / math.sqrt(head_dim)
16:
17: _module = None
18:
19: def _compile():
20: CUDA_SRC = r"""
21: #include <cuda_runtime.h>
22: #include <cuda_bf16.h>
23: #include <math.h>
24:
25: using bfloat = __nv_bfloat16;
26:
27: /*
28: * Design: one thread block per (batch, head), 128 threads.
29: * Each thread handles D/128 elements (1 element when D=128).
30: * No shared memory: each thread loads from global and accumulates.
31: *
32: * For D=128, each thread handles exactly 1 element of the 128-dim vector.
33: * Thread d loads KV[d], computes partial QK, and accumulates acc_num[d].
34: * We need a reduction across threads for qk.
35: */
36:
37: __global__ void __launch_bounds__(128) paged_decode_kernel(
38: bfloat* __restrict__ Out,
39: const bfloat
... (7249 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -132,8 +132,18 @@
#include <torch/extension.h>
#include <cuda_bf16.h>
+// Kernel declaration (defined in .cu)
+__global__ void paged_decode_kernel(
+ __nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
+ const int* BlockTable, const int* SeqLens,
+ int B, int H, int Hkv, int D, int P, float scale,
+ int stride_qb, int stride_qo,
+ int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
+ int stride_ob, int stride_oo,
+ int stride_bt_row);
+
void launch_paged_decode(
- bfloat* Out, const bfloat* Q, const bfloat* KV,
+ __nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
const int* BlockTable, const int* SeqLens,
int B, int H, int Hkv, int D, int P, float scale,
int stride_bt_row,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py",
"newString": "<2119 chars \u2014 see below>",
"oldString": "<1687 chars \u2014 see below>"
} CPP_SRC = r"""
#include <torch/extension.h>
#include <cuda_bf16.h>
// Kernel declaration (defined in .cu)
__global__ void paged_decode_kernel(
__nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
const int* BlockTable, const int* SeqLens,
int B, int H, int Hkv, int D, int P, float scale,
int stride_qb, int stride_qo,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
int stride_ob, int stride_oo,
int stride_bt_row);
void launch_paged_decode(
__nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
const int* BlockTable, const int* SeqLens,
int B, int H, int Hkv, int D, int P, float scale,
int stride_bt_row,
int stride_qb, int stride_qo,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
int stride_ob, int stride_oo
) {
dim3 grid(B, H);
dim3 block(128, 1, 1);
paged_decode_kernel<<<grid, block>>>(
Out, Q, KV, BlockTable, SeqLens,
B, H, Hkv, D, P, scale,
stride_qb, stride_qo,
stride_kvb, stride_kvs, stride_kvh, stride_kvo,
stride_ob, stride_oo, stride_bt_row);
}
void paged_attention_forward(
torch::Tensor out,
torch::Tensor query,
torch::Tensor kv_cache,
torch::Tensor block_table,
torch::Tensor seq_lens,
float scale)
{
int B = query.size(0);
int H = query.size(1);
int D = query.size(2);
int Hkv = kv_cache.size(2);
int P = kv_cache.size(1);
int stride_bt_row = block_table.stride(0);
launch_paged_decode(
(__nv_bfloat16*)out.data_ptr(),
(const __nv_bfloat16*)query.data_ptr(),
(const __nv_bfloat16*)kv_cache.data_ptr(),
block_table.data_ptr<int>(),
seq_lens.data_ptr<int>(),
B, H, Hkv, D, P, scale,
stride_bt_row,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
out.stride(0), out.stride(1));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("paged_attention_forward", &paged_attention_forward, "paged_attention_forward");
}
""" CPP_SRC = r"""
#include <torch/extension.h>
#include <cuda_bf16.h>
void launch_paged_decode(
bfloat* Out, const bfloat* Q, const bfloat* KV,
const int* BlockTable, const int* SeqLens,
int B, int H, int Hkv, int D, int P, float scale,
int stride_bt_row,
int stride_qb, int stride_qo,
int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
int stride_ob, int stride_oo
) {
dim3 grid(B, H);
dim3 block(128, 1, 1);
paged_decode_kernel<<<grid, block>>>(
Out, Q, KV, BlockTable, SeqLens,
B, H, Hkv, D, P, scale,
stride_qb, stride_qo,
stride_kvb, stride_kvs, stride_kvh, stride_kvo,
stride_ob, stride_oo, stride_bt_row);
}
void paged_attention_forward(
torch::Tensor out,
torch::Tensor query,
torch::Tensor kv_cache,
torch::Tensor block_table,
torch::Tensor seq_lens,
float scale)
{
int B = query.size(0);
int H = query.size(1);
int D = query.size(2);
int Hkv = kv_cache.size(2);
int P = kv_cache.size(1);
int stride_bt_row = block_table.stride(0);
launch_paged_decode(
(__nv_bfloat16*)out.data_ptr(),
(const __nv_bfloat16*)query.data_ptr(),
(const __nv_bfloat16*)kv_cache.data_ptr(),
block_table.data_ptr<int>(),
seq_lens.data_ptr<int>(),
B, H, Hkv, D, P, scale,
stride_bt_row,
query.stride(0), query.stride(1),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
out.stride(0), out.stride(1));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("paged_attention_forward", &paged_attention_forward, "paged_attention_forward");
}
"""Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA v5 correctness",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 109, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/check.py", line 72, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py", line 229, in forward
_get_module().paged_attention_forward(out, query, kv_cache, block_table, seq_le
... (3074 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py
@@ -1,10 +1,7 @@
-"""Paged attention decode kernel - per-head design optimized for bandwidth.
+"""Per-head CUDA paged attention decode kernel.
-Architecture: one block per (batch, query_head), 128 threads.
-- Each thread loads and computes a portion of the attention (D / BLOCK_DIM elements)
-- No shared memory - each thread loads its own K/V from global memory
-- Use __ldg (read-only data cache) for K/V loads
-- Use vectorized loads (float4 = 8 bf16) where possible
+Architecture: one thread block per (batch, query_head), 128 threads.
+Each thread handles one element of D, reduces qk across threads.
"""
import math
import torch
@@ -17,22 +14,15 @@
_module = None
def _compile():
+ # Put ALL the code in the CUDA source. No C++ wrapper needed.
+ # We define the kernel AND the C++ interface function in the .cu file.
CUDA_SRC = r"""
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <math.h>
+#include <torch/extension.h>
using bfloat = __nv_bfloat16;
-
-/*
- * Design: one thread block per (batch, head), 128 threads.
- * Each thread handles D/128 elements (1 element when D=128).
- * No shared memory: each thread loads from global and accumulates.
- *
- * For D=128, each thread handles exactly 1 element of the 128-dim vector.
- * Thread d loads KV[d], computes partial QK, and accumulates acc_num[d].
- * We need a reduction across threads for qk.
- */
__global__ void __launch_bounds__(128) paged_decode_kernel(
bfloat* __restrict__ Out,
@@ -46,7 +36,6 @@
int stride_ob, int stride_oo,
int stride_bt_row
) {
- int BLOCK_DIM = 128;
int pid_b = blockIdx.x;
int pid_h = blockIdx.y;
int tid = threadIdx.x;
@@ -55,11 +44,10 @@
int seq_len = SeqLens[pid_b];
int num_pages = (seq_len + P - 1) / P;
- // Each thread works on element tid of the D-dimensional vector
- // For D < 128, some threads are idle
+ // Early exit
if (tid >= D) return;
- // Load this thread's Q element
+ // Load Q element
float q_val = __bfloat162float(Q[pid_b * stride_qb + pid_h * stride_qo + tid]);
// Accumulators
@@ -67,7 +55,10 @@
float acc_denom = 1e-30f;
float acc_max = -1e30f;
- // Process pages and tokens
+ // Warp count
+ int num_warps = (D + 31) / 32;
+ int my_warp = tid / 32;
+
for (int pg = 0; pg < num_pages; pg++) {
int page_id = BlockTable[pid_b * stride_bt_row + pg];
@@ -77,46 +68,39 @@
for (int t = tok_start; t < tok_end; t++) {
int local_t = t - tok_start;
- // Load K[tid] for this token
+ // Load K[tid]
int k_base = page_id * stride_kvb + local_t * stride_kvs + pid_hkv * stride_kvh;
float k_val = __bfloat162float(KV[k_base + tid * stride_kvo]);
- // Each thread computes its partial Q * K contribution
+ // Partial dot product
float partial = q_val * k_val;
- // Warp-wide reduction to compute full qk dot product
- // Use __shfl_xor for butterfly reduction across 128 threads
- // 128 threads = 4 warps (32 each)
-
- // First: reduce within warp
+ // Intra-warp reduction
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
partial += __shfl_xor_sync(0xffffffff, partial, offset);
- // Now each warp has partial sum of its 32 elements
- // Warp 0 holds sum of elements [0..31], warp 1: [32..63], etc.
- // Need to get all partial sums to all threads
-
- __shared__ float warp_sum[4]; // 128/32 = 4 warps max
+ // Each warp leader writes its sum to shared memory
+ __shared__ float warp_sum[4];
if (tid % 32 == 0)
warp_sum[tid / 32] = partial;
__syncthreads();
+ // Compute full qk reduction
float qk = 0.0f;
- for (int w = 0; w < (D + 31) / 32 && w * 32 < D; w++)
+ for (int w = 0; w < num_warps; w++)
qk += warp_sum[w];
qk *= scale;
- // Now each thread loads V[tid] and does its softmax update
+ // Load V[tid]
float v_val = __bfloat162float(KV[k_base + D + tid * stride_kvo]);
- // All threads need consistent qk, max, exp, rescale
+ // Online softmax (consistent across all threads)
float old_max = acc_max;
float new_max = fmaxf(old_max, qk);
float exp_v = __expf(qk - new_max);
float rescale = __expf(old_max - new_max);
- // Update: acc_num[tid] += exp_v * v_val, acc_denom += exp_v
acc_num_val = acc_num_val * rescale + exp_v * v_val;
acc_denom = acc_denom * rescale + exp_v;
acc_max = new_max;
@@ -126,41 +110,8 @@
// Write output
Out[pid_b * stride_ob + pid_h * stride_oo + tid] = __float2bfloat16(acc_num_val / acc_denom);
}
-"""
- CPP_SRC = r"""
-#include <torch/extension.h>
-#include <cuda_bf16.h>
-
-// Kernel declaration (defined in .cu)
-__global__ void paged_decode_kernel(
- __nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
- const int* BlockTable, const int* SeqLens,
- int B, int H, int Hkv, int D, int P, float scale,
- int stride_qb, int stride_qo,
- int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
- int stride_ob, int stride_oo,
- int stride_bt_row);
-
-void launch_paged_decode(
- __nv_bfloat16* Out, const __nv_bfloat16* Q, const __nv_bfloat16* KV,
- const int* BlockTable, const int* SeqLens,
- int B, int H, int Hkv, int D, int P, float scale,
- int stride_bt_row,
- int stride_qb, int stride_qo,
- int stride_kvb, int stride_kvs, int stride_kvh, int stride_kvo,
- i
... (1591 more chars){
"content": "<5866 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/03_paged_attention/solution.py"
}Wrote file successfully.