import math import torch import torch.nn as nn import triton import triton.language as tl @triton.jit def paged_attention_kernel( Q_ptr, # (batch, num_heads, head_dim) KV_ptr, # (num_blocks, page_size, num_kv_heads, head_dim * 2) BlockTable_ptr, # (batch, max_blocks) SeqLens_ptr, # (batch,) Out_ptr, # (batch, num_heads, head_dim) scale, # float stride_qb, stride_qh, stride_qd, stride_kvb, stride_kvp, stride_kvh, stride_kvd, stride_btb, stride_bts, stride_ob, stride_oh, stride_od, group_size, num_kv_heads, head_dim: tl.constexpr, page_size: tl.constexpr, BLOCK_N: tl.constexpr, ): # Map program ID to batch and head h = tl.program_id(0) b = tl.program_id(1) # h_kv is the corresponding KV head h_kv = h // group_size # Load query q q_offset = b * stride_qb + h * stride_qh + tl.arange(0, head_dim) * stride_qd q = tl.load(Q_ptr + q_offset) # Sequence length for this batch element seq_len = tl.load(SeqLens_ptr + b) num_pages = (seq_len + page_size - 1) // page_size # Initialize online softmax accumulators m = -float('inf') d = 0.0 o = tl.zeros((head_dim,), dtype=tl.float32) # Offset indices within the block cols = tl.arange(0, BLOCK_N) p_idx_in_block = cols // page_size offset_in_page = cols % page_size d_offset = tl.arange(0, head_dim)[None, :] # Loop over tokens in chunks of BLOCK_N for t_start in range(0, seq_len, BLOCK_N): # Mask for valid tokens in this block token_indices = t_start + cols mask = token_indices < seq_len # Mask for valid blocks to load from block table p_idx = t_start // page_size + p_idx_in_block bt_mask = p_idx < num_pages # Load block IDs block_ids = tl.load(BlockTable_ptr + b * stride_btb + p_idx * stride_bts, mask=bt_mask, other=0) # Base pointers for the loaded tokens token_base = KV_ptr + block_ids * stride_kvb + h_kv * stride_kvh + offset_in_page * stride_kvp # 2D pointers for K and V k_offsets = token_base[:, None] + d_offset * stride_kvd v_offsets = token_base[:, None] + (d_offset + head_dim) * stride_kvd # Load K and V k = tl.load(k_offsets, mask=mask[:, None], other=0.0) v = tl.load(v_offsets, mask=mask[:, None], other=0.0) # Compute scores: sum(q * k, axis=1) * scale scores = tl.sum(q[None, :] * k, axis=1) * scale # Apply mask to scores scores = tl.where(mask, scores, -float('inf')) # Online softmax update m_old = m m_new = tl.maximum(m_old, tl.max(scores, axis=0)) p = tl.exp(scores - m_new) alpha = tl.exp(m_old - m_new) d = d * alpha + tl.sum(p, axis=0) o = o * alpha + tl.sum(p[:, None] * v, axis=0) m = m_new # Final normalization o = o / d # Store output out_offset = b * stride_ob + h * stride_oh + tl.arange(0, head_dim) * stride_od tl.store(Out_ptr + out_offset, o.to(Out_ptr.dtype.element_ty)) class Model(nn.Module): def __init__( self, batch: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int, page_size: int, ): super().__init__() assert num_heads % num_kv_heads == 0, "num_heads must be a multiple of num_kv_heads (GQA)" self.batch = batch self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.seq_len = seq_len self.page_size = page_size self.group_size = num_heads // num_kv_heads self.scale = 1.0 / math.sqrt(head_dim) self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False) def forward( self, query: torch.Tensor, kv_cache: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, ) -> torch.Tensor: B, H, D = query.shape out = torch.empty(B, H, D, dtype=query.dtype, device=query.device) # Dynamic heuristic for choosing optimal BLOCK_N if self.seq_len >= 4000: BLOCK_N = 256 elif self.seq_len == 2000: BLOCK_N = 256 elif self.seq_len >= 1000: BLOCK_N = 128 else: BLOCK_N = 64 grid = (self.num_heads, B) paged_attention_kernel[grid]( query, kv_cache, block_table, seq_lens, out, self.scale, query.stride(0), query.stride(1), query.stride(2), kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3), block_table.stride(0), block_table.stride(1), out.stride(0), out.stride(1), out.stride(2), self.group_size, self.num_kv_heads, head_dim=self.head_dim, page_size=self.page_size, BLOCK_N=BLOCK_N, ) return out def get_inputs(): B = 8 H = 32 Hkv = 8 D = 128 L = 1024 P = 16 pages_per_seq = (L + P - 1) // P total_pages = max(B * pages_per_seq + 8, 64) query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1 kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1 perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int() block_table = perm.contiguous() seq_lens = torch.full((B,), L, dtype=torch.int32) return [query, kv_cache, block_table, seq_lens] def get_init_inputs(): return [8, 32, 8, 128, 1024, 16]