"""KDA forward (chunk form) — optimized implementation for SM120 Blackwell. Uses cuBLAS for batched matmuls and torch.compile (inductor) to fuse the inter-chunk recurrence loop. Key optimizations: - Intra-chunk: batched bmm + solve_triangular (cuBLAS batch-GEMM) - Pre-compute all-chunk Aqk in one batched bmm - Fuse w@S and q@S into a single stacked bmm per chunk - torch.compile the inter-chunk loop (max-autotune Triton kernels) - All intermediate compute in fp32 for cuBLAS efficiency """ from __future__ import annotations import torch import torch.nn as nn from einops import rearrange OP_TYPE = "linear_attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] # --------------------------------------------------------------------------- # Intra-chunk # --------------------------------------------------------------------------- def _intra_chunk(q, k, v, g, beta, scale): B, H, NT, BT, K = q.shape V = v.shape[-1] device = q.device q_f = q.float() * scale k_f = k.float() v_f = v.float() g_f = g.float().cumsum(-2) beta_f = beta.float() exp_g = g_f.exp() exp_neg_g = (-g_f).exp() k_exp_g = k_f * exp_g k_exp_neg_g = k_f * exp_neg_g fb = B * H * NT k_exp_g_f = k_exp_g.reshape(fb, BT, K) k_exp_neg_g_f = k_exp_neg_g.reshape(fb, BT, K) beta_fb = beta_f.reshape(fb, BT) # Batched matmul: M_raw = (k*exp(g)) @ (k*exp(-g))^T M_raw = torch.bmm(k_exp_g_f, k_exp_neg_g_f.transpose(1, 2)) M_raw = M_raw * beta_fb.unsqueeze(-1) # Solve triangular: A = (I + tril(M, -1))^{-1} * diag(beta) L = torch.tril(M_raw, diagonal=-1) I_plus_L = torch.eye(BT, dtype=torch.float32, device=device).unsqueeze(0) + L A_fb = torch.linalg.solve_triangular( I_plus_L, torch.diag_embed(beta_fb), upper=False, unitriangular=True, ) w = torch.bmm(A_fb, k_exp_g_f).reshape(B, H, NT, BT, K) u = torch.bmm(A_fb, v_f.reshape(fb, BT, V)).reshape(B, H, NT, BT, V) q_exp_g = q_f * exp_g g_last = g_f[:, :, :, -1, :] exp_g_last = g_last.exp() kg = (g_last.unsqueeze(-2) - g_f).exp() * k_f return w, u, q_exp_g, k_exp_neg_g, exp_g_last, kg # --------------------------------------------------------------------------- # Inter-chunk (compiled) # --------------------------------------------------------------------------- def _make_inter_chunk_fn(): """Build and compile the inter-chunk recurrence function.""" def _inter_chunk_loop(qe_flat, kd_flat, w_flat, u_flat, gl_flat, kg_flat): BH, NT, BT, K = qe_flat.shape V = u_flat.shape[-1] device = qe_flat.device # Pre-compute Aqk for all chunks in one batched bmm qe2d = qe_flat.reshape(BH * NT, BT, K) kd2d = kd_flat.reshape(BH * NT, BT, K) Aqk_all = torch.bmm(qe2d, kd2d.transpose(1, 2)) mask = torch.triu( torch.ones(BT, BT, dtype=torch.bool, device=device), diagonal=1, ) Aqk_all = Aqk_all.masked_fill(mask, 0) Aqk_flat = Aqk_all.reshape(BH, NT, BT, BT) o_flat = torch.empty(BH, NT, BT, V, dtype=torch.float32, device=device) S = torch.zeros(BH, K, V, dtype=torch.float32, device=device) for n in range(NT): # Fused w@S + q@S: concatenated bmm wq = torch.cat([w_flat[:, n], qe_flat[:, n]], dim=1) wqS = torch.bmm(wq, S) wS, qS = wqS[:, :BT, :], wqS[:, BT:, :] vc = u_flat[:, n] - wS o_flat[:, n] = qS + torch.bmm(Aqk_flat[:, n], vc) # State update: S = S * gl + kg^T @ vc S = S * gl_flat[:, n].unsqueeze(-1) S = S + torch.bmm(kg_flat[:, n].transpose(1, 2), vc) return o_flat return torch.compile(_inter_chunk_loop, mode="max-autotune", fullgraph=False) _inter_chunk_compiled = None def _inter_chunk(qe_flat, kd_flat, w_flat, u_flat, gl_flat, kg_flat): global _inter_chunk_compiled if _inter_chunk_compiled is None: _inter_chunk_compiled = _make_inter_chunk_fn() return _inter_chunk_compiled(qe_flat, kd_flat, w_flat, u_flat, gl_flat, kg_flat) # --------------------------------------------------------------------------- # Main forward # --------------------------------------------------------------------------- def _kda_forward(q, k, v, g, beta, scale, chunk_size): B, T, H, K_shape = q.shape V = v.shape[-1] BT = chunk_size NT = T // BT assert T % BT == 0 q_c = rearrange(q, "b (n c) h d -> b h n c d", c=BT) k_c = rearrange(k, "b (n c) h d -> b h n c d", c=BT) v_c = rearrange(v, "b (n c) h d -> b h n c d", c=BT) g_c = rearrange(g, "b (n c) h d -> b h n c d", c=BT) beta_c = rearrange(beta, "b (n c) h -> b h n c", c=BT) w, u, q_exp_g, k_exp_neg_g, exp_g_last, kg = _intra_chunk( q_c, k_c, v_c, g_c, beta_c, scale, ) BH = B * H qe_f = q_exp_g.reshape(BH, NT, BT, K) kd_f = k_exp_neg_g.reshape(BH, NT, BT, K) w_f = w.reshape(BH, NT, BT, K) u_f = u.reshape(BH, NT, BT, V) gl_f = exp_g_last.reshape(BH, NT, K) kg_f = kg.reshape(BH, NT, BT, K) o_f = _inter_chunk(qe_f, kd_f, w_f, u_f, gl_f, kg_f) o_c = o_f.reshape(B, H, NT, BT, V) o = rearrange(o_c, "b h n c d -> b (n c) h d") return o.to(v.dtype) # --------------------------------------------------------------------------- # Module interface # --------------------------------------------------------------------------- B = 2 T = 1024 H = 8 K = 128 V = 128 CHUNK_SIZE = 64 class Model(nn.Module): def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 self.register_buffer("_dummy", torch.zeros(1), persistent=False) def forward(self, q, k, v, g, beta): return _kda_forward(q, k, v, g, beta, scale=self.scale, chunk_size=self.chunk_size) def get_inputs(): torch.manual_seed(0) q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1 g = (torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05) beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16)) return [q, k, v, g, beta] def get_init_inputs(): return [B, T, H, K, V, CHUNK_SIZE]