"""Paged-attention decode kernel (Triton flash-decoding) for SM120 Blackwell. Single-query decode. Memory-bound: the KV cache must be streamed exactly once, reused across each GQA group. Small-batch shapes are parallelized with split-K (flash-decoding) to fill the 188 SMs, then a cheap combine kernel reduces the per-split partial softmaxes. When a shape needs no split (S==1) the phase-1 kernel writes the normalized bf16 output directly and the combine is skipped. """ import math import torch import torch.nn as nn import triton import triton.language as tl @triton.jit def _paged_decode_phase1( q_ptr, kv_ptr, bt_ptr, sl_ptr, o_ptr, # partial out (B,H,S,D) fp32 OR final (B,H,D) bf16 if S==1 m_ptr, l_ptr, # (B,H,S) fp32 (unused when S==1) scale, H: tl.constexpr, Hkv: tl.constexpr, G: tl.constexpr, D: tl.constexpr, P: tl.constexpr, max_blocks: tl.constexpr, S: tl.constexpr, BLOCK_N: tl.constexpr, GP: tl.constexpr, PAGES_PER_BLOCK: tl.constexpr, NORMALIZE: tl.constexpr, ): pid = tl.program_id(0) split = pid % S bkv = pid // S b = bkv // Hkv kvh = bkv % Hkv L = tl.load(sl_ptr + b) num_pages = (L + P - 1) // P pages_per_split = (num_pages + S - 1) // S page_start = split * pages_per_split page_end = tl.minimum(page_start + pages_per_split, num_pages) grow = tl.arange(0, GP) dcol = tl.arange(0, D) q_head = kvh * G + grow q_off = b * (H * D) + q_head[:, None] * D + dcol[None, :] q_mask = grow[:, None] < G q = tl.load(q_ptr + q_off, mask=q_mask, other=0.0) m_i = tl.full((GP,), -float("inf"), dtype=tl.float32) l_i = tl.zeros((GP,), dtype=tl.float32) acc = tl.zeros((GP, D), dtype=tl.float32) row = tl.arange(0, BLOCK_N) local_page = row // P within = row % P kv2d = 2 * D page_stride = P * Hkv * kv2d for pblk in range(page_start, page_end, PAGES_PER_BLOCK): gpage = pblk + local_page valid_page = gpage < page_end page_id = tl.load(bt_ptr + b * max_blocks + gpage, mask=valid_page, other=0) token = gpage * P + within valid = valid_page & (token < L) base = page_id * page_stride + within * (Hkv * kv2d) + kvh * kv2d k = tl.load(kv_ptr + base[:, None] + dcol[None, :], mask=valid[:, None], other=0.0) v = tl.load(kv_ptr + base[:, None] + (D + dcol[None, :]), mask=valid[:, None], other=0.0) s = tl.dot(q, tl.trans(k)) * scale s = tl.where(valid[None, :], s, -float("inf")) m_new = tl.maximum(m_i, tl.max(s, axis=1)) alpha = tl.exp(m_i - m_new) p = tl.exp(s - m_new[:, None]) l_i = l_i * alpha + tl.sum(p, axis=1) acc = acc * alpha[:, None] + tl.dot(p.to(k.dtype), v) m_i = m_new out_head = kvh * G + grow valid_g = grow < G if NORMALIZE: o = acc / l_i[:, None] o_base = b * (H * D) + out_head[:, None] * D + dcol[None, :] tl.store(o_ptr + o_base, o.to(tl.bfloat16), mask=valid_g[:, None]) else: o_base = b * (H * S * D) + out_head[:, None] * (S * D) + split * D + dcol[None, :] tl.store(o_ptr + o_base, acc.to(tl.bfloat16), mask=valid_g[:, None]) ml_off = b * (H * S) + out_head * S + split tl.store(m_ptr + ml_off, m_i, mask=valid_g) tl.store(l_ptr + ml_off, l_i, mask=valid_g) @triton.jit def _paged_decode_combine( o_ptr, m_ptr, l_ptr, out_ptr, H: tl.constexpr, D: tl.constexpr, S: tl.constexpr, ): pid = tl.program_id(0) b = pid // H h = pid % H sidx = tl.arange(0, S) dcol = tl.arange(0, D) m = tl.load(m_ptr + (b * H + h) * S + sidx) l = tl.load(l_ptr + (b * H + h) * S + sidx) m_g = tl.max(m, axis=0) sc = tl.exp(m - m_g) l_g = tl.sum(l * sc, axis=0) o = tl.load(o_ptr + (b * H + h) * (S * D) + sidx[:, None] * D + dcol[None, :]).to(tl.float32) out = tl.sum(o * sc[:, None], axis=0) / l_g tl.store(out_ptr + (b * H + h) * D + dcol, out.to(tl.bfloat16)) # Per-shape tuned configs keyed by (B, H, Hkv, D, seq_len, P) # -> (num_splits, BLOCK_N, num_warps, num_stages). _CONFIGS = { (8, 32, 8, 128, 1024, 16): (4, 64, 4, 2), (32, 32, 8, 128, 2048, 16): (1, 64, 4, 2), (4, 64, 8, 128, 4096, 16): (16, 32, 4, 2), (16, 32, 8, 128, 1535, 16): (1, 128, 4, 2), (8, 16, 4, 64, 2000, 16): (16, 64, 2, 2), } _TARGET_BLOCKS = 512 _BLOCK_N = 64 _NUM_WARPS = 4 _NUM_STAGES = 2 def _heuristic_num_splits(groups, num_pages): s = max(1, _TARGET_BLOCKS // groups) s = min(s, num_pages) if s <= 1: return 1 return triton.next_power_of_2(s) class Model(nn.Module): def __init__(self, batch, num_heads, num_kv_heads, head_dim, seq_len, page_size): super().__init__() self.batch = batch self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.seq_len = seq_len self.page_size = page_size self.group_size = num_heads // num_kv_heads self.scale = 1.0 / math.sqrt(head_dim) self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False) key = (batch, num_heads, num_kv_heads, head_dim, seq_len, page_size) num_pages = (seq_len + page_size - 1) // page_size if key in _CONFIGS: self._S, self._bn, self._warps, self._stages = _CONFIGS[key] self._S = min(self._S, num_pages) else: self._S = _heuristic_num_splits(batch * num_kv_heads, num_pages) self._bn, self._warps, self._stages = _BLOCK_N, _NUM_WARPS, _NUM_STAGES self._GP = max(4, triton.next_power_of_2(self.group_size)) self._bufs = None # persistent (out, o_partial, m, l) self._graphs = {} # sig -> CUDAGraph def _alloc(self, device): B, H, D = self.batch, self.num_heads, self.head_dim S = self._S out = torch.empty(B, H, D, dtype=torch.bfloat16, device=device) if S == 1: self._bufs = (out, None, None, None) else: o_p = torch.empty(B, H, S, D, dtype=torch.bfloat16, device=device) m_p = torch.empty(B, H, S, dtype=torch.float32, device=device) l_p = torch.empty(B, H, S, dtype=torch.float32, device=device) self._bufs = (out, o_p, m_p, l_p) def _launch(self, query, kv_cache, block_table, seq_lens): B, H, D = self.batch, self.num_heads, self.head_dim Hkv, G, P = self.num_kv_heads, self.group_size, self.page_size max_blocks = block_table.shape[1] S, GP, BLOCK_N, warps, stages = self._S, self._GP, self._bn, self._warps, self._stages PAGES_PER_BLOCK = BLOCK_N // P out, o_p, m_p, l_p = self._bufs if S == 1: _paged_decode_phase1[(B * Hkv,)]( query, kv_cache, block_table, seq_lens, out, out, out, self.scale, H, Hkv, G, D, P, max_blocks, 1, BLOCK_N, GP, PAGES_PER_BLOCK, True, num_warps=warps, num_stages=stages, ) else: _paged_decode_phase1[(B * Hkv * S,)]( query, kv_cache, block_table, seq_lens, o_p, m_p, l_p, self.scale, H, Hkv, G, D, P, max_blocks, S, BLOCK_N, GP, PAGES_PER_BLOCK, False, num_warps=warps, num_stages=stages, ) _paged_decode_combine[(B * H,)]( o_p, m_p, l_p, out, H, D, S, num_warps=4, ) return out def forward(self, query, kv_cache, block_table, seq_lens): if self._bufs is None: self._alloc(query.device) sig = (query.data_ptr(), kv_cache.data_ptr(), block_table.data_ptr(), seq_lens.data_ptr()) g = self._graphs.get(sig) if g is not None: g.replay() return self._bufs[0] # Eager run (also serves as compile/warmup), then try to capture a graph # bound to these input addresses for fast replay on repeated calls. out = self._launch(query, kv_cache, block_table, seq_lens) if len(self._graphs) < 8: try: self._graphs[sig] = self._capture(query, kv_cache, block_table, seq_lens) except Exception: pass return out def _capture(self, query, kv_cache, block_table, seq_lens): s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): self._launch(query, kv_cache, block_table, seq_lens) torch.cuda.current_stream().wait_stream(s) g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): self._launch(query, kv_cache, block_table, seq_lens) return g