./kernelbenchhardv3runsdroid · custom:GLM-5.1-[Z.AI-Coding-Plan]-0
droid / custom:GLM-5.1-[Z.AI-Coding-Plan]-0 session: 07b57880-7113-449e-b13c-38c76fb33777 cwd: /tmp/KernelBench-Hard-zai-droid/problems/03_paged_attention
INCOMPLETE SESSION. hit wall-clock budget (SIGTERM). The transcript below is usable but may be missing the agent's final tool calls or summary. Don't score this run as a clean failure or success.
harness
droid
model
custom:GLM-5.1-[Z.AI-Coding-Plan]-0
turns
28
tools called
0
events
30
input toks
90,701
output toks
109,517
cache hit
3,109,568
solution.py
benchmark.log
check.log
result.json
"""Paged attention decode kernel – FlashDecoding + GQA-shared Triton for SM120.

Optimizations:
  - GQA sharing: each K/V loaded once per (batch, kv_head), shared across group_size query heads.
  - FlashDecoding splits sequence into chunks for parallelism.
  - Two-phase: partial softmax per chunk, then reduce with log-sum-exp.
  - Specialized kernels for group_size 2, 4, 8.
"""
import math

import torch
import torch.nn as nn
import triton
import triton.language as tl


@triton.jit
def _partial_gs4(
    query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr,
    partial_out_ptr, partial_mi_ptr, partial_li_ptr,
    stride_qb, stride_qh, stride_qd,
    stride_kb, stride_kp, stride_kh, stride_kd,
    stride_btb, stride_btp,
    stride_pob, stride_poh, stride_pos, stride_pod,
    stride_mib, stride_mih, stride_mis,
    stride_lib, stride_lih, stride_lis,
    HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr,
    NUM_HEADS: tl.constexpr, NUM_KV_HEADS: tl.constexpr,
    scale, MAX_NUM_PAGES: tl.constexpr,
    BLOCK_D: tl.constexpr, BLOCK_SEQ: tl.constexpr,
):
    pid = tl.program_id(0)
    num_splits = tl.cdiv(MAX_NUM_PAGES * PAGE_SIZE, BLOCK_SEQ)
    s = pid % num_splits
    remainder = pid // num_splits
    hkv = remainder % NUM_KV_HEADS
    b = remainder // NUM_KV_HEADS

    seq_len = tl.load(seq_lens_ptr + b).to(tl.int32)
    token_start = s * BLOCK_SEQ
    token_end = tl.minimum(token_start + BLOCK_SEQ, seq_len)

    offs_d = tl.arange(0, BLOCK_D)
    mask_d = offs_d < HEAD_DIM

    q_base = query_ptr + b * stride_qb + offs_d * stride_qd
    q0 = tl.load(q_base + (hkv * 4) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q1 = tl.load(q_base + (hkv * 4 + 1) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q2 = tl.load(q_base + (hkv * 4 + 2) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q3 = tl.load(q_base + (hkv * 4 + 3) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)

    m0 = tl.full([], float('-inf'), dtype=tl.float32)
    m1 = tl.full([], float('-inf'), dtype=tl.float32)
    m2 = tl.full([], float('-inf'), dtype=tl.float32)
    m3 = tl.full([], float('-inf'), dtype=tl.float32)
    se0 = tl.zeros([], dtype=tl.float32)
    se1 = tl.zeros([], dtype=tl.float32)
    se2 = tl.zeros([], dtype=tl.float32)
    se3 = tl.zeros([], dtype=tl.float32)
    a0 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a1 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a2 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a3 = tl.zeros((BLOCK_D,), dtype=tl.float32)

    if token_start < seq_len:
        for token_flat in range(token_start, token_end):
            page_idx = token_flat // PAGE_SIZE
            t = token_flat % PAGE_SIZE
            page_id = tl.load(block_table_ptr + b * stride_btb + page_idx * stride_btp).to(tl.int32)
            page_base = kv_cache_ptr + page_id * stride_kb + hkv * stride_kh
            token_off = page_base + t * stride_kp
            k = tl.load(token_off + offs_d * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
            v = tl.load(token_off + (offs_d + HEAD_DIM) * stride_kd, mask=mask_d, other=0.0).to(tl.float32)

            # Head 0
            sc = tl.sum(q0 * k) * scale
            nm = tl.maximum(m0, sc); rc = tl.exp(m0 - nm)
            a0 = a0 * rc; se0 = se0 * rc
            e = tl.exp(sc - nm); se0 = se0 + e; a0 = a0 + e * v; m0 = nm
            # Head 1
            sc = tl.sum(q1 * k) * scale
            nm = tl.maximum(m1, sc); rc = tl.exp(m1 - nm)
            a1 = a1 * rc; se1 = se1 * rc
            e = tl.exp(sc - nm); se1 = se1 + e; a1 = a1 + e * v; m1 = nm
            # Head 2
            sc = tl.sum(q2 * k) * scale
            nm = tl.maximum(m2, sc); rc = tl.exp(m2 - nm)
            a2 = a2 * rc; se2 = se2 * rc
            e = tl.exp(sc - nm); se2 = se2 + e; a2 = a2 + e * v; m2 = nm
            # Head 3
            sc = tl.sum(q3 * k) * scale
            nm = tl.maximum(m3, sc); rc = tl.exp(m3 - nm)
            a3 = a3 * rc; se3 = se3 * rc
            e = tl.exp(sc - nm); se3 = se3 + e; a3 = a3 + e * v; m3 = nm

    # Store partials for all 4 heads
    tl.store(partial_out_ptr + b * stride_pob + (hkv*4+0) * stride_poh + s * stride_pos + offs_d * stride_pod, a0, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*4+0) * stride_mih + s * stride_mis, m0)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*4+0) * stride_lih + s * stride_lis, se0)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*4+1) * stride_poh + s * stride_pos + offs_d * stride_pod, a1, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*4+1) * stride_mih + s * stride_mis, m1)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*4+1) * stride_lih + s * stride_lis, se1)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*4+2) * stride_poh + s * stride_pos + offs_d * stride_pod, a2, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*4+2) * stride_mih + s * stride_mis, m2)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*4+2) * stride_lih + s * stride_lis, se2)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*4+3) * stride_poh + s * stride_pos + offs_d * stride_pod, a3, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*4+3) * stride_mih + s * stride_mis, m3)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*4+3) * stride_lih + s * stride_lis, se3)


