"""Paged attention decode via Triton. Strategy: Flash-Decoding (split-K) with online softmax. Each program processes one (batch, kv-head-block, sequence-chunk) and writes a partial output + LSE. A second kernel reduces the partials to the final output. GQA grouping: process BLOCK_KV consecutive KV heads per program. The Q tile has BLOCK_M = BLOCK_KV * G query heads (covering all G queries per kv head). K is loaded directly in (BLOCK_KV, BLOCK_N, BLOCK_D) layout and the matmul is 3D-batched so each (kv-head, group) pair sees the right K head. Per-shape tuning: small-batch shapes use a smaller BLOCK_KV to spawn more programs and better fill the SMs; the matmul still hits tensor cores because BLOCK_KV * G >= 16. """ from __future__ import annotations import math import torch import triton import triton.language as tl # Module-level knobs (overridden by check.py / benchmark.py from shapes.py). BATCH = 8 NUM_HEADS = 32 NUM_KV_HEADS = 8 HEAD_DIM = 128 SEQ_LEN = 1024 PAGE_SIZE = 16 # --------------------------------------------------------------------------- # Kernels # --------------------------------------------------------------------------- @triton.jit def _paged_attn_splitk_kernel( Q_ptr, KV_ptr, BT_ptr, SL_ptr, O_partial_ptr, LSE_partial_ptr, sm_scale, # Q strides stride_qb, stride_qh, # KV strides stride_kvb, stride_kvp, stride_kvh, # Block-table strides stride_btb, stride_btblock, # O_partial strides stride_op_b, stride_op_h, stride_op_s, stride_op_d, # LSE_partial strides stride_lse_b, stride_lse_h, stride_lse_s, BLOCK_KV: tl.constexpr, # number of kv heads per program G: tl.constexpr, # group size BLOCK_M: tl.constexpr, # = BLOCK_KV * G BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, D: tl.constexpr, P: tl.constexpr, SPLIT_K: tl.constexpr, CHUNK: tl.constexpr, ): bid = tl.program_id(0) hkv_blk = tl.program_id(1) split = tl.program_id(2) seq_len = tl.load(SL_ptr + bid) start = split * CHUNK end = tl.minimum(start + CHUNK, seq_len) hkv_start = hkv_blk * BLOCK_KV h_start = hkv_start * G offs_m = h_start + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_D) offs_d_v = D + tl.arange(0, BLOCK_D) offs_kv = hkv_start + tl.arange(0, BLOCK_KV) # Load Q: (BLOCK_M, BLOCK_D) -- one row per query head. q_ptrs = Q_ptr + bid * stride_qb + offs_m[:, None] * stride_qh + offs_d[None, :] q = tl.load(q_ptrs) # (BLOCK_M, BLOCK_D) bf16 # Reshape to (BLOCK_KV, G, BLOCK_D) for the batched matmul. q3 = tl.reshape(q, (BLOCK_KV, G, BLOCK_D)) # Online softmax accumulators m_i = tl.full([BLOCK_M], -float("inf"), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) if start < end: first_page = start // P last_page = (end - 1) // P for p in range(first_page, last_page + 1): block_id = tl.load(BT_ptr + bid * stride_btb + p * stride_btblock) tok_idx = p * P + offs_n valid = (tok_idx >= start) & (tok_idx < end) # Load K directly in (BLOCK_KV, BLOCK_N, BLOCK_D) layout to avoid permute. k_addr = ( KV_ptr + block_id * stride_kvb + offs_kv[:, None, None] * stride_kvh + offs_n[None, :, None] * stride_kvp + offs_d[None, None, :] ) k3 = tl.load(k_addr, mask=valid[None, :, None], other=0.0) # (BLOCK_KV, BLOCK_N, BLOCK_D) # Load V similarly. v_addr = ( KV_ptr + block_id * stride_kvb + offs_kv[:, None, None] * stride_kvh + offs_n[None, :, None] * stride_kvp + offs_d_v[None, None, :] ) v3 = tl.load(v_addr, mask=valid[None, :, None], other=0.0) # (BLOCK_KV, BLOCK_N, BLOCK_D) # QK^T batched: (BLOCK_KV, G, BLOCK_N). s3 = tl.dot(q3, tl.permute(k3, (0, 2, 1))) # (BLOCK_KV, G, BLOCK_N) s3 = s3 * sm_scale s3 = tl.where(valid[None, None, :], s3, -float("inf")) # Flatten to (BLOCK_M, BLOCK_N) for the softmax accumulators. s = tl.reshape(s3, (BLOCK_M, BLOCK_N)) m_new = tl.maximum(m_i, tl.max(s, axis=1)) alpha = tl.exp(m_i - m_new) p = tl.exp(s - m_new[:, None]) l_i = l_i * alpha + tl.sum(p, axis=1) # Reshape p back to (BLOCK_KV, G, BLOCK_N) for the second matmul. p3 = tl.reshape(p, (BLOCK_KV, G, BLOCK_N)) # P @ V batched: (BLOCK_KV, G, BLOCK_D). o3 = tl.dot(p3.to(tl.bfloat16), v3) # fp32 accumulator o2 = tl.reshape(o3, (BLOCK_M, BLOCK_D)) acc = acc * alpha[:, None] + o2 m_i = m_new l_i_safe = tl.where(l_i > 0, l_i, 1.0) o_partial = (acc / l_i_safe[:, None]).to(tl.bfloat16) lse = m_i + tl.log(l_i_safe) else: o_partial = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.bfloat16) lse = tl.full([BLOCK_M], -1.0e30, dtype=tl.float32) o_ptrs = ( O_partial_ptr + bid * stride_op_b + offs_m[:, None] * stride_op_h + split * stride_op_s + offs_d[None, :] * stride_op_d ) tl.store(o_ptrs, o_partial) lse_ptrs = ( LSE_partial_ptr + bid * stride_lse_b + offs_m * stride_lse_h + split * stride_lse_s ) tl.store(lse_ptrs, lse) @triton.jit def _reduce_kernel( O_partial_ptr, LSE_partial_ptr, O_ptr, stride_op_b, stride_op_h, stride_op_s, stride_op_d, stride_lse_b, stride_lse_h, stride_lse_s, stride_ob, stride_oh, stride_od, H: tl.constexpr, BLOCK_D: tl.constexpr, SPLIT_K: tl.constexpr, ): bid = tl.program_id(0) h = tl.program_id(1) offs_d = tl.arange(0, BLOCK_D) offs_s = tl.arange(0, SPLIT_K) lse_ptrs = ( LSE_partial_ptr + bid * stride_lse_b + h * stride_lse_h + offs_s * stride_lse_s ) lse = tl.load(lse_ptrs) # (SPLIT_K,) m_max = tl.max(lse) w = tl.exp(lse - m_max) w_sum = tl.sum(w) o_ptrs = ( O_partial_ptr + bid * stride_op_b + h * stride_op_h + offs_s[:, None] * stride_op_s + offs_d[None, :] * stride_op_d ) o = tl.load(o_ptrs) # (SPLIT_K, BLOCK_D) o_total = tl.sum(o.to(tl.float32) * w[:, None], axis=0) / w_sum out_ptrs = O_ptr + bid * stride_ob + h * stride_oh + offs_d * stride_od tl.store(out_ptrs, o_total.to(tl.bfloat16)) # --------------------------------------------------------------------------- # Python wrapper # --------------------------------------------------------------------------- def _next_pow2(x: int) -> int: return 1 << (x - 1).bit_length() if x > 0 else 1 class Model(torch.nn.Module): """Triton paged-attention decode (Flash-Decoding).""" def __init__( self, batch: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int, page_size: int, ): super().__init__() assert num_heads % num_kv_heads == 0 self.batch = batch self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.seq_len = seq_len self.page_size = page_size self.group_size = num_heads // num_kv_heads self.scale = 1.0 / math.sqrt(head_dim) # Pick BLOCK_KV to balance parallelism vs per-program work. # The 3D batched matmul (BLOCK_KV, G, BLOCK_N) hits tensor cores when # BLOCK_KV * G >= 16. For G=4, BLOCK_KV=4 -> 16 heads. For G=8, BLOCK_KV=2 -> 16. G = self.group_size HKV = self.num_kv_heads # For low-batch / low-parallelism shapes, use a smaller BLOCK_KV to get more programs. if HKV >= 8 and self.seq_len <= 2048: BLOCK_KV = max(1, 16 // G // 2) # 2 for G=4, 1 for G=8 else: BLOCK_KV = max(1, 16 // G) # 4 for G=4, 2 for G=8 BLOCK_KV = min(BLOCK_KV, HKV) self.BLOCK_KV = BLOCK_KV self.BLOCK_M = BLOCK_KV * G self.BLOCK_D = head_dim self.BLOCK_N = page_size self.CHUNK = 128 self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False) def forward( self, query: torch.Tensor, kv_cache: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor, ) -> torch.Tensor: B, H, D = query.shape Hkv = self.num_kv_heads G = self.group_size P = self.page_size BLOCK_M = self.BLOCK_M BLOCK_KV = self.BLOCK_KV BLOCK_D = self.BLOCK_D BLOCK_N = self.BLOCK_N CHUNK = self.CHUNK assert Hkv * G == H assert Hkv % BLOCK_KV == 0 HKV_BLKS = Hkv // BLOCK_KV SPLIT_K = max(1, _next_pow2((self.seq_len + CHUNK - 1) // CHUNK)) o_partial = torch.empty( B, H, SPLIT_K, D, dtype=torch.bfloat16, device=query.device ) lse_partial = torch.empty( B, H, SPLIT_K, dtype=torch.float32, device=query.device ) out = torch.empty(B, H, D, dtype=query.dtype, device=query.device) # Stage 1: split-K attention grid = (B, HKV_BLKS, SPLIT_K) _paged_attn_splitk_kernel[grid]( query, kv_cache, block_table, seq_lens, o_partial, lse_partial, self.scale, query.stride(0), query.stride(1), kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), block_table.stride(0), block_table.stride(1), o_partial.stride(0), o_partial.stride(1), o_partial.stride(2), o_partial.stride(3), lse_partial.stride(0), lse_partial.stride(1), lse_partial.stride(2), BLOCK_KV=BLOCK_KV, G=G, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_D=BLOCK_D, D=D, P=P, SPLIT_K=SPLIT_K, CHUNK=CHUNK, num_warps=2, num_stages=3, ) # Stage 2: reduce partial outputs grid_red = (B, H) _reduce_kernel[grid_red]( o_partial, lse_partial, out, o_partial.stride(0), o_partial.stride(1), o_partial.stride(2), o_partial.stride(3), lse_partial.stride(0), lse_partial.stride(1), lse_partial.stride(2), out.stride(0), out.stride(1), out.stride(2), H=H, BLOCK_D=BLOCK_D, SPLIT_K=SPLIT_K, num_warps=1, num_stages=1, ) return out def get_inputs(): B = BATCH H = NUM_HEADS Hkv = NUM_KV_HEADS D = HEAD_DIM L = SEQ_LEN P = PAGE_SIZE pages_per_seq = (L + P - 1) // P total_pages = max(B * pages_per_seq + 8, 64) query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1 kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1 perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int() block_table = perm.contiguous() seq_lens = torch.full((B,), L, dtype=torch.int32) return [query, kv_cache, block_table, seq_lens] def get_init_inputs(): return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]