"""Paged attention decode kernel – FlashDecoding + GQA-shared Triton for SM120.
Optimizations:
- GQA sharing: each K/V loaded once per (batch, kv_head), shared across group_size query heads.
- FlashDecoding splits sequence into chunks for parallelism.
- Two-phase: partial softmax per chunk, then reduce with log-sum-exp.
- Specialized kernels for group_size 2, 4, 8.
"""
import math
import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _partial_gs4(
query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr,
partial_out_ptr, partial_mi_ptr, partial_li_ptr,
stride_qb, stride_qh, stride_qd,
stride_kb, stride_kp, stride_kh, stride_kd,
stride_btb, stride_btp,
stride_pob, stride_poh, stride_pos, stride_pod,
stride_mib, stride_mih, stride_mis,
stride_lib, stride_lih, stride_lis,
HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr,
NUM_HEADS: tl.constexpr, NUM_KV_HEADS: tl.constexpr,
scale, MAX_NUM_PAGES: tl.constexpr,
BLOCK_D: tl.constexpr, BLOCK_SEQ: tl.constexpr,
):
pid = tl.program_id(0)
num_splits = tl.cdiv(MAX_NUM_PAGES * PAGE_SIZE, BLOCK_SEQ)
s = pid % num_splits
remainder = pid // num_splits
hkv = remainder % NUM_KV_HEADS
b = remainder // NUM_KV_HEADS
seq_len = tl.load(seq_lens_ptr + b).to(tl.int32)
token_start = s * BLOCK_SEQ
token_end = tl.minimum(token_start + BLOCK_SEQ, seq_len)
offs_d = tl.arange(0, BLOCK_D)
mask_d = offs_d < HEAD_DIM
q_base = query_ptr + b * stride_qb + offs_d * stride_qd
q0 = tl.load(q_base + (hkv * 4) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q1 = tl.load(q_base + (hkv * 4 + 1) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q2 = tl.load(q_base + (hkv * 4 + 2) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q3 = tl.load(q_base + (hkv * 4 + 3) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
m0 = tl.full([], float('-inf'), dtype=tl.float32)
m1 = tl.full([], float('-inf'), dtype=tl.float32)
m2 = tl.full([], float('-inf'), dtype=tl.float32)
m3 = tl.full([], float('-inf'), dtype=tl.float32)
se0 = tl.zeros([], dtype=tl.float32)
se1 = tl.zeros([], dtype=tl.float32)
se2 = tl.zeros([], dtype=tl.float32)
se3 = tl.zeros([], dtype=tl.float32)
a0 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a1 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a2 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a3 = tl.zeros((BLOCK_D,), dtype=tl.float32)
if token_start < seq_len:
for token_flat in range(token_start, token_end):
page_idx = token_flat // PAGE_SIZE
t = token_flat % PAGE_SIZE
page_id = tl.load(block_table_ptr + b * stride_btb + page_idx * stride_btp).to(tl.int32)
page_base = kv_cache_ptr + page_id * stride_kb + hkv * stride_kh
token_off = page_base + t * stride_kp
k = tl.load(token_off + offs_d * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
v = tl.load(token_off + (offs_d + HEAD_DIM) * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
# Head 0
sc = tl.sum(q0 * k) * scale
nm = tl.maximum(m0, sc); rc = tl.exp(m0 - nm)
a0 = a0 * rc; se0 = se0 * rc
e = tl.exp(sc - nm); se0 = se0 + e; a0 = a0 + e * v; m0 = nm
# Head 1
sc = tl.sum(q1 * k) * scale
nm = tl.maximum(m1, sc); rc = tl.exp(m1 - nm)
a1 = a1 * rc; se1 = se1 * rc
e = tl.exp(sc - nm); se1 = se1 + e; a1 = a1 + e * v; m1 = nm
# Head 2
sc = tl.sum(q2 * k) * scale
nm = tl.maximum(m2, sc); rc = tl.exp(m2 - nm)
a2 = a2 * rc; se2 = se2 * rc
e = tl.exp(sc - nm); se2 = se2 + e; a2 = a2 + e * v; m2 = nm
# Head 3
sc = tl.sum(q3 * k) * scale
nm = tl.maximum(m3, sc); rc = tl.exp(m3 - nm)
a3 = a3 * rc; se3 = se3 * rc
e = tl.exp(sc - nm); se3 = se3 + e; a3 = a3 + e * v; m3 = nm
# Store partials for all 4 heads
tl.store(partial_out_ptr + b * stride_pob + (hkv*4+0) * stride_poh + s * stride_pos + offs_d * stride_pod, a0, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*4+0) * stride_mih + s * stride_mis, m0)
tl.store(partial_li_ptr + b * stride_lib + (hkv*4+0) * stride_lih + s * stride_lis, se0)
tl.store(partial_out_ptr + b * stride_pob + (hkv*4+1) * stride_poh + s * stride_pos + offs_d * stride_pod, a1, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*4+1) * stride_mih + s * stride_mis, m1)
tl.store(partial_li_ptr + b * stride_lib + (hkv*4+1) * stride_lih + s * stride_lis, se1)
tl.store(partial_out_ptr + b * stride_pob + (hkv*4+2) * stride_poh + s * stride_pos + offs_d * stride_pod, a2, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*4+2) * stride_mih + s * stride_mis, m2)
tl.store(partial_li_ptr + b * stride_lib + (hkv*4+2) * stride_lih + s * stride_lis, se2)
tl.store(partial_out_ptr + b * stride_pob + (hkv*4+3) * stride_poh + s * stride_pos + offs_d * stride_pod, a3, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*4+3) * stride_mih + s * stride_mis, m3)
tl.store(partial_li_ptr + b * stride_lib + (hkv*4+3) * stride_lih + s * stride_lis, se3)
# gs8 kernel for Llama-3 70B style (64/8 = 8 group_size)
@triton.jit
def _partial_gs8(
query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr,
partial_out_ptr, partial_mi_ptr, partial_li_ptr,
stride_qb, stride_qh, stride_qd,
stride_kb, stride_kp, stride_kh, stride_kd,
stride_btb, stride_btp,
stride_pob, stride_poh, stride_pos, stride_pod,
stride_mib, stride_mih, stride_mis,
stride_lib, stride_lih, stride_lis,
HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr,
NUM_HEADS: tl.constexpr, NUM_KV_HEADS: tl.constexpr,
scale, MAX_NUM_PAGES: tl.constexpr,
BLOCK_D: tl.constexpr, BLOCK_SEQ: tl.constexpr,
):
pid = tl.program_id(0)
num_splits = tl.cdiv(MAX_NUM_PAGES * PAGE_SIZE, BLOCK_SEQ)
s = pid % num_splits
remainder = pid // num_splits
hkv = remainder % NUM_KV_HEADS
b = remainder // NUM_KV_HEADS
seq_len = tl.load(seq_lens_ptr + b).to(tl.int32)
token_start = s * BLOCK_SEQ
token_end = tl.minimum(token_start + BLOCK_SEQ, seq_len)
offs_d = tl.arange(0, BLOCK_D)
mask_d = offs_d < HEAD_DIM
q_base = query_ptr + b * stride_qb + offs_d * stride_qd
q0 = tl.load(q_base + (hkv * 8 + 0) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q1 = tl.load(q_base + (hkv * 8 + 1) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q2 = tl.load(q_base + (hkv * 8 + 2) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q3 = tl.load(q_base + (hkv * 8 + 3) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q4 = tl.load(q_base + (hkv * 8 + 4) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q5 = tl.load(q_base + (hkv * 8 + 5) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q6 = tl.load(q_base + (hkv * 8 + 6) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q7 = tl.load(q_base + (hkv * 8 + 7) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
m0 = tl.full([], float('-inf'), dtype=tl.float32)
m1 = tl.full([], float('-inf'), dtype=tl.float32)
m2 = tl.full([], float('-inf'), dtype=tl.float32)
m3 = tl.full([], float('-inf'), dtype=tl.float32)
m4 = tl.full([], float('-inf'), dtype=tl.float32)
m5 = tl.full([], float('-inf'), dtype=tl.float32)
m6 = tl.full([], float('-inf'), dtype=tl.float32)
m7 = tl.full([], float('-inf'), dtype=tl.float32)
se0 = tl.zeros([], dtype=tl.float32)
se1 = tl.zeros([], dtype=tl.float32)
se2 = tl.zeros([], dtype=tl.float32)
se3 = tl.zeros([], dtype=tl.float32)
se4 = tl.zeros([], dtype=tl.float32)
se5 = tl.zeros([], dtype=tl.float32)
se6 = tl.zeros([], dtype=tl.float32)
se7 = tl.zeros([], dtype=tl.float32)
a0 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a1 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a2 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a3 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a4 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a5 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a6 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a7 = tl.zeros((BLOCK_D,), dtype=tl.float32)
if token_start < seq_len:
for token_flat in range(token_start, token_end):
page_idx = token_flat // PAGE_SIZE
t = token_flat % PAGE_SIZE
page_id = tl.load(block_table_ptr + b * stride_btb + page_idx * stride_btp).to(tl.int32)
page_base = kv_cache_ptr + page_id * stride_kb + hkv * stride_kh
token_off = page_base + t * stride_kp
k = tl.load(token_off + offs_d * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
v = tl.load(token_off + (offs_d + HEAD_DIM) * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
# Head 0
sc = tl.sum(q0 * k) * scale
nm = tl.maximum(m0, sc); rc = tl.exp(m0 - nm)
a0 = a0 * rc; se0 = se0 * rc
e = tl.exp(sc - nm); se0 = se0 + e; a0 = a0 + e * v; m0 = nm
# Head 1
sc = tl.sum(q1 * k) * scale
nm = tl.maximum(m1, sc); rc = tl.exp(m1 - nm)
a1 = a1 * rc; se1 = se1 * rc
e = tl.exp(sc - nm); se1 = se1 + e; a1 = a1 + e * v; m1 = nm
# Head 2
sc = tl.sum(q2 * k) * scale
nm = tl.maximum(m2, sc); rc = tl.exp(m2 - nm)
a2 = a2 * rc; se2 = se2 * rc
e = tl.exp(sc - nm); se2 = se2 + e; a2 = a2 + e * v; m2 = nm
# Head 3
sc = tl.sum(q3 * k) * scale
nm = tl.maximum(m3, sc); rc = tl.exp(m3 - nm)
a3 = a3 * rc; se3 = se3 * rc
e = tl.exp(sc - nm); se3 = se3 + e; a3 = a3 + e * v; m3 = nm
# Head 4
sc = tl.sum(q4 * k) * scale
nm = tl.maximum(m4, sc); rc = tl.exp(m4 - nm)
a4 = a4 * rc; se4 = se4 * rc
e = tl.exp(sc - nm); se4 = se4 + e; a4 = a4 + e * v; m4 = nm
# Head 5
sc = tl.sum(q5 * k) * scale
nm = tl.maximum(m5, sc); rc = tl.exp(m5 - nm)
a5 = a5 * rc; se5 = se5 * rc
e = tl.exp(sc - nm); se5 = se5 + e; a5 = a5 + e * v; m5 = nm
# Head 6
sc = tl.sum(q6 * k) * scale
nm = tl.maximum(m6, sc); rc = tl.exp(m6 - nm)
a6 = a6 * rc; se6 = se6 * rc
e = tl.exp(sc - nm); se6 = se6 + e; a6 = a6 + e * v; m6 = nm
# Head 7
sc = tl.sum(q7 * k) * scale
nm = tl.maximum(m7, sc); rc = tl.exp(m7 - nm)
a7 = a7 * rc; se7 = se7 * rc
e = tl.exp(sc - nm); se7 = se7 + e; a7 = a7 + e * v; m7 = nm
tl.store(partial_out_ptr + b * stride_pob + (hkv*8+0) * stride_poh + s * stride_pos + offs_d * stride_pod, a0, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+0) * stride_mih + s * stride_mis, m0)
tl.store(partial_li_ptr + b * stride_lib + (hkv*8+0) * stride_lih + s * stride_lis, se0)
tl.store(partial_out_ptr + b * stride_pob + (hkv*8+1) * stride_poh + s * stride_pos + offs_d * stride_pod, a1, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+1) * stride_mih + s * stride_mis, m1)
tl.store(partial_li_ptr + b * stride_lib + (hkv*8+1) * stride_lih + s * stride_lis, se1)
tl.store(partial_out_ptr + b * stride_pob + (hkv*8+2) * stride_poh + s * stride_pos + offs_d * stride_pod, a2, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+2) * stride_mih + s * stride_mis, m2)
tl.store(partial_li_ptr + b * stride_lib + (hkv*8+2) * stride_lih + s * stride_lis, se2)
tl.store(partial_out_ptr + b * stride_pob + (hkv*8+3) * stride_poh + s * stride_pos + offs_d * stride_pod, a3, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+3) * stride_mih + s * stride_mis, m3)
tl.store(partial_li_ptr + b * stride_lib + (hkv*8+3) * stride_lih + s * stride_lis, se3)
tl.store(partial_out_ptr + b * stride_pob + (hkv*8+4) * stride_poh + s * stride_pos + offs_d * stride_pod, a4, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+4) * stride_mih + s * stride_mis, m4)
tl.store(partial_li_ptr + b * stride_lib + (hkv*8+4) * stride_lih + s * stride_lis, se4)
tl.store(partial_out_ptr + b * stride_pob + (hkv*8+5) * stride_poh + s * stride_pos + offs_d * stride_pod, a5, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+5) * stride_mih + s * stride_mis, m5)
tl.store(partial_li_ptr + b * stride_lib + (hkv*8+5) * stride_lih + s * stride_lis, se5)
tl.store(partial_out_ptr + b * stride_pob + (hkv*8+6) * stride_poh + s * stride_pos + offs_d * stride_pod, a6, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+6) * stride_mih + s * stride_mis, m6)
tl.store(partial_li_ptr + b * stride_lib + (hkv*8+6) * stride_lih + s * stride_lis, se6)
tl.store(partial_out_ptr + b * stride_pob + (hkv*8+7) * stride_poh + s * stride_pos + offs_d * stride_pod, a7, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*8+7) * stride_mih + s * stride_mis, m7)
tl.store(partial_li_ptr + b * stride_lib + (hkv*8+7) * stride_lih + s * stride_lis, se7)
@triton.jit
def _partial_gs2(
query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr,
partial_out_ptr, partial_mi_ptr, partial_li_ptr,
stride_qb, stride_qh, stride_qd,
stride_kb, stride_kp, stride_kh, stride_kd,
stride_btb, stride_btp,
stride_pob, stride_poh, stride_pos, stride_pod,
stride_mib, stride_mih, stride_mis,
stride_lib, stride_lih, stride_lis,
HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr,
NUM_HEADS: tl.constexpr, NUM_KV_HEADS: tl.constexpr,
scale, MAX_NUM_PAGES: tl.constexpr,
BLOCK_D: tl.constexpr, BLOCK_SEQ: tl.constexpr,
):
pid = tl.program_id(0)
num_splits = tl.cdiv(MAX_NUM_PAGES * PAGE_SIZE, BLOCK_SEQ)
s = pid % num_splits
remainder = pid // num_splits
hkv = remainder % NUM_KV_HEADS
b = remainder // NUM_KV_HEADS
seq_len = tl.load(seq_lens_ptr + b).to(tl.int32)
token_start = s * BLOCK_SEQ
token_end = tl.minimum(token_start + BLOCK_SEQ, seq_len)
offs_d = tl.arange(0, BLOCK_D)
mask_d = offs_d < HEAD_DIM
q_base = query_ptr + b * stride_qb + offs_d * stride_qd
q0 = tl.load(q_base + (hkv * 2) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
q1 = tl.load(q_base + (hkv * 2 + 1) * stride_qh, mask=mask_d, other=0.0).to(tl.float32)
m0 = tl.full([], float('-inf'), dtype=tl.float32)
m1 = tl.full([], float('-inf'), dtype=tl.float32)
se0 = tl.zeros([], dtype=tl.float32)
se1 = tl.zeros([], dtype=tl.float32)
a0 = tl.zeros((BLOCK_D,), dtype=tl.float32)
a1 = tl.zeros((BLOCK_D,), dtype=tl.float32)
if token_start < seq_len:
for token_flat in range(token_start, token_end):
page_idx = token_flat // PAGE_SIZE
t = token_flat % PAGE_SIZE
page_id = tl.load(block_table_ptr + b * stride_btb + page_idx * stride_btp).to(tl.int32)
page_base = kv_cache_ptr + page_id * stride_kb + hkv * stride_kh
token_off = page_base + t * stride_kp
k = tl.load(token_off + offs_d * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
v = tl.load(token_off + (offs_d + HEAD_DIM) * stride_kd, mask=mask_d, other=0.0).to(tl.float32)
sc = tl.sum(q0 * k) * scale
nm = tl.maximum(m0, sc); rc = tl.exp(m0 - nm)
a0 = a0 * rc; se0 = se0 * rc
e = tl.exp(sc - nm); se0 = se0 + e; a0 = a0 + e * v; m0 = nm
sc = tl.sum(q1 * k) * scale
nm = tl.maximum(m1, sc); rc = tl.exp(m1 - nm)
a1 = a1 * rc; se1 = se1 * rc
e = tl.exp(sc - nm); se1 = se1 + e; a1 = a1 + e * v; m1 = nm
tl.store(partial_out_ptr + b * stride_pob + (hkv*2+0) * stride_poh + s * stride_pos + offs_d * stride_pod, a0, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*2+0) * stride_mih + s * stride_mis, m0)
tl.store(partial_li_ptr + b * stride_lib + (hkv*2+0) * stride_lih + s * stride_lis, se0)
tl.store(partial_out_ptr + b * stride_pob + (hkv*2+1) * stride_poh + s * stride_pos + offs_d * stride_pod, a1, mask=mask_d)
tl.store(partial_mi_ptr + b * stride_mib + (hkv*2+1) * stride_mih + s * stride_mis, m1)
tl.store(partial_li_ptr + b * stride_lib + (hkv*2+1) * stride_lih + s * stride_lis, se1)
@triton.jit
def _reduce_kernel(
partial_out_ptr, partial_mi_ptr, partial_li_ptr, output_ptr,
stride_pob, stride_poh, stride_pos, stride_pod,
stride_mib, stride_mih, stride_mis,
stride_lib, stride_lih, stride_lis,
stride_ob, stride_oh, stride_od,
NUM_HEADS: tl.constexpr, NUM_SPLITS: tl.constexpr,
HEAD_DIM: tl.constexpr, BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
b = pid // NUM_HEADS
h = pid % NUM_HEADS
offs_d = tl.arange(0, BLOCK_D)
mask_d = offs_d < HEAD_DIM
global_max = tl.full([], float('-inf'), dtype=tl.float32)
for s in range(NUM_SPLITS):
mi = tl.load(partial_mi_ptr + b * stride_mib + h * stride_mih + s * stride_mis)
global_max = tl.maximum(global_max, mi)
total_sum = tl.zeros([], dtype=tl.float32)
total_acc = tl.zeros((BLOCK_D,), dtype=tl.float32)
for s in range(NUM_SPLITS):
mi = tl.load(partial_mi_ptr + b * stride_mib + h * stride_mih + s * stride_mis)
li = tl.load(partial_li_ptr + b * stride_lib + h * stride_lih + s * stride_lis)
acc = tl.load(partial_out_ptr + b * stride_pob + h * stride_poh + s * stride_pos + offs_d * stride_pod,
mask=mask_d, other=0.0)
rc = tl.exp(mi - global_max)
total_sum = total_sum + li * rc
total_acc = total_acc + acc * rc
result = (total_acc * (1.0 / total_sum)).to(tl.bfloat16)
tl.store(output_ptr + b * stride_ob + h * stride_oh + offs_d * stride_od, result, mask=mask_d)
def paged_attention_decode_triton(
query, kv_cache, block_table, seq_lens,
num_heads, num_kv_heads, head_dim, page_size,
):
B, H, D = query.shape
G = num_heads // num_kv_heads
scale = 1.0 / math.sqrt(head_dim)
max_pages = block_table.shape[1]
max_seq_len = max_pages * page_size
BLOCK_D = triton.next_power_of_2(head_dim)
BLOCK_SEQ = 128
num_splits = triton.cdiv(max_seq_len, BLOCK_SEQ)
output = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device)
partial_out = torch.empty(B, H, num_splits, D, dtype=torch.float32, device=query.device)
partial_mi = torch.full((B, H, num_splits), float('-inf'), dtype=torch.float32, device=query.device)
partial_li = torch.zeros(B, H, num_splits, dtype=torch.float32, device=query.device)
pk = {2: _partial_gs2, 4: _partial_gs4, 8: _partial_gs8}[G]
grid = (B * num_kv_heads * num_splits,)
pk[grid](
query, kv_cache, block_table, seq_lens,
partial_out, partial_mi, partial_li,
query.stride(0), query.stride(1), query.stride(2),
kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),
block_table.stride(0), block_table.stride(1),
partial_out.stride(0), partial_out.stride(1), partial_out.stride(2), partial_out.stride(3),
partial_mi.stride(0), partial_mi.stride(1), partial_mi.stride(2),
partial_li.stride(0), partial_li.stride(1), partial_li.stride(2),
HEAD_DIM=head_dim, PAGE_SIZE=page_size,
NUM_HEADS=num_heads, NUM_KV_HEADS=num_kv_heads,
scale=scale, MAX_NUM_PAGES=max_pages,
BLOCK_D=BLOCK_D, BLOCK_SEQ=BLOCK_SEQ,
num_warps=1, num_stages=5,
)
grid_reduce = (B * H,)
_reduce_kernel[grid_reduce](
partial_out, partial_mi, partial_li, output,
partial_out.stride(0), partial_out.stride(1), partial_out.stride(2), partial_out.stride(3),
partial_mi.stride(0), partial_mi.stride(1), partial_mi.stride(2),
partial_li.stride(0), partial_li.stride(1), partial_li.stride(2),
output.stride(0), output.stride(1), output.stride(2),
NUM_HEADS=H, NUM_SPLITS=num_splits,
HEAD_DIM=head_dim, BLOCK_D=BLOCK_D,
num_warps=4, num_stages=2,
)
return output
# ---------------------------------------------------------------------------
OP_TYPE = "attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
BATCH = 8
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
SEQ_LEN = 1024
PAGE_SIZE = 16
class Model(nn.Module):
def __init__(self, batch, num_heads, num_kv_heads, head_dim, seq_len, page_size):
super().__init__()
assert num_heads % num_kv_heads == 0
self.batch = batch
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.seq_len = seq_len
self.page_size = page_size
self.group_size = num_heads // num_kv_heads
self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False)
def forward(self, query, kv_cache, block_table, seq_lens):
return paged_attention_decode_triton(
query, kv_cache, block_table, seq_lens,
num_heads=self.num_heads, num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim, page_size=self.page_size,
)
def get_inputs():
B, H, Hkv, D, L, P = BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE
pages_per_seq = (L + P - 1) // P
total_pages = max(B * pages_per_seq + 8, 64)
query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1
kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1
perm = torch.randperm(total_pages)[:B * pages_per_seq].reshape(B, pages_per_seq).int()
block_table = perm.contiguous()
seq_lens = torch.full((B,), L, dtype=torch.int32)
return [query, kv_cache, block_table, seq_lens]
def get_init_inputs():
return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]