# gs8 kernel for Llama-3 70B style (64/8 = 8 group_size)
@triton.jit
def _partial_gs8(
    query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr,
    partial_out_ptr, partial_mi_ptr, partial_li_ptr,
    stride_qb, stride_qh, stride_qd,
    stride_kb, stride_kp, stride_kh, stride_kd,
    stride_btb, stride_btp,
    stride_pob, stride_poh, stride_pos, stride_pod,
    stride_mib, stride_mih, stride_mis,
    stride_lib, stride_lih, stride_lis,
    HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr,
    NUM_HEADS: tl.constexpr, NUM_KV_HEADS: tl.constexpr,
    scale, MAX_NUM_PAGES: tl.constexpr,
    BLOCK_D: tl.constexpr, BLOCK_SEQ: tl.constexpr,
):
    pid = tl.program_id(0)
    num_splits = tl.cdiv(MAX_NUM_PAGES * PAGE_SIZE, BLOCK_SEQ)
    s = pid % num_splits
    remainder = pid // num_splits
    hkv = remainder % NUM_KV_HEADS
    b = remainder // NUM_KV_HEADS

    seq_len = tl.load(seq_lens_ptr + b).to(tl.int32)
    token_start = s * BLOCK_SEQ
    token_end = tl.minimum(token_start + BLOCK_SEQ, seq_len)

    offs_d = tl.arange(0, BLOCK_D)
    mask_d = offs_d < HEAD_DIM

    q_base = query_ptr + b * stride_qb + offs_d * stride_qd
    q0 = tl.load(q_base + (hkv * 8 + 0) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q1 = tl.load(q_base + (hkv * 8 + 1) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q2 = tl.load(q_base + (hkv * 8 + 2) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q3 = tl.load(q_base + (hkv * 8 + 3) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q4 = tl.load(q_base + (hkv * 8 + 4) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q5 = tl.load(q_base + (hkv * 8 + 5) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q6 = tl.load(q_base + (hkv * 8 + 6) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q7 = tl.load(q_base + (hkv * 8 + 7) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)

    m0 = tl.full([], float('-inf'), dtype=tl.float32)
    m1 = tl.full([], float('-inf'), dtype=tl.float32)
    m2 = tl.full([], float('-inf'), dtype=tl.float32)
    m3 = tl.full([], float('-inf'), dtype=tl.float32)
    m4 = tl.full([], float('-inf'), dtype=tl.float32)
    m5 = tl.full([], float('-inf'), dtype=tl.float32)
    m6 = tl.full([], float('-inf'), dtype=tl.float32)
    m7 = tl.full([], float('-inf'), dtype=tl.float32)
    se0 = tl.zeros([], dtype=tl.float32)
    se1 = tl.zeros([], dtype=tl.float32)
    se2 = tl.zeros([], dtype=tl.float32)
    se3 = tl.zeros([], dtype=tl.float32)
    se4 = tl.zeros([], dtype=tl.float32)
    se5 = tl.zeros([], dtype=tl.float32)
    se6 = tl.zeros([], dtype=tl.float32)
    se7 = tl.zeros([], dtype=tl.float32)
    a0 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a1 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a2 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a3 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a4 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a5 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a6 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a7 = tl.zeros((BLOCK_D,), dtype=tl.float32)

    if token_start < seq_len:
        for token_flat in range(token_start, token_end):
            page_idx = token_flat // PAGE_SIZE
            t = token_flat % PAGE_SIZE
            page_id = tl.load(block_table_ptr + b * stride_btb + page_idx * stride_btp).to(tl.int32)
            page_base = kv_cache_ptr + page_id * stride_kb + hkv * stride_kh
            token_off = page_base + t * stride_kp
            k = tl.load(token_off + offs_d * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
            v = tl.load(token_off + (offs_d + HEAD_DIM) * stride_kd, mask=mask_d, other=0.0).to(tl.float32)

            # Head 0
            sc = tl.sum(q0 * k) * scale
            nm = tl.maximum(m0, sc); rc = tl.exp(m0 - nm)
            a0 = a0 * rc; se0 = se0 * rc
            e = tl.exp(sc - nm); se0 = se0 + e; a0 = a0 + e * v; m0 = nm
            # Head 1
            sc = tl.sum(q1 * k) * scale
            nm = tl.maximum(m1, sc); rc = tl.exp(m1 - nm)
            a1 = a1 * rc; se1 = se1 * rc
            e = tl.exp(sc - nm); se1 = se1 + e; a1 = a1 + e * v; m1 = nm
            # Head 2
            sc = tl.sum(q2 * k) * scale
            nm = tl.maximum(m2, sc); rc = tl.exp(m2 - nm)
            a2 = a2 * rc; se2 = se2 * rc
            e = tl.exp(sc - nm); se2 = se2 + e; a2 = a2 + e * v; m2 = nm
            # Head 3
            sc = tl.sum(q3 * k) * scale
            nm = tl.maximum(m3, sc); rc = tl.exp(m3 - nm)
            a3 = a3 * rc; se3 = se3 * rc
            e = tl.exp(sc - nm); se3 = se3 + e; a3 = a3 + e * v; m3 = nm
            # Head 4
            sc = tl.sum(q4 * k) * scale
            nm = tl.maximum(m4, sc); rc = tl.exp(m4 - nm)
            a4 = a4 * rc; se4 = se4 * rc
            e = tl.exp(sc - nm); se4 = se4 + e; a4 = a4 + e * v; m4 = nm
            # Head 5
            sc = tl.sum(q5 * k) * scale
            nm = tl.maximum(m5, sc); rc = tl.exp(m5 - nm)
            a5 = a5 * rc; se5 = se5 * rc
            e = tl.exp(sc - nm); se5 = se5 + e; a5 = a5 + e * v; m5 = nm
            # Head 6
            sc = tl.sum(q6 * k) * scale
            nm = tl.maximum(m6, sc); rc = tl.exp(m6 - nm)
            a6 = a6 * rc; se6 = se6 * rc
            e = tl.exp(sc - nm); se6 = se6 + e; a6 = a6 + e * v; m6 = nm
            # Head 7
            sc = tl.sum(q7 * k) * scale
            nm = tl.maximum(m7, sc); rc = tl.exp(m7 - nm)
            a7 = a7 * rc; se7 = se7 * rc
            e = tl.exp(sc - nm); se7 = se7 + e; a7 = a7 + e * v; m7 = nm

    tl.store(partial_out_ptr + b * stride_pob + (hkv*8+0) * stride_poh + s * stride_pos + offs_d * stride_pod, a0, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+0) * stride_mih + s * stride_mis, m0)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*8+0) * stride_lih + s * stride_lis, se0)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*8+1) * stride_poh + s * stride_pos + offs_d * stride_pod, a1, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+1) * stride_mih + s * stride_mis, m1)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*8+1) * stride_lih + s * stride_lis, se1)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*8+2) * stride_poh + s * stride_pos + offs_d * stride_pod, a2, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+2) * stride_mih + s * stride_mis, m2)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*8+2) * stride_lih + s * stride_lis, se2)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*8+3) * stride_poh + s * stride_pos + offs_d * stride_pod, a3, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+3) * stride_mih + s * stride_mis, m3)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*8+3) * stride_lih + s * stride_lis, se3)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*8+4) * stride_poh + s * stride_pos + offs_d * stride_pod, a4, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+4) * stride_mih + s * stride_mis, m4)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*8+4) * stride_lih + s * stride_lis, se4)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*8+5) * stride_poh + s * stride_pos + offs_d * stride_pod, a5, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+5) * stride_mih + s * stride_mis, m5)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*8+5) * stride_lih + s * stride_lis, se5)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*8+6) * stride_poh + s * stride_pos + offs_d * stride_pod, a6, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+6) * stride_mih + s * stride_mis, m6)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*8+6) * stride_lih + s * stride_lis, se6)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*8+7) * stride_poh + s * stride_pos + offs_d * stride_pod, a7, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+7) * stride_mih + s * stride_mis, m7)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*8+7) * stride_lih + s * stride_lis, se7)


