"""Kimi Delta Attention (chunk form) forward, Triton implementation. Multi-kernel design: kernel 1 (wu_kernel): per-chunk w, u compute (parallel over B*H*NT) kernel 2 (aqk_kernel): per-chunk Aqk compute (parallel over B*H*NT) kernel 3 (o_kernel): inter-chunk output pass with V-tiling (one program per B*H, sequential over NT) """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl # ------------------------------------------------------------------------- # Kernel 1: per-chunk w, u compute. # Uses IN-PLACE Neumann to match the reference's fp32 precision. # ------------------------------------------------------------------------- @triton.jit def _wu_kernel( K_ptr, V_ptr, G_ptr, BETA_ptr, W_ptr, U_ptr, stride_kb, stride_kt, stride_kh, stride_kk, stride_vb, stride_vt, stride_vh, stride_vv, stride_gb, stride_gt, stride_gh, stride_gk, stride_bb, stride_bt, stride_bh, stride_wb, stride_wt, stride_wh, stride_wk, stride_ub, stride_ut, stride_uh, stride_uv, H: tl.constexpr, NT: tl.constexpr, BT: tl.constexpr, K_C: tl.constexpr, V_C: tl.constexpr, ): pid_bh = tl.program_id(0) chunk_idx = tl.program_id(1) b = pid_bh // H h = pid_bh % H i_idx = tl.arange(0, BT) j_idx = tl.arange(0, BT) k_idx = tl.arange(0, K_C) v_idx = tl.arange(0, V_C) mask_lt = i_idx[:, None] > j_idx[None, :] eye_mask = i_idx[:, None] == j_idx[None, :] t_start = chunk_idx * BT k_off = ( K_ptr + b * stride_kb + h * stride_kh + (t_start + i_idx)[:, None] * stride_kt + k_idx[None, :] * stride_kk ) k = tl.load(k_off).to(tl.float32) v_off = ( V_ptr + b * stride_vb + h * stride_vh + (t_start + i_idx)[:, None] * stride_vt + v_idx[None, :] * stride_vv ) v = tl.load(v_off).to(tl.float32) g_off = ( G_ptr + b * stride_gb + h * stride_gh + (t_start + i_idx)[:, None] * stride_gt + k_idx[None, :] * stride_gk ) g = tl.load(g_off) g = tl.cumsum(g, axis=0) beta_off = ( BETA_ptr + b * stride_bb + h * stride_bh + (t_start + i_idx) * stride_bt ) beta = tl.load(beta_off).to(tl.float32) g_exp = tl.exp(g) g_neg_exp = tl.exp(-g) K_ng = k * g_neg_exp T = k * g_exp # A = -K_ng @ T.T, strict lower triangular, multiplied by beta on rows A = tl.dot(K_ng, tl.trans(T), input_precision="ieee") A = tl.where(mask_lt, -A, 0.0) A = A * beta[:, None] # In-place Neumann: A[i, :i] += A[i, :] @ A[:, :i] for i = 1..BT-1 # This computes A + A^2 + A^3 + ... + A^{i-1} for the i-th row. # After this, A[i, j] for j < i is sum_{k=1}^{i-j} A^k[i, j] (the partial sum up to length i-j). for i in tl.static_range(1, BT): # Extract row i A_row_i = tl.sum(tl.where(i_idx[None, :] == i, A, 0.0), axis=1) # [BT] # Compute matvec A_row_i @ A (note: A is being updated) # The matvec gives [BT], where entry j is sum_l A_row_i[l] * A[l, j]. # For j >= i, this is 0 (since A is strictly lower). # For j < i, this is the update we want to add to A[i, j]. update = tl.sum(A_row_i[:, None] * A, axis=0) # [BT] # Add to row i new_row_i = A_row_i + update # Update A: replace row i A = tl.where(i_idx[None, :] == i, new_row_i[None, :], A) # Add I and multiply by beta on columns A = A + tl.where(eye_mask, 1.0, 0.0) A = A * beta[None, :] w = tl.dot(A, T, input_precision="ieee") u = tl.dot(A, v, input_precision="ieee") w_off = ( W_ptr + b * stride_wb + h * stride_wh + (t_start + i_idx)[:, None] * stride_wt + k_idx[None, :] * stride_wk ) tl.store(w_off, w) u_off = ( U_ptr + b * stride_ub + h * stride_uh + (t_start + i_idx)[:, None] * stride_ut + v_idx[None, :] * stride_uv ) tl.store(u_off, u) # ------------------------------------------------------------------------- # Kernel 2: per-chunk Aqk compute. # ------------------------------------------------------------------------- @triton.jit def _aqk_kernel( Q_ptr, K_ptr, G_ptr, AQK_ptr, scale, stride_qb, stride_qt, stride_qh, stride_qk, stride_kb, stride_kt, stride_kh, stride_kk, stride_gb, stride_gt, stride_gh, stride_gk, stride_ab, stride_at, stride_ah, stride_aq, H: tl.constexpr, NT: tl.constexpr, BT: tl.constexpr, K_C: tl.constexpr, ): pid_bh = tl.program_id(0) chunk_idx = tl.program_id(1) b = pid_bh // H h = pid_bh % H i_idx = tl.arange(0, BT) j_idx = tl.arange(0, BT) k_idx = tl.arange(0, K_C) mask_lt = i_idx[:, None] > j_idx[None, :] t_start = chunk_idx * BT q_off = ( Q_ptr + b * stride_qb + h * stride_qh + (t_start + i_idx)[:, None] * stride_qt + k_idx[None, :] * stride_qk ) q = tl.load(q_off).to(tl.float32) * scale k_off = ( K_ptr + b * stride_kb + h * stride_kh + (t_start + i_idx)[:, None] * stride_kt + k_idx[None, :] * stride_kk ) k = tl.load(k_off).to(tl.float32) g_off = ( G_ptr + b * stride_gb + h * stride_gh + (t_start + i_idx)[:, None] * stride_gt + k_idx[None, :] * stride_gk ) g = tl.load(g_off) g = tl.cumsum(g, axis=0) g_exp = tl.exp(g) g_neg_exp = tl.exp(-g) T = k * g_exp Qg = q * g_neg_exp Aqk = tl.dot(Qg, tl.trans(T), input_precision="ieee") Aqk = tl.where(mask_lt, Aqk, 0.0) aqk_off = ( AQK_ptr + b * stride_ab + h * stride_ah + (t_start + i_idx)[:, None] * stride_at + j_idx[None, :] * stride_aq ) tl.store(aqk_off, Aqk) # ------------------------------------------------------------------------- # Kernel 3: inter-chunk output pass with V-tiling. # ------------------------------------------------------------------------- @triton.jit def _o_kernel( Q_ptr, K_ptr, V_ptr, G_ptr, BETA_ptr, W_ptr, U_ptr, AQK_ptr, O_ptr, scale, stride_qb, stride_qt, stride_qh, stride_qk, stride_kb, stride_kt, stride_kh, stride_kk, stride_vb, stride_vt, stride_vh, stride_vv, stride_gb, stride_gt, stride_gh, stride_gk, stride_bb, stride_bt, stride_bh, stride_wb, stride_wt, stride_wh, stride_wk, stride_ub, stride_ut, stride_uh, stride_uv, stride_ab, stride_at, stride_ah, stride_aq, stride_ob, stride_ot, stride_oh, stride_ov, H: tl.constexpr, NT: tl.constexpr, BT: tl.constexpr, K_C: tl.constexpr, V_C: tl.constexpr, BV: tl.constexpr, ): pid_bh = tl.program_id(0) b = pid_bh // H h = pid_bh % H i_idx = tl.arange(0, BT) j_idx = tl.arange(0, BT) k_idx = tl.arange(0, K_C) bv_idx = tl.arange(0, BV) S0 = tl.zeros((K_C, BV), dtype=tl.float32) S1 = tl.zeros((K_C, BV), dtype=tl.float32) for chunk_i in range(NT): t_start = chunk_i * BT q_off = ( Q_ptr + b * stride_qb + h * stride_qh + (t_start + i_idx)[:, None] * stride_qt + k_idx[None, :] * stride_qk ) q = tl.load(q_off).to(tl.float32) * scale k_off = ( K_ptr + b * stride_kb + h * stride_kh + (t_start + i_idx)[:, None] * stride_kt + k_idx[None, :] * stride_kk ) k = tl.load(k_off).to(tl.float32) g_off = ( G_ptr + b * stride_gb + h * stride_gh + (t_start + i_idx)[:, None] * stride_gt + k_idx[None, :] * stride_gk ) g = tl.load(g_off) g = tl.cumsum(g, axis=0) w_off = ( W_ptr + b * stride_wb + h * stride_wh + (t_start + i_idx)[:, None] * stride_wt + k_idx[None, :] * stride_wk ) w = tl.load(w_off) aqk_off = ( AQK_ptr + b * stride_ab + h * stride_ah + (t_start + i_idx)[:, None] * stride_at + j_idx[None, :] * stride_aq ) Aqk = tl.load(aqk_off) g_exp = tl.exp(g) g_last = tl.sum(tl.where(i_idx[:, None] == BT - 1, g, 0.0), axis=0) g_last_exp = tl.exp(g_last) qg = q * g_exp K_g = k * tl.exp(g_last - g) # v_block 0 v0_idx = 0 + bv_idx u0_off = ( U_ptr + b * stride_ub + h * stride_uh + (t_start + i_idx)[:, None] * stride_ut + v0_idx[None, :] * stride_uv ) u0 = tl.load(u0_off) wS0 = tl.dot(w, S0, input_precision="ieee") v_i0 = u0 - wS0 part1_0 = tl.dot(qg, S0, input_precision="ieee") part2_0 = tl.dot(Aqk, v_i0, input_precision="ieee") o0 = part1_0 + part2_0 delta0 = tl.dot(tl.trans(K_g), v_i0, input_precision="ieee") S0 = S0 * g_last_exp[:, None] + delta0 o0_off = ( O_ptr + b * stride_ob + h * stride_oh + (t_start + i_idx)[:, None] * stride_ot + v0_idx[None, :] * stride_ov ) tl.store(o0_off, o0.to(tl.bfloat16)) # v_block 1 v1_idx = BV + bv_idx u1_off = ( U_ptr + b * stride_ub + h * stride_uh + (t_start + i_idx)[:, None] * stride_ut + v1_idx[None, :] * stride_uv ) u1 = tl.load(u1_off) wS1 = tl.dot(w, S1, input_precision="ieee") v_i1 = u1 - wS1 part1_1 = tl.dot(qg, S1, input_precision="ieee") part2_1 = tl.dot(Aqk, v_i1, input_precision="ieee") o1 = part1_1 + part2_1 delta1 = tl.dot(tl.trans(K_g), v_i1, input_precision="ieee") S1 = S1 * g_last_exp[:, None] + delta1 o1_off = ( O_ptr + b * stride_ob + h * stride_oh + (t_start + i_idx)[:, None] * stride_ot + v1_idx[None, :] * stride_ov ) tl.store(o1_off, o1.to(tl.bfloat16)) def kda_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, chunk_size: int = 64, ) -> torch.Tensor: B, T, H, K_dim = q.shape V_dim = v.shape[-1] NT = T // chunk_size BV = V_dim // 2 device = q.device dtype = v.dtype w_buf = torch.empty(B, T, H, K_dim, dtype=torch.float32, device=device) u_buf = torch.empty(B, T, H, V_dim, dtype=torch.float32, device=device) aqk_buf = torch.empty(B, T, H, chunk_size, dtype=torch.float32, device=device) o = torch.empty(B, T, H, V_dim, dtype=dtype, device=device) _wu_kernel[(B * H, NT)]( k, v, g, beta, w_buf, u_buf, k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), g.stride(0), g.stride(1), g.stride(2), g.stride(3), beta.stride(0), beta.stride(1), beta.stride(2), w_buf.stride(0), w_buf.stride(1), w_buf.stride(2), w_buf.stride(3), u_buf.stride(0), u_buf.stride(1), u_buf.stride(2), u_buf.stride(3), H=H, NT=NT, BT=chunk_size, K_C=K_dim, V_C=V_dim, num_warps=4, num_stages=1, ) _aqk_kernel[(B * H, NT)]( q, k, g, aqk_buf, scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), g.stride(0), g.stride(1), g.stride(2), g.stride(3), aqk_buf.stride(0), aqk_buf.stride(1), aqk_buf.stride(2), aqk_buf.stride(3), H=H, NT=NT, BT=chunk_size, K_C=K_dim, num_warps=4, num_stages=1, ) _o_kernel[(B * H,)]( q, k, v, g, beta, w_buf, u_buf, aqk_buf, o, scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), g.stride(0), g.stride(1), g.stride(2), g.stride(3), beta.stride(0), beta.stride(1), beta.stride(2), w_buf.stride(0), w_buf.stride(1), w_buf.stride(2), w_buf.stride(3), u_buf.stride(0), u_buf.stride(1), u_buf.stride(2), u_buf.stride(3), aqk_buf.stride(0), aqk_buf.stride(1), aqk_buf.stride(2), aqk_buf.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), H=H, NT=NT, BT=chunk_size, K_C=K_dim, V_C=V_dim, BV=BV, num_warps=2, num_stages=1, ) return o class Model(nn.Module): def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 self.register_buffer("_dummy", torch.zeros(1), persistent=False) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: return kda_fwd(q, k, v, g, beta, scale=self.scale, chunk_size=self.chunk_size) def get_inputs(): torch.manual_seed(0) q = torch.randn(2, 1024, 8, 128, dtype=torch.bfloat16) * 0.1 k = torch.randn(2, 1024, 8, 128, dtype=torch.bfloat16) * 0.1 v = torch.randn(2, 1024, 8, 128, dtype=torch.bfloat16) * 0.1 g = (torch.randn(2, 1024, 8, 128, dtype=torch.float32) * 0.1 - 0.05) beta = torch.sigmoid(torch.randn(2, 1024, 8, dtype=torch.bfloat16)) return [q, k, v, g, beta] def get_init_inputs(): return [2, 1024, 8, 128, 128, 64]