"""Paged attention decode kernel via Triton (GQA-aware, page-batched).""" import math import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] BATCH = 8 NUM_HEADS = 32 NUM_KV_HEADS = 8 HEAD_DIM = 128 SEQ_LEN = 1024 PAGE_SIZE = 16 @triton.jit def _paged_decode_gqa_kernel( Q_ptr, KV_ptr, BT_ptr, SL_ptr, O_ptr, stride_qb, stride_qh, stride_qd, stride_kv_block, stride_kv_page, stride_kv_h, stride_kv_d, stride_bt_b, stride_bt_p, stride_ob, stride_oh, stride_od, num_heads, head_dim, page_size, group_size, scale, BLOCK_G: tl.constexpr, BLOCK_T: tl.constexpr, BLOCK_D: tl.constexpr, ): batch_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) seq_len = tl.load(SL_ptr + batch_idx) num_pages = (seq_len + page_size - 1) // page_size offs_g = tl.arange(0, BLOCK_G) offs_t = tl.arange(0, BLOCK_T) offs_d = tl.arange(0, BLOCK_D) head_ids = kv_head_idx * group_size + offs_g g_mask = head_ids < num_heads d_mask = offs_d < head_dim q_ptrs = ( Q_ptr + batch_idx * stride_qb + head_ids[:, None] * stride_qh + offs_d[None, :] * stride_qd ) q = tl.load(q_ptrs, mask=g_mask[:, None] & d_mask[None, :], other=0.0) m_i = tl.full([BLOCK_G], -float("inf"), tl.float32) l_i = tl.zeros([BLOCK_G], dtype=tl.float32) acc = tl.zeros([BLOCK_G, BLOCK_D], dtype=tl.float32) for page_idx in range(num_pages): block_id = tl.load(BT_ptr + batch_idx * stride_bt_b + page_idx * stride_bt_p).to(tl.int64) page_start = page_idx * page_size tokens_in_page = tl.minimum(page_size, seq_len - page_start) kv_h_base = block_id * stride_kv_block + kv_head_idx * stride_kv_h t_valid_row = offs_t[None, :] < tokens_in_page t_valid_2d = offs_t[:, None] < tokens_in_page k_ptrs = ( KV_ptr + kv_h_base + offs_t[:, None] * stride_kv_page + offs_d[None, :] * stride_kv_d ) v_ptrs = ( KV_ptr + kv_h_base + offs_t[:, None] * stride_kv_page + (head_dim + offs_d)[None, :] * stride_kv_d ) kv_mask = t_valid_2d & d_mask[None, :] k = tl.load(k_ptrs, mask=kv_mask, other=0.0) v = tl.load(v_ptrs, mask=kv_mask, other=0.0) qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale qk = tl.where(t_valid_row, qk, -float("inf")) m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) alpha = tl.exp(m_i - m_ij) p = tl.exp(qk - m_ij[:, None]) l_i = l_i * alpha + tl.sum(p, axis=1) acc = acc * alpha[:, None] + tl.dot(p.to(tl.bfloat16), v).to(tl.float32) m_i = m_ij out_vals = acc / l_i[:, None] o_ptrs = ( O_ptr + batch_idx * stride_ob + head_ids[:, None] * stride_oh + offs_d[None, :] * stride_od ) tl.store(o_ptrs, out_vals.to(tl.bfloat16), mask=g_mask[:, None] & d_mask[None, :]) def paged_attention_decode( query: torch.Tensor, kv_cache: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, scale: float, ) -> torch.Tensor: B, H, D = query.shape _, P, Hkv, D2 = kv_cache.shape assert D2 == 2 * D group_size = H // Hkv out = torch.empty_like(query) BLOCK_G = triton.next_power_of_2(group_size) BLOCK_T = P BLOCK_D = triton.next_power_of_2(D) if D <= 64: num_warps, num_stages = 4, 3 else: num_warps, num_stages = 8, 3 grid = (B, Hkv) _paged_decode_gqa_kernel[grid]( query, kv_cache, block_table, seq_lens, out, query.stride(0), query.stride(1), query.stride(2), kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3), block_table.stride(0), block_table.stride(1), out.stride(0), out.stride(1), out.stride(2), num_heads=H, head_dim=D, page_size=P, group_size=group_size, scale=scale, BLOCK_G=BLOCK_G, BLOCK_T=BLOCK_T, BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=num_stages, ) return out class Model(nn.Module): def __init__( self, batch: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int, page_size: int, ): super().__init__() assert num_heads % num_kv_heads == 0 self.batch = batch self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.seq_len = seq_len self.page_size = page_size self.group_size = num_heads // num_kv_heads self.scale = 1.0 / math.sqrt(head_dim) self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False) def forward( self, query: torch.Tensor, kv_cache: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, ) -> torch.Tensor: return paged_attention_decode(query, kv_cache, block_table, seq_lens, self.scale) def get_inputs(): B = BATCH H = NUM_HEADS Hkv = NUM_KV_HEADS D = HEAD_DIM L = SEQ_LEN P = PAGE_SIZE pages_per_seq = (L + P - 1) // P total_pages = max(B * pages_per_seq + 8, 64) query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1 kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1 perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int() block_table = perm.contiguous() seq_lens = torch.full((B,), L, dtype=torch.int32) return [query, kv_cache, block_table, seq_lens] def get_init_inputs(): return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]