@triton.jit
def _partial_gs2(
    query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr,
    partial_out_ptr, partial_mi_ptr, partial_li_ptr,
    stride_qb, stride_qh, stride_qd,
    stride_kb, stride_kp, stride_kh, stride_kd,
    stride_btb, stride_btp,
    stride_pob, stride_poh, stride_pos, stride_pod,
    stride_mib, stride_mih, stride_mis,
    stride_lib, stride_lih, stride_lis,
    HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr,
    NUM_HEADS: tl.constexpr, NUM_KV_HEADS: tl.constexpr,
    scale, MAX_NUM_PAGES: tl.constexpr,
    BLOCK_D: tl.constexpr, BLOCK_SEQ: tl.constexpr,
):
    pid = tl.program_id(0)
    num_splits = tl.cdiv(MAX_NUM_PAGES * PAGE_SIZE, BLOCK_SEQ)
    s = pid % num_splits
    remainder = pid // num_splits
    hkv = remainder % NUM_KV_HEADS
    b = remainder // NUM_KV_HEADS

    seq_len = tl.load(seq_lens_ptr + b).to(tl.int32)
    token_start = s * BLOCK_SEQ
    token_end = tl.minimum(token_start + BLOCK_SEQ, seq_len)

    offs_d = tl.arange(0, BLOCK_D)
    mask_d = offs_d < HEAD_DIM

    q_base = query_ptr + b * stride_qb + offs_d * stride_qd
    q0 = tl.load(q_base + (hkv * 2) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
    q1 = tl.load(q_base + (hkv * 2 + 1) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)

    m0 = tl.full([], float('-inf'), dtype=tl.float32)
    m1 = tl.full([], float('-inf'), dtype=tl.float32)
    se0 = tl.zeros([], dtype=tl.float32)
    se1 = tl.zeros([], dtype=tl.float32)
    a0 = tl.zeros((BLOCK_D,), dtype=tl.float32)
    a1 = tl.zeros((BLOCK_D,), dtype=tl.float32)

    if token_start < seq_len:
        for token_flat in range(token_start, token_end):
            page_idx = token_flat // PAGE_SIZE
            t = token_flat % PAGE_SIZE
            page_id = tl.load(block_table_ptr + b * stride_btb + page_idx * stride_btp).to(tl.int32)
            page_base = kv_cache_ptr + page_id * stride_kb + hkv * stride_kh
            token_off = page_base + t * stride_kp
            k = tl.load(token_off + offs_d * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
            v = tl.load(token_off + (offs_d + HEAD_DIM) * stride_kd, mask=mask_d, other=0.0).to(tl.float32)

            sc = tl.sum(q0 * k) * scale
            nm = tl.maximum(m0, sc); rc = tl.exp(m0 - nm)
            a0 = a0 * rc; se0 = se0 * rc
            e = tl.exp(sc - nm); se0 = se0 + e; a0 = a0 + e * v; m0 = nm

            sc = tl.sum(q1 * k) * scale
            nm = tl.maximum(m1, sc); rc = tl.exp(m1 - nm)
            a1 = a1 * rc; se1 = se1 * rc
            e = tl.exp(sc - nm); se1 = se1 + e; a1 = a1 + e * v; m1 = nm

    tl.store(partial_out_ptr + b * stride_pob + (hkv*2+0) * stride_poh + s * stride_pos + offs_d * stride_pod, a0, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*2+0) * stride_mih + s * stride_mis, m0)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*2+0) * stride_lih + s * stride_lis, se0)
    tl.store(partial_out_ptr + b * stride_pob + (hkv*2+1) * stride_poh + s * stride_pos + offs_d * stride_pod, a1, mask=mask_d)
    tl.store(partial_mi_ptr + b * stride_mib + (hkv*2+1) * stride_mih + s * stride_mis, m1)
    tl.store(partial_li_ptr + b * stride_lib + (hkv*2+1) * stride_lih + s * stride_lis, se1)


