"""Kimi Delta Attention (KDA) forward (chunk form) via Triton. Matches reference.py's semantics: inputs are bf16 (g fp32), chunk_size=64, no initial/final state, returns bf16 o. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "linear_attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] @triton.autotune( configs=[ triton.Config({"BK": 64, "BV": 8}, num_warps=4, num_stages=1), triton.Config({"BK": 64, "BV": 8}, num_warps=8, num_stages=1), triton.Config({"BK": 64, "BV": 16}, num_warps=4, num_stages=1), triton.Config({"BK": 64, "BV": 16}, num_warps=8, num_stages=1), triton.Config({"BK": 64, "BV": 32}, num_warps=4, num_stages=1), triton.Config({"BK": 64, "BV": 32}, num_warps=8, num_stages=1), triton.Config({"BK": 64, "BV": 64}, num_warps=4, num_stages=1), triton.Config({"BK": 64, "BV": 64}, num_warps=8, num_stages=1), triton.Config({"BK": 128, "BV": 8}, num_warps=4, num_stages=1), triton.Config({"BK": 128, "BV": 8}, num_warps=8, num_stages=1), triton.Config({"BK": 128, "BV": 16}, num_warps=4, num_stages=1), triton.Config({"BK": 128, "BV": 16}, num_warps=8, num_stages=1), triton.Config({"BK": 128, "BV": 32}, num_warps=4, num_stages=1), triton.Config({"BK": 128, "BV": 32}, num_warps=8, num_stages=1), triton.Config({"BK": 128, "BV": 64}, num_warps=4, num_stages=1), triton.Config({"BK": 128, "BV": 64}, num_warps=8, num_stages=1), ], key=["B", "T", "H", "K", "V"], ) @triton.jit def _kda_intra_kernel( q_ptr, k_ptr, v_ptr, g_ptr, gmid_ptr, beta_ptr, Aqk_ptr, w_ptr, u_ptr, scale, B: tl.constexpr, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, ): i_n = tl.program_id(0) i_bh = tl.program_id(1) i_b = i_bh // H i_h = i_bh % H q_ptr += (i_b * T * H + i_h) * K k_ptr += (i_b * T * H + i_h) * K g_ptr += (i_b * T * H + i_h) * K gmid_ptr += (i_b * H + i_h) * K beta_ptr += i_b * T * H + i_h v_ptr += (i_b * T * H + i_h) * V Aqk_ptr += (i_b * T * H + i_h) * BT w_ptr += (i_b * T * H + i_h) * K u_ptr += (i_b * T * H + i_h) * V t0 = i_n * BT r = tl.arange(0, BT) c = tl.arange(0, BT) mask_strict = r[:, None] > c[None, :] mask_lower = r[:, None] >= c[None, :] Aqk = tl.zeros([BT, BT], dtype=tl.float32) M = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): off_k = i_k * BK p_q = tl.make_block_ptr( q_ptr, (T, K), (H * K, 1), (t0, off_k), (BT, BK), (1, 0) ) p_k = tl.make_block_ptr( k_ptr, (T, K), (H * K, 1), (t0, off_k), (BT, BK), (1, 0) ) p_g = tl.make_block_ptr( g_ptr, (T, K), (H * K, 1), (t0, off_k), (BT, BK), (1, 0) ) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) p_gmid = tl.make_block_ptr( gmid_ptr, (H, K), (K, 1), (i_h, off_k), (1, BK), (1, 0) ) b_gmid = tl.load(p_gmid, boundary_check=(0, 1)).to(tl.float32) diff = b_g - b_gmid qg_norm = (b_q * tl.exp(diff)).to(tl.bfloat16) kg_norm = (b_k * tl.exp(diff)).to(tl.bfloat16) kdecay_norm = (b_k * tl.exp(-diff)).to(tl.bfloat16) Aqk += tl.dot(qg_norm, tl.trans(kdecay_norm)) M += tl.dot(kg_norm, tl.trans(kdecay_norm)) p_beta = tl.make_block_ptr( beta_ptr, (T,), (H,), (t0,), (BT,), (0,) ) b_beta = tl.load(p_beta, boundary_check=(0,)).to(tl.float32) Aqk = tl.where(mask_lower, Aqk * scale, 0.0) A0 = tl.where(mask_strict, -M * b_beta[:, None], 0.0) A = A0 for i in range(1, BT): row_i = tl.sum(tl.where(r[:, None] == i, A, 0.0), axis=0) update = tl.sum(row_i[:, None] * A, axis=0) new_row = tl.where(c < i, row_i + update, row_i) A = tl.where((r[:, None] == i), new_row[None, :], A) A = tl.where(r[:, None] == c[None, :], A + 1.0, A) A = A * b_beta[None, :] p_Aqk = tl.make_block_ptr( Aqk_ptr, (T, BT), (H * BT, 1), (t0, 0), (BT, BT), (1, 0) ) tl.store(p_Aqk, Aqk.to(tl.bfloat16), boundary_check=(0, 1)) A_bf16 = A.to(tl.bfloat16) for i_k in range(tl.cdiv(K, BK)): off_k = i_k * BK p_k = tl.make_block_ptr( k_ptr, (T, K), (H * K, 1), (t0, off_k), (BT, BK), (1, 0) ) p_g = tl.make_block_ptr( g_ptr, (T, K), (H * K, 1), (t0, off_k), (BT, BK), (1, 0) ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) b_kg = (b_k * tl.exp(b_g)).to(tl.bfloat16) b_w = tl.dot(A_bf16, b_kg) p_w = tl.make_block_ptr( w_ptr, (T, K), (H * K, 1), (t0, off_k), (BT, BK), (1, 0) ) tl.store(p_w, b_w.to(tl.bfloat16), boundary_check=(0, 1)) for i_v in range(tl.cdiv(V, BV)): off_v = i_v * BV p_v = tl.make_block_ptr( v_ptr, (T, V), (H * V, 1), (t0, off_v), (BT, BV), (1, 0) ) b_v = tl.load(p_v, boundary_check=(0, 1)) b_u = tl.dot(A_bf16, b_v) p_u = tl.make_block_ptr( u_ptr, (T, V), (H * V, 1), (t0, off_v), (BT, BV), (1, 0) ) tl.store(p_u, b_u.to(tl.bfloat16), boundary_check=(0, 1)) @triton.autotune( configs=[ triton.Config({"BK": 64, "BV": 8}, num_warps=4, num_stages=1), triton.Config({"BK": 64, "BV": 8}, num_warps=8, num_stages=1), triton.Config({"BK": 64, "BV": 16}, num_warps=4, num_stages=1), triton.Config({"BK": 64, "BV": 16}, num_warps=8, num_stages=1), triton.Config({"BK": 64, "BV": 32}, num_warps=4, num_stages=1), triton.Config({"BK": 64, "BV": 32}, num_warps=8, num_stages=1), triton.Config({"BK": 64, "BV": 64}, num_warps=4, num_stages=1), triton.Config({"BK": 64, "BV": 64}, num_warps=8, num_stages=1), triton.Config({"BK": 128, "BV": 8}, num_warps=4, num_stages=1), triton.Config({"BK": 128, "BV": 8}, num_warps=8, num_stages=1), triton.Config({"BK": 128, "BV": 16}, num_warps=4, num_stages=1), triton.Config({"BK": 128, "BV": 16}, num_warps=8, num_stages=1), triton.Config({"BK": 128, "BV": 32}, num_warps=4, num_stages=1), triton.Config({"BK": 128, "BV": 32}, num_warps=8, num_stages=1), triton.Config({"BK": 128, "BV": 64}, num_warps=4, num_stages=1), triton.Config({"BK": 128, "BV": 64}, num_warps=8, num_stages=1), ], key=["B", "T", "H", "K", "V"], ) @triton.jit def _kda_inter_kernel( w_ptr, u_ptr, qg_ptr, k_ptr, g_ptr, Aqk_ptr, glast_ptr, o_ptr, B: tl.constexpr, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, ): i_vb = tl.program_id(0) i_bh = tl.program_id(1) i_b = i_bh // H i_h = i_bh % H w_ptr += (i_b * T * H + i_h) * K u_ptr += (i_b * T * H + i_h) * V qg_ptr += (i_b * T * H + i_h) * K k_ptr += (i_b * T * H + i_h) * K g_ptr += (i_b * T * H + i_h) * K Aqk_ptr += (i_b * T * H + i_h) * BT o_ptr += (i_b * T * H + i_h) * V off_v0 = i_vb * BV NT = T // BT # For K=128 we have either one BK=128 tile or two BK=64 tiles. if tl.constexpr(K == 128 and BK == 128): S = tl.zeros([BK, BV], dtype=tl.float32) for i_n in range(NT): t0 = i_n * BT p_u = tl.make_block_ptr( u_ptr, (T, V), (H * V, 1), (t0, off_v0), (BT, BV), (1, 0) ) b_u = tl.load(p_u, boundary_check=(0, 1)).to(tl.float32) b_v = b_u p_w = tl.make_block_ptr( w_ptr, (T, K), (H * K, 1), (t0, 0), (BT, BK), (1, 0) ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v -= tl.dot(b_w, S.to(tl.bfloat16)).to(tl.float32) b_o = tl.zeros([BT, BV], dtype=tl.float32) p_qg = tl.make_block_ptr( qg_ptr, (T, K), (H * K, 1), (t0, 0), (BT, BK), (1, 0) ) b_qg = tl.load(p_qg, boundary_check=(0, 1)) b_o += tl.dot(b_qg, S.to(tl.bfloat16)).to(tl.float32) p_Aqk = tl.make_block_ptr( Aqk_ptr, (T, BT), (H * BT, 1), (t0, 0), (BT, BT), (1, 0) ) b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1)) b_o += tl.dot(b_Aqk, b_v.to(tl.bfloat16)).to(tl.float32) p_o = tl.make_block_ptr( o_ptr, (T, V), (H * V, 1), (t0, off_v0), (BT, BV), (1, 0) ) tl.store(p_o, b_o.to(tl.bfloat16), boundary_check=(0, 1)) p_glast = glast_ptr + (i_b * H + i_h) * NT * K + i_n * K + tl.arange(0, BK) d = tl.exp(tl.load(p_glast)) p_k = tl.make_block_ptr( k_ptr, (T, K), (H * K, 1), (t0, 0), (BT, BK), (1, 0) ) b_k = tl.load(p_k, boundary_check=(0, 1)) p_g = tl.make_block_ptr( g_ptr, (T, K), (H * K, 1), (t0, 0), (BT, BK), (1, 0) ) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) k_up = (b_k.to(tl.float32) * tl.exp(tl.log(d)[None, :] - b_g)).to(tl.bfloat16) S = S * d[:, None] S += tl.dot(tl.trans(k_up), b_v.to(tl.bfloat16)) else: S0 = tl.zeros([BK, BV], dtype=tl.float32) S1 = tl.zeros([BK, BV], dtype=tl.float32) for i_n in range(NT): t0 = i_n * BT p_u = tl.make_block_ptr( u_ptr, (T, V), (H * V, 1), (t0, off_v0), (BT, BV), (1, 0) ) b_u = tl.load(p_u, boundary_check=(0, 1)).to(tl.float32) b_v = b_u p_w0 = tl.make_block_ptr( w_ptr, (T, K), (H * K, 1), (t0, 0), (BT, BK), (1, 0) ) b_w0 = tl.load(p_w0, boundary_check=(0, 1)) b_v -= tl.dot(b_w0, S0.to(tl.bfloat16)).to(tl.float32) p_w1 = tl.make_block_ptr( w_ptr, (T, K), (H * K, 1), (t0, BK), (BT, BK), (1, 0) ) b_w1 = tl.load(p_w1, boundary_check=(0, 1)) b_v -= tl.dot(b_w1, S1.to(tl.bfloat16)).to(tl.float32) b_o = tl.zeros([BT, BV], dtype=tl.float32) p_qg0 = tl.make_block_ptr( qg_ptr, (T, K), (H * K, 1), (t0, 0), (BT, BK), (1, 0) ) b_qg0 = tl.load(p_qg0, boundary_check=(0, 1)) b_o += tl.dot(b_qg0, S0.to(tl.bfloat16)).to(tl.float32) p_qg1 = tl.make_block_ptr( qg_ptr, (T, K), (H * K, 1), (t0, BK), (BT, BK), (1, 0) ) b_qg1 = tl.load(p_qg1, boundary_check=(0, 1)) b_o += tl.dot(b_qg1, S1.to(tl.bfloat16)).to(tl.float32) p_Aqk = tl.make_block_ptr( Aqk_ptr, (T, BT), (H * BT, 1), (t0, 0), (BT, BT), (1, 0) ) b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1)) b_o += tl.dot(b_Aqk, b_v.to(tl.bfloat16)).to(tl.float32) p_o = tl.make_block_ptr( o_ptr, (T, V), (H * V, 1), (t0, off_v0), (BT, BV), (1, 0) ) tl.store(p_o, b_o.to(tl.bfloat16), boundary_check=(0, 1)) p_glast0 = glast_ptr + (i_b * H + i_h) * NT * K + i_n * K + tl.arange(0, BK) d0 = tl.exp(tl.load(p_glast0)) p_glast1 = glast_ptr + (i_b * H + i_h) * NT * K + i_n * K + BK + tl.arange(0, BK) d1 = tl.exp(tl.load(p_glast1)) p_k0 = tl.make_block_ptr( k_ptr, (T, K), (H * K, 1), (t0, 0), (BT, BK), (1, 0) ) b_k0 = tl.load(p_k0, boundary_check=(0, 1)) p_g0 = tl.make_block_ptr( g_ptr, (T, K), (H * K, 1), (t0, 0), (BT, BK), (1, 0) ) b_g0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) k_up0 = (b_k0.to(tl.float32) * tl.exp(tl.log(d0)[None, :] - b_g0)).to(tl.bfloat16) p_k1 = tl.make_block_ptr( k_ptr, (T, K), (H * K, 1), (t0, BK), (BT, BK), (1, 0) ) b_k1 = tl.load(p_k1, boundary_check=(0, 1)) p_g1 = tl.make_block_ptr( g_ptr, (T, K), (H * K, 1), (t0, BK), (BT, BK), (1, 0) ) b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) k_up1 = (b_k1.to(tl.float32) * tl.exp(tl.log(d1)[None, :] - b_g1)).to(tl.bfloat16) S0 = S0 * d0[:, None] S1 = S1 * d1[:, None] S0 += tl.dot(tl.trans(k_up0), b_v.to(tl.bfloat16)) S1 += tl.dot(tl.trans(k_up1), b_v.to(tl.bfloat16)) class Model(nn.Module): def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 self.register_buffer("_dummy", torch.zeros(1), persistent=False) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: B, T, H, K = q.shape V = v.shape[-1] BT = self.chunk_size assert T % BT == 0 NT = T // BT device = q.device q = q.to(torch.bfloat16) k = k.to(torch.bfloat16) v = v.to(torch.bfloat16) g = g.to(torch.float32) beta = beta.to(torch.bfloat16) g_4d = g.view(B, NT, BT, H, K) g_cum = g_4d.cumsum(dim=2).view(B, T, H, K) q_scaled = q * self.scale qg = (q_scaled * g_cum.exp()).to(torch.bfloat16) g_mid = g_cum.view(B, NT, BT, H, K)[:, :, BT // 2, :, :].permute(0, 2, 1, 3).contiguous() glast = g_cum.view(B, NT, BT, H, K)[:, :, -1, :, :].permute(0, 2, 1, 3).contiguous() Aqk = torch.empty(B, T, H, BT, device=device, dtype=torch.bfloat16) w = torch.empty(B, T, H, K, device=device, dtype=torch.bfloat16) u = torch.empty(B, T, H, V, device=device, dtype=torch.bfloat16) # Use 64-wide tiles as a conservative default; autotune explores others. BK = 64 BV = 32 grid1 = (NT, B * H) _kda_intra_kernel[grid1]( q_scaled, k, v, g_cum, g_mid, beta, Aqk, w, u, 1.0, B=B, T=T, H=H, K=K, V=V, BT=BT, ) o = torch.empty(B, T, H, V, device=device, dtype=torch.bfloat16) grid2 = lambda meta: (triton.cdiv(V, meta['BV']), B * H) _kda_inter_kernel[grid2]( w, u, qg, k, g_cum, Aqk, glast, o, B=B, T=T, H=H, K=K, V=V, BT=BT, ) return o B = 2 T = 1024 H = 8 K = 128 V = 128 CHUNK_SIZE = 64 def get_inputs(): torch.manual_seed(0) q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1 g = (torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05) beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16)) return [q, k, v, g, beta] def get_init_inputs(): return [B, T, H, K, V, CHUNK_SIZE]