"""Kimi Delta Attention forward (chunk form) — custom Triton kernels for SM120. Chunk-parallel design (FLA-style 3-kernel decomposition), all bf16 tensor cores, launched once via a CUDA graph bound to the input tensors: Kernel A (prepare, parallel over b*h*chunk): build the intra-chunk WY transform. - gc = cumsum(g) within chunk (done as a lower-triangular ones matmul on TCs) - A0 = strict-lower(-beta[c] * (k*e^gc) @ (k*e^-gc)^T) - Tinv = (I - A0)^{-1} via Neumann doubling (ITERS iters; A0 nilpotent and its high powers are decay-suppressed, so 4 iters covers tolerance with wide margin) - w = Tinv @ (beta * e^gc * k), u = Tinv @ (beta * v) - also precomputes kd^T = (e^last * k * e^-gc)^T and decay = e^last so the scan's hot path carries no exp/cumsum/transpose. Kernel B (state scan, parallel over b*h*v-block, sequential over chunks): the only sequential pass. Keeps recurrent state S (K x BV) in registers, emits per-chunk start state h_n and the corrected values v_new_n = u_n - w_n @ h_n. Kernel C (output, fully parallel over b*h*chunk*v-block): the heavy compute. - o = (q*scale*e^gc) @ h_n + tril(Aqk) @ v_new_n """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl @triton.jit def _kda_prepare_kernel( k_ptr, v_ptr, g_ptr, beta_ptr, w_ptr, u_ptr, kd_ptr, decay_ptr, B, T, H, NT, K: tl.constexpr, V: tl.constexpr, C: tl.constexpr, ITERS: tl.constexpr, ): pid_n = tl.program_id(0) pid_bh = tl.program_id(1) b = pid_bh // H h = pid_bh % H sb_k = T * H * K sb_v = T * H * V base_k = k_ptr + b * sb_k + h * K base_g = g_ptr + b * sb_k + h * K base_v = v_ptr + b * sb_v + h * V base_w = w_ptr + b * sb_k + h * K base_u = u_ptr + b * sb_v + h * V p_k = tl.make_block_ptr(base_k, (T, K), (H * K, 1), (pid_n * C, 0), (C, K), (1, 0)) p_g = tl.make_block_ptr(base_g, (T, K), (H * K, 1), (pid_n * C, 0), (C, K), (1, 0)) p_v = tl.make_block_ptr(base_v, (T, V), (H * V, 1), (pid_n * C, 0), (C, V), (1, 0)) k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32) offs_c = tl.arange(0, C) p_beta = beta_ptr + b * (T * H) + (pid_n * C + offs_c) * H + h beta = tl.load(p_beta).to(tl.float32) Ltri = tl.where(offs_c[:, None] >= offs_c[None, :], 1.0, 0.0) gc = tl.dot(Ltri, g, input_precision="tf32") # cumulative sum via tri-matmul last = tl.sum(g, axis=0) # gc at last row (K,) egc = tl.exp(gc) inv_egc = 1.0 / egc # = exp(-gc) decay_vec = tl.exp(last) # (K,) kg = k * egc kng = k * inv_egc Kgg = tl.dot(kg.to(tl.bfloat16), tl.trans(kng).to(tl.bfloat16)) # (C, C) row = offs_c[:, None] col = offs_c[None, :] A0 = tl.where(row > col, -beta[:, None] * Kgg, 0.0) M = tl.where(row == col, 1.0, 0.0) P = A0 for i in tl.static_range(ITERS): M = M + tl.dot(P.to(tl.bfloat16), M.to(tl.bfloat16)) if i < ITERS - 1: P = tl.dot(P.to(tl.bfloat16), P.to(tl.bfloat16)) beta_kg = (beta[:, None] * kg).to(tl.bfloat16) beta_v = (beta[:, None] * v).to(tl.bfloat16) Mb = M.to(tl.bfloat16) w = tl.dot(Mb, beta_kg) u = tl.dot(Mb, beta_v) # state-scan precompute: kd = e^(last-gc)*k = e^last * (k*e^-gc) = decay * kng # store transposed (K, C) so the sequential scan avoids tl.trans on its hot path kdt = tl.trans(decay_vec[None, :] * kng) # (K, C) p_w = tl.make_block_ptr(base_w, (T, K), (H * K, 1), (pid_n * C, 0), (C, K), (1, 0)) p_u = tl.make_block_ptr(base_u, (T, V), (H * V, 1), (pid_n * C, 0), (C, V), (1, 0)) kdt_base = kd_ptr + (pid_bh * NT + pid_n) * K * C p_kdt = tl.make_block_ptr(kdt_base, (K, C), (C, 1), (0, 0), (K, C), (1, 0)) tl.store(p_w, w.to(w_ptr.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_u, u.to(u_ptr.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_kdt, kdt.to(kd_ptr.dtype.element_ty), boundary_check=(0, 1)) offs_k = tl.arange(0, K) tl.store(decay_ptr + (pid_bh * NT + pid_n) * K + offs_k, decay_vec) @triton.jit def _kda_state_kernel( w_ptr, u_ptr, kd_ptr, decay_ptr, h_ptr, vnew_ptr, B, T, H, NT, K: tl.constexpr, V: tl.constexpr, C: tl.constexpr, BV: tl.constexpr, ): pid_bh = tl.program_id(0) pid_v = tl.program_id(1) b = pid_bh // H h = pid_bh % H v0 = pid_v * BV sb_k = T * H * K sb_v = T * H * V base_w = w_ptr + b * sb_k + h * K base_u = u_ptr + b * sb_v + h * V base_vn = vnew_ptr + b * sb_v + h * V sb_h = H * NT * K * V offs_k = tl.arange(0, K) S = tl.zeros((K, BV), dtype=tl.float32) for n in range(NT): toff = n * C p_w = tl.make_block_ptr(base_w, (T, K), (H * K, 1), (toff, 0), (C, K), (1, 0)) kdt_base = kd_ptr + (pid_bh * NT + n) * K * C p_kdt = tl.make_block_ptr(kdt_base, (K, C), (C, 1), (0, 0), (K, C), (1, 0)) p_u = tl.make_block_ptr(base_u, (T, V), (H * V, 1), (toff, v0), (C, BV), (1, 0)) w = tl.load(p_w, boundary_check=(0, 1)) kdt = tl.load(p_kdt, boundary_check=(0, 1)) u = tl.load(p_u, boundary_check=(0, 1)).to(tl.float32) decay = tl.load(decay_ptr + (pid_bh * NT + n) * K + offs_k) # store start-of-chunk state h_n h_base = h_ptr + (b * sb_h + (h * NT + n) * K * V) p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v0), (K, BV), (1, 0)) tl.store(p_h, S.to(h_ptr.dtype.element_ty), boundary_check=(0, 1)) v_new = u - tl.dot(w, S.to(w.dtype), input_precision="tf32") p_vn = tl.make_block_ptr(base_vn, (T, V), (H * V, 1), (toff, v0), (C, BV), (1, 0)) tl.store(p_vn, v_new.to(vnew_ptr.dtype.element_ty), boundary_check=(0, 1)) S = decay[:, None] * S + tl.dot(kdt, v_new.to(kdt.dtype), input_precision="tf32") @triton.jit def _kda_output_kernel( q_ptr, k_ptr, g_ptr, h_ptr, vnew_ptr, o_ptr, B, T, H, NT, scale, K: tl.constexpr, V: tl.constexpr, C: tl.constexpr, BV: tl.constexpr, ): pid_bh = tl.program_id(0) pid_n = tl.program_id(1) pid_v = tl.program_id(2) b = pid_bh // H h = pid_bh % H v0 = pid_v * BV toff = pid_n * C sb_k = T * H * K sb_v = T * H * V base_q = q_ptr + b * sb_k + h * K base_k = k_ptr + b * sb_k + h * K base_g = g_ptr + b * sb_k + h * K base_vn = vnew_ptr + b * sb_v + h * V base_o = o_ptr + b * sb_v + h * V sb_h = H * NT * K * V p_q = tl.make_block_ptr(base_q, (T, K), (H * K, 1), (toff, 0), (C, K), (1, 0)) p_k = tl.make_block_ptr(base_k, (T, K), (H * K, 1), (toff, 0), (C, K), (1, 0)) p_g = tl.make_block_ptr(base_g, (T, K), (H * K, 1), (toff, 0), (C, K), (1, 0)) q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) offs_c = tl.arange(0, C) Ltri = tl.where(offs_c[:, None] >= offs_c[None, :], 1.0, 0.0) gc = tl.dot(Ltri, g, input_precision="tf32") egc = tl.exp(gc) qg = (q * scale) * egc kng = k * (1.0 / egc) qgb = qg.to(tl.bfloat16) Aqk = tl.dot(qgb, tl.trans(kng).to(tl.bfloat16)) Aqk = tl.where(offs_c[:, None] >= offs_c[None, :], Aqk, 0.0).to(tl.bfloat16) h_base = h_ptr + (b * sb_h + (h * NT + pid_n) * K * V) p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v0), (K, BV), (1, 0)) h_state = tl.load(p_h, boundary_check=(0, 1)) p_vn = tl.make_block_ptr(base_vn, (T, V), (H * V, 1), (toff, v0), (C, BV), (1, 0)) v_new = tl.load(p_vn, boundary_check=(0, 1)) o = tl.dot(qgb, h_state) + tl.dot(Aqk, v_new) p_o = tl.make_block_ptr(base_o, (T, V), (H * V, 1), (toff, v0), (C, BV), (1, 0)) tl.store(p_o, o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) import os as _os _WA = int(_os.environ.get("KDA_WA", "8")); _SA = int(_os.environ.get("KDA_SA", "2")) _NITER = int(_os.environ.get("KDA_NITER", "4")) _WB = int(_os.environ.get("KDA_WB", "4")); _SB = int(_os.environ.get("KDA_SB", "2")) _WC = int(_os.environ.get("KDA_WC", "4")); _SC = int(_os.environ.get("KDA_SC", "2")) def _launch(bufs, B, T, H, K, V, C, NT, scale, BV_STATE, BV_OUT): q, k, v, g, beta, w, u, kd, decay, vnew, hstates, o = bufs _kda_prepare_kernel[(NT, B * H)]( k, v, g, beta, w, u, kd, decay, B, T, H, NT, K, V, C, _NITER, num_warps=_WA, num_stages=_SA, ) _kda_state_kernel[(B * H, V // BV_STATE)]( w, u, kd, decay, hstates, vnew, B, T, H, NT, K, V, C, BV_STATE, num_warps=_WB, num_stages=_SB, ) _kda_output_kernel[(B * H, NT, V // BV_OUT)]( q, k, g, hstates, vnew, o, B, T, H, NT, scale, K, V, C, BV_OUT, num_warps=_WC, num_stages=_SC, ) class Model(nn.Module): def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 self.register_buffer("_dummy", torch.zeros(1), persistent=False) self._graph = None # None=not tried, False=disabled, else CUDAGraph self._inter = None self._cap_ptrs = None import os self.BV_STATE = int(os.environ.get("KDA_BVS", "16")) nprog_out = B * H * (T // chunk_size) default_bvo = 64 if nprog_out <= 256 else 128 self.BV_OUT = int(os.environ.get("KDA_BVO", str(default_bvo))) def _alloc(self, device): B, T, H, K, V = self.B, self.T, self.H, self.K, self.V C = self.chunk_size NT = T // C bf = torch.bfloat16 w = torch.empty((B, T, H, K), dtype=bf, device=device) u = torch.empty((B, T, H, V), dtype=bf, device=device) kd = torch.empty((B, T, H, K), dtype=bf, device=device) decay = torch.empty((B * H, NT, K), dtype=torch.float32, device=device) vnew = torch.empty((B, T, H, V), dtype=bf, device=device) hstates = torch.empty((B, H, NT, K, V), dtype=bf, device=device) o = torch.empty((B, T, H, V), dtype=bf, device=device) self._inter = (w, u, kd, decay, vnew, hstates, o) self._out = o self._launch_args = (B, T, H, K, V, C, NT, self.scale, self.BV_STATE, self.BV_OUT) def _capture(self, q, k, v, g, beta): bufs = (q, k, v, g, beta, *self._inter) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): _launch(bufs, *self._launch_args) torch.cuda.current_stream().wait_stream(s) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): _launch(bufs, *self._launch_args) self._graph = graph def forward(self, q, k, v, g, beta): if self._inter is None: self._alloc(q.device) ptrs = (q.data_ptr(), k.data_ptr(), v.data_ptr(), g.data_ptr(), beta.data_ptr()) if self._graph not in (None, False) and ptrs == self._cap_ptrs: self._graph.replay() return self._out if self._graph is None: try: self._capture(q, k, v, g, beta) self._cap_ptrs = ptrs self._graph.replay() return self._out except Exception: self._graph = False _launch((q, k, v, g, beta, *self._inter), *self._launch_args) return self._out