@triton.jit
def _reduce_kernel(
    partial_out_ptr, partial_mi_ptr, partial_li_ptr, output_ptr,
    stride_pob, stride_poh, stride_pos, stride_pod,
    stride_mib, stride_mih, stride_mis,
    stride_lib, stride_lih, stride_lis,
    stride_ob, stride_oh, stride_od,
    NUM_HEADS: tl.constexpr, NUM_SPLITS: tl.constexpr,
    HEAD_DIM: tl.constexpr, BLOCK_D: tl.constexpr,
):
    pid = tl.program_id(0)
    b = pid // NUM_HEADS
    h = pid % NUM_HEADS

    offs_d = tl.arange(0, BLOCK_D)
    mask_d = offs_d < HEAD_DIM

    global_max = tl.full([], float('-inf'), dtype=tl.float32)
    for s in range(NUM_SPLITS):
        mi = tl.load(partial_mi_ptr + b * stride_mib + h * stride_mih + s * stride_mis)
        global_max = tl.maximum(global_max, mi)

    total_sum = tl.zeros([], dtype=tl.float32)
    total_acc = tl.zeros((BLOCK_D,), dtype=tl.float32)

    for s in range(NUM_SPLITS):
        mi = tl.load(partial_mi_ptr + b * stride_mib + h * stride_mih + s * stride_mis)
        li = tl.load(partial_li_ptr + b * stride_lib + h * stride_lih + s * stride_lis)
        acc = tl.load(partial_out_ptr + b * stride_pob + h * stride_poh + s * stride_pos + offs_d * stride_pod,
                      mask=mask_d, other=0.0)
        rc = tl.exp(mi - global_max)
        total_sum = total_sum + li * rc
        total_acc = total_acc + acc * rc

    result = (total_acc * (1.0 / total_sum)).to(tl.bfloat16)
    tl.store(output_ptr + b * stride_ob + h * stride_oh + offs_d * stride_od, result, mask=mask_d)


