"""Paged attention decode kernel via Triton (GQA-aware, page-batched)."""
import math
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
BATCH = 8
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
SEQ_LEN = 1024
PAGE_SIZE = 16
@triton.jit
def _paged_decode_gqa_kernel(
Q_ptr,
KV_ptr,
BT_ptr,
SL_ptr,
O_ptr,
stride_qb,
stride_qh,
stride_qd,
stride_kv_block,
stride_kv_page,
stride_kv_h,
stride_kv_d,
stride_bt_b,
stride_bt_p,
stride_ob,
stride_oh,
stride_od,
num_heads,
head_dim,
page_size,
group_size,
scale,
BLOCK_G: tl.constexpr,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
batch_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
seq_len = tl.load(SL_ptr + batch_idx)
num_pages = (seq_len + page_size - 1) // page_size
offs_g = tl.arange(0, BLOCK_G)
offs_t = tl.arange(0, BLOCK_T)
offs_d = tl.arange(0, BLOCK_D)
head_ids = kv_head_idx * group_size + offs_g
g_mask = head_ids < num_heads
d_mask = offs_d < head_dim
q_ptrs = (
Q_ptr
+ batch_idx * stride_qb
+ head_ids[:, None] * stride_qh
+ offs_d[None, :] * stride_qd
)
q = tl.load(q_ptrs, mask=g_mask[:, None] & d_mask[None, :], other=0.0)
m_i = tl.full([BLOCK_G], -float("inf"), tl.float32)
l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
acc = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
for page_idx in range(num_pages):
block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64)
page_start = page_idx * page_size
tokens_in_page = tl.minimum(page_size, seq_len - page_start)
kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
t_valid_row = offs_t[None, :] < tokens_in_page
t_valid_2d = offs_t[:, None] < tokens_in_page
k_ptrs = (
KV_ptr
+ kv_h_base
+ offs_t[:, None] * stride_kv_page
+ offs_d[None, :] * stride_kv_d
)
v_ptrs = (
KV_ptr
+ kv_h_base
+ offs_t[:, None] * stride_kv_page
+ (head_dim + offs_d)[None, :] * stride_kv_d
)
kv_mask = t_valid_2d & d_mask[None, :]
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale
qk = tl.where(t_valid_row, qk, -float("inf"))
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
alpha = tl.exp(m_i - m_ij)
p = tl.exp(qk - m_ij[:, None])
l_i = l_i * alpha + tl.sum(p, axis=1)
acc = acc * alpha[:, None] + tl.dot(p.to(tl.bfloat16), v).to(tl.float32)
m_i = m_ij
out_vals = acc / l_i[:, None]
o_ptrs = (
O_ptr
+ batch_idx * stride_ob
+ head_ids[:, None] * stride_oh
+ offs_d[None, :] * stride_od
)
tl.store(o_ptrs, out_vals.to(tl.bfloat16), mask=g_mask[:, None] & d_mask[None, :])
def paged_attention_decode(
query: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
) -> torch.Tensor:
B, H, D = query.shape
_, P, Hkv, D2 = kv_cache.shape
assert D2 == 2 * D
group_size = H // Hkv
out = torch.empty_like(query)
BLOCK_G = triton.next_power_of_2(group_size)
BLOCK_T = P
BLOCK_D = triton.next_power_of_2(D)
if D <= 64:
num_warps, num_stages = 4, 3
else:
num_warps, num_stages = 8, 3
grid = (B, Hkv)
_paged_decode_gqa_kernel[grid](
query,
kv_cache,
block_table,
seq_lens,
out,
query.stride(0),
query.stride(1),
query.stride(2),
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
kv_cache.stride(3),
block_table.stride(0),
block_table.stride(1),
out.stride(0),
out.stride(1),
out.stride(2),
num_heads=H,
head_dim=D,
page_size=P,
group_size=group_size,
scale=scale,
BLOCK_G=BLOCK_G,
BLOCK_T=BLOCK_T,
BLOCK_D=BLOCK_D,
num_warps=num_warps,
num_stages=num_stages,
)
return out
class Model(nn.Module):
def __init__(
self,
batch: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_len: int,
page_size: int,
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.batch = batch
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.seq_len = seq_len
self.page_size = page_size
self.group_size = num_heads // num_kv_heads
self.scale = 1.0 / math.sqrt(head_dim)
self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
def forward(
self,
query: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
) -> torch.Tensor:
return paged_attention_decode(query, kv_cache, block_table, seq_lens, self.scale)
def get_inputs():
B = BATCH
H = NUM_HEADS
Hkv = NUM_KV_HEADS
D = HEAD_DIM
L = SEQ_LEN
P = PAGE_SIZE
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1
perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int()
block_table = perm.contiguous()
seq_lens = torch.full((B,), L, dtype=torch.int32)
return [query, kv_cache, block_table, seq_lens]
def get_init_inputs():
return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120402_cursor_composer-2.5-fast_03_paged_attention/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T18:07:48.291710+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T18:07:48.607538+00:00 elapsed_s=0.316 ms=0.068608
shape=0 variant=solution tflops=1.956 gbps=490.985 ms=0.069
shape=0 solution_peak_fraction=0.2728
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T18:07:50.271253+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T18:07:50.290533+00:00 elapsed_s=0.019 ms=0.204224
shape=1 variant=solution tflops=5.258 gbps=1316.984 ms=0.204
shape=1 solution_peak_fraction=0.7317
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T18:07:50.716370+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T18:07:50.790609+00:00 elapsed_s=0.074 ms=0.240000
shape=2 variant=solution tflops=2.237 gbps=280.166 ms=0.240
shape=2 solution_peak_fraction=0.1556
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T18:07:51.422622+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T18:07:51.430098+00:00 elapsed_s=0.007 ms=0.126784
shape=3 variant=solution tflops=3.174 gbps=795.525 ms=0.127
shape=3 solution_peak_fraction=0.4420
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T18:07:51.531844+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T18:07:51.539666+00:00 elapsed_s=0.008 ms=0.113504
shape=4 variant=solution tflops=0.577 gbps=144.636 ms=0.114
shape=4 solution_peak_fraction=0.0804
peak_fraction: 0.2562
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120402_cursor_composer-2.5-fast_03_paged_attention/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120402_cursor_composer-2.5-fast_03_paged_attention/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120402_cursor_composer-2.5-fast_03_paged_attention/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120402_cursor_composer-2.5-fast_03_paged_attention/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120402_cursor_composer-2.5-fast_03_paged_attention/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_120402_cursor_composer-2.5-fast_03_paged_attention/repo
Installed 56 packages in 71ms
PASS
{
"run_id": "20260615_120402_cursor_composer-2.5-fast_03_paged_attention",
"run_group": "",
"problem": "03_paged_attention",
"harness": "cursor",
"model": "composer-2.5-fast",
"reasoning_effort": "",
"started_at": "2026-06-15T12:04:02-06:00",
"harness_finished_at": "2026-06-15T12:07:34-06:00",
"finished_at": "2026-06-15T12:07:51-06:00",
"start_epoch": 1781546642,
"harness_end_epoch": 1781546854,
"end_epoch": 1781546871,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.2562,
"template_mutated": false,
"elapsed_seconds": 212,
"total_elapsed_seconds": 229,
"check_elapsed_seconds": 13,
"benchmark_elapsed_seconds": 4,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 100.72169811320755,
"usage": {"input_tokens": 40058, "output_tokens": 21353, "cache_read_tokens": 687168, "cache_creation_tokens": 0, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"targetDirectory": "/workspace/problems/03_paged_attention",
"globPattern": "**/*"
}{
"path": "/workspace/problems/03_paged_attention/reference.py"
}"""Naive PyTorch paged-attention decode reference (correctness oracle, not SOTA).
Single-query decode: each batch element has a query of shape (num_heads, head_dim)
and attends over a KV cache of `seq_len[b]` tokens stored as fixed-size pages in
a global pool. Pages for batch element b are listed in `block_table[b]`.
The reference performs the slow path:
1. Gather pages -> contiguous (seq_len, num_kv_heads, head_dim) per batch element.
2. Repeat KV heads for grouped-query (broadcast num_kv_heads -> num_heads).
3. Manual softmax(QK^T / sqrt(d)) @ V in fp32, cast back to bf16.
This avoids torch.nn.functional.scaled_dot_product_attention (which is on the
forbidden list) so the agent cannot dispatch through SDPA either.
"""
import math
import torch
import torch.nn as nn
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
# --- Shape knobs (overridden by check.py / benchmark.py from shapes.py) ----
BATCH = 8
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
SEQ_LEN = 1024
PAGE_SIZE = 16
class Model(nn.Module):
"""Single-query paged attention decode.
Forward inputs (all on device):
query: (batch, num_heads, head_dim) bf16
kv_cache: (num_blocks, page_size, num_kv_heads, head_dim * 2)
Layout: last dim packs [K | V] so a single gather pulls both.
Stored as bf16.
block_table: (batch, max_blocks) int32
seq
... (3651 more chars)null{
"path": "/workspace/problems/03_paged_attention/check.py"
}{
"path": "/workspace/problems/03_paged_attention/benchmark.py"
}"""Correctness runner for paged-attention decode.
Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
"""
import re
import sys
from pathlib import Path
import torch
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.correctness import check_correctness # noqa: E402
from src.eval.numeric_stress import ( # noqa: E402
numeric_stress_cases,
numeric_stress_context,
tolerance_for_case,
)
def _apply_shape(reference, shape: dict) -> None:
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
def main():
try:
import reference
import shapes
import solution
except Exception as e:
print(f"FAIL: import error: {e}")
sys.exit(1)
problem_yaml = Path("problem.yaml")
meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
for forbidden in meta.get("forbidden", []):
pat = re.escape(forbidden)
if re.search(pat, sol_src):
print(f"FAIL: forbidden op used: {forbidden}")
sys.exit(1)
device = torch.device("cuda
... (2341 more chars)"""Roofline benchmark for paged-attention decode.
For each shape: times the agent's solution first. Optional eager, compiled,
and SOTA diagnostics are enabled with KBH_BENCHMARK_BASELINES=1. Reports achieved TFLOPS, GB/s, and peak_fraction.
Decode is memory-bound, so peak_fraction is computed from achieved GB/s vs
the GPU's peak DRAM bandwidth.
"""
import sys
from math import exp, log
from pathlib import Path
import torch
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
from src.hardware import get as get_hw # noqa: E402
def _eval_formula(expr: str, vars: dict) -> float:
return float(eval(expr, {"__builtins__": {}}, vars))
def _apply_shape(reference, shape: dict) -> None:
reference.BATCH = shape["batch"]
reference.NUM_HEADS = shape["num_heads"]
reference.NUM_KV_HEADS = shape["num_kv_heads"]
reference.HEAD_DIM = shape["head_dim"]
reference.SEQ_LEN = shape["seq_len"]
reference.PAGE_SIZE = shape["page_size"]
def main():
import reference
import shapes
import solution
meta = yaml.safe_load(Path("problem.yaml").read_text())
hw = get_hw(meta["hardware"][0])
peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
peak_gbps = hw.peak_bandwidth_gb_s
regime = meta.get("regime", "memory")
... (3887 more chars){
"path": "/workspace/problems/03_paged_attention/shapes.py"
}{
"path": "/workspace/problems/03_paged_attention/problem.yaml"
}{
"path": "/workspace/problems/03_paged_attention/sota.py"
}{
"targetDirectory": "/workspace/problems/03_paged_attention",
"globPattern": "solution.py"
}{
"command": "cd /workspace/problems/03_paged_attention && uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_name(0)); print(torch.cuda.get_device_capability(0))\"",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_877e023f-f316-4a85-8dd0-616a76f3b02",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value"
... (1148 more chars)name: 03_paged_attention
display_name: "Paged Attention Decode"
precision: bf16
regime: memory # decode is bandwidth-bound (KV cache streamed once per token)
# Effective FLOPs per call: QK^T + softmax*V across all batches/heads/seq.
# Sum-formula evaluated per-shape; the benchmark substitutes batch/num_heads/etc.
# 2 * B * H * L * D for QK^T plus 2 * B * H * L * D for AV => 4 * B * H * L * D
flops_formula: "4 * batch * num_heads * seq_len * head_dim"
# Bytes moved (the real bottleneck): the KV cache must be streamed end-to-end.
# K and V each: batch * seq_len * num_kv_heads * head_dim * 2 bytes (bf16)
# Q: batch * num_heads * head_dim * 2 (negligible vs KV)
# Out: same as Q
# So total ~ 2 * (B * L * Hkv * D * 2) + small.
bytes_formula: "2 * batch * seq_len * num_kv_heads * head_dim * 2 + batch * num_heads * head_dim * 2 * 2"
hardware: [RTX_PRO_6000]
peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
peak_bandwidth_key: dram # 1.8 TB/s GDDR7 on RTX PRO 6000 Blackwell
tolerance:
bfloat16: 0.02
float16: 0.02
# Forbidden ops -- using any of these in solution.py fails correctness post-hoc.
# vllm.attention / flashinfer.batch_decode_with_paged_kv_cache: dispatching to
# the SOTA library is the trivial cheat. SDPA is also banned because the agent
# could gather pages then call SDPA and inherit FlashAttention "for free".
forbidden:
- "vllm.attention"
- "flashinfer.batch_decode_with_paged_kv_cache"
- "flashinfer.decode"
- "t
... (530 more chars)"""Shape sweep for paged attention decode.
Mix targets:
- small batch / long context (Llama-3 8B-style decode)
- large batch / medium context (server batched decode)
- GQA wide ratio (Llama-3 70B: 64 heads / 8 kv-heads)
- non-power-of-2 seq_len (forces predicated tail handling)
- head_dim=64 small-head case
"""
SHAPES = [
# (B, H, Hkv, D, L, P)
{"batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1024, "page_size": 16},
{"batch": 32, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 2048, "page_size": 16},
{"batch": 4, "num_heads": 64, "num_kv_heads": 8, "head_dim": 128, "seq_len": 4096, "page_size": 16},
{"batch": 16, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "seq_len": 1535, "page_size": 16}, # non-pow2
{"batch": 8, "num_heads": 16, "num_kv_heads": 4, "head_dim": 64, "seq_len": 2000, "page_size": 16}, # small-D, non-pow2
]
"""SOTA reference for paged-attention decode.
Tries, in order:
1. FlashInfer's BatchDecodeWithPagedKVCacheWrapper (preferred -- portable,
supports SM120, GQA, arbitrary head_dim).
2. vLLM's paged_attention_v2 CUDA op (requires its KV-cache layout, more
finicky; we adapt the layout on the fly when possible).
If neither is importable, is_available() returns False and the benchmark just
reports eager + compiled + solution.
Agents are FORBIDDEN from importing these in solution.py (see problem.yaml).
This file is only for the benchmark's reference line.
"""
from __future__ import annotations
import torch
def _try_flashinfer(
query: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
num_kv_heads: int,
head_dim: int,
page_size: int,
) -> torch.Tensor | None:
try:
import flashinfer # noqa: F401
from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
except Exception:
return None
B, H, D = query.shape
# FlashInfer expects K and V as separate (num_blocks, page_size, num_kv_heads, head_dim) tensors.
# Our reference packs [K|V] on the last dim -- split here.
k_cache = kv_cache[..., :D].contiguous()
v_cache = kv_cache[..., D:].contiguous()
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=query.device)
wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace, kv_layout="NHD")
# Build the indptr / indices / last_
... (1319 more chars)nullnull{
"command": "cd /workspace/problems/03_paged_attention && uv run python -c \"import triton; print('triton', triton.__version__)\"",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_c2d5755a-9d28-4604-b659-127e6f09178",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
... (940 more chars){
"query": "paged attention decode kernel triton or CUDA implementation",
"targetDirectories": [],
"explanation": ""
}{
"targetDirectory": "/workspace",
"globPattern": "**/solution.py"
}errorMessage: 'Codebase search is unavailable for this conversation because the codebase index could not be found. Do not retry SemanticSearch. Use Grep, Glob, and Read instead.'nullnull--- /dev/null
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -0,0 +1,212 @@
+"""Paged attention decode kernel via Triton."""
+import math
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "attention"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+BATCH = 8
+NUM_HEADS = 32
+NUM_KV_HEADS = 8
+HEAD_DIM = 128
+SEQ_LEN = 1024
+PAGE_SIZE = 16
+
+
+@triton.jit
+def _paged_decode_kernel(
+ Q_ptr,
+ KV_ptr,
+ BT_ptr,
+ SL_ptr,
+ O_ptr,
+ stride_qb,
+ stride_qh,
+ stride_qd,
+ stride_kv_block,
+ stride_kv_page,
+ stride_kv_h,
+ stride_kv_d,
+ stride_bt_b,
+ stride_bt_p,
+ stride_ob,
+ stride_oh,
+ stride_od,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ page_size,
+ group_size,
+ scale,
+ BLOCK_D: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+):
+ batch_idx = tl.program_id(0)
+ head_idx = tl.program_id(1)
+
+ kv_head_idx = head_idx // group_size
+ seq_len = tl.load(SL_ptr + batch_idx)
+
+ offs_d = tl.arange(0, BLOCK_D)
+ d_mask = offs_d < head_dim
+
+ q_ptr = Q_ptr + batch_idx * stride_qb + head_idx * stride_qh
+ q = tl.load(q_ptr + offs_d * stride_qd, mask=d_mask, other=0.0).to(tl.float32)
+
+ m_i = tl.full([], -float("inf"), tl.float32)
+ l_i = tl.full([], 0.0, tl.float32)
+ acc = tl.zeros([BLOCK_D], dtype=tl.float32)
+
+ num_pages = (seq_len + page_size - 1) // page_size
+
+ for page_idx in range(num_pages):
+ block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64)
+
+ page_start = page_idx * page_size
+ tokens_in_page = tl.minimum(page_size, seq_len - page_start)
+
+ for t in range(BLOCK_T):
+ if t >= tokens_in_page:
+ break
+
+ token_idx = page_start + t
+ kv_base = (
+ block_id * stride_kv_block
+ + t * stride_kv_page
+ + kv_head_idx * stride_kv_h
+ )
+
+ k = tl.load(
+ KV_ptr + kv_base + offs_d * stride_kv_d,
+ mask=d_mask,
+ other=0.0,
+ ).to(tl.float32)
+ v = tl.load(
+ KV_ptr + kv_base + (head_dim + offs_d) * stride_kv_d,
+ mask=d_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ score = tl.sum(q * k, axis=0) * scale
+ m_ij = tl.maximum(m_i, score)
+ alpha = tl.exp(m_i - m_ij)
+ beta = tl.exp(score - m_ij)
+ l_i = l_i * alpha + beta
+ acc = acc * alpha + beta * v
+ m_i = m_ij
+
+ out = (acc / l_i).to(tl.bfloat16)
+ o_ptr = O_ptr + batch_idx * stride_ob + head_idx * stride_oh
+ tl.store(o_ptr + offs_d * stride_od, out, mask=d_mask)
+
+
+def paged_attention_decode(
+ query: torch.Tensor,
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ seq_lens: torch.Tensor,
+ scale: float,
+) -> torch.Tensor:
+ B, H, D = query.shape
+ _, P, Hkv, D2 = kv_cache.shape
+ assert D2 == 2 * D
+ group_size = H // Hkv
+
+ out = torch.empty_like(query)
+
+ BLOCK_D = triton.next_power_of_2(D)
+ BLOCK_T = P # page_size
+
+ grid = (B, H)
+ _paged_decode_kernel[grid](
+ query,
+ kv_cache,
+ block_table,
+ seq_lens,
+ out,
+ query.stride(0),
+ query.stride(1),
+ query.stride(2),
+ kv_cache.stride(0),
+ kv_cache.stride(1),
+ kv_cache.stride(2),
+ kv_cache.stride(3),
+ block_table.stride(0),
+ block_table.stride(1),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ num_heads=H,
+ num_kv_heads=Hkv,
+ head_dim=D,
+ page_size=P,
+ group_size=group_size,
+ scale=scale,
+ BLOCK_D=BLOCK_D,
+ BLOCK_T=BLOCK_T,
+ num_warps=4,
+ num_stages=2,
+ )
+ return out
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ batch: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ seq_len: int,
+ page_size: int,
+ ):
+ super().__init__()
+ assert num_heads % num_kv_heads == 0
+ self.batch = batch
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
+ self.seq_len = seq_len
+ self.page_size = page_size
+ self.group_size = num_heads // num_kv_heads
+ self.scale = 1.0 / math.sqrt(head_dim)
+ self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ seq_lens: torch.Tensor,
+ ) -> torch.Tensor:
+ return paged_attention_decode(query, kv_cache, block_table, seq_lens, self.scale)
+
+
+def get_inputs():
+ B = BATCH
+ H = NUM_HEADS
+ Hkv = NUM_KV_HEADS
+ D = HEAD_DIM
+ L = SEQ_LEN
+ P = PAGE_SIZE
+
+ pages_per_seq = (L + P - 1) // P
+ total_pages = max(B * pages_per_seq + 8, 64)
+
+ query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1
+ kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1
+
+ perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int()
+ block_table = perm.contiguous()
+ seq_lens = torch.full((B,), L, dtype=torch.int32)
+
+ return [query, kv_cache, block_table, seq_lens]
+
+
+def get_init_inputs():
+ return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]
{
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<5405 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_9d9137cd-7f35-4134-82b9-1708c2de791",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
"fullText": "cd /workspace/problems/03_p
... (908 more chars)failure:
command: 'cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/check.py", line 116, in <module>
main()
File "/workspace/problems/03_paged_attention/check.py", line 79, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 187, in forward
return paged_attention_decode(query, kv_cache, block_table, seq_lens, self.scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 126, in paged_attention_decode
_paged_decode_kernel[grid](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
... (4962 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,80 +1,5 @@
-"""Paged attention decode kernel via Triton."""
-import math
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-OP_TYPE = "attention"
-SUPPORTED_PRECISIONS = ["bf16"]
-HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-
-BATCH = 8
-NUM_HEADS = 32
-NUM_KV_HEADS = 8
-HEAD_DIM = 128
-SEQ_LEN = 1024
-PAGE_SIZE = 16
-
-
-@triton.jit
-def _paged_decode_kernel(
- Q_ptr,
- KV_ptr,
- BT_ptr,
- SL_ptr,
- O_ptr,
- stride_qb,
- stride_qh,
- stride_qd,
- stride_kv_block,
- stride_kv_page,
- stride_kv_h,
- stride_kv_d,
- stride_bt_b,
- stride_bt_p,
- stride_ob,
- stride_oh,
- stride_od,
- num_heads,
- num_kv_heads,
- head_dim,
- page_size,
- group_size,
- scale,
- BLOCK_D: tl.constexpr,
- BLOCK_T: tl.constexpr,
-):
- batch_idx = tl.program_id(0)
- head_idx = tl.program_id(1)
-
- kv_head_idx = head_idx // group_size
- seq_len = tl.load(SL_ptr + batch_idx)
-
- offs_d = tl.arange(0, BLOCK_D)
- d_mask = offs_d < head_dim
-
- q_ptr = Q_ptr + batch_idx * stride_qb + head_idx * stride_qh
- q = tl.load(q_ptr + offs_d * stride_qd, mask=d_mask, other=0.0).to(tl.float32)
-
- m_i = tl.full([], -float("inf"), tl.float32)
- l_i = tl.full([], 0.0, tl.float32)
- acc = tl.zeros([BLOCK_D], dtype=tl.float32)
-
- num_pages = (seq_len + page_size - 1) // page_size
-
- for page_idx in range(num_pages):
- block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64)
-
- page_start = page_idx * page_size
- tokens_in_page = tl.minimum(page_size, seq_len - page_start)
-
for t in range(BLOCK_T):
- if t >= tokens_in_page:
- break
-
- token_idx = page_start + t
+ t_valid = t < tokens_in_page
kv_base = (
block_id * stride_kv_block
+ t * stride_kv_page
@@ -83,130 +8,20 @@
k = tl.load(
KV_ptr + kv_base + offs_d * stride_kv_d,
- mask=d_mask,
+ mask=d_mask & t_valid,
other=0.0,
).to(tl.float32)
v = tl.load(
KV_ptr + kv_base + (head_dim + offs_d) * stride_kv_d,
- mask=d_mask,
+ mask=d_mask & t_valid,
other=0.0,
).to(tl.float32)
score = tl.sum(q * k, axis=0) * scale
+ score = tl.where(t_valid, score, -float("inf"))
m_ij = tl.maximum(m_i, score)
alpha = tl.exp(m_i - m_ij)
beta = tl.exp(score - m_ij)
l_i = l_i * alpha + beta
acc = acc * alpha + beta * v
- m_i = m_ij
-
- out = (acc / l_i).to(tl.bfloat16)
- o_ptr = O_ptr + batch_idx * stride_ob + head_idx * stride_oh
- tl.store(o_ptr + offs_d * stride_od, out, mask=d_mask)
-
-
-def paged_attention_decode(
- query: torch.Tensor,
- kv_cache: torch.Tensor,
- block_table: torch.Tensor,
- seq_lens: torch.Tensor,
- scale: float,
-) -> torch.Tensor:
- B, H, D = query.shape
- _, P, Hkv, D2 = kv_cache.shape
- assert D2 == 2 * D
- group_size = H // Hkv
-
- out = torch.empty_like(query)
-
- BLOCK_D = triton.next_power_of_2(D)
- BLOCK_T = P # page_size
-
- grid = (B, H)
- _paged_decode_kernel[grid](
- query,
- kv_cache,
- block_table,
- seq_lens,
- out,
- query.stride(0),
- query.stride(1),
- query.stride(2),
- kv_cache.stride(0),
- kv_cache.stride(1),
- kv_cache.stride(2),
- kv_cache.stride(3),
- block_table.stride(0),
- block_table.stride(1),
- out.stride(0),
- out.stride(1),
- out.stride(2),
- num_heads=H,
- num_kv_heads=Hkv,
- head_dim=D,
- page_size=P,
- group_size=group_size,
- scale=scale,
- BLOCK_D=BLOCK_D,
- BLOCK_T=BLOCK_T,
- num_warps=4,
- num_stages=2,
- )
- return out
-
-
-class Model(nn.Module):
- def __init__(
- self,
- batch: int,
- num_heads: int,
- num_kv_heads: int,
- head_dim: int,
- seq_len: int,
- page_size: int,
- ):
- super().__init__()
- assert num_heads % num_kv_heads == 0
- self.batch = batch
- self.num_heads = num_heads
- self.num_kv_heads = num_kv_heads
- self.head_dim = head_dim
- self.seq_len = seq_len
- self.page_size = page_size
- self.group_size = num_heads // num_kv_heads
- self.scale = 1.0 / math.sqrt(head_dim)
- self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
-
- def forward(
- self,
- query: torch.Tensor,
- kv_cache: torch.Tensor,
- block_table: torch.Tensor,
- seq_lens: torch.Tensor,
- ) -> torch.Tensor:
- return paged_attention_decode(query, kv_cache, block_table, seq_lens, self.scale)
-
-
-def get_inputs():
- B = BATCH
- H = NUM_HEADS
- Hkv = NUM_KV_HEADS
- D = HEAD_DIM
- L = SEQ_LEN
- P = PAGE_SIZE
-
- pages_per_seq = (L + P - 1) // P
- total_pages = max(B * pages_per_seq + 8, 64)
-
- query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1
- kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1
-
- perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int()
- block_table = perm.contiguous()
- seq_lens = torch.full((B,), L, dtype=torch.int32)
-
- return [query, kv_cache, block_table, seq_lens]
-
-
-def get_init_inputs():
- return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]
+ m_i = m_ij{
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<936 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_cd952fd0-ecc2-40a5-83cf-372e2dc58db",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
"fullText": "cd /workspace/problems/03_p
... (914 more chars)null{
"command": "cd /workspace/problems/03_paged_attention && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_e5337487-7811-451c-b413-e06ef103e69",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
"fullText": "cd /workspace/problems/
... (924 more chars)null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,27 +1,234 @@
- for t in range(BLOCK_T):
- t_valid = t < tokens_in_page
- kv_base = (
- block_id * stride_kv_block
- + t * stride_kv_page
- + kv_head_idx * stride_kv_h
- )
-
- k = tl.load(
+"""Paged attention decode kernel via Triton (GQA-aware, page-batched)."""
+import math
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "attention"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+BATCH = 8
+NUM_HEADS = 32
+NUM_KV_HEADS = 8
+HEAD_DIM = 128
+SEQ_LEN = 1024
+PAGE_SIZE = 16
+
+
+@triton.jit
+def _paged_decode_gqa_kernel(
+ Q_ptr,
+ KV_ptr,
+ BT_ptr,
+ SL_ptr,
+ O_ptr,
+ stride_qb,
+ stride_qh,
+ stride_qd,
+ stride_kv_block,
+ stride_kv_page,
+ stride_kv_h,
+ stride_kv_d,
+ stride_bt_b,
+ stride_bt_p,
+ stride_ob,
+ stride_oh,
+ stride_od,
+ num_heads,
+ head_dim,
+ page_size,
+ group_size,
+ scale,
+ BLOCK_G: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ batch_idx = tl.program_id(0)
+ kv_head_idx = tl.program_id(1)
+
+ seq_len = tl.load(SL_ptr + batch_idx)
+ num_pages = (seq_len + page_size - 1) // page_size
+
+ offs_g = tl.arange(0, BLOCK_G)
+ offs_t = tl.arange(0, BLOCK_T)
+ offs_d = tl.arange(0, BLOCK_D)
+ g_mask = (kv_head_idx * group_size + offs_g) < num_heads
+ d_mask = offs_d < head_dim
+
+ q = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
+ head_base = kv_head_idx * group_size
+ for g in tl.static_range(BLOCK_G):
+ head_idx = head_base + g
+ q_ptr = Q_ptr + batch_idx * stride_qb + head_idx * stride_qh
+ q_row = tl.load(
+ q_ptr + offs_d * stride_qd,
+ mask=g_mask[g] & d_mask,
+ other=0.0,
+ ).to(tl.float32)
+ q = tl.where(g_mask[g], tl.expand_dims(q_row, 0), q) # need proper assignment
+
+ # Load Q rows explicitly
+ q_rows = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
+ for g in tl.static_range(BLOCK_G):
+ head_idx = head_base + g
+ q_ptr = Q_ptr + batch_idx * stride_qb + head_idx * stride_qh
+ q_row = tl.load(q_ptr + offs_d * stride_qd, mask=d_mask, other=0.0).to(tl.float32)
+ q_rows = tl.where(tl.expand_dims(g_mask[g], 1), tl.expand_dims(q_row, 0), q_rows)
+ q = q_rows
+
+ m_i = tl.full([BLOCK_G], -float("inf"), tl.float32)
+ l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
+
+ for page_idx in range(num_pages):
+ block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64)
+ page_start = page_idx * page_size
+ tokens_in_page = tl.minimum(page_size, seq_len - page_start)
+
+ t_mask = offs_t[:, None] < tokens_in_page
+ kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
+
+ k = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.float32)
+ v = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.float32)
+ for t in tl.static_range(BLOCK_T):
+ kv_base = kv_h_base + t * stride_kv_page
+ k_row = tl.load(
KV_ptr + kv_base + offs_d * stride_kv_d,
- mask=d_mask & t_valid,
+ mask=d_mask,
other=0.0,
).to(tl.float32)
- v = tl.load(
+ v_row = tl.load(
KV_ptr + kv_base + (head_dim + offs_d) * stride_kv_d,
- mask=d_mask & t_valid,
+ mask=d_mask,
other=0.0,
).to(tl.float32)
-
- score = tl.sum(q * k, axis=0) * scale
- score = tl.where(t_valid, score, -float("inf"))
- m_ij = tl.maximum(m_i, score)
- alpha = tl.exp(m_i - m_ij)
- beta = tl.exp(score - m_ij)
- l_i = l_i * alpha + beta
- acc = acc * alpha + beta * v
- m_i = m_ij+ k = tl.where(tl.expand_dims(offs_t == t, 1), tl.expand_dims(k_row, 0), k)
+ v = tl.where(tl.expand_dims(offs_t == t, 1), tl.expand_dims(v_row, 0), v)
+
+ qk = tl.dot(q, tl.trans(k)) * scale
+ qk = tl.where(t_mask, qk, -float("inf"))
+
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
+ alpha = tl.exp(m_i - m_ij)
+ p = tl.exp(qk - tl.expand_dims(m_ij, 1))
+ l_i = l_i * alpha + tl.sum(p, axis=1)
+ acc = acc * tl.expand_dims(alpha, 1) + tl.dot(p.to(tl.bfloat16), v)
+ m_i = m_ij
+
+ out_vals = acc / tl.expand_dims(l_i, 1)
+ for g in tl.static_range(BLOCK_G):
+ head_idx = head_base + g
+ o_ptr = O_ptr + batch_idx * stride_ob + head_idx * stride_oh
+ o_row = tl.load(out_vals + g) # invalid
+ tl.store(o_ptr + offs_d * stride_od, out_vals[g, :].to(tl.bfloat16), mask=g_mask[g] & d_mask)
+
+
+def paged_attention_decode(
+ query: torch.Tensor,
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ seq_lens: torch.Tensor,
+ scale: float,
+) -> torch.Tensor:
+ B, H, D = query.shape
+ _, P, Hkv, D2 = kv_cache.shape
+ assert D2 == 2 * D
+ group_size = H // Hkv
+
+ out = torch.empty_like(query)
+
+ BLOCK_G = triton.next_power_of_2(group_size)
+ BLOCK_T = P
+ BLOCK_D = triton.next_power_of_2(D)
+
+ grid = (B, Hkv)
+ _paged_decode_gqa_kernel[grid](
+ query,
+ kv_cache,
+ block_table,
+ seq_lens,
+ out,
+ query.stride(0),
+ query.stride(1),
+ query.stride(2),
+ kv_cache.stride(0),
+ kv_cache.stride(1),
+ kv_cache.stride(2),
+ kv_cache.stride(3),
+ block_table.stride(0),
+ block_table.stride(1),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ num_heads=H,
+
... (1955 more chars){
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<6870 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_6b6b8cfd-dc37-4ee0-b91b-86a271586fd",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
"fullText": "cd /workspace/problems/03_p
... (920 more chars)failure:
command: 'cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1101, in __getitem__
raise ValueError(f"unsupported tensor index: {sl}")
ValueError: unsupported tensor index: constexpr[0]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/check.py", line 116, in <module>
main()
File "/workspace/problems/03_paged_attention/check.py", line 79, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/03_paged_attention/solution.py", line 209, in forward
return paged_attentio
... (6140 more chars){
"command": "<494 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_0fadcf3e-e516-474e-8809-74f71d90892",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
"fullText": "cd /workspace/problems/03_paged_attention"
},
{
"nam
... (1832 more chars)cd /workspace/problems/03_paged_attention && uv run python -c "
import torch
import reference, shapes
from benchmark import _apply_shape
device = torch.device('cuda:0')
_apply_shape(reference, shapes.SHAPES[0])
inputs = [t.to(device) for t in reference.get_inputs()]
try:
import sota
if sota.is_available():
q,kvc,bt,sl = inputs
out = sota.sota_forward(q,kvc,bt,sl,8,128,16)
print('flashinfer ok', out.shape)
except Exception as e:
print('no sota', e)
" 2>&1null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -55,29 +55,18 @@
offs_g = tl.arange(0, BLOCK_G)
offs_t = tl.arange(0, BLOCK_T)
offs_d = tl.arange(0, BLOCK_D)
- g_mask = (kv_head_idx * group_size + offs_g) < num_heads
+
+ head_ids = kv_head_idx * group_size + offs_g
+ g_mask = head_ids < num_heads
d_mask = offs_d < head_dim
- q = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
- head_base = kv_head_idx * group_size
- for g in tl.static_range(BLOCK_G):
- head_idx = head_base + g
- q_ptr = Q_ptr + batch_idx * stride_qb + head_idx * stride_qh
- q_row = tl.load(
- q_ptr + offs_d * stride_qd,
- mask=g_mask[g] & d_mask,
- other=0.0,
- ).to(tl.float32)
- q = tl.where(g_mask[g], tl.expand_dims(q_row, 0), q) # need proper assignment
-
- # Load Q rows explicitly
- q_rows = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
- for g in tl.static_range(BLOCK_G):
- head_idx = head_base + g
- q_ptr = Q_ptr + batch_idx * stride_qb + head_idx * stride_qh
- q_row = tl.load(q_ptr + offs_d * stride_qd, mask=d_mask, other=0.0).to(tl.float32)
- q_rows = tl.where(tl.expand_dims(g_mask[g], 1), tl.expand_dims(q_row, 0), q_rows)
- q = q_rows
+ q_ptrs = (
+ Q_ptr
+ + batch_idx * stride_qb
+ + head_ids[:, None] * stride_qh
+ + offs_d[None, :] * stride_qd
+ )
+ q = tl.load(q_ptrs, mask=g_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
m_i = tl.full([BLOCK_G], -float("inf"), tl.float32)
l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
@@ -88,56 +77,51 @@
page_start = page_idx * page_size
tokens_in_page = tl.minimum(page_size, seq_len - page_start)
- t_mask = offs_t[:, None] < tokens_in_page
kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
-
- k = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.float32)
- v = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.float32)
- for t in tl.static_range(BLOCK_T):
- kv_base = kv_h_base + t * stride_kv_page
- k_row = tl.load(
- KV_ptr + kv_base + offs_d * stride_kv_d,
- mask=d_mask,
- other=0.0,
- ).to(tl.float32)
- v_row = tl.load(
- KV_ptr + kv_base + (head_dim + offs_d) * stride_kv_d,
- mask=d_mask,
- other=0.0,
- ).to(tl.float32)
- k = tl.where(tl.expand_dims(offs_t == t, 1), tl.expand_dims(k_row, 0), k)
- v = tl.where(tl.expand_dims(offs_t == t, 1), tl.expand_dims(v_row, 0), v)
+ t_valid = offs_t[:, None] < tokens_in_page
+
+ k_ptrs = (
+ KV_ptr
+ + kv_h_base
+ + offs_t[:, None] * stride_kv_page
+ + offs_d[None, :] * stride_kv_d
+ )
+ v_ptrs = (
+ KV_ptr
+ + kv_h_base
+ + offs_t[:, None] * stride_kv_page
+ + (head_dim + offs_d)[None, :] * stride_kv_d
+ )
+ kv_mask = t_valid & d_mask[None, :]
+
+ k = tl.load(k_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
+ v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
qk = tl.dot(q, tl.trans(k)) * scale
- qk = tl.where(t_mask, qk, -float("inf"))
+ qk = tl.where(t_valid, qk, -float("inf"))
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
alpha = tl.exp(m_i - m_ij)
- p = tl.exp(qk - tl.expand_dims(m_ij, 1))
+ p = tl.exp(qk - m_ij[:, None])
l_i = l_i * alpha + tl.sum(p, axis=1)
- acc = acc * tl.expand_dims(alpha, 1) + tl.dot(p.to(tl.bfloat16), v)
+ acc = acc * alpha[:, None] + tl.dot(p.to(tl.bfloat16), v)
m_i = m_ij
- out_vals = acc / tl.expand_dims(l_i, 1)
- for g in tl.static_range(BLOCK_G):
- head_idx = head_base + g
- o_ptr = O_ptr + batch_idx * stride_ob + head_idx * stride_oh
- o_row = tl.load(out_vals + g) # invalid
- tl.store(o_ptr + offs_d * stride_od, out_vals[g, :].to(tl.bfloat16), mask=g_mask[g] & d_mask)
-
-
-def paged_attention_decode(
- query: torch.Tensor,
- kv_cache: torch.Tensor,
- block_table: torch.Tensor,
- seq_lens: torch.Tensor,
- scale: float,
-) -> torch.Tensor:
+ out_vals = acc / l_i[:, None]
+ o_ptrs = (
+ O_ptr
+ + batch_idx * stride_ob
+ + head_ids[:, None] * stride_oh
+ + offs_d[None, :] * stride_od
+ )
+ tl.store(o_ptrs, out_vals.to(tl.bfloat16), mask=g_mask[:, None] & d_mask[None, :])
+
+
+def _launch(query, kv_cache, block_table, seq_lens, scale, num_warps, num_stages):
B, H, D = query.shape
_, P, Hkv, D2 = kv_cache.shape
assert D2 == 2 * D
group_size = H // Hkv
-
out = torch.empty_like(query)
BLOCK_G = triton.next_power_of_2(group_size)
@@ -171,10 +155,23 @@
BLOCK_G=BLOCK_G,
BLOCK_T=BLOCK_T,
BLOCK_D=BLOCK_D,
- num_warps=4 if D <= 64 else 8,
- num_stages=2,
+ num_warps=num_warps,
+ num_stages=num_stages,
)
return out
+
+
+def paged_attention_decode(
+ query: torch.Tensor,
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ seq_lens: torch.Tensor,
+ scale: float,
+) -> torch.Tensor:
+ D = query.shape[-1]
+ if D <= 64:
+ return _launch(query, kv_cache, block_table, seq_lens, scale, num_warps=4, num_stages=2)
+ return _launch(query, kv_cache, block_table, seq_lens, scale, num_warps=8, num_stages=3)
class Model(nn.Module):
{
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<6231 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_32aa772e-6e86-4156-a04f-b9b885080b4",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
... (1485 more chars)failure:
command: 'cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 2476, in where
return _semantic.where(condition, x, y)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1677, in where
condition, x = self.broadcast_impl_value(condition, x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 752, in broadcast_impl_value
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
ValueError: Cannot make_shape_compatible: incompatible dimensions at index 0: 16 and 4
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/check.py", line 116, in <module>
main()
File "/workspace/problems/03_paged_attention/check.py", line 79, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^
... (7739 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,231 +1,5 @@
-"""Paged attention decode kernel via Triton (GQA-aware, page-batched)."""
-import math
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-OP_TYPE = "attention"
-SUPPORTED_PRECISIONS = ["bf16"]
-HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-
-BATCH = 8
-NUM_HEADS = 32
-NUM_KV_HEADS = 8
-HEAD_DIM = 128
-SEQ_LEN = 1024
-PAGE_SIZE = 16
-
-
-@triton.jit
-def _paged_decode_gqa_kernel(
- Q_ptr,
- KV_ptr,
- BT_ptr,
- SL_ptr,
- O_ptr,
- stride_qb,
- stride_qh,
- stride_qd,
- stride_kv_block,
- stride_kv_page,
- stride_kv_h,
- stride_kv_d,
- stride_bt_b,
- stride_bt_p,
- stride_ob,
- stride_oh,
- stride_od,
- num_heads,
- head_dim,
- page_size,
- group_size,
- scale,
- BLOCK_G: tl.constexpr,
- BLOCK_T: tl.constexpr,
- BLOCK_D: tl.constexpr,
-):
- batch_idx = tl.program_id(0)
- kv_head_idx = tl.program_id(1)
-
- seq_len = tl.load(SL_ptr + batch_idx)
- num_pages = (seq_len + page_size - 1) // page_size
-
- offs_g = tl.arange(0, BLOCK_G)
- offs_t = tl.arange(0, BLOCK_T)
- offs_d = tl.arange(0, BLOCK_D)
-
- head_ids = kv_head_idx * group_size + offs_g
- g_mask = head_ids < num_heads
- d_mask = offs_d < head_dim
-
- q_ptrs = (
- Q_ptr
- + batch_idx * stride_qb
- + head_ids[:, None] * stride_qh
- + offs_d[None, :] * stride_qd
- )
- q = tl.load(q_ptrs, mask=g_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
-
- m_i = tl.full([BLOCK_G], -float("inf"), tl.float32)
- l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
- acc = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
-
- for page_idx in range(num_pages):
- block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64)
page_start = page_idx * page_size
tokens_in_page = tl.minimum(page_size, seq_len - page_start)
kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
- t_valid = offs_t[:, None] < tokens_in_page
-
- k_ptrs = (
- KV_ptr
- + kv_h_base
- + offs_t[:, None] * stride_kv_page
- + offs_d[None, :] * stride_kv_d
- )
- v_ptrs = (
- KV_ptr
- + kv_h_base
- + offs_t[:, None] * stride_kv_page
- + (head_dim + offs_d)[None, :] * stride_kv_d
- )
- kv_mask = t_valid & d_mask[None, :]
-
- k = tl.load(k_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
- v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
-
- qk = tl.dot(q, tl.trans(k)) * scale
- qk = tl.where(t_valid, qk, -float("inf"))
-
- m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
- alpha = tl.exp(m_i - m_ij)
- p = tl.exp(qk - m_ij[:, None])
- l_i = l_i * alpha + tl.sum(p, axis=1)
- acc = acc * alpha[:, None] + tl.dot(p.to(tl.bfloat16), v)
- m_i = m_ij
-
- out_vals = acc / l_i[:, None]
- o_ptrs = (
- O_ptr
- + batch_idx * stride_ob
- + head_ids[:, None] * stride_oh
- + offs_d[None, :] * stride_od
- )
- tl.store(o_ptrs, out_vals.to(tl.bfloat16), mask=g_mask[:, None] & d_mask[None, :])
-
-
-def _launch(query, kv_cache, block_table, seq_lens, scale, num_warps, num_stages):
- B, H, D = query.shape
- _, P, Hkv, D2 = kv_cache.shape
- assert D2 == 2 * D
- group_size = H // Hkv
- out = torch.empty_like(query)
-
- BLOCK_G = triton.next_power_of_2(group_size)
- BLOCK_T = P
- BLOCK_D = triton.next_power_of_2(D)
-
- grid = (B, Hkv)
- _paged_decode_gqa_kernel[grid](
- query,
- kv_cache,
- block_table,
- seq_lens,
- out,
- query.stride(0),
- query.stride(1),
- query.stride(2),
- kv_cache.stride(0),
- kv_cache.stride(1),
- kv_cache.stride(2),
- kv_cache.stride(3),
- block_table.stride(0),
- block_table.stride(1),
- out.stride(0),
- out.stride(1),
- out.stride(2),
- num_heads=H,
- head_dim=D,
- page_size=P,
- group_size=group_size,
- scale=scale,
- BLOCK_G=BLOCK_G,
- BLOCK_T=BLOCK_T,
- BLOCK_D=BLOCK_D,
- num_warps=num_warps,
- num_stages=num_stages,
- )
- return out
-
-
-def paged_attention_decode(
- query: torch.Tensor,
- kv_cache: torch.Tensor,
- block_table: torch.Tensor,
- seq_lens: torch.Tensor,
- scale: float,
-) -> torch.Tensor:
- D = query.shape[-1]
- if D <= 64:
- return _launch(query, kv_cache, block_table, seq_lens, scale, num_warps=4, num_stages=2)
- return _launch(query, kv_cache, block_table, seq_lens, scale, num_warps=8, num_stages=3)
-
-
-class Model(nn.Module):
- def __init__(
- self,
- batch: int,
- num_heads: int,
- num_kv_heads: int,
- head_dim: int,
- seq_len: int,
- page_size: int,
- ):
- super().__init__()
- assert num_heads % num_kv_heads == 0
- self.batch = batch
- self.num_heads = num_heads
- self.num_kv_heads = num_kv_heads
- self.head_dim = head_dim
- self.seq_len = seq_len
- self.page_size = page_size
- self.group_size = num_heads // num_kv_heads
- self.scale = 1.0 / math.sqrt(head_dim)
- self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
-
- def forward(
- self,
- query: torch.Tensor,
- kv_cache: torch.Tensor,
- block_table: torch.Tensor,
- seq_lens: torch.Tensor,
- ) -> torch.Tensor:
- return paged_attention_decode(query, kv_cache, block_table, seq_lens, self.scale)
-
-
-def get_inputs():
- B = BATCH
- H = NUM_HEADS
- Hkv = NUM_KV_HEADS
- D = HEAD_DIM
- L = SEQ_LEN
- P =
... (645 more chars){
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<237 chars \u2014 see diff>"
}null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,5 +1,3 @@
- page_start = page_idx * page_size
- tokens_in_page = tl.minimum(page_size, seq_len - page_start)
-
kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
- t_valid = offs_t[None, :] < tokens_in_page+ t_valid_row = offs_t[None, :] < tokens_in_page
+ t_valid_2d = offs_t[:, None] < tokens_in_page{
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<183 chars \u2014 see diff>"
}null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,3 +1 @@
- kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
- t_valid_row = offs_t[None, :] < tokens_in_page
- t_valid_2d = offs_t[:, None] < tokens_in_page+ kv_mask = t_valid_2d & d_mask[None, :]{
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": " kv_mask = t_valid_2d & d_mask[None, :]"
}null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1 +1 @@
- kv_mask = t_valid_2d & d_mask[None, :]+ qk = tl.where(t_valid_row, qk, -float("inf")){
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": " qk = tl.where(t_valid_row, qk, -float(\"inf\"))"
}null{
"command": "cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_790720c3-2c3a-4959-a7fb-0ded62ed38b",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
... (1485 more chars)failure:
command: 'cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 2054, in dot
res = _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1492, in dot
assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Both operands must be same dtype. Got bf16 and fp32
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/03_paged_attention/check.py", line 116, in <module>
main()
File "/workspace/problems/03_paged_attention/check.py", line 79, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in
... (7609 more chars)--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1 +1 @@
- qk = tl.where(t_valid_row, qk, -float("inf"))+ acc = acc * alpha[:, None] + tl.dot(p, v){
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": " acc = acc * alpha[:, None] + tl.dot(p, v)"
}null{
"command": "cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_acd1aa3e-b794-4a14-974e-9ff6bba3087",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
... (1485 more chars)null{
"command": "cd /workspace/problems/03_paged_attention && KBH_BENCHMARK_BASELINES=1 uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_74c4a366-9fbb-4038-a480-03c291a9404",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
"fullText"
... (954 more chars){
"path": "/workspace/problems/03_paged_attention/solution.py"
}"""Paged attention decode kernel via Triton (GQA-aware, page-batched)."""
import math
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
BATCH = 8
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
SEQ_LEN = 1024
PAGE_SIZE = 16
@triton.jit
def _paged_decode_gqa_kernel(
Q_ptr,
KV_ptr,
BT_ptr,
SL_ptr,
O_ptr,
stride_qb,
stride_qh,
stride_qd,
stride_kv_block,
stride_kv_page,
stride_kv_h,
stride_kv_d,
stride_bt_b,
stride_bt_p,
stride_ob,
stride_oh,
stride_od,
num_heads,
head_dim,
page_size,
group_size,
scale,
BLOCK_G: tl.constexpr,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
batch_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
seq_len = tl.load(SL_ptr + batch_idx)
num_pages = (seq_len + page_size - 1) // page_size
offs_g = tl.arange(0, BLOCK_G)
offs_t = tl.arange(0, BLOCK_T)
offs_d = tl.arange(0, BLOCK_D)
head_ids = kv_head_idx * group_size + offs_g
g_mask = head_ids < num_heads
d_mask = offs_d < head_dim
q_ptrs = (
Q_ptr
+ batch_idx * stride_qb
+ head_ids[:, None] * stride_qh
+ offs_d[None, :] * stride_qd
)
q = tl.load(q_ptrs, mask=g_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
m_i = tl.full([BLOCK_G], -float("inf"), tl.float32)
... (4780 more chars)null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1 +1,247 @@
- acc = acc * alpha[:, None] + tl.dot(p, v)+"""Paged attention decode kernel via Triton (GQA-aware, page-batched)."""
+import math
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "attention"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+BATCH = 8
+NUM_HEADS = 32
+NUM_KV_HEADS = 8
+HEAD_DIM = 128
+SEQ_LEN = 1024
+PAGE_SIZE = 16
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64, "PAGE_UNROLL": 1}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64, "PAGE_UNROLL": 2}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 32, "BLOCK_D": 64, "PAGE_UNROLL": 2}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128, "PAGE_UNROLL": 1}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128, "PAGE_UNROLL": 2}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 32, "BLOCK_D": 128, "PAGE_UNROLL": 2}, num_warps=8, num_stages=4),
+ ],
+ key=["num_heads", "head_dim", "page_size", "group_size"],
+)
+@triton.jit
+def _paged_decode_gqa_kernel(
+ Q_ptr,
+ KV_ptr,
+ BT_ptr,
+ SL_ptr,
+ O_ptr,
+ stride_qb,
+ stride_qh,
+ stride_qd,
+ stride_kv_block,
+ stride_kv_page,
+ stride_kv_h,
+ stride_kv_d,
+ stride_bt_b,
+ stride_bt_p,
+ stride_ob,
+ stride_oh,
+ stride_od,
+ num_heads,
+ head_dim,
+ page_size,
+ group_size,
+ scale,
+ BLOCK_G: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ PAGE_UNROLL: tl.constexpr,
+):
+ batch_idx = tl.program_id(0)
+ kv_head_idx = tl.program_id(1)
+
+ seq_len = tl.load(SL_ptr + batch_idx)
+ num_pages = (seq_len + page_size - 1) // page_size
+ tokens_per_iter = BLOCK_T // PAGE_UNROLL
+
+ offs_g = tl.arange(0, BLOCK_G)
+ offs_t = tl.arange(0, BLOCK_T)
+ offs_d = tl.arange(0, BLOCK_D)
+
+ head_ids = kv_head_idx * group_size + offs_g
+ g_mask = head_ids < num_heads
+ d_mask = offs_d < head_dim
+
+ q_ptrs = (
+ Q_ptr
+ + batch_idx * stride_qb
+ + head_ids[:, None] * stride_qh
+ + offs_d[None, :] * stride_qd
+ )
+ q = tl.load(q_ptrs, mask=g_mask[:, None] & d_mask[None, :], other=0.0)
+
+ m_i = tl.full([BLOCK_G], -float("inf"), tl.float32)
+ l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
+
+ for page_base in range(0, num_pages, PAGE_UNROLL):
+ k = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.bfloat16)
+ v = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.bfloat16)
+
+ for pu in tl.static_range(PAGE_UNROLL):
+ page_idx = page_base + pu
+ page_active = page_idx < num_pages
+
+ block_id = tl.load(
+ BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p,
+ mask=page_active,
+ other=0,
+ ).to(tl.int64)
+
+ page_start = page_idx * page_size
+ tokens_in_page = tl.minimum(page_size, seq_len - page_start)
+ kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
+
+ local_t = offs_t - pu * tokens_per_iter
+ t_in_page = (local_t >= 0) & (local_t < tokens_per_iter)
+ t_valid = t_in_page & (local_t < tokens_in_page) & page_active
+
+ k_ptrs = (
+ KV_ptr
+ + kv_h_base
+ + local_t[:, None] * stride_kv_page
+ + offs_d[None, :] * stride_kv_d
+ )
+ v_ptrs = (
+ KV_ptr
+ + kv_h_base
+ + local_t[:, None] * stride_kv_page
+ + (head_dim + offs_d)[None, :] * stride_kv_d
+ )
+ kv_mask = t_valid[:, None] & d_mask[None, :]
+
+ k_page = tl.load(k_ptrs, mask=kv_mask, other=0.0)
+ v_page = tl.load(v_ptrs, mask=kv_mask, other=0.0)
+
+ slot = pu * tokens_per_iter + local_t
+ k = tl.where(slot[:, None] == offs_t[:, None], k_page, k)
+ v = tl.where(slot[:, None] == offs_t[:, None], v_page, v)
+
+ t_valid_row = offs_t[None, :] < tl.minimum(seq_len - page_base * tokens_per_iter, BLOCK_T)
+ qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale
+ qk = tl.where(t_valid_row, qk, -float("inf"))
+
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
+ alpha = tl.exp(m_i - m_ij)
+ p = tl.exp(qk - m_ij[:, None])
+ l_i = l_i * alpha + tl.sum(p, axis=1)
+ acc = acc * alpha[:, None] + tl.dot(p.to(tl.bfloat16), v).to(tl.float32)
+ m_i = m_ij
+
+ out_vals = acc / l_i[:, None]
+ o_ptrs = (
+ O_ptr
+ + batch_idx * stride_ob
+ + head_ids[:, None] * stride_oh
+ + offs_d[None, :] * stride_od
+ )
+ tl.store(o_ptrs, out_vals.to(tl.bfloat16), mask=g_mask[:, None] & d_mask[None, :])
+
+
+def paged_attention_decode(
+ query: torch.Tensor,
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ seq_lens: torch.Tensor,
+ scale: float,
+) -> torch.Tensor:
+ B, H, D = query.shape
+ _, P, Hkv, D2 = kv_cache.shape
+ assert D2 == 2 * D
+ group_size = H // Hkv
+ out = torch.empty_like(query)
+
+ grid = (B, Hkv)
+ _paged_decode_gqa_kernel[grid](
+ query,
+ kv_cache,
+ block_table,
+ seq_lens,
+ out,
+ query.stride(0),
+ query.stride(1),
+ query.stride(2),
+ kv_cache.stride(0),
+ kv_cache.stride(1),
+ kv_cache.stride(2),
+ kv_cache.stride(3),
+ block_table.stride(0),
+ block_table.stride(1),
+ out.stride(0),
+ out.stride(1),
+ o
... (1852 more chars){
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<7425 chars \u2014 see diff>"
}null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -20,12 +20,12 @@
@triton.autotune(
configs=[
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64, "PAGE_UNROLL": 1}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64, "PAGE_UNROLL": 2}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 32, "BLOCK_D": 64, "PAGE_UNROLL": 2}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128, "PAGE_UNROLL": 1}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128, "PAGE_UNROLL": 2}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 32, "BLOCK_D": 128, "PAGE_UNROLL": 2}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=8, num_stages=4),
],
key=["num_heads", "head_dim", "page_size", "group_size"],
)
@@ -56,14 +56,12 @@
BLOCK_G: tl.constexpr,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
- PAGE_UNROLL: tl.constexpr,
):
batch_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
seq_len = tl.load(SL_ptr + batch_idx)
num_pages = (seq_len + page_size - 1) // page_size
- tokens_per_iter = BLOCK_T // PAGE_UNROLL
offs_g = tl.arange(0, BLOCK_G)
offs_t = tl.arange(0, BLOCK_T)
@@ -85,50 +83,32 @@
l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
acc = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
- for page_base in range(0, num_pages, PAGE_UNROLL):
- k = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.bfloat16)
- v = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.bfloat16)
-
- for pu in tl.static_range(PAGE_UNROLL):
- page_idx = page_base + pu
- page_active = page_idx < num_pages
-
- block_id = tl.load(
- BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p,
- mask=page_active,
- other=0,
- ).to(tl.int64)
-
- page_start = page_idx * page_size
- tokens_in_page = tl.minimum(page_size, seq_len - page_start)
- kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
-
- local_t = offs_t - pu * tokens_per_iter
- t_in_page = (local_t >= 0) & (local_t < tokens_per_iter)
- t_valid = t_in_page & (local_t < tokens_in_page) & page_active
-
- k_ptrs = (
- KV_ptr
- + kv_h_base
- + local_t[:, None] * stride_kv_page
- + offs_d[None, :] * stride_kv_d
- )
- v_ptrs = (
- KV_ptr
- + kv_h_base
- + local_t[:, None] * stride_kv_page
- + (head_dim + offs_d)[None, :] * stride_kv_d
- )
- kv_mask = t_valid[:, None] & d_mask[None, :]
-
- k_page = tl.load(k_ptrs, mask=kv_mask, other=0.0)
- v_page = tl.load(v_ptrs, mask=kv_mask, other=0.0)
-
- slot = pu * tokens_per_iter + local_t
- k = tl.where(slot[:, None] == offs_t[:, None], k_page, k)
- v = tl.where(slot[:, None] == offs_t[:, None], v_page, v)
-
- t_valid_row = offs_t[None, :] < tl.minimum(seq_len - page_base * tokens_per_iter, BLOCK_T)
+ for page_idx in range(num_pages):
+ block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64)
+ page_start = page_idx * page_size
+ tokens_in_page = tl.minimum(page_size, seq_len - page_start)
+
+ kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
+ t_valid_row = offs_t[None, :] < tokens_in_page
+ t_valid_2d = offs_t[:, None] < tokens_in_page
+
+ k_ptrs = (
+ KV_ptr
+ + kv_h_base
+ + offs_t[:, None] * stride_kv_page
+ + offs_d[None, :] * stride_kv_d
+ )
+ v_ptrs = (
+ KV_ptr
+ + kv_h_base
+ + offs_t[:, None] * stride_kv_page
+ + (head_dim + offs_d)[None, :] * stride_kv_d
+ )
+ kv_mask = t_valid_2d & d_mask[None, :]
+
+ k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
+ v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
+
qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale
qk = tl.where(t_valid_row, qk, -float("inf"))
@@ -161,6 +141,10 @@
assert D2 == 2 * D
group_size = H // Hkv
out = torch.empty_like(query)
+
+ BLOCK_G = triton.next_power_of_2(group_size)
+ BLOCK_T = P
+ BLOCK_D = triton.next_power_of_2(D)
grid = (B, Hkv)
_paged_decode_gqa_kernel[grid](
@@ -186,6 +170,9 @@
page_size=P,
group_size=group_size,
scale=scale,
+ BLOCK_G=BLOCK_G,
+ BLOCK_T=BLOCK_T,
+ BLOCK_D=BLOCK_D,
)
return out
{
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<6588 chars \u2014 see diff>"
}null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,151 +1,3 @@
-"""Paged attention decode kernel via Triton (GQA-aware, page-batched)."""
-import math
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-OP_TYPE = "attention"
-SUPPORTED_PRECISIONS = ["bf16"]
-HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-
-BATCH = 8
-NUM_HEADS = 32
-NUM_KV_HEADS = 8
-HEAD_DIM = 128
-SEQ_LEN = 1024
-PAGE_SIZE = 16
-
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=8, num_stages=4),
- ],
- key=["num_heads", "head_dim", "page_size", "group_size"],
-)
-@triton.jit
-def _paged_decode_gqa_kernel(
- Q_ptr,
- KV_ptr,
- BT_ptr,
- SL_ptr,
- O_ptr,
- stride_qb,
- stride_qh,
- stride_qd,
- stride_kv_block,
- stride_kv_page,
- stride_kv_h,
- stride_kv_d,
- stride_bt_b,
- stride_bt_p,
- stride_ob,
- stride_oh,
- stride_od,
- num_heads,
- head_dim,
- page_size,
- group_size,
- scale,
- BLOCK_G: tl.constexpr,
- BLOCK_T: tl.constexpr,
- BLOCK_D: tl.constexpr,
-):
- batch_idx = tl.program_id(0)
- kv_head_idx = tl.program_id(1)
-
- seq_len = tl.load(SL_ptr + batch_idx)
- num_pages = (seq_len + page_size - 1) // page_size
-
- offs_g = tl.arange(0, BLOCK_G)
- offs_t = tl.arange(0, BLOCK_T)
- offs_d = tl.arange(0, BLOCK_D)
-
- head_ids = kv_head_idx * group_size + offs_g
- g_mask = head_ids < num_heads
- d_mask = offs_d < head_dim
-
- q_ptrs = (
- Q_ptr
- + batch_idx * stride_qb
- + head_ids[:, None] * stride_qh
- + offs_d[None, :] * stride_qd
- )
- q = tl.load(q_ptrs, mask=g_mask[:, None] & d_mask[None, :], other=0.0)
-
- m_i = tl.full([BLOCK_G], -float("inf"), tl.float32)
- l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
- acc = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
-
- for page_idx in range(num_pages):
- block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64)
- page_start = page_idx * page_size
- tokens_in_page = tl.minimum(page_size, seq_len - page_start)
-
- kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
- t_valid_row = offs_t[None, :] < tokens_in_page
- t_valid_2d = offs_t[:, None] < tokens_in_page
-
- k_ptrs = (
- KV_ptr
- + kv_h_base
- + offs_t[:, None] * stride_kv_page
- + offs_d[None, :] * stride_kv_d
- )
- v_ptrs = (
- KV_ptr
- + kv_h_base
- + offs_t[:, None] * stride_kv_page
- + (head_dim + offs_d)[None, :] * stride_kv_d
- )
- kv_mask = t_valid_2d & d_mask[None, :]
-
- k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
- v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
-
- qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale
- qk = tl.where(t_valid_row, qk, -float("inf"))
-
- m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
- alpha = tl.exp(m_i - m_ij)
- p = tl.exp(qk - m_ij[:, None])
- l_i = l_i * alpha + tl.sum(p, axis=1)
- acc = acc * alpha[:, None] + tl.dot(p.to(tl.bfloat16), v).to(tl.float32)
- m_i = m_ij
-
- out_vals = acc / l_i[:, None]
- o_ptrs = (
- O_ptr
- + batch_idx * stride_ob
- + head_ids[:, None] * stride_oh
- + offs_d[None, :] * stride_od
- )
- tl.store(o_ptrs, out_vals.to(tl.bfloat16), mask=g_mask[:, None] & d_mask[None, :])
-
-
-def paged_attention_decode(
- query: torch.Tensor,
- kv_cache: torch.Tensor,
- block_table: torch.Tensor,
- seq_lens: torch.Tensor,
- scale: float,
-) -> torch.Tensor:
- B, H, D = query.shape
- _, P, Hkv, D2 = kv_cache.shape
- assert D2 == 2 * D
- group_size = H // Hkv
- out = torch.empty_like(query)
-
- BLOCK_G = triton.next_power_of_2(group_size)
- BLOCK_T = P
- BLOCK_D = triton.next_power_of_2(D)
-
grid = (B, Hkv)
_paged_decode_gqa_kernel[grid](
query,
@@ -170,65 +22,4 @@
page_size=P,
group_size=group_size,
scale=scale,
- BLOCK_G=BLOCK_G,
- BLOCK_T=BLOCK_T,
- BLOCK_D=BLOCK_D,
- )
- return out
-
-
-class Model(nn.Module):
- def __init__(
- self,
- batch: int,
- num_heads: int,
- num_kv_heads: int,
- head_dim: int,
- seq_len: int,
- page_size: int,
- ):
- super().__init__()
- assert num_heads % num_kv_heads == 0
- self.batch = batch
- self.num_heads = num_heads
- self.num_kv_heads = num_kv_heads
- self.head_dim = head_dim
- self.seq_len = seq_len
- self.page_size = page_size
- self.group_size = num_heads // num_kv_heads
- self.scale = 1.0 / math.sqrt(head_dim)
- self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
-
- def forward(
- self,
- query: torch.Tensor,
- kv_cache: torch.Tensor,
- block_table: torch.Tensor,
- seq_lens: torch.Tensor,
- ) -> torch.Tensor:
- return paged_attention_decode(query, kv_cache, block_table, seq_lens, self.scale)
-
-
-def get_inputs():
- B = BATCH
- H = NUM_HEADS
- Hkv = NUM_KV_HEADS
- D = HEAD_DIM
- L = SEQ_LEN
- P = PAGE_SIZE
-
- pages_per_seq = (L + P - 1) // P
- total_pages
... (533 more chars){
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<578 chars \u2014 see diff>"
}null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,25 +1,9 @@
- grid = (B, Hkv)
- _paged_decode_gqa_kernel[grid](
- query,
- kv_cache,
- block_table,
- seq_lens,
- out,
- query.stride(0),
- query.stride(1),
- query.stride(2),
- kv_cache.stride(0),
- kv_cache.stride(1),
- kv_cache.stride(2),
- kv_cache.stride(3),
- block_table.stride(0),
- block_table.stride(1),
- out.stride(0),
- out.stride(1),
- out.stride(2),
- num_heads=H,
- head_dim=D,
- page_size=P,
- group_size=group_size,
- scale=scale,
- )+ configs=[
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=8, num_stages=4),
+ ],{
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<695 chars \u2014 see diff>"
}null--- a//workspace/problems/03_paged_attention/solution.py
+++ b//workspace/problems/03_paged_attention/solution.py
@@ -1,9 +1,230 @@
- configs=[
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_G": 4, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_G": 8, "BLOCK_T": 16, "BLOCK_D": 128}, num_warps=8, num_stages=4),
- ],+"""Paged attention decode kernel via Triton (GQA-aware, page-batched)."""
+import math
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "attention"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+BATCH = 8
+NUM_HEADS = 32
+NUM_KV_HEADS = 8
+HEAD_DIM = 128
+SEQ_LEN = 1024
+PAGE_SIZE = 16
+
+
+@triton.jit
+def _paged_decode_gqa_kernel(
+ Q_ptr,
+ KV_ptr,
+ BT_ptr,
+ SL_ptr,
+ O_ptr,
+ stride_qb,
+ stride_qh,
+ stride_qd,
+ stride_kv_block,
+ stride_kv_page,
+ stride_kv_h,
+ stride_kv_d,
+ stride_bt_b,
+ stride_bt_p,
+ stride_ob,
+ stride_oh,
+ stride_od,
+ num_heads,
+ head_dim,
+ page_size,
+ group_size,
+ scale,
+ BLOCK_G: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ batch_idx = tl.program_id(0)
+ kv_head_idx = tl.program_id(1)
+
+ seq_len = tl.load(SL_ptr + batch_idx)
+ num_pages = (seq_len + page_size - 1) // page_size
+
+ offs_g = tl.arange(0, BLOCK_G)
+ offs_t = tl.arange(0, BLOCK_T)
+ offs_d = tl.arange(0, BLOCK_D)
+
+ head_ids = kv_head_idx * group_size + offs_g
+ g_mask = head_ids < num_heads
+ d_mask = offs_d < head_dim
+
+ q_ptrs = (
+ Q_ptr
+ + batch_idx * stride_qb
+ + head_ids[:, None] * stride_qh
+ + offs_d[None, :] * stride_qd
+ )
+ q = tl.load(q_ptrs, mask=g_mask[:, None] & d_mask[None, :], other=0.0)
+
+ m_i = tl.full([BLOCK_G], -float("inf"), tl.float32)
+ l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32)
+
+ for page_idx in range(num_pages):
+ block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64)
+ page_start = page_idx * page_size
+ tokens_in_page = tl.minimum(page_size, seq_len - page_start)
+
+ kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h
+ t_valid_row = offs_t[None, :] < tokens_in_page
+ t_valid_2d = offs_t[:, None] < tokens_in_page
+
+ k_ptrs = (
+ KV_ptr
+ + kv_h_base
+ + offs_t[:, None] * stride_kv_page
+ + offs_d[None, :] * stride_kv_d
+ )
+ v_ptrs = (
+ KV_ptr
+ + kv_h_base
+ + offs_t[:, None] * stride_kv_page
+ + (head_dim + offs_d)[None, :] * stride_kv_d
+ )
+ kv_mask = t_valid_2d & d_mask[None, :]
+
+ k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
+ v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
+
+ qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale
+ qk = tl.where(t_valid_row, qk, -float("inf"))
+
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
+ alpha = tl.exp(m_i - m_ij)
+ p = tl.exp(qk - m_ij[:, None])
+ l_i = l_i * alpha + tl.sum(p, axis=1)
+ acc = acc * alpha[:, None] + tl.dot(p.to(tl.bfloat16), v).to(tl.float32)
+ m_i = m_ij
+
+ out_vals = acc / l_i[:, None]
+ o_ptrs = (
+ O_ptr
+ + batch_idx * stride_ob
+ + head_ids[:, None] * stride_oh
+ + offs_d[None, :] * stride_od
+ )
+ tl.store(o_ptrs, out_vals.to(tl.bfloat16), mask=g_mask[:, None] & d_mask[None, :])
+
+
+def paged_attention_decode(
+ query: torch.Tensor,
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ seq_lens: torch.Tensor,
+ scale: float,
+) -> torch.Tensor:
+ B, H, D = query.shape
+ _, P, Hkv, D2 = kv_cache.shape
+ assert D2 == 2 * D
+ group_size = H // Hkv
+ out = torch.empty_like(query)
+
+ BLOCK_G = triton.next_power_of_2(group_size)
+ BLOCK_T = P
+ BLOCK_D = triton.next_power_of_2(D)
+
+ if D <= 64:
+ num_warps, num_stages = 4, 3
+ else:
+ num_warps, num_stages = 8, 3
+
+ grid = (B, Hkv)
+ _paged_decode_gqa_kernel[grid](
+ query,
+ kv_cache,
+ block_table,
+ seq_lens,
+ out,
+ query.stride(0),
+ query.stride(1),
+ query.stride(2),
+ kv_cache.stride(0),
+ kv_cache.stride(1),
+ kv_cache.stride(2),
+ kv_cache.stride(3),
+ block_table.stride(0),
+ block_table.stride(1),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ num_heads=H,
+ head_dim=D,
+ page_size=P,
+ group_size=group_size,
+ scale=scale,
+ BLOCK_G=BLOCK_G,
+ BLOCK_T=BLOCK_T,
+ BLOCK_D=BLOCK_D,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return out
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ batch: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ seq_len: int,
+ page_size: int,
+ ):
+ super().__init__()
+ assert num_heads % num_kv_heads == 0
+ self.batch = batch
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
+ self.seq_len = seq_len
+ self.page_size = page_size
+ self.group_size = num_heads // num_kv_heads
+ se
... (1133 more chars){
"path": "/workspace/problems/03_paged_attention/solution.py",
"streamContent": "<6067 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/03_paged_attention && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_5f3a524f-d617-4a2b-a6dd-8bebe4a6781",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/03_paged_attention"
}
],
... (1491 more chars)null