from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl _SOLVE_EXT = None def _solve_a_bf_cuda(raw: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: global _SOLVE_EXT if _SOLVE_EXT is None: from torch.utils.cpp_extension import load_inline cuda_src = r""" #include #include #include __global__ void solve_a_kernel(const __nv_bfloat16* __restrict__ raw, const float* __restrict__ beta, __nv_bfloat16* __restrict__ out, int m_count) { int m = blockIdx.x; int j = threadIdx.x; if (m >= m_count || j >= 64) return; __shared__ float a[64][64]; __shared__ float beta_s[64]; __shared__ float raw_row[64]; int bbase = m * 64; int base = m * 64 * 64; beta_s[j] = beta[bbase + j]; __syncthreads(); #pragma unroll for (int i = 0; i < 64; ++i) { raw_row[j] = __bfloat162float(raw[base + i * 64 + j]); __syncthreads(); float beta_i = beta_s[i]; float acc = (j < i) ? (-beta_i * raw_row[j]) : 0.0f; #pragma unroll for (int p = 0; p < 64; ++p) { if (p < i) { float coeff = -beta_i * raw_row[p]; acc = fmaf(coeff, a[p][j], acc); } } a[i][j] = acc; __syncthreads(); } float beta_j = beta_s[j]; #pragma unroll for (int i = 0; i < 64; ++i) { float val = (a[i][j] + (i == j ? 1.0f : 0.0f)) * beta_j; out[base + i * 64 + j] = __float2bfloat16_rn(val); } } torch::Tensor solve_a_bf(torch::Tensor raw, torch::Tensor beta) { auto out = torch::empty_like(raw); int m_count = raw.size(0); auto stream = at::cuda::getCurrentCUDAStream(); solve_a_kernel<<>>( reinterpret_cast(raw.data_ptr()), beta.data_ptr(), reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), m_count); return out; } """ _SOLVE_EXT = load_inline( name="kda_solve_a_bf_ext", cpp_sources="torch::Tensor solve_a_bf(torch::Tensor raw, torch::Tensor beta);", cuda_sources=cuda_src, functions=["solve_a_bf"], extra_cuda_cflags=["-O3", "--use_fast_math"], verbose=False, ) return _SOLVE_EXT.solve_a_bf(raw, beta) @triton.jit def _raw_aqk_kernel(xk, xq, yk, raw, aqk): m = tl.program_id(0) offs_c = tl.arange(0, 64) offs_k = tl.arange(0, 128) base_k = m * 64 * 128 base_c = m * 64 * 64 left = offs_c[:, None] * 128 + offs_k[None, :] right = offs_k[:, None] + offs_c[None, :] * 128 ykt = tl.load(yk + base_k + right) xk_blk = tl.load(xk + base_k + left) xq_blk = tl.load(xq + base_k + left) offs_out = base_c + offs_c[:, None] * 64 + offs_c[None, :] tl.store(raw + offs_out, tl.dot(xk_blk, ykt)) tl.store(aqk + offs_out, tl.dot(xq_blk, ykt)) @triton.jit def _solve_a_kernel(raw, bf, amat): m = tl.program_id(0) offs = tl.arange(0, 64) base = m * 64 * 64 bbase = m * 64 for i in range(0, 64): beta_i = tl.load(bf + bbase + i) raw_row = tl.load(raw + base + i * 64 + offs) acc = tl.where(offs < i, -beta_i * raw_row, 0.0) for j in range(0, 64): if j < i: raw_ij = tl.load(raw + base + i * 64 + j) coeff = -beta_i * raw_ij prev = tl.load(amat + base + j * 64 + offs) acc += coeff * prev tl.store(amat + base + i * 64 + offs, acc) beta_cols = tl.load(bf + bbase + offs) for i in range(0, 64): row = tl.load(amat + base + i * 64 + offs) row = (row + tl.where(offs == i, 1.0, 0.0)) * beta_cols tl.store(amat + base + i * 64 + offs, row) @triton.jit def _preprocess_kernel( q, k, v, g, beta, xq, xk, yk, gamma, vf, bf, scale: tl.constexpr, NT: tl.constexpr, H: tl.constexpr, BK: tl.constexpr, ): m = tl.program_id(0) pid_d = tl.program_id(1) offs_c = tl.arange(0, 64) offs_d = pid_d * BK + tl.arange(0, BK) b = m // (H * NT) rem = m - b * (H * NT) h = rem // NT n = rem - h * NT t = n * 64 + offs_c qkv_offsets = ((b * (NT * 64) + t[:, None]) * H + h) * 128 + offs_d[None, :] g_vals = tl.load(g + qkv_offsets) gc = tl.cumsum(g_vals, 0) e = tl.exp(gc) k_vals = tl.load(k + qkv_offsets).to(tl.float32) q_vals = tl.load(q + qkv_offsets).to(tl.float32) v_vals = tl.load(v + qkv_offsets).to(tl.float32) out_offsets = m * 64 * 128 + offs_c[:, None] * 128 + offs_d[None, :] tl.store(xk + out_offsets, k_vals * e) tl.store(yk + out_offsets, k_vals * tl.exp(-gc)) tl.store(xq + out_offsets, q_vals * e * scale) tl.store(vf + out_offsets, v_vals) e_last = tl.sum(tl.where(offs_c[:, None] == 63, e, 0.0), axis=0) tl.store(gamma + m * 128 + offs_d, e_last) beta_offsets = (b * (NT * 64) + t) * H + h beta_vals = tl.load(beta + beta_offsets).to(tl.float32) tl.store(bf + m * 64 + offs_c, beta_vals) @triton.jit def _recurrent_kernel( xq, yk, gamma, w, u, aqk, out, NT: tl.constexpr, H: tl.constexpr, BV: tl.constexpr, ): pid_bh = tl.program_id(0) pid_vb = tl.program_id(1) offs_k = tl.arange(0, 128) offs_c = tl.arange(0, 64) offs_v = pid_vb * BV + tl.arange(0, BV) state = tl.zeros((128, BV), tl.float32) base_bh = pid_bh * NT b = pid_bh // H h = pid_bh - b * H for n in range(0, NT): m = base_bh + n base_m_k = m * 64 * 128 base_m_c = m * 64 * 64 c_by_k = offs_c[:, None] * 128 + offs_k[None, :] w_blk = tl.load(w + base_m_k + c_by_k) u_blk = tl.load(u + base_m_k + offs_c[:, None] * 128 + offs_v[None, :]) vi = u_blk - tl.dot(w_blk, state, input_precision="tf32") xq_blk = tl.load(xq + base_m_k + c_by_k) acc = tl.dot(xq_blk, state, input_precision="tf32") aqk_blk = tl.load(aqk + base_m_c + offs_c[:, None] * 64 + offs_c[None, :]) aqk_blk = tl.where(offs_c[:, None] >= offs_c[None, :], aqk_blk, 0.0) acc += tl.dot(aqk_blk, vi, input_precision="tf32") t = n * 64 + offs_c out_offs = ((b * (NT * 64) + t[:, None]) * H + h) * 128 + offs_v[None, :] tl.store(out + out_offs, acc) ykt = tl.load(yk + base_m_k + offs_c[None, :] * 128 + offs_k[:, None]) upd = tl.dot(ykt, vi, input_precision="tf32") gamma_vals = tl.load(gamma + m * 128 + offs_k)[:, None] state = (state + upd) * gamma_vals @triton.jit def _recurrent_bf_kernel( xq, yk, gamma, w, u, aqk, out, NT: tl.constexpr, H: tl.constexpr, BV: tl.constexpr, ): pid_bh = tl.program_id(0) pid_vb = tl.program_id(1) offs_k = tl.arange(0, 128) offs_c = tl.arange(0, 64) offs_v = pid_vb * BV + tl.arange(0, BV) state = tl.zeros((128, BV), tl.float32) base_bh = pid_bh * NT b = pid_bh // H h = pid_bh - b * H for n in range(0, NT): m = base_bh + n base_m_k = m * 64 * 128 base_m_c = m * 64 * 64 c_by_k = offs_c[:, None] * 128 + offs_k[None, :] state_bf = state.to(tl.bfloat16) w_blk = tl.load(w + base_m_k + c_by_k) u_blk = tl.load(u + base_m_k + offs_c[:, None] * 128 + offs_v[None, :]).to(tl.float32) vi = u_blk - tl.dot(w_blk, state_bf) vi_bf = vi.to(tl.bfloat16) xq_blk = tl.load(xq + base_m_k + c_by_k) acc = tl.dot(xq_blk, state_bf) aqk_blk = tl.load(aqk + base_m_c + offs_c[:, None] * 64 + offs_c[None, :]) aqk_blk = tl.where(offs_c[:, None] >= offs_c[None, :], aqk_blk, 0.0) acc += tl.dot(aqk_blk, vi_bf) t = n * 64 + offs_c out_offs = ((b * (NT * 64) + t[:, None]) * H + h) * 128 + offs_v[None, :] tl.store(out + out_offs, acc) ykt = tl.load(yk + base_m_k + offs_c[None, :] * 128 + offs_k[:, None]) upd = tl.dot(ykt, vi_bf) gamma_vals = tl.load(gamma + m * 128 + offs_k)[:, None] state = (state + upd) * gamma_vals @triton.jit def _recurrent_fused_kernel( xq, xk, yk, gamma, amat, vf, aqk, out, NT: tl.constexpr, H: tl.constexpr, BV: tl.constexpr, ): pid_bh = tl.program_id(0) pid_vb = tl.program_id(1) offs_k = tl.arange(0, 128) offs_c = tl.arange(0, 64) offs_v = pid_vb * BV + tl.arange(0, BV) state = tl.zeros((128, BV), tl.float32) base_bh = pid_bh * NT b = pid_bh // H h = pid_bh - b * H for n in range(0, NT): m = base_bh + n base_m_k = m * 64 * 128 base_m_c = m * 64 * 64 c_by_k = offs_c[:, None] * 128 + offs_k[None, :] c_by_c = offs_c[:, None] * 64 + offs_c[None, :] state_bf = state.to(tl.bfloat16) xk_blk = tl.load(xk + base_m_k + c_by_k) vf_blk = tl.load(vf + base_m_k + offs_c[:, None] * 128 + offs_v[None, :]).to(tl.float32) residual = vf_blk - tl.dot(xk_blk, state_bf) a_blk = tl.load(amat + base_m_c + c_by_c) vi = tl.dot(a_blk, residual.to(tl.bfloat16)) vi_bf = vi.to(tl.bfloat16) xq_blk = tl.load(xq + base_m_k + c_by_k) acc = tl.dot(xq_blk, state_bf) aqk_blk = tl.load(aqk + base_m_c + c_by_c) aqk_blk = tl.where(offs_c[:, None] >= offs_c[None, :], aqk_blk, 0.0) acc += tl.dot(aqk_blk, vi_bf) t = n * 64 + offs_c out_offs = ((b * (NT * 64) + t[:, None]) * H + h) * 128 + offs_v[None, :] tl.store(out + out_offs, acc) ykt = tl.load(yk + base_m_k + offs_c[None, :] * 128 + offs_k[:, None]) upd = tl.dot(ykt, vi_bf) gamma_vals = tl.load(gamma + m * 128 + offs_k)[:, None] state = (state + upd) * gamma_vals def _forward_torch( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, block: int, ) -> torch.Tensor: dtype = v.dtype B, T, H, K = q.shape V = v.shape[-1] NT = T // block M = B * H * NT xq = torch.empty((M, block, K), dtype=torch.bfloat16, device=q.device) xk = torch.empty((M, block, K), dtype=torch.bfloat16, device=q.device) yk = torch.empty((M, block, K), dtype=torch.bfloat16, device=q.device) gamma = torch.empty((M, K), dtype=torch.float32, device=q.device) vf = torch.empty((M, block, V), dtype=torch.bfloat16, device=q.device) bf = torch.empty((M, block), dtype=torch.float32, device=q.device) _preprocess_kernel[(M, triton.cdiv(K, 32))]( q, k, v, g, beta, xq, xk, yk, gamma, vf, bf, scale, NT, H, BK=32, num_warps=8 ) raw = torch.empty((M, block, block), dtype=torch.bfloat16, device=q.device) aqk = torch.empty((M, block, block), dtype=torch.bfloat16, device=q.device) _raw_aqk_kernel[(M,)](xk, xq, yk, raw, aqk, num_warps=4, num_stages=1) amat_bf = _solve_a_bf_cuda(raw, bf) out = torch.empty((B, T, H, V), dtype=dtype, device=q.device) bv = 16 _recurrent_fused_kernel[(B * H, triton.cdiv(V, bv))]( xq, xk, yk, gamma, amat_bf, vf, aqk, out, NT, H, BV=bv, num_warps=4, num_stages=1 ) return out class Model(nn.Module): def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 self._graph = None self._graph_key = None self._graph_out = None self.register_buffer("_dummy", torch.zeros(1), persistent=False) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: if not q.is_cuda: return _forward_torch(q, k, v, g, beta, self.scale, self.chunk_size) key = (q.data_ptr(), k.data_ptr(), v.data_ptr(), g.data_ptr(), beta.data_ptr()) if self._graph is None or self._graph_key != key: _forward_torch(q, k, v, g, beta, self.scale, self.chunk_size) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): self._graph_out = _forward_torch(q, k, v, g, beta, self.scale, self.chunk_size) self._graph = graph self._graph_key = key self._graph.replay() return self._graph_out B = 2 T = 1024 H = 8 K = 128 V = 128 CHUNK_SIZE = 64 def get_inputs(): torch.manual_seed(0) q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1 g = torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05 beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16)) return [q, k, v, g, beta] def get_init_inputs(): return [B, T, H, K, V, CHUNK_SIZE]