"""Kimi Delta Attention (KDA) forward, chunk form — custom Triton kernels for SM120 (RTX PRO 6000 Blackwell). No library calls; the chunk-parallel KDA math is implemented from scratch. Math (per chunk, in-chunk cumsummed gate g_cs = cumsum(g) over the BT tokens): k_g = k*exp(g_cs); k_ng = k*exp(-g_cs); q_g = (scale*q)*exp(g_cs) gram = k_g @ k_ng^T (decayed K-K gram, lower-tri used) N = beta_row * gram (strictly lower) Tinv = (I + N)^{-1} (block tril-solve) A = Tinv * beta_col w = A @ k_g ; u = A @ v Aqk = lower_incl_diag(q_g @ k_ng^T) inter-chunk recurrence (state S [K,V], S_0 = 0): v_i = u - w @ S o = q_g @ S + Aqk @ v_i S = exp(g_cs[BT-1]) * (S + k_ng^T @ v_i) Two-kernel split: 1) intra kernel — grid (B*H*NT,). One program per (b, h, chunk). Builds N, solves Tinv via a *blocked* forward substitution (BT=64 split into NB=4 blocks of BC=16: four 16x16 unit-lower inverses + off-diagonal matmuls via tl.dot), then computes w, u (block-wise, exploiting triangularity), Aqk. w/u/Aqk/q_g/k_ng/g_last are stored to HBM in bf16 to cut the recurrence's redundant per-V-tile traffic. 2) recurrence kernel — grid (V/BV, B*H). Sequential over chunks (BV=16 keeps enough blocks live for occupancy; num_stages=2 software-pipelines the chunk loop to hide load latency behind the carried state S). Moving the (sequential, expensive) tril solve out of the recurrence into the embarrassingly-parallel intra kernel is what restores occupancy on the 240-SM GPU; bf16 intermediates + V-tile + pipelining keep the sequential recurrence near its memory floor. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl # --------------------------------------------------------------------------- # # blocked tril-solve helpers (BT=64 split into NB=4 blocks of BC=16) # --------------------------------------------------------------------------- # @triton.jit def _inv16(Nii, BC: tl.constexpr): """Inverse of I+Nii for a strictly-lower BC x BC tile, via row-scan.""" A0 = -Nii offs = tl.arange(0, BC) for ii in range(1, BC): r_ii = (offs == ii) rvec = tl.sum(tl.where(r_ii[:, None], A0, 0.0), axis=0) contrib = tl.sum(rvec[:, None] * A0, axis=0) upd = r_ii[:, None] & (offs[None, :] < ii) A0 = tl.where(upd, A0 + contrib[None, :], A0) return tl.where(offs[:, None] == offs[None, :], 1.0, A0) @triton.jit def _blk4(N4, bi, bk, NB: tl.constexpr): """Extract the [BC,BC] block (bi,bk) from a [NB,BC,NB,BC] reshaped tile.""" sel = (tl.arange(0, NB)[:, None, None, None] == bi) & \ (tl.arange(0, NB)[None, None, :, None] == bk) return tl.sum(tl.sum(tl.where(sel, N4, 0.0), axis=0), axis=1) @triton.jit def _blkrow(M4, bi, NB: tl.constexpr): """Extract block-row bi [BC,K] from a [NB,BC,K] reshaped tile.""" sel = (tl.arange(0, NB)[:, None, None] == bi) return tl.sum(tl.where(sel, M4, 0.0), axis=0) # --------------------------------------------------------------------------- # # intra kernel: per (b, h, chunk) # --------------------------------------------------------------------------- # @triton.jit(do_not_specialize=["B", "T", "H", "scale"]) def _kda_intra_kernel( q_ptr, k_ptr, v_ptr, g_ptr, beta_ptr, w_ptr, u_ptr, Aqk_ptr, qg_ptr, kng_ptr, glast_ptr, scale, B, T, H, NT: tl.constexpr, BT: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BV: tl.constexpr, BC: tl.constexpr, NB: tl.constexpr, PREC: tl.constexpr, PSOLVE: tl.constexpr, ): pid = tl.program_id(0) i_b = pid // (H * NT) rem = pid % (H * NT) i_h = rem // NT i_n = rem % NT HK = H * K HV = H * V offs_r = tl.arange(0, BT) offs_k = tl.arange(0, K) rr = offs_r[:, None] cc = offs_r[None, :] t_idx = i_n * BT + offs_r qk_row = (i_b * T + t_idx) * HK + i_h * K v_row = (i_b * T + t_idx) * HV + i_h * V k = tl.load(k_ptr + qk_row[:, None] + offs_k[None, :]).to(tl.float32) q = tl.load(q_ptr + qk_row[:, None] + offs_k[None, :]).to(tl.float32) * scale g = tl.load(g_ptr + qk_row[:, None] + offs_k[None, :]).to(tl.float32) beta = tl.load(beta_ptr + (i_b * T + t_idx) * H + i_h).to(tl.float32) g_cs = tl.cumsum(g, axis=0) g_last = tl.sum(g, axis=0) eg = tl.exp(g_cs) k_g = k * eg k_ng = k * tl.exp(-g_cs) q_g = q * eg gram = tl.dot(k_g, tl.trans(k_ng), input_precision=PREC) N = tl.where(rr > cc, gram, 0.0) * beta[:, None] # strictly lower # ---- blocked forward-substitution: Tinv = (I + N)^{-1} ---- # NB=4 diagonal 16x16 inverses, then off-diagonal blocks via matmul. # w[bi] = sum_k Tinv[bi][k] @ (beta*k_g)[k] (computed block-wise). N4 = tl.reshape(N, (NB, BC, NB, BC)) bg = beta[:, None] * k_g # [BT, K] bg4 = tl.reshape(bg, (NB, BC, K)) d0 = _inv16(_blk4(N4, 0, 0, NB), BC) d1 = _inv16(_blk4(N4, 1, 1, NB), BC) d2 = _inv16(_blk4(N4, 2, 2, NB), BC) d3 = _inv16(_blk4(N4, 3, 3, NB), BC) n10 = _blk4(N4, 1, 0, NB) n20 = _blk4(N4, 2, 0, NB); n21 = _blk4(N4, 2, 1, NB) n30 = _blk4(N4, 3, 0, NB); n31 = _blk4(N4, 3, 1, NB); n32 = _blk4(N4, 3, 2, NB) t10 = -tl.dot(d1, tl.dot(n10, d0, input_precision=PSOLVE), input_precision=PSOLVE) t20 = -tl.dot(d2, tl.dot(n20, d0, input_precision=PSOLVE) + tl.dot(n21, t10, input_precision=PSOLVE), input_precision=PSOLVE) t21 = -tl.dot(d2, tl.dot(n21, d1, input_precision=PSOLVE), input_precision=PSOLVE) t30 = -tl.dot(d3, tl.dot(n30, d0, input_precision=PSOLVE) + tl.dot(n31, t10, input_precision=PSOLVE) + tl.dot(n32, t20, input_precision=PSOLVE), input_precision=PSOLVE) t31 = -tl.dot(d3, tl.dot(n31, d1, input_precision=PSOLVE) + tl.dot(n32, t21, input_precision=PSOLVE), input_precision=PSOLVE) t32 = -tl.dot(d3, tl.dot(n32, d2, input_precision=PSOLVE), input_precision=PSOLVE) bg0 = _blkrow(bg4, 0, NB); bg1 = _blkrow(bg4, 1, NB) bg2 = _blkrow(bg4, 2, NB); bg3 = _blkrow(bg4, 3, NB) w0 = tl.dot(d0, bg0, input_precision=PSOLVE) w1 = tl.dot(t10, bg0, input_precision=PSOLVE) + tl.dot(d1, bg1, input_precision=PSOLVE) w2 = tl.dot(t20, bg0, input_precision=PSOLVE) + tl.dot(t21, bg1, input_precision=PSOLVE) + tl.dot(d2, bg2, input_precision=PSOLVE) w3 = tl.dot(t30, bg0, input_precision=PSOLVE) + tl.dot(t31, bg1, input_precision=PSOLVE) + tl.dot(t32, bg2, input_precision=PSOLVE) + tl.dot(d3, bg3, input_precision=PSOLVE) ob = tl.arange(0, BC) wdt = tl.bfloat16 tl.store(w_ptr + (pid * BT + 0 * BC + ob)[:, None] * K + offs_k[None, :], w0.to(wdt)) tl.store(w_ptr + (pid * BT + 1 * BC + ob)[:, None] * K + offs_k[None, :], w1.to(wdt)) tl.store(w_ptr + (pid * BT + 2 * BC + ob)[:, None] * K + offs_k[None, :], w2.to(wdt)) tl.store(w_ptr + (pid * BT + 3 * BC + ob)[:, None] * K + offs_k[None, :], w3.to(wdt)) Aqk_full = tl.dot(q_g, tl.trans(k_ng), input_precision=PREC) Aqk = tl.where(rr >= cc, Aqk_full, 0.0) base = pid * BT + offs_r # [BT] tl.store(Aqk_ptr + base[:, None] * BT + offs_r[None, :], Aqk.to(wdt)) tl.store(qg_ptr + base[:, None] * K + offs_k[None, :], q_g.to(wdt)) tl.store(kng_ptr + base[:, None] * K + offs_k[None, :], k_ng.to(wdt)) tl.store(glast_ptr + pid * K + offs_k, g_last) # u = Tinv @ (beta*v), tiled over V; reuse the same Tinv blocks. for i_v in range(0, V, BV): offs_v = i_v + tl.arange(0, BV) v_tile = tl.load(v_ptr + v_row[:, None] + offs_v[None, :]).to(tl.float32) bv = beta[:, None] * v_tile # [BT, BV] bv4 = tl.reshape(bv, (NB, BC, BV)) bv0 = _blkrow(bv4, 0, NB); bv1 = _blkrow(bv4, 1, NB) bv2 = _blkrow(bv4, 2, NB); bv3 = _blkrow(bv4, 3, NB) u0 = tl.dot(d0, bv0, input_precision=PSOLVE) u1 = tl.dot(t10, bv0, input_precision=PSOLVE) + tl.dot(d1, bv1, input_precision=PSOLVE) u2 = tl.dot(t20, bv0, input_precision=PSOLVE) + tl.dot(t21, bv1, input_precision=PSOLVE) + tl.dot(d2, bv2, input_precision=PSOLVE) u3 = tl.dot(t30, bv0, input_precision=PSOLVE) + tl.dot(t31, bv1, input_precision=PSOLVE) + tl.dot(t32, bv2, input_precision=PSOLVE) + tl.dot(d3, bv3, input_precision=PSOLVE) tl.store(u_ptr + (pid * BT + 0 * BC + ob)[:, None] * V + offs_v[None, :], u0.to(tl.bfloat16)) tl.store(u_ptr + (pid * BT + 1 * BC + ob)[:, None] * V + offs_v[None, :], u1.to(tl.bfloat16)) tl.store(u_ptr + (pid * BT + 2 * BC + ob)[:, None] * V + offs_v[None, :], u2.to(tl.bfloat16)) tl.store(u_ptr + (pid * BT + 3 * BC + ob)[:, None] * V + offs_v[None, :], u3.to(tl.bfloat16)) # --------------------------------------------------------------------------- # # recurrence kernel: per (v_tile, b, h), sequential over chunks # --------------------------------------------------------------------------- # @triton.jit(do_not_specialize=["B", "T", "H"]) def _kda_rec_kernel( w_ptr, u_ptr, Aqk_ptr, qg_ptr, kng_ptr, glast_ptr, o_ptr, B, T, H, NT: tl.constexpr, BT: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BV: tl.constexpr, PREC: tl.constexpr, ): i_v = tl.program_id(0) i_nh = tl.program_id(1) i_b = i_nh // H i_h = i_nh % H offs_r = tl.arange(0, BT) offs_k = tl.arange(0, K) offs_v = i_v * BV + tl.arange(0, BV) rr = offs_r[:, None] cc = offs_r[None, :] S = tl.zeros([K, BV], dtype=tl.float32) HV = H * V nh_off = i_nh * NT # chunk-0 intra pid for this (b, h) for i_n in range(0, NT): pid = nh_off + i_n base = pid * BT + offs_r # [BT] w = tl.load(w_ptr + base[:, None] * K + offs_k[None, :]).to(tl.float32) u = tl.load(u_ptr + base[:, None] * V + offs_v[None, :]).to(tl.float32) Aqk = tl.load(Aqk_ptr + base[:, None] * BT + offs_r[None, :]).to(tl.float32) qg = tl.load(qg_ptr + base[:, None] * K + offs_k[None, :]).to(tl.float32) kng = tl.load(kng_ptr + base[:, None] * K + offs_k[None, :]).to(tl.float32) glast = tl.load(glast_ptr + pid * K + offs_k) v_i = u - tl.dot(w, S, input_precision=PREC) o = tl.dot(qg, S, input_precision=PREC) + tl.dot(Aqk, v_i, input_precision=PREC) t_idx = i_n * BT + offs_r v_row = (i_b * T + t_idx) * HV + i_h * V tl.store(o_ptr + v_row[:, None] + offs_v[None, :], o.to(o_ptr.dtype.element_ty)) kn = tl.dot(tl.trans(kng), v_i, input_precision=PREC) # [K, BV] S = tl.exp(glast)[:, None] * (S + kn) def _kda_fwd(q, k, v, g, beta, scale, chunk_size=64): B, T, H, K = q.shape V = v.shape[-1] BT = chunk_size assert T % BT == 0 NT = T // BT device, dtype = q.device, q.dtype NBH = B * H * NT # Intermediates laid out flat as (B*H*NT, BT, D). # V-independent w/q_g/k_ng/A_qk stored in bf16 to halve HBM traffic (the # recurrence re-reads them per V-tile); compute stays fp32/tf32. w = torch.empty(NBH * BT * K, device=device, dtype=torch.bfloat16) u = torch.empty(NBH * BT * V, device=device, dtype=torch.bfloat16) Aqk = torch.empty(NBH * BT * BT, device=device, dtype=torch.bfloat16) qg = torch.empty(NBH * BT * K, device=device, dtype=torch.bfloat16) kng = torch.empty(NBH * BT * K, device=device, dtype=torch.bfloat16) glast = torch.empty(NBH * K, device=device, dtype=torch.float32) o = torch.empty_like(v) PREC = "tf32" PSOLVE = "tf32" # Decouple V-tile sizes: the intra u=A@v GEMM wants a large tile (fewer, # bigger dots); the recurrence wants a small tile (more blocks). streams = B * H # smaller V-tile for fewer-stream shapes (more blocks -> better occupancy). BV_REC = 8 if streams <= 4 else 16 BV_INTRA = V # no u-tiling: one [BC,16]@[16,V] dot per row-block BC = 16 NB = BT // BC _kda_intra_kernel[(NBH,)]( q, k, v, g, beta, w, u, Aqk, qg, kng, glast, scale, B, T, H, NT=NT, BT=BT, K=K, V=V, BV=BV_INTRA, BC=BC, NB=NB, PREC=PREC, PSOLVE=PSOLVE, num_warps=8, num_stages=1, ) _kda_rec_kernel[(triton.cdiv(V, BV_REC), B * H)]( w, u, Aqk, qg, kng, glast, o, B, T, H, NT=NT, BT=BT, K=K, V=V, BV=BV_REC, PREC=PREC, num_warps=4, num_stages=2, ) return o class Model(nn.Module): """KDA forward (chunk form). No learned parameters; all inputs are activations.""" def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 self.register_buffer("_dummy", torch.zeros(1), persistent=False) def forward(self, q, k, v, g, beta): return _kda_fwd(q, k, v, g, beta, scale=self.scale, chunk_size=self.chunk_size) # Module-level shape shims (overridden by check.py / benchmark.py per shape). B = 2 T = 1024 H = 8 K = 128 V = 128 CHUNK_SIZE = 64 def get_inputs(): torch.manual_seed(0) q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1 g = (torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05) beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16)) return [q, k, v, g, beta] def get_init_inputs(): return [B, T, H, K, V, CHUNK_SIZE]