"""Optimized Kimi Delta Attention (KDA) forward, chunk form.""" from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl from einops import rearrange OP_TYPE = "linear_attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] @triton.jit def recurrence_kernel( Q_ptr, K_ptr, G_ptr, U_ptr, W_ptr, Aqk_ptr, O_ptr, stride_qb, stride_qh, stride_qn, stride_qc, stride_qd, stride_kb, stride_kh, stride_kn, stride_kc, stride_kd, stride_gb, stride_gh, stride_qn_g, stride_qc_g, stride_qd_g, stride_ub, stride_uh, stride_qn_u, stride_qc_u, stride_qd_u, stride_wb, stride_wh, stride_qn_w, stride_qc_w, stride_qd_w, stride_aqkb, stride_aqkh, stride_qn_aqk, stride_qc_aqk, stride_qd_aqk, stride_ob, stride_oh, stride_on, stride_oc, stride_od, B, H, NT, BT: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr ): pid = tl.program_id(0) b = pid // H h = pid % H # Offsets for K and V dimensions offs_k = tl.arange(0, DK) offs_v = tl.arange(0, DV) offs_c = tl.arange(0, BT) # Initialize S to zeros in registers S = tl.zeros((DK, DV), dtype=tl.float32) for i in range(0, NT): # Compute pointers for current chunk i q_offset = b * stride_qb + h * stride_qh + i * stride_qn k_offset = b * stride_kb + h * stride_kh + i * stride_kn g_offset = b * stride_gb + h * stride_gh + i * stride_qn_g u_offset = b * stride_ub + h * stride_uh + i * stride_qn_u w_offset = b * stride_wb + h * stride_wh + i * stride_qn_w aqk_offset = b * stride_aqkb + h * stride_aqkh + i * stride_qn_aqk # Load Q, K, G, U, W, Aqk for chunk i (loaded in bf16 and cast to fp32) Q = tl.load(Q_ptr + q_offset + offs_c[:, None] * stride_qc + offs_k[None, :] * stride_qd).to(tl.float32) K = tl.load(K_ptr + k_offset + offs_c[:, None] * stride_kc + offs_k[None, :] * stride_kd).to(tl.float32) G = tl.load(G_ptr + g_offset + offs_c[:, None] * stride_qc_g + offs_k[None, :] * stride_qd_g).to(tl.float32) U = tl.load(U_ptr + u_offset + offs_c[:, None] * stride_qc_u + offs_v[None, :] * stride_qd_u).to(tl.float32) W = tl.load(W_ptr + w_offset + offs_c[:, None] * stride_qc_w + offs_k[None, :] * stride_qd_w).to(tl.float32) Aqk = tl.load(Aqk_ptr + aqk_offset + offs_c[:, None] * stride_qc_aqk + offs_c[None, :] * stride_qd_aqk).to(tl.float32) # 1. v_i = u_i - w_i @ S W_S = tl.dot(W, S) V_i = U - W_S # 2. o_i = (q_i * exp(g_i)) @ S + Aqk @ v_i Q_decayed = Q * tl.exp(G) Q_S = tl.dot(Q_decayed, S) Aqk_V = tl.dot(Aqk, V_i) O_i = Q_S + Aqk_V # Store O_i to DRAM (cast to bf16) o_offset = b * stride_ob + h * stride_oh + i * stride_on tl.store(O_ptr + o_offset + offs_c[:, None] * stride_oc + offs_v[None, :] * stride_od, O_i.to(tl.bfloat16)) # 3. S_new = S * decay + update g_last = tl.load(G_ptr + g_offset + (BT - 1) * stride_qc_g + offs_k * stride_qd_g).to(tl.float32) decay = tl.exp(g_last)[:, None] S = S * decay g_last_expanded = g_last[None, :] k_decayed = tl.exp(g_last_expanded - G) * K k_decayed_T = tl.trans(k_decayed) update = tl.dot(k_decayed_T, V_i) S = S + update def _triton_recurrence(q, k, g, u, w, Aqk, NT, BT, K, V, dtype): B, H = q.shape[0], q.shape[1] o = torch.empty((B, H, NT, BT, V), dtype=dtype, device=q.device) grid = (B * H,) recurrence_kernel[grid]( q, k, g, u, w, Aqk, o, *q.stride(), *k.stride(), *g.stride(), *u.stride(), *w.stride(), *Aqk.stride(), *o.stride(), B, H, NT, BT, K, V, num_stages=1 ) return o @torch.compile(mode="reduce-overhead", fullgraph=True) def _intra_chunk_pass(q, k, v, g, beta, scale, BT): # Keep activations in bfloat16 dtype = q.dtype B, T, H, K = q.shape V = v.shape[-1] NT = T // BT # Scale query q = q * scale q = rearrange(q, "b (n c) h d -> b h n c d", c=BT) k = rearrange(k, "b (n c) h d -> b h n c d", c=BT) v = rearrange(v, "b (n c) h d -> b h n c d", c=BT) g = rearrange(g, "b (n c) h d -> b h n c d", c=BT) beta = rearrange(beta, "b (n c) h -> b h n c", c=BT) # In-chunk cumsum on g (keeps as fp32) g = g.cumsum(-2) # Convert exponential decays to bfloat16 for matrix operations g_exp = g.exp().to(dtype) g_neg_exp = (-g).exp().to(dtype) # ---- Build A_kk (intra-chunk K-K interaction, lower-triangular w/ diag masked) ---- k_decayed_c = k * g_exp k_decayed_j = k * g_neg_exp A = torch.matmul(k_decayed_c, k_decayed_j.transpose(-1, -2)) A = A * beta[..., None] mask_diag_upper = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=0) A = -A.masked_fill(mask_diag_upper, 0) # ---- Block Inversion of (I - A) ---- M = torch.eye(BT, dtype=dtype, device=A.device) - A X = torch.eye(BT, dtype=dtype, device=A.device).expand(B, H, NT, BT, BT).clone() step = 1 while step < BT: num_blocks = BT // (2 * step) # Extract block-diagonals in parallel X_reshaped = X.view(B, H, NT, num_blocks, 2 * step, num_blocks, 2 * step) X_diag = torch.diagonal(X_reshaped, dim1=3, dim2=5).permute(0, 1, 2, 5, 3, 4) M_reshaped = M.view(B, H, NT, num_blocks, 2 * step, num_blocks, 2 * step) M_diag = torch.diagonal(M_reshaped, dim1=3, dim2=5).permute(0, 1, 2, 5, 3, 4) A_inv = X_diag[..., 0 : step, 0 : step] D_inv = X_diag[..., step : 2*step, step : 2*step] C = M_diag[..., step : 2*step, 0 : step] res = -D_inv @ C @ A_inv for b in range(num_blocks): i = b * 2 * step X[..., i+step : i+2*step, i : i+step] = res[..., b, :, :] step *= 2 A_final = X * beta[..., None, :] w = A_final @ (g_exp * k) u = A_final @ v # ---- Compute Aqk ---- q_decayed_c = q * g_exp Aqk = torch.matmul(q_decayed_c, k_decayed_j.transpose(-1, -2)) mask_strict_upper = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=1) Aqk = Aqk.masked_fill(mask_strict_upper, 0) return q, k, g, u, w, Aqk def _optimized_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, chunk_size: int = 64, ) -> torch.Tensor: """KDA forward, no initial state, no final state. Returns o with v's dtype.""" dtype = v.dtype BT = chunk_size T = q.shape[1] NT = T // BT K = q.shape[-1] V = v.shape[-1] # Run compiled intra-chunk pass q_out, k_out, g_out, u_out, w_out, Aqk_out = _intra_chunk_pass(q, k, v, g, beta, scale, BT) # Run Triton recurrence pass o = _triton_recurrence(q_out, k_out, g_out, u_out, w_out, Aqk_out, NT, BT, K, V, dtype) o = rearrange(o, "b h n c d -> b (n c) h d") return o.to(dtype) class Model(nn.Module): """KDA forward (chunk form). No learned parameters; all inputs are activations.""" def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 # No learned params; declare a dummy buffer so state_dict is well-defined. self.register_buffer("_dummy", torch.zeros(1), persistent=False) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: return _optimized_forward(q, k, v, g, beta, scale=self.scale, chunk_size=self.chunk_size) # Module-level shape shims (overridden by check.py / benchmark.py per shape). B = 2 T = 1024 H = 8 K = 128 V = 128 CHUNK_SIZE = 64 def get_inputs(): """Return a list of activations for one forward call.""" torch.manual_seed(0) q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1 # log-decay: small negative numbers so exp(g) is in (0, 1). g = (torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05) beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16)) return [q, k, v, g, beta] def get_init_inputs(): return [B, T, H, K, V, CHUNK_SIZE]