def paged_attention_decode_triton(
    query, kv_cache, block_table, seq_lens,
    num_heads, num_kv_heads, head_dim, page_size,
):
    B, H, D = query.shape
    G = num_heads // num_kv_heads
    scale = 1.0 / math.sqrt(head_dim)
    max_pages = block_table.shape[1]
    max_seq_len = max_pages * page_size

    BLOCK_D = triton.next_power_of_2(head_dim)
    BLOCK_SEQ = 128
    num_splits = triton.cdiv(max_seq_len, BLOCK_SEQ)

    output = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device)
    partial_out = torch.empty(B, H, num_splits, D, dtype=torch.float32, device=query.device)
    partial_mi = torch.full((B, H, num_splits), float('-inf'), dtype=torch.float32, device=query.device)
    partial_li = torch.zeros(B, H, num_splits, dtype=torch.float32, device=query.device)

    pk = {2: _partial_gs2, 4: _partial_gs4, 8: _partial_gs8}[G]

    grid = (B * num_kv_heads * num_splits,)
    pk[grid](
        query, kv_cache, block_table, seq_lens,
        partial_out, partial_mi, partial_li,
        query.stride(0), query.stride(1), query.stride(2),
        kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
        block_table.stride(0), block_table.stride(1),
        partial_out.stride(0), partial_out.stride(1), partial_out.stride(2), partial_out.stride(3),
        partial_mi.stride(0), partial_mi.stride(1), partial_mi.stride(2),
        partial_li.stride(0), partial_li.stride(1), partial_li.stride(2),
        HEAD_DIM=head_dim, PAGE_SIZE=page_size,
        NUM_HEADS=num_heads, NUM_KV_HEADS=num_kv_heads,
        scale=scale, MAX_NUM_PAGES=max_pages,
        BLOCK_D=BLOCK_D, BLOCK_SEQ=BLOCK_SEQ,
        num_warps=1, num_stages=5,
    )

    grid_reduce = (B * H,)
    _reduce_kernel[grid_reduce](
        partial_out, partial_mi, partial_li, output,
        partial_out.stride(0), partial_out.stride(1), partial_out.stride(2), partial_out.stride(3),
        partial_mi.stride(0), partial_mi.stride(1), partial_mi.stride(2),
        partial_li.stride(0), partial_li.stride(1), partial_li.stride(2),
        output.stride(0), output.stride(1), output.stride(2),
        NUM_HEADS=H, NUM_SPLITS=num_splits,
        HEAD_DIM=head_dim, BLOCK_D=BLOCK_D,
        num_warps=4, num_stages=2,
    )

    return output


# ---------------------------------------------------------------------------
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
BATCH = 8
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
SEQ_LEN = 1024
PAGE_SIZE = 16


class Model(nn.Module):
    def __init__(self, batch, num_heads, num_kv_heads, head_dim, seq_len, page_size):
        super().__init__()
        assert num_heads % num_kv_heads == 0
        self.batch = batch
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.seq_len = seq_len
        self.page_size = page_size
        self.group_size = num_heads // num_kv_heads
        self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)

    def forward(self, query, kv_cache, block_table, seq_lens):
        return paged_attention_decode_triton(
            query, kv_cache, block_table, seq_lens,
            num_heads=self.num_heads, num_kv_heads=self.num_kv_heads,
            head_dim=self.head_dim, page_size=self.page_size,
        )


