"""Kimi Delta Attention forward (chunk form) — custom Triton kernels for SM120. Pipeline (chunk size BT=64, K=V=128): K1 (parallel over chunks): g cumsum -> Akk/Aqk via factored exp2 bf16 GEMMs, (I + tril(Akk))^{-1} via 16x16 fp32 forward-substitution + block merge, then w/u/kg/qg precomputation. One program per (chunk, b*h). K2 (sequential over chunks): inter-chunk state recurrence, parallel over (b*h, V-blocks). Stores per-chunk state h and corrected values vnew. K3 (parallel over chunks): o = qg @ h + tril(Aqk) @ vnew. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "linear_attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] RCP_LN2 = tl.constexpr(1.4426950408889634) @triton.jit def kda_prep_kernel( q, k, v, g, beta, Aqk, w, u, kg, qg, gexp, SA, SM, scale, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H NT = T // BT o_t = tl.arange(0, BT) o_k = tl.arange(0, K) o_v = tl.arange(0, V) tok = i_b * T + i_t * BT # first global token of this chunk # input row pointers: layout (B, T, H, *) qk_rows = ((tok + o_t) * H + i_h) * K v_rows = ((tok + o_t) * H + i_h) * V b_g = tl.load(g + qk_rows[:, None] + o_k[None, :]) g2 = tl.cumsum(b_g, axis=0) * RCP_LN2 # reference row (middle) and last row, via masked reductions ref = tl.sum(tl.where(o_t[:, None] == BT // 2 - 1, g2, 0.0), 0) # (K,) gl = tl.sum(tl.where(o_t[:, None] == BT - 1, g2, 0.0), 0) # (K,) tl.store(gexp + (i_bh * NT + i_t) * K + o_k, tl.exp2(gl)) b_q = tl.load(q + qk_rows[:, None] + o_k[None, :]) b_k = tl.load(k + qk_rows[:, None] + o_k[None, :]) b_beta = tl.load(beta + (tok + o_t) * H + i_h).to(tl.float32) e_pos = tl.exp2(g2 - ref[None, :]) e_neg = tl.exp2(ref[None, :] - g2) b_kpos = (b_k * e_pos).to(tl.bfloat16) b_kneg = (b_k * e_neg).to(tl.bfloat16) b_qpos = (b_q * (scale * e_pos)).to(tl.bfloat16) b_kneg_t = tl.trans(b_kneg) b_Aqk = tl.dot(b_qpos, b_kneg_t) b_Akk = tl.dot(b_kpos, b_kneg_t) * b_beta[:, None] m_lower = o_t[:, None] >= o_t[None, :] m_strict = o_t[:, None] > o_t[None, :] b_Aqk = tl.where(m_lower, b_Aqk, 0.0) b_Akk = tl.where(m_strict, b_Akk, 0.0) # store Aqk (bf16) and Akk scratch (fp32) aqk_base = Aqk + (i_bh * NT + i_t) * BT * BT tl.store(aqk_base + o_t[:, None] * BT + o_t[None, :], b_Aqk.to(tl.bfloat16)) sa = SA + (i_bh * NT + i_t) * BT * BT tl.store(sa + o_t[:, None] * BT + o_t[None, :], b_Akk) tl.debug_barrier() # ---- invert (I + Akk) : batched 16x16 diagonal forward substitution ---- BC: tl.constexpr = 16 NC: tl.constexpr = BT // BC o_b = tl.arange(0, NC) o_i = tl.arange(0, BC) # load diag blocks (NC, BC, BC) d_ptr = sa + (o_b[:, None, None] * BC + o_i[None, :, None]) * BT \ + o_b[:, None, None] * BC + o_i[None, None, :] b_Ai = -tl.load(d_ptr) for i in tl.static_range(2, BC): r_ptr = sa + (o_b[:, None] * BC + i) * BT + o_b[:, None] * BC + o_i[None, :] b_a = -tl.load(r_ptr) b_a = tl.where(o_i[None, :] < i, b_a, 0.0) b_a += tl.sum(b_a[:, :, None] * b_Ai, 1) b_Ai = tl.where((o_i == i)[None, :, None], b_a[:, None, :], b_Ai) b_Ai += (o_i[:, None] == o_i[None, :])[None, :, :].to(tl.float32) sm = SM + (i_bh * NT + i_t) * BT * BT m_ptr = sm + (o_b[:, None, None] * BC + o_i[None, :, None]) * BT \ + o_b[:, None, None] * BC + o_i[None, None, :] tl.store(m_ptr, b_Ai) tl.debug_barrier() # ---- block merge: M[i][j] = -Mi[i] @ (sum_k Akk[i][k] @ M[k][j]) ---- oc = tl.arange(0, BC) for bi in tl.static_range(1, NC): mi_ptr = sm + (bi * BC + oc[:, None]) * BT + bi * BC + oc[None, :] b_mi = tl.load(mi_ptr) for bj in tl.static_range(0, NC - 1): if bj < bi: acc = tl.zeros([BC, BC], dtype=tl.float32) for bk in tl.static_range(0, NC - 1): if (bk >= bj) and (bk < bi): a_ptr = sa + (bi * BC + oc[:, None]) * BT + bk * BC + oc[None, :] mkj_ptr = sm + (bk * BC + oc[:, None]) * BT + bj * BC + oc[None, :] b_ab = tl.load(a_ptr) b_mkj = tl.load(mkj_ptr) acc += tl.dot(b_ab, b_mkj, input_precision="tf32") b_mij = -tl.dot(b_mi, acc, input_precision="tf32") mij_ptr = sm + (bi * BC + oc[:, None]) * BT + bj * BC + oc[None, :] tl.store(mij_ptr, b_mij) tl.debug_barrier() # ---- load assembled M (mask upper garbage), compute w/u, store kg/qg ---- b_M = tl.load(sm + o_t[:, None] * BT + o_t[None, :]) b_M = tl.where(m_lower, b_M, 0.0).to(tl.bfloat16) e_g = tl.exp2(g2) b_kbg = (b_k * e_g * b_beta[:, None]).to(tl.bfloat16) b_w = tl.dot(b_M, b_kbg) wukg_rows = (i_bh * T + i_t * BT + o_t) tl.store(w + wukg_rows[:, None] * K + o_k[None, :], b_w.to(tl.bfloat16)) b_v = tl.load(v + v_rows[:, None] + o_v[None, :]) b_vb = (b_v.to(tl.float32) * b_beta[:, None]).to(tl.bfloat16) b_u = tl.dot(b_M, b_vb) tl.store(u + wukg_rows[:, None] * V + o_v[None, :], b_u.to(tl.bfloat16)) b_kg = b_k * tl.exp2(gl[None, :] - g2) tl.store(kg + wukg_rows[:, None] * K + o_k[None, :], b_kg.to(tl.bfloat16)) b_qg = b_q * (scale * e_g) tl.store(qg + wukg_rows[:, None] * K + o_k[None, :], b_qg.to(tl.bfloat16)) @triton.jit def kda_h_kernel( w, u, kg, gexp, hbuf, vnew, T, NT, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BV: tl.constexpr, ): i_v, i_bh = tl.program_id(0), tl.program_id(1) o_t = tl.arange(0, BT) o_k = tl.arange(0, K) o_v = i_v * BV + tl.arange(0, BV) b_h = tl.zeros([K, BV], dtype=tl.float32) for n in range(0, NT): rows = i_bh * T + n * BT + o_t # store state at chunk start h_ptr = hbuf + (i_bh * NT + n) * K * V + o_k[:, None] * V + o_v[None, :] b_hb = b_h.to(tl.bfloat16) tl.store(h_ptr, b_hb) b_w = tl.load(w + rows[:, None] * K + o_k[None, :]) b_u = tl.load(u + rows[:, None] * V + o_v[None, :]).to(tl.float32) b_vn = b_u - tl.dot(b_w, b_hb) b_vnb = b_vn.to(tl.bfloat16) tl.store(vnew + rows[:, None] * V + o_v[None, :], b_vnb) b_kg = tl.load(kg + rows[:, None] * K + o_k[None, :]) b_gexp = tl.load(gexp + (i_bh * NT + n) * K + o_k) b_h = b_h * b_gexp[:, None] + tl.dot(tl.trans(b_kg), b_vnb) @triton.jit def kda_o_kernel( qg, Aqk, hbuf, vnew, o, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H NT = T // BT o_t = tl.arange(0, BT) o_k = tl.arange(0, K) o_v = tl.arange(0, V) rows = i_bh * T + i_t * BT + o_t b_qg = tl.load(qg + rows[:, None] * K + o_k[None, :]) h_ptr = hbuf + (i_bh * NT + i_t) * K * V + o_k[:, None] * V + o_v[None, :] b_h = tl.load(h_ptr) b_o = tl.dot(b_qg, b_h) aqk_base = Aqk + (i_bh * NT + i_t) * BT * BT b_A = tl.load(aqk_base + o_t[:, None] * BT + o_t[None, :]) b_vn = tl.load(vnew + rows[:, None] * V + o_v[None, :]) b_o += tl.dot(b_A, b_vn) tok = i_b * T + i_t * BT out_rows = ((tok + o_t) * H + i_h) * V tl.store(o + out_rows[:, None] + o_v[None, :], b_o.to(tl.bfloat16)) class _Workspace: def __init__(self, B, T, H, K, V, BT, device): NT = T // BT BH = B * H self.Aqk = torch.empty(BH * NT, BT, BT, dtype=torch.bfloat16, device=device) self.w = torch.empty(BH * T, K, dtype=torch.bfloat16, device=device) self.u = torch.empty(BH * T, V, dtype=torch.bfloat16, device=device) self.kg = torch.empty(BH * T, K, dtype=torch.bfloat16, device=device) self.qg = torch.empty(BH * T, K, dtype=torch.bfloat16, device=device) self.gexp = torch.empty(BH * NT, K, dtype=torch.float32, device=device) self.SA = torch.empty(BH * NT, BT, BT, dtype=torch.float32, device=device) self.SM = torch.empty(BH * NT, BT, BT, dtype=torch.float32, device=device) self.hbuf = torch.empty(BH * NT, K, V, dtype=torch.bfloat16, device=device) self.vnew = torch.empty(BH * T, V, dtype=torch.bfloat16, device=device) self.o = torch.empty(B, T, H, V, dtype=torch.bfloat16, device=device) def kda_fwd(q, k, v, g, beta, scale, BT, ws): B, T, H, K = q.shape V = v.shape[-1] NT = T // BT BH = B * H kda_prep_kernel[(NT, BH)]( q, k, v, g, beta, ws.Aqk, ws.w, ws.u, ws.kg, ws.qg, ws.gexp, ws.SA, ws.SM, scale, T, H=H, K=K, V=V, BT=BT, num_warps=8, ) BV = 64 kda_h_kernel[(V // BV, BH)]( ws.w, ws.u, ws.kg, ws.gexp, ws.hbuf, ws.vnew, T, NT, K=K, V=V, BT=BT, BV=BV, num_warps=4, ) kda_o_kernel[(NT, BH)]( ws.qg, ws.Aqk, ws.hbuf, ws.vnew, ws.o, T, H=H, K=K, V=V, BT=BT, num_warps=8, ) return ws.o class Model(nn.Module): """KDA forward (chunk form). No learned parameters; all inputs are activations.""" def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 self.register_buffer("_dummy", torch.zeros(1), persistent=False) self._ws = {} def _workspace(self, B, T, H, K, V, device): key = (B, T, H, K, V, device) ws = self._ws.get(key) if ws is None: ws = _Workspace(B, T, H, K, V, self.chunk_size, device) self._ws[key] = ws return ws def forward(self, q, k, v, g, beta): B, T, H, K = q.shape V = v.shape[-1] ws = self._workspace(B, T, H, K, V, q.device) return kda_fwd(q, k, v, g, beta, self.scale, self.chunk_size, ws) # Module-level shape shims (overridden by check.py / benchmark.py per shape). B = 2 T = 1024 H = 8 K = 128 V = 128 CHUNK_SIZE = 64 def get_inputs(): torch.manual_seed(0) q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1 g = (torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05) beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16)) return [q, k, v, g, beta] def get_init_inputs(): return [B, T, H, K, V, CHUNK_SIZE]