"""Paged-attention decode kernel for SM120 (RTX PRO 6000 Blackwell). Flash-decoding style split-K CUDA kernel (torch.utils.cpp_extension.load_inline): - One threadblock per (batch, kv_head, split). Each block streams its chunk of the KV cache exactly once (K and V share a 512B segment, gathered via the page table staged in shared memory), computes online softmax for the G = num_heads/num_kv_heads grouped query heads with 8/16B vector loads and a 2-stage register prefetch pipeline. - Split partials (fp32 O, m, l) are merged by the LAST finishing block of each (batch, kv_head) group — detected with an auto-resetting atomicInc semaphore — so the whole decode is a single kernel launch. """ import math import os import torch import torch.nn as nn OP_TYPE = "attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] # --- Shape knobs (kept for interface parity with reference.py) ------------- BATCH = 8 NUM_HEADS = 32 NUM_KV_HEADS = 8 HEAD_DIM = 128 SEQ_LEN = 1024 PAGE_SIZE = 16 _CPP_SRC = r""" #include void paged_decode(at::Tensor q, at::Tensor kv, at::Tensor bt, at::Tensor sl, at::Tensor out, at::Tensor o_part, at::Tensor ml, at::Tensor sem, int64_t S, int64_t chunk, int64_t nwarps, int64_t pf, double qscale); """ _CUDA_SRC = r""" #include #include #include using bf16 = __nv_bfloat16; using bf162 = __nv_bfloat162; #define DEVINL __device__ __forceinline__ template struct VecT; template <> struct VecT<8> { using type = uint2; }; template <> struct VecT<16> { using type = uint4; }; template DEVINL void unpack_bf16(const void* src, float* dst) { #pragma unroll for (int i = 0; i < E / 2; ++i) { float2 t = __bfloat1622float2(reinterpret_cast(src)[i]); dst[2 * i] = t.x; dst[2 * i + 1] = t.y; } } // Maximum pages a single chunk may span (chunk <= 4096 tokens, page 16). #define MAX_CHUNK_PAGES 256 #define NEG_INF (-1e30f) template __global__ void __launch_bounds__(NWARPS * 32) decode_split_kernel(const bf16* __restrict__ q, const bf16* __restrict__ kvc, const int* __restrict__ block_table, const int* __restrict__ seq_lens, bf16* __restrict__ o_part, // (B, Hkv, S, G, D) normalized float* __restrict__ ml_part, // (B, Hkv, S, G, 2) unsigned* __restrict__ sem, // (B, Hkv) bf16* __restrict__ out, // (B, H, D) const int Hkv, const int S, const int chunk, const int max_blocks, const float qscale) { constexpr int TG = (G * D) / 32; // lanes cooperating on one token constexpr int E = D / TG; // elems per lane per K (or V) vector constexpr int TPW = 32 / TG; // tokens per warp per iteration constexpr int NS = NWARPS * TPW; // token streams per block using KVec = typename VecT::type; static_assert(E == 4 || E == 8, "bad config"); const int split = blockIdx.x; const int kvh = blockIdx.y; const int b = blockIdx.z; const int H = Hkv * G; const int len = seq_lens[b]; const int Sb = min(S, (len + chunk - 1) / chunk); const int t0 = split * chunk; if (t0 >= len) return; const int tend = min(t0 + chunk, len); __shared__ int sm_pages[MAX_CHUNK_PAGES]; __shared__ float sm_m[NWARPS][G]; __shared__ float sm_l[NWARPS][G]; __shared__ float sm_o[NWARPS][G][D]; __shared__ unsigned sm_last; const int warp = threadIdx.x >> 5; const int lane = threadIdx.x & 31; const int grp = lane / TG; const int gl = lane - grp * TG; const int stream = warp * TPW + grp; { const int p0 = t0 >> 4; const int np = ((tend + 15) >> 4) - p0; const int* bt = block_table + (int64_t)b * max_blocks + p0; for (int i = threadIdx.x; i < np; i += NWARPS * 32) sm_pages[i] = bt[i]; } __syncthreads(); // Query fragment for this lane (scaled by softmax scale * log2(e)). float qr[G][E]; { const bf16* qb = q + ((int64_t)b * H + (int64_t)kvh * G) * D + gl * E; #pragma unroll for (int g = 0; g < G; ++g) { KVec v = *reinterpret_cast(qb + g * D); float tmp[E]; unpack_bf16(&v, tmp); #pragma unroll for (int e = 0; e < E; ++e) qr[g][e] = tmp[e] * qscale; } } float m[G], l[G], o[G][E]; #pragma unroll for (int g = 0; g < G; ++g) { m[g] = NEG_INF; l[g] = 0.f; #pragma unroll for (int e = 0; e < E; ++e) o[g][e] = 0.f; } const int64_t slot_stride = (int64_t)Hkv * (2 * D); const int64_t kvh_off = (int64_t)kvh * (2 * D) + gl * E; const unsigned grp_mask = (TG == 32) ? 0xffffffffu : (((1u << TG) - 1u) << (grp * TG)); // PF-stage software pipeline: K/V for token t prefetched PF*NS ahead. auto addr = [&](int tok) { return kvc + (int64_t)sm_pages[(tok - t0) >> 4] * (16 * slot_stride) + (tok & 15) * slot_stride + kvh_off; }; KVec kb[PF], vb[PF]; { int tp = t0 + stream; #pragma unroll for (int p = 0; p < PF; ++p, tp += NS) { if (tp < tend) { const bf16* kp = addr(tp); kb[p] = __ldcs(reinterpret_cast(kp)); vb[p] = __ldcs(reinterpret_cast(kp + D)); } } } int t = t0 + stream; while (t < tend) { #pragma unroll for (int p = 0; p < PF; ++p, t += NS) { if (t >= tend) break; const KVec ck = kb[p], cv = vb[p]; const int tn = t + PF * NS; if (tn < tend) { const bf16* kp = addr(tn); kb[p] = __ldcs(reinterpret_cast(kp)); vb[p] = __ldcs(reinterpret_cast(kp + D)); } float kf[E], vf[E]; unpack_bf16(&ck, kf); unpack_bf16(&cv, vf); float s[G]; #pragma unroll for (int g = 0; g < G; ++g) { float acc = 0.f; #pragma unroll for (int e = 0; e < E; ++e) acc = fmaf(qr[g][e], kf[e], acc); s[g] = acc; } #pragma unroll for (int off = TG / 2; off > 0; off >>= 1) { #pragma unroll for (int g = 0; g < G; ++g) s[g] += __shfl_xor_sync(grp_mask, s[g], off); } #pragma unroll for (int g = 0; g < G; ++g) { const float mn = fmaxf(m[g], s[g]); const float sc = exp2f(m[g] - mn); const float pr = exp2f(s[g] - mn); l[g] = fmaf(l[g], sc, pr); m[g] = mn; #pragma unroll for (int e = 0; e < E; ++e) o[g][e] = fmaf(o[g][e], sc, pr * vf[e]); } } } // Merge the TPW token streams within each warp (all lanes redundantly). #pragma unroll for (int off = TG; off < 32; off <<= 1) { #pragma unroll for (int g = 0; g < G; ++g) { const float mo = __shfl_xor_sync(0xffffffffu, m[g], off); const float lo = __shfl_xor_sync(0xffffffffu, l[g], off); const float M = fmaxf(m[g], mo); const float wa = exp2f(m[g] - M); const float wb = exp2f(mo - M); l[g] = fmaf(l[g], wa, lo * wb); m[g] = M; #pragma unroll for (int e = 0; e < E; ++e) { const float oo = __shfl_xor_sync(0xffffffffu, o[g][e], off); o[g][e] = fmaf(o[g][e], wa, oo * wb); } } } // First lane group of each warp publishes the warp's state. if (grp == 0) { #pragma unroll for (int g = 0; g < G; ++g) { #pragma unroll for (int e = 0; e < E; ++e) sm_o[warp][g][gl * E + e] = o[g][e]; if (gl == 0) { sm_m[warp][g] = m[g]; sm_l[warp][g] = l[g]; } } } __syncthreads(); // Merge warps; write final output (single split) or fp32 partials. for (int idx = threadIdx.x; idx < G * D; idx += NWARPS * 32) { const int g = idx / D; const int d = idx - g * D; float M = NEG_INF; #pragma unroll for (int w = 0; w < NWARPS; ++w) M = fmaxf(M, sm_m[w][g]); float L = 0.f, O = 0.f; #pragma unroll for (int w = 0; w < NWARPS; ++w) { const float wgt = exp2f(sm_m[w][g] - M); L = fmaf(sm_l[w][g], wgt, L); O = fmaf(sm_o[w][g][d], wgt, O); } if (Sb == 1) { out[((int64_t)b * H + kvh * G + g) * D + d] = __float2bfloat16(O / L); } else { // Store the split-local softmax output (normalized) in bf16 to // halve partial traffic; merge weights come from (M, L) in fp32. o_part[((((int64_t)b * Hkv + kvh) * S + split) * G + g) * D + d] = __float2bfloat16(O / L); if (d == 0) { float* mlp = ml_part + ((((int64_t)b * Hkv + kvh) * S + split) * G + g) * 2; mlp[0] = M; mlp[1] = L; } } } if (Sb == 1) return; // Semaphore: the last block of this (b, kvh) group merges all partials. // atomicInc with val = Sb-1 wraps back to 0, so no reset pass is needed. __threadfence(); __syncthreads(); if (threadIdx.x == 0) sm_last = (atomicInc(&sem[b * Hkv + kvh], (unsigned)(Sb - 1)) == (unsigned)(Sb - 1)); __syncthreads(); if (!sm_last) return; const bf16* op = o_part + ((int64_t)b * Hkv + kvh) * S * G * D; const float* mlp = ml_part + ((int64_t)b * Hkv + kvh) * S * G * 2; for (int idx = threadIdx.x; idx < G * D; idx += NWARPS * 32) { const int g = idx / D; const int d = idx - g * D; float M = NEG_INF; for (int s = 0; s < Sb; ++s) M = fmaxf(M, mlp[(s * G + g) * 2]); float W = 0.f, O = 0.f; for (int s = 0; s < Sb; ++s) { const float w = exp2f(mlp[(s * G + g) * 2] - M) * mlp[(s * G + g) * 2 + 1]; W += w; O = fmaf(w, __bfloat162float(op[((int64_t)s * G + g) * D + d]), O); } out[((int64_t)b * H + kvh * G + g) * D + d] = __float2bfloat16(O / W); } } void paged_decode(at::Tensor q, at::Tensor kv, at::Tensor bt, at::Tensor sl, at::Tensor out, at::Tensor o_part, at::Tensor ml, at::Tensor sem, int64_t S, int64_t chunk, int64_t nwarps, int64_t pf, double qscale) { const int H = q.size(1); const int D = q.size(2); const int B = q.size(0); const int Hkv = kv.size(2); const int G = H / Hkv; const int maxb = bt.size(1); auto stream = at::cuda::getCurrentCUDAStream(); dim3 grid((unsigned)S, (unsigned)Hkv, (unsigned)B); #define LAUNCH(D_, G_, W_, PF_) \ decode_split_kernel<<>>( \ reinterpret_cast(q.data_ptr()), \ reinterpret_cast(kv.data_ptr()), \ bt.data_ptr(), sl.data_ptr(), \ reinterpret_cast(o_part.data_ptr()), ml.data_ptr(), \ reinterpret_cast(sem.data_ptr()), \ reinterpret_cast(out.data_ptr()), \ Hkv, (int)S, (int)chunk, maxb, (float)qscale) #define PF_SWITCH(D_, G_, W_) \ do { \ if (pf == 3) LAUNCH(D_, G_, W_, 3); \ else if (pf == 4) LAUNCH(D_, G_, W_, 4); \ else if (pf == 6) LAUNCH(D_, G_, W_, 6); \ else LAUNCH(D_, G_, W_, 2); \ } while (0) if (D == 128 && G == 4) { if (nwarps == 4) PF_SWITCH(128, 4, 4); else if (nwarps == 16) PF_SWITCH(128, 4, 16); else PF_SWITCH(128, 4, 8); } else if (D == 128 && G == 8) { if (nwarps == 4) PF_SWITCH(128, 8, 4); else if (nwarps == 16) PF_SWITCH(128, 8, 16); else PF_SWITCH(128, 8, 8); } else if (D == 64 && G == 4) { if (nwarps == 2) PF_SWITCH(64, 4, 2); else if (nwarps == 8) PF_SWITCH(64, 4, 8); else if (nwarps == 16) PF_SWITCH(64, 4, 16); else PF_SWITCH(64, 4, 4); } else if (D == 64 && G == 8) { if (nwarps == 8) PF_SWITCH(64, 8, 8); else PF_SWITCH(64, 8, 4); } else { TORCH_CHECK(false, "unsupported (D, G) = (", D, ", ", G, ")"); } #undef PF_SWITCH #undef LAUNCH } """ _ext = None def _get_ext(): global _ext if _ext is None: from torch.utils.cpp_extension import load_inline os.environ["TORCH_CUDA_ARCH_LIST"] = "12.0a" _ext = load_inline( name="paged_decode_v3", cpp_sources=[_CPP_SRC], cuda_sources=[_CUDA_SRC], functions=["paged_decode"], extra_cuda_cflags=["-O3", "--use_fast_math"], verbose=os.environ.get("PD_VERBOSE", "0") == "1", ) return _ext _LOG2E = 1.4426950408889634 # Tuned launch plans: (B, Hkv, G, D, L) -> (num_splits, nwarps, prefetch_depth) _PLANS = { (8, 8, 4, 128, 1024): (2, 8, 2), (32, 8, 4, 128, 2048): (1, 8, 2), (4, 8, 8, 128, 4096): (9, 8, 2), (16, 8, 4, 128, 1535): (4, 4, 2), (8, 4, 4, 64, 2000): (4, 8, 6), } def _plan(batch, num_kv_heads, group, head_dim, seq_len): env_s = os.environ.get("PD_S") env_w = os.environ.get("PD_NWARPS") env_pf = os.environ.get("PD_PF") key = (batch, num_kv_heads, group, head_dim, seq_len) if key in _PLANS and not (env_s or env_w or env_pf): return _PLANS[key] base = _PLANS.get(key) pairs = batch * num_kv_heads if env_s: S = int(env_s) elif base: S = base[0] else: target = int(os.environ.get("PD_TARGET_BLOCKS", "576")) S = max(1, -(-target // pairs)) if env_w: nwarps = int(env_w) elif base: nwarps = base[1] else: nwarps = 8 if head_dim == 128 else 4 if env_pf: pf = int(env_pf) elif base: pf = base[2] else: pf = 2 return S, nwarps, pf class Model(nn.Module): """Single-query paged attention decode (matches reference.Model interface).""" def __init__(self, batch, num_heads, num_kv_heads, head_dim, seq_len, page_size): super().__init__() assert num_heads % num_kv_heads == 0 self.batch = batch self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.seq_len = seq_len self.page_size = page_size self.group_size = num_heads // num_kv_heads self.scale = 1.0 / math.sqrt(head_dim) self.qscale = self.scale * _LOG2E self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False) G = self.group_size self._supported = page_size == 16 and (head_dim, G) in ( (128, 4), (128, 8), (64, 4), (64, 8), ) S, nwarps, pf = _plan(batch, num_kv_heads, G, head_dim, seq_len) chunk = max(16, (-(-seq_len // S) + 15) // 16 * 16) chunk = min(chunk, 4096) S = -(-seq_len // chunk) self.S = S self.chunk = chunk self.nwarps = nwarps self.pf = pf if self._supported and torch.cuda.is_available(): dev = torch.device("cuda") self._out = torch.empty(batch, num_heads, head_dim, dtype=torch.bfloat16, device=dev) self._opart = torch.empty( batch, num_kv_heads, S, G, head_dim, dtype=torch.bfloat16, device=dev ) self._ml = torch.empty(batch, num_kv_heads, S, G, 2, dtype=torch.float32, device=dev) self._sem = torch.zeros(batch, num_kv_heads, dtype=torch.int32, device=dev) self._fn = _get_ext().paged_decode self._rest = (self._out, self._opart, self._ml, self._sem, S, chunk, nwarps, pf, self.qscale) else: self._out = None self._fn = None # Hot path: skip nn.Module.__call__ hook dispatch (a few us per call). def __call__(self, query, kv_cache, block_table, seq_lens): fn = self._fn if fn is not None: fn(query, kv_cache, block_table, seq_lens, *self._rest) return self._out return self._fallback(query, kv_cache, block_table, seq_lens) forward = __call__ def _fallback(self, query, kv_cache, block_table, seq_lens): B, H, D = query.shape P = self.page_size G = self.group_size out = torch.empty_like(query) for b in range(B): L = int(seq_lens[b].item()) np_ = (L + P - 1) // P kv = kv_cache.index_select(0, block_table[b, :np_].long()) kv = kv.reshape(np_ * P, self.num_kv_heads, 2 * D)[:L] k = kv[..., :D].repeat_interleave(G, dim=1).float() v = kv[..., D:].repeat_interleave(G, dim=1).float() qf = query[b].float() scores = torch.einsum("hd,lhd->hl", qf, k) * self.scale probs = torch.softmax(scores, dim=-1) out[b] = torch.einsum("hl,lhd->hd", probs, v).to(query.dtype) return out def get_inputs(): B, H, Hkv, D, L, P = BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE pages_per_seq = (L + P - 1) // P total_pages = max(B * pages_per_seq + 8, 64) query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1 kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1 perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int() block_table = perm.contiguous() seq_lens = torch.full((B,), L, dtype=torch.int32) return [query, kv_cache, block_table, seq_lens] def get_init_inputs(): return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]