def get_inputs():
    B, H, Hkv, D, L, P = BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE
    pages_per_seq = (L + P - 1) // P
    total_pages = max(B * pages_per_seq + 8, 64)
    query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1
    kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1
    perm = torch.randperm(total_pages)[:B * pages_per_seq].reshape(B, pages_per_seq).int()
    block_table = perm.contiguous()
    seq_lens = torch.full((B,), L, dtype=torch.int32)
    return [query, kv_cache, block_table, seq_lens]


def get_init_inputs():
    return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]   File "/tmp/KernelBench-Hard-zai-droid/problems/03_paged_attention/reference.py", line 89, in forward
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     L = int(seq_lens[b].item())
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
W0508 17:23:29.298000 4063667 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
shape=0 variant=eager tflops=0.193 gbps=48.402 ms=0.696
shape=0 variant=compiled tflops=0.204 gbps=51.159 ms=0.658
shape=0 variant=solution tflops=1.164 gbps=292.084 ms=0.115
shape=0 solution_peak_fraction=0.1623
shape=1 variant=eager tflops=0.310 gbps=77.741 ms=3.460
shape=1 variant=compiled tflops=0.314 gbps=78.572 ms=3.423
shape=1 variant=solution tflops=4.987 gbps=1249.070 ms=0.215
shape=1 solution_peak_fraction=0.6939
shape=2 variant=eager tflops=0.243 gbps=30.438 ms=2.209
shape=2 variant=compiled tflops=0.243 gbps=30.395 ms=2.212
shape=2 variant=solution tflops=4.107 gbps=514.318 ms=0.131
shape=2 solution_peak_fraction=0.2857
shape=3 variant=eager tflops=0.278 gbps=69.783 ms=1.445
shape=3 variant=compiled tflops=0.274 gbps=68.584 ms=1.471
shape=3 variant=solution tflops=2.957 gbps=741.268 ms=0.136
shape=3 solution_peak_fraction=0.4118
shape=4 variant=eager tflops=0.103 gbps=25.787 ms=0.637
shape=4 variant=compiled tflops=0.099 gbps=24.738 ms=0.664
shape=4 variant=solution tflops=0.554 gbps=138.899 ms=0.118
shape=4 solution_peak_fraction=0.0772
peak_fraction: 0.2523
RESULT: OK
Uninstalled 1 package in 0.27ms
Installed 1 package in 1ms
PASS
{
    "problem": "03_paged_attention",
    "harness": "droid",
    "model": "custom:GLM-5.1-[Z.AI-Coding-Plan]-0",
    "reasoning_effort": "",
    "has_solution": true,
    "correct": true,
    "peak_fraction": 0.2523,
    "elapsed_seconds": 2700,
    "harness_exit_code": 124,
    "session_complete": false,
    "usage": {"input_tokens": 90701, "output_tokens": 109517, "cache_read_tokens": 3109568, "cache_creation_tokens": 0, "reasoning_tokens": 0, "total_cost_usd": null}
}

timeline (30 events)

