"""Paged attention decode kernel implemented in Triton. Conforms to the same Model/get_inputs/get_init_inputs interface as reference.py. """ import math import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] BATCH = 8 NUM_HEADS = 32 NUM_KV_HEADS = 8 HEAD_DIM = 128 SEQ_LEN = 1024 PAGE_SIZE = 16 @triton.jit def paged_decode_kernel( q_ptr, kv_ptr, block_table_ptr, seq_lens_ptr, out_ptr, stride_qb, stride_qh, stride_qd, stride_kvn, stride_kvt, stride_kvh, stride_kvd, stride_bt_b, stride_bt_p, stride_ob, stride_oh, stride_od, SCALE: tl.constexpr, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr, G: tl.constexpr, ): """One CUDA block per (batch, kv_head). Computes all G query heads.""" pid_b = tl.program_id(0) pid_hkv = tl.program_id(1) seq_len = tl.load(seq_lens_ptr + pid_b) num_pages = tl.cdiv(seq_len, PAGE_SIZE) offs_d = tl.arange(0, HEAD_DIM) offs_t = tl.arange(0, PAGE_SIZE) offs_g = tl.arange(0, G)[:, None] offs_g1 = tl.arange(0, G) # Load the G query heads for this KV group: (G, HEAD_DIM). q_ptrs = ( q_ptr + pid_b * stride_qb + (pid_hkv * G + offs_g) * stride_qh + offs_d[None, :] * stride_qd ) qs = tl.load(q_ptrs, mask=offs_d[None, :] < HEAD_DIM, other=0.0).to(tl.float32) ms = tl.full((G,), float("-inf"), dtype=tl.float32) ls = tl.zeros((G,), dtype=tl.float32) accs = tl.zeros((G, HEAD_DIM), dtype=tl.float32) for p in range(num_pages): physical_block = tl.load( block_table_ptr + pid_b * stride_bt_b + p * stride_bt_p ).to(tl.int64) token_offset = p * PAGE_SIZE valid = tl.minimum(PAGE_SIZE, seq_len - token_offset) kv_base = kv_ptr + physical_block * stride_kvn + pid_hkv * stride_kvh k_ptrs = kv_base + (offs_t[:, None] * stride_kvt + offs_d[None, :] * stride_kvd) k = tl.load( k_ptrs, mask=(offs_t[:, None] < valid) & (offs_d[None, :] < HEAD_DIM), other=0.0, ).to(tl.float32) v_ptrs = kv_base + HEAD_DIM * stride_kvd + ( offs_t[:, None] * stride_kvt + offs_d[None, :] * stride_kvd ) v = tl.load( v_ptrs, mask=(offs_t[:, None] < valid) & (offs_d[None, :] < HEAD_DIM), other=0.0, ).to(tl.float32) for g in tl.static_range(G): mask_g = tl.arange(0, G) == g mask_g2 = mask_g[:, None] q = tl.sum(tl.where(mask_g2, qs, 0.0), axis=0) acc_cur = tl.sum(tl.where(mask_g2, accs, 0.0), axis=0) m_cur = tl.sum(tl.where(mask_g, ms, 0.0), axis=0) l_cur = tl.sum(tl.where(mask_g, ls, 0.0), axis=0) scores = tl.reshape(tl.dot(q[None, :], tl.trans(k)), (PAGE_SIZE,)) * SCALE scores = tl.where(offs_t < valid, scores, float("-inf")) m_new = tl.maximum(m_cur, tl.max(scores, axis=0)) exp_scale = tl.exp(m_cur - m_new) exp_scores = tl.exp(scores - m_new) weighted = tl.reshape(tl.dot(exp_scores[None, :], v), (HEAD_DIM,)) new_acc = acc_cur * exp_scale + weighted new_l = l_cur * exp_scale + tl.sum(exp_scores, axis=0) accs = tl.where(mask_g2, new_acc[None, :], accs) ls = tl.where(mask_g, new_l, ls) ms = tl.where(mask_g, m_new, ms) out_ptrs = ( out_ptr + pid_b * stride_ob + (pid_hkv * G + offs_g) * stride_oh + offs_d[None, :] * stride_od ) tl.store( out_ptrs, (accs / ls[:, None]).to(tl.bfloat16), mask=offs_d[None, :] < HEAD_DIM, ) @triton.jit def paged_decode_part_kernel( q_ptr, kv_ptr, block_table_ptr, seq_lens_ptr, part_out_ptr, part_m_ptr, part_l_ptr, stride_qb, stride_qh, stride_qd, stride_kvn, stride_kvt, stride_kvh, stride_kvd, stride_bt_b, stride_bt_p, stride_pob, stride_poh, stride_pos, stride_pod, stride_pmb, stride_pmh, stride_pms, SCALE: tl.constexpr, PAGE_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr, G: tl.constexpr, NUM_SPLITS: tl.constexpr, ): """One CUDA block per (batch, kv_head, split). Computes partial attention.""" pid_b = tl.program_id(0) pid_hkv = tl.program_id(1) pid_s = tl.program_id(2) seq_len = tl.load(seq_lens_ptr + pid_b) num_pages = tl.cdiv(seq_len, PAGE_SIZE) pages_per_split = tl.cdiv(num_pages, NUM_SPLITS) start_page = pid_s * pages_per_split end_page = tl.minimum(num_pages, start_page + pages_per_split) offs_d = tl.arange(0, HEAD_DIM) offs_t = tl.arange(0, PAGE_SIZE) offs_g = tl.arange(0, G)[:, None] offs_g1 = tl.arange(0, G) # Load the G query heads for this KV group: (G, HEAD_DIM). q_ptrs = ( q_ptr + pid_b * stride_qb + (pid_hkv * G + offs_g) * stride_qh + offs_d[None, :] * stride_qd ) qs = tl.load(q_ptrs, mask=offs_d[None, :] < HEAD_DIM, other=0.0).to(tl.float32) ms = tl.full((G,), float("-inf"), dtype=tl.float32) ls = tl.zeros((G,), dtype=tl.float32) accs = tl.zeros((G, HEAD_DIM), dtype=tl.float32) for p in range(start_page, end_page): physical_block = tl.load( block_table_ptr + pid_b * stride_bt_b + p * stride_bt_p ).to(tl.int64) token_offset = p * PAGE_SIZE valid = tl.minimum(PAGE_SIZE, seq_len - token_offset) kv_base = kv_ptr + physical_block * stride_kvn + pid_hkv * stride_kvh k_ptrs = kv_base + (offs_t[:, None] * stride_kvt + offs_d[None, :] * stride_kvd) k = tl.load( k_ptrs, mask=(offs_t[:, None] < valid) & (offs_d[None, :] < HEAD_DIM), other=0.0, ).to(tl.float32) v_ptrs = kv_base + HEAD_DIM * stride_kvd + ( offs_t[:, None] * stride_kvt + offs_d[None, :] * stride_kvd ) v = tl.load( v_ptrs, mask=(offs_t[:, None] < valid) & (offs_d[None, :] < HEAD_DIM), other=0.0, ).to(tl.float32) for g in tl.static_range(G): mask_g = tl.arange(0, G) == g mask_g2 = mask_g[:, None] q = tl.sum(tl.where(mask_g2, qs, 0.0), axis=0) acc_cur = tl.sum(tl.where(mask_g2, accs, 0.0), axis=0) m_cur = tl.sum(tl.where(mask_g, ms, 0.0), axis=0) l_cur = tl.sum(tl.where(mask_g, ls, 0.0), axis=0) scores = tl.reshape(tl.dot(q[None, :], tl.trans(k)), (PAGE_SIZE,)) * SCALE scores = tl.where(offs_t < valid, scores, float("-inf")) m_new = tl.maximum(m_cur, tl.max(scores, axis=0)) exp_scale = tl.exp(m_cur - m_new) exp_scores = tl.exp(scores - m_new) weighted = tl.reshape(tl.dot(exp_scores[None, :], v), (HEAD_DIM,)) new_acc = acc_cur * exp_scale + weighted new_l = l_cur * exp_scale + tl.sum(exp_scores, axis=0) accs = tl.where(mask_g2, new_acc[None, :], accs) ls = tl.where(mask_g, new_l, ls) ms = tl.where(mask_g, m_new, ms) # Write partials for all G query heads in the group at once. out_ptrs = ( part_out_ptr + pid_b * stride_pob + (pid_hkv * G + offs_g) * stride_poh + pid_s * stride_pos + offs_d[None, :] * stride_pod ) tl.store(out_ptrs, accs, mask=offs_d[None, :] < HEAD_DIM) m_ptrs = ( part_m_ptr + pid_b * stride_pmb + (pid_hkv * G + offs_g1) * stride_pmh + pid_s * stride_pms ) tl.store(m_ptrs, ms) l_ptrs = ( part_l_ptr + pid_b * stride_pmb + (pid_hkv * G + offs_g1) * stride_pmh + pid_s * stride_pms ) tl.store(l_ptrs, ls) @triton.jit def combine_kernel( part_out_ptr, part_m_ptr, part_l_ptr, out_ptr, stride_pob, stride_poh, stride_pos, stride_pod, stride_pmb, stride_pmh, stride_pms, stride_ob, stride_oh, stride_od, HEAD_DIM: tl.constexpr, NUM_SPLITS: tl.constexpr, ): """Combine partial attention results across splits.""" pid_b = tl.program_id(0) pid_h = tl.program_id(1) offs_d = tl.arange(0, HEAD_DIM) M = tl.full((), float("-inf"), dtype=tl.float32) denom = tl.zeros((), dtype=tl.float32) acc = tl.zeros((HEAD_DIM,), dtype=tl.float32) base_out = part_out_ptr + pid_b * stride_pob + pid_h * stride_poh base_m = part_m_ptr + pid_b * stride_pmb + pid_h * stride_pmh base_l = part_l_ptr + pid_b * stride_pmb + pid_h * stride_pmh for s in tl.static_range(NUM_SPLITS): m_s = tl.load(base_m + s * stride_pms) l_s = tl.load(base_l + s * stride_pms) acc_s = tl.load(base_out + s * stride_pos + offs_d * stride_pod, mask=offs_d < HEAD_DIM, other=0.0) M_new = tl.maximum(M, m_s) scale_global = tl.exp(M - M_new) scale_s = tl.exp(m_s - M_new) acc = acc * scale_global + acc_s * scale_s denom = denom * scale_global + l_s * scale_s M = M_new out_ptrs = out_ptr + pid_b * stride_ob + pid_h * stride_oh + offs_d * stride_od tl.store( out_ptrs, (acc / denom).to(tl.bfloat16), mask=offs_d < HEAD_DIM, ) class Model(nn.Module): def __init__( self, batch: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int, page_size: int, ): super().__init__() assert num_heads % num_kv_heads == 0, "num_heads must be a multiple of num_kv_heads (GQA)" self.batch = batch self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.seq_len = seq_len self.page_size = page_size self.group_size = num_heads // num_kv_heads self.scale = 1.0 / math.sqrt(head_dim) self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False) def forward( self, query: torch.Tensor, kv_cache: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, ) -> torch.Tensor: B, H, D = query.shape Hkv = self.num_kv_heads G = self.group_size P = self.page_size device = query.device out = torch.empty(B, H, D, dtype=query.dtype, device=device) # Use sequence splitting when there are too few (batch, kv_head) tiles to # keep the GPU busy. Aim for ~256 active blocks. blocks = B * Hkv num_splits = max(1, 256 // blocks) max_pages = (int(seq_lens.max().item()) + P - 1) // P num_splits = min(num_splits, max_pages) if num_splits == 1: grid = (B, Hkv) paged_decode_kernel[grid]( query, kv_cache, block_table, seq_lens, out, query.stride(0), query.stride(1), query.stride(2), kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3), block_table.stride(0), block_table.stride(1), out.stride(0), out.stride(1), out.stride(2), SCALE=self.scale, PAGE_SIZE=P, HEAD_DIM=D, G=G, num_warps=4, ) return out part_out = torch.empty(B, H, num_splits, D, dtype=torch.float32, device=device) part_m = torch.empty(B, H, num_splits, dtype=torch.float32, device=device) part_l = torch.empty(B, H, num_splits, dtype=torch.float32, device=device) grid_part = (B, Hkv, num_splits) paged_decode_part_kernel[grid_part]( query, kv_cache, block_table, seq_lens, part_out, part_m, part_l, query.stride(0), query.stride(1), query.stride(2), kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3), block_table.stride(0), block_table.stride(1), part_out.stride(0), part_out.stride(1), part_out.stride(2), part_out.stride(3), part_m.stride(0), part_m.stride(1), part_m.stride(2), SCALE=self.scale, PAGE_SIZE=P, HEAD_DIM=D, G=G, NUM_SPLITS=num_splits, num_warps=4, ) grid_combine = (B, H) combine_kernel[grid_combine]( part_out, part_m, part_l, out, part_out.stride(0), part_out.stride(1), part_out.stride(2), part_out.stride(3), part_m.stride(0), part_m.stride(1), part_m.stride(2), out.stride(0), out.stride(1), out.stride(2), HEAD_DIM=D, NUM_SPLITS=num_splits, num_warps=2, ) return out def get_inputs(): B = BATCH H = NUM_HEADS Hkv = NUM_KV_HEADS D = HEAD_DIM L = SEQ_LEN P = PAGE_SIZE pages_per_seq = (L + P - 1) // P total_pages = max(B * pages_per_seq + 8, 64) query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1 kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1 perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int() block_table = perm.contiguous() seq_lens = torch.full((B,), L, dtype=torch.int32) return [query, kv_cache, block_table, seq_lens] def get_init_inputs(): return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]