from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "grouped_gemm_swiglu"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
@triton.jit
def _grouped_gemm_swiglu_3d_kernel(
x_ptr,
W_gate_ptr,
W_up_ptr,
out_ptr,
expert_offsets_ptr,
H,
I,
E,
num_m_tiles,
stride_xm,
stride_xk,
stride_wge,
stride_wgk,
stride_wgn,
stride_wue,
stride_wuk,
stride_wun,
stride_om,
stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_m = tl.program_id(1)
pid_e = tl.program_id(2)
expert_start = tl.load(expert_offsets_ptr + pid_e)
expert_end = tl.load(expert_offsets_ptr + pid_e + 1)
m_start = expert_start + pid_m * BLOCK_M
if m_start >= expert_end:
return
m_offsets = m_start + tl.arange(0, BLOCK_M)
n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
gate_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
up_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, H, BLOCK_K):
k_off = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + m_offsets[:, None] * stride_xm + k_off[None, :] * stride_xk
x_mask = (m_offsets[:, None] < expert_end) & (k_off[None, :] < H)
x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0)
w_mask = (k_off[:, None] < H) & (n_offsets[None, :] < I)
wg_ptrs = W_gate_ptr + pid_e * stride_wge + k_off[:, None] * stride_wgk + n_offsets[None, :] * stride_wgn
wg_tile = tl.load(wg_ptrs, mask=w_mask, other=0.0)
wu_ptrs = W_up_ptr + pid_e * stride_wue + k_off[:, None] * stride_wuk + n_offsets[None, :] * stride_wun
wu_tile = tl.load(wu_ptrs, mask=w_mask, other=0.0)
gate_acc = tl.dot(x_tile, wg_tile, acc=gate_acc, out_dtype=tl.float32)
up_acc = tl.dot(x_tile, wu_tile, acc=up_acc, out_dtype=tl.float32)
out_tile = gate_acc * tl.sigmoid(gate_acc) * up_acc
out_ptrs = out_ptr + m_offsets[:, None] * stride_om + n_offsets[None, :] * stride_on
out_mask = (m_offsets[:, None] < expert_end) & (n_offsets[None, :] < I)
tl.store(out_ptrs, out_tile.to(tl.bfloat16), mask=out_mask)
@triton.jit
def _grouped_gemm_swiglu_2pass_kernel(
x_ptr,
W_gate_ptr,
W_up_ptr,
out_ptr,
expert_offsets_ptr,
H,
I,
E,
num_m_tiles,
stride_xm,
stride_xk,
stride_wge,
stride_wgk,
stride_wgn,
stride_wue,
stride_wuk,
stride_wun,
stride_om,
stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_m = tl.program_id(1)
pid_e = tl.program_id(2)
expert_start = tl.load(expert_offsets_ptr + pid_e)
expert_end = tl.load(expert_offsets_ptr + pid_e + 1)
m_start = expert_start + pid_m * BLOCK_M
if m_start >= expert_end:
return
m_offsets = m_start + tl.arange(0, BLOCK_M)
n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
gate_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, H, BLOCK_K):
k_off = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + m_offsets[:, None] * stride_xm + k_off[None, :] * stride_xk
x_mask = (m_offsets[:, None] < expert_end) & (k_off[None, :] < H)
x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0)
w_mask = (k_off[:, None] < H) & (n_offsets[None, :] < I)
wg_ptrs = W_gate_ptr + pid_e * stride_wge + k_off[:, None] * stride_wgk + n_offsets[None, :] * stride_wgn
wg_tile = tl.load(wg_ptrs, mask=w_mask, other=0.0)
gate_acc = tl.dot(x_tile, wg_tile, acc=gate_acc, out_dtype=tl.float32)
up_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, H, BLOCK_K):
k_off = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + m_offsets[:, None] * stride_xm + k_off[None, :] * stride_xk
x_mask = (m_offsets[:, None] < expert_end) & (k_off[None, :] < H)
x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0)
w_mask = (k_off[:, None] < H) & (n_offsets[None, :] < I)
wu_ptrs = W_up_ptr + pid_e * stride_wue + k_off[:, None] * stride_wuk + n_offsets[None, :] * stride_wun
wu_tile = tl.load(wu_ptrs, mask=w_mask, other=0.0)
up_acc = tl.dot(x_tile, wu_tile, acc=up_acc, out_dtype=tl.float32)
out_tile = gate_acc * tl.sigmoid(gate_acc) * up_acc
out_ptrs = out_ptr + m_offsets[:, None] * stride_om + n_offsets[None, :] * stride_on
out_mask = (m_offsets[:, None] < expert_end) & (n_offsets[None, :] < I)
tl.store(out_ptrs, out_tile.to(tl.bfloat16), mask=out_mask)
def _compute_grid(hidden_states, W_gate, expert_offsets, BLOCK_M, BLOCK_N):
T_perm, H = hidden_states.shape
E = expert_offsets.shape[0] - 1
I = W_gate.shape[2]
num_n_tiles = (I + BLOCK_N - 1) // BLOCK_N
offsets_cpu = expert_offsets.cpu()
max_m_tiles = 0
for e in range(E):
n_e = int(offsets_cpu[e + 1]) - int(offsets_cpu[e])
nm = (n_e + BLOCK_M - 1) // BLOCK_M
max_m_tiles = max(max_m_tiles, nm)
return num_n_tiles, max_m_tiles
def _launch_kernel(
kernel,
hidden_states: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
expert_offsets: torch.Tensor,
BLOCK_M: int = 64,
BLOCK_N: int = 128,
BLOCK_K: int = 64,
num_warps: int = 8,
num_stages: int = 2,
) -> torch.Tensor:
T_perm, H = hidden_states.shape
E = expert_offsets.shape[0] - 1
I = W_gate.shape[2]
num_n_tiles, max_m_tiles = _compute_grid(hidden_states, W_gate, expert_offsets, BLOCK_M, BLOCK_N)
if max_m_tiles == 0 or num_n_tiles == 0:
return torch.zeros(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
grid = (num_n_tiles, max_m_tiles, E)
kernel[grid](
hidden_states,
W_gate,
W_up,
out,
expert_offsets,
H,
I,
E,
max_m_tiles,
hidden_states.stride(0),
hidden_states.stride(1),
W_gate.stride(0),
W_gate.stride(1),
W_gate.stride(2),
W_up.stride(0),
W_up.stride(1),
W_up.stride(2),
out.stride(0),
out.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
num_warps=num_warps,
num_stages=num_stages,
)
return out
_TUNE_CACHE: dict = {}
def _tuned_launch(
hidden_states: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
T_perm, H = hidden_states.shape
E = expert_offsets.shape[0] - 1
I = W_gate.shape[2]
cache_key = (T_perm, H, I, E)
if cache_key not in _TUNE_CACHE:
best_ms = float("inf")
best_cfg = None
configs = [
# (kernel, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages)
(_grouped_gemm_swiglu_3d_kernel, 128, 64, 64, 4, 2),
(_grouped_gemm_swiglu_3d_kernel, 128, 64, 64, 4, 3),
(_grouped_gemm_swiglu_3d_kernel, 64, 128, 64, 4, 2),
(_grouped_gemm_swiglu_3d_kernel, 64, 128, 64, 4, 3),
(_grouped_gemm_swiglu_3d_kernel, 64, 128, 64, 8, 2),
(_grouped_gemm_swiglu_3d_kernel, 64, 128, 64, 8, 3),
(_grouped_gemm_swiglu_3d_kernel, 128, 128, 64, 8, 2),
(_grouped_gemm_swiglu_3d_kernel, 128, 128, 64, 8, 3),
(_grouped_gemm_swiglu_2pass_kernel, 128, 64, 64, 4, 2),
(_grouped_gemm_swiglu_2pass_kernel, 128, 64, 64, 4, 3),
(_grouped_gemm_swiglu_2pass_kernel, 64, 128, 64, 4, 2),
(_grouped_gemm_swiglu_2pass_kernel, 64, 128, 64, 4, 3),
(_grouped_gemm_swiglu_2pass_kernel, 128, 128, 64, 8, 2),
(_grouped_gemm_swiglu_2pass_kernel, 128, 128, 64, 8, 3),
(_grouped_gemm_swiglu_2pass_kernel, 64, 64, 64, 4, 2),
(_grouped_gemm_swiglu_2pass_kernel, 64, 64, 64, 4, 3),
]
import time
for kernel, BM, BN, BK, nw, ns in configs:
try:
for _ in range(3):
_launch_kernel(kernel, hidden_states, W_gate, W_up, expert_offsets, BM, BN, BK, nw, ns)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(10):
_launch_kernel(kernel, hidden_states, W_gate, W_up, expert_offsets, BM, BN, BK, nw, ns)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) / 10
if ms < best_ms:
best_ms = ms
best_cfg = (kernel, BM, BN, BK, nw, ns)
except Exception:
pass
_TUNE_CACHE[cache_key] = best_cfg
kernel, BM, BN, BK, nw, ns = _TUNE_CACHE[cache_key]
return _launch_kernel(kernel, hidden_states, W_gate, W_up, expert_offsets, BM, BN, BK, nw, ns)
class Model(nn.Module):
def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
def forward(
self,
hidden_states: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
return _tuned_launch(hidden_states, self.W_gate, self.W_up, expert_offsets)
T_total = 32768
H = 4096
I = 1536 # noqa: E741
E = 128
K = 8
def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
T_perm = T_total * K
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
return offsets
def get_inputs():
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
expert_offsets = _build_routing(T_total, E, K)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0508 15:57:46.209000 3998989 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=23.158 gbps=173.384 ms=35.609
shape=0 variant=compiled tflops=23.312 gbps=174.533 ms=35.374
shape=0 variant=solution tflops=37.368 gbps=279.774 ms=22.068
shape=0 solution_peak_fraction=0.1868
shape=1 variant=eager tflops=10.210 gbps=189.443 ms=3.365
shape=1 variant=compiled tflops=10.959 gbps=203.343 ms=3.135
shape=1 variant=solution tflops=56.045 gbps=1039.901 ms=0.613
shape=1 solution_peak_fraction=0.2802
shape=2 variant=eager tflops=29.866 gbps=204.164 ms=18.407
shape=2 variant=compiled tflops=29.639 gbps=202.610 ms=18.548
shape=2 variant=solution tflops=38.173 gbps=260.945 ms=14.402
shape=2 solution_peak_fraction=0.1909
peak_fraction: 0.2154
RESULT: OK
Uninstalled 1 package in 0.19ms
Installed 1 package in 2ms
PASS
{
"problem": "06_sonic_moe_swiglu",
"harness": "opencode",
"model": "zai/glm-5.1",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.2154,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 186442, "output_tokens": 13862, "cache_read_tokens": 1889664, "cache_creation_tokens": 0, "reasoning_tokens": 74675, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and problem.yaml to understand interface and correctness criteria",
"status": "in_progress",
"priority": "high"
},
{
"content": "Read existing solution.py if any",
"status": "pending",
"priority": "high"
},
{
"content": "Design and implement custom kernel (Triton grouped GEMM with fused SwiGLU)",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py to verify correctness across all shapes",
"status": "pending",
... (294 more chars)content: 'Read reference.py, check.py, benchmark.py, and problem.yaml to understand interface and correctness criteria'
status: 'in_progress'
priority: 'high'
content: 'Read existing solution.py if any'
status: 'pending'
priority: 'high'
content: 'Design and implement custom kernel (Triton grouped GEMM with fused SwiGLU)'
status: 'pending'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Profile and optimize with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate until PASS on check.py and good performance'
status: 'pending'
priority: 'high'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py</path>
<type>file</type>
<content>
1: """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2:
3: This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4: expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5: according to routing metadata; we compute, per expert e:
6:
7: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
8:
9: where x_e is the slice of permuted hidden states routed to expert e, with
10: expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11:
12: The reference loops over experts in Python. Slow, but pedagogically clear and
13: correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14: are NOT used here, but the reference is exempt — only solution.py is checked.
15: """
16: from __future__ import annotations
17:
18: import torch
19: import torch.nn as nn
20: import torch.nn.functional as F
21:
22: OP_TYPE = "grouped_gemm_swiglu"
23: SUPPORTED_PRECISIONS = ["bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26:
27: class Model(nn.Module):
28: """Up-projection of a top-K MoE FFN with fused SwiGLU.
29:
30: Inputs at call time:
31: hidden_states: (T_perm, H) bf16, already permuted to expert order
32: expert_offsets: (E+1,) int32, prefix sums of token counts per expert
33:
... (2866 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: return float(eval(expr, {"__builtins__": {}}, vars))
27:
28:
29: def main():
30: import reference
31: import shapes
32: import solution
33:
34: meta = yaml.safe_load(Path("problem.yaml").read_text())
35: hw = get_hw(meta["hardware"][0])
36: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37: peak_gbps = hw.peak_bandwidth_gb_s
38: regime = meta.get("regime", "compute")
39: flops_formula = meta["
... (3701 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.ge
... (2793 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to write the up-projection of a top-K MoE FFN with grouped GEMM and fused SwiGLU for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The op is: per expert e, compute h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e]) where x_e is the slice of permuted hidden states routed to expert e. Inputs are hidden_states of shape (T_perm, H) bf16 (already permuted into expert order, T_perm = T_total * K), expert_offsets of shape (E+1,) int32 (prefix sums of token counts per expert, so expert e owns rows [offsets[e]:offsets[e+1]]), and per-expert weights W_gate and W_up each of shape (E, H, I) bf16. You return gated_up of shape (T_perm, I) bf16. Correctness tolerance is 0.02 abs/rel on bf16. The shapes you have to handle are T_total=32768 H=4096 I=1536 E=128 K=8 (the headline sonic-moe configuration, 32K tokens, 128 experts, top-8), T_total=4096 H=2048 I=1024 E=64 K=4 (smaller / faster iteration), and T_total=16384 H=2048 I=4096 E=64 K=8 (intermediate-heavy aspect ratio that forces larger N tiles).
4:
5: This needs t
... (1078 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
2:
3: Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
4: requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
5: SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
6: the package installs cleanly but kernels may fail at launch on SM120 (the
7: QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
8:
9: If the live call fails, `is_available()` returns False and the benchmark scores
10: the agent against PyTorch eager + the documented H100 paper ceiling (see
11: problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
12: importing sonic_moe in solution.py (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def _try_sonic_moe(
20: hidden_states: torch.Tensor,
21: W_gate: torch.Tensor,
22: W_up: torch.Tensor,
23: expert_offsets: torch.Tensor,
24: ) -> torch.Tensor | None:
25: try:
26: import sonic_moe # type: ignore # noqa: F401
27: except Exception:
28: return None
29: try:
30: # Public sonic-moe API surface is still stabilizing. The expected entry
31: # point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
32: # SwiGLU. Adapt to the a
... (1755 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Defaults match the sonic-moe paper's headline configuration. We add:
4: - a smaller shape for fast iteration during agent development
5: - a wider intermediate (different aspect ratio) to stress N-tile selection
6: """
7:
8: SHAPES = [
9: # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10: {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11:
12: # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13: # variable-length grouped layout meaningful, but smaller token / hidden dims.
14: {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15:
16: # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17: # Forces tiles to handle larger N relative to K.
18: {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml</path>
<type>file</type>
<content>
1: name: 06_sonic_moe_swiglu
2: display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3: precision: bf16
4: regime: compute
5: hardware: [RTX_PRO_6000]
6: peak_tflops_key: bf16
7: peak_bandwidth_key: dram
8:
9: # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10: # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11: # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12: # only if you count routing. The standard MoE FLOPs convention counts only the
13: # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14: flops_formula: "2 * T_total * H * (2 * I)"
15:
16: # Bytes moved (approximate, lower bound):
17: # read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18: # expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19: bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25: # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26: # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27: forbidden:
28: - "torch.matmul"
29: - "torch.bmm"
30: - "torch.nn.functional.linear"
31: - "F.linear"
32: - "from son
... (681 more chars){
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and problem.yaml to understand interface and correctness criteria",
"priority": "high",
"status": "completed"
},
{
"content": "Check GPU status and available tools (Triton version, CUDA version)",
"priority": "high",
"status": "in_progress"
},
{
"content": "Design and implement custom kernel (Triton grouped GEMM with fused SwiGLU)",
"priority": "high",
"status": "pending"
},
{
"content": "Run check.py to verify correctness across all sh
... (331 more chars)content: 'Read reference.py, check.py, benchmark.py, and problem.yaml to understand interface and correctness criteria'
status: 'completed'
priority: 'high'
content: 'Check GPU status and available tools (Triton version, CUDA version)'
status: 'in_progress'
priority: 'high'
content: 'Design and implement custom kernel (Triton grouped GEMM with fused SwiGLU)'
status: 'pending'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Profile and optimize with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate until PASS on check.py and good performance'
status: 'pending'
priority: 'high'{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>&1 | head -5",
"description": "Check GPU status"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "uv run python -c \"import triton; print(triton.__version__)\" 2>&1",
"description": "Check Triton version"
}3.6.0
{
"command": "uv run python -c \"import torch; print(torch.__version__, torch.cuda.get_device_capability(0))\" 2>&1",
"description": "Check PyTorch and CUDA capability"
}2.11.0+cu130 (12, 0)
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py 2>&1",
"description": "Check if solution.py exists"
}ls: cannot access '/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py': No such file or directory
{
"command": "uv run python -c \"import triton; import triton.language as tl; print('tl.dot available:', hasattr(tl, 'dot'))\" 2>&1",
"description": "Check Triton dot availability"
}tl.dot available: True
{
"command": "<280 chars \u2014 see below>",
"description": "Check Triton dot capabilities"
}uv run python -c "
import triton
import triton.language as tl
# Check if we can use tl.dot with bf16
print('Triton language version:', triton.language.__version__ if hasattr(triton.language, '__version__') else 'unknown')
# Check available dot modes
help(tl.dot)
" 2>&1 | head -30Triton language version: unknown
Help on function dot in module triton.language.core:
dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=triton.language.float32, _semantic=None)
Returns the matrix product of two blocks.
The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions.
For three-dimensional blocks, `tl.dot` performs the batched matrix product,
where the first dimension of each block represents the batch dimension.
:param input: The first tensor to be multiplied.
:type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second tensor to be multiplied.
:type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
:type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
:param input_precision: How to exercise the Tensor Cores for f32 x f32. If
the device does not have Tensor Cores or the inputs are not of dtype f32,
this option is ignored. For devices that do have tensor cores, the
default precision is tf32.
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"iee
... (313 more chars){
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and problem.yaml to understand interface and correctness criteria",
"priority": "high",
"status": "completed"
},
{
"content": "Check GPU status and available tools (Triton version, CUDA version)",
"priority": "high",
"status": "completed"
},
{
"content": "Design and implement custom kernel (Triton grouped GEMM with fused SwiGLU)",
"priority": "high",
"status": "in_progress"
},
{
"content": "Run check.py to verify correctness across all
... (333 more chars)content: 'Read reference.py, check.py, benchmark.py, and problem.yaml to understand interface and correctness criteria'
status: 'completed'
priority: 'high'
content: 'Check GPU status and available tools (Triton version, CUDA version)'
status: 'completed'
priority: 'high'
content: 'Design and implement custom kernel (Triton grouped GEMM with fused SwiGLU)'
status: 'in_progress'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Profile and optimize with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate until PASS on check.py and good performance'
status: 'pending'
priority: 'high'{
"command": "uv run python -c \"import triton; print(dir(triton))\" 2>&1",
"description": "Check triton module contents"
}Uninstalled 1 package in 0.18ms
Installed 1 package in 2ms
['AsyncCompileMode', 'CompilationError', 'Config', 'FutureKernel', 'InterpreterError', 'JITFunction', 'KernelInterface', 'MockTensor', 'OutOfResources', 'TensorWrapper', 'TritonError', '_C', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', '_utils', 'autotune', 'backends', 'cdiv', 'compile', 'compiler', 'constexpr_function', 'errors', 'heuristics', 'jit', 'knobs', 'language', 'must_use_result', 'next_power_of_2', 'reinterpret', 'runtime', 'set_allocator', 'testing', 'tools']
{
"command": "<135 chars \u2014 see below>",
"description": "Check correctness utility source",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
from src.eval.correctness import check_correctness
import inspect
print(inspect.getsource(check_correctness))
" 2>&1def check_correctness(
reference_out: torch.Tensor,
solution_out: torch.Tensor,
dtype: torch.dtype | None = None,
override: dict | None = None,
) -> tuple[bool, str]:
"""Return (passed, message). Integer comparisons are bitwise; floats use atol/rtol."""
if reference_out.shape != solution_out.shape:
return False, f"shape mismatch: ref={tuple(reference_out.shape)} sol={tuple(solution_out.shape)}"
if torch.isnan(solution_out).any():
return False, "solution contains NaN"
if torch.isinf(solution_out).any():
return False, "solution contains Inf"
dtype = dtype or reference_out.dtype
tol = tolerance_for_dtype(dtype, override)
# Cast both to fp32 for the comparison to avoid dtype-specific allclose quirks
ref_f = reference_out.float()
sol_f = solution_out.float()
if tol["atol"] == 0 and tol["rtol"] == 0:
if torch.equal(ref_f, sol_f):
return True, "ok (exact)"
n_diff = (ref_f != sol_f).sum().item()
return False, f"exact match required; {n_diff} elements differ"
if torch.allclose(ref_f, sol_f, atol=tol["atol"], rtol=tol["rtol"]):
return True, f"ok (atol={tol['atol']}, rtol={tol['rtol']})"
max_diff = (ref_f - sol_f).abs().max().item()
return False, f"tolerance exceeded: max_abs_diff={max_diff:.6g} (atol={tol['atol']}, rtol={tol['rtol']})"
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,202 @@
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "grouped_gemm_swiglu"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+@triton.jit
+def _grouped_gemm_swiglu_kernel(
+ x_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ out_ptr,
+ expert_offsets_ptr,
+ tiles_prefix_sum_ptr,
+ H,
+ I,
+ E,
+ num_n_tiles,
+ stride_xm,
+ stride_xk,
+ stride_wge,
+ stride_wgk,
+ stride_wgn,
+ stride_wue,
+ stride_wuk,
+ stride_wun,
+ stride_om,
+ stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ lo = 0
+ hi = E - 1
+ while lo < hi:
+ mid = (lo + hi) >> 1
+ if tl.load(tiles_prefix_sum_ptr + mid + 1) <= pid:
+ lo = mid + 1
+ else:
+ hi = mid
+ expert = lo
+
+ expert_tile_start = tl.load(tiles_prefix_sum_ptr + expert)
+ local_tile = pid - expert_tile_start
+ pid_m = local_tile // num_n_tiles
+ pid_n = local_tile % num_n_tiles
+
+ expert_start = tl.load(expert_offsets_ptr + expert)
+ expert_end = tl.load(expert_offsets_ptr + expert + 1)
+ m_start = expert_start + pid_m * BLOCK_M
+ n_start = pid_n * BLOCK_N
+
+ if m_start >= expert_end:
+ return
+
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ n_offsets = n_start + tl.arange(0, BLOCK_N)
+
+ gate_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ up_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = k_start + tl.arange(0, BLOCK_K)
+
+ x_ptrs = x_ptr + m_offsets[:, None] * stride_xm + k_off[None, :] * stride_xk
+ x_mask = (m_offsets[:, None] < expert_end) & (k_off[None, :] < H)
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0)
+
+ w_mask = (k_off[:, None] < H) & (n_offsets[None, :] < I)
+
+ wg_ptrs = W_gate_ptr + expert * stride_wge + k_off[:, None] * stride_wgk + n_offsets[None, :] * stride_wgn
+ wg_tile = tl.load(wg_ptrs, mask=w_mask, other=0.0)
+
+ wu_ptrs = W_up_ptr + expert * stride_wue + k_off[:, None] * stride_wuk + n_offsets[None, :] * stride_wun
+ wu_tile = tl.load(wu_ptrs, mask=w_mask, other=0.0)
+
+ gate_acc = tl.dot(x_tile, wg_tile, acc=gate_acc, out_dtype=tl.float32)
+ up_acc = tl.dot(x_tile, wu_tile, acc=up_acc, out_dtype=tl.float32)
+
+ out_tile = gate_acc * tl.sigmoid(gate_acc) * up_acc
+
+ out_ptrs = out_ptr + m_offsets[:, None] * stride_om + n_offsets[None, :] * stride_on
+ out_mask = (m_offsets[:, None] < expert_end) & (n_offsets[None, :] < I)
+ tl.store(out_ptrs, out_tile.to(tl.bfloat16), mask=out_mask)
+
+
+def _launch_grouped_gemm_swiglu(
+ hidden_states: torch.Tensor,
+ W_gate: torch.Tensor,
+ W_up: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ BLOCK_M: int = 64,
+ BLOCK_N: int = 128,
+ BLOCK_K: int = 64,
+) -> torch.Tensor:
+ T_perm, H = hidden_states.shape
+ E = expert_offsets.shape[0] - 1
+ I = W_gate.shape[2]
+
+ num_n_tiles = (I + BLOCK_N - 1) // BLOCK_N
+
+ offsets_cpu = expert_offsets.cpu()
+ tiles_prefix_sum = torch.zeros(E + 1, dtype=torch.int32)
+ for e in range(E):
+ n_e = int(offsets_cpu[e + 1]) - int(offsets_cpu[e])
+ num_m_tiles = (n_e + BLOCK_M - 1) // BLOCK_M
+ tiles_prefix_sum[e + 1] = tiles_prefix_sum[e] + num_m_tiles * num_n_tiles
+
+ total_tiles = int(tiles_prefix_sum[E].item())
+ if total_tiles == 0:
+ return torch.zeros(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
+
+ tiles_prefix_sum = tiles_prefix_sum.to(device=hidden_states.device)
+ out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
+
+ grid = (total_tiles,)
+ _grouped_gemm_swiglu_kernel[grid](
+ hidden_states,
+ W_gate,
+ W_up,
+ out,
+ expert_offsets,
+ tiles_prefix_sum,
+ H,
+ I,
+ E,
+ num_n_tiles,
+ hidden_states.stride(0),
+ hidden_states.stride(1),
+ W_gate.stride(0),
+ W_gate.stride(1),
+ W_gate.stride(2),
+ W_up.stride(0),
+ W_up.stride(1),
+ W_up.stride(2),
+ out.stride(0),
+ out.stride(1),
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ num_warps=8,
+ )
+
+ return out
+
+
+class Model(nn.Module):
+ def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
+ super().__init__()
+ self.T_total = T_total
+ self.H = H
+ self.I = I
+ self.E = E
+ self.K = K
+ self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ nn.init.normal_(self.W_gate, std=0.02)
+ nn.init.normal_(self.W_up, std=0.02)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ ) -> torch.Tensor:
+ return _launch_grouped_gemm_swiglu(hidden_states, self.W_gate, self.W_up, expert_offsets)
+
+
+T_total = 32768
+H = 4096
+I = 1536 # noqa: E741
+E = 128
+K = 8
+
+
+def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
+ T_perm = T_total * K
+ base = T_perm // E
+ rem = T_perm - base * E
+ counts = torch.full((E,), base, dtype=torch.int32, device=device)
+ counts[:rem] += 1
+ offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
+ offsets[1:] = torch.cumsum(counts, dim=0)
+ return offsets
+
+
+def get_inputs():
+ T_perm = T_total * K
+ hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
+ expert_offsets = _build_routing(T_total, E, K)
+ re
... (99 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"content": "<5780 chars \u2014 see diff>"
}Wrote file successfully.{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and problem.yaml to understand interface and correctness criteria",
"priority": "high",
"status": "completed"
},
{
"content": "Check GPU status and available tools (Triton version, CUDA version)",
"priority": "high",
"status": "completed"
},
{
"content": "Design and implement custom kernel (Triton grouped GEMM with fused SwiGLU)",
"priority": "high",
"status": "completed"
},
{
"content": "Run check.py to verify correctness across all sh
... (335 more chars)content: 'Read reference.py, check.py, benchmark.py, and problem.yaml to understand interface and correctness criteria'
status: 'completed'
priority: 'high'
content: 'Check GPU status and available tools (Triton version, CUDA version)'
status: 'completed'
priority: 'high'
content: 'Design and implement custom kernel (Triton grouped GEMM with fused SwiGLU)'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'in_progress'
priority: 'high'
content: 'Profile and optimize with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate until PASS on check.py and good performance'
status: 'pending'
priority: 'high'{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu",
"timeout": 300000
}Uninstalled 1 package in 0.17ms
Installed 1 package in 1ms
PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run performance benchmark",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0508 15:15:48.801000 3956788 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0
... (807 more chars){
"command": "<1748 chars \u2014 see below>",
"description": "Benchmark different tile configurations",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch
import time
import sys
sys.path.insert(0, '.')
from solution import Model, _launch_grouped_gemm_swiglu
device = torch.device('cuda:0')
# Use headline shape
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
model = Model(T_total, H, I, E, K).to(device).eval()
T_perm = T_total * K
torch.manual_seed(42)
hs = torch.randn(T_perm, H, dtype=torch.bfloat16, device=device) * 0.1
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E+1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
configs = [
(64, 128, 64, 8, 2),
(64, 128, 64, 8, 3),
(128, 128, 64, 8, 2),
(128, 128, 64, 8, 3),
(64, 128, 128, 8, 2),
(64, 128, 128, 8, 3),
(128, 64, 64, 4, 2),
(64, 64, 64, 4, 2),
(128, 128, 128, 8, 2),
(128, 128, 128, 8, 3),
(64, 256, 64, 8, 2),
(128, 256, 64, 8, 3),
(32, 128, 64, 4, 3),
(32, 256, 64, 8, 3),
(64, 64, 128, 4, 3),
]
for BM, BN, BK, nw, ns in configs:
try:
# warmup
for _ in range(3):
_launch_grouped_gemm_swiglu(hs, model.W_gate, model.W_up, offsets, BM, BN, BK, num_warps=nw, num_stages=ns)
torch.cuda.synchronize()
# bench
t0 = time.perf_counter()
for _ in range(20):
_launch_grouped_gemm_swiglu(hs, model.W_gate, model.W_up, offsets, BM, BN, BK, num_warps=nw, num_stages=ns)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) / 20
print(f'BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: {ms*1000:.3f} ms')
except Exception as e:
print(f'BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: FAILED ({e})')
" 2>&1Uninstalled 1 package in 0.19ms
Installed 1 package in 1ms
BM= 64 BN=128 BK= 64 nw=8 ns=2: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM= 64 BN=128 BK= 64 nw=8 ns=3: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM=128 BN=128 BK= 64 nw=8 ns=2: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM=128 BN=128 BK= 64 nw=8 ns=3: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM= 64 BN=128 BK=128 nw=8 ns=2: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM= 64 BN=128 BK=128 nw=8 ns=3: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM=128 BN= 64 BK= 64 nw=4 ns=2: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM= 64 BN= 64 BK= 64 nw=4 ns=2: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM=128 BN=128 BK=128 nw=8 ns=2: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM=128 BN=128 BK=128 nw=8 ns=3: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM= 64 BN=256 BK= 64 nw=8 ns=2: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM=128 BN=256 BK= 64 nw=8 ns=3: FAILED (_launch_grouped_gemm_swiglu() got an unexpected keyword argument 'num_warps')
BM= 32 BN=128 BK= 64 nw=4
... (329 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"offset": 60,
"limit": 50
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
60:
61: if m_start >= expert_end:
62: return
63:
64: m_offsets = m_start + tl.arange(0, BLOCK_M)
65: n_offsets = n_start + tl.arange(0, BLOCK_N)
66:
67: gate_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
68: up_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
69:
70: for k_start in range(0, H, BLOCK_K):
71: k_off = k_start + tl.arange(0, BLOCK_K)
72:
73: x_ptrs = x_ptr + m_offsets[:, None] * stride_xm + k_off[None, :] * stride_xk
74: x_mask = (m_offsets[:, None] < expert_end) & (k_off[None, :] < H)
75: x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0)
76:
77: w_mask = (k_off[:, None] < H) & (n_offsets[None, :] < I)
78:
79: wg_ptrs = W_gate_ptr + expert * stride_wge + k_off[:, None] * stride_wgk + n_offsets[None, :] * stride_wgn
80: wg_tile = tl.load(wg_ptrs, mask=w_mask, other=0.0)
81:
82: wu_ptrs = W_up_ptr + expert * stride_wue + k_off[:, None] * stride_wuk + n_offsets[None, :] * stride_wun
83: wu_tile = tl.load(wu_ptrs, mask=w_mask, other=0.0)
84:
85: gate_acc = tl.dot(x_tile, wg_tile, acc=gate_acc, out_dtype=tl.float32)
86: up_acc = tl.dot(x_tile, wu_tile, acc=up_acc, out_dtype=tl.float32)
87:
88: out_tile = gate_acc * tl.sigmoid(gate_acc) * up_acc
89:
90: out_ptrs = out_ptr + m_offsets[:, None] * stride_om
... (715 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -100,6 +100,8 @@
BLOCK_M: int = 64,
BLOCK_N: int = 128,
BLOCK_K: int = 64,
+ num_warps: int = 8,
+ num_stages: int = 2,
) -> torch.Tensor:
T_perm, H = hidden_states.shape
E = expert_offsets.shape[0] - 1
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<238 chars \u2014 see below>",
"newString": "<287 chars \u2014 see below>"
}def _launch_grouped_gemm_swiglu(
hidden_states: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
expert_offsets: torch.Tensor,
BLOCK_M: int = 64,
BLOCK_N: int = 128,
BLOCK_K: int = 64,
) -> torch.Tensor:def _launch_grouped_gemm_swiglu(
hidden_states: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
expert_offsets: torch.Tensor,
BLOCK_M: int = 64,
BLOCK_N: int = 128,
BLOCK_K: int = 64,
num_warps: int = 8,
num_stages: int = 2,
) -> torch.Tensor:Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -148,7 +148,8 @@
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
- num_warps=8,
+ num_warps=num_warps,
+ num_stages=num_stages,
)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<101 chars \u2014 see below>",
"newString": "<140 chars \u2014 see below>"
} BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
num_warps=8,
) BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
num_warps=num_warps,
num_stages=num_stages,
)Edit applied successfully.{
"command": "<1714 chars \u2014 see below>",
"description": "Benchmark tile configurations",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch
import time
import sys
sys.path.insert(0, '.')
from solution import Model, _launch_grouped_gemm_swiglu
device = torch.device('cuda:0')
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
model = Model(T_total, H, I, E, K).to(device).eval()
T_perm = T_total * K
torch.manual_seed(42)
hs = torch.randn(T_perm, H, dtype=torch.bfloat16, device=device) * 0.1
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E+1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
configs = [
(64, 128, 64, 8, 2),
(64, 128, 64, 8, 3),
(128, 128, 64, 8, 2),
(128, 128, 64, 8, 3),
(64, 128, 128, 8, 2),
(64, 128, 128, 8, 3),
(128, 64, 64, 4, 2),
(64, 64, 64, 4, 2),
(128, 128, 128, 8, 2),
(128, 128, 128, 8, 3),
(64, 256, 64, 8, 2),
(128, 256, 64, 8, 3),
(32, 128, 64, 4, 3),
(32, 256, 64, 8, 3),
(64, 64, 128, 4, 3),
]
for BM, BN, BK, nw, ns in configs:
try:
for _ in range(3):
_launch_grouped_gemm_swiglu(hs, model.W_gate, model.W_up, offsets, BM, BN, BK, num_warps=nw, num_stages=ns)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(20):
_launch_grouped_gemm_swiglu(hs, model.W_gate, model.W_up, offsets, BM, BN, BK, num_warps=nw, num_stages=ns)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) / 20
print(f'BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: {ms*1000:.3f} ms')
except Exception as e:
print(f'BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: FAILED ({type(e).__name__}: {e})')
" 2>&1BM= 64 BN=128 BK= 64 nw=8 ns=2: 24.469 ms
BM= 64 BN=128 BK= 64 nw=8 ns=3: 23.714 ms
BM=128 BN=128 BK= 64 nw=8 ns=2: 24.585 ms
BM=128 BN=128 BK= 64 nw=8 ns=3: 23.396 ms
BM= 64 BN=128 BK=128 nw=8 ns=2: 32.817 ms
BM= 64 BN=128 BK=128 nw=8 ns=3: FAILED (OutOfResources: out of resource: shared memory, Required: 163840, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.)
BM=128 BN= 64 BK= 64 nw=4 ns=2: 22.045 ms
BM= 64 BN= 64 BK= 64 nw=4 ns=2: 26.336 ms
BM=128 BN=128 BK=128 nw=8 ns=2: 25.007 ms
BM=128 BN=128 BK=128 nw=8 ns=3: FAILED (OutOfResources: out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.)
BM= 64 BN=256 BK= 64 nw=8 ns=2: 26.078 ms
BM=128 BN=256 BK= 64 nw=8 ns=3: FAILED (OutOfResources: out of resource: shared memory, Required: 163840, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.)
BM= 32 BN=128 BK= 64 nw=4 ns=3: 33.343 ms
BM= 32 BN=256 BK= 64 nw=8 ns=3: FAILED (OutOfResources: out of resource: shared memory, Required: 139264, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.)
BM= 64 BN= 64 BK=128 nw=4 ns=3: 27.193 ms
{
"command": "<1749 chars \u2014 see below>",
"description": "Benchmark more tile configurations",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch
import time
import sys
sys.path.insert(0, '.')
from solution import Model, _launch_grouped_gemm_swiglu
device = torch.device('cuda:0')
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
model = Model(T_total, H, I, E, K).to(device).eval()
T_perm = T_total * K
torch.manual_seed(42)
hs = torch.randn(T_perm, H, dtype=torch.bfloat16, device=device) * 0.1
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E+1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
configs = [
# Best so far: BM=128 BN=64 BK=64 nw=4 ns=2
(128, 64, 64, 4, 1),
(128, 64, 64, 4, 2),
(128, 64, 64, 4, 3),
(128, 64, 64, 8, 2),
(128, 64, 32, 4, 2),
(128, 64, 32, 4, 3),
(128, 64, 128, 4, 2),
(64, 64, 64, 4, 2),
(64, 64, 64, 4, 3),
(64, 64, 64, 2, 2),
(128, 32, 64, 4, 2),
(64, 128, 64, 4, 2),
(64, 128, 64, 4, 3),
(128, 128, 64, 4, 2),
(128, 128, 64, 4, 3),
]
for BM, BN, BK, nw, ns in configs:
try:
for _ in range(3):
_launch_grouped_gemm_swiglu(hs, model.W_gate, model.W_up, offsets, BM, BN, BK, num_warps=nw, num_stages=ns)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(20):
_launch_grouped_gemm_swiglu(hs, model.W_gate, model.W_up, offsets, BM, BN, BK, num_warps=nw, num_stages=ns)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) / 20
print(f'BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: {ms*1000:.3f} ms')
except Exception as e:
print(f'BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: FAILED ({type(e).__name__})')
" 2>&1BM=128 BN= 64 BK= 64 nw=4 ns=1: 27.091 ms
BM=128 BN= 64 BK= 64 nw=4 ns=2: 22.168 ms
BM=128 BN= 64 BK= 64 nw=4 ns=3: 22.547 ms
BM=128 BN= 64 BK= 64 nw=8 ns=2: 28.016 ms
BM=128 BN= 64 BK= 32 nw=4 ns=2: 22.489 ms
BM=128 BN= 64 BK= 32 nw=4 ns=3: 22.620 ms
BM=128 BN= 64 BK=128 nw=4 ns=2: 27.747 ms
BM= 64 BN= 64 BK= 64 nw=4 ns=2: 26.482 ms
BM= 64 BN= 64 BK= 64 nw=4 ns=3: 25.840 ms
BM= 64 BN= 64 BK= 64 nw=2 ns=2: 26.985 ms
BM=128 BN= 32 BK= 64 nw=4 ns=2: 26.292 ms
BM= 64 BN=128 BK= 64 nw=4 ns=2: 23.537 ms
BM= 64 BN=128 BK= 64 nw=4 ns=3: 23.418 ms
BM=128 BN=128 BK= 64 nw=4 ns=2: 1322.613 ms
BM=128 BN=128 BK= 64 nw=4 ns=3: 977.685 ms
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -11,17 +11,16 @@
@triton.jit
-def _grouped_gemm_swiglu_kernel(
+def _grouped_gemm_swiglu_3d_kernel(
x_ptr,
W_gate_ptr,
W_up_ptr,
out_ptr,
expert_offsets_ptr,
- tiles_prefix_sum_ptr,
H,
I,
E,
- num_n_tiles,
+ num_m_tiles,
stride_xm,
stride_xk,
stride_wge,
@@ -36,33 +35,19 @@
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
- pid = tl.program_id(0)
-
- lo = 0
- hi = E - 1
- while lo < hi:
- mid = (lo + hi) >> 1
- if tl.load(tiles_prefix_sum_ptr + mid + 1) <= pid:
- lo = mid + 1
- else:
- hi = mid
- expert = lo
-
- expert_tile_start = tl.load(tiles_prefix_sum_ptr + expert)
- local_tile = pid - expert_tile_start
- pid_m = local_tile // num_n_tiles
- pid_n = local_tile % num_n_tiles
-
- expert_start = tl.load(expert_offsets_ptr + expert)
- expert_end = tl.load(expert_offsets_ptr + expert + 1)
+ pid_n = tl.program_id(0)
+ pid_m = tl.program_id(1)
+ pid_e = tl.program_id(2)
+
+ expert_start = tl.load(expert_offsets_ptr + pid_e)
+ expert_end = tl.load(expert_offsets_ptr + pid_e + 1)
m_start = expert_start + pid_m * BLOCK_M
- n_start = pid_n * BLOCK_N
if m_start >= expert_end:
return
m_offsets = m_start + tl.arange(0, BLOCK_M)
- n_offsets = n_start + tl.arange(0, BLOCK_N)
+ n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
gate_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
up_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
@@ -76,10 +61,10 @@
w_mask = (k_off[:, None] < H) & (n_offsets[None, :] < I)
- wg_ptrs = W_gate_ptr + expert * stride_wge + k_off[:, None] * stride_wgk + n_offsets[None, :] * stride_wgn
+ wg_ptrs = W_gate_ptr + pid_e * stride_wge + k_off[:, None] * stride_wgk + n_offsets[None, :] * stride_wgn
wg_tile = tl.load(wg_ptrs, mask=w_mask, other=0.0)
- wu_ptrs = W_up_ptr + expert * stride_wue + k_off[:, None] * stride_wuk + n_offsets[None, :] * stride_wun
+ wu_ptrs = W_up_ptr + pid_e * stride_wue + k_off[:, None] * stride_wuk + n_offsets[None, :] * stride_wun
wu_tile = tl.load(wu_ptrs, mask=w_mask, other=0.0)
gate_acc = tl.dot(x_tile, wg_tile, acc=gate_acc, out_dtype=tl.float32)
@@ -92,7 +77,96 @@
tl.store(out_ptrs, out_tile.to(tl.bfloat16), mask=out_mask)
-def _launch_grouped_gemm_swiglu(
+@triton.jit
+def _grouped_gemm_swiglu_2pass_kernel(
+ x_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ out_ptr,
+ expert_offsets_ptr,
+ H,
+ I,
+ E,
+ num_m_tiles,
+ stride_xm,
+ stride_xk,
+ stride_wge,
+ stride_wgk,
+ stride_wgn,
+ stride_wue,
+ stride_wuk,
+ stride_wun,
+ stride_om,
+ stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_n = tl.program_id(0)
+ pid_m = tl.program_id(1)
+ pid_e = tl.program_id(2)
+
+ expert_start = tl.load(expert_offsets_ptr + pid_e)
+ expert_end = tl.load(expert_offsets_ptr + pid_e + 1)
+ m_start = expert_start + pid_m * BLOCK_M
+
+ if m_start >= expert_end:
+ return
+
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ gate_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_start in range(0, H, BLOCK_K):
+ k_off = k_start + tl.arange(0, BLOCK_K)
+
+ x_ptrs = x_ptr + m_offsets[:, None] * stride_xm + k_off[None, :] * stride_xk
+ x_mask = (m_offsets[:, None] < expert_end) & (k_off[None, :] < H)
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0)
+
+ w_mask = (k_off[:, None] < H) & (n_offsets[None, :] < I)
+ wg_ptrs = W_gate_ptr + pid_e * stride_wge + k_off[:, None] * stride_wgk + n_offsets[None, :] * stride_wgn
+ wg_tile = tl.load(wg_ptrs, mask=w_mask, other=0.0)
+
+ gate_acc = tl.dot(x_tile, wg_tile, acc=gate_acc, out_dtype=tl.float32)
+
+ up_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_start in range(0, H, BLOCK_K):
+ k_off = k_start + tl.arange(0, BLOCK_K)
+
+ x_ptrs = x_ptr + m_offsets[:, None] * stride_xm + k_off[None, :] * stride_xk
+ x_mask = (m_offsets[:, None] < expert_end) & (k_off[None, :] < H)
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0)
+
+ w_mask = (k_off[:, None] < H) & (n_offsets[None, :] < I)
+ wu_ptrs = W_up_ptr + pid_e * stride_wue + k_off[:, None] * stride_wuk + n_offsets[None, :] * stride_wun
+ wu_tile = tl.load(wu_ptrs, mask=w_mask, other=0.0)
+
+ up_acc = tl.dot(x_tile, wu_tile, acc=up_acc, out_dtype=tl.float32)
+
+ out_tile = gate_acc * tl.sigmoid(gate_acc) * up_acc
+
+ out_ptrs = out_ptr + m_offsets[:, None] * stride_om + n_offsets[None, :] * stride_on
+ out_mask = (m_offsets[:, None] < expert_end) & (n_offsets[None, :] < I)
+ tl.store(out_ptrs, out_tile.to(tl.bfloat16), mask=out_mask)
+
+
+def _compute_grid(hidden_states, W_gate, expert_offsets, BLOCK_M, BLOCK_N):
+ T_perm, H = hidden_states.shape
+ E = expert_offsets.shape[0] - 1
+ I = W_gate.shape[2]
+ num_n_tiles = (I + BLOCK_N - 1) // BLOCK_N
+ offsets_cpu = expert_offsets.cpu()
+ max_m_tiles = 0
+ for e in range(E):
+ n_e = int(offsets_cpu[e + 1]) - int(offsets_cpu[e])
+ nm = (n_e + BLOCK_M - 1) // BLOCK_M
+ max_m_tiles = max(max_m_tiles, nm)
+ return num_n_tiles, max_m_tiles
+
+
+def _launch_kernel(
+ kernel,
hidden_states: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
@@ -107,34 +181,24 @@
E = expert_offsets.shape[0] - 1
I = W_gate.shape[2]
- num_n_tiles = (I + BLOCK_N - 1) // BLOCK_N
-
- offsets_cpu = expert_offsets.cpu()
- tiles_prefix_sum = torch.zeros(E + 1, dtype=torc
... (4205 more chars){
"content": "<10524 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness after refactor",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark with autotuned kernel",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0508 15:32:25.812000 3975579 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0
... (808 more chars){
"command": "<1213 chars \u2014 see below>",
"description": "Check autotuning cache selections",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch, sys
sys.path.insert(0, '.')
from solution import Model, _tuned_launch, _TUNE_CACHE
device = torch.device('cuda:0')
# Shape 0
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
model = Model(T_total, H, I, E, K).to(device)
T_perm = T_total * K
hs = torch.randn(T_perm, H, dtype=torch.bfloat16, device=device) * 0.1
base = T_perm // E; rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E+1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
result = _tuned_launch(hs, model.W_gate, model.W_up, offsets)
print('Shape 0 cache:', _TUNE_CACHE)
# Shape 1
T_total, H, I, E, K = 4096, 2048, 1024, 64, 4
model = Model(T_total, H, I, E, K).to(device)
T_perm = T_total * K
hs = torch.randn(T_perm, H, dtype=torch.bfloat16, device=device) * 0.1
base = T_perm // E; rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E+1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
result = _tuned_launch(hs, model.W_gate, model.W_up, offsets)
print('After shape 1 cache:', _TUNE_CACHE)
" 2>&1Shape 0 cache: {(262144, 4096, 1536, 128): (JITFunction(solution:_grouped_gemm_swiglu_3d_kernel), 128, 64, 64, 4, 2)}
After shape 1 cache: {(262144, 4096, 1536, 128): (JITFunction(solution:_grouped_gemm_swiglu_3d_kernel), 128, 64, 64, 4, 2), (16384, 2048, 1024, 64): (JITFunction(solution:_grouped_gemm_swiglu_3d_kernel), 128, 64, 64, 4, 2)}
{
"command": "<3221 chars \u2014 see below>",
"description": "Comprehensive config sweep",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch
import time
import sys
sys.path.insert(0, '.')
from solution import Model, _launch_kernel, _grouped_gemm_swiglu_3d_kernel, _grouped_gemm_swiglu_2pass_kernel
device = torch.device('cuda:0')
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
model = Model(T_total, H, I, E, K).to(device)
T_perm = T_total * K
torch.manual_seed(42)
hs = torch.randn(T_perm, H, dtype=torch.bfloat16, device=device) * 0.1
base = T_perm // E; rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E+1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
# Focus on configs with smaller tiles for higher occupancy
configs = [
# (kernel, BM, BN, BK, nw, ns)
(_grouped_gemm_swiglu_3d_kernel, 128, 64, 64, 4, 2), # baseline
(_grouped_gemm_swiglu_3d_kernel, 64, 64, 64, 4, 1), # smaller, ns=1
(_grouped_gemm_swiglu_3d_kernel, 64, 64, 64, 4, 2), # smaller, ns=2
(_grouped_gemm_swiglu_3d_kernel, 64, 64, 64, 4, 3), # smaller, ns=3
(_grouped_gemm_swiglu_3d_kernel, 64, 128, 32, 4, 2), # small BK
(_grouped_gemm_swiglu_3d_kernel, 64, 128, 32, 4, 3),
(_grouped_gemm_swiglu_3d_kernel, 64, 64, 32, 4, 2),
(_grouped_gemm_swiglu_3d_kernel, 64, 64, 32, 4, 3),
(_grouped_gemm_swiglu_3d_kernel, 32, 64, 32, 4, 2),
(_grouped_gemm_swiglu_3d_kernel, 32, 64, 32, 4, 3),
(_grouped_gemm_swiglu_3d_kernel, 32, 128, 32, 4, 2),
(_grouped_gemm_swiglu_3d_kernel, 32, 128, 32, 4, 3),
(_grouped_gemm_swiglu_2pass_kernel, 64, 64, 64, 4, 2),
(_grouped_gemm_swiglu_2pass_kernel, 64, 64, 64, 4, 3),
(_grouped_gemm_swiglu_2pass_kernel, 64, 64, 32, 4, 2),
(_grouped_gemm_swiglu_2pass_kernel, 64, 64, 32, 4, 3),
(_grouped_gemm_swiglu_2pass_kernel, 128, 64, 64, 4, 2),
(_grouped_gemm_swiglu_2pass_kernel, 128, 64, 64, 4, 3),
(_grouped_gemm_swiglu_2pass_kernel, 128, 64, 32, 4, 3),
(_grouped_gemm_swiglu_2pass_kernel, 128, 64, 32, 4, 2),
(_grouped_gemm_swiglu_3d_kernel, 128, 128, 32, 8, 2),
(_grouped_gemm_swiglu_3d_kernel, 128, 128, 32, 8, 3),
(_grouped_gemm_swiglu_3d_kernel, 64, 256, 64, 8, 2),
(_grouped_gemm_swiglu_3d_kernel, 64, 256, 64, 8, 3),
(_grouped_gemm_swiglu_3d_kernel, 128, 256, 64, 8, 2),
(_grouped_gemm_swiglu_3d_kernel, 128, 256, 32, 8, 2),
]
for ker, BM, BN, BK, nw, ns in configs:
try:
for _ in range(3):
_launch_kernel(ker, hs, model.W_gate, model.W_up, offsets, BM, BN, BK, nw, ns)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(20):
_launch_kernel(ker, hs, model.W_gate, model.W_up, offsets, BM, BN, BK, nw, ns)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) / 20
name = '1p' if ker is _grouped_gemm_swiglu_3d_kernel else '2p'
print(f'{name} BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: {ms*1000:.3f} ms')
except Exception as e:
name = '1p' if ker is _grouped_gemm_swiglu_3d_kernel else '2p'
print(f'{name} BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: FAILED ({type(e).__name__})')
" 2>&11p BM=128 BN= 64 BK= 64 nw=4 ns=2: 21.839 ms
1p BM= 64 BN= 64 BK= 64 nw=4 ns=1: 30.991 ms
1p BM= 64 BN= 64 BK= 64 nw=4 ns=2: 26.378 ms
1p BM= 64 BN= 64 BK= 64 nw=4 ns=3: 25.787 ms
1p BM= 64 BN=128 BK= 32 nw=4 ns=2: 23.453 ms
1p BM= 64 BN=128 BK= 32 nw=4 ns=3: 22.770 ms
1p BM= 64 BN= 64 BK= 32 nw=4 ns=2: 26.699 ms
1p BM= 64 BN= 64 BK= 32 nw=4 ns=3: 26.290 ms
1p BM= 32 BN= 64 BK= 32 nw=4 ns=2: 36.728 ms
1p BM= 32 BN= 64 BK= 32 nw=4 ns=3: 36.411 ms
1p BM= 32 BN=128 BK= 32 nw=4 ns=2: 33.409 ms
1p BM= 32 BN=128 BK= 32 nw=4 ns=3: 32.611 ms
2p BM= 64 BN= 64 BK= 64 nw=4 ns=2: 32.204 ms
2p BM= 64 BN= 64 BK= 64 nw=4 ns=3: 30.961 ms
2p BM= 64 BN= 64 BK= 32 nw=4 ns=2: 31.885 ms
2p BM= 64 BN= 64 BK= 32 nw=4 ns=3: 31.908 ms
2p BM=128 BN= 64 BK= 64 nw=4 ns=2: 27.088 ms
2p BM=128 BN= 64 BK= 64 nw=4 ns=3: 25.847 ms
2p BM=128 BN= 64 BK= 32 nw=4 ns=3: 26.501 ms
2p BM=128 BN= 64 BK= 32 nw=4 ns=2: 27.664 ms
1p BM=128 BN=128 BK= 32 nw=8 ns=2: 31.502 ms
1p BM=128 BN=128 BK= 32 nw=8 ns=3: 27.252 ms
1p BM= 64 BN=256 BK= 64 nw=8 ns=2: 37.472 ms
1p BM= 64 BN=256 BK= 64 nw=8 ns=3: FAILED (OutOfResources)
1p BM=128 BN=256 BK= 64 nw=8 ns=2: 1356.410 ms
1p BM=128 BN=256 BK= 32 nw=8 ns=2: 1760.793 ms
{
"command": "<812 chars \u2014 see below>",
"description": "Benchmark peak matmul throughput",
"timeout": 120000
}uv run python -c "
import torch
import time
device = torch.device('cuda:0')
# Simple dense matmul benchmark to check peak
M, K, N = 8192, 4096, 4096
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(K, N, dtype=torch.bfloat16, device=device)
# warmup
for _ in range(5):
c = a @ b
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(100):
c = a @ b
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) / 100
flops = 2 * M * K * N
tflops = flops / (ms * 1e12)
print(f'Dense GEMM {M}x{K}x{N}: {ms*1000:.3f} ms, {tflops:.1f} TFLOPS')
# Check SM count
props = torch.cuda.get_device_properties(0)
print(f'SM count: {props.multi_processor_count}')
print(f'Registers per SM: {props.regs_per_multiprocessor}')
print(f'Shared mem per SM: {props.total_memory}')
" 2>&1Dense GEMM 8192x4096x4096: 0.849 ms, 323.6 TFLOPS
SM count: 188
Registers per SM: 65536
Shared mem per SM: 101949898752
{
"command": "<552 chars \u2014 see below>",
"description": "Check GPU hardware limits"
}uv run python -c "
import torch
props = torch.cuda.get_device_properties(0)
print(f'SMs: {props.multi_processor_count}')
print(f'Max threads per SM: {props.max_threads_per_multiprocessor}')
print(f'Max threads per block: {props.max_threads_per_block}')
print(f'Max shared mem per block: {props.max_shared_memory_per_block}')
print(f'Max shared mem per SM: {props.max_shared_memory_per_multiprocessor}')
print(f'Regs per block: {props.regs_per_block}')
print(f'Regs per SM: {props.regs_per_multiprocessor}')
print(f'Warp size: {props.warp_size}')
" 2>&1Traceback (most recent call last):
File "<string>", line 5, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'max_threads_per_multiprocessor'. Did you mean: 'max_threads_per_multi_processor'?
SMs: 188
{
"command": "<306 chars \u2014 see below>",
"description": "List all GPU properties"
}uv run python -c "
import torch
props = torch.cuda.get_device_properties(0)
for attr in sorted(dir(props)):
if not attr.startswith('_'):
try:
val = getattr(props, attr)
if not callable(val):
print(f'{attr}: {val}')
except:
pass
" 2>&1L2_cache_size: 134217728
clock_rate: 2617000
gcnArchName: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
is_integrated: 0
is_multi_gpu_board: 0
major: 12
max_threads_per_block: 1024
max_threads_per_multi_processor: 1536
memory_bus_width: 512
memory_clock_rate: 14001000
minor: 0
multi_processor_count: 188
name: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
pci_bus_id: 1
pci_device_id: 0
pci_domain_id: 0
regs_per_multiprocessor: 65536
shared_memory_per_block: 49152
shared_memory_per_block_optin: 101376
shared_memory_per_multiprocessor: 102400
total_memory: 101949898752
uuid: d594fb7a-abcb-bad5-1fc4-675b478a47df
warp_size: 32
{
"command": "<2280 chars \u2014 see below>",
"description": "Sweep configs with more warps",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch
import time
import sys
sys.path.insert(0, '.')
from solution import Model, _launch_kernel, _grouped_gemm_swiglu_3d_kernel
device = torch.device('cuda:0')
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
model = Model(T_total, H, I, E, K).to(device)
T_perm = T_total * K
torch.manual_seed(42)
hs = torch.randn(T_perm, H, dtype=torch.bfloat16, device=device) * 0.1
base = T_perm // E; rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E+1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
configs = [
# BM BN BK nw ns - focus on higher warp counts and smaller BK for less SMEM
(128, 64, 64, 4, 2), # baseline: ~22ms
(128, 64, 32, 8, 2), # more warps, less SMEM
(128, 64, 32, 8, 3),
(128, 64, 32, 4, 2), # same warps, less SMEM
(128, 64, 32, 4, 3),
(64, 128, 32, 8, 2),
(64, 128, 32, 8, 3),
(64, 64, 32, 8, 2),
(64, 64, 32, 8, 3),
(128, 64, 64, 8, 2), # more warps, same BK
(128, 128, 32, 8, 2),
(128, 128, 32, 8, 3),
(64, 64, 64, 8, 2),
(64, 64, 64, 8, 3),
(128, 64, 32, 16, 2), # even more warps
(128, 64, 32, 16, 3),
(64, 128, 32, 16, 2),
(64, 128, 32, 16, 3),
]
for BM, BN, BK, nw, ns in configs:
try:
for _ in range(3):
_launch_kernel(_grouped_gemm_swiglu_3d_kernel, hs, model.W_gate, model.W_up, offsets, BM, BN, BK, nw, ns)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(20):
_launch_kernel(_grouped_gemm_swiglu_3d_kernel, hs, model.W_gate, model.W_up, offsets, BM, BN, BK, nw, ns)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) / 20
# Estimate SMEM: ns * (BM*BK*2 + BK*BN*2 + BK*BN*2)
smem = ns * (BM*BK*2 + BK*BN*2*2)
blocks = min(102400 // max(smem,1), 1536 // (nw*32))
occ = min(blocks * nw * 32, 1536) / 1536 * 100
print(f'BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw:2d} ns={ns}: {ms*1000:.3f} ms SMEM={smem//1024}KB occ={occ:.0f}%')
except Exception as e:
print(f'BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw:2d} ns={ns}: FAILED ({type(e).__name__})')
" 2>&1BM=128 BN= 64 BK= 64 nw= 4 ns=2: 21.394 ms SMEM=64KB occ=8%
BM=128 BN= 64 BK= 32 nw= 8 ns=2: 30.145 ms SMEM=32KB occ=50%
BM=128 BN= 64 BK= 32 nw= 8 ns=3: 22.839 ms SMEM=48KB occ=33%
BM=128 BN= 64 BK= 32 nw= 4 ns=2: 21.566 ms SMEM=32KB occ=25%
BM=128 BN= 64 BK= 32 nw= 4 ns=3: 21.508 ms SMEM=48KB occ=17%
BM= 64 BN=128 BK= 32 nw= 8 ns=2: 24.464 ms SMEM=40KB occ=33%
BM= 64 BN=128 BK= 32 nw= 8 ns=3: 23.955 ms SMEM=60KB occ=17%
BM= 64 BN= 64 BK= 32 nw= 8 ns=2: 59.934 ms SMEM=24KB occ=67%
BM= 64 BN= 64 BK= 32 nw= 8 ns=3: 41.719 ms SMEM=36KB occ=33%
BM=128 BN= 64 BK= 64 nw= 8 ns=2: 27.758 ms SMEM=64KB occ=17%
BM=128 BN=128 BK= 32 nw= 8 ns=2: 30.262 ms SMEM=48KB occ=33%
BM=128 BN=128 BK= 32 nw= 8 ns=3: 26.612 ms SMEM=72KB occ=17%
BM= 64 BN= 64 BK= 64 nw= 8 ns=2: 47.119 ms SMEM=48KB occ=33%
BM= 64 BN= 64 BK= 64 nw= 8 ns=3: 42.646 ms SMEM=72KB occ=17%
BM=128 BN= 64 BK= 32 nw=16 ns=2: 52.286 ms SMEM=32KB occ=100%
BM=128 BN= 64 BK= 32 nw=16 ns=3: 49.089 ms SMEM=48KB occ=67%
BM= 64 BN=128 BK= 32 nw=16 ns=2: 38.049 ms SMEM=40KB occ=67%
BM= 64 BN=128 BK= 32 nw=16 ns=3: 27.951 ms SMEM=60KB occ=33%
{
"command": "nvcc --version 2>&1 && nvcc --list-gpu-arch 2>&1 | tail -5",
"description": "Check nvcc version and supported archs"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
compute_100
compute_110
compute_103
compute_120
compute_121
{
"command": "<142 chars \u2014 see below>",
"description": "Check default CUDA arch flags"
}uv run python -c "
import torch.utils.cpp_extension as cpp_ext
print('CUDA arch list:', cpp_ext.CUDAExtension._get_cuda_arch_flags([]))
" 2>&1Traceback (most recent call last):
File "<string>", line 3, in <module>
AttributeError: 'function' object has no attribute '_get_cuda_arch_flags'
{
"command": "<681 chars \u2014 see below>",
"description": "Test CUDA compilation for SM120",
"timeout": 120000
}uv run python -c "
import torch
print('CUDA arch:', torch.cuda.get_device_capability())
print('Torch CUDA version:', torch.version.cuda)
import torch.utils.cpp_extension as ce
# Test compilation with SM120
src = '''
__global__ void test_kernel(float* out, const float* in, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = in[idx] * 2.0f;
}
'''
mod = ce.load_inline(name='test_mod', cpp_sources=['void test(float* out, const float* in, int n) { test_kernel<<<(n+255)/256, 256>>>(out, in, n); }'], cuda_sources=[src], extra_cuda_cflags=['-gencode=arch=compute_120,code=sm_120'], verbose=True)
print('Compilation succeeded!')
" 2>&1 | tail -20File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 13, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2268, in _jit_compile
_write_ninja_file_and_build_library(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2429, in _write_ninja_file_and_build_library
_run_ninja_build(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2791, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error building extension 'test_mod'
{
"command": "<499 chars \u2014 see below>",
"description": "Test CUDA compilation with -arch=sm_120",
"timeout": 120000
}uv run python -c "
import torch.utils.cpp_extension as ce
src = '''
__global__ void test_kernel(float* out, const float* in, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = in[idx] * 2.0f;
}
'''
mod = ce.load_inline(name='test_mod2', cpp_sources=['void test(float* out, const float* in, int n) { test_kernel<<<(n+255)/256, 256>>>(out, in, n); }'], cuda_sources=[src], extra_cuda_cflags=['-arch=sm_120'], verbose=False)
print('Compilation succeeded!')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 9, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2268, in _jit_compile
_write_ninja_file_and_build_library(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2429, in _write_ninja_file_and_build_library
_run_ninja_build(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2791, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error building extension 'test_mod2': [1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=test_mod2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem
... (2860 more chars)