system — init
session start cwd=/tmp/KernelBench-Hard-zai-droid/problems/03_paged_attention model=custom:GLM-5.1-[Z.AI-Coding-Plan]-0 effort=none tools=31
user
I need you to write a paged attention decode kernel for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py. This is single-query decode: each batch element has a query of shape (num_heads, head_dim) and attends over a KV cache of seq_len[b] tokens stored as fixed-size pages in a global pool, with block_table[b] listing which pages belong to batch element b. The KV cache is laid out as (num_blocks, page_size, num_kv_heads, head_dim * 2) where the last dim packs [K | V] so a single gather pulls both. You return the bf16 output of shape (batch, num_heads, head_dim). Correctness tolerance is 0.02 abs/rel on bf16/fp16. The shapes you have to handle are batch=8 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=1024 page_size=16 (Llama-3 8B-style decode), batch=32 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=2048 page_size=16 (server-batched decode), batch=4 num_heads=64 num_kv_heads=8 head_dim=128 seq_len=4096 page_size=16 (Llama-3 70B GQA wide ratio with long context), batch=16 num_heads=32 num_kv_heads=8 head_dim=128 seq_len=1535 page_size=16 (non-power-of-two seq_len that forces predicated tails), and batch=8 num_heads=16 num_kv_heads=4 head_dim=64 seq_len=2000 page_size=16 (small head_dim, non-pow2). This needs to be a real custom kernel. Don't import or call vllm.attention, flashinfer.batch_decode_with_paged_kv_cache, flashinfer.decode, torch.nn.functional.scaled_dot_product_attention, or F.scaled_dot_product_attention — they're off-limits, including the trick of gathering pages first and then dispatching to SDPA. Try CUDA C++ via load_inline, Triton, inline PTX, or CUTLASS — whatever fits. Anything you're uncertain about, look up PTX docs, clone CUTLASS or vLLM or FlashInfer or other reference repos, read library headers, and investigate. Your flywheel is implement, profile (ncu, nsys, torch.profiler — whatever's useful) and time it with benchmark.py, verify correctness by running `python check.py` and reading the output, then iterate. Don't substitute your own one-off correctness snippets for check.py — it iterates over every shape, your spot-check almost certainly won't. If `python check.py` hasn't printed PASS, you're not done. Take as long as you need to actually push the number up.
assistant
Now let me understand the full problem by reading all the relevant files.
assistant
Let me check the environment first:
assistant
Now I understand the full problem. Let me design the kernel. This is a **memory-bound** decode attention kernel on SM120 (Blackwell). Key observations: 1. **Decode attention** = single query token attending to full KV cache → memory-bound (streaming K,V) 2. **GQA**: num_heads=32, num_kv_heads=8, so group_size=4 (each KV head shared by 4 query heads) 3. **Paged KV cache**: (num_blocks, page_size, num_kv_heads, head_dim*2) with interleaved [K|V] 4. **Peak bandwidth**: 1.8 TB/s → need to maximize DRAM throughput **Approach: Triton kernel** that: - Each CTA handles one (batch_element, kv_head) pair - Iterates over pages, loading K and V in vectorized loads - Computes QK^T dot products in fp32 accumulation - Online softmax (max/sum reduction) - Weighted V accumulation - GQA: group_size query heads share the same KV head result Let me implement this:
assistant
Now let me test correctness:
assistant
Right, Triton doesn't support indexed assignment. I need to restructure the kernel to avoid index assignment on tensors. Let me redesign with a different approach - unroll over group_size at the kernel level or use a different accumulation strategy.
assistant
Triton doesn't support `break` in loops. I need to use conditional masking instead. Let me rewrite the inner loop:
assistant
Correctness passes. Now let me run the benchmark:
assistant
The performance is quite low. The problem is that my kernel is serializing too much - each CTA processes one query head and iterates through all tokens sequentially with no parallelism across tokens. The key issues: 1. **Too many small CTAs**: B*num_heads CTAs, each doing very little compute but a lot of serial memory access 2. **No vectorized loads**: Loading one K/V pair at a time 3. **No GQA sharing**: Loading the same K/V for multiple query heads redundantly 4. **Inner loop over page_size=16**: Not parallelized at all Let me redesign with a much better approach - process multiple tokens per CTA using block parallelism, and share KV loads across GQA heads:
assistant
Good, correctness passes. Let me benchmark this version:
assistant
Still slow. The problem is likely too few CTAs (B*num_kv_heads is small, like 8*8=64), and each CTA does a lot of serial work. The per-query-head version (B*num_heads = 8*32=256 CTAs) was actually better for some shapes. The real issue is that: 1. The inner loop over page_size=16 tokens is fully serial within each CTA 2. Not enough parallelism to saturate the GPU 3. Each load is only head_dim=128 elements of bf16 = 256 bytes, which is small Let me try a completely different strategy: use a CUDA C++ kernel via `load_inline` that processes multiple tokens in parallel per warp, with better memory access patterns.
assistant
The CUDA C++ approach has API issues with this version of PyTorch. Let me go back to the Triton approach, but optimize it more aggressively. The key insight: the problem is memory-bound, and I need to: 1. Use more CTAs (go back to per-query-head) 2. Ensure contiguous memory access patterns 3. Reduce kernel launch overhead Let me also try a fundamentally different approach: do the gather first (flatten pages) and then do the attention with better parallelism.
assistant
OK, the CUDA C++ approach has API compatibility issues. Let me go back to the proven Triton approach and fix the performance. The first Triton version got 0.025 peak fraction which is very low. Let me analyze the issue and optimize. Key issues with the v1 Triton kernel: 1. Loading K and V separately (2 loads per token) - could load K|V together 2. The mask on loads with `is_valid` creates unnecessary overhead 3. Per-query-head CTAs means redundant K/V loads across GQA heads Let me go back to the v2 GQA-sharing approach (the one that passed check.py), and focus on optimizing it. The issue there was performance was even worse. Let me try a hybrid: use the per-head approach but with `num_stages` and `num_warps` tuning, and ensure the loads are coalesced.
assistant
Wait, let me think more carefully about what's happening. The Triton kernel already passed but was slow. The issue is likely that: 1. The KV cache has stride_kd > 1 (it's not contiguous in the last dim since it packs [K|V]), so loads aren't coalesced 2. Each CTA only has 1 warp of useful work, doing serial token iteration Let me check the actual kv_cache strides:
assistant
Good - `stride_kd = 1`, so K and V are contiguous in memory. K[0:D] and V[D:2D] are laid out contiguously. This means we can potentially load K and V together in one wide load of 2*head_dim elements. The key performance issue is that each CTA (one per query head) does serial iteration over all tokens, and the total number of CTAs is only B*H which can be as low as 4*64=256 (shape 2) or 8*32=256 (shape 0). For the RTX PRO 6000 with lots of SMs, we need more parallelism. Let me try a different approach: parallelize across tokens using multiple CTAs per (batch, head) pair, compute partial QK scores, then reduce. This is similar to FlashDecoding.
assistant
Big improvement! We went from 0.025 to 0.1074 peak fraction, and it's now "OK" (>0.1). But I can do much better. The FlashDecoding approach works well - let me tune the parameters: 1. Shape 2 (batch=4, H=64, L=4096) is the weakest at 0.0896 2. Shape 4 (batch=8, H=16, L=2000, D=64) is also weak at 0.0403 Let me try tuning BLOCK_SEQ and num_warps more carefully. Also, I should try to avoid the overhead of the intermediate buffers.
assistant
The RTX PRO 6000 has 18176 CUDA cores across ~142 SMs. We want enough CTAs to keep all SMs busy. Around 500-2000 CTAs seems reasonable. The current BLOCK_SEQ=128 for all shapes looks like a reasonable choice. Let me try tuning more carefully. The key issue is that the reduce phase adds overhead, especially the intermediate memory allocation. Let me try a different approach: tune the BLOCK_SEQ per shape, and also try to make the reduce kernel faster by minimizing the loop over splits.
assistant
The auto-tuning of BLOCK_SEQ made things worse. The original fixed BLOCK_SEQ=128 was better. Let me revert and try different fixed values:
assistant
Now let me think about what's limiting performance. The key insight for memory-bound kernels: we need to maximize the bytes transferred per unit time. The RTX PRO 6000 has 1.8 TB/s bandwidth. For shape 0 (B=8, H=32, L=1024, D=128, Hkv=8): - KV bytes = 2 * 8 * 1024 * 8 * 128 * 2 = 33.55 MB - Peak time = 33.55 MB / 1800 GB/s = 0.019 ms - Our time = 0.140 ms → ~7x slower than peak The problem is that with FlashDecoding, we're loading K/V multiple times (once per query head in the GQA group). With group_size=4, each K/V element is loaded 4 times! For a memory-bound kernel, this is the bottleneck. Let me go back to the shared-KV approach (v2 kernel from before that shared KV across GQA heads), but apply the FlashDecoding split to it. That way each KV element is loaded exactly once.
assistant
Triton doesn't support inner function definitions. I need to inline the update logic. Also, the ternary pattern `x if group_size > 1 else y` with `tl.zeros` creates issues when the conditionals reference variables defined in the same scope. Let me take a simpler approach - just handle group_size=4 and group_size=8 explicitly as separate kernels, which are the only cases we need:
assistant
The closure approach doesn't work with Triton because GS is a Python variable that can't be captured. I need to use actual string-based code generation or just write the kernels explicitly. Let me take a simpler approach - just write separate kernels for each group_size inline:
assistant
Triton doesn't support `list.append()` either. The gs4 kernel works because it doesn't use lists. Let me rewrite gs8 without lists:
assistant
Significant improvement! 0.1381 peak fraction, up from 0.1074. The GQA sharing is helping - especially shape 1 (0.2707!) and shape 3 (0.2077). But shape 2 (0.1170) and shape 4 (0.0620) are still weak. Let me try tuning `num_warps` and `BLOCK_SEQ` for different shapes. The issue might be: - For shape 4 (head_dim=64, group_size=4), each CTA does very little work - only 64 elements per vector - For shape 2 (seq_len=4096), the reduce kernel has to iterate over 32 splits which adds overhead Let me try with `num_warps=8` for the gs8 kernel and also try to optimize the reduce phase by pre-computing the max in one pass:
assistant
Slightly better overall (0.1405 vs 0.1381). Shape 0 improved to 0.1642 with BLOCK_SEQ=64. Shape 2 got worse (0.0946 vs 0.1170). Let me try different BLOCK_SEQ values per shape more carefully:
assistant
This is slightly worse. Let me go back to the per-seq-length tuning which was better:
assistant
Now let me try another approach to boost performance. The key bottleneck now is that the Triton kernels aren't achieving high enough DRAM bandwidth. Let me try to use `num_warps=2` for simpler kernels and `num_stages=5` for better pipelining:
assistant
Let me try tuning with `num_stages=5` for better memory pipelining:
assistant
Essentially the same. Let me try `num_warps=2` for the gs4 kernel (less register pressure, more occupancy):
assistant
Huge improvement! `num_warps=2` made a big difference: - Shape 1: 0.2708 → 0.4280 (770 GB/s!) - Shape 2: 0.1171 → 0.1912 (344 GB/s!) - Shape 3: 0.2078 → 0.3283 (591 GB/s!) - Peak fraction: 0.1384 → **0.1870** But shape 4 is still low at 0.0680 and shape 0 is at 0.1252. Let me try `num_warps=1` to see if even more occupancy helps: