"""KDA forward (chunk form) — Triton/CUBLAS hybrid implementation for SM120.
Uses PyTorch bmm (cuBLAS) for batch matmuls (intra-chunk A, w, u) +
Triton fused kernel for the inter-chunk recurrence.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
from einops import rearrange
OP_TYPE = "linear_attention"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
# ---------------------------------------------------------------------------
# Intra-chunk: PyTorch bmm + recursive inversion via torch.compile
# ---------------------------------------------------------------------------
def _intra_chunk(k, g, v, beta):
"""Compute w, u using batched matmul (cuBLAS)."""
B, H, NT, BT, Kdim = k.shape
Vdim = v.shape[-1]
device = k.device
k_f32 = k.to(torch.float32)
v_f32 = v.to(torch.float32)
beta_f32 = beta.to(torch.float32)
exp_g = torch.exp(g)
exp_neg_g = torch.exp(-g)
k_exp_g = k_f32 * exp_g
k_exp_neg_g = k_f32 * exp_neg_g
M = B * H * NT
k_exp_g_bmm = k_exp_g.reshape(M, BT, Kdim)
k_exp_neg_g_bmm = k_exp_neg_g.reshape(M, BT, Kdim)
# A = k_exp_g @ k_exp_neg_g.T via cuBLAS batched matmul
A = torch.bmm(k_exp_g_bmm, k_exp_neg_g_bmm.transpose(1, 2))
A_view = A.view(B, H, NT, BT, BT)
# Mask: zero upper triangle + diagonal, negate lower
mask_diag_upper = torch.triu(
torch.ones(BT, BT, dtype=torch.bool, device=device), diagonal=0
)
A_lower = -A_view.masked_fill(mask_diag_upper, 0.0)
# Recursive inversion
A_inv = _recursive_inversion(A_lower)
# final_A = (A_inv + I) * beta_j
eye = torch.eye(BT, dtype=torch.float32, device=device)
final_A = (A_inv + eye.view(1, 1, 1, BT, BT)) * beta_f32[..., None, :]
final_A_bmm = final_A.reshape(M, BT, BT)
# w = final_A @ k_exp_g, u = final_A @ v
w = torch.bmm(final_A_bmm, k_exp_g_bmm).view(B, H, NT, BT, Kdim)
u = torch.bmm(final_A_bmm, v_f32.reshape(M, BT, Vdim)).view(B, H, NT, BT, Vdim)
return w, u
@torch.compile(fullgraph=False, mode="reduce-overhead")
def _recursive_inversion(A_lower: torch.Tensor) -> torch.Tensor:
"""Compute (I - A_lower)^{-1} - I via triangular recursion."""
BT = A_lower.shape[-1]
A = A_lower.clone()
for i in range(1, BT):
left = A[..., i, :, None]
right = A[..., :, :i]
increment = (left * right).sum(-2)
A[..., i, :i] = A[..., i, :i] + increment
return A
# ---------------------------------------------------------------------------
# Inter-chunk fused Triton kernel
# ---------------------------------------------------------------------------
@triton.jit
def _inter_chunk_triton(
# Precomputed fp32 tensors: (B, H, NT, BT, D)
q_exp_g_ptr, # q * exp(g) (B,H,NT,BT,K) fp32
k_exp_neg_g_ptr, # k * exp(-g) (B,H,NT,BT,K) fp32
k_f32_ptr, # k in fp32 (B,H,NT,BT,K) fp32
g_ptr, # g (cumsum) (B,H,NT,BT,K) fp32
exp_g_ptr, # exp(g) (B,H,NT,BT,K) fp32
w_ptr, # w from intra-chunk (B,H,NT,BT,K) fp32
u_ptr, # u from intra-chunk (B,H,NT,BT,V) fp32
o_ptr, # output (B,H,NT,BT,V) fp32
S_ptr, # state buffer (B*H, K, V) fp32 (in/out)
NT: tl.constexpr,
BT: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
stride_n_q: tl.constexpr, stride_bt_q: tl.constexpr, stride_d_q: tl.constexpr,
stride_n_k: tl.constexpr, stride_bt_k: tl.constexpr, stride_d_k: tl.constexpr,
stride_n_g: tl.constexpr, stride_bt_g: tl.constexpr, stride_d_g: tl.constexpr,
stride_n_e: tl.constexpr, stride_bt_e: tl.constexpr, stride_d_e: tl.constexpr,
stride_n_w: tl.constexpr, stride_bt_w: tl.constexpr, stride_d_w: tl.constexpr,
stride_n_u: tl.constexpr, stride_bt_u: tl.constexpr, stride_d_u: tl.constexpr,
stride_n_o: tl.constexpr, stride_bt_o: tl.constexpr, stride_d_o: tl.constexpr,
stride_s_h: tl.constexpr, stride_s_k: tl.constexpr, stride_s_v: tl.constexpr,
BLOCK_K: tl.constexpr, BLOCK_V: tl.constexpr,
):
"""
One program per (B,H) pair. Sequential inter-chunk recurrence.
State S (K, V) in global memory at S_ptr[pid, :, :].
"""
pid = tl.program_id(0)
# Ranges
off_bt = tl.arange(0, BT)
off_k = tl.arange(0, BLOCK_K)
off_v = tl.arange(0, BLOCK_V)
# Base pointer for this (B,H)
base_q = q_exp_g_ptr + pid * stride_n_q * NT * BT * K
base_k = k_exp_neg_g_ptr + pid * stride_n_k * NT * BT * K
base_kf = k_f32_ptr + pid * stride_n_k * NT * BT * K
base_g = g_ptr + pid * stride_n_g * NT * BT * K
base_e = exp_g_ptr + pid * stride_n_e * NT * BT * K
base_w = w_ptr + pid * stride_n_w * NT * BT * K
base_u = u_ptr + pid * stride_n_u * NT * BT * V
base_o = o_ptr + pid * stride_n_o * NT * BT * V
# S pointer for this program
base_S = S_ptr + pid * stride_s_h * K * V
# Mask for Aqk causal
mask_causal = off_bt[:, None] > off_bt[None, :]
# Bounds masks
k_mask = off_k < K
v_mask = off_v < V
bt_mask = off_bt < BT
for n in range(NT):
# Pointers to chunk n
qn = base_q + n * stride_n_q * BT * K
kn = base_k + n * stride_n_k * BT * K
kfn = base_kf + n * stride_n_k * BT * K
gn = base_g + n * stride_n_g * BT * K
en = base_e + n * stride_n_e * BT * K
wn = base_w + n * stride_n_w * BT * K
un = base_u + n * stride_n_u * BT * V
# ---- Aqk = q_exp_g @ k_exp_neg_g^T (BT, BT) ----
Aqk = tl.zeros((BT, BT), dtype=tl.float32)
for k0 in range(0, K, BLOCK_K):
k_off = k0 + off_k
# q_exp_g tile: (BT, BLOCK_K)
qe = tl.load(
qn + off_bt[:, None] * stride_bt_q + k_off[None, :] * stride_d_q,
mask=bt_mask[:, None] & (k_off[None, :] < K),
other=0.0,
)
# k_exp_neg_g tile: (BT, BLOCK_K)
kn_t = tl.load(
kn + off_bt[:, None] * stride_bt_k + k_off[None, :] * stride_d_k,
mask=bt_mask[:, None] & (k_off[None, :] < K),
other=0.0,
)
Aqk += tl.dot(qe, tl.trans(kn_t))
# Mask causal
Aqk = tl.where(mask_causal, Aqk, 0.0)
# ---- wS = w @ S (BT, K) @ (K, V) -> (BT, V) ----
# ---- qS = q_exp @ S ----
wS = tl.zeros((BT, V), dtype=tl.float32)
qS = tl.zeros((BT, V), dtype=tl.float32)
for v0 in range(0, V, BLOCK_V):
v_off = v0 + off_v
v_mask2 = v_off < V
for k0 in range(0, K, BLOCK_K):
k_off2 = k0 + off_k
k_mask2 = k_off2 < K
# w tile
w_t = tl.load(
wn + off_bt[:, None] * stride_bt_w + k_off2[None, :] * stride_d_w,
mask=bt_mask[:, None] & k_mask2[None, :],
other=0.0,
)
# q_exp tile
qe2 = tl.load(
qn + off_bt[:, None] * stride_bt_q + k_off2[None, :] * stride_d_q,
mask=bt_mask[:, None] & k_mask2[None, :],
other=0.0,
)
# S tile: (BLOCK_K, BLOCK_V)
s_t = tl.load(
base_S + k_off2[:, None] * stride_s_k + v_off[None, :] * stride_s_v,
mask=k_mask2[:, None] & v_mask2[None, :],
other=0.0,
)
wS += tl.dot(w_t, s_t)
qS += tl.dot(qe2, s_t)
# ---- v_prime = u - wS ----
u_t = tl.zeros((BT, V), dtype=tl.float32)
for v0 in range(0, V, BLOCK_V):
v_off3 = v0 + off_v
v_mask3 = v_off3 < V
uz = tl.load(
un + off_bt[:, None] * stride_bt_u + v_off3[None, :] * stride_d_u,
mask=bt_mask[:, None] & v_mask3[None, :],
other=0.0,
)
for i in range(BT):
for vv in range(BLOCK_V):
if v_off3[vv] < V:
u_t[i, v_off3[vv]] = uz[i, vv]
v_prime = u_t - wS
# ---- Aqk @ v_prime (BT, BT) @ (BT, V) -> (BT, V) ----
Aqk_vp = tl.zeros((BT, V), dtype=tl.float32)
for v0 in range(0, V, BLOCK_V):
v_off4 = v0 + off_v
v_mask4 = v_off4 < V
vp = tl.zeros((BT, BLOCK_V), dtype=tl.float32)
for i in range(BT):
for vv in range(BLOCK_V):
if v_off4[vv] < V:
vp[i, vv] = v_prime[i, v_off4[vv]]
Aqk_vp += tl.dot(Aqk, vp)
# ---- o_n = qS + Aqk_vp ----
o_n = qS + Aqk_vp
# Store o_n
on_ptr = base_o + n * stride_n_o * BT * V
for v0 in range(0, V, BLOCK_V):
v_off5 = v0 + off_v
v_mask5 = v_off5 < V
tl.store(
on_ptr + off_bt[:, None] * stride_bt_o + v_off5[None, :] * stride_d_o,
tl.where(bt_mask[:, None] & v_mask5[None, :], o_n, 0.0),
mask=bt_mask[:, None] & v_mask5[None, :],
)
# ---- Update S ----
# g_last = g[BT-1, :] (K,)
g_last = tl.zeros((K,), dtype=tl.float32)
for kk in range(K):
g_last[kk] = tl.load(gn + (BT - 1) * stride_bt_g + kk * stride_d_g)
# S *= exp(g_last) elementwise over K
for k0 in range(0, K, BLOCK_K):
k_off6 = k0 + off_k
k_mask6 = k_off6 < K
for v0 in range(0, V, BLOCK_V):
v_off6 = v0 + off_v
v_mask6 = v_off6 < V
s_val = tl.load(
base_S + k_off6[:, None] * stride_s_k + v_off6[None, :] * stride_s_v,
mask=k_mask6[:, None] & v_mask6[None, :],
other=0.0,
)
# Decay
decay = tl.math.exp(tl.load(
gn + (BT - 1) * stride_bt_g + k_off6 * stride_d_g,
mask=k_mask6,
other=0.0,
))
s_new = s_val * decay[:, None]
tl.store(
base_S + k_off6[:, None] * stride_s_k + v_off6[None, :] * stride_s_v,
s_new,
mask=k_mask6[:, None] & v_mask6[None, :],
)
# S += (exp(g_last - g) * k).T @ v_prime
# decay_k = exp(g_last - g) * k (BT, K)
# S += decay_k.T @ v_prime (K, BT) @ (BT, V) -> (K, V)
for k0 in range(0, K, BLOCK_K):
k_off7 = k0 + off_k
k_mask7 = k_off7 < K
# Build decay_k tile for this K tile: (BT, BLOCK_K)
dk_tile = tl.zeros((BT, BLOCK_K), dtype=tl.float32)
for i in range(BT):
g_i = tl.load(gn + i * stride_bt_g + k_off7 * stride_d_g,
mask=k_mask7, other=0.0)
k_i = tl.load(kfn + i * stride_bt_k + k_off7 * stride_d_k,
mask=k_mask7, other=0.0)
dk_tile[i, :] = tl.math.exp(g_last[k_off7] - g_i) * k_i
# dk_tile^T: (BLOCK_K, BT)
dk_t = tl.trans(dk_tile)
for v0 in range(0, V, BLOCK_V):
v_off7 = v0 + off_v
v_mask7 = v_off7 < V
vp_tile = tl.zeros((BT, BLOCK_V), dtype=tl.float32)
for i in range(BT):
for vv in range(BLOCK_V):
if v_off7[vv] < V:
vp_tile[i, vv] = v_prime[i, v_off7[vv]]
# S_update tile: (BLOCK_K, BLOCK_V) = dk_t @ vp_tile
s_upd = tl.dot(dk_t, vp_tile)
# S += s_upd
s_cur = tl.load(
base_S + k_off7[:, None] * stride_s_k + v_off7[None, :] * stride_s_v,
mask=k_mask7[:, None] & v_mask7[None, :],
other=0.0,
)
tl.store(
base_S + k_off7[:, None] * stride_s_k + v_off7[None, :] * stride_s_v,
s_cur + s_upd,
mask=k_mask7[:, None] & v_mask7[None, :],
)
# ---------------------------------------------------------------------------
# Host launchers
# ---------------------------------------------------------------------------
def _launch_inter_chunk(
q_exp_g, k_exp_neg_g, k_fp32, g, exp_g, w, u,
) -> torch.Tensor:
"""Launch fused inter-chunk Triton kernel. Returns o (B, T, H, V) bf16."""
B, H, NT, BT, Kdim = q_exp_g.shape
Vdim = w.shape[-1] # w has Vdim == u has Vdim, but actually w is (...,K)
# w shape: (B, H, NT, BT, K)
Vdim = u.shape[-1]
device = q_exp_g.device
# S buffer: (B*H, K, V) fp32
S_buf = torch.zeros(B * H, Kdim, Vdim, dtype=torch.float32, device=device)
# Output
o = torch.empty(B, H, NT, BT, Vdim, dtype=torch.float32, device=device)
# Ensure contiguous
q_exp_g = q_exp_g.contiguous()
k_exp_neg_g = k_exp_neg_g.contiguous()
k_fp32 = k_fp32.contiguous()
g = g.contiguous()
exp_g = exp_g.contiguous()
w = w.contiguous()
u = u.contiguous()
BLOCK_K = min(32, Kdim)
BLOCK_V = min(32, Vdim)
grid = (B * H,)
_inter_chunk_triton[grid](
q_exp_g, k_exp_neg_g, k_fp32, g, exp_g, w, u, o, S_buf,
NT=NT, BT=BT, K=Kdim, V=Vdim,
stride_n_q=q_exp_g.stride(2), stride_bt_q=q_exp_g.stride(3), stride_d_q=q_exp_g.stride(4),
stride_n_k=k_exp_neg_g.stride(2), stride_bt_k=k_exp_neg_g.stride(3), stride_d_k=k_exp_neg_g.stride(4),
stride_n_g=g.stride(2), stride_bt_g=g.stride(3), stride_d_g=g.stride(4),
stride_n_e=exp_g.stride(2), stride_bt_e=exp_g.stride(3), stride_d_e=exp_g.stride(4),
stride_n_w=w.stride(2), stride_bt_w=w.stride(3), stride_d_w=w.stride(4),
stride_n_u=u.stride(2), stride_bt_u=u.stride(3), stride_d_u=u.stride(4),
stride_n_o=o.stride(2), stride_bt_o=o.stride(3), stride_d_o=o.stride(4),
stride_s_h=Kdim * Vdim, stride_s_k=Vdim, stride_s_v=1,
BLOCK_K=BLOCK_K, BLOCK_V=BLOCK_V,
)
o = rearrange(o, "b h n c d -> b (n c) h d")
return o.to(dtype=torch.bfloat16)
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class Model(nn.Module):
"""KDA forward (chunk form) — hybrid PyTorch bmm + Triton recurrence."""
def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64):
super().__init__()
self.B, self.T, self.H, self.K, self.V = B, T, H, K, V
self.chunk_size = chunk_size
self.scale = float(K) ** -0.5
self.register_buffer("_dummy", torch.zeros(1), persistent=False)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
B, T, H, Kdim = q.shape
Vdim = v.shape[-1]
BT = self.chunk_size
scale = self.scale
# Scale q and reshape
q = q * scale
q = rearrange(q, "b (n c) h d -> b h n c d", c=BT).contiguous()
k = rearrange(k, "b (n c) h d -> b h n c d", c=BT).contiguous()
v = rearrange(v, "b (n c) h d -> b h n c d", c=BT).contiguous()
g = rearrange(g, "b (n c) h d -> b h n c d", c=BT).contiguous()
beta = rearrange(beta, "b (n c) h -> b h n c", c=BT).contiguous()
# In-chunk cumsum on g
g = g.cumsum(-2)
# ---- Intra-chunk: compute w and u via cuBLAS bmm ----
w, u = _intra_chunk(k, g, v, beta)
# ---- Precompute tensors for inter-chunk ----
q_fp32 = q.to(torch.float32)
k_fp32 = k.to(torch.float32)
exp_g = torch.exp(g)
q_exp_g = q_fp32 * exp_g
k_exp_neg_g = k_fp32 * torch.exp(-g)
# ---- Inter-chunk: fused Triton kernel ----
o = _launch_inter_chunk(
q_exp_g, k_exp_neg_g, k_fp32, g, exp_g, w, u,
)
return o
# Module-level shape shims
B = 2
T = 1024
H = 8
K = 128
V = 128
CHUNK_SIZE = 64
def get_inputs():
torch.manual_seed(0)
q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1
k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1
v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1
g = torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05
beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16))
return [q, k, v, g, beta]
def get_init_inputs():
return [B, T, H, K, V, CHUNK_SIZE]
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py", line 113, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py", line 74, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py", line 426, in forward
o = _launch_inter_chunk(
^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py", line 359, in _launch_inter_chunk
_inter_chunk_triton[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 80, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 120:16:
qe2 = tl.load(
qn + off_bt[:, None] * stride_bt_q + k_off2[None, :] * stride_d_q,
mask=bt_mask[:, None] & k_mask2[None, :],
other=0.0,
)
# S tile: (BLOCK_K, BLOCK_V)
s_t = tl.load(
base_S + k_off2[:, None] * stride_s_k + v_off[None, :] * stride_s_v,
mask=k_mask2[:, None] & v_mask2[None, :],
other=0.0,
)
wS += tl.dot(w_t, s_t)
^
ValueError('Cannot make_shape_compatible: incompatible dimensions at index 1: 128 and 32')
{
"problem": "02_kda_cutlass",
"harness": "opencode",
"model": "deepseek/deepseek-v4-pro",
"reasoning_effort": "",
"has_solution": true,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 34483, "output_tokens": 30134, "cache_read_tokens": 834944, "cache_creation_tokens": 0, "reasoning_tokens": 39148, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/reference.py</path>
<type>file</type>
<content>
1: """Naive PyTorch reference for Kimi Delta Attention (KDA) forward, chunk form.
2:
3: This is the correctness oracle, NOT the SOTA baseline. It mirrors the
4: chunk-parallel formulation in fla/ops/kda/naive.py (Songlin Yang et al.)
5: without any Triton or CUDA optimization.
6:
7: Inputs (per the FLA convention):
8: q, k : (B, T, H, K) bf16 -- queries / keys
9: v : (B, T, H, V) bf16 -- values
10: g : (B, T, H, K) fp32 -- per-channel log-decay (in-chunk cumsum applied)
11: beta : (B, T, H) bf16 -- write strength
12:
13: Output:
14: o : (B, T, H, V) bf16
15:
16: The agent must reproduce this output (within bf16 tolerance) using a CUTLASS
17: CuTe kernel on SM120 -- NOT by calling fla.ops.chunk_kda directly.
18: """
19: from __future__ import annotations
20:
21: import torch
22: import torch.nn as nn
23: from einops import rearrange
24:
25: OP_TYPE = "linear_attention"
26: SUPPORTED_PRECISIONS = ["bf16"]
27: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
28:
29:
30: def _naive_chunk_kda(
31: q: torch.Tensor,
32: k: torch.Tensor,
33: v: torch.Tensor,
34: g: torch.Tensor,
35: beta: torch.Tensor,
36: scale: float,
37: chunk_size: int = 64,
38: ) -> torch.Tensor:
39: """KDA forward, no initial state, no final state. Returns o with v's dtype."""
40: dtype = v.dtype
41: B, T,
... (4359 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for KDA forward (chunk form).
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol (bf16 default 1e-2 plus a 5e-2 override for
5: this problem). Also rejects forbidden ops by grep.
6: """
7: import re
8: import sys
9: from pathlib import Path
10:
11: import torch
12: import yaml
13:
14: # Make the repo's src/ importable
15: REPO_ROOT = Path(__file__).resolve().parents[2]
16: sys.path.insert(0, str(REPO_ROOT))
17:
18: from src.eval.correctness import check_correctness # noqa: E402
19:
20:
21: def _apply_shape(reference, shape):
22: """Override reference's module-level shape shims so get_inputs/get_init_inputs match."""
23: for k, v in shape.items():
24: setattr(reference, k, v)
25:
26:
27: def main():
28: try:
29: import reference
30: import shapes
31: import solution
32: except Exception as e:
33: print(f"FAIL: import error: {e}")
34: sys.exit(1)
35:
36: problem_yaml = Path("problem.yaml")
37: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
38:
39: # --- Forbidden-op check ------------------------------------------------
40: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
41: for forbidden in meta.get("forbidden", [])
... (2934 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for KDA forward (chunk form).
2:
3: For each shape: times eager reference, compiled reference, SOTA (FLA's Triton
4: chunk_kda, if available on this GPU), and the agent's solution. Reports
5: achieved TFLOPS, GB/s, and peak_fraction.
6:
7: Output lines the harness picks up:
8: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10: """
11: import sys
12: from math import exp, log
13: from pathlib import Path
14:
15: import torch
16: import yaml
17:
18: REPO_ROOT = Path(__file__).resolve().parents[2]
19: sys.path.insert(0, str(REPO_ROOT))
20:
21: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22: from src.eval.timing import time_fn # noqa: E402
23: from src.hardware import get as get_hw # noqa: E402
24:
25:
26: def _eval_formula(expr: str, vars: dict) -> float:
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def _apply_shape(reference, shape):
31: for k, v in shape.items():
32: setattr(reference, k, v)
33:
34:
35: def main():
36: import reference
37: import shapes
38: import solution
39:
40: meta = yaml.safe_load(Path("problem.yaml").read_text())
41: hw = get_hw(meta["hardware"][0])
42: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"
... (3596 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/problem.yaml</path>
<type>file</type>
<content>
1: name: 02_kda_cutlass
2: display_name: "Kimi Delta Attention (chunk forward) via CUTLASS CuTe"
3: precision: bf16
4: regime: compute
5:
6: # Approximate FLOPs per forward call. KDA forward (chunk form) breakdown per
7: # token across all heads:
8: # - intra-chunk Aqk (q @ k.T with decay): 2*BT*K
9: # - intra-chunk Akk + cumulative invert : ~3*BT*K + BT^2
10: # - w = A @ (g.exp()*k): 2*BT*K
11: # - u = A @ v: 2*BT*V
12: # - inter-chunk q @ S: 2*K*V
13: # - inter-chunk (Aqk @ v_i): 2*BT*V
14: # - state update S += k.T @ v_i: 2*K*V
15: # Aggregating with a constant of 4 (typical for chunked linear attention) gives
16: # the formula below. Used only for roofline reporting -- not graded directly.
17: flops_formula: "4 * B * T * H * (K*V + CHUNK_SIZE*K + CHUNK_SIZE*V)"
18:
19: # Bytes moved (lower bound): one read of q,k,v,g,beta and one write of o.
20: bytes_formula: "B*T*H*K*2 + B*T*H*K*2 + B*T*H*V*2 + B*T*H*K*4 + B*T*H*2 + B*T*H*V*2"
21:
22: hardware: [RTX_PRO_6000]
23: peak_tflops_key: bf16 # KDA's heavy work is bf16 @ bf16 -> fp32 GEMM-equivalent
24: peak_bandwidth_key: dram
25:
26: # bf16 default tolerance from src/eval/correctness.py is 1e-2 atol/rtol; KDA's
27: # long recurrence accumulates more error than a single GEMM, so we relax slightly.
28: tolerance
... (1108 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for KDA forward (chunk form).
2:
3: Mix of:
4: - short-context training-step scale (T=1024)
5: - mid-context (T=2048) which is the headline benchmark
6: - long-context that stresses the inter-chunk recurrence (T=4096)
7: - thin-batch decode-style (B=1, T=2048, fewer heads)
8:
9: Constraints:
10: - T % chunk_size == 0 (chunk_size = 64)
11: - K, V are the per-head channel dims; KDA in Kimi Linear uses K=V=128
12: """
13:
14: SHAPES = [
15: {"B": 2, "T": 1024, "H": 8, "K": 128, "V": 128, "CHUNK_SIZE": 64},
16: {"B": 2, "T": 2048, "H": 8, "K": 128, "V": 128, "CHUNK_SIZE": 64},
17: {"B": 1, "T": 4096, "H": 8, "K": 128, "V": 128, "CHUNK_SIZE": 64},
18: {"B": 1, "T": 2048, "H": 4, "K": 128, "V": 128, "CHUNK_SIZE": 64},
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for KDA forward: fla.ops.kda.chunk_kda (Triton).
2:
3: The agent's solution is forbidden from importing this module path (see
4: problem.yaml.forbidden). This file is only used by benchmark.py to draw
5: the SOTA reference line.
6:
7: If FLA's Triton kernel does not run on SM120 (Blackwell consumer-lineage --
8: some Triton kernels in FLA target Hopper TMA), is_available() returns False
9: and benchmark.py omits the SOTA variant. The H100 reference is documented
10: in problem.yaml for context.
11: """
12: from __future__ import annotations
13:
14: import torch
15:
16:
17: def _import_fla():
18: try:
19: from fla.ops.kda import chunk_kda # noqa: F401
20: return chunk_kda
21: except Exception:
22: return None
23:
24:
25: def sota_forward(
26: q: torch.Tensor,
27: k: torch.Tensor,
28: v: torch.Tensor,
29: g: torch.Tensor,
30: beta: torch.Tensor,
31: scale: float | None = None,
32: ) -> torch.Tensor:
33: """Run FLA's Triton chunk_kda. Returns o (B, T, H, V) in v's dtype."""
34: chunk_kda = _import_fla()
35: if chunk_kda is None:
36: raise RuntimeError("fla.ops.kda.chunk_kda unavailable")
37: # FLA's chunk_kda has a richer signature (A_log, dt_bias, l2norm, gates, ...).
38: # We need the bare forward: pass A_log/dt_bias as None, gates off, no l2norm.
39: #
... (1352 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to implement Kimi Delta Attention forward (chunk form) for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The op is the chunk-parallel KDA forward from the FLA library: q and k of shape (B, T, H, K) in bf16, v of shape (B, T, H, V) in bf16, g of shape (B, T, H, K) in fp32 (per-channel log-decay with in-chunk cumsum already applied), beta of shape (B, T, H) in bf16, scale a python float, chunk_size 64, no initial state, no final state. You return o of shape (B, T, H, V) in bf16. Correctness tolerance is 0.05 abs/rel — the long recurrence accumulates more error than a single GEMM so the bar's a bit looser than default bf16. The shapes you have to handle are B=2 T=1024 H=8 K=128 V=128 (short-context training step), B=2 T=2048 H=8 K=128 V=128 (the headline shape from the Kimi Linear paper), B=1 T=4096 H=8 K=128 V=128 (long context that stresses the inter-chunk recurrence), and B=1 T=2048 H=4 K=128 V=128 (thin batch decode).
4:
5: This needs to be a real custom kernel — the whole point of the problem is to w
... (976 more chars){
"command": "nvidia-smi --query-gpu=name,memory.total,compute_cap --format=csv,noheader 2>/dev/null; python -c \"import triton; print(f'Triton {triton.__version__}')\" 2>/dev/null; python -c \"import cutlass; print('CUTLASS available')\" 2>/dev/null || echo \"no cutlass python\"; pip list 2>/dev/null | grep -i triton",
"description": "Check GPU, Triton, CUTLASS availability"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 12.0
Triton 3.6.0
no cutlass python
triton 3.6.0
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/scripts/ 2>/dev/null; cat /home/infatoshi/cuda/KernelBench-Hard/scripts/run_hard.sh 2>/dev/null || true",
"description": "Check scripts directory"
}ccr-rust.config.json
extract_usage.py
patch_torch.sh
roofline_plot.py
run_hard.sh
serve.sh
setup_problem.py
shakedown_sweep.sh
sweep.sh
topk_overnight.sh
#!/bin/bash
# Run one (harness, model, problem) combination.
#
# Usage:
# ./scripts/run_hard.sh <harness> <model> <problem_dir> [reasoning_effort]
#
# Examples:
# ./scripts/run_hard.sh claude claude-opus-4-7 problems/01_fp8_gemm
# ./scripts/run_hard.sh codex gpt-5.5 problems/01_fp8_gemm xhigh
# ./scripts/run_hard.sh kimi kimi-k2.6 problems/01_fp8_gemm
# ./scripts/run_hard.sh ccr-claude glm-5.1 problems/01_fp8_gemm
#
# Archives everything to outputs/runs/<ts>_<harness>_<model>_<problem>/.
set -euo pipefail
# Pin CUDA 13 — /usr/local/cuda may still point at 12.8.
if [ -d /usr/local/cuda-13 ]; then
export CUDA_HOME=/usr/local/cuda-13
export PATH="$CUDA_HOME/bin:$PATH"
fi
# Source API keys if the user has an env_vars file.
if [ -f "$HOME/.env_vars" ]; then
set -a
# shellcheck disable=SC1091
. "$HOME/.env_vars"
set +a
fi
HARNESS="${1:?Usage: $0 <harness> <model> <problem_dir> [reasoning_effort]}"
MODEL="${2:?model required}"
PROBLEM_DIR="${3:?problem_dir required}"
REASONING_EFFORT="${4:-}"
PROBLEM_DIR="$(cd "$PROBLEM_DIR" && pwd)"
PROBLEM_NAME="$(basename "$PROBLEM_DIR")"
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
MODEL_SLUG="$(echo "$MODEL" | tr '/:[] ' '_')"
RUN_DIR="${REPO_ROOT}/outputs/runs/${TIMESTAMP}_${HARNESS}_${MODEL_SLUG}_${PROBLEM_NAME}"
mkd
... (9931 more chars){
"todos": [
{
"content": "Analyze reference algorithm and design Triton kernel decomposition",
"status": "completed",
"priority": "high"
},
{
"content": "Write intra-chunk kernel (compute w, u from k, g, v, beta)",
"status": "in_progress",
"priority": "high"
},
{
"content": "Write inter-chunk recurrence kernel (compute o from q, k, g, w, u)",
"status": "pending",
"priority": "high"
},
{
"content": "Write Model.forward coordinating PyTorch ops and kernels",
"status": "pending",
"priority": "hi
... (410 more chars)content: 'Analyze reference algorithm and design Triton kernel decomposition'
status: 'completed'
priority: 'high'
content: 'Write intra-chunk kernel (compute w, u from k, g, v, beta)'
status: 'in_progress'
priority: 'high'
content: 'Write inter-chunk recurrence kernel (compute o from q, k, g, w, u)'
status: 'pending'
priority: 'high'
content: 'Write Model.forward coordinating PyTorch ops and kernels'
status: 'pending'
priority: 'high'
content: 'Run check.py for correctness validation'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py for performance measurement'
status: 'pending'
priority: 'medium'
content: 'Iterate optimization based on profiling and timing'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -0,0 +1,385 @@
+"""KDA forward (chunk form) — Triton implementation for SM120 Blackwell.
+
+Two-kernel design:
+ 1. _intra_chunk_kernel: parallel over (B, H, NT) — computes w, u from k, g, v, beta
+ 2. _inter_chunk_kernel: parallel over (B, H) — sequential chunk recurrence → o
+"""
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+from einops import rearrange
+
+OP_TYPE = "linear_attention"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+# ---------------------------------------------------------------------------
+# Intra-chunk kernel: for each chunk compute w (BT,K) and u (BT,V)
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _intra_chunk_kernel(
+ k_ptr, g_ptr, v_ptr, beta_ptr,
+ w_ptr, u_ptr,
+ BT: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
+ stride_b: tl.constexpr, stride_h: tl.constexpr,
+ stride_nt: tl.constexpr, stride_bt_k: tl.constexpr,
+ stride_k: tl.constexpr,
+ stride_bt_v: tl.constexpr, stride_v: tl.constexpr,
+ BLOCK_K_K: tl.constexpr, BLOCK_K_V: tl.constexpr,
+):
+ """
+ One program per chunk. Computes:
+ A = -(k * exp(g)) @ (k * exp(-g)).T [lower-triangular part]
+ A_inv = (I - A)^(-1) - I [recursive triangular inversion]
+ final_A = (A_inv + I) * beta[None, :]
+ w = final_A @ (k * exp(g))
+ u = final_A @ v
+ """
+ pid = tl.program_id(0)
+ NT = tl.num_programs(0) // (1) # NT is inferred, but we don't have it directly
+ # We'll pass NT separately; for now compute B*H*NT as num_programs
+
+ # --- decayed keys ------------------------------------------------------
+ # Load k and g for this chunk: shape (BT, K)
+ offsets_bt = tl.arange(0, BT) # [0, BT)
+ offsets_k = tl.arange(0, BLOCK_K_K) # [0, BLOCK_K_K)
+
+ # k_exp_g = k * exp(g), stored in registers/shared: (BT, K) in fp32
+ # We compute A = (k_exp_g) @ (k_exp_neg_g).T tile by tile
+ # A shape: (BT, BT), lower-tri part needed
+
+ # We store A in a (BT, BT) fp32 scratch in smem
+ A_smem = tl.zeros((BT, BT), dtype=tl.float32)
+
+ # Compute A block by block over K
+ for k_start in range(0, K, BLOCK_K_K):
+ k_off = k_start + offsets_k
+ k_mask = k_off < K
+
+ # Load k tile: (BT, BLOCK_K_K)
+ k_tile = tl.load(
+ k_ptr + offsets_bt[:, None] * stride_bt_k + k_off[None, :] * stride_k,
+ mask=(offsets_bt[:, None] < BT) & k_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+
+ # Load g tile: (BT, BLOCK_K_K)
+ g_tile = tl.load(
+ g_ptr + offsets_bt[:, None] * stride_bt_k + k_off[None, :] * stride_k,
+ mask=(offsets_bt[:, None] < BT) & k_mask[None, :],
+ other=0.0,
+ )
+
+ # k_exp_g = k * exp(g), k_exp_neg_g = k * exp(-g)
+ exp_g = tl.math.exp(g_tile)
+ exp_neg_g = tl.math.exp(-g_tile)
+
+ ke_tile = k_tile * exp_g # (BT, BLOCK_K_K)
+ kn_tile = k_tile * exp_neg_g # (BT, BLOCK_K_K)
+
+ # Partial A contribution: A += ke_tile @ kn_tile.T
+ # Each thread handles a sub-block of A
+ pid_m = pid % 1 # just for uniformity, we handle the full BT here
+ # In practice, with 1 program per chunk and BT=64, we compute all A entries
+
+ # Use tl.dot for the outer product
+ # A[:,:] += ke_tile @ kn_tile^T
+ # ke_tile: (BT, BK), kn_tile: (BT, BK)
+ # Need: dot(ke_tile, kn_tile, trans_b=True)
+ # But tl.dot expects preloaded inputs in registers
+ # For simplicity, accumulate manually (BT=64, small enough)
+
+ # Accumulate into A_smem
+ for i in tl.static_range(0, BT):
+ ke_i = ke_tile[i, :] # (BK,) broadcast if needed
+ # kn has shape (BT, BK)
+ # A_smem[i, :] += ke_i @ kn^T -> (BK,) @ (BK, BT) via sum over BK
+ # Actually: A[i,j] += sum_d ke[i,d] * kn[j,d]
+ # So A[i,:] += (kn @ ke_i) : (BT,BK) @ (BK,) -> (BT,)
+ # We can do: partial = tl.sum(kn * ke_i[None,:], axis=1) -> (BT,)
+ # But this needs manual handling with tl.dot or loops
+ pass
+ # TODO: Implement properly with tl.dot
+
+ # Placeholder return — we'll fill in the real computation below
+ return
+
+
+# ---------------------------------------------------------------------------
+# Inter-chunk recurrence kernel
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _inter_chunk_kernel(
+ q_ptr, k_ptr, g_ptr, w_ptr, u_ptr,
+ o_ptr,
+ NT: tl.constexpr, BT: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
+ stride_b: tl.constexpr, stride_h: tl.constexpr,
+ stride_nt: tl.constexpr, stride_bt_k: tl.constexpr,
+ stride_k: tl.constexpr, stride_bt_v: tl.constexpr,
+ stride_v: tl.constexpr, stride_bt_o: tl.constexpr,
+):
+ """
+ One program per (B, H) pair. Sequential recurrence over NT chunks.
+ """
+ pid = tl.program_id(0)
+ # pid = b * H + h
+
+ # Initialize state S (K, V) — not stored in global, maintained in registers/smem
+ # Loop over chunks...
+ pass
+
+
+# ---------------------------------------------------------------------------
+# PyTorch helper: intra-chunk computation with tl.dot and recursive inversion
+# ---------------------------------------------------------------------------
+
+def _intra_chunk_torch(
+ k: torch.Tensor, # (B, H, NT, BT, K) bf16 after cumsum
+ g: torch.Tensor, # (B, H, NT, BT, K) fp32 cumsum applied
+ v: torch.Tensor, # (B, H, NT, BT, V) bf16
+ beta: torch.Tensor, # (B, H, NT, BT) bf16
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute w (B,H,NT,BT,K) and u (B,H,NT,BT,V) using the intra-chunk
+
... (9280 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"content": "<14783 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run check.py for correctness",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && uv run python check.py 2>&1",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py", line 113, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py", line 74, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py", line 356, in forward
w, u = _intra_chunk_torch(k, g, v, beta)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py", line 199, in _intra_chunk_torch
A_inv[..., i, :i] = A_inv[..., i, :i] + increment.squeeze(-2)
~~~~~^^^^^^^^^^^^
RuntimeError: The expanded size of the tensor (1) must match the existing size (64) at non-singleton dimension 3. Target sizes: [2, 8, 16, 1]. Tensor sizes: [2, 8, 16, 64]
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -191,12 +191,14 @@
# A[..., i, :i] += (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A_inv = A_lower.clone() # (B, H, NT, BT, BT)
for i in range(1, BT):
- # A_inv[..., i, :i] += sum over middle dim of A_inv[..., i, :, None] * A_inv[..., :, :i]
- left = A_inv[..., i:i + 1, :] # (B, H, NT, 1, BT)
- right = A_inv[..., :, :i] # (B, H, NT, BT, i)
- # product: (B, H, NT, 1, BT) * (B, H, NT, BT, i) -> (B, H, NT, 1, i) after sum(-2)
- increment = (left.unsqueeze(-2) * right.unsqueeze(-3)).sum(-2) # (B, H, NT, 1, i)
- A_inv[..., i, :i] = A_inv[..., i, :i] + increment.squeeze(-2)
+ # left: A[..., i, :, None] = (B, H, NT, BT, 1) -- row i, with trailing singleton
+ left = A_inv[..., i, :, None]
+ # right: A[..., :, :i] = (B, H, NT, BT, i) -- columns 0..i-1
+ right = A_inv[..., :, :i]
+ # Multiply broadcast: (B,H,NT,BT,1) * (B,H,NT,BT,i) -> (B,H,NT,BT,i)
+ # Sum over dim -2 (BT): -> (B,H,NT,i)
+ increment = (left * right).sum(-2)
+ A_inv[..., i, :i] = A_inv[..., i, :i] + increment
# --- Step 5: final_A = (A_inv + I) * beta[:, None, :] ---
eye = torch.eye(BT, dtype=torch.float32, device=k.device)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"oldString": "<720 chars \u2014 see below>",
"newString": "<696 chars \u2014 see below>"
} # --- Step 4: recursive inversion ---
# Reference:
# for i in range(1, BT):
# A[..., i, :i] += (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A_inv = A_lower.clone() # (B, H, NT, BT, BT)
for i in range(1, BT):
# A_inv[..., i, :i] += sum over middle dim of A_inv[..., i, :, None] * A_inv[..., :, :i]
left = A_inv[..., i:i + 1, :] # (B, H, NT, 1, BT)
right = A_inv[..., :, :i] # (B, H, NT, BT, i)
# product: (B, H, NT, 1, BT) * (B, H, NT, BT, i) -> (B, H, NT, 1, i) after sum(-2)
increment = (left.unsqueeze(-2) * right.unsqueeze(-3)).sum(-2) # (B, H, NT, 1, i)
A_inv[..., i, :i] = A_inv[..., i, :i] + increment.squeeze(-2) # --- Step 4: recursive inversion ---
# Reference:
# for i in range(1, BT):
# A[..., i, :i] += (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A_inv = A_lower.clone() # (B, H, NT, BT, BT)
for i in range(1, BT):
# left: A[..., i, :, None] = (B, H, NT, BT, 1) -- row i, with trailing singleton
left = A_inv[..., i, :, None]
# right: A[..., :, :i] = (B, H, NT, BT, i) -- columns 0..i-1
right = A_inv[..., :, :i]
# Multiply broadcast: (B,H,NT,BT,1) * (B,H,NT,BT,i) -> (B,H,NT,BT,i)
# Sum over dim -2 (BT): -> (B,H,NT,i)
increment = (left * right).sum(-2)
A_inv[..., i, :i] = A_inv[..., i, :i] + incrementEdit applied successfully.{
"description": "Run check.py for correctness",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && uv run python check.py 2>&1",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py", line 113, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py", line 74, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py", line 361, in forward
o = _inter_chunk_recurrent(q, k, g, w, u, v, beta)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py", line 294, in _inter_chunk_recurrent
S = S * torch.exp(g_i_last).transpose(-1, -2).unsqueeze(-2) # broadcast: (K, V)
~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (8) must match the size of tensor b (128) at non-singleton dimension 2
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -1,11 +1,11 @@
# --- Update S ---
- # S *= exp(g_i_last), g_i_last = g_i[:, :, -1:, :] (B, H, 1, K)
- g_i_last = g_i[:, :, -1:, :] # (B, H, 1, K)
- S = S * torch.exp(g_i_last).transpose(-1, -2).unsqueeze(-2) # broadcast: (K, V)
+ # S *= exp(g_i_last), g_i_last = g_i[:, :, -1] (B, H, K)
+ g_i_last = g_i[:, :, -1, :] # (B, H, K)
+ S = S * torch.exp(g_i_last)[:, :, :, None] # (B,H,K,V) * (B,H,K,1) -> (B,H,K,V)
# S += (exp(g_i_last - g_i) * k_i).T @ v_prime
- # decay_k = exp(g_i_last - g_i) * k_i: (B, H, BT, K)
- decay_k = torch.exp(g_i_last - g_i) * k_i # (B, H, BT, K)
+ # decay_k = exp(g_i_last[:,:,:,None] - g_i) * k_i: (B, H, BT, K)
+ decay_k = torch.exp(g_i_last[:, :, None, :] - g_i) * k_i # (B, H, BT, K)
# (B, H, K, BT) @ (B, H, BT, V) -> (B, H, K, V)
S_update = torch.einsum("bhck,bhcv->bhkv", decay_k, v_prime)
S = S + S_update{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"oldString": "<578 chars \u2014 see below>",
"newString": "<594 chars \u2014 see below>"
} # --- Update S ---
# S *= exp(g_i_last), g_i_last = g_i[:, :, -1:, :] (B, H, 1, K)
g_i_last = g_i[:, :, -1:, :] # (B, H, 1, K)
S = S * torch.exp(g_i_last).transpose(-1, -2).unsqueeze(-2) # broadcast: (K, V)
# S += (exp(g_i_last - g_i) * k_i).T @ v_prime
# decay_k = exp(g_i_last - g_i) * k_i: (B, H, BT, K)
decay_k = torch.exp(g_i_last - g_i) * k_i # (B, H, BT, K)
# (B, H, K, BT) @ (B, H, BT, V) -> (B, H, K, V)
S_update = torch.einsum("bhck,bhcv->bhkv", decay_k, v_prime)
S = S + S_update # --- Update S ---
# S *= exp(g_i_last), g_i_last = g_i[:, :, -1] (B, H, K)
g_i_last = g_i[:, :, -1, :] # (B, H, K)
S = S * torch.exp(g_i_last)[:, :, :, None] # (B,H,K,V) * (B,H,K,1) -> (B,H,K,V)
# S += (exp(g_i_last - g_i) * k_i).T @ v_prime
# decay_k = exp(g_i_last[:,:,:,None] - g_i) * k_i: (B, H, BT, K)
decay_k = torch.exp(g_i_last[:, :, None, :] - g_i) * k_i # (B, H, BT, K)
# (B, H, K, BT) @ (B, H, BT, V) -> (B, H, K, V)
S_update = torch.einsum("bhck,bhcv->bhkv", decay_k, v_prime)
S = S + S_updateEdit applied successfully.{
"description": "Run check.py again",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && uv run python check.py 2>&1",
"timeout": 120000
}PASS
{
"description": "Benchmark current PyTorch implementation",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && uv run python benchmark.py 2>&1",
"timeout": 300000
}shape=0 variant=eager tflops=0.061 gbps=0.714 ms=35.293
shape=0 variant=compiled tflops=0.509 gbps=5.972 ms=4.219
shape=0 variant=sota tflops=8.542 gbps=100.236 ms=0.251
shape=0 variant=solution tflops=0.755 gbps=8.856 ms=2.845
shape=0 solution_peak_fraction=0.0038
shape=1 variant=eager tflops=0.064 gbps=0.754 ms=66.824
shape=1 variant=compiled tflops=0.493 gbps=5.781 ms=8.718
shape=1 variant=sota tflops=15.674 gbps=183.921 ms=0.274
shape=1 variant=solution tflops=0.977 gbps=11.467 ms=4.395
shape=1 solution_peak_fraction=0.0049
shape=2 variant=eager tflops=0.041 gbps=0.479 ms=105.313
shape=2 variant=compiled tflops=0.267 gbps=3.130 ms=16.103
shape=2 variant=sota tflops=13.981 gbps=164.053 ms=0.307
shape=2 variant=solution tflops=0.616 gbps=7.229 ms=6.972
shape=2 solution_peak_fraction=0.0031
shape=3 variant=eager tflops=0.020 gbps=0.229 ms=54.919
shape=3 variant=compiled tflops=0.212 gbps=2.482 ms=5.075
shape=3 variant=sota tflops=4.089 gbps=47.975 ms=0.263
shape=3 variant=solution tflops=0.259 gbps=3.034 ms=4.153
shape=3 solution_peak_fraction=0.0013
peak_fraction: 0.0029
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -1,8 +1,8 @@
"""KDA forward (chunk form) — Triton implementation for SM120 Blackwell.
Two-kernel design:
- 1. _intra_chunk_kernel: parallel over (B, H, NT) — computes w, u from k, g, v, beta
- 2. _inter_chunk_kernel: parallel over (B, H) — sequential chunk recurrence → o
+ 1. intra_chunk: parallel over (B, H, NT) — computes w, u from k, g, v, beta
+ 2. inter_chunk_recurrent: parallel over (B, H) — sequential chunk recurrence → o
"""
from __future__ import annotations
@@ -19,7 +19,7 @@
# ---------------------------------------------------------------------------
-# Intra-chunk kernel: for each chunk compute w (BT,K) and u (BT,V)
+# Intra-chunk Triton kernel
# ---------------------------------------------------------------------------
@triton.jit
@@ -27,280 +27,466 @@
k_ptr, g_ptr, v_ptr, beta_ptr,
w_ptr, u_ptr,
BT: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
- stride_b: tl.constexpr, stride_h: tl.constexpr,
- stride_nt: tl.constexpr, stride_bt_k: tl.constexpr,
- stride_k: tl.constexpr,
- stride_bt_v: tl.constexpr, stride_v: tl.constexpr,
- BLOCK_K_K: tl.constexpr, BLOCK_K_V: tl.constexpr,
+ stride_bt_k: tl.constexpr, stride_k_k: tl.constexpr,
+ stride_bt_k2: tl.constexpr, stride_k_g: tl.constexpr,
+ stride_bt_v: tl.constexpr, stride_k_v: tl.constexpr,
+ stride_bt_w: tl.constexpr, stride_k_w: tl.constexpr,
+ stride_bt_u: tl.constexpr, stride_k_u: tl.constexpr,
+ BLOCK_KK: tl.constexpr, BLOCK_KV: tl.constexpr,
):
"""
- One program per chunk. Computes:
- A = -(k * exp(g)) @ (k * exp(-g)).T [lower-triangular part]
- A_inv = (I - A)^(-1) - I [recursive triangular inversion]
- final_A = (A_inv + I) * beta[None, :]
- w = final_A @ (k * exp(g))
- u = final_A @ v
+ One program per chunk (BT=64 rows, K=128/V=128 cols).
+ Computes:
+ k_exp_g = k * exp(g) (BT, K)
+ k_exp_neg_g = k * exp(-g) (BT, K)
+ A = k_exp_g @ k_exp_neg_g.T (BT, BT) -- lower tri only
+ A_inv via recursive triangular formula
+ final_A = (A_inv + I) * beta[None,:]
+ w = final_A @ k_exp_g (BT, K)
+ u = final_A @ v (BT, V)
+
+ All arithmetic in fp32.
"""
pid = tl.program_id(0)
- NT = tl.num_programs(0) // (1) # NT is inferred, but we don't have it directly
- # We'll pass NT separately; for now compute B*H*NT as num_programs
-
- # --- decayed keys ------------------------------------------------------
- # Load k and g for this chunk: shape (BT, K)
- offsets_bt = tl.arange(0, BT) # [0, BT)
- offsets_k = tl.arange(0, BLOCK_K_K) # [0, BLOCK_K_K)
-
- # k_exp_g = k * exp(g), stored in registers/shared: (BT, K) in fp32
- # We compute A = (k_exp_g) @ (k_exp_neg_g).T tile by tile
- # A shape: (BT, BT), lower-tri part needed
-
- # We store A in a (BT, BT) fp32 scratch in smem
- A_smem = tl.zeros((BT, BT), dtype=tl.float32)
-
- # Compute A block by block over K
- for k_start in range(0, K, BLOCK_K_K):
- k_off = k_start + offsets_k
+
+ # --- Helper: load (BT, K) tile of k into fp32 registers ----------------
+ off_bt = tl.arange(0, BT)
+
+ # Load beta: (BT,) -> fp32
+ beta_vals = tl.load(
+ beta_ptr + pid * stride_bt_w + off_bt, # stride_bt_w reused for beta stride
+ mask=off_bt < BT, other=0.0,
+ ).to(tl.float32)
+
+ # --- Allocate A (BT, BT) in registers via accumulator loop ------------
+ # We need full A for the recursive inversion, so we compute it tile-by-tile
+ # over K and accumulate into a full (BT, BT) scratch.
+
+ # Since Triton doesn't easily give us a 2D register array,
+ # we compute A in tiled fashion: for each output row i, accumulate over K.
+ # But we need A[i,j] for all i,j to do the recursive inversion.
+ # Strategy: compute A into (BT, BT) smem, then do inversion in smem.
+
+ # For now, use a sequential approach per chunk: compute A (BT,BT) fully,
+ # do recursive inversion, then w,u. All in one program.
+
+ # Better: store A in a local buffer by accumulating tile-by-tile.
+ # Let's use tl.zeros for a (BT, BT) accumulator then accum over K.
+
+ # However, each thread only sees a subset. We use a single "program"
+ # per chunk and split the work manually.
+
+ # Load k and g fully into registers? BT*K*2 = 64*128*2 = 16384 floats.
+ # That's 64KB for k in fp32 + 64KB for g in fp32.
+ # Too much for registers, but Triton handles spilling automatically.
+
+ # Simpler: process the full BTxBT output in a single program using
+ # tiling over K. Each output element A[i,j] needs all K values.
+
+ # Approach: use tl.dot for the A = ke @ kn^T matmul.
+ # ke: (BT, K) in blocks of (BT_BLK, KK_BLK)
+ # kn: (BT, K) in blocks of (BT_BLK, KK_BLK)
+ # A: (BT, BT) in blocks of (BT_BLK, BT_BLK)
+
+ # For simplicity, with BT=64, K=128, we use one program with
+ # a single tl.dot call that computes the full (BT, BT) output.
+
+ BT_PAD = triton.next_power_of_2(BT) # 64 is already pow2
+
+ # Accumulate A_ij for this chunk
+ A_acc = tl.zeros((BT, BT), dtype=tl.float32)
+
+ off_k = tl.arange(0, BLOCK_KK)
+
+ for k_start in range(0, K, BLOCK_KK):
+ k_off = k_start + off_k
k_mask = k_off < K
- # Load k tile: (BT, BLOCK_K_K)
+ # Load k tile: (BT, BLOCK_KK)
k_tile = tl.load(
- k_ptr + offsets_bt[:, None] * stride_bt_k + k_off[None, :] * stride_k,
- mask=(offsets_bt[:, None] < BT) & k_mask[None, :],
+ k_ptr + pid * stride_bt_k * BT + off_bt[:, None] * stride_bt_k + k_off[None, :] * stride_k_k,
+ mask=(off_bt[:, None] < BT) & k_mask[None, :],
other=0.0,
).to(tl.float32)
- # Load g til
... (25426 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"content": "<21165 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run check.py to verify correctness",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && uv run python check.py 2>&1",
"timeout": 120000
}PASS
{
"description": "Benchmark with torch.compile recursive inversion",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && uv run python benchmark.py 2>&1",
"timeout": 300000
}shape=0 variant=eager tflops=0.062 gbps=0.724 ms=34.781
shape=0 variant=compiled tflops=0.510 gbps=5.981 ms=4.213
shape=0 variant=sota tflops=8.606 gbps=100.982 ms=0.250
shape=0 variant=solution tflops=1.168 gbps=13.708 ms=1.838
shape=0 solution_peak_fraction=0.0058
shape=1 variant=eager tflops=0.064 gbps=0.748 ms=67.365
shape=1 variant=compiled tflops=0.491 gbps=5.762 ms=8.746
shape=1 variant=sota tflops=15.491 gbps=181.777 ms=0.277
shape=1 variant=solution tflops=1.252 gbps=14.692 ms=3.430
shape=1 solution_peak_fraction=0.0063
shape=2 variant=eager tflops=0.041 gbps=0.477 ms=105.609
shape=2 variant=compiled tflops=0.267 gbps=3.128 ms=16.114
shape=2 variant=sota tflops=13.904 gbps=163.144 ms=0.309
shape=2 variant=solution tflops=0.721 gbps=8.457 ms=5.959
shape=2 solution_peak_fraction=0.0036
shape=3 variant=eager tflops=0.020 gbps=0.232 ms=54.316
shape=3 variant=compiled tflops=0.212 gbps=2.485 ms=5.070
shape=3 variant=sota tflops=4.079 gbps=47.867 ms=0.263
shape=3 variant=solution tflops=0.343 gbps=4.030 ms=3.127
shape=3 solution_peak_fraction=0.0017
peak_fraction: 0.0039
RESULT: LOW
<bash_metadata>
bash tool terminated command after exceeding timeout 300000 ms. If this command is expected to take longer and is not waiting for interactive input, retry with a larger timeout value in milliseconds.
</bash_metadata>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -1,8 +1,9 @@
"""KDA forward (chunk form) — Triton implementation for SM120 Blackwell.
-Two-kernel design:
- 1. intra_chunk: parallel over (B, H, NT) — computes w, u from k, g, v, beta
- 2. inter_chunk_recurrent: parallel over (B, H) — sequential chunk recurrence → o
+Architecture:
+ 1. Intra-chunk: parallel over (B, H, NT) — each program computes w, u for one chunk
+ 2. Inter-chunk: parallel over (B, H), chunk loop per program — each program
+ handles one (batch, head) sequence, looping over NT chunks sequentially.
"""
from __future__ import annotations
@@ -19,430 +20,571 @@
# ---------------------------------------------------------------------------
-# Intra-chunk Triton kernel
+# Intra-chunk kernel: one program per chunk → w (BT,K), u (BT,V)
# ---------------------------------------------------------------------------
@triton.jit
def _intra_chunk_kernel(
k_ptr, g_ptr, v_ptr, beta_ptr,
- w_ptr, u_ptr,
+ w_out_ptr, u_out_ptr,
BT: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
- stride_bt_k: tl.constexpr, stride_k_k: tl.constexpr,
- stride_bt_k2: tl.constexpr, stride_k_g: tl.constexpr,
- stride_bt_v: tl.constexpr, stride_k_v: tl.constexpr,
- stride_bt_w: tl.constexpr, stride_k_w: tl.constexpr,
- stride_bt_u: tl.constexpr, stride_k_u: tl.constexpr,
- BLOCK_KK: tl.constexpr, BLOCK_KV: tl.constexpr,
+ stride_k_bt: tl.constexpr, stride_k_d: tl.constexpr,
+ stride_g_bt: tl.constexpr, stride_g_d: tl.constexpr,
+ stride_v_bt: tl.constexpr, stride_v_d: tl.constexpr,
+ stride_beta_bt: tl.constexpr,
+ stride_w_bt: tl.constexpr, stride_w_d: tl.constexpr,
+ stride_u_bt: tl.constexpr, stride_u_d: tl.constexpr,
+ BLOCK_KK: tl.constexpr,
):
- """
- One program per chunk (BT=64 rows, K=128/V=128 cols).
- Computes:
- k_exp_g = k * exp(g) (BT, K)
- k_exp_neg_g = k * exp(-g) (BT, K)
- A = k_exp_g @ k_exp_neg_g.T (BT, BT) -- lower tri only
- A_inv via recursive triangular formula
- final_A = (A_inv + I) * beta[None,:]
- w = final_A @ k_exp_g (BT, K)
- u = final_A @ v (BT, V)
-
- All arithmetic in fp32.
- """
- pid = tl.program_id(0)
-
- # --- Helper: load (BT, K) tile of k into fp32 registers ----------------
+ pid = tl.program_id(0) # chunk index ∈ [0, B*H*NT)
+
off_bt = tl.arange(0, BT)
-
- # Load beta: (BT,) -> fp32
- beta_vals = tl.load(
- beta_ptr + pid * stride_bt_w + off_bt, # stride_bt_w reused for beta stride
- mask=off_bt < BT, other=0.0,
- ).to(tl.float32)
-
- # --- Allocate A (BT, BT) in registers via accumulator loop ------------
- # We need full A for the recursive inversion, so we compute it tile-by-tile
- # over K and accumulate into a full (BT, BT) scratch.
-
- # Since Triton doesn't easily give us a 2D register array,
- # we compute A in tiled fashion: for each output row i, accumulate over K.
- # But we need A[i,j] for all i,j to do the recursive inversion.
- # Strategy: compute A into (BT, BT) smem, then do inversion in smem.
-
- # For now, use a sequential approach per chunk: compute A (BT,BT) fully,
- # do recursive inversion, then w,u. All in one program.
-
- # Better: store A in a local buffer by accumulating tile-by-tile.
- # Let's use tl.zeros for a (BT, BT) accumulator then accum over K.
-
- # However, each thread only sees a subset. We use a single "program"
- # per chunk and split the work manually.
-
- # Load k and g fully into registers? BT*K*2 = 64*128*2 = 16384 floats.
- # That's 64KB for k in fp32 + 64KB for g in fp32.
- # Too much for registers, but Triton handles spilling automatically.
-
- # Simpler: process the full BTxBT output in a single program using
- # tiling over K. Each output element A[i,j] needs all K values.
-
- # Approach: use tl.dot for the A = ke @ kn^T matmul.
- # ke: (BT, K) in blocks of (BT_BLK, KK_BLK)
- # kn: (BT, K) in blocks of (BT_BLK, KK_BLK)
- # A: (BT, BT) in blocks of (BT_BLK, BT_BLK)
-
- # For simplicity, with BT=64, K=128, we use one program with
- # a single tl.dot call that computes the full (BT, BT) output.
-
- BT_PAD = triton.next_power_of_2(BT) # 64 is already pow2
-
- # Accumulate A_ij for this chunk
- A_acc = tl.zeros((BT, BT), dtype=tl.float32)
-
- off_k = tl.arange(0, BLOCK_KK)
+ off_kk = tl.arange(0, BLOCK_KK)
+
+ # Load beta: (BT,) fp32
+ beta = tl.load(beta_ptr + pid * stride_beta_bt * BT + off_bt,
+ mask=off_bt < BT, other=1.0).to(tl.float32)
+
+ # A (BT, BT) accumulator in fp32
+ A = tl.zeros((BT, BT), dtype=tl.float32)
+
+ # k_exp_g and k_exp_neg_g computed on the fly while accumulating A
+ # We also accumulate k_exp_g * beta into a weighted version for later
+ # w computation. But since we need the full A for recursive inversion,
+ # we compute A first, then do inversion, then w,u.
+
+ ke_beta_accum = tl.zeros((BT, K), dtype=tl.float32) # beta[j] * k_exp_g[j]
+ v_beta_accum = tl.zeros((BT, V), dtype=tl.float32) # beta[j] * v[j]
for k_start in range(0, K, BLOCK_KK):
- k_off = k_start + off_k
+ k_off = k_start + off_kk
k_mask = k_off < K
- # Load k tile: (BT, BLOCK_KK)
k_tile = tl.load(
- k_ptr + pid * stride_bt_k * BT + off_bt[:, None] * stride_bt_k + k_off[None, :] * stride_k_k,
+ k_ptr + pid * stride_k_bt * BT + off_bt[:, None] * stride_k_bt + k_off[None, :] * stride_k_d,
mask=(off_bt[:, None] < BT) & k_mask[None, :],
other=0.0,
).to(tl.float32)
-
- # Load g tile: (BT, BLOCK_KK) in fp32
g_tile = tl.load(
- g_ptr + pid * stride_bt_k2 * BT + off_bt[:, None] * stride_bt_k2
... (36207 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"content": "<26939 chars \u2014 see diff>"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -1,9 +1,7 @@
-"""KDA forward (chunk form) — Triton implementation for SM120 Blackwell.
-
-Architecture:
- 1. Intra-chunk: parallel over (B, H, NT) — each program computes w, u for one chunk
- 2. Inter-chunk: parallel over (B, H), chunk loop per program — each program
- handles one (batch, head) sequence, looping over NT chunks sequentially.
+"""KDA forward (chunk form) — Triton/CUBLAS hybrid implementation for SM120.
+
+Uses PyTorch bmm (cuBLAS) for batch matmuls (intra-chunk A, w, u) +
+Triton fused kernel for the inter-chunk recurrence.
"""
from __future__ import annotations
@@ -20,526 +18,14 @@
# ---------------------------------------------------------------------------
-# Intra-chunk kernel: one program per chunk → w (BT,K), u (BT,V)
-# ---------------------------------------------------------------------------
-
-@triton.jit
-def _intra_chunk_kernel(
- k_ptr, g_ptr, v_ptr, beta_ptr,
- w_out_ptr, u_out_ptr,
- BT: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
- stride_k_bt: tl.constexpr, stride_k_d: tl.constexpr,
- stride_g_bt: tl.constexpr, stride_g_d: tl.constexpr,
- stride_v_bt: tl.constexpr, stride_v_d: tl.constexpr,
- stride_beta_bt: tl.constexpr,
- stride_w_bt: tl.constexpr, stride_w_d: tl.constexpr,
- stride_u_bt: tl.constexpr, stride_u_d: tl.constexpr,
- BLOCK_KK: tl.constexpr,
-):
- pid = tl.program_id(0) # chunk index ∈ [0, B*H*NT)
-
- off_bt = tl.arange(0, BT)
- off_kk = tl.arange(0, BLOCK_KK)
-
- # Load beta: (BT,) fp32
- beta = tl.load(beta_ptr + pid * stride_beta_bt * BT + off_bt,
- mask=off_bt < BT, other=1.0).to(tl.float32)
-
- # A (BT, BT) accumulator in fp32
- A = tl.zeros((BT, BT), dtype=tl.float32)
-
- # k_exp_g and k_exp_neg_g computed on the fly while accumulating A
- # We also accumulate k_exp_g * beta into a weighted version for later
- # w computation. But since we need the full A for recursive inversion,
- # we compute A first, then do inversion, then w,u.
-
- ke_beta_accum = tl.zeros((BT, K), dtype=tl.float32) # beta[j] * k_exp_g[j]
- v_beta_accum = tl.zeros((BT, V), dtype=tl.float32) # beta[j] * v[j]
-
- for k_start in range(0, K, BLOCK_KK):
- k_off = k_start + off_kk
- k_mask = k_off < K
-
- k_tile = tl.load(
- k_ptr + pid * stride_k_bt * BT + off_bt[:, None] * stride_k_bt + k_off[None, :] * stride_k_d,
- mask=(off_bt[:, None] < BT) & k_mask[None, :],
- other=0.0,
- ).to(tl.float32)
- g_tile = tl.load(
- g_ptr + pid * stride_g_bt * BT + off_bt[:, None] * stride_g_bt + k_off[None, :] * stride_g_d,
- mask=(off_bt[:, None] < BT) & k_mask[None, :],
- other=0.0,
- )
-
- exp_g = tl.math.exp(g_tile)
- exp_neg_g = tl.math.exp(-g_tile)
-
- ke_tile = k_tile * exp_g # (BT, BLOCK_KK)
- kn_tile = k_tile * exp_neg_g # (BT, BLOCK_KK)
-
- # Accumulate A += ke_tile @ kn_tile^T
- A += tl.dot(ke_tile, tl.trans(kn_tile))
-
- # Accumulate beta-weighted k_exp_g and beta-weighted v
- # ke_beta[j, :] += beta[j] * ke[j, :]
- for j in tl.static_range(0, BT):
- beta_j = beta[j]
- for kk in tl.static_range(0, BLOCK_KK):
- if k_off[kk] < K:
- ke_beta_accum[j, k_off[kk]] += beta_j * ke_tile[j, kk]
-
- # Load v and accumulate beta-weighted v (v_beta_accum)
- off_vv = tl.arange(0, BLOCK_KK)
- for v_start in range(0, V, BLOCK_KK):
- v_off = v_start + off_vv
- v_mask = v_off < V
- v_tile = tl.load(
- v_ptr + pid * stride_v_bt * BT + off_bt[:, None] * stride_v_bt + v_off[None, :] * stride_v_d,
- mask=(off_bt[:, None] < BT) & v_mask[None, :],
- other=0.0,
- ).to(tl.float32)
- for j in tl.static_range(0, BT):
- beta_j = beta[j]
- for vv in tl.static_range(0, BLOCK_KK):
- if v_off[vv] < V:
- v_beta_accum[j, v_off[vv]] += beta_j * v_tile[j, vv]
-
- # --- Mask A: keep only strict lower triangle (j > i), negate ---
- BT_PAD = triton.next_power_of_2(BT)
- A_pad = tl.zeros((BT_PAD, BT_PAD), dtype=tl.float32)
- for i in tl.static_range(0, BT):
- for j in tl.static_range(0, BT):
- A_pad[i, j] = A[i, j]
-
- # Zero upper + diagonal, negate lower
- for i in tl.static_range(0, BT):
- for j in tl.static_range(0, BT):
- if j > i:
- A_pad[i, j] = -A_pad[i, j]
- else:
- A_pad[i, j] = 0.0
-
- # --- Recursive inversion ---
- for i in tl.static_range(1, BT):
- for j in tl.static_range(0, i):
- acc = 0.0
- for k in tl.static_range(j + 1, i):
- acc += A_pad[i, k] * A_pad[k, j]
- A_pad[i, j] += acc
-
- # --- Compute w = (A_inv + I) @ (beta * k_exp_g) ---
- # w[i,:] = beta[i] * k_exp_g[i,:] + sum_{j > i} A_inv[i,j] * ke_beta_accum[j,:]
- # Note: A_inv[i,j] has entries for j > i (strictly lower)
- for i in tl.static_range(0, BT):
- # Diagonal: w[i,:] = ke_beta_accum[i,:]
- for kk in tl.static_range(0, K):
- tl.store(
- w_out_ptr + pid * stride_w_bt * BT + i * stride_w_bt + kk * stride_w_d,
- ke_beta_accum[i, kk],
- )
- # Add off-diagonal contributions
- for j in tl.static_range(i + 1, BT):
- a_ij = A_pad[i, j]
- for kk in tl.static_range(0, K):
- w_cur = tl.load(
- w_out_ptr + pid * stride_w_bt * BT + i * stride_w_bt + kk * stride_w_d
- )
- tl.store(
- w_out_ptr + pid * stride_w_bt * BT + i * stride_w_bt + kk * stride_w_d,
-
... (34704 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"content": "<16794 chars \u2014 see diff>"
}Wrote file successfully.