"""KDA forward solution with custom CUDA and Triton kernels."""
from __future__ import annotations
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# CUDA forward substitution kernel (loaded inline)
# ---------------------------------------------------------------------------
_CUDA_FWD_SUBST_SRC = """
__global__ void fwd_subst_kernel(float* __restrict__ A, int C, int batch) {
int b = blockIdx.x;
if (b >= batch) return;
int tid = threadIdx.x;
__shared__ float s_A[64][64];
__shared__ float s_row[64];
float* mat = A + b * C * C;
for (int idx = tid; idx < C * C; idx += blockDim.x) {
s_A[idx / C][idx % C] = mat[idx];
}
__syncthreads();
for (int i = 1; i < C; i++) {
if (tid < i) s_row[tid] = s_A[i][tid];
__syncthreads();
if (tid < i) {
float sum = 0.0f;
for (int k = 0; k < i; k++) {
sum += s_row[k] * s_A[k][tid];
}
s_A[i][tid] += sum;
}
__syncthreads();
}
for (int idx = tid; idx < C * C; idx += blockDim.x) {
mat[idx] = s_A[idx / C][idx % C];
}
}
torch::Tensor fwd_subst(torch::Tensor A) {
int batch = A.size(0);
int C = A.size(1);
fwd_subst_kernel<<<batch, 64>>>(A.data_ptr<float>(), C, batch);
return A;
}
"""
_fwd_subst_mod = load_inline(
name="kda_fwd_subst_v2",
cpp_sources="torch::Tensor fwd_subst(torch::Tensor A);",
cuda_sources=_CUDA_FWD_SUBST_SRC,
functions=["fwd_subst"],
extra_cuda_cflags=["-O3"],
verbose=False,
)
# ---------------------------------------------------------------------------
# Triton kernel: inter-chunk recurrence + output
# ---------------------------------------------------------------------------
@triton.jit
def kda_inter_chunk_kernel(
q_ptr,
k_ptr,
u_ptr,
w_ptr,
g_ptr,
Aqk_ptr,
o_ptr,
B: tl.constexpr,
T: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
scale: tl.constexpr,
):
i_v = tl.program_id(0)
i_bh = tl.program_id(1)
i_b = i_bh // H
i_h = i_bh % H
NT = T // BT
b_S1 = tl.zeros((64, BV), dtype=tl.float32)
b_S2 = tl.zeros((64, BV), dtype=tl.float32)
q_base = q_ptr + (i_b * T * H + i_h) * K
k_base = k_ptr + (i_b * T * H + i_h) * K
u_base = u_ptr + (i_b * T * H + i_h) * V
w_base = w_ptr + (i_b * T * H + i_h) * K
g_base = g_ptr + (i_b * T * H + i_h) * K
Aqk_base = Aqk_ptr + (i_b * T * H + i_h) * BT
o_base = o_ptr + (i_b * T * H + i_h) * V
for i_t in range(NT):
t0 = i_t * BT
p_u = tl.make_block_ptr(
u_base, (T, V), (H * V, 1), (t0, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_u, boundary_check=(0, 1)).to(tl.float32)
p_w = tl.make_block_ptr(
w_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v -= tl.dot(b_w, b_S1.to(b_w.dtype))
p_w = tl.make_block_ptr(
w_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v -= tl.dot(b_w, b_S2.to(b_w.dtype))
b_o = tl.zeros((BT, BV), dtype=tl.float32)
p_q = tl.make_block_ptr(
q_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
)
p_g = tl.make_block_ptr(
g_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
b_o += tl.dot(b_qg, b_S1.to(b_qg.dtype))
p_q = tl.make_block_ptr(
q_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
)
p_g = tl.make_block_ptr(
g_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
b_o += tl.dot(b_qg, b_S2.to(b_qg.dtype))
b_o *= scale
p_Aqk = tl.make_block_ptr(
Aqk_base, (T, BT), (H * BT, 1), (t0, 0), (BT, BT), (1, 0)
)
b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_Aqk = tl.where(m_s, b_Aqk, 0.0).to(b_v.dtype)
b_o += tl.dot(b_Aqk, b_v.to(b_Aqk.dtype))
p_o = tl.make_block_ptr(
o_base, (T, V), (H * V, 1), (t0, i_v * BV), (BT, BV), (1, 0)
)
tl.store(p_o, b_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
last_idx = t0 + BT - 1
p_glast1 = g_base + last_idx * H * K + tl.arange(0, 64)
p_glast2 = g_base + last_idx * H * K + 64 + tl.arange(0, 64)
b_glast1 = tl.load(p_glast1, mask=tl.arange(0, 64) < 64, other=0.0).to(tl.float32)
b_glast2 = tl.load(p_glast2, mask=tl.arange(0, 64) < 64, other=0.0).to(tl.float32)
b_decay1 = tl.exp(b_glast1)
b_decay2 = tl.exp(b_glast2)
b_S1 *= b_decay1[:, None]
b_S2 *= b_decay2[:, None]
p_k = tl.make_block_ptr(
k_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
)
p_g = tl.make_block_ptr(
g_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
b_kg = b_k * tl.exp(b_glast1[None, :] - b_g)
b_S1 += tl.dot(b_kg.to(b_v.dtype).T, b_v.to(b_k.dtype))
p_k = tl.make_block_ptr(
k_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
)
p_g = tl.make_block_ptr(
g_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
b_kg = b_k * tl.exp(b_glast2[None, :] - b_g)
b_S2 += tl.dot(b_kg.to(b_v.dtype).T, b_v.to(b_k.dtype))
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class Model(nn.Module):
def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64):
super().__init__()
self.B, self.T, self.H, self.K, self.V = B, T, H, K, V
self.chunk_size = chunk_size
self.scale = float(K) ** -0.5
self.register_buffer("_dummy", torch.zeros(1), persistent=False)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
dtype = v.dtype
B, T, H, K = q.shape
V = v.shape[-1]
BT = self.chunk_size
assert T % BT == 0
NT = T // BT
qf, kf, vf, gf, bf = (x.to(torch.float32) for x in (q, k, v, g, beta))
qf = qf * self.scale
# Chunked layout contiguous
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).contiguous()
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).contiguous()
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4).contiguous()
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).contiguous()
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2).contiguous()
gf = gf.cumsum(-2)
# ---- Intra-chunk A_kk ----
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
BHN = B * H * NT
kg_flat = kg.reshape(BHN, BT, K)
kgn_flat = kgn.reshape(BHN, BT, K)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1))
A = A.view(B, H, NT, BT, BT)
mask_diag_upper = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_diag_upper, 0)
A_flat = A.view(BHN, BT, BT)
A_flat = _fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B, H, NT, BT, BT)
I = torch.eye(BT, dtype=torch.float32, device=q.device)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1))
Aqk = Aqk.view(B, H, NT, BT, BT)
mask_strict = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
# ---- Inter-chunk recurrence via Triton kernel ----
# Use original (B, T, H, K/V) layout for the kernel to avoid extra copies
# qf_flat is already standard layout after permute+reshape
qf_flat = qf.permute(0, 2, 3, 4, 1).reshape(B, T, H, K)
kf_flat = kf.permute(0, 2, 3, 4, 1).reshape(B, T, H, K)
u_flat = u.permute(0, 2, 3, 4, 1).reshape(B, T, H, V)
w_flat = w.permute(0, 2, 3, 4, 1).reshape(B, T, H, K)
gf_flat = gf.permute(0, 2, 3, 4, 1).reshape(B, T, H, K)
Aqk_flat = Aqk.permute(0, 2, 3, 4, 1).reshape(B, T, H, BT)
o = torch.empty_like(v)
BV = 32
grid = (triton.cdiv(V, BV), B * H)
kda_inter_chunk_kernel[grid](
qf_flat,
kf_flat,
u_flat,
w_flat,
gf_flat,
Aqk_flat,
o,
B=B,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
scale=self.scale,
num_warps=4,
num_stages=1,
)
return o.to(dtype)
# Module-level shape shims
B = 2
T = 1024
H = 8
K = 128
V = 128
CHUNK_SIZE = 64
def get_inputs():
torch.manual_seed(0)
q = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1
k = torch.randn(B, T, H, K, dtype=torch.bfloat16) * 0.1
v = torch.randn(B, T, H, V, dtype=torch.bfloat16) * 0.1
g = (torch.randn(B, T, H, K, dtype=torch.float32) * 0.1 - 0.05)
beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.bfloat16))
return [q, k, v, g, beta]
def get_init_inputs():
return [B, T, H, K, V, CHUNK_SIZE]
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:322: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
shape=0 variant=eager tflops=0.050 gbps=0.583 ms=43.216
shape=0 variant=compiled tflops=0.506 gbps=5.942 ms=4.241
shape=0 variant=sota tflops=7.629 gbps=89.514 ms=0.282
shape=0 variant=solution tflops=5.971 gbps=70.068 ms=0.360
shape=0 solution_peak_fraction=0.0299
shape=1 variant=eager tflops=0.052 gbps=0.612 ms=82.394
shape=1 variant=compiled tflops=0.487 gbps=5.718 ms=8.814
shape=1 variant=sota tflops=13.533 gbps=158.801 ms=0.317
shape=1 variant=solution tflops=6.097 gbps=71.540 ms=0.704
shape=1 solution_peak_fraction=0.0305
shape=2 variant=eager tflops=0.033 gbps=0.391 ms=128.748
shape=2 variant=compiled tflops=0.262 gbps=3.079 ms=16.367
shape=2 variant=sota tflops=12.732 gbps=149.401 ms=0.337
shape=2 variant=solution tflops=4.416 gbps=51.819 ms=0.973
shape=2 solution_peak_fraction=0.0221
shape=3 variant=eager tflops=0.016 gbps=0.192 ms=65.730
shape=3 variant=compiled tflops=0.209 gbps=2.457 ms=5.129
shape=3 variant=sota tflops=3.758 gbps=44.098 ms=0.286
shape=3 variant=solution tflops=2.250 gbps=26.400 ms=0.477
shape=3 solution_peak_fraction=0.0112
peak_fraction: 0.0218
RESULT: LOW
PASS
{
"problem": "02_kda_cutlass",
"harness": "kimi",
"model": "kimi-k2.6",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0218,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": null, "output_tokens": null, "cache_read_tokens": null, "cache_creation_tokens": null, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/reference.py"
}{
"raw": "{\"path\": \"/home/infatos"
}<system>143 lines read from file starting from line 1. Total lines in file: 143. End of file reached.</system>
1 """Naive PyTorch reference for Kimi Delta Attention (KDA) forward, chunk form.
2
3 This is the correctness oracle, NOT the SOTA baseline. It mirrors the
4 chunk-parallel formulation in fla/ops/kda/naive.py (Songlin Yang et al.)
5 without any Triton or CUDA optimization.
6
7 Inputs (per the FLA convention):
8 q, k : (B, T, H, K) bf16 -- queries / keys
9 v : (B, T, H, V) bf16 -- values
10 g : (B, T, H, K) fp32 -- per-channel log-decay (in-chunk cumsum applied)
11 beta : (B, T, H) bf16 -- write strength
12
13 Output:
14 o : (B, T, H, V) bf16
15
16 The agent must reproduce this output (within bf16 tolerance) using a CUTLASS
17 CuTe kernel on SM120 -- NOT by calling fla.ops.chunk_kda directly.
18 """
19 from __future__ import annotations
20
21 import torch
22 import torch.nn as nn
23 from einops import rearrange
24
25 OP_TYPE = "linear_attention"
26 SUPPORTED_PRECISIONS = ["bf16"]
27 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
28
29
30 def _naive_chunk_kda(
31 q: torch.Tensor,
32 k: torch.Tensor,
33 v: torch.Tensor,
34 g: torch.Tensor,
35 beta: torch.Tensor,
36 scale: float,
37 chunk_size: int = 64,
38 ) -> torch.Tensor:
39
... (4705 more chars){
"raw": "{\"path\": \"/home/infatos"
}<system>113 lines read from file starting from line 1. Total lines in file: 113. End of file reached.</system>
1 """Correctness runner for KDA forward (chunk form).
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each, with per-dtype atol/rtol (bf16 default 1e-2 plus a 5e-2 override for
5 this problem). Also rejects forbidden ops by grep.
6 """
7 import re
8 import sys
9 from pathlib import Path
10
11 import torch
12 import yaml
13
14 # Make the repo's src/ importable
15 REPO_ROOT = Path(__file__).resolve().parents[2]
16 sys.path.insert(0, str(REPO_ROOT))
17
18 from src.eval.correctness import check_correctness # noqa: E402
19
20
21 def _apply_shape(reference, shape):
22 """Override reference's module-level shape shims so get_inputs/get_init_inputs match."""
23 for k, v in shape.items():
24 setattr(reference, k, v)
25
26
27 def main():
28 try:
29 import reference
30 import shapes
31 import solution
32 except Exception as e:
33 print(f"FAIL: import error: {e}")
34 sys.exit(1)
35
36 problem_yaml = Path("problem.yaml")
37 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
38
39 # --- Forbidden-op check ------------------------------------------------
40 sol_
... (3224 more chars){
"raw": "{\"path\": \"/home/infatos"
}<system>133 lines read from file starting from line 1. Total lines in file: 133. End of file reached.</system>
1 """Roofline benchmark for KDA forward (chunk form).
2
3 For each shape: times eager reference, compiled reference, SOTA (FLA's Triton
4 chunk_kda, if available on this GPU), and the agent's solution. Reports
5 achieved TFLOPS, GB/s, and peak_fraction.
6
7 Output lines the harness picks up:
8 shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
9 peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
10 """
11 import sys
12 from math import exp, log
13 from pathlib import Path
14
15 import torch
16 import yaml
17
18 REPO_ROOT = Path(__file__).resolve().parents[2]
19 sys.path.insert(0, str(REPO_ROOT))
20
21 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
22 from src.eval.timing import time_fn # noqa: E402
23 from src.hardware import get as get_hw # noqa: E402
24
25
26 def _eval_formula(expr: str, vars: dict) -> float:
27 return float(eval(expr, {"__builtins__": {}}, vars))
28
29
30 def _apply_shape(reference, shape):
31 for k, v in shape.items():
32 setattr(reference, k, v)
33
34
35 def main():
36 import reference
37 import shapes
38 import solution
39
40 meta = yaml.safe_load(Path("problem.yaml")
... (3922 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/Kernel"
}<system>54 lines read from file starting from line 1. Total lines in file: 54. End of file reached.</system>
1 name: 02_kda_cutlass
2 display_name: "Kimi Delta Attention (chunk forward) via CUTLASS CuTe"
3 precision: bf16
4 regime: compute
5
6 # Approximate FLOPs per forward call. KDA forward (chunk form) breakdown per
7 # token across all heads:
8 # - intra-chunk Aqk (q @ k.T with decay): 2*BT*K
9 # - intra-chunk Akk + cumulative invert : ~3*BT*K + BT^2
10 # - w = A @ (g.exp()*k): 2*BT*K
11 # - u = A @ v: 2*BT*V
12 # - inter-chunk q @ S: 2*K*V
13 # - inter-chunk (Aqk @ v_i): 2*BT*V
14 # - state update S += k.T @ v_i: 2*K*V
15 # Aggregating with a constant of 4 (typical for chunked linear attention) gives
16 # the formula below. Used only for roofline reporting -- not graded directly.
17 flops_formula: "4 * B * T * H * (K*V + CHUNK_SIZE*K + CHUNK_SIZE*V)"
18
19 # Bytes moved (lower bound): one read of q,k,v,g,beta and one write of o.
20 bytes_formula: "B*T*H*K*2 + B*T*H*K*2 + B*T*H*V*2 + B*T*H*K*4 + B*T*H*2 + B*T*H*V*2"
21
22 hardware: [RTX_PRO_6000]
23 peak_tflops_key: bf16 # KDA's heavy work is bf16 @ bf16 -> fp32 GEMM-equivalent
24 peak_bandwidth_key: dram
25
26 # bf16 default tolerance from src/eval/correctness.py is 1e-2 atol/rtol; KDA's
27 # long recurr
... (1230 more chars)<system>19 lines read from file starting from line 1. Total lines in file: 19. End of file reached.</system>
1 """Canonical shape sweep for KDA forward (chunk form).
2
3 Mix of:
4 - short-context training-step scale (T=1024)
5 - mid-context (T=2048) which is the headline benchmark
6 - long-context that stresses the inter-chunk recurrence (T=4096)
7 - thin-batch decode-style (B=1, T=2048, fewer heads)
8
9 Constraints:
10 - T % chunk_size == 0 (chunk_size = 64)
11 - K, V are the per-head channel dims; KDA in Kimi Linear uses K=V=128
12 """
13
14 SHAPES = [
15 {"B": 2, "T": 1024, "H": 8, "K": 128, "V": 128, "CHUNK_SIZE": 64},
16 {"B": 2, "T": 2048, "H": 8, "K": 128, "V": 128, "CHUNK_SIZE": 64},
17 {"B": 1, "T": 4096, "H": 8, "K": 128, "V": 128, "CHUNK_SIZE": 64},
18 {"B": 1, "T": 2048, "H": 4, "K": 128, "V": 128, "CHUNK_SIZE": 64},
19 ]
{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader"
}{
"raw": "{\"command\":"
}<system>Command executed successfully.</system>
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "which nvcc && nvcc --version | head -3"
}{
"raw": "{\"command\":"
}<system>Command executed successfully.</system>
/usr/local/cuda-13/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
<system>Command executed successfully.</system>
3.6.0
<system>Command executed successfully.</system>
2.11.0+cu130
NVIDIA RTX PRO 6000 Blackwell Workstation Edition
(12, 0)
{
"command": "find /home/infatoshi -path \"*/fla/ops/kda*\" -type f 2>/dev/null | head -20"
}{
"command": "pip show flash-linear-attention 2>/dev/null | head -5 || echo \"FLA not installed\""
}{
"raw": "{\"command\": \"python -c \\\"import cutlass; print(cutlass.__version__)\\\" "
}<system>Command executed successfully.</system>
Name: flash-linear-attention
Version: 0.4.2
Summary: Fast linear attention models and layers
Home-page:
Author:
{
"raw": "{\"command\": \""
}<system>Command executed successfully.</system>
CUTLASS python bindings not installed
<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/backends/__init__.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/backends/flashkda.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/backends/tilelang/__init__.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/backends/tilelang/chunk_bwd_dqkg.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/__init__.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/chunk_intra_token_parallel.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/naive.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/chunk_bwd.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/fused_recurrent.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/gate.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/chunk.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/chunk_fwd.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/chunk_intra.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/wy_fast.py
/home/infatoshi/.cache/uv/archive-v0/iFZrSVxXzSlD8QXgcEf3G/vllm/model_executor/layers/fla/ops/kda.py
/home/infatoshi/.cache/uv/archive-v0/pD3XxLT5AwqAQ7PonqFRK/vllm/model_executor/layers/fla/ops/kda.py
/home/inf
... (453 more chars)<system>Command executed successfully.</system>{
"path": "/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/chunk_fwd.py"
}{
"raw": "{\"path\": \"/home/infatoshi/."
}<system>135 lines read from file starting from line 1. Total lines in file: 135. End of file reached.</system>
1 # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li
2 #
3 # This source code is licensed under the MIT license found in the
4 # LICENSE file in the root directory of this source tree.
5 # For a list of all contributors, visit:
6 # https://github.com/fla-org/flash-linear-attention/graphs/contributors
7
8 import torch
9
10 from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
11 from fla.ops.cp import FLACPContext
12 from fla.ops.cp.chunk_delta_h import chunk_gated_delta_rule_fwd_h_pre_process, compress_h0
13 from fla.ops.gla.chunk import chunk_gla_fwd_o_gk
14 from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra
15 from fla.ops.kda.gate import kda_gate_chunk_cumsum
16 from fla.ops.utils import chunk_local_cumsum
17 from fla.ops.utils.constant import RCP_LN2
18
19
20 def chunk_kda_fwd(
21 q: torch.Tensor,
22 k: torch.Tensor,
23 v: torch.Tensor,
24 g: torch.Tensor,
25 beta: torch.Tensor,
26 scale: float,
27 initial_state: torch.Tensor,
28 output_final_state: bool,
29 cu_seqlens: torch.LongTensor | None = None,
30 cu_seqlens_cpu: torch.LongTensor | None = None,
31 chunk_indices: torch.LongTensor | None = None,
32 chunk_size: int = 64,
33 safe_gat
... (3918 more chars){
"raw": "{\"path\": \"/home/infatoshi/."
}<system>914 lines read from file starting from line 1. Total lines in file: 914. End of file reached.</system>
1 # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li
2 #
3 # This source code is licensed under the MIT license found in the
4 # LICENSE file in the root directory of this source tree.
5 # For a list of all contributors, visit:
6 # https://github.com/fla-org/flash-linear-attention/graphs/contributors
7
8 import torch
9 import triton
10 import triton.language as tl
11
12 from fla.ops.kda.chunk_intra_token_parallel import chunk_kda_fwd_intra_token_parallel
13 from fla.ops.kda.wy_fast import recompute_w_u_fwd
14 from fla.ops.utils import prepare_chunk_indices
15 from fla.ops.utils.op import exp2, gather
16 from fla.utils import IS_GATHER_SUPPORTED, IS_TF32_SUPPORTED, autotune_cache_kwargs
17
18 if IS_TF32_SUPPORTED:
19 SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32')
20 else:
21 SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
22
23 ################################################################################
24 # Fused inter + solve_tril kernel: compute off-diagonal Akk and solve in one pass
25 ################################################################################
26
27
28 @triton.heuristics({
29 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
30 })
31 @triton.autotune(
32 config
... (41025 more chars)<system>385 lines read from file starting from line 1. Total lines in file: 385. End of file reached.</system>
1 # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li
2 #
3 # This source code is licensed under the MIT license found in the
4 # LICENSE file in the root directory of this source tree.
5 # For a list of all contributors, visit:
6 # https://github.com/fla-org/flash-linear-attention/graphs/contributors
7
8 # Related files are modified and supported by the Moonshot AI Team
9
10 import torch
11
12 from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
13 from fla.ops.backends import dispatch
14 from fla.ops.cp import FLACPContext
15 from fla.ops.kda.chunk_bwd import chunk_kda_bwd
16 from fla.ops.kda.chunk_fwd import chunk_kda_fwd
17 from fla.ops.utils.index import prepare_chunk_indices
18 from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
19
20
21 class ChunkKDAFunction(torch.autograd.Function):
22 @staticmethod
23 @input_guard
24 @autocast_custom_fwd
25 def forward(
26 ctx,
27 q: torch.Tensor,
28 k: torch.Tensor,
29 v: torch.Tensor,
30 g: torch.Tensor,
31 beta: torch.Tensor,
32 A_log: torch.Tensor,
33 dt_bias: torch.Tensor,
34 scale: float,
35 initial_state: torch.Tensor,
36 o
... (18190 more chars){
"command": "<239 chars \u2014 see below>"
}python -c "
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard')
from src.hardware import get as get_hw
hw = get_hw('RTX_PRO_6000')
print('Peak TFLOPS:', hw.peak_tflops_dense)
print('Peak GB/s:', hw.peak_bandwidth_gb_s)
"<system>Command executed successfully.</system>
Peak TFLOPS: {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}
Peak GB/s: 1800.0
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/sota.py"
}<system>71 lines read from file starting from line 1. Total lines in file: 71. End of file reached.</system>
1 """SOTA reference for KDA forward: fla.ops.kda.chunk_kda (Triton).
2
3 The agent's solution is forbidden from importing this module path (see
4 problem.yaml.forbidden). This file is only used by benchmark.py to draw
5 the SOTA reference line.
6
7 If FLA's Triton kernel does not run on SM120 (Blackwell consumer-lineage --
8 some Triton kernels in FLA target Hopper TMA), is_available() returns False
9 and benchmark.py omits the SOTA variant. The H100 reference is documented
10 in problem.yaml for context.
11 """
12 from __future__ import annotations
13
14 import torch
15
16
17 def _import_fla():
18 try:
19 from fla.ops.kda import chunk_kda # noqa: F401
20 return chunk_kda
21 except Exception:
22 return None
23
24
25 def sota_forward(
26 q: torch.Tensor,
27 k: torch.Tensor,
28 v: torch.Tensor,
29 g: torch.Tensor,
30 beta: torch.Tensor,
31 scale: float | None = None,
32 ) -> torch.Tensor:
33 """Run FLA's Triton chunk_kda. Returns o (B, T, H, V) in v's dtype."""
34 chunk_kda = _import_fla()
35 if chunk_kda is None:
36 raise RuntimeError("fla.ops.kda.chunk_kda unavailable")
37 # FLA's chunk_kda has a richer signature (A_log, dt_bias
... (1530 more chars){
"command": "<139 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import sota
print('SOTA available:', sota.is_available())
"<system>Command executed successfully.</system>
SOTA available: True
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -0,0 +1,76 @@
+"""Initial naive solution based on reference.py for baseline."""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, T: int, H: int, K: int, V: int, chunk_size: int = 64):
+ super().__init__()
+ self.B, self.T, self.H, self.K, self.V = B, T, H, K, V
+ self.chunk_size = chunk_size
+ self.scale = float(K) ** -0.5
+ self.register_buffer("_dummy", torch.zeros(1), persistent=False)
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ ) -> torch.Tensor:
+ dtype = v.dtype
+ B, T, H, K = q.shape
+ V = v.shape[-1]
+ BT = self.chunk_size
+ assert T % BT == 0
+ NT = T // BT
+
+ q, k, v, g, beta = (x.to(torch.float32) for x in (q, k, v, g, beta))
+ q = q * self.scale
+
+ q = rearrange(q, "b (n c) h d -> b h n c d", c=BT)
+ k = rearrange(k, "b (n c) h d -> b h n c d", c=BT)
+ v = rearrange(v, "b (n c) h d -> b h n c d", c=BT)
+ g = rearrange(g, "b (n c) h d -> b h n c d", c=BT)
+ beta = rearrange(beta, "b (n c) h -> b h n c", c=BT)
+
+ g = g.cumsum(-2)
+
+ mask_diag_upper = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=0)
+ A = torch.zeros(*q.shape[:-1], BT, dtype=torch.float32, device=q.device)
+ for i in range(BT):
+ k_i = k[..., i, :]
+ g_i = g[..., i:i + 1, :]
+ A[..., i] = torch.einsum("... c d, ... d -> ... c", k * (g - g_i).exp(), k_i)
+ A = A * beta[..., None]
+ A = -A.masked_fill(mask_diag_upper, 0)
+
+ for i in range(1, BT):
+ A[..., i, :i] = A[..., i, :i].clone() + (A[..., i, :, None].clone() * A[..., :, :i].clone()).sum(-2)
+ A = (A + torch.eye(BT, dtype=torch.float32, device=q.device)) * beta[..., None, :]
+
+ w = A @ (g.exp() * k)
+ u = A @ v
+
+ S = q.new_zeros(B, H, K, V)
+ o = torch.zeros_like(v)
+ mask_strict_upper = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=1)
+ for i in range(NT):
+ q_i, k_i, u_i, g_i, w_i = q[:, :, i], k[:, :, i], u[:, :, i], g[:, :, i], w[:, :, i]
+ Aqk = torch.zeros(B, H, BT, BT, dtype=torch.float32, device=q.device)
+ for j in range(BT):
+ k_j = k[:, :, i, j]
+ g_j = g[:, :, i, j:j + 1, :]
+ Aqk[..., j] = torch.einsum("... c d, ... d -> ... c", q_i * (g_i - g_j).exp(), k_j)
+ Aqk = Aqk.masked_fill(mask_strict_upper, 0)
+ v_i = u_i - w_i @ S
+ o[:, :, i] = (q_i * g_i.exp()) @ S + Aqk @ v_i
+ S = S * rearrange(g_i[:, :, -1].exp(), "b h k -> b h k 1")
+ S = S + rearrange((g_i[:, :, -1:] - g_i).exp() * k_i, "b h c k -> b h k c") @ v_i
+
+ o = rearrange(o, "b h n c d -> b (n c) h d")
+ return o.to(dtype)
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"content": "<3005 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 3005 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python check.py",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python benchmark.py",
"timeout": 120
}<system>ERROR: Command killed by timeout (120s)</system>
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:322: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
{
"command": "<788 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
shape = shapes.SHAPES[1] # B=2, T=2048
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).cuda().eval()
sol_model = solution.Model(*init_args).cuda().eval()
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
# Warmup
with torch.no_grad():
_ = sol_model(*inputs)
torch.cuda.synchronize()
import time
start = time.time()
with torch.no_grad():
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
elapsed = (time.time() - start) / 10 * 1000
print(f'Mean ms: {elapsed:.3f}')
"<system>Command executed successfully.</system>
Mean ms: 77.779
{
"command": "<757 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import sota
import shapes
shape = shapes.SHAPES[1] # B=2, T=2048
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
scale = float(shape['K']) ** -0.5
# Warmup
with torch.no_grad():
_ = sota.sota_forward(*inputs, scale=scale)
torch.cuda.synchronize()
import time
start = time.time()
with torch.no_grad():
for _ in range(50):
_ = sota.sota_forward(*inputs, scale=scale)
torch.cuda.synchronize()
elapsed = (time.time() - start) / 50 * 1000
print(f'SOTA mean ms: {elapsed:.3f}')
"<system>Command executed successfully.</system>
SOTA mean ms: 0.202
{
"command": "find /home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops -name \"chunk_delta_h.py\" -o -name \"chunk_gla*\" | head -10"
}<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/common/chunk_delta_h.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/cp/chunk_delta_h.py
{
"command": "<3210 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
# Reshape
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4) # B, H, N, C, K
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 2, 1, 3) # B, H, N, C
gf = gf.cumsum(-2)
# Build A_kk using einsum
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
# B,H,N,C,K @ B,H,N,C,K -> B,H,N,C,C
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
A = A * bf.unsqueeze(-1)
A = -A.masked_fill(mask_upper, 0)
# Forward substitution
for i in range(1, BT):
A[:, :, :, i, :i] = A[:, :, :, i, :i] + (A[:, :, :, i, :, None] * A[:, :, :, :, :i]).sum(-2)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
A = (A + I) * bf.unsqueeze(-2)
w = torch.matmul(A, (gf.exp() * kf)) # B,H,N,C,K
u = torch.matmul(A, vf) # B,H,N,C,V
# Aqk
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
# Inter-chunk
S = qf.new_zeros(B, H, K, V)
o = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i] # B,H,C,K
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S) # B,H,C,V
qS = torch.matmul(q_i * g_i.exp(), S) # B,H,C,V
Av = torch.matmul(Aqk[:, :, i], v_i) # B,H,C,V
o[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1) # B,H,K,1
S = S * decay
kg_decay = (g_i[:, :, -1:].unsqueeze(-2) - g_i).exp() * k_i # B,H,C,K
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
o = o.permute(0, 2, 3, 1, 4).contiguous().view(B, T, H, V)
print('Output shape:', o.shape)
print('Output dtype:', o.dtype)
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(20):
# just the inter-chunk loop
S = qf.new_zeros(B, H, K, V)
o = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i]
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S)
qS = torch.matmul(q_i * g_i.exp(), S)
Av = torch.matmul(Aqk[:, :, i], v_i)
o[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1)
S = S * decay
kg_decay = (g_i[:, :, -1:].unsqueeze(-2) - g_i).exp() * k_i
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
torch.cuda.synchronize()
print(f'Inter-chunk only: {(time.time()-start)/20*1000:.3f} ms')
"<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 43, in <module>
RuntimeError: The size of tensor a (64) must match the size of tensor b (8) at non-singleton dimension 3
{
"command": "<4556 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
# Reshape: B,T,H,K -> B,H,N,C,K
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2) # B,H,N,C
gf = gf.cumsum(-2)
# Build A_kk using einsum
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
# B,H,N,C,K @ B,H,N,C,K -> B,H,N,C,C
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
A = A * bf.unsqueeze(-1) # B,H,N,C,1
A = -A.masked_fill(mask_upper, 0)
# Forward substitution
for i in range(1, BT):
A[:, :, :, i, :i] = A[:, :, :, i, :i] + (A[:, :, :, i, :, None] * A[:, :, :, :, :i]).sum(-2)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
A = (A + I) * bf.unsqueeze(-2) # B,H,N,1,C
w = torch.matmul(A, (gf.exp() * kf)) # B,H,N,C,K
u = torch.matmul(A, vf) # B,H,N,C,V
# Aqk
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
# Inter-chunk
S = qf.new_zeros(B, H, K, V)
o = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i] # B,H,C,K
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S) # B,H,C,V
qS = torch.matmul(q_i * g_i.exp(), S) # B,H,C,V
Av = torch.matmul(Aqk[:, :, i], v_i) # B,H,C,V
o[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1) # B,H,K,1
S = S * decay
kg_decay = (g_i[:, :, -1:].unsqueeze(-2) - g_i).exp() * k_i # B,H,C,K
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
o = o.permute(0, 2, 3, 1, 4).contiguous().view(B, T, H, V)
print('Output shape:', o.shape)
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(20):
S = qf.new_zeros(B, H, K, V)
out = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i]
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S)
qS = torch.matmul(q_i * g_i.exp(), S)
Av = torch.matmul(Aqk[:, :, i], v_i)
out[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1)
S = S * decay
kg_decay = (g_i[:, :, -1:].unsqueeze(-2) - g_i).exp() * k_i
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
torch.cuda.synchronize()
print(f'Inter-chunk only: {(time.time()-start)/20*1000:.3f} ms')
# Full timing
torch.cuda.synchronize()
start = time.time()
for _ in range(20):
gf = g.float().view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf.unsqueeze(-1)
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[:, :, :, i, :i] = A[:, :, :, i, :i] + (A[:, :, :, i, :, None] * A[:, :, :, :, :i]).sum(-2)
A = (A + I) * bf.unsqueeze(-2)
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
S = qf.new_zeros(B, H, K, V)
out = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i]
k_i = kf[:, :, i]
... (556 more chars)<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 80, in <module>
RuntimeError: The size of tensor a (8) must match the size of tensor b (2) at non-singleton dimension 1
{
"command": "<3947 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
# Reshape: B,T,H,K -> B,H,N,C,K
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2) # B,H,N,C
print('bf shape:', bf.shape)
gf = gf.cumsum(-2)
# Build A_kk using einsum
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
# B,H,N,C,K @ B,H,N,C,K -> B,H,N,C,C
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
print('A shape:', A.shape)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None] # B,H,N,C,1
print('After beta:', A.shape)
A = -A.masked_fill(mask_upper, 0)
# Forward substitution
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
A = (A + I) * bf[..., None, :] # B,H,N,1,C
print('A final shape:', A.shape)
w = torch.matmul(A, (gf.exp() * kf)) # B,H,N,C,K
print('w shape:', w.shape)
u = torch.matmul(A, vf) # B,H,N,C,V
print('u shape:', u.shape)
# Aqk
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
print('Aqk shape:', Aqk.shape)
# Inter-chunk
S = qf.new_zeros(B, H, K, V)
o = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i] # B,H,C,K
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S) # B,H,C,V
qS = torch.matmul(q_i * g_i.exp(), S) # B,H,C,V
Av = torch.matmul(Aqk[:, :, i], v_i) # B,H,C,V
o[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1) # B,H,K,1
S = S * decay
kg_decay = (g_i[:, :, -1:].unsqueeze(-2) - g_i).exp() * k_i # B,H,C,K
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
o = o.permute(0, 2, 3, 1, 4).contiguous().view(B, T, H, V)
print('o shape:', o.shape)
# Timing
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(50):
gf = g.float().view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
S = qf.new_zeros(B, H, K, V)
out = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i]
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S)
qS = torch.matmul(q_i * g_i.exp(), S)
Av = torch.matmul(Aqk[:, :, i], v_i)
out[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1)
S = S * decay
kg_decay = (g_i[:, :, -1:].unsqueeze(-2) - g_i).exp() * k_i
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
torch.cuda.synchronize()
print(f'Full vectorized: {(time.time()-start)/50*1000:.3f} ms')
"<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 87, in <module>
RuntimeError: The size of tensor a (8) must match the size of tensor b (2) at non-singleton dimension 1
bf shape: torch.Size([2, 8, 32, 64])
A shape: torch.Size([2, 8, 32, 64, 64])
After beta: torch.Size([2, 8, 32, 64, 64])
A final shape: torch.Size([2, 8, 32, 64, 64])
w shape: torch.Size([2, 8, 32, 64, 128])
u shape: torch.Size([2, 8, 32, 64, 128])
Aqk shape: torch.Size([2, 8, 32, 64, 64])
{
"command": "<2403 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
# Reshape: B,T,H,K -> B,H,N,C,K
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2) # B,H,N,C
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
S = qf.new_zeros(B, H, K, V)
print('S shape:', S.shape)
for i in range(NT):
print(f'Chunk {i}')
q_i = qf[:, :, i]
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
print('q_i:', q_i.shape)
print('k_i:', k_i.shape)
print('u_i:', u_i.shape)
print('g_i:', g_i.shape)
print('w_i:', w_i.shape)
v_i = u_i - torch.matmul(w_i, S)
print('v_i:', v_i.shape)
qS = torch.matmul(q_i * g_i.exp(), S)
print('qS:', qS.shape)
Av = torch.matmul(Aqk[:, :, i], v_i)
print('Av:', Av.shape)
decay = g_i[:, :, -1].exp().unsqueeze(-1)
print('decay:', decay.shape)
S = S * decay
kg_decay = (g_i[:, :, -1:].unsqueeze(-2) - g_i).exp() * k_i
print('kg_decay:', kg_decay.shape)
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
print('S:', S.shape)
print('Done')
"<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 80, in <module>
RuntimeError: The size of tensor a (8) must match the size of tensor b (2) at non-singleton dimension 1
S shape: torch.Size([2, 8, 128, 128])
Chunk 0
q_i: torch.Size([2, 8, 64, 128])
k_i: torch.Size([2, 8, 64, 128])
u_i: torch.Size([2, 8, 64, 128])
g_i: torch.Size([2, 8, 64, 128])
w_i: torch.Size([2, 8, 64, 128])
v_i: torch.Size([2, 8, 64, 128])
qS: torch.Size([2, 8, 64, 128])
Av: torch.Size([2, 8, 64, 128])
decay: torch.Size([2, 8, 128, 1])
{
"command": "<3756 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
# Reshape: B,T,H,K -> B,H,N,C,K
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2) # B,H,N,C
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
S = qf.new_zeros(B, H, K, V)
o = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i]
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S)
qS = torch.matmul(q_i * g_i.exp(), S)
Av = torch.matmul(Aqk[:, :, i], v_i)
o[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1) # B,H,K,1
S = S * decay
# g_i[:, :, -1:] is (B,H,1,128), g_i is (B,H,C,128)
kg_decay = (g_i[:, :, -1:] - g_i).exp() * k_i # B,H,C,K
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
o = o.permute(0, 2, 3, 1, 4).contiguous().view(B, T, H, V)
# Check correctness
ref_model = reference.Model(B, T, H, K, V, BT).cuda()
with torch.no_grad():
ref_out = ref_model(*inputs)
print('Max diff:', (ref_out.float() - o.to(ref_out.dtype).float()).abs().max().item())
# Timing
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(50):
gf = g.float().view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
S = qf.new_zeros(B, H, K, V)
out = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i]
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S)
qS = torch.matmul(q_i * g_i.exp(), S)
Av = torch.matmul(Aqk[:, :, i], v_i)
out[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1)
S = S * decay
kg_decay = (g_i[:, :, -1:] - g_i).exp() * k_i
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
torch.cuda.synchronize()
print(f'Full vectorized: {(time.time()-start)/50*1000:.3f} ms')
"<system>Command executed successfully.</system>
Max diff: 3.0517578125e-05
Full vectorized: 3.934 ms
{
"path": "/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/common/chunk_delta_h.py"
}<system>784 lines read from file starting from line 1. Total lines in file: 784. End of file reached.</system>
1 # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li
2 #
3 # This source code is licensed under the MIT license found in the
4 # LICENSE file in the root directory of this source tree.
5 # For a list of all contributors, visit:
6 # https://github.com/fla-org/flash-linear-attention/graphs/contributors
7
8 import torch
9 import triton
10 import triton.language as tl
11
12 from fla.ops.backends import dispatch
13 from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets
14 from fla.ops.utils.op import exp, exp2
15 from fla.utils import IS_NVIDIA_HOPPER, USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem
16
17 NUM_WARPS = [2, 4] if IS_NVIDIA_HOPPER else [2, 4, 8, 16]
18
19
20 @triton.heuristics({
21 'USE_G': lambda args: args['g'] is not None,
22 'USE_GK': lambda args: args['gk'] is not None,
23 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
24 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
25 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None,
26 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
27 })
28 @triton.autotune(
29 configs=[
30 triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
31 for num_warps in
... (36828 more chars){
"command": "find /home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla -name \"*.py\" | head -10"
}<system>Command executed successfully.</system>
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/fused_chunk.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/__init__.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/naive.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/fused_recurrent.py
/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/chunk.py
{
"command": "grep -n \"chunk_gla_fwd_o_gk\" /home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/chunk.py | head -10"
}<system>Command executed successfully.</system>
877:def chunk_gla_fwd_o_gk(
1127: o = chunk_gla_fwd_o_gk(
{
"path": "/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/chunk.py",
"line_offset": 877,
"n_lines": 200
}<system>200 lines read from file starting from line 877. Total lines in file: 1391.</system>
877 def chunk_gla_fwd_o_gk(
878 q: torch.Tensor,
879 v: torch.Tensor,
880 g: torch.Tensor,
881 A: torch.Tensor,
882 h: torch.Tensor,
883 scale: float,
884 cu_seqlens: torch.LongTensor | None = None,
885 chunk_size: int = 64,
886 chunk_indices: torch.LongTensor | None = None,
887 use_exp2: bool = False,
888 transpose_state_layout: bool = False,
889 ):
890 B, T, H, K, HV, V = *q.shape, v.shape[2], v.shape[-1]
891 BT = chunk_size
892
893 if chunk_indices is None and cu_seqlens is not None:
894 chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
895 NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
896
897 # Please ensure zeros, since vllm will use padding v
898 o = torch.zeros_like(v)
899 def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HV)
900 chunk_gla_fwd_kernel_o[grid](
901 q=q,
902 v=v,
903 g=g,
904 h=h,
905 o=o,
906 A=A,
907 cu_seqlens=cu_seqlens,
908 chunk_indices=chunk_indices,
909 scale=scale,
910 T=T,
911 H=H,
912 HV=HV,
913 K=K,
914 V=V,
915 BT=BT,
916 USE_EXP2=use_exp2,
917 TRANSPOSE_STATE=transpo
... (5019 more chars){
"command": "grep -n \"chunk_gla_fwd_kernel_o\" /home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/chunk.py | head -5"
}<system>Command executed successfully.</system>
313:def chunk_gla_fwd_kernel_o(
900: chunk_gla_fwd_kernel_o[grid](
{
"path": "/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/gla/chunk.py",
"line_offset": 313,
"n_lines": 200
}<system>200 lines read from file starting from line 313. Total lines in file: 1391.</system>
313 def chunk_gla_fwd_kernel_o(
314 q,
315 v,
316 g,
317 h,
318 o,
319 A,
320 cu_seqlens,
321 chunk_indices,
322 scale,
323 T,
324 H: tl.constexpr,
325 HV: tl.constexpr,
326 K: tl.constexpr,
327 V: tl.constexpr,
328 BT: tl.constexpr,
329 BK: tl.constexpr,
330 BV: tl.constexpr,
331 USE_EXP2: tl.constexpr,
332 TRANSPOSE_STATE: tl.constexpr,
333 IS_VARLEN: tl.constexpr,
334 ):
335 i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
336 i_b, i_hv = i_bh // HV, i_bh % HV
337 i_h = i_hv // (HV // H)
338 if IS_VARLEN:
339 i_tg = i_t.to(tl.int64)
340 i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
341 bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
342 T = eos - bos
343 NT = tl.cdiv(T, BT)
344 else:
345 NT = tl.cdiv(T, BT)
346 i_tg = (i_b * NT + i_t).to(tl.int64)
347 bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64)
348
349 m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
350
351 q += (bos * H + i_h) * K
352 g += (bos * HV + i_hv) * K
353 v += (bos
... (7470 more chars){
"command": "<2962 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
gf = gf.cumsum(-2)
# Time intra-chunk only
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(100):
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
torch.cuda.synchronize()
print(f'Intra-chunk: {(time.time()-start)/100*1000:.3f} ms')
# Precompute for inter-chunk timing
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
S = qf.new_zeros(B, H, K, V)
out = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i]
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S)
qS = torch.matmul(q_i * g_i.exp(), S)
Av = torch.matmul(Aqk[:, :, i], v_i)
out[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1)
S = S * decay
kg_decay = (g_i[:, :, -1:] - g_i).exp() * k_i
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
torch.cuda.synchronize()
print(f'Inter-chunk: {(time.time()-start)/100*1000:.3f} ms')
"<system>Command executed successfully.</system>
Intra-chunk: 2.355 ms
Inter-chunk: 2.563 ms
{
"command": "<2645 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
gf = gf.cumsum(-2)
# Precompute masks once
# Time with bmm instead of einsum
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(200):
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
# Using bmm
BHN = B * H * NT
kg_flat = kg.reshape(BHN, BT, K)
kgn_flat = kgn.reshape(BHN, BT, K)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)).view(B, H, NT, BT, BT)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1)).view(B, H, NT, BT, BT)
Aqk = Aqk.masked_fill(mask_strict, 0)
torch.cuda.synchronize()
print(f'With bmm: {(time.time()-start)/200*1000:.3f} ms')
# Original einsum
torch.cuda.synchronize()
start = time.time()
for _ in range(200):
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
torch.cuda.synchronize()
print(f'With einsum: {(time.time()-start)/200*1000:.3f} ms')
"<system>Command executed successfully.</system>
With bmm: 1.808 ms
With einsum: 1.327 ms
{
"command": "<720 chars \u2014 see below>",
"timeout": 60
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import time
BHN = 512 # B*H*N for T=2048
C = 64
A = torch.randn(BHN, C, C, device='cuda')
# Time forward substitution
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
A_copy = A.clone()
for i in range(1, C):
A_copy[:, i, :i] = A_copy[:, i, :i] + (A_copy[:, i, :, None] * A_copy[:, :, :i]).sum(-2)
torch.cuda.synchronize()
print(f'Forward substitution: {(time.time()-start)/1000*1000:.3f} ms')
# Compare to full matmul
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
_ = torch.bmm(A, A)
torch.cuda.synchronize()
print(f'BMM: {(time.time()-start)/1000*1000:.3f} ms')
"<system>Command executed successfully.</system>
Forward substitution: 1.222 ms
BMM: 0.049 ms
{
"command": "<1941 chars \u2014 see below>",
"timeout": 60
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import time
def tri_inv_dc(A):
'''A is (..., C, C) strictly lower triangular. Returns (I - A)^{-1}.'''
C = A.shape[-1]
result = torch.eye(C, device=A.device, dtype=A.dtype).expand_as(A).clone()
bs = 1
while bs < C:
num_pairs = C // (2 * bs)
for p in range(num_pairs):
top_start = p * 2 * bs
top_end = top_start + bs
bot_start = top_end
bot_end = bot_start + bs
top = result[..., top_start:top_end, top_start:top_end]
bot = result[..., bot_start:bot_end, bot_start:bot_end]
off = A[..., bot_start:bot_end, top_start:top_end]
result[..., bot_start:bot_end, top_start:top_end] = torch.matmul(torch.matmul(bot, off), top)
bs *= 2
return result
# Test correctness
C = 64
BHN = 512
A = torch.randn(BHN, C, C, device='cuda') * 0.1
mask = torch.tril(torch.ones(C, C, device='cuda'), diagonal=-1)
A = A * mask
# Reference: sequential forward substitution
ref = torch.eye(C, device='cuda').expand_as(A).clone()
ref[:, 0, 0] = 1
for i in range(1, C):
ref[:, i, :i] = A[:, i, :i] + (A[:, i, :, None] * ref[:, :, :i]).sum(-2)
ref[:, i, i] = 1
# DC approach
dc = tri_inv_dc(A)
print('Max diff:', (ref - dc).abs().max().item())
# Time
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
_ = tri_inv_dc(A)
torch.cuda.synchronize()
print(f'DC time: {(time.time()-start)/1000*1000:.3f} ms')
# Time reference
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
ref = torch.eye(C, device='cuda').expand_as(A).clone()
for i in range(1, C):
ref[:, i, :i] = A[:, i, :i] + (A[:, i, :, None] * ref[:, :, :i]).sum(-2)
ref[:, i, i] = 1
torch.cuda.synchronize()
print(f'Loop time: {(time.time()-start)/1000*1000:.3f} ms')
"<system>Command executed successfully.</system>
Max diff: 0.5737575888633728
DC time: 1.347 ms
Loop time: 1.471 ms
{
"command": "<1382 chars \u2014 see below>",
"timeout": 60
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
def reference_fwd_subst(A):
C = A.shape[-1]
for i in range(1, C):
A[..., i, :i] = A[..., i, :i].clone() + (A[..., i, :, None].clone() * A[..., :, :i].clone()).sum(-2)
return A
def tri_inv_dc(A):
C = A.shape[-1]
result = torch.eye(C, device=A.device, dtype=A.dtype).expand_as(A).clone()
bs = 1
while bs < C:
num_pairs = C // (2 * bs)
for p in range(num_pairs):
top_start = p * 2 * bs
top_end = top_start + bs
bot_start = top_end
bot_end = bot_start + bs
top = result[..., top_start:top_end, top_start:top_end]
bot = result[..., bot_start:bot_end, bot_start:bot_end]
off = A[..., bot_start:bot_end, top_start:top_end]
result[..., bot_start:bot_end, top_start:top_end] = torch.matmul(torch.matmul(bot, off), top)
bs *= 2
return result
# Test with small C
for C in [2, 4, 8, 16, 32, 64]:
A = torch.randn(4, C, C, device='cuda') * 0.1
mask = torch.tril(torch.ones(C, C, device='cuda'), diagonal=-1)
A = A * mask
ref = reference_fwd_subst(A.clone()) + torch.eye(C, device='cuda')
dc = tri_inv_dc(A)
print(f'C={C}: max_diff={(ref - dc).abs().max().item():.6f}')
"<system>Command executed successfully.</system>
C=2: max_diff=0.000000
C=4: max_diff=0.000000
C=8: max_diff=0.000000
C=16: max_diff=0.000000
C=32: max_diff=0.000000
C=64: max_diff=0.000000
{
"command": "<2517 chars \u2014 see below>",
"timeout": 60
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import time
def tri_inv_dc(A):
C = A.shape[-1]
result = torch.eye(C, device=A.device, dtype=A.dtype).expand_as(A).clone()
bs = 1
while bs < C:
num_pairs = C // (2 * bs)
# Batch all pairs into a single bmm
top_list = []
bot_list = []
off_list = []
for p in range(num_pairs):
top_start = p * 2 * bs
top_end = top_start + bs
bot_start = top_end
bot_end = bot_start + bs
top_list.append(result[..., top_start:top_end, top_start:top_end])
bot_list.append(result[..., bot_start:bot_end, bot_start:bot_end])
off_list.append(A[..., bot_start:bot_end, top_start:top_end])
if num_pairs > 0:
# Stack and bmm
tops = torch.stack(top_list, dim=-4) # (num_pairs, ..., bs, bs)
bots = torch.stack(bot_list, dim=-4)
offs = torch.stack(off_list, dim=-4)
# Reshape to merge batch dims
batch_shape = A.shape[:-2]
tops_flat = tops.view(num_pairs, -1, bs, bs)
bots_flat = bots.view(num_pairs, -1, bs, bs)
offs_flat = offs.view(num_pairs, -1, bs, bs)
# bot @ off
tmp = torch.bmm(bots_flat, offs_flat)
# tmp @ top
out = torch.bmm(tmp, tops_flat)
# Unstack back
out = out.view(num_pairs, *batch_shape, bs, bs)
for p in range(num_pairs):
top_start = p * 2 * bs
top_end = top_start + bs
bot_start = top_end
bot_end = bot_start + bs
result[..., bot_start:bot_end, top_start:top_end] = out[p]
bs *= 2
return result
# Performance test
C = 64
BHN = 512
A = torch.randn(BHN, C, C, device='cuda') * 0.1
mask = torch.tril(torch.ones(C, C, device='cuda'), diagonal=-1)
A = A * mask
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
_ = tri_inv_dc(A)
torch.cuda.synchronize()
print(f'Batched DC: {(time.time()-start)/1000*1000:.3f} ms')
# Reference
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
ref = A.clone()
for i in range(1, C):
ref[:, i, :i] = ref[:, i, :i] + (ref[:, i, :, None] * ref[:, :, :i]).sum(-2)
torch.cuda.synchronize()
print(f'Loop: {(time.time()-start)/1000*1000:.3f} ms')
"<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 65, in <module>
File "<string>", line 38, in tri_inv_dc
RuntimeError: batch1 must be a 3D tensor
{
"command": "<2350 chars \u2014 see below>",
"timeout": 60
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import time
def tri_inv_dc(A):
C = A.shape[-1]
result = torch.eye(C, device=A.device, dtype=A.dtype).expand_as(A).clone()
batch_shape = A.shape[:-2]
batch_numel = batch_shape.numel()
bs = 1
while bs < C:
num_pairs = C // (2 * bs)
if num_pairs > 0:
# Stack all pairs
tops = []
bots = []
offs = []
for p in range(num_pairs):
top_start = p * 2 * bs
top_end = top_start + bs
bot_start = top_end
bot_end = bot_start + bs
tops.append(result[..., top_start:top_end, top_start:top_end])
bots.append(result[..., bot_start:bot_end, bot_start:bot_end])
offs.append(A[..., bot_start:bot_end, top_start:top_end])
# Shape: (num_pairs, batch_numel, bs, bs)
tops_t = torch.stack(tops, dim=0).view(num_pairs * batch_numel, bs, bs)
bots_t = torch.stack(bots, dim=0).view(num_pairs * batch_numel, bs, bs)
offs_t = torch.stack(offs, dim=0).view(num_pairs * batch_numel, bs, bs)
tmp = torch.bmm(bots_t, offs_t)
out = torch.bmm(tmp, tops_t)
out = out.view(num_pairs, *batch_shape, bs, bs)
for p in range(num_pairs):
top_start = p * 2 * bs
top_end = top_start + bs
bot_start = top_end
bot_end = bot_start + bs
result[..., bot_start:bot_end, top_start:top_end] = out[p]
bs *= 2
return result
# Performance test
C = 64
BHN = 512
A = torch.randn(BHN, C, C, device='cuda') * 0.1
mask = torch.tril(torch.ones(C, C, device='cuda'), diagonal=-1)
A = A * mask
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
_ = tri_inv_dc(A)
torch.cuda.synchronize()
print(f'Batched DC: {(time.time()-start)/1000*1000:.3f} ms')
# Reference
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
ref = A.clone()
for i in range(1, C):
ref[:, i, :i] = ref[:, i, :i] + (ref[:, i, :, None] * ref[:, :, :i]).sum(-2)
torch.cuda.synchronize()
print(f'Loop: {(time.time()-start)/1000*1000:.3f} ms')
"<system>Command executed successfully.</system>
Batched DC: 0.874 ms
Loop: 1.221 ms
{
"command": "<1590 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_source = '''
__global__ void fwd_subst_kernel(float* __restrict__ A, int C, int batch) {
int b = blockIdx.x;
if (b >= batch) return;
float* mat = A + b * C * C;
for (int i = 1; i < C; i++) {
for (int j = 0; j < i; j++) {
float sum = 0.0f;
for (int k = 0; k < i; k++) {
sum += mat[i * C + k] * mat[k * C + j];
}
mat[i * C + j] += sum;
}
}
}
torch::Tensor fwd_subst(torch::Tensor A) {
int batch = A.size(0);
int C = A.size(1);
fwd_subst_kernel<<<batch, 1>>>(A.data_ptr<float>(), C, batch);
return A;
}
'''
module = load_inline(
name='fwd_subst',
cpp_sources='torch::Tensor fwd_subst(torch::Tensor A);',
cuda_sources=cuda_source,
functions=['fwd_subst'],
extra_cuda_cflags=['-O3'],
verbose=False,
)
C = 64
BHN = 512
A = torch.randn(BHN, C, C, device='cuda') * 0.1
mask = torch.tril(torch.ones(C, C, device='cuda'), diagonal=-1)
A = (A * mask).contiguous()
# Correctness
ref = A.clone()
for i in range(1, C):
ref[:, i, :i] = ref[:, i, :i] + (ref[:, i, :, None] * ref[:, :, :i]).sum(-2)
out = module.fwd_subst(A.clone())
print('Max diff:', (ref - out).abs().max().item())
# Timing
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(10000):
_ = module.fwd_subst(A.clone())
torch.cuda.synchronize()
print(f'CUDA kernel: {(time.time()-start)/10000*1000:.3f} ms')
"<system>Command executed successfully.</system>
Max diff: 1.1920928955078125e-07
CUDA kernel: 0.979 ms
{
"command": "<2230 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_source = '''
__global__ void fwd_subst_kernel(float* __restrict__ A, int C, int batch) {
int b = blockIdx.x;
if (b >= batch) return;
int tid = threadIdx.x;
__shared__ float s_A[64][64];
float* mat = A + b * C * C;
// Load matrix into shared memory
for (int idx = tid; idx < C * C; idx += blockDim.x) {
s_A[idx / C][idx % C] = mat[idx];
}
__syncthreads();
// Save original row i in a shared buffer
__shared__ float s_row[64];
for (int i = 1; i < C; i++) {
// Save original row i
if (tid < i) {
s_row[tid] = s_A[i][tid];
}
__syncthreads();
// Each thread computes one j
if (tid < i) {
float sum = 0.0f;
for (int k = 0; k < i; k++) {
sum += s_row[k] * s_A[k][tid];
}
s_A[i][tid] += sum;
}
__syncthreads();
}
// Store back
for (int idx = tid; idx < C * C; idx += blockDim.x) {
mat[idx] = s_A[idx / C][idx % C];
}
}
torch::Tensor fwd_subst(torch::Tensor A) {
int batch = A.size(0);
int C = A.size(1);
fwd_subst_kernel<<<batch, 64>>>(A.data_ptr<float>(), C, batch);
return A;
}
'''
module = load_inline(
name='fwd_subst2',
cpp_sources='torch::Tensor fwd_subst(torch::Tensor A);',
cuda_sources=cuda_source,
functions=['fwd_subst'],
extra_cuda_cflags=['-O3'],
verbose=False,
)
C = 64
BHN = 512
A = torch.randn(BHN, C, C, device='cuda') * 0.1
mask = torch.tril(torch.ones(C, C, device='cuda'), diagonal=-1)
A = (A * mask).contiguous()
# Correctness
ref = A.clone()
for i in range(1, C):
ref[:, i, :i] = ref[:, i, :i] + (ref[:, i, :, None] * ref[:, :, :i]).sum(-2)
out = module.fwd_subst(A.clone())
print('Max diff:', (ref - out).abs().max().item())
# Timing
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(10000):
_ = module.fwd_subst(A.clone())
torch.cuda.synchronize()
print(f'CUDA kernel (64 threads): {(time.time()-start)/10000*1000:.3f} ms')
"<system>Command executed successfully.</system>
Max diff: 1.1920928955078125e-07
CUDA kernel (64 threads): 0.034 ms
{
"path": "/home/infatoshi/.cache/uv/archive-v0/rnS-5wxCI5ghhBOYKv7Tc/fla/ops/kda/wy_fast.py"
}<system>352 lines read from file starting from line 1. Total lines in file: 352. End of file reached.</system>
1 # Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li
2 #
3 # This source code is licensed under the MIT license found in the
4 # LICENSE file in the root directory of this source tree.
5 # For a list of all contributors, visit:
6 # https://github.com/fla-org/flash-linear-attention/graphs/contributors
7
8 import torch
9 import triton
10 import triton.language as tl
11
12 from fla.ops.utils import prepare_chunk_indices
13 from fla.ops.utils.op import exp2
14 from fla.utils import autotune_cache_kwargs, check_shared_mem
15
16
17 @triton.heuristics({
18 'STORE_QG': lambda args: args['qg'] is not None,
19 'STORE_KG': lambda args: args['kg'] is not None,
20 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
21 })
22 @triton.autotune(
23 configs=[
24 triton.Config({}, num_warps=num_warps, num_stages=num_stages)
25 for num_warps in [2, 4, 8]
26 for num_stages in [2, 3, 4]
27 ],
28 key=['H', 'HV', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
29 **autotune_cache_kwargs,
30 )
31 @triton.jit(do_not_specialize=['T'])
32 def recompute_w_u_fwd_kda_kernel(
33 q,
34 k,
35 qg,
36 kg,
37 v,
38 beta,
39 w,
40 u,
4
... (13031 more chars){
"command": "<132 chars \u2014 see below>"
}python -c "
import triton.language as tl
print('exp available:', hasattr(tl, 'exp'))
print('exp2 available:', hasattr(tl, 'exp2'))
"<system>Command executed successfully.</system>
exp available: True
exp2 available: True
{
"command": "<3170 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
gf = gf.cumsum(-2)
# Measure with cuda events
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(10):
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
torch.cuda.synchronize()
start_event.record()
for _ in range(100):
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
end_event.record()
torch.cuda.synchronize()
print(f'Intra-chunk (100 iters): {start_event.elapsed_time(end_event):.3f} ms')
print(f'Intra-chunk per iter: {start_event.elapsed_time(end_event)/100:.3f} ms')
# Measure just einsum + matmul without fwd subst
start_event.record()
for _ in range(100):
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
end_event.record()
torch.cuda.synchronize()
print(f'Without fwd subst (100 iters): {start_event.elapsed_time(end_event):.3f} ms')
print(f'Without fwd subst per iter: {start_event.elapsed_time(end_event)/100:.3f} ms')
"<system>Command executed successfully.</system>
Intra-chunk (100 iters): 136.076 ms
Intra-chunk per iter: 1.361 ms
Without fwd subst (100 iters): 21.846 ms
Without fwd subst per iter: 0.218 ms
{
"command": "<3009 chars \u2014 see below>",
"timeout": 180
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import shapes
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
torch.manual_seed(2026)
inputs = [t.cuda() if hasattr(t, 'cuda') else t for t in reference.get_inputs()]
q, k, v, g, beta = inputs
scale = float(shape['K']) ** -0.5
B, T, H, K = q.shape
V = v.shape[-1]
BT = 64
NT = T // BT
qf = q.float() * scale
kf = k.float()
vf = v.float()
gf = g.float()
bf = beta.float()
qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2)
mask_upper = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=0)
mask_strict = torch.triu(torch.ones(BT, BT, device=q.device, dtype=torch.bool), diagonal=1)
I = torch.eye(BT, device=q.device, dtype=torch.float32)
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
A = torch.einsum('bhnik,bhnjk->bhnij', kg, kgn)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
for i in range(1, BT):
A[..., i, :i] = A[..., i, :i] + (A[..., i, :, None] * A[..., :, :i]).sum(-2)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
Aqk = torch.einsum('bhnik,bhnjk->bhnij', qg, kgn)
Aqk = Aqk.masked_fill(mask_strict, 0)
def inter_chunk(qf, kf, gf, w, u, Aqk):
S = qf.new_zeros(B, H, K, V)
o = torch.zeros_like(vf)
for i in range(NT):
q_i = qf[:, :, i]
k_i = kf[:, :, i]
u_i = u[:, :, i]
g_i = gf[:, :, i]
w_i = w[:, :, i]
v_i = u_i - torch.matmul(w_i, S)
qS = torch.matmul(q_i * g_i.exp(), S)
Av = torch.matmul(Aqk[:, :, i], v_i)
o[:, :, i] = qS + Av
decay = g_i[:, :, -1].exp().unsqueeze(-1)
S = S * decay
kg_decay = (g_i[:, :, -1:] - g_i).exp() * k_i
S = S + torch.matmul(kg_decay.transpose(-2, -1), v_i)
return o
compiled_inter = torch.compile(inter_chunk, mode='max-autotune')
# Warmup compiled
for _ in range(5):
_ = compiled_inter(qf, kf, gf, w, u, Aqk)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(100):
_ = compiled_inter(qf, kf, gf, w, u, Aqk)
end_event.record()
torch.cuda.synchronize()
print(f'Compiled inter-chunk (100 iters): {start_event.elapsed_time(end_event):.3f} ms')
print(f'Compiled inter-chunk per iter: {start_event.elapsed_time(end_event)/100:.3f} ms')
# Baseline
start_event.record()
for _ in range(100):
_ = inter_chunk(qf, kf, gf, w, u, Aqk)
end_event.record()
torch.cuda.synchronize()
print(f'Eager inter-chunk (100 iters): {start_event.elapsed_time(end_event):.3f} ms')
print(f'Eager inter-chunk per iter: {start_event.elapsed_time(end_event)/100:.3f} ms')
"<system>Command executed successfully. Output is truncated to fit in the message.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:322: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py:3686: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
current_out_size = out_base.storage().size()
E0428 10:21:41.698000 370917 torch/_inductor/select_algorithm.py:3924] [0/0] Runtime error during autotuning:
E0428 10:21:41.698000 370917 torch/_inductor/select_algorithm.py:3924] [0/0] No valid triton configs. OutOfMemoryError: out of resource: triton_bmm Required: 196608 Hardware limit:101376 Reducing block sizes or `num_stages` may help..
E0428 10:21:41.698000 370917 torch/_inductor/select_algorithm.py:3924] [0/0] Ignoring this choice.
Autotune Choices Stats:
{"num_choices": 19, "num_triton_choices": 18, "best_kernel": "triton_bmm_2", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, n
... (48591 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -1,11 +1,216 @@
-"""Initial naive solution based on reference.py for baseline."""
+"""KDA forward solution with custom CUDA and Triton kernels."""
from __future__ import annotations
+
+import math
import torch
import torch.nn as nn
-from einops import rearrange
-
-
+import triton
+import triton.language as tl
+
+# ---------------------------------------------------------------------------
+# CUDA forward substitution kernel (loaded inline)
+# ---------------------------------------------------------------------------
+_CUDA_FWD_SUBST_SRC = """
+__global__ void fwd_subst_kernel(float* __restrict__ A, int C, int batch) {
+ int b = blockIdx.x;
+ if (b >= batch) return;
+ int tid = threadIdx.x;
+ __shared__ float s_A[64][64];
+ __shared__ float s_row[64];
+ float* mat = A + b * C * C;
+ for (int idx = tid; idx < C * C; idx += blockDim.x) {
+ s_A[idx / C][idx % C] = mat[idx];
+ }
+ __syncthreads();
+ for (int i = 1; i < C; i++) {
+ if (tid < i) s_row[tid] = s_A[i][tid];
+ __syncthreads();
+ if (tid < i) {
+ float sum = 0.0f;
+ for (int k = 0; k < i; k++) {
+ sum += s_row[k] * s_A[k][tid];
+ }
+ s_A[i][tid] += sum;
+ }
+ __syncthreads();
+ }
+ for (int idx = tid; idx < C * C; idx += blockDim.x) {
+ mat[idx] = s_A[idx / C][idx % C];
+ }
+}
+torch::Tensor fwd_subst(torch::Tensor A) {
+ int batch = A.size(0);
+ int C = A.size(1);
+ fwd_subst_kernel<<<batch, 64>>>(A.data_ptr<float>(), C, batch);
+ return A;
+}
+"""
+
+_fwd_subst_mod = torch.utils.cpp_extension.load_inline(
+ name="kda_fwd_subst",
+ cpp_sources="torch::Tensor fwd_subst(torch::Tensor A);",
+ cuda_sources=_CUDA_FWD_SUBST_SRC,
+ functions=["fwd_subst"],
+ extra_cuda_cflags=["-O3"],
+ verbose=False,
+)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel: inter-chunk recurrence + output
+# ---------------------------------------------------------------------------
+@triton.jit
+def kda_inter_chunk_kernel(
+ q_ptr,
+ k_ptr,
+ u_ptr,
+ w_ptr,
+ g_ptr,
+ Aqk_ptr,
+ o_ptr,
+ B: tl.constexpr,
+ T: tl.constexpr,
+ H: tl.constexpr,
+ K: tl.constexpr,
+ V: tl.constexpr,
+ BT: tl.constexpr,
+ BV: tl.constexpr,
+ scale: tl.constexpr,
+):
+ i_v = tl.program_id(0)
+ i_bh = tl.program_id(1)
+ i_b = i_bh // H
+ i_h = i_bh % H
+
+ NT = T // BT
+
+ # State tiles: K is fixed at 128, split into two 64-wide tiles
+ # S1: (64, BV), S2: (64, BV)
+ b_S1 = tl.zeros((64, BV), dtype=tl.float32)
+ b_S2 = tl.zeros((64, BV), dtype=tl.float32)
+
+ # Pointers base
+ q_base = q_ptr + (i_b * T * H + i_h) * K
+ k_base = k_ptr + (i_b * T * H + i_h) * K
+ u_base = u_ptr + (i_b * T * H + i_h) * V
+ w_base = w_ptr + (i_b * T * H + i_h) * K
+ g_base = g_ptr + (i_b * T * H + i_h) * K
+ Aqk_base = Aqk_ptr + (i_b * T * H + i_h) * BT
+ o_base = o_ptr + (i_b * T * H + i_h) * V
+
+ for i_t in range(NT):
+ t0 = i_t * BT
+
+ # Load u_i tile (BT, BV)
+ p_u = tl.make_block_ptr(
+ u_base, (T, V), (H * V, 1), (t0, i_v * BV), (BT, BV), (1, 0)
+ )
+ b_v = tl.load(p_u, boundary_check=(0, 1)).to(tl.float32)
+
+ # v_i = u_i - w_i @ S
+ # First K tile (0:64)
+ p_w = tl.make_block_ptr(
+ w_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
+ )
+ b_w = tl.load(p_w, boundary_check=(0, 1))
+ b_v -= tl.dot(b_w, b_S1.to(b_w.dtype))
+
+ # Second K tile (64:128)
+ p_w = tl.make_block_ptr(
+ w_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
+ )
+ b_w = tl.load(p_w, boundary_check=(0, 1))
+ b_v -= tl.dot(b_w, b_S2.to(b_w.dtype))
+
+ # Compute o_i = (q_i * exp(g_i)) @ S + Aqk @ v_i
+ b_o = tl.zeros((BT, BV), dtype=tl.float32)
+
+ # qS from first K tile
+ p_q = tl.make_block_ptr(
+ q_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
+ )
+ p_g = tl.make_block_ptr(
+ g_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
+ )
+ b_q = tl.load(p_q, boundary_check=(0, 1))
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
+ b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
+ b_o += tl.dot(b_qg, b_S1.to(b_qg.dtype))
+
+ # qS from second K tile
+ p_q = tl.make_block_ptr(
+ q_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
+ )
+ p_g = tl.make_block_ptr(
+ g_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
+ )
+ b_q = tl.load(p_q, boundary_check=(0, 1))
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
+ b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
+ b_o += tl.dot(b_qg, b_S2.to(b_qg.dtype))
+
+ b_o *= scale
+
+ # Aqk @ v_i
+ p_Aqk = tl.make_block_ptr(
+ Aqk_base, (T, BT), (H * BT, 1), (t0, 0), (BT, BT), (1, 0)
+ )
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
+ # Mask to strict lower triangular
+ o_i = tl.arange(0, BT)
+ m_s = o_i[:, None] >= o_i[None, :]
+ b_Aqk = tl.where(m_s, b_Aqk, 0.0).to(b_v.dtype)
+ b_o += tl.dot(b_Aqk, b_v.to(b_Aqk.dtype))
+
+ # Store o_i
+ p_o = tl.make_block_ptr(
+ o_base, (T, V), (H * V, 1), (t0, i_v * BV), (BT, BV), (1, 0)
+ )
+ tl.store(p_o, b_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
+
+ # Update state S
+ # decay = exp(g_i[:, -1])
+ last_idx = t0 + BT - 1
+ p_glast1 = g_base + last_idx * H * K + tl.arange(0, 64)
+ p_glast2 = g_base + last_idx * H * K + 64 + tl.arange(0, 64)
+ b_glast1
... (7153 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"content": "<11019 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 11019 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python check.py",
"timeout": 120
}<system>ERROR: Command failed with exit code: 1.</system>
FAIL: import error: module 'torch.utils' has no attribute 'cpp_extension'
{
"command": "python -c \"import torch.utils.cpp_extension; print('available')\" 2>&1 || echo \"not available\""
}<system>Command executed successfully.</system>
available
{
"command": "<156 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
from torch.utils.cpp_extension import load_inline
print('OK')
"<system>Command executed successfully.</system>
OK
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c \"import solution; print('OK')\" 2>&1"
}<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py", line 50, in <module>
_fwd_subst_mod = torch.utils.cpp_extension.load_inline(
^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'torch.utils' has no attribute 'cpp_extension'
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"edit": {
"old": "import torch\nimport torch.nn as nn\nimport triton\nimport triton.language as tl",
"new": "import torch\nimport torch.nn as nn\nfrom torch.utils.cpp_extension import load_inline\nimport triton\nimport triton.language as tl"
}
}<system>File successfully edited. Applied 1 edit(s) with 1 total replacement(s).</system>{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"edit": {
"old": "_fwd_subst_mod = torch.utils.cpp_extension.load_inline(",
"new": "_fwd_subst_mod = load_inline("
}
}<system>File successfully edited. Applied 1 edit(s) with 1 total replacement(s).</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python check.py",
"timeout": 180
}<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py", line 113, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/check.py", line 74, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py", line 298, in forward
kda_inter_chunk_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.
... (579 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"edit": {
"old": " grid = (triton.cdiv(V, BV), B * H)\n kda_inter_chunk_kernel[grid](\n q_flat,\n k_flat,\n u_flat,\n w_flat,\n g_flat,\n Aqk_flat,\n o,\n B=B,\n T=T,\n H=H,\n K=K,\n V=V,\n BT=BT,\n BV=BV,\n scale=self.scale,\n )",
"new": " grid = (triton.cdiv(V, BV), B * H)\n kda_inter_chunk_kernel[
... (382 more chars)<system>File successfully edited. Applied 1 edit(s) with 1 total replacement(s).</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python check.py",
"timeout": 180
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python benchmark.py",
"timeout": 180
}<system>ERROR: Command killed by timeout (180s)</system>{
"command": "<1395 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
device = torch.device('cuda:0')
for shape_idx, shape in enumerate(shapes.SHAPES):
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Warmup
with torch.no_grad():
for _ in range(5):
_ = sol_model(*inputs)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
for _ in range(50):
_ = sol_model(*inputs)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / 50
# FLOPs
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
CHUNK_SIZE = shape['CHUNK_SIZE']
flops = 4 * B * T * H * (K * V + CHUNK_SIZE * K + CHUNK_SIZE * V)
tflops = flops / (ms * 1e-3) / 1e12
peak_frac = tflops / 200.0
print(f'Shape {shape_idx}: {ms:.3f} ms, {tflops:.2f} TFLOPS, peak_fraction={peak_frac:.4f}')
"<system>Command executed successfully.</system>
Shape 0: 0.452 ms, 4.75 TFLOPS, peak_fraction=0.0238
Shape 1: 0.900 ms, 4.77 TFLOPS, peak_fraction=0.0239
Shape 2: 1.391 ms, 3.09 TFLOPS, peak_fraction=0.0154
Shape 3: 0.583 ms, 1.84 TFLOPS, peak_fraction=0.0092
{
"command": "<846 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Profile with torch.profiler
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with torch.no_grad():
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=20))
"<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
kda_inter_chunk_kernel 0.00% 0.000us 0.00% 0.000us 0.000us 4.808ms 53.89% 4.808ms 480.761us 10
aten::copy_ 3.57% 4.716ms 7.42% 9.802ms 57.660us 1.990ms 22.30% 2.015ms 11.855us 170
aten::clone 0.40% 527.595us 1.66% 2.194ms 16.874us 0.000us 0.00% 1.555ms 11.963us
... (3634 more chars){
"command": "<4892 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Warmup
with torch.no_grad():
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
# Time with different num_warps by modifying the kernel launch
# Actually, let's just test num_warps=8 directly
from solution import kda_inter_chunk_kernel
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
BT = shape['CHUNK_SIZE']
BV = 64
# Get the intermediate tensors
with torch.no_grad():
# Run forward once to get tensors
qf = inputs[0].float() * (float(K) ** -0.5)
kf = inputs[1].float()
vf = inputs[2].float()
gf = inputs[3].float()
bf = inputs[4].float()
B_, T_, H_, K_ = qf.shape
V_ = vf.shape[-1]
NT_ = T_ // BT
qf = qf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4)
kf = kf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4)
vf = vf.view(B_, NT_, BT, H_, V_).permute(0, 3, 1, 2, 4)
gf = gf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4)
bf = bf.view(B_, NT_, BT, H_).permute(0, 3, 1, 2)
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
BHN = B_ * H_ * NT_
kg_flat = kg.reshape(BHN, BT, K_)
kgn_flat = kgn.reshape(BHN, BT, K_)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_upper = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
A_flat = A.view(BHN, BT, BT).contiguous()
A_flat = solution._fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B_, H_, NT_, BT, BT)
I = torch.eye(BT, device=device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K_)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_strict = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
q_flat = qf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
k_flat = kf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
u_flat = u.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, V_)
w_flat = w.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
g_flat = gf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
Aqk_flat = Aqk.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, BT)
o = torch.empty_like(inputs[2])
scale = float(K_) ** -0.5
# Test with num_warps=8
grid = (triton.cdiv(V_, BV), B_ * H_)
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=8, num_stages=1,
)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(100):
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=8, num_stages=1,
)
end.record()
torch.cuda.synchronize()
print(f'num_warps=8: {start.elapsed_time(end)/100:.4f} ms')
# Test with num_warps=4, BV=32
BV2 = 32
o2 = torch.empty_like(inputs[2])
grid2 = (triton.cdiv(V_, BV2), B_ * H_)
kda_inter_chunk_kernel[grid2](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o2,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV2, scale=s
... (892 more chars)<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "<string>", line 87, in <module>
NameError: name 'triton' is not defined
{
"command": "<4708 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import triton
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Warmup
with torch.no_grad():
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
from solution import kda_inter_chunk_kernel
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
BT = shape['CHUNK_SIZE']
BV = 64
# Get the intermediate tensors
with torch.no_grad():
qf = inputs[0].float() * (float(K) ** -0.5)
kf = inputs[1].float()
vf = inputs[2].float()
gf = inputs[3].float()
bf = inputs[4].float()
B_, T_, H_, K_ = qf.shape
V_ = vf.shape[-1]
NT_ = T_ // BT
qf = qf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4)
kf = kf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4)
vf = vf.view(B_, NT_, BT, H_, V_).permute(0, 3, 1, 2, 4)
gf = gf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4)
bf = bf.view(B_, NT_, BT, H_).permute(0, 3, 1, 2)
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
BHN = B_ * H_ * NT_
kg_flat = kg.reshape(BHN, BT, K_)
kgn_flat = kgn.reshape(BHN, BT, K_)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_upper = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
A_flat = A.view(BHN, BT, BT).contiguous()
A_flat = solution._fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B_, H_, NT_, BT, BT)
I = torch.eye(BT, device=device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K_)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_strict = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
q_flat = qf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
k_flat = kf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
u_flat = u.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, V_)
w_flat = w.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
g_flat = gf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
Aqk_flat = Aqk.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, BT)
o = torch.empty_like(inputs[2])
scale = float(K_) ** -0.5
# Test with num_warps=8
grid = (triton.cdiv(V_, BV), B_ * H_)
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=8, num_stages=1,
)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(100):
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=8, num_stages=1,
)
end.record()
torch.cuda.synchronize()
print(f'num_warps=8: {start.elapsed_time(end)/100:.4f} ms')
# Test with BV=32
BV2 = 32
o2 = torch.empty_like(inputs[2])
grid2 = (triton.cdiv(V_, BV2), B_ * H_)
kda_inter_chunk_kernel[grid2](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o2,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV2, scale=scale,
num_warps=4, num_stages=1,
)
torch.cuda.synchronize()
start.record()
for _ in range(100):
kda_inter_chunk_
... (708 more chars)<system>Command executed successfully.</system>
num_warps=8: 0.5711 ms
BV=32, num_warps=4: 0.2474 ms
BV=32, num_warps=8: 4.2533 ms
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"edit": {
"old": " BV = 64\n grid = (triton.cdiv(V, BV), B * H)\n kda_inter_chunk_kernel[grid](\n q_flat,\n k_flat,\n u_flat,\n w_flat,\n g_flat,\n Aqk_flat,\n o,\n B=B,\n T=T,\n H=H,\n K=K,\n V=V,\n BT=BT,\n BV=BV,\n scale=self.scale,\n num_warps=4,\n num_stages=1,\n )",
"new": "
... (469 more chars)<system>File successfully edited. Applied 1 edit(s) with 1 total replacement(s).</system>{
"command": "<1334 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
device = torch.device('cuda:0')
for shape_idx, shape in enumerate(shapes.SHAPES):
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Warmup
with torch.no_grad():
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
for _ in range(50):
_ = sol_model(*inputs)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / 50
# FLOPs
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
CHUNK_SIZE = shape['CHUNK_SIZE']
flops = 4 * B * T * H * (K * V + CHUNK_SIZE * K + CHUNK_SIZE * V)
tflops = flops / (ms * 1e-3) / 1e12
peak_frac = tflops / 200.0
print(f'Shape {shape_idx}: {ms:.3f} ms, {tflops:.2f} TFLOPS, peak_fraction={peak_frac:.4f}')
"<system>Command executed successfully.</system>
Shape 0: 0.377 ms, 5.70 TFLOPS, peak_fraction=0.0285
Shape 1: 0.709 ms, 6.06 TFLOPS, peak_fraction=0.0303
Shape 2: 0.976 ms, 4.40 TFLOPS, peak_fraction=0.0220
Shape 3: 0.396 ms, 2.71 TFLOPS, peak_fraction=0.0136
{
"command": "<816 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with torch.no_grad():
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=25))
"<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
kda_inter_chunk_kernel 0.00% 0.000us 0.00% 0.000us 0.000us 2.814ms 40.63% 2.814ms 281.415us 10
aten::copy_ 3.30% 4.172ms 7.98% 10.084ms 59.318us 1.993ms 28.77% 2.019ms 11.875us 170
aten::clone 0.17% 209.729us 1.35% 1.710ms 13.154us 0.000us 0.00% 1.561ms 12.007us
... (4624 more chars){
"command": "which ncu || which nv-nsight-cu-cli || echo \"ncu not found\""
}<system>Command executed successfully.</system>
/usr/local/cuda-13/bin/ncu
{
"command": "<3242 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import triton
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Warmup
with torch.no_grad():
for _ in range(5):
_ = sol_model(*inputs)
# Now run just the inter-chunk kernel with ncu
from solution import kda_inter_chunk_kernel
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
BT = shape['CHUNK_SIZE']
BV = 32
with torch.no_grad():
qf = inputs[0].float() * (float(K) ** -0.5)
kf = inputs[1].float()
vf = inputs[2].float()
gf = inputs[3].float()
bf = inputs[4].float()
B_, T_, H_, K_ = qf.shape
V_ = vf.shape[-1]
NT_ = T_ // BT
qf = qf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
kf = kf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
vf = vf.view(B_, NT_, BT, H_, V_).permute(0, 3, 1, 2, 4).contiguous()
gf = gf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
bf = bf.view(B_, NT_, BT, H_).permute(0, 3, 1, 2).contiguous()
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
BHN = B_ * H_ * NT_
kg_flat = kg.reshape(BHN, BT, K_)
kgn_flat = kgn.reshape(BHN, BT, K_)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_upper = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
A_flat = A.view(BHN, BT, BT).contiguous()
A_flat = solution._fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B_, H_, NT_, BT, BT)
I = torch.eye(BT, device=device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K_)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_strict = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
q_flat = qf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
k_flat = kf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
u_flat = u.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, V_)
w_flat = w.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
g_flat = gf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
Aqk_flat = Aqk.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, BT)
o = torch.empty_like(inputs[2])
scale = float(K_) ** -0.5
grid = (triton.cdiv(V_, BV), B_ * H_)
# Run many times for profiling
for _ in range(100):
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=4, num_stages=1,
)
torch.cuda.synchronize()
print('Done')
" 2>&1 | tail -5<system>Command executed successfully.</system>
Done
{
"command": "<3246 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && ncu --kernel-name kda_inter_chunk_kernel --metrics sm__warps_active.avg.pct_of_peak_sustained_elapsed,gpu__time_duration.avg,launch__block_size python -c "
import torch
import triton
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
from solution import kda_inter_chunk_kernel
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
BT = shape['CHUNK_SIZE']
BV = 32
with torch.no_grad():
qf = inputs[0].float() * (float(K) ** -0.5)
kf = inputs[1].float()
vf = inputs[2].float()
gf = inputs[3].float()
bf = inputs[4].float()
B_, T_, H_, K_ = qf.shape
V_ = vf.shape[-1]
NT_ = T_ // BT
qf = qf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
kf = kf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
vf = vf.view(B_, NT_, BT, H_, V_).permute(0, 3, 1, 2, 4).contiguous()
gf = gf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
bf = bf.view(B_, NT_, BT, H_).permute(0, 3, 1, 2).contiguous()
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
BHN = B_ * H_ * NT_
kg_flat = kg.reshape(BHN, BT, K_)
kgn_flat = kgn.reshape(BHN, BT, K_)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_upper = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
A_flat = A.view(BHN, BT, BT).contiguous()
A_flat = solution._fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B_, H_, NT_, BT, BT)
I = torch.eye(BT, device=device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K_)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_strict = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
q_flat = qf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
k_flat = kf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
u_flat = u.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, V_)
w_flat = w.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
g_flat = gf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
Aqk_flat = Aqk.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, BT)
o = torch.empty_like(inputs[2])
scale = float(K_) ** -0.5
grid = (triton.cdiv(V_, BV), B_ * H_)
for _ in range(10):
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=4, num_stages=1,
)
torch.cuda.synchronize()
" 2>&1 | grep -E "(Duration|Warp|block_size|Kernel)" | head -20<system>Command executed successfully.</system>{
"command": "<3126 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && ncu --kernel-name-regex "kda_inter" --metrics gpu__time_duration.avg python -c "
import torch
import triton
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
from solution import kda_inter_chunk_kernel
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
BT = shape['CHUNK_SIZE']
BV = 32
with torch.no_grad():
qf = inputs[0].float() * (float(K) ** -0.5)
kf = inputs[1].float()
vf = inputs[2].float()
gf = inputs[3].float()
bf = inputs[4].float()
B_, T_, H_, K_ = qf.shape
V_ = vf.shape[-1]
NT_ = T_ // BT
qf = qf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
kf = kf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
vf = vf.view(B_, NT_, BT, H_, V_).permute(0, 3, 1, 2, 4).contiguous()
gf = gf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
bf = bf.view(B_, NT_, BT, H_).permute(0, 3, 1, 2).contiguous()
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
BHN = B_ * H_ * NT_
kg_flat = kg.reshape(BHN, BT, K_)
kgn_flat = kgn.reshape(BHN, BT, K_)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_upper = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
A_flat = A.view(BHN, BT, BT).contiguous()
A_flat = solution._fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B_, H_, NT_, BT, BT)
I = torch.eye(BT, device=device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K_)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_strict = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
q_flat = qf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
k_flat = kf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
u_flat = u.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, V_)
w_flat = w.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
g_flat = gf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
Aqk_flat = Aqk.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, BT)
o = torch.empty_like(inputs[2])
scale = float(K_) ** -0.5
grid = (triton.cdiv(V_, BV), B_ * H_)
for _ in range(100):
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=4, num_stages=1,
)
torch.cuda.synchronize()
" 2>&1 | tail -30<system>Command executed successfully.</system>
==ERROR== unrecognised option '--kernel-name-regex'. Use --help for further details.
{
"command": "<3109 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && ncu -k "kda_inter" --metrics gpu__time_duration.avg python -c "
import torch
import triton
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
from solution import kda_inter_chunk_kernel
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
BT = shape['CHUNK_SIZE']
BV = 32
with torch.no_grad():
qf = inputs[0].float() * (float(K) ** -0.5)
kf = inputs[1].float()
vf = inputs[2].float()
gf = inputs[3].float()
bf = inputs[4].float()
B_, T_, H_, K_ = qf.shape
V_ = vf.shape[-1]
NT_ = T_ // BT
qf = qf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
kf = kf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
vf = vf.view(B_, NT_, BT, H_, V_).permute(0, 3, 1, 2, 4).contiguous()
gf = gf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
bf = bf.view(B_, NT_, BT, H_).permute(0, 3, 1, 2).contiguous()
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
BHN = B_ * H_ * NT_
kg_flat = kg.reshape(BHN, BT, K_)
kgn_flat = kgn.reshape(BHN, BT, K_)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_upper = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
A_flat = A.view(BHN, BT, BT).contiguous()
A_flat = solution._fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B_, H_, NT_, BT, BT)
I = torch.eye(BT, device=device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K_)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_strict = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
q_flat = qf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
k_flat = kf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
u_flat = u.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, V_)
w_flat = w.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
g_flat = gf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
Aqk_flat = Aqk.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, BT)
o = torch.empty_like(inputs[2])
scale = float(K_) ** -0.5
grid = (triton.cdiv(V_, BV), B_ * H_)
for _ in range(100):
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=4, num_stages=1,
)
torch.cuda.synchronize()
" 2>&1 | tail -30<system>Command executed successfully.</system>
==PROF== Connected to process 380195 (/usr/bin/python3.12)
==PROF== Disconnected from process 380195
==WARNING== No kernels were profiled.
Available Kernels:
1. Kernel2
2. elementwise_kernel
3. fwd_subst_kernel
4. kda_inter_chunk_kernel
5. tensor_kernel_scan_outer_dim
6. triu_tril_kernel
7. unrolled_elementwise_kernel
8. vectorized_elementwise_kernel
{
"command": "<3218 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && ncu -k "kda_inter_chunk_kernel" --metrics gpu__time_duration.avg,sm__warps_active.avg.pct_of_peak_sustained_elapsed python -c "
import torch
import triton
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
from solution import kda_inter_chunk_kernel
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
BT = shape['CHUNK_SIZE']
BV = 32
with torch.no_grad():
qf = inputs[0].float() * (float(K) ** -0.5)
kf = inputs[1].float()
vf = inputs[2].float()
gf = inputs[3].float()
bf = inputs[4].float()
B_, T_, H_, K_ = qf.shape
V_ = vf.shape[-1]
NT_ = T_ // BT
qf = qf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
kf = kf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
vf = vf.view(B_, NT_, BT, H_, V_).permute(0, 3, 1, 2, 4).contiguous()
gf = gf.view(B_, NT_, BT, H_, K_).permute(0, 3, 1, 2, 4).contiguous()
bf = bf.view(B_, NT_, BT, H_).permute(0, 3, 1, 2).contiguous()
gf = gf.cumsum(-2)
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
BHN = B_ * H_ * NT_
kg_flat = kg.reshape(BHN, BT, K_)
kgn_flat = kgn.reshape(BHN, BT, K_)
A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_upper = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_upper, 0)
A_flat = A.view(BHN, BT, BT).contiguous()
A_flat = solution._fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B_, H_, NT_, BT, BT)
I = torch.eye(BT, device=device, dtype=torch.float32)
A = (A + I) * bf[..., None, :]
w = torch.matmul(A, (gf.exp() * kf))
u = torch.matmul(A, vf)
qg = qf * gf.exp()
qg_flat = qg.reshape(BHN, BT, K_)
Aqk = torch.bmm(qg_flat, kgn_flat.transpose(-2, -1)).view(B_, H_, NT_, BT, BT)
mask_strict = torch.triu(torch.ones(BT, BT, device=device, dtype=torch.bool), diagonal=1)
Aqk = Aqk.masked_fill(mask_strict, 0)
q_flat = qf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
k_flat = kf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
u_flat = u.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, V_)
w_flat = w.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
g_flat = gf.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, K_)
Aqk_flat = Aqk.permute(0, 2, 3, 4, 1).reshape(B_, T_, H_, BT)
o = torch.empty_like(inputs[2])
scale = float(K_) ** -0.5
grid = (triton.cdiv(V_, BV), B_ * H_)
for _ in range(100):
kda_inter_chunk_kernel[grid](
q_flat, k_flat, u_flat, w_flat, g_flat, Aqk_flat, o,
B=B_, T=T_, H=H_, K=K_, V=V_, BT=BT, BV=BV, scale=scale,
num_warps=4, num_stages=1,
)
torch.cuda.synchronize()
" 2>&1 | grep -E "(Duration|Warp|Kernel|kda_inter)" | head -20<system>Command executed successfully.</system>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py
@@ -1,10 +1,9 @@
"""KDA forward solution with custom CUDA and Triton kernels."""
from __future__ import annotations
-import math
-
import torch
import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
import triton
import triton.language as tl
@@ -47,8 +46,8 @@
}
"""
-_fwd_subst_mod = torch.utils.cpp_extension.load_inline(
- name="kda_fwd_subst",
+_fwd_subst_mod = load_inline(
+ name="kda_fwd_subst_v2",
cpp_sources="torch::Tensor fwd_subst(torch::Tensor A);",
cuda_sources=_CUDA_FWD_SUBST_SRC,
functions=["fwd_subst"],
@@ -85,12 +84,9 @@
NT = T // BT
- # State tiles: K is fixed at 128, split into two 64-wide tiles
- # S1: (64, BV), S2: (64, BV)
b_S1 = tl.zeros((64, BV), dtype=tl.float32)
b_S2 = tl.zeros((64, BV), dtype=tl.float32)
- # Pointers base
q_base = q_ptr + (i_b * T * H + i_h) * K
k_base = k_ptr + (i_b * T * H + i_h) * K
u_base = u_ptr + (i_b * T * H + i_h) * V
@@ -102,31 +98,25 @@
for i_t in range(NT):
t0 = i_t * BT
- # Load u_i tile (BT, BV)
p_u = tl.make_block_ptr(
u_base, (T, V), (H * V, 1), (t0, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_u, boundary_check=(0, 1)).to(tl.float32)
- # v_i = u_i - w_i @ S
- # First K tile (0:64)
p_w = tl.make_block_ptr(
w_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v -= tl.dot(b_w, b_S1.to(b_w.dtype))
- # Second K tile (64:128)
p_w = tl.make_block_ptr(
w_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v -= tl.dot(b_w, b_S2.to(b_w.dtype))
- # Compute o_i = (q_i * exp(g_i)) @ S + Aqk @ v_i
b_o = tl.zeros((BT, BV), dtype=tl.float32)
- # qS from first K tile
p_q = tl.make_block_ptr(
q_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
)
@@ -138,7 +128,6 @@
b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
b_o += tl.dot(b_qg, b_S1.to(b_qg.dtype))
- # qS from second K tile
p_q = tl.make_block_ptr(
q_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
)
@@ -152,25 +141,20 @@
b_o *= scale
- # Aqk @ v_i
p_Aqk = tl.make_block_ptr(
Aqk_base, (T, BT), (H * BT, 1), (t0, 0), (BT, BT), (1, 0)
)
b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
- # Mask to strict lower triangular
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_Aqk = tl.where(m_s, b_Aqk, 0.0).to(b_v.dtype)
b_o += tl.dot(b_Aqk, b_v.to(b_Aqk.dtype))
- # Store o_i
p_o = tl.make_block_ptr(
o_base, (T, V), (H * V, 1), (t0, i_v * BV), (BT, BV), (1, 0)
)
tl.store(p_o, b_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
- # Update state S
- # decay = exp(g_i[:, -1])
last_idx = t0 + BT - 1
p_glast1 = g_base + last_idx * H * K + tl.arange(0, 64)
p_glast2 = g_base + last_idx * H * K + 64 + tl.arange(0, 64)
@@ -181,9 +165,6 @@
b_S1 *= b_decay1[:, None]
b_S2 *= b_decay2[:, None]
- # S += k_decay.T @ v_i
- # k_decay = exp(g_last - g_i) * k_i
- # First K tile
p_k = tl.make_block_ptr(
k_base, (T, K), (H * K, 1), (t0, 0), (BT, 64), (1, 0)
)
@@ -195,7 +176,6 @@
b_kg = b_k * tl.exp(b_glast1[None, :] - b_g)
b_S1 += tl.dot(b_kg.to(b_v.dtype).T, b_v.to(b_k.dtype))
- # Second K tile
p_k = tl.make_block_ptr(
k_base, (T, K), (H * K, 1), (t0, 64), (BT, 64), (1, 0)
)
@@ -237,69 +217,66 @@
qf, kf, vf, gf, bf = (x.to(torch.float32) for x in (q, k, v, g, beta))
qf = qf * self.scale
- # Reshape to chunk form: (B, H, N, C, D)
- qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
- kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
- vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4)
- gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4)
- bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2)
-
- # In-chunk cumsum
+ # Chunked layout contiguous
+ qf = qf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).contiguous()
+ kf = kf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).contiguous()
+ vf = vf.view(B, NT, BT, H, V).permute(0, 3, 1, 2, 4).contiguous()
+ gf = gf.view(B, NT, BT, H, K).permute(0, 3, 1, 2, 4).contiguous()
+ bf = bf.view(B, NT, BT, H).permute(0, 3, 1, 2).contiguous()
+
gf = gf.cumsum(-2)
# ---- Intra-chunk A_kk ----
kg = kf * gf.exp()
kgn = kf * (-gf).exp()
- # B,H,N,C,K -> BHN,C,K for bmm
BHN = B * H * NT
kg_flat = kg.reshape(BHN, BT, K)
kgn_flat = kgn.reshape(BHN, BT, K)
- A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1)) # (BHN, C, C)
+ A = torch.bmm(kg_flat, kgn_flat.transpose(-2, -1))
A = A.view(B, H, NT, BT, BT)
mask_diag_upper = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=0)
A = A * bf[..., None]
A = -A.masked_fill(mask_diag_upper, 0)
- # Forward substitution via custom CUDA kernel
- A_flat = A.view(BHN, BT, BT).contiguous()
+ A_flat = A.view(BHN, BT, BT)
A_flat = _fwd_subst_mod.fwd_subst(A_flat)
A = A_flat.view(B, H, NT, BT, BT)
I = torch.eye(BT, dtype=torch.float32, device=q.device)
A = (A + I) * bf[..., None, :]
- # w and u
- w = torch.matmul(A, (gf.exp() * kf)) # (B, H, N, C, K)
- u = torch.
... (1956 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass/solution.py",
"content": "<10407 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 10407 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python check.py",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "<1334 chars \u2014 see below>",
"timeout": 120
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
device = torch.device('cuda:0')
for shape_idx, shape in enumerate(shapes.SHAPES):
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Warmup
with torch.no_grad():
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
for _ in range(50):
_ = sol_model(*inputs)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / 50
# FLOPs
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
CHUNK_SIZE = shape['CHUNK_SIZE']
flops = 4 * B * T * H * (K * V + CHUNK_SIZE * K + CHUNK_SIZE * V)
tflops = flops / (ms * 1e-3) / 1e12
peak_frac = tflops / 200.0
print(f'Shape {shape_idx}: {ms:.3f} ms, {tflops:.2f} TFLOPS, peak_fraction={peak_frac:.4f}')
"<system>Command executed successfully.</system>
Shape 0: 0.341 ms, 6.30 TFLOPS, peak_fraction=0.0315
Shape 1: 0.681 ms, 6.31 TFLOPS, peak_fraction=0.0315
Shape 2: 0.949 ms, 4.53 TFLOPS, peak_fraction=0.0226
Shape 3: 0.384 ms, 2.79 TFLOPS, peak_fraction=0.0140
{
"command": "<835 chars \u2014 see below>",
"timeout": 60
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import time
C = 64
BHN = 512
A = torch.randn(BHN, C, C, device='cuda') * 0.1
mask = torch.tril(torch.ones(C, C, device='cuda'), diagonal=-1)
A = A * mask
I = torch.eye(C, device='cuda').expand_as(A)
# Test solve_triangular
start = time.time()
for _ in range(100):
L = I - A
X = torch.linalg.solve_triangular(L, I, upper=False)
torch.cuda.synchronize()
print(f'solve_triangular: {(time.time()-start)/100*1000:.3f} ms')
# Reference forward substitution
start = time.time()
for _ in range(100):
ref = A.clone()
for i in range(1, C):
ref[:, i, :i] = ref[:, i, :i] + (ref[:, i, :, None] * ref[:, :, :i]).sum(-2)
ref = ref + I
torch.cuda.synchronize()
print(f'fwd subst loop: {(time.time()-start)/100*1000:.3f} ms')
"<system>Command executed successfully.</system>
solve_triangular: 0.489 ms
fwd subst loop: 1.291 ms
{
"command": "<1588 chars \u2014 see below>",
"timeout": 300
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[1]
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
# Compile the forward method
compiled_model = torch.compile(sol_model, mode='max-autotune')
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Warmup
with torch.no_grad():
for _ in range(10):
_ = compiled_model(*inputs)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
for _ in range(50):
_ = compiled_model(*inputs)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / 50
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
CHUNK_SIZE = shape['CHUNK_SIZE']
flops = 4 * B * T * H * (K * V + CHUNK_SIZE * K + CHUNK_SIZE * V)
tflops = flops / (ms * 1e-3) / 1e12
peak_frac = tflops / 200.0
print(f'Compiled: {ms:.3f} ms, {tflops:.2f} TFLOPS, peak_fraction={peak_frac:.4f}')
# Baseline
start.record()
with torch.no_grad():
for _ in range(50):
_ = sol_model(*inputs)
end.record()
torch.cuda.synchronize()
ms2 = start.elapsed_time(end) / 50
tflops2 = flops / (ms2 * 1e-3) / 1e12
peak_frac2 = tflops2 / 200.0
print(f'Eager: {ms2:.3f} ms, {tflops2:.2f} TFLOPS, peak_fraction={peak_frac2:.4f}')
"<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:2082: UserWarning: Dynamo does not know how to trace the builtin `kda_fwd_subst_v2.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.fwd_subst.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:322: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
E0428 10:41:28.605000 382171 torch/_inductor/select_algorithm.py:3541] [0/0_1] Exception No valid triton configs. OutOfMemoryError: out of resource: triton_bmm Required: 131072 Hardware limit:101376 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller
... (10241 more chars){
"command": "<1421 chars \u2014 see below>",
"timeout": 300
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/02_kda_cutlass && python -c "
import torch
import reference
import solution
import shapes
device = torch.device('cuda:0')
for shape_idx, shape in enumerate(shapes.SHAPES):
for k, v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
compiled_model = torch.compile(sol_model, mode='max-autotune')
torch.manual_seed(2026)
inputs = [t.to(device) if hasattr(t, 'to') else t for t in reference.get_inputs()]
# Extensive warmup
with torch.no_grad():
for _ in range(20):
_ = compiled_model(*inputs)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
for _ in range(50):
_ = compiled_model(*inputs)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / 50
# FLOPs
B, T, H, K, V = shape['B'], shape['T'], shape['H'], shape['K'], shape['V']
CHUNK_SIZE = shape['CHUNK_SIZE']
flops = 4 * B * T * H * (K * V + CHUNK_SIZE * K + CHUNK_SIZE * V)
tflops = flops / (ms * 1e-3) / 1e12
peak_frac = tflops / 200.0
print(f'Shape {shape_idx}: {ms:.3f} ms, {tflops:.2f} TFLOPS, peak_fraction={peak_frac:.4f}')
"<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:2082: UserWarning: Dynamo does not know how to trace the builtin `kda_fwd_subst_v2.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.fwd_subst.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:322: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
E0428 10:42:00.129000 382720 torch/_inductor/select_algorithm.py:3541] [0/0_1] Exception No valid triton configs. OutOfMemoryError: out of resource: triton_bmm Required: 131072 Hardware limit:101376 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller
... (15377 more chars)