"""Paged attention decode kernel for RTX PRO 6000 (SM120 Blackwell). Triton-based decode kernel with online softmax and paged KV-cache gathering. Splits work across (batch, kv_head, seq_chunk) for SM occupancy, then reduces partial results with a second kernel. """ import math import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] BATCH = 8 NUM_HEADS = 32 NUM_KV_HEADS = 8 HEAD_DIM = 128 SEQ_LEN = 1024 PAGE_SIZE = 16 # --------------------------------------------------------------------------- # Pass 1: partial attention over a chunk of pages # --------------------------------------------------------------------------- @triton.jit def _partial_decode_kernel( query_ptr, kv_cache_ptr, block_table_ptr, seq_lens_ptr, partial_out_ptr, B, H, Hkv, D, max_blocks, P, pages_per_chunk, stride_q_b, stride_q_h, stride_kv_blk, stride_kv_pos, stride_kv_h, stride_bt_b, stride_po_c, stride_po_b, stride_po_h, BLOCK_D: tl.constexpr, group_size: tl.constexpr, ): """Compute partial attention for one chunk of pages. Grid: (num_chunks, B, Hkv) """ pid_c = tl.program_id(0) pid_b = tl.program_id(1) pid_kv = tl.program_id(2) seq_len = tl.load(seq_lens_ptr + pid_b) num_pages = (seq_len + P - 1) // P page_start = pid_c * pages_per_chunk page_end = tl.minimum(page_start + pages_per_chunk, num_pages) if page_start >= num_pages: return BLOCK_L: tl.constexpr = 16 offs_l = tl.arange(0, BLOCK_L) offs_d = tl.arange(0, BLOCK_D) offs_g = tl.arange(0, group_size) q_head_base = pid_kv * group_size # Load all Q heads in this group q_offs = ( pid_b * stride_q_b + (q_head_base + offs_g[:, None]) * stride_q_h + offs_d[None, :] ) mask_q = offs_d[None, :] < D q = tl.load(query_ptr + q_offs, mask=mask_q, other=0.0).to(tl.float32) m = tl.full([group_size], float("-inf"), dtype=tl.float32) l_sum = tl.zeros([group_size], dtype=tl.float32) acc = tl.zeros([group_size, BLOCK_D], dtype=tl.float32) scale = 1.0 / tl.sqrt(D.to(tl.float32)) for page_idx in range(page_start, page_end): blk_idx = tl.load(block_table_ptr + pid_b * stride_bt_b + page_idx) if page_idx == num_pages - 1: rem = seq_len % P tokens_this_page = tl.where(rem == 0, P, rem) else: tokens_this_page = P valid_l = offs_l < tokens_this_page k_base = blk_idx.to(tl.int64) * stride_kv_blk + pid_kv * stride_kv_h # K tile k_offs = k_base + offs_l[:, None] * stride_kv_pos + offs_d[None, :] mask_kv = valid_l[:, None] & (offs_d[None, :] < D) k_tile = tl.load(kv_cache_ptr + k_offs, mask=mask_kv, other=0.0).to(tl.float32) scores = tl.dot(q, tl.trans(k_tile)) * scale scores = tl.where(valid_l[None, :], scores, float("-inf")) m_new = tl.maximum(m, tl.max(scores, axis=1)) rescale = tl.exp(m - m_new) acc = acc * rescale[:, None] l_sum = l_sum * rescale p = tl.exp(scores - m_new[:, None]) p = tl.where(valid_l[None, :], p, 0.0) l_sum = l_sum + tl.sum(p, axis=1) # V tile v_offs = k_base + D + offs_l[:, None] * stride_kv_pos + offs_d[None, :] v_tile = tl.load(kv_cache_ptr + v_offs, mask=mask_kv, other=0.0).to(tl.float32) acc = acc + tl.dot(p, v_tile) m = m_new # Write partial results — vectorised across all Q heads in the group. base_off = pid_c * stride_po_c + pid_b * stride_po_b # m: (group_size,) -> store at [c, b, q_head, D] m_offs = base_off + (q_head_base + offs_g) * stride_po_h + D tl.store(partial_out_ptr + m_offs, m) # l_sum: (group_size,) -> store at [c, b, q_head, D+1] ls_offs = base_off + (q_head_base + offs_g) * stride_po_h + D + 1 tl.store(partial_out_ptr + ls_offs, l_sum) # acc: (group_size, BLOCK_D) -> store at [c, b, q_head, 0:D] acc_offs = ( base_off + (q_head_base + offs_g[:, None]) * stride_po_h + offs_d[None, :] ) mask_acc = offs_d[None, :] < D tl.store(partial_out_ptr + acc_offs, acc, mask=mask_acc) # --------------------------------------------------------------------------- # Pass 2: reduce partial results across chunks # --------------------------------------------------------------------------- @triton.jit def _reduce_partial_kernel( partial_in_ptr, output_ptr, num_chunks, B, H, D, stride_pi_c, stride_pi_b, stride_pi_h, stride_q_b, stride_q_h, BLOCK_D: tl.constexpr, ): """Reduce partial results across chunks into final output. Grid: (H, B) — one program per query head. """ pid_q = tl.program_id(0) pid_b = tl.program_id(1) offs_d = tl.arange(0, BLOCK_D) mask_d = offs_d < D m_global = float("-inf") l_global = 0.0 o_global = tl.zeros([BLOCK_D], dtype=tl.float32) for c in range(num_chunks): # Load m_c, l_c m_c = tl.load( partial_in_ptr + c * stride_pi_c + pid_b * stride_pi_b + pid_q * stride_pi_h + D ) l_c = tl.load( partial_in_ptr + c * stride_pi_c + pid_b * stride_pi_b + pid_q * stride_pi_h + D + 1 ) chunk_valid = l_c > 0.0 m_new = tl.maximum(m_global, m_c) # Only rescale if chunk is valid; otherwise keep current state. # tl.where selects element-wise — both branches are evaluated. rescale_old = tl.exp(m_global - m_new) rescale_c = tl.exp(m_c - m_new) o_global = tl.where(chunk_valid, o_global * rescale_old, o_global) l_global = tl.where(chunk_valid, l_global * rescale_old, l_global) # Load acc_c for this chunk acc_c = tl.load( partial_in_ptr + c * stride_pi_c + pid_b * stride_pi_b + pid_q * stride_pi_h + offs_d, mask=mask_d, other=0.0, ) # acc_c is already the exp-weighted sum: Σ exp(s-m_c)·V o_global = tl.where( chunk_valid, o_global + rescale_c * acc_c, o_global, ) l_global = tl.where( chunk_valid, l_global + rescale_c * l_c, l_global, ) m_global = tl.where(chunk_valid, m_new, m_global) l_safe = tl.where(l_global == 0.0, 1.0, l_global) out_vals = o_global / l_safe out_offs = pid_b * stride_q_b + pid_q * stride_q_h + offs_d tl.store(output_ptr + out_offs, out_vals.to(tl.bfloat16), mask=mask_d) # --------------------------------------------------------------------------- # Host-side dispatch # --------------------------------------------------------------------------- def _paged_attention_decode( query: torch.Tensor, kv_cache: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, ) -> torch.Tensor: B, H, D = query.shape Hkv = kv_cache.shape[2] P = kv_cache.shape[1] max_blocks = block_table.shape[1] group_size = H // Hkv max_seq = int(seq_lens.max().item()) max_pages = (max_seq + P - 1) // P # Target at least 256 blocks for good SM occupancy. base_blocks = B * Hkv target_blocks = 256 num_chunks = max(1, min(max_pages, target_blocks // base_blocks)) pages_per_chunk = (max_pages + num_chunks - 1) // num_chunks # Intermediate storage: (num_chunks, B, H, D + 2) fp32 partial = torch.zeros( num_chunks, B, H, D + 2, dtype=torch.float32, device=query.device, ) out = torch.empty(B, H, D, dtype=torch.bfloat16, device=query.device) # Pass 1: partial attention over chunks grid1 = (num_chunks, B, Hkv) _partial_decode_kernel[grid1]( query, kv_cache, block_table, seq_lens, partial, B, H, Hkv, D, max_blocks, P, pages_per_chunk, query.stride(0), query.stride(1), kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), block_table.stride(0), partial.stride(0), partial.stride(1), partial.stride(2), BLOCK_D=D, group_size=group_size, ) # Pass 2: reduce partial results across chunks grid2 = (H, B) _reduce_partial_kernel[grid2]( partial, out, num_chunks, B, H, D, partial.stride(0), partial.stride(1), partial.stride(2), out.stride(0), out.stride(1), BLOCK_D=D, ) return out # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- class Model(nn.Module): """Single-query paged attention decode — Triton two-pass kernel.""" def __init__( self, batch: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int, page_size: int, ): super().__init__() assert num_heads % num_kv_heads == 0 self.batch = batch self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.seq_len = seq_len self.page_size = page_size self.scale = 1.0 / math.sqrt(head_dim) self.register_buffer( "_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False ) def forward( self, query: torch.Tensor, kv_cache: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, ) -> torch.Tensor: return _paged_attention_decode( query.contiguous(), kv_cache.contiguous(), block_table.contiguous(), seq_lens.contiguous(), ) def get_inputs(): B = BATCH H = NUM_HEADS Hkv = NUM_KV_HEADS D = HEAD_DIM L = SEQ_LEN P = PAGE_SIZE pages_per_seq = (L + P - 1) // P total_pages = max(B * pages_per_seq + 8, 64) query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1 kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1 perm = ( torch.randperm(total_pages)[: B * pages_per_seq] .reshape(B, pages_per_seq) .int() ) block_table = perm.contiguous() seq_lens = torch.full((B,), L, dtype=torch.int32) return [query, kv_cache, block_table, seq_lens] def get_init_inputs(): return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]