"""KDA forward (chunk form) — custom Triton kernels (SM120).""" from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "linear_attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] RCP_LN2 = 1.4426950408889634 SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32") @triton.jit def exp2(x): return tl.exp2(x) def prepare_chunk_indices(cu_seqlens, chunk_size): return None def prepare_chunk_offsets(cu_seqlens, chunk_size): return None def check_shared_mem(arch="ampere"): return True IS_GATHER_SUPPORTED = False NUM_WARPS = [2, 4, 8, 16] # === cumsum.py === # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # For a list of all contributors, visit: # https://github.com/fla-org/flash-linear-attention/graphs/contributors import torch import triton import triton.language as tl @triton.jit(do_not_specialize=['T']) def chunk_local_cumsum_scalar_kernel( s, o, scale, cu_seqlens, chunk_indices, T, B: tl.constexpr, H: tl.constexpr, BT: tl.constexpr, REVERSE: tl.constexpr, HAS_SCALE: tl.constexpr, IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) # [BT] b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) b_o = tl.cumsum(b_s, axis=0) if REVERSE: b_z = tl.sum(b_s, axis=0) b_o = -b_o + b_z[None] + b_s if HAS_SCALE: b_o *= scale tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) @triton.jit(do_not_specialize=['T']) def chunk_local_cumsum_vector_kernel( s, o, scale, cu_seqlens, chunk_indices, T, B: tl.constexpr, H: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, REVERSE: tl.constexpr, HAS_SCALE: tl.constexpr, IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) else: p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) if REVERSE: b_o = tl.cumsum(b_s, axis=0, reverse=True) else: b_o = tl.cumsum(b_s, axis=0) if HAS_SCALE: b_o *= scale tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def chunk_global_cumsum_scalar_kernel( s, o, scale, cu_seqlens, T, B: tl.constexpr, H: tl.constexpr, BT: tl.constexpr, REVERSE: tl.constexpr, HAS_SCALE: tl.constexpr, IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_nh = tl.program_id(0) i_n, i_h = i_nh // H, i_nh % H if IS_VARLEN: bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) else: bos, eos = i_n * T, i_n * T + T T = eos - bos b_z = tl.zeros([], dtype=tl.float32) NT = tl.cdiv(T, BT) for i_c in range(NT): i_t = NT - 1 - i_c if REVERSE else i_c if HEAD_FIRST: p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) b_o = tl.cumsum(b_s, axis=0) b_ss = tl.sum(b_s, 0) if REVERSE: b_o = -b_o + b_ss + b_s b_o += b_z if i_c >= 0: b_z += b_ss if HAS_SCALE: b_o *= scale tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) @triton.jit(do_not_specialize=['T']) def chunk_global_cumsum_vector_kernel( s, o, scale, cu_seqlens, T, B: tl.constexpr, H: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, REVERSE: tl.constexpr, HAS_SCALE: tl.constexpr, IS_VARLEN: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_s, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H if IS_VARLEN: bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) else: bos, eos = i_n * T, i_n * T + T T = eos - bos b_z = tl.zeros([BS], dtype=tl.float32) NT = tl.cdiv(T, BT) for i_c in range(NT): i_t = NT - 1 - i_c if REVERSE else i_c if HEAD_FIRST: p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) else: p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) if REVERSE: b_c = b_z[None, :] + tl.cumsum(b_s, axis=0, reverse=True) else: b_c = b_z[None, :] + tl.cumsum(b_s, axis=0) if HAS_SCALE: b_c *= scale tl.store(p_o, b_c.to(p_o.dtype.element_ty), boundary_check=(0, 1)) b_z += tl.sum(b_s, 0) def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, reverse: bool = False, scale: float = None, cu_seqlens: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, chunk_indices: torch.LongTensor | None = None, ) -> torch.Tensor: if head_first: B, H, T = g.shape else: B, T, H = g.shape assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) grid = (NT, B * H) chunk_local_cumsum_scalar_kernel[grid]( s=g_org, o=g, scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, B=B, H=H, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse, ) return g def chunk_local_cumsum_vector( g: torch.Tensor, chunk_size: int, reverse: bool = False, scale: float = None, cu_seqlens: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, chunk_indices: torch.LongTensor | None = None, ) -> torch.Tensor: if head_first: B, H, T, S = g.shape else: B, T, H, S = g.shape BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) BS = 64 chunk_local_cumsum_vector_kernel[(triton.cdiv(S, BS), NT, B * H)]( s=g_org, o=g, scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, B=B, H=H, S=S, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse, BS=BS, HAS_SCALE=scale is not None, IS_VARLEN=cu_seqlens is not None, ) return g def chunk_global_cumsum_scalar( s: torch.Tensor, reverse: bool = False, cu_seqlens: torch.Tensor | None = None, scale: float = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: if head_first: B, H, T = s.shape else: B, T, H = s.shape N = len(cu_seqlens) - 1 if cu_seqlens is not None else B z = torch.empty_like(s, dtype=output_dtype or s.dtype) grid = (N * H,) chunk_global_cumsum_scalar_kernel[grid]( s=s, o=z, scale=scale, cu_seqlens=cu_seqlens, T=T, B=B, H=H, HEAD_FIRST=head_first, REVERSE=reverse, ) return z def chunk_global_cumsum_vector( s: torch.Tensor, reverse: bool = False, cu_seqlens: torch.Tensor | None = None, scale: float = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: if head_first: B, H, T, S = s.shape else: B, T, H, S = s.shape N = len(cu_seqlens) - 1 if cu_seqlens is not None else B BS = min(32, triton.next_power_of_2(S)) z = torch.empty_like(s, dtype=output_dtype or s.dtype) grid = (triton.cdiv(S, BS), N * H) chunk_global_cumsum_vector_kernel[grid]( s=s, o=z, scale=scale, cu_seqlens=cu_seqlens, T=T, B=B, H=H, S=S, BS=BS, HEAD_FIRST=head_first, REVERSE=reverse, ) return z def chunk_global_cumsum( s: torch.Tensor, reverse: bool = False, cu_seqlens: torch.Tensor | None = None, scale: float = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: if cu_seqlens is not None: assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" if len(s.shape) == 3: return chunk_global_cumsum_scalar( s=s, reverse=reverse, cu_seqlens=cu_seqlens, scale=scale, head_first=head_first, output_dtype=output_dtype, ) elif len(s.shape) == 4: return chunk_global_cumsum_vector( s=s, reverse=reverse, cu_seqlens=cu_seqlens, scale=scale, head_first=head_first, output_dtype=output_dtype, ) else: raise ValueError( f"Unsupported input shape {s.shape}, " f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` " f"or [B, H, T]/[B, H, T, D] otherwise", ) def chunk_local_cumsum( g: torch.Tensor, chunk_size: int, reverse: bool = False, scale: float = None, cu_seqlens: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, chunk_indices: torch.LongTensor | None = None, **kwargs, ) -> torch.Tensor: if cu_seqlens is not None: assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" if len(g.shape) == 3: return chunk_local_cumsum_scalar( g=g, chunk_size=chunk_size, reverse=reverse, scale=scale, cu_seqlens=cu_seqlens, head_first=head_first, output_dtype=output_dtype, chunk_indices=chunk_indices, ) elif len(g.shape) == 4: return chunk_local_cumsum_vector( g=g, chunk_size=chunk_size, reverse=reverse, scale=scale, cu_seqlens=cu_seqlens, head_first=head_first, output_dtype=output_dtype, chunk_indices=chunk_indices, ) else: raise ValueError( f"Unsupported input shape {g.shape}, " f"which should be (B, T, H, D) if `head_first=False` " f"or (B, H, T, D) otherwise", ) # === chunk_intra_token_parallel.py === # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # For a list of all contributors, visit: # https://github.com/fla-org/flash-linear-attention/graphs/contributors # Token-parallel implementation of KDA intra chunk kernel import torch import triton import triton.language as tl @triton.jit(do_not_specialize=['T', 'N']) def kda_fwd_kernel_intra_token_parallel( q, k, g, beta, Aqk, Akk, scale, cu_seqlens, N, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BH: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_tg, i_hg = tl.program_id(0), tl.program_id(1) if IS_VARLEN: i_n = 0 left, right = 0, N # Unrolled binary search (max B=2^32) # We can limit iterations based on expected max batch size if needed # 20 iterations covers B=1M, usually enough for _ in range(20): if left < right: mid = (left + right) // 2 if i_tg < tl.load(cu_seqlens + mid + 1).to(tl.int32): right = mid else: left = mid + 1 i_n = left bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos i_t = i_tg - bos else: bos = (i_tg // T) * T i_t = i_tg % T if i_t >= T: return i_c = i_t // BT i_s = (i_t % BT) // BC i_tc = i_c * BT i_ts = i_tc + i_s * BC G: tl.constexpr = HV // H q += bos * H*K k += bos * H*K g += bos * HV*K Aqk += bos * HV*BT Akk += bos * HV*BC beta += bos * HV BK: tl.constexpr = triton.next_power_of_2(K) o_hv = i_hg * BH + tl.arange(0, BH) o_h = o_hv // G o_k = tl.arange(0, BK) m_hv = o_hv < HV m_k = o_k < K m_hk = m_hv[:, None] & m_k[None, :] # q/k: [B, T, H, K], manual load via mapped qk head index p_qk = o_h[:, None] * K + o_k[None, :] b_q = tl.load(q + i_t * H * K + p_qk, mask=m_hk, other=0).to(tl.float32) b_k = tl.load(k + i_t * H * K + p_qk, mask=m_hk, other=0).to(tl.float32) # g: [B, T, HV, K], beta: [B, T, HV] p_g = tl.make_block_ptr(g + i_t * HV * K, (HV, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) p_beta = tl.make_block_ptr(beta + i_t * HV, (HV,), (1,), (i_hg * BH,), (BH,), (0,)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) b_k = b_k * tl.load(p_beta, boundary_check=(0,)).to(tl.float32)[:, None] for j in range(i_ts, min(i_t + 1, min(T, i_ts + BC))): b_kj = tl.load(k + j * H * K + p_qk, mask=m_hk, other=0).to(tl.float32) p_gj = tl.make_block_ptr(g + j * HV * K, (HV, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) b_gj = tl.load(p_gj, boundary_check=(0, 1)).to(tl.float32) b_kgj = tl.where(m_k[None, :], b_kj * exp2(b_g - b_gj), 0.0) b_Aqk = tl.sum(b_q * b_kgj, axis=1) * scale b_Akk = tl.sum(b_k * b_kgj, axis=1) * tl.where(j < i_t, 1.0, 0.0) tl.store(Aqk + i_t * HV * BT + o_hv * BT + j % BT, b_Aqk.to(Aqk.dtype.element_ty), mask=m_hv) tl.store(Akk + i_t * HV * BC + o_hv * BC + j - i_ts, b_Akk.to(Akk.dtype.element_ty), mask=m_hv) def kda_fwd_intra_token_parallel( q: torch.Tensor, k: torch.Tensor, gk: torch.Tensor, beta: torch.Tensor, Aqk: torch.Tensor, Akk: torch.Tensor, scale: float, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, sub_chunk_size: int = 16, ) -> None: """ Token-parallel implementation: each token gets its own thread block. Supports both fixed-length and variable-length sequences. Reduces wasted computation on padding. Writes directly to Aqk and Akk tensors (in-place). Args: q: [B, T, H, K] k: [B, T, H, K] gk: [B, T, HV, K] cumsum of gates (HV >= H for GVA) beta: [B, T, HV] Aqk: [B, T, HV, BT] output tensor to write to Akk: [B, T, HV, BC] output tensor for diagonal blocks (fp32) scale: attention scale chunk_size: BT (default 64) sub_chunk_size: BC (default 16) """ B, T, H, K, HV = *q.shape, gk.shape[2] N = len(cu_seqlens) - 1 if cu_seqlens is not None else B BT = chunk_size BC = sub_chunk_size BH = 8 kda_fwd_kernel_intra_token_parallel[(B * T, triton.cdiv(HV, BH))]( q=q, k=k, g=gk, beta=beta, Aqk=Aqk, Akk=Akk, scale=scale, cu_seqlens=cu_seqlens, N=N, T=T, H=H, HV=HV, K=K, BT=BT, BC=BC, BH=BH, IS_VARLEN=cu_seqlens is not None, ) return Aqk, Akk # === chunk_intra.py === # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # For a list of all contributors, visit: # https://github.com/fla-org/flash-linear-attention/graphs/contributors import torch import triton import triton.language as tl ################################################################################ # Fused inter + solve_tril kernel: compute off-diagonal Akk and solve in one pass ################################################################################ @triton.jit(do_not_specialize=['T']) def kda_fwd_kernel_inter_solve_fused( q, k, g, beta, Aqk, Akkd, Akk, scale, cu_seqlens, chunk_indices, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, NC: tl.constexpr, BK: tl.constexpr, IS_VARLEN: tl.constexpr, USE_SAFE_GATE: tl.constexpr, ): """ Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. Prerequisite: token_parallel has already computed diagonal Akk blocks in Akkd. This kernel: 1. Computes off-diagonal Aqk blocks -> writes to global 2. Computes off-diagonal Akk blocks -> keeps in registers 3. Loads diagonal Akk blocks from Akkd (fp32) 4. Does forward substitution on diagonals 5. Computes merged Akk_inv 6. Writes Akk_inv to Akk """ i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_hv = i_bh // HV, i_bh % HV i_h = i_hv // (HV // H) if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if i_t * BT >= T: return i_tc0 = i_t * BT i_tc1 = i_t * BT + BC i_tc2 = i_t * BT + 2 * BC i_tc3 = i_t * BT + 3 * BC q += (bos * H + i_h) * K k += (bos * H + i_h) * K g += (bos * HV + i_hv) * K Aqk += (bos * HV + i_hv) * BT Akk += (bos * HV + i_hv) * BT Akkd += (bos * HV + i_hv) * BC o_i = tl.arange(0, BC) m_tc1 = (i_tc1 + o_i) < T m_tc2 = (i_tc2 + o_i) < T m_tc3 = (i_tc3 + o_i) < T b_Aqk10 = tl.zeros([BC, BC], dtype=tl.float32) b_Akk10 = tl.zeros([BC, BC], dtype=tl.float32) b_Aqk20 = tl.zeros([BC, BC], dtype=tl.float32) b_Akk20 = tl.zeros([BC, BC], dtype=tl.float32) b_Aqk21 = tl.zeros([BC, BC], dtype=tl.float32) b_Akk21 = tl.zeros([BC, BC], dtype=tl.float32) b_Aqk30 = tl.zeros([BC, BC], dtype=tl.float32) b_Akk30 = tl.zeros([BC, BC], dtype=tl.float32) b_Aqk31 = tl.zeros([BC, BC], dtype=tl.float32) b_Akk31 = tl.zeros([BC, BC], dtype=tl.float32) b_Aqk32 = tl.zeros([BC, BC], dtype=tl.float32) b_Akk32 = tl.zeros([BC, BC], dtype=tl.float32) ################################################################################ # off-diagonal blocks ################################################################################ for i_k in range(tl.cdiv(K, BK)): o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K p_k0 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) p_g0 = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) b_k0 = tl.load(p_k0, boundary_check=(0, 1)).to(tl.float32) b_g0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) if i_tc1 < T: p_q1 = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) p_k1 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) p_g1 = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) # [BC, BK] b_q1 = tl.load(p_q1, boundary_check=(0, 1)).to(tl.float32) b_k1 = tl.load(p_k1, boundary_check=(0, 1)).to(tl.float32) b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) # [BK] b_gn1 = tl.load(g + i_tc1 * HV*K + o_k, mask=m_k, other=0).to(tl.float32) # [BC, BK] b_gqn = tl.where(m_tc1[:, None], exp2(b_g1 - b_gn1[None, :]), 0) # [BK, BC] b_kgt = tl.trans(b_k0 * exp2(b_gn1[None, :] - b_g0)) # [BC, BC] b_Aqk10 += tl.dot(b_q1 * b_gqn, b_kgt) b_Akk10 += tl.dot(b_k1 * b_gqn, b_kgt) if NC >= 3 and i_tc2 < T: p_q2 = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) p_k2 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) p_g2 = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) # [BC, BK] b_q2 = tl.load(p_q2, boundary_check=(0, 1)).to(tl.float32) b_k2 = tl.load(p_k2, boundary_check=(0, 1)).to(tl.float32) b_g2 = tl.load(p_g2, boundary_check=(0, 1)).to(tl.float32) # [BK] b_gn2 = tl.load(g + i_tc2 * HV*K + o_k, mask=m_k, other=0).to(tl.float32) # [BC, BK] b_gqn2 = tl.where(m_tc2[:, None], exp2(b_g2 - b_gn2[None, :]), 0) b_qg2 = b_q2 * b_gqn2 b_kg2 = b_k2 * b_gqn2 # [BK, BC] b_kgt = tl.trans(b_k0 * exp2(b_gn2[None, :] - b_g0)) b_Aqk20 += tl.dot(b_qg2, b_kgt) b_Akk20 += tl.dot(b_kg2, b_kgt) # [BC, BC] b_kgt = tl.trans(b_k1 * exp2(b_gn2[None, :] - b_g1)) # [BC, BC] b_Aqk21 += tl.dot(b_qg2, b_kgt) b_Akk21 += tl.dot(b_kg2, b_kgt) if NC >= 4 and i_tc3 < T: p_q3 = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) p_k3 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) p_g3 = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) # [BC, BK] b_q3 = tl.load(p_q3, boundary_check=(0, 1)).to(tl.float32) b_k3 = tl.load(p_k3, boundary_check=(0, 1)).to(tl.float32) b_g3 = tl.load(p_g3, boundary_check=(0, 1)).to(tl.float32) # [BK] b_gn3 = tl.load(g + i_tc3 * HV*K + o_k, mask=m_k, other=0).to(tl.float32) # [BC, BK] b_gqn3 = tl.where(m_tc3[:, None], exp2(b_g3 - b_gn3[None, :]), 0) b_qg3 = b_q3 * b_gqn3 b_kg3 = b_k3 * b_gqn3 # [BK, BC] b_kgt = tl.trans(b_k0 * exp2(b_gn3[None, :] - b_g0)) # [BC, BC] b_Aqk30 += tl.dot(b_qg3, b_kgt) b_Akk30 += tl.dot(b_kg3, b_kgt) # [BK, BC] b_kgt = tl.trans(b_k1 * exp2(b_gn3[None, :] - b_g1)) # [BC, BC] b_Aqk31 += tl.dot(b_qg3, b_kgt) b_Akk31 += tl.dot(b_kg3, b_kgt) # [BK, BC] b_kgt = tl.trans(b_k2 * exp2(b_gn3[None, :] - b_g2)) # [BC, BC] b_Aqk32 += tl.dot(b_qg3, b_kgt) b_Akk32 += tl.dot(b_kg3, b_kgt) ################################################################################ # save off-diagonal Aqk blocks and prepare Akk ################################################################################ if i_tc1 < T: p_Aqk10 = tl.make_block_ptr(Aqk, (T, BT), (HV*BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) tl.store(p_Aqk10, (b_Aqk10 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) p_b1 = tl.make_block_ptr(beta + bos * HV + i_hv, (T,), (HV,), (i_tc1,), (BC,), (0,)) b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) b_Akk10 = b_Akk10 * b_b1[:, None] if NC >= 3 and i_tc2 < T: p_Aqk20 = tl.make_block_ptr(Aqk, (T, BT), (HV*BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) p_Aqk21 = tl.make_block_ptr(Aqk, (T, BT), (HV*BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) tl.store(p_Aqk20, (b_Aqk20 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Aqk21, (b_Aqk21 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) p_b2 = tl.make_block_ptr(beta + bos * HV + i_hv, (T,), (HV,), (i_tc2,), (BC,), (0,)) b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) b_Akk20 = b_Akk20 * b_b2[:, None] b_Akk21 = b_Akk21 * b_b2[:, None] if NC >= 4 and i_tc3 < T: p_Aqk30 = tl.make_block_ptr(Aqk, (T, BT), (HV*BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) p_Aqk31 = tl.make_block_ptr(Aqk, (T, BT), (HV*BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) p_Aqk32 = tl.make_block_ptr(Aqk, (T, BT), (HV*BT, 1), (i_tc3, 2*BC), (BC, BC), (1, 0)) tl.store(p_Aqk30, (b_Aqk30 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Aqk31, (b_Aqk31 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Aqk32, (b_Aqk32 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) p_b3 = tl.make_block_ptr(beta + bos * HV + i_hv, (T,), (HV,), (i_tc3,), (BC,), (0,)) b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) b_Akk30 = b_Akk30 * b_b3[:, None] b_Akk31 = b_Akk31 * b_b3[:, None] b_Akk32 = b_Akk32 * b_b3[:, None] p_Akk00 = tl.make_block_ptr(Akkd, (T, BC), (HV*BC, 1), (i_tc0, 0), (BC, BC), (1, 0)) p_Akk11 = tl.make_block_ptr(Akkd, (T, BC), (HV*BC, 1), (i_tc1, 0), (BC, BC), (1, 0)) b_Ai00 = tl.load(p_Akk00, boundary_check=(0, 1)).to(tl.float32) b_Ai11 = tl.load(p_Akk11, boundary_check=(0, 1)).to(tl.float32) if NC >= 3: p_Akk22 = tl.make_block_ptr(Akkd, (T, BC), (HV*BC, 1), (i_tc2, 0), (BC, BC), (1, 0)) b_Ai22 = tl.load(p_Akk22, boundary_check=(0, 1)).to(tl.float32) if NC >= 4: p_Akk33 = tl.make_block_ptr(Akkd, (T, BC), (HV*BC, 1), (i_tc3, 0), (BC, BC), (1, 0)) b_Ai33 = tl.load(p_Akk33, boundary_check=(0, 1)).to(tl.float32) ################################################################################ # forward substitution on diagonals ################################################################################ if not USE_SAFE_GATE: m_A = o_i[:, None] > o_i[None, :] m_I = o_i[:, None] == o_i[None, :] b_Ai00 = -tl.where(m_A, b_Ai00, 0) b_Ai11 = -tl.where(m_A, b_Ai11, 0) if NC >= 3: b_Ai22 = -tl.where(m_A, b_Ai22, 0) if NC >= 4: b_Ai33 = -tl.where(m_A, b_Ai33, 0) for i in range(2, min(BC, T - i_tc0)): b_a00 = -tl.load(Akkd + (i_tc0 + i) * HV*BC + o_i) b_a00 = tl.where(o_i < i, b_a00, 0.) b_a00 += tl.sum(b_a00[:, None] * b_Ai00, 0) b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) for i in range(BC + 2, min(2*BC, T - i_tc0)): b_a11 = -tl.load(Akkd + (i_tc0 + i) * HV*BC + o_i) b_a11 = tl.where(o_i < i - BC, b_a11, 0.) b_a11 += tl.sum(b_a11[:, None] * b_Ai11, 0) b_Ai11 = tl.where((o_i == i - BC)[:, None], b_a11, b_Ai11) if NC >= 3: for i in range(2*BC + 2, min(3*BC, T - i_tc0)): b_a22 = -tl.load(Akkd + (i_tc0 + i) * HV*BC + o_i) b_a22 = tl.where(o_i < i - 2*BC, b_a22, 0.) b_a22 += tl.sum(b_a22[:, None] * b_Ai22, 0) b_Ai22 = tl.where((o_i == i - 2*BC)[:, None], b_a22, b_Ai22) if NC >= 4: for i in range(3*BC + 2, min(4*BC, T - i_tc0)): b_a33 = -tl.load(Akkd + (i_tc0 + i) * HV*BC + o_i) b_a33 = tl.where(o_i < i - 3*BC, b_a33, 0.) b_a33 += tl.sum(b_a33[:, None] * b_Ai33, 0) b_Ai33 = tl.where((o_i == i - 3*BC)[:, None], b_a33, b_Ai33) b_Ai00 += m_I b_Ai11 += m_I if NC >= 3: b_Ai22 += m_I if NC >= 4: b_Ai33 += m_I ################################################################################ # compute merged inverse using off-diagonals ################################################################################ # we used tf32 to maintain matrix inverse's precision whenever possible. b_Ai10 = -tl.dot( tl.dot(b_Ai11, b_Akk10, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION ) if NC >= 3: b_Ai21 = -tl.dot( tl.dot(b_Ai22, b_Akk21, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION ) b_Ai20 = -tl.dot( b_Ai22, tl.dot(b_Akk20, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + tl.dot(b_Akk21, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION), input_precision=SOLVE_TRIL_DOT_PRECISION ) if NC >= 4: b_Ai32 = -tl.dot( tl.dot(b_Ai33, b_Akk32, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai22, input_precision=SOLVE_TRIL_DOT_PRECISION ) b_Ai31 = -tl.dot( b_Ai33, tl.dot(b_Akk31, b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + tl.dot(b_Akk32, b_Ai21, input_precision=SOLVE_TRIL_DOT_PRECISION), input_precision=SOLVE_TRIL_DOT_PRECISION ) b_Ai30 = -tl.dot( b_Ai33, tl.dot(b_Akk30, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + tl.dot(b_Akk31, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION) + tl.dot(b_Akk32, b_Ai20, input_precision=SOLVE_TRIL_DOT_PRECISION), input_precision=SOLVE_TRIL_DOT_PRECISION ) ################################################################################ # store full Akk_inv to Akk ################################################################################ p_Akk00 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) p_Akk10 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) p_Akk11 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc1, BC), (BC, BC), (1, 0)) tl.store(p_Akk00, b_Ai00.to(Akk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Akk10, b_Ai10.to(Akk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Akk11, b_Ai11.to(Akk.dtype.element_ty), boundary_check=(0, 1)) if NC >= 3: p_Akk20 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) p_Akk21 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) p_Akk22 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc2, 2*BC), (BC, BC), (1, 0)) tl.store(p_Akk20, b_Ai20.to(Akk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Akk21, b_Ai21.to(Akk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Akk22, b_Ai22.to(Akk.dtype.element_ty), boundary_check=(0, 1)) if NC >= 4: p_Akk30 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) p_Akk31 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) p_Akk32 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc3, 2*BC), (BC, BC), (1, 0)) p_Akk33 = tl.make_block_ptr(Akk, (T, BT), (HV*BT, 1), (i_tc3, 3*BC), (BC, BC), (1, 0)) tl.store(p_Akk30, b_Ai30.to(Akk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Akk31, b_Ai31.to(Akk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Akk32, b_Ai32.to(Akk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Akk33, b_Ai33.to(Akk.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['B', 'T']) def kda_bwd_kernel_intra( q, k, g, beta, dAqk, dAkk, dq, dq2, dk, dk2, dg, dg2, db, cu_seqlens, chunk_indices, B, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, IS_VARLEN: tl.constexpr, SAFE_GATE: tl.constexpr, USE_GATHER: tl.constexpr, ): i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hv = i_bh // HV, i_bh % HV i_h = i_hv // (HV // H) i_k, i_i = i_kc // NC, i_kc % NC all = B * T if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T T = eos - bos i_ti = i_t * BT + i_i * BC if i_ti >= T: return o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K q += (bos * H + i_h) * K k += (bos * H + i_h) * K g += (bos * HV + i_hv) * K beta += bos * HV + i_hv dAqk += (bos * HV + i_hv) * BT dAkk += (bos * HV + i_hv) * BT dq += (bos * HV + i_hv) * K dq2 += (bos * HV + i_hv) * K dk += (bos * HV + i_hv) * K dk2 += (bos * HV + i_hv) * K dg += (bos * HV + i_hv) * K dg2 += (bos * HV + i_hv) * K db += (i_k * all + bos) * HV + i_hv p_g = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_ti,), (BC,), (0,)) b_b = tl.load(p_b, boundary_check=(0,)) b_dq2 = tl.zeros([BC, BK], dtype=tl.float32) b_dk2 = tl.zeros([BC, BK], dtype=tl.float32) if i_i > 0: p_gn = g + i_ti * HV*K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] for i_j in range(0, i_i): p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) p_gk = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (HV*BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (HV*BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_kg = b_k * exp2(b_gn - b_gk) # [BC, BC] b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) # [BC, BK] b_dq2 += tl.dot(b_dAqk, b_kg) b_dk2 += tl.dot(b_dAkk, b_kg) b_gqn = exp2(b_g - b_gn) b_dq2 *= b_gqn b_dk2 *= b_gqn o_i = tl.arange(0, BC) m_dA = (i_ti + o_i) < T o_dA = (i_ti + o_i) * HV*BT + i_i * BC p_kj = k + i_ti * H*K + o_k p_gkj = g + i_ti * HV*K + o_k p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if SAFE_GATE: if USE_GATHER: b_gn = gather(b_g, tl.full([1, BK], min(BC//2, T - i_ti - 1), dtype=tl.int16), axis=0) else: p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * HV*K + o_k b_gn = tl.load(p_gn, mask=m_k, other=0)[None, :] p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (HV*BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (HV*BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) b_dAqk_diag_qk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) b_dAkk_diag_qk = tl.load(p_dAkk, boundary_check=(0, 1)).to(tl.float32) m_i_diag_qk = (o_i[:, None] >= o_i[None, :]) & ((i_ti + o_i[:, None]) < T) & ((i_ti + o_i[None, :]) < T) m_j_diag_qk = (i_ti + o_i[:, None]) < T b_dAqk_diag_qk = tl.where(m_i_diag_qk, b_dAqk_diag_qk, 0.) b_dAkk_diag_qk = tl.where(m_i_diag_qk, b_dAkk_diag_qk, 0.) b_g_diag_qk = tl.where(m_j_diag_qk, b_g - b_gn, 0.) exp_b_g_diag_qk = tl.where(m_j_diag_qk, exp2(b_g_diag_qk), 0.) exp_neg_b_g_diag_qk = tl.where(m_j_diag_qk, exp2(-b_g_diag_qk), 0.) b_k_exp_diag_qk = b_k * exp_neg_b_g_diag_qk b_dq2 += tl.dot(b_dAqk_diag_qk, b_k_exp_diag_qk) * exp_b_g_diag_qk b_dk2 += tl.dot(b_dAkk_diag_qk, b_k_exp_diag_qk) * exp_b_g_diag_qk else: for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC] b_dAqk = tl.load(dAqk + o_dA + j, mask=m_dA, other=0) b_dAkk = tl.load(dAkk + o_dA + j, mask=m_dA, other=0) # [BK] b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) # [BC, BK] m_i = o_i[:, None] >= j # [BC, BK] b_gqk = exp2(b_g - b_gkj[None, :]) b_dq2 += tl.where(m_i, b_dAqk[:, None] * b_kj[None, :] * b_gqk, 0.) b_dk2 += tl.where(m_i, b_dAkk[:, None] * b_kj[None, :] * b_gqk, 0.) p_kj += H*K p_gkj += HV*K b_db = tl.sum(b_dk2 * b_k, 1) b_dk2 *= b_b[:, None] p_dq = tl.make_block_ptr(dq, (T, K), (HV*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) p_dq2 = tl.make_block_ptr(dq2, (T, K), (HV*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) p_db = tl.make_block_ptr(db, (T,), (HV,), (i_ti,), (BC,), (0,)) b_dg2 = b_q * b_dq2 b_dq2 = b_dq2 + tl.load(p_dq, boundary_check=(0, 1)) tl.store(p_dq2, b_dq2.to(p_dq2.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) tl.debug_barrier() b_dkt = tl.zeros([BC, BK], dtype=tl.float32) NC = min(NC, tl.cdiv(T - i_t * BT, BC)) if i_i < NC - 1: p_gn = g + (min(i_ti + BC, T) - 1) * HV*K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] for i_j in range(i_i + 1, NC): p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) p_gk = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_t * BT + i_j * BC,), (BC,), (0,)) p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, HV*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, HV*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) # [BC] b_b = tl.load(p_b, boundary_check=(0,)) # [BC, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_kb = tl.load(p_k, boundary_check=(0, 1)) * b_b[:, None] b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32) # [BC, BC] b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) o_j = i_t * BT + i_j * BC + o_i m_j = o_j < T # [BC, BK] b_gkn = exp2(b_gk - b_gn) b_qg = b_q * tl.where(m_j[:, None], b_gkn, 0) b_kbg = b_kb * tl.where(m_j[:, None], b_gkn, 0) # [BC, BK] # (SY 09/17) important to not use bf16 here to have a good precision. b_dkt += tl.dot(b_dAqk, b_qg) b_dkt += tl.dot(b_dAkk, b_kbg) b_dkt *= exp2(b_gn - b_g) o_dA = i_ti * HV*BT + i_i * BC + o_i p_qj = q + i_ti * H*K + o_k p_kj = k + i_ti * H*K + o_k p_gkj = g + i_ti * HV*K + o_k p_bj = beta + i_ti * HV if SAFE_GATE: if USE_GATHER: b_gn = gather(b_g, tl.full([1, BK], min(BC//2, T - i_ti - 1), dtype=tl.int16), axis=0) else: p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * HV*K + o_k b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_ti,), (BC,), (0,)) b_b = tl.load(p_b, boundary_check=(0,)) p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, HV*BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, HV*BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) b_dAqk_diag_kk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) b_dAkk_diag_kk = tl.load(p_dAkk, boundary_check=(0, 1)).to(tl.float32) m_i_diag_kk = (o_i[:, None] <= o_i[None, :]) & ((i_ti + o_i[:, None]) < T) & ((i_ti + o_i[None, :]) < T) m_j_diag_kk = (i_ti + o_i[:, None]) < T b_dAqk_diag_kk = tl.where(m_i_diag_kk, b_dAqk_diag_kk, 0.) b_dAkk_diag_kk = tl.where(m_i_diag_kk, b_dAkk_diag_kk, 0.) # ensure numerical stability b_g_diag_kk = tl.where(m_j_diag_kk, b_g - b_gn, 0.) exp_b_g_diag_kk = tl.where(m_j_diag_kk, exp2(b_g_diag_kk), 0.) exp_neg_b_g_diag_kk = tl.where(m_j_diag_kk, exp2(-b_g_diag_kk), 0.) b_q_exp = b_q * exp_b_g_diag_kk b_kb_exp = b_k * b_b[:, None] * exp_b_g_diag_kk b_dkt += tl.dot(b_dAqk_diag_kk, b_q_exp) * exp_neg_b_g_diag_kk b_dkt += tl.dot(b_dAkk_diag_kk, b_kb_exp) * exp_neg_b_g_diag_kk else: for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] b_dAqk = tl.load(dAqk + o_dA + j * HV*BT) b_dAkk = tl.load(dAkk + o_dA + j * HV*BT) # [BK,] b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) b_kbj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) * tl.load(p_bj) b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) # [BC, BK] m_i = o_i[:, None] <= j b_gkq = exp2(b_gkj[None, :] - b_g) b_dkt += tl.where(m_i, b_dAqk[:, None] * b_qj[None, :] * b_gkq, 0.) b_dkt += tl.where(m_i, b_dAkk[:, None] * b_kbj[None, :] * b_gkq, 0.) p_qj += H*K p_kj += H*K p_gkj += HV*K p_bj += HV p_dk = tl.make_block_ptr(dk, (T, K), (HV*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) p_dk2 = tl.make_block_ptr(dk2, (T, K), (HV*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) p_dg = tl.make_block_ptr(dg, (T, K), (HV*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) p_dg2 = tl.make_block_ptr(dg2, (T, K), (HV*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) b_dg2 += (b_dk2 - b_dkt) * b_k + tl.load(p_dg, boundary_check=(0, 1)) b_dk2 += tl.load(p_dk, boundary_check=(0, 1)) b_dk2 += b_dkt tl.store(p_dk2, b_dk2.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dg2, b_dg2.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def kda_fwd_kernel_intra_sub_chunk( q, k, g, beta, Aqk, Akk, scale, cu_seqlens, chunk_indices, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, IS_VARLEN: tl.constexpr, USE_GATHER: tl.constexpr, ): i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hv = i_bh // HV, i_bh % HV i_h = i_hv // (HV // H) if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T i_ti = i_t * BT + i_i * BC if i_ti >= T: return o_c = i_ti + tl.arange(0, BC) m_c = o_c < T q = q + (bos * H + i_h) * K k = k + (bos * H + i_h) * K g = g + (bos * HV + i_hv) * K beta = beta + bos * HV + i_hv Aqk = Aqk + (bos * HV + i_hv) * BT Akk = Akk + (bos * HV + i_hv) * BC p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_ti, 0), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_ti, 0), (BC, BK), (1, 0)) p_g = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_ti, 0), (BC, BK), (1, 0)) p_beta = tl.make_block_ptr(beta, (T,), (HV,), (i_ti,), (BC,), (0,)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) b_beta = tl.load(p_beta, boundary_check=(0,)) if USE_GATHER: b_gn = gather(b_g, tl.full([1, BK], min(BC//2, T - i_ti - 1), dtype=tl.int16), axis=0) else: # caculate offset p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * HV*K + tl.arange(0, BK) b_gn = tl.load(p_gn, mask=tl.arange(0, BK) < K, other=0.0) b_gn = b_gn[None, :] # current block, keep numerical stability by subtracting the left boundary # less than 85 to avoid overflow in exp2 b_gm = (b_g - b_gn).to(tl.float32) b_gq = tl.where(m_c[:, None], exp2(b_gm), 0.) b_gk = tl.where(m_c[:, None], exp2(-b_gm), 0.) b_kgt = tl.trans(b_k * b_gk) b_Aqk = tl.dot(b_q * b_gq, b_kgt) * scale b_Akk = tl.dot(b_k * b_gq, b_kgt) * b_beta[:, None] o_i = tl.arange(0, BC) m_Aqk = o_i[:, None] >= o_i[None, :] m_Akk = o_i[:, None] > o_i[None, :] m_I = o_i[:, None] == o_i[None, :] b_Aqk = tl.where(m_Aqk, b_Aqk, 0.0) b_Akk = tl.where(m_Akk, b_Akk, 0.0) p_Aqk = tl.make_block_ptr(Aqk, (T, BT), (HV*BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) p_Akk = tl.make_block_ptr(Akk, (T, BC), (HV*BC, 1), (i_ti, 0), (BC, BC), (1, 0)) tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_Akk, b_Akk.to(Akk.dtype.element_ty), boundary_check=(0, 1)) tl.debug_barrier() ################################################################################ # forward substitution ################################################################################ b_Ai = -b_Akk for i in range(2, min(BC, T - i_ti)): b_a = -tl.load(Akk + (i_ti + i) * HV*BC + o_i) b_a = tl.where(o_i < i, b_a, 0.) b_a += tl.sum(b_a[:, None] * b_Ai, 0) b_Ai = tl.where((o_i == i)[:, None], b_a, b_Ai) b_Ai += m_I tl.store(p_Akk, b_Ai.to(Akk.dtype.element_ty), boundary_check=(0, 1)) def kda_fwd_intra( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gk: torch.Tensor | None = None, beta: torch.Tensor | None = None, scale: float | None = None, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, chunk_indices: torch.LongTensor | None = None, safe_gate: bool = False, disable_recompute: bool = False, ): B, T, H, K, HV = *k.shape, gk.shape[2] BT = chunk_size if BT not in (32, 64): raise ValueError(f"KDA intra chunk kernel only supports chunk_size 32 or 64, got {BT}.") BC = 16 if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) NC = triton.cdiv(BT, BC) Aqk = torch.empty(B, T, HV, BT, device=k.device, dtype=k.dtype) # Akk must be zero-initialized - kernel only writes lower triangular Akk = torch.zeros(B, T, HV, BT, device=k.device, dtype=k.dtype) # Separate fp32 buffer for diagonal 16x16 blocks (for precision in solve_tril) Akkd = torch.empty(B, T, HV, BC, device=k.device, dtype=torch.float32) # Step 1: Run token_parallel first to compute diagonal blocks into Akkd (fp32) # Step 1: compute diagonal blocks into Akk_diag (fp32) if safe_gate: grid = (NT, NC, B * HV) BK = triton.next_power_of_2(K) kda_fwd_kernel_intra_sub_chunk[grid]( q=q, k=k, g=gk, beta=beta, Aqk=Aqk, Akk=Akkd, scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, HV=HV, K=K, BT=BT, BC=BC, BK=BK, USE_GATHER=False, ) else: Aqk, Akkd = kda_fwd_intra_token_parallel( q=q, k=k, gk=gk, beta=beta, Aqk=Aqk, Akk=Akkd, scale=scale, cu_seqlens=cu_seqlens, chunk_size=BT, sub_chunk_size=BC, ) # Step 2: Fused inter + solve_tril (works for both fixed-len and varlen) grid = (NT, B * HV) kda_fwd_kernel_inter_solve_fused[grid]( q=q, k=k, g=gk, beta=beta, Aqk=Aqk, Akkd=Akkd, Akk=Akk, scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, HV=HV, K=K, BT=BT, BC=BC, NC=NC, USE_SAFE_GATE=safe_gate, ) w, u, qg, kg = recompute_w_u_fwd( k=k, v=v, beta=beta, A=Akk, q=q if disable_recompute else None, gk=gk, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, ) return w, u, qg, kg, Aqk, Akk def kda_bwd_intra( q: torch.Tensor, k: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, dAqk: torch.Tensor, dAkk: torch.Tensor, dq: torch.Tensor, dk: torch.Tensor, db: torch.Tensor, dg: torch.Tensor, cu_seqlens: torch.LongTensor | None = None, chunk_indices: torch.LongTensor | None = None, chunk_size: int = 64, safe_gate: bool = False, ): B, T, H, K, HV = *k.shape, g.shape[2] BT = chunk_size BC = min(16, BT) BK = min(32, triton.next_power_of_2(K)) if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) NC = triton.cdiv(BT, BC) NK = triton.cdiv(K, BK) dq2 = torch.empty_like(dq) dk2 = torch.empty_like(dk) db2 = beta.new_empty(NK, *beta.shape, dtype=torch.float) dg2 = torch.empty_like(dg, dtype=torch.float) grid = (NK * NC, NT, B * HV) kda_bwd_kernel_intra[grid]( q=q, k=k, g=g, beta=beta, dAqk=dAqk, dAkk=dAkk, dq=dq, dq2=dq2, dk=dk, dk2=dk2, dg=dg, dg2=dg2, db=db2, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, B=B, T=T, H=H, HV=HV, K=K, BT=BT, BC=BC, BK=BK, NC=NC, SAFE_GATE=safe_gate, USE_GATHER=False, ) dq = dq2 dk = dk2 db = db2.sum(0).add_(db) dg = dg2 return dq, dk, db, dg # === wy_fast.py === # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # For a list of all contributors, visit: # https://github.com/fla-org/flash-linear-attention/graphs/contributors import torch import triton import triton.language as tl @triton.jit(do_not_specialize=['T']) def recompute_w_u_fwd_kda_kernel( q, k, qg, kg, v, beta, w, u, A, gk, cu_seqlens, chunk_indices, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, STORE_QG: tl.constexpr, STORE_KG: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_hv = i_bh // HV, i_bh % HV i_h = i_hv // (HV // H) if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T k += (bos * H + i_h) * K v += (bos * HV + i_hv) * V u += (bos * HV + i_hv) * V w += (bos * HV + i_hv) * K gk += (bos * HV + i_hv) * K beta += bos * HV + i_hv A += (bos * HV + i_hv) * BT if STORE_QG: q += (bos * H + i_h) * K qg += (bos * HV + i_hv) * K if STORE_KG: kg += (bos * HV + i_hv) * K p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_t * BT,), (BT,), (0,)) b_b = tl.load(p_b, boundary_check=(0,)) p_A = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) b_A = tl.load(p_A, boundary_check=(0, 1)) for i_v in range(tl.cdiv(V, BV)): p_v = tl.make_block_ptr(v, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_u = tl.make_block_ptr(u, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_vb = (b_v * b_b[:, None]).to(b_v.dtype) b_u = tl.dot(b_A, b_vb) tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) for i_k in range(tl.cdiv(K, BK)): p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = b_k * b_b[:, None] p_gk = tl.make_block_ptr(gk, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32) b_kb *= exp2(b_gk) if STORE_QG: p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_qg = tl.make_block_ptr(qg, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_qg = b_q * exp2(b_gk) tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1)) if STORE_KG: last_idx = min(i_t * BT + BT, T) - 1 o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K b_gn = tl.load(gk + last_idx * HV*K + o_k, mask=m_k, other=0.).to(tl.float32) b_kg = b_k * tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], exp2(b_gn[None, :] - b_gk), 0) p_kg = tl.make_block_ptr(kg, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1)) b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def prepare_wy_repr_bwd_kda_kernel( k, v, beta, gk, A, dA, dw, du, dk, dk2, dv, db, dg, dg2, cu_seqlens, chunk_indices, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_hv = i_bh // HV, i_bh % HV i_h = i_hv // (HV // H) if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T k += (bos * H + i_h) * K v += (bos * HV + i_hv) * V beta += bos * HV + i_hv gk += (bos * HV + i_hv) * K A += (bos * HV + i_hv) * BT dA += (bos * HV + i_hv) * BT dk += (bos * HV + i_hv) * K dk2 += (bos * HV + i_hv) * K dw += (bos * HV + i_hv) * K du += (bos * HV + i_hv) * V dv += (bos * HV + i_hv) * V db += bos * HV + i_hv dg += (bos * HV + i_hv) * K dg2 += (bos * HV + i_hv) * K p_b = tl.make_block_ptr(beta, (T,), (HV,), (i_t * BT,), (BT,), (0,)) p_db = tl.make_block_ptr(db, (T,), (HV,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A, (BT, T), (1, HV*BT), (0, i_t * BT), (BT, BT), (0, 1)) b_b = tl.load(p_b, boundary_check=(0,)) b_db = tl.zeros([BT], dtype=tl.float32) b_A = tl.load(p_A, boundary_check=(0, 1)) b_dA = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk2 = tl.make_block_ptr(dk2, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dw = tl.make_block_ptr(dw, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dg = tl.make_block_ptr(dg, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dg2 = tl.make_block_ptr(dg2, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) # [BT, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) p_gk = tl.make_block_ptr(gk, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_gk_exp = exp2(tl.load(p_gk, boundary_check=(0, 1))) b_kbg = b_k * b_b[:, None] * b_gk_exp b_dw = tl.load(p_dw, boundary_check=(0, 1)) b_dA += tl.dot(b_dw, tl.trans(b_kbg).to(b_dw.dtype)) b_dkbg = tl.dot(b_A, b_dw) b_dk = b_dkbg * b_gk_exp * b_b[:, None] + tl.load(p_dk, boundary_check=(0, 1)) b_db += tl.sum(b_dkbg * b_k * b_gk_exp, 1) b_dg = b_kbg * b_dkbg + tl.load(p_dg, boundary_check=(0, 1)) tl.store(p_dk2, b_dk.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dg2, b_dg.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) for i_v in range(tl.cdiv(V, BV)): p_v = tl.make_block_ptr(v, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv = tl.make_block_ptr(dv, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_du = tl.make_block_ptr(du, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_vb = (b_v * b_b[:, None]).to(b_v.dtype) b_du = tl.load(p_du, boundary_check=(0, 1)) b_dA += tl.dot(b_du, tl.trans(b_vb)) b_dvb = tl.dot(b_A, b_du) b_dv = b_dvb * b_b[:, None] b_db += tl.sum(b_dvb * b_v, 1) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) o_t = i_t * BT + tl.arange(0, BT) m_t = o_t < T m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) b_dA = tl.where(m_A, b_dA, 0) b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) b_dA = tl.where(m_A, -b_dA, 0) p_dA = tl.make_block_ptr(dA, (T, BT), (HV*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) def recompute_w_u_fwd( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, gk: torch.Tensor, q: torch.Tensor | None = None, cu_seqlens: torch.LongTensor | None = None, chunk_indices: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] BT = A.shape[-1] BK = 64 BV = 64 if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) w = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) u = torch.empty_like(v) qg = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) if q is not None else None kg = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) recompute_w_u_fwd_kda_kernel[(NT, B*HV)]( q=q, k=k, qg=qg, kg=kg, v=v, beta=beta, w=w, u=u, A=A, gk=gk, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, HV=HV, K=K, V=V, BT=BT, BK=BK, BV=BV, STORE_QG=q is not None, STORE_KG=True, IS_VARLEN=cu_seqlens is not None, ) return w, u, qg, kg def prepare_wy_repr_bwd( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, gk: torch.Tensor, A: torch.Tensor, dk: torch.Tensor, dw: torch.Tensor, du: torch.Tensor, dg: torch.Tensor, cu_seqlens: torch.LongTensor | None = None, chunk_indices: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] BT = A.shape[-1] if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) dk2 = torch.empty_like(dk, dtype=torch.float) dv = torch.empty_like(v) dg2 = torch.empty_like(gk, dtype=torch.float) dA = torch.empty_like(A, dtype=torch.float) db = torch.empty_like(beta, dtype=torch.float) prepare_wy_repr_bwd_kda_kernel[(NT, B * HV)]( k=k, v=v, beta=beta, gk=gk, A=A, dA=dA, dw=dw, du=du, dk=dk, dk2=dk2, dv=dv, db=db, dg=dg, dg2=dg2, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, HV=HV, K=K, V=V, BT=BT, BK=BK, BV=BV, ) dk = dk2 dg = dg2 return dk, dv, db, dg, dA # === chunk_delta_h.py === # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # For a list of all contributors, visit: # https://github.com/fla-org/flash-linear-attention/graphs/contributors import torch import triton import triton.language as tl @triton.jit(do_not_specialize=['T']) def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( k, v, w, v_new, g, gk, h, h0, ht, cu_seqlens, chunk_offsets, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BV: tl.constexpr, USE_G: tl.constexpr, USE_GK: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, SAVE_NEW_VALUE: tl.constexpr, STATE_V_FIRST: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // HV, i_nh % HV if IS_VARLEN: bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) boh = tl.load(chunk_offsets + i_n).to(tl.int32) else: bos, eos = i_n * T, i_n * T + T NT = tl.cdiv(T, BT) boh = i_n * NT if STATE_V_FIRST: b_h1 = tl.zeros([BV, 64], dtype=tl.float32) if K > 64: b_h2 = tl.zeros([BV, 64], dtype=tl.float32) if K > 128: b_h3 = tl.zeros([BV, 64], dtype=tl.float32) if K > 192: b_h4 = tl.zeros([BV, 64], dtype=tl.float32) else: b_h1 = tl.zeros([64, BV], dtype=tl.float32) if K > 64: b_h2 = tl.zeros([64, BV], dtype=tl.float32) if K > 128: b_h3 = tl.zeros([64, BV], dtype=tl.float32) if K > 192: b_h4 = tl.zeros([64, BV], dtype=tl.float32) # calculate offset h += (boh * HV + i_h).to(tl.int64) * K*V v += (bos * HV + i_h).to(tl.int64) * V k += (bos * H + i_h // (HV // H)).to(tl.int64) * K w += (bos * HV + i_h).to(tl.int64) * K if SAVE_NEW_VALUE: v_new += (bos * HV + i_h).to(tl.int64) * V if USE_INITIAL_STATE: h0 = h0 + i_nh * K*V if STORE_FINAL_STATE: ht = ht + i_nh * K*V # load initial state if USE_INITIAL_STATE: if STATE_V_FIRST: p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) else: p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) if K > 64: if STATE_V_FIRST: p_h0_2 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) else: p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) if K > 128: if STATE_V_FIRST: p_h0_3 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) else: p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) if K > 192: if STATE_V_FIRST: p_h0_4 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) else: p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) # main recurrence for i_t in range(NT): i_t_int64 = i_t.to(tl.int64) if STATE_V_FIRST: p_h1 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) else: p_h1 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) if K > 64: if STATE_V_FIRST: p_h2 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) else: p_h2 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) if K > 128: if STATE_V_FIRST: p_h3 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) else: p_h3 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) if K > 192: if STATE_V_FIRST: p_h4 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) else: p_h4 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, 0), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) if STATE_V_FIRST: b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype)) else: b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) if K > 64: p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, 64), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) if STATE_V_FIRST: b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype)) else: b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) if K > 128: p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, 128), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) if STATE_V_FIRST: b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype)) else: b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) if K > 192: p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, 192), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) if STATE_V_FIRST: b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype)) else: b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) p_v = tl.make_block_ptr(v, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v if SAVE_NEW_VALUE: p_v = tl.make_block_ptr(v_new, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) last_idx = min((i_t + 1) * BT, T) - 1 if USE_G: m_t = (i_t * BT + tl.arange(0, BT)) < T b_g_last = tl.load(g + (bos * HV + last_idx * HV + i_h).to(tl.int64)).to(tl.float32) p_g = tl.make_block_ptr(g + (bos * HV + i_h).to(tl.int64), (T,), (HV,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) b_v = b_v * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None] b_g_last = exp2(b_g_last) b_h1 *= b_g_last if K > 64: b_h2 *= b_g_last if K > 128: b_h3 *= b_g_last if K > 192: b_h4 *= b_g_last if USE_GK: o_k1 = tl.arange(0, 64) b_gk_last1 = tl.load(gk + (bos + last_idx) * HV*K + i_h * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) if STATE_V_FIRST: b_h1 *= exp2(b_gk_last1)[None, :] else: b_h1 *= exp2(b_gk_last1)[:, None] if K > 64: o_k2 = 64 + o_k1 b_gk_last2 = tl.load(gk + (bos + last_idx) * HV*K + i_h * K + o_k2, mask=(o_k2 < K), other=0.).to(tl.float32) if STATE_V_FIRST: b_h2 *= exp2(b_gk_last2)[None, :] else: b_h2 *= exp2(b_gk_last2)[:, None] if K > 128: o_k3 = 128 + o_k1 b_gk_last3 = tl.load(gk + (bos + last_idx) * HV*K + i_h * K + o_k3, mask=(o_k3 < K), other=0.).to(tl.float32) if STATE_V_FIRST: b_h3 *= exp2(b_gk_last3)[None, :] else: b_h3 *= exp2(b_gk_last3)[:, None] if K > 192: o_k4 = 192 + o_k1 b_gk_last4 = tl.load(gk + (bos + last_idx) * HV*K + i_h * K + o_k4, mask=(o_k4 < K), other=0.).to(tl.float32) if STATE_V_FIRST: b_h4 *= exp2(b_gk_last4)[None, :] else: b_h4 *= exp2(b_gk_last4)[:, None] b_v = b_v.to(k.dtype.element_ty) p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if STATE_V_FIRST: b_h1 += tl.trans(tl.dot(b_k, b_v)) else: b_h1 += tl.dot(b_k, b_v) if K > 64: p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if STATE_V_FIRST: b_h2 += tl.trans(tl.dot(b_k, b_v)) else: b_h2 += tl.dot(b_k, b_v) if K > 128: p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if STATE_V_FIRST: b_h3 += tl.trans(tl.dot(b_k, b_v)) else: b_h3 += tl.dot(b_k, b_v) if K > 192: p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if STATE_V_FIRST: b_h4 += tl.trans(tl.dot(b_k, b_v)) else: b_h4 += tl.dot(b_k, b_v) if STORE_FINAL_STATE: if STATE_V_FIRST: p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) else: p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 64: if STATE_V_FIRST: p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) else: p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 128: if STATE_V_FIRST: p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) else: p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 192: if STATE_V_FIRST: p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) else: p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( q, k, w, g, gk, dht, dh0, do, dh, dv, dv2, cu_seqlens, chunk_offsets, scale, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BV: tl.constexpr, USE_G: tl.constexpr, USE_GK: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, STATE_V_FIRST: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // HV, i_nh % HV if IS_VARLEN: bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) boh = tl.load(chunk_offsets + i_n).to(tl.int32) else: bos, eos = i_n * T, i_n * T + T NT = tl.cdiv(T, BT) boh = i_n * NT if STATE_V_FIRST: b_dh1 = tl.zeros([BV, 64], dtype=tl.float32) if K > 64: b_dh2 = tl.zeros([BV, 64], dtype=tl.float32) if K > 128: b_dh3 = tl.zeros([BV, 64], dtype=tl.float32) if K > 192: b_dh4 = tl.zeros([BV, 64], dtype=tl.float32) else: b_dh1 = tl.zeros([64, BV], dtype=tl.float32) if K > 64: b_dh2 = tl.zeros([64, BV], dtype=tl.float32) if K > 128: b_dh3 = tl.zeros([64, BV], dtype=tl.float32) if K > 192: b_dh4 = tl.zeros([64, BV], dtype=tl.float32) # calculate offset q += (bos * H + i_h // (HV // H)).to(tl.int64) * K k += (bos * H + i_h // (HV // H)).to(tl.int64) * K w += (bos * HV + i_h).to(tl.int64) * K do += (bos * HV + i_h).to(tl.int64) * V dv += (bos * HV + i_h).to(tl.int64) * V dv2 += (bos * HV + i_h).to(tl.int64) * V dh += (boh * HV + i_h).to(tl.int64) * K*V if USE_GK: gk += (bos * HV + i_h).to(tl.int64) * K if USE_INITIAL_STATE: dh0 += i_nh * K*V if USE_FINAL_STATE_GRADIENT: dht += i_nh * K*V if USE_FINAL_STATE_GRADIENT: if STATE_V_FIRST: p_dht1 = tl.make_block_ptr(dht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) else: p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) if K > 64: if STATE_V_FIRST: p_dht2 = tl.make_block_ptr(dht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) else: p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) if K > 128: if STATE_V_FIRST: p_dht3 = tl.make_block_ptr(dht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) else: p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) if K > 192: if STATE_V_FIRST: p_dht4 = tl.make_block_ptr(dht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) else: p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) for i_t in range(NT - 1, -1, -1): i_t_int64 = i_t.to(tl.int64) if STATE_V_FIRST: p_dh1 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) else: p_dh1 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) if K > 64: if STATE_V_FIRST: p_dh2 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) else: p_dh2 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) if K > 128: if STATE_V_FIRST: p_dh3 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) else: p_dh3 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) if K > 192: if STATE_V_FIRST: p_dh4 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) else: p_dh4 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) last_idx = min((i_t + 1) * BT, T) - 1 if USE_G: bg_last = tl.load(g + (bos + last_idx) * HV + i_h).to(tl.float32) p_g = tl.make_block_ptr(g + bos * HV + i_h, (T,), (HV,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) bg_last_exp = exp2(bg_last) b_g_exp = exp2(b_g) p_dv = tl.make_block_ptr(dv, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv2 = tl.make_block_ptr(dv2, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_do = tl.make_block_ptr(do, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_do = tl.load(p_do, boundary_check=(0, 1)) # Update dv p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 0), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k1 = tl.arange(0, 64) b_gk_last1 = tl.load(gk + last_idx * HV*K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) if STATE_V_FIRST: b_dv = tl.dot(b_k, tl.trans(b_dh1).to(b_k.dtype)) else: b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) if K > 64: p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 64), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k2 = 64 + o_k1 b_gk_last2 = tl.load(gk + last_idx * HV*K + o_k2, mask=(o_k2 < K), other=0.).to(tl.float32) if STATE_V_FIRST: b_dv += tl.dot(b_k, tl.trans(b_dh2).to(b_k.dtype)) else: b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) if K > 128: p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 128), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k3 = 128 + o_k1 b_gk_last3 = tl.load(gk + last_idx * HV*K + o_k3, mask=(o_k3 < K), other=0.).to(tl.float32) if STATE_V_FIRST: b_dv += tl.dot(b_k, tl.trans(b_dh3).to(b_k.dtype)) else: b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) if K > 192: p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 192), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k4 = 192 + o_k1 b_gk_last4 = tl.load(gk + last_idx * HV*K + o_k4, mask=(o_k4 < K), other=0.).to(tl.float32) if STATE_V_FIRST: b_dv += tl.dot(b_k, tl.trans(b_dh4).to(b_k.dtype)) else: b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) if USE_G: m_t = (i_t * BT + tl.arange(0, BT)) < T b_dv *= tl.where(m_t, exp2(bg_last - b_g), 0)[:, None] b_dv += tl.load(p_dv, boundary_check=(0, 1)) tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) # Update dh p_w = tl.make_block_ptr(w, (K, T), (1, HV*K), (0, i_t * BT), (64, BT), (0, 1)) p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) if USE_G: b_dh1 *= bg_last_exp b_q = b_q * b_g_exp[None, :] if USE_GK: if STATE_V_FIRST: b_dh1 *= exp2(b_gk_last1)[None, :] else: b_dh1 *= exp2(b_gk_last1[:, None]) if STATE_V_FIRST: b_dh1 += tl.trans(tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))) else: b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 64: p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, HV*K), (64, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) if USE_G: b_dh2 *= bg_last_exp b_q = b_q * b_g_exp[None, :] if USE_GK: if STATE_V_FIRST: b_dh2 *= exp2(b_gk_last2)[None, :] else: b_dh2 *= exp2(b_gk_last2[:, None]) if STATE_V_FIRST: b_dh2 += tl.trans(tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))) else: b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 128: p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, HV*K), (128, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) if USE_G: b_dh3 *= bg_last_exp b_q = b_q * b_g_exp[None, :] if USE_GK: if STATE_V_FIRST: b_dh3 *= exp2(b_gk_last3)[None, :] else: b_dh3 *= exp2(b_gk_last3[:, None]) if STATE_V_FIRST: b_dh3 += tl.trans(tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))) else: b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 192: p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, HV*K), (192, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) if USE_G: b_dh4 *= bg_last_exp b_q = b_q * b_g_exp[None, :] if USE_GK: if STATE_V_FIRST: b_dh4 *= exp2(b_gk_last4)[None, :] else: b_dh4 *= exp2(b_gk_last4[:, None]) if STATE_V_FIRST: b_dh4 += tl.trans(tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))) else: b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if USE_INITIAL_STATE: if STATE_V_FIRST: p_dh0 = tl.make_block_ptr(dh0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) else: p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) if K > 64: if STATE_V_FIRST: p_dh1 = tl.make_block_ptr(dh0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) else: p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) if K > 128: if STATE_V_FIRST: p_dh2 = tl.make_block_ptr(dh0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) else: p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) if K > 192: if STATE_V_FIRST: p_dh3 = tl.make_block_ptr(dh0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) else: p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) return h, v_new, final_state def chunk_gated_delta_rule_bwd_dhu( q: torch.Tensor, k: torch.Tensor, w: torch.Tensor, do: torch.Tensor, dv: torch.Tensor, g: torch.Tensor | None = None, gk: torch.Tensor | None = None, h0: torch.Tensor | None = None, dht: torch.Tensor | None = None, scale: float | None = None, state_v_first: bool = False, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, chunk_indices: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, T, H, K, V, HV = *q.shape, do.shape[-1], do.shape[2] # N: the actual number of sequences in the batch with either equal or variable lengths BT = chunk_size assert K <= 256, "current kernel does not support head dimension being larger than 256." if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) if state_v_first: dh = q.new_empty(B, NT, HV, V, K) else: dh = q.new_empty(B, NT, HV, K, V) dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None dv2 = torch.empty_like(dv) def grid(meta): return (triton.cdiv(V, meta['BV']), N*HV) chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[grid]( q=q, k=k, w=w, g=g, gk=gk, dht=dht, dh0=dh0, do=do, dh=dh, dv=dv, dv2=dv2, cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, scale=scale, T=T, H=H, HV=HV, K=K, V=V, BT=BT, STATE_V_FIRST=state_v_first, ) return dh, dh0, dv2 # === chunk.py === # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # For a list of all contributors, visit: # https://github.com/fla-org/flash-linear-attention/graphs/contributors import torch import triton import triton.language as tl @triton.jit(do_not_specialize=['T']) def chunk_gla_fwd_A_kernel_intra_sub_inter( q, k, g, A, cu_seqlens, chunk_indices, scale, T, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_i, i_j = i_c // NC, i_c % NC if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if i_t * BT + i_i * BC >= T: return if i_i <= i_j: return b_A = tl.zeros([BC, BC], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) # [BC, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) b_qg = b_q * exp2(b_g - b_gn[None, :]) * scale # [BK, BC] b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_kg = b_k * exp2(b_gn[:, None] - b_gk) # [BC, BC] using tf32 to improve precision here. b_A += tl.dot(b_qg, b_kg) p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def chunk_gla_fwd_A_kernel_intra_sub_intra( q, k, g, A, cu_seqlens, chunk_indices, scale, T, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_j = i_i if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if i_t * BT + i_i * BC >= T: return o_i = tl.arange(0, BC) o_k = tl.arange(0, BK) o_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_j * BC m_k = o_k < K m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T q += (bos * H + i_h) * K k += (bos * H + i_h) * K g += (bos * H + i_h) * K A += (bos * H + i_h) * BT p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) p_k = k + (i_t * BT + i_j * BC) * H*K + o_k p_gk = g + (i_t * BT + i_j * BC) * H*K + o_k for j in range(0, min(BC, T - i_t * BT - i_i * BC)): b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) b_A = tl.sum(b_q * b_k[None, :] * exp2(b_g - b_gk[None, :]), 1) * scale tl.store(A + o_A + j, b_A, mask=m_A) p_k += H*K p_gk += H*K tl.debug_barrier() b_A = tl.zeros([BC, BC], dtype=tl.float32) tl.store(A + o_A[:, None] + o_i, b_A, mask=m_A[:, None] & (o_i[:, None] < o_i)) @triton.jit(do_not_specialize=['T']) def chunk_gla_fwd_A_kernel_intra_sub_intra_split( q, k, g, A, cu_seqlens, chunk_indices, scale, T, B: tl.constexpr, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_t, i_i = i_tc // NC, i_tc % NC i_j = i_i if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) all = T T = eos - bos else: bos, eos = i_b * T, i_b * T + T all = B * T if i_t * BT + i_i * BC >= T: return o_i = tl.arange(0, BC) o_k = i_k * BK + tl.arange(0, BK) o_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC m_k = o_k < K m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T q += (bos * H + i_h) * K k += (bos * H + i_h) * K g += (bos * H + i_h) * K A += ((i_k * all + bos) * H + i_h) * BC p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) p_k = k + (i_t * BT + i_j * BC) * H*K + o_k p_gk = g + (i_t * BT + i_j * BC) * H*K + o_k for j in range(0, min(BC, T - i_t * BT - i_i * BC)): b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) b_A = tl.sum(b_q * b_k[None, :] * exp2(b_g - b_gk[None, :]), 1) * scale tl.store(A + o_A + j, b_A, mask=m_A) p_k += H*K p_gk += H*K tl.debug_barrier() b_A = tl.zeros([BC, BC], dtype=tl.float32) tl.store(A + o_A[:, None] + o_i, b_A, mask=m_A[:, None] & (o_i[:, None] < o_i)) @triton.jit(do_not_specialize=['T']) def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( A, A2, cu_seqlens, chunk_indices, T, B: tl.constexpr, H: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, NK: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) all = T T = eos - bos else: bos, eos = i_b * T, i_b * T + T all = B * T if i_t * BT + i_c * BC >= T: return b_A = tl.zeros([BC, BC], dtype=tl.float32) for i_k in range(0, NK): p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) b_A += tl.load(p_A, boundary_check=(0, 1)) p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def chunk_gla_fwd_kernel_o( q, v, g, h, o, A, cu_seqlens, chunk_indices, scale, T, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, STATE_V_FIRST: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_hv = i_bh // HV, i_bh % HV i_h = i_hv // (HV // H) if IS_VARLEN: i_tg = i_t.to(tl.int64) i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) T = eos - bos NT = tl.cdiv(T, BT) else: NT = tl.cdiv(T, BT) i_tg = (i_b * NT + i_t).to(tl.int64) bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64) m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] q += (bos * H + i_h) * K g += (bos * HV + i_hv) * K v += (bos * HV + i_hv) * V o += (bos * HV + i_hv) * V h += (i_tg * HV + i_hv).to(tl.int64) * K * V A += (bos * HV + i_hv) * BT b_o = tl.zeros([BT, BV], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_g = tl.make_block_ptr(g, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) if STATE_V_FIRST: p_h = tl.make_block_ptr(h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) else: p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) # [BT, BK] b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) # [BT, BK] b_qg = (b_q * exp2(b_g)).to(b_q.dtype) b_h = tl.load(p_h, boundary_check=(0, 1)) if i_k >= 0: if STATE_V_FIRST: b_o += tl.dot(b_qg, tl.trans(b_h).to(b_qg.dtype)) else: b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) b_o *= scale p_v = tl.make_block_ptr(v, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_o = tl.make_block_ptr(o, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_A = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) # [BT, BT] b_A = tl.load(p_A, boundary_check=(0, 1)) b_A = tl.where(m_s, b_A, 0.).to(b_v.dtype) b_o += tl.dot(b_A, b_v) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def chunk_gla_bwd_kernel_intra( q, k, g, dA, dq, dk, cu_seqlens, chunk_indices, T, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_k, i_i = i_kc // NC, i_kc % NC if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T T = eos - bos if i_t * BT + i_i * BC >= T: return o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) # [BC, BK] b_g = tl.load(p_g, boundary_check=(0, 1)) b_dq = tl.zeros([BC, BK], dtype=tl.float32) if i_i > 0: p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) for i_j in range(0, i_i): p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_kg = b_k * exp2(b_gn[None, :] - b_gk) # [BC, BC] b_dA = tl.load(p_dA, boundary_check=(0, 1)) b_dq += tl.dot(b_dA, b_kg) b_dq *= exp2(b_g - b_gn[None, :]) o_i = tl.arange(0, BC) m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k p_gkj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) # [BK,] b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) # [BC, BK] m_i = o_i[:, None] >= j # [BC, BK] # (SY 09/17) important to not use bf16 here to have a good precision. b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp2(b_g - b_gkj[None, :]), 0.) p_kj += H*K p_gkj += H*K tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) tl.debug_barrier() # [BC, BK] b_dk = tl.zeros([BC, BK], dtype=tl.float32) NC = min(NC, tl.cdiv(T - i_t * BT, BC)) if i_i < NC - 1: p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h * K + o_k # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) for i_j in range(i_i + 1, NC): p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) p_gq = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) o_j = i_t * BT + i_j * BC + o_i m_j = o_j < T # [BC, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_gq = tl.load(p_gq, boundary_check=(0, 1)) b_qg = b_q * tl.where(m_j[:, None], exp2(b_gq - b_gn[None, :]), 0) # [BC, BC] b_dA = tl.load(p_dA, boundary_check=(0, 1)) # [BC, BK] # (SY 09/17) important to not use bf16 here to have a good precision. b_dk += tl.dot(b_dA, b_qg) b_dk *= exp2(b_gn[None, :] - b_g) o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k p_gqj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] b_dA = tl.load(dA + o_dA + j * H*BT) # [BK,] b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32) # [BC, BK] m_i = o_i[:, None] <= j b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp2(b_gqj[None, :] - b_g), 0.) p_qj += H*K p_gqj += H*K tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def chunk_gla_bwd_kernel_dA( v, do, dA, cu_seqlens, chunk_indices, scale, T, H: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T T = eos - bos b_dA = tl.zeros([BT, BT], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) b_dA += tl.dot(b_do, b_v) p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] b_dA = tl.where(m_s, b_dA * scale, 0.) tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def chunk_gla_bwd_kernel_dv( k, g, A, do, dh, dv, cu_seqlens, chunk_indices, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, STATE_V_FIRST: tl.constexpr, ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) else: NT = tl.cdiv(T, BT) i_tg = i_b * NT + i_t bos, eos = i_b * T, i_b * T + T p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_A = tl.load(p_A, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.) # (SY 09/17) important to disallow tf32 here to maintain a good precision. b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False) for i_k in range(tl.cdiv(K, BK)): o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_gn = g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k if STATE_V_FIRST: # dh stored as [V, K]; read a logical [BK, BV] tile via on-the-fly transpose p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (1, K), (i_k * BK, i_v * BV), (BK, BV), (0, 1)) else: p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_dh = tl.load(p_dh, boundary_check=(0, 1)) b_gn = exp2(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk) b_k = (b_k * b_gn).to(b_k.dtype) # [BT, BV] # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) @triton.jit(do_not_specialize=['T']) def chunk_gla_bwd_kernel_inter( q, k, v, g, h, do, dh, dq, dk, dq2, dk2, dg, cu_seqlens, chunk_indices, scale, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, STATE_V_FIRST: tl.constexpr, ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_tg = i_t i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) else: NT = tl.cdiv(T, BT) i_tg = i_b * NT + i_t bos, eos = i_b * T, i_b * T + T o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K q += (bos * H + i_h) * K k += (bos * H + i_h) * K v += (bos * H + i_h) * V g += (bos * H + i_h) * K h += (i_tg * H + i_h) * K*V do += (bos * H + i_h) * V dh += (i_tg * H + i_h) * K*V dq += (bos * H + i_h) * K dk += (bos * H + i_h) * K dq2 += (bos * H + i_h) * K dk2 += (bos * H + i_h) * K dg += (bos * H + i_h) * K p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) p_gn = g + (min(T, i_t * BT + BT) - 1) * H*K + o_k b_gn = tl.load(p_gn, mask=m_k, other=0) b_dq = tl.zeros([BT, BK], dtype=tl.float32) b_dk = tl.zeros([BT, BK], dtype=tl.float32) b_dgk = tl.zeros([BK], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) if STATE_V_FIRST: # h / dh stored as [V, K] -- the [BV, BK] tile is now a contiguous read p_h = tl.make_block_ptr(h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) p_dh = tl.make_block_ptr(dh, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) else: p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) # [BV, BK] b_h = tl.load(p_h, boundary_check=(0, 1)) b_dh = tl.load(p_dh, boundary_check=(0, 1)) # [BK] b_dgk += tl.sum(b_h * b_dh, axis=0) # [BT, BK] b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) b_dgk *= exp2(b_gn) b_dq *= scale b_dq = b_dq * exp2(b_gk) b_dk = b_dk * exp2(b_gn[None, :] - b_gk) p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_dgk += tl.sum(b_dk * b_k, axis=0) b_dq += tl.load(p_dq, boundary_check=(0, 1)) b_dk += tl.load(p_dk, boundary_check=(0, 1)) b_dg = b_q * b_dq - b_k * b_dk # tl.debug_barrier() b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] # Buggy due to strange triton compiler issue. # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.) # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :] p_dq = tl.make_block_ptr(dq2, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk2, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) def chunk_gla_fwd_intra_gk( q: torch.Tensor, k: torch.Tensor, g: torch.Tensor, scale: float, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, chunk_indices: torch.LongTensor | None = None, ): B, T, H, K = k.shape BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) BC = min(16, BT) NC = triton.cdiv(BT, BC) A = q.new_empty(B, T, H, BT, dtype=torch.float) grid = (NT, NC * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_inter[grid]( q=q, k=k, g=g, A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, scale=scale, T=T, H=H, K=K, BT=BT, BC=BC, NC=NC, ) grid = (NT, NC, B * H) # load the entire [BC, K] blocks into SRAM at once if K <= 256: BK = max(triton.next_power_of_2(K), 16) chunk_gla_fwd_A_kernel_intra_sub_intra[grid]( q=q, k=k, g=g, A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, scale=scale, T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, ) # split then merge else: BK = min(128, triton.next_power_of_2(K)) NK = triton.cdiv(K, BK) A_intra = q.new_empty(NK, B, T, H, BC, dtype=torch.float) grid = (NK, NT * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid]( q=q, k=k, g=g, A=A_intra, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, scale=scale, T=T, B=B, H=H, K=K, BT=BT, BC=BC, BK=BK, NC=NC, ) grid = (NT, NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_intra_merge[grid]( A=A_intra, A2=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, B=B, H=H, BT=BT, BC=BC, NK=NK, ) return A def chunk_gla_fwd_o_gk( q: torch.Tensor, v: torch.Tensor, g: torch.Tensor, A: torch.Tensor, h: torch.Tensor, scale: float, state_v_first: bool = False, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, chunk_indices: torch.LongTensor | None = None, ): B, T, H, K, HV, V = *q.shape, v.shape[2], v.shape[-1] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) # Please ensure zeros, since vllm will use padding v o = torch.zeros_like(v) BV = 32 chunk_gla_fwd_kernel_o[(triton.cdiv(V, BV), NT, B * HV)]( q=q, v=v, g=g, h=h, o=o, A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, scale=scale, T=T, H=H, HV=HV, K=K, V=V, BT=BT, STATE_V_FIRST=state_v_first, BK=64, BV=BV, IS_VARLEN=cu_seqlens is not None, ) return o def chunk_gated_delta_rule_fwd_h( k, w, u, g=None, gk=None, initial_state=None, output_final_state=False, chunk_size=64, save_new_value=True, state_v_first=False, cu_seqlens=None, cu_seqlens_cpu=None, chunk_indices=None, ): B, T, H, K, V, HV = *k.shape, u.shape[-1], u.shape[2] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) if state_v_first: h = k.new_empty(B, NT, HV, V, K) final_state = k.new_zeros(N, HV, V, K, dtype=torch.float32) if output_final_state else None else: h = k.new_empty(B, NT, HV, K, V) final_state = k.new_zeros(N, HV, K, V, dtype=torch.float32) if output_final_state else None v_new = torch.empty_like(u) if save_new_value else None BV = 32 chunk_gated_delta_rule_fwd_kernel_h_blockdim64[(triton.cdiv(V, BV), N * HV)]( k=k, v=u, w=w, v_new=v_new, g=g, gk=gk, h=h, h0=initial_state, ht=final_state, cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, T=T, H=H, HV=HV, K=K, V=V, BT=BT, STATE_V_FIRST=state_v_first, BV=BV, USE_G=False, USE_GK=gk is not None, USE_INITIAL_STATE=initial_state is not None, STORE_FINAL_STATE=output_final_state, SAVE_NEW_VALUE=save_new_value, IS_VARLEN=cu_seqlens is not None, ) return h, v_new, final_state def kda_forward(q, k, v, g, beta, scale: float, chunk_size: int = 64): B, T, H, K = q.shape HV = v.shape[2] gk = chunk_local_cumsum_vector( g, chunk_size=chunk_size, scale=RCP_LN2, output_dtype=torch.float32, ) BC = 16 BT = chunk_size NC = triton.cdiv(BT, BC) NT = triton.cdiv(T, BT) Aqk = torch.empty(B, T, HV, BT, device=q.device, dtype=q.dtype) Akkd = torch.empty(B, T, HV, BC, device=q.device, dtype=torch.float32) Akk = torch.zeros(B, T, HV, BT, device=q.device, dtype=q.dtype) Aqk, Akkd = kda_fwd_intra_token_parallel( q=q, k=k, gk=gk, beta=beta, Aqk=Aqk, Akk=Akkd, scale=scale, cu_seqlens=None, chunk_size=BT, sub_chunk_size=BC, ) kda_fwd_kernel_inter_solve_fused[(NT, B * HV)]( q=q, k=k, g=gk, beta=beta, Aqk=Aqk, Akkd=Akkd, Akk=Akk, scale=scale, cu_seqlens=None, chunk_indices=None, T=T, H=H, HV=HV, K=K, BT=BT, BC=BC, NC=NC, BK=32, USE_SAFE_GATE=False, IS_VARLEN=False, num_warps=2, ) w, u, _, kg = recompute_w_u_fwd( k=k, v=v, beta=beta, A=Akk, gk=gk, q=None, cu_seqlens=None, chunk_indices=None, ) h, v_new, _ = chunk_gated_delta_rule_fwd_h( k=kg, w=w, u=u, g=None, gk=gk, initial_state=None, output_final_state=False, chunk_size=BT, save_new_value=True, state_v_first=False, cu_seqlens=None, cu_seqlens_cpu=None, chunk_indices=None, ) return chunk_gla_fwd_o_gk( q=q, v=v_new, g=gk, A=Aqk, h=h, scale=scale, state_v_first=False, cu_seqlens=None, chunk_size=BT, chunk_indices=None, ) class Model(nn.Module): def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64): super().__init__() self.B, self.T, self.H, self.K, self.V = B, T, H, K, V self.chunk_size = chunk_size self.scale = float(K) ** -0.5 self.register_buffer("_dummy", torch.zeros(1), persistent=False) def forward(self, q, k, v, g, beta): return kda_forward(q, k, v, g, beta, scale=self.scale, chunk_size=self.chunk_size) B = 2 T = 1024 H = 8 K = 128 V = 128 CHUNK_SIZE = 64 def get_inputs(): torch.manual_seed(0) q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1 v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1 g = (torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05) beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16)) return [q, k, v, g, beta] def get_init_inputs(): return [B, T, H, K, V, CHUNK_SIZE]