"""Grouped GEMM + fused SwiGLU kernel for MoE up-projection.
Per expert e: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
Key optimization: M_STRIDE super-tiling. Each CTA processes M_STRIDE
consecutive M-tiles with shared weight loads. This amortizes the cost
of loading W_gate and W_up across multiple M-tiles, effectively halving
weight memory traffic when M_STRIDE=2.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "grouped_gemm_swiglu"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
@triton.autotune(
configs=[
# M_STRIDE=1 (baseline, no super-tiling)
triton.Config({"BM": 64, "BN": 128, "BK": 64, "MS": 1}, num_warps=4, num_stages=3),
triton.Config({"BM": 128, "BN": 64, "BK": 64, "MS": 1}, num_warps=8, num_stages=3),
triton.Config({"BM": 128, "BN": 128, "BK": 64, "MS": 1}, num_warps=8, num_stages=3),
triton.Config({"BM": 64, "BN": 64, "BK": 128, "MS": 1}, num_warps=4, num_stages=2),
triton.Config({"BM": 128, "BN": 64, "BK": 128, "MS": 1}, num_warps=8, num_stages=2),
triton.Config({"BM": 64, "BN": 128, "BK": 128, "MS": 1}, num_warps=4, num_stages=2),
triton.Config({"BM": 128, "BN": 128, "BK": 128, "MS": 1}, num_warps=8, num_stages=2),
triton.Config({"BM": 256, "BN": 64, "BK": 64, "MS": 1}, num_warps=8, num_stages=2),
# M_STRIDE=2 (super-tiling: weight tiles shared across 2 M-tiles)
triton.Config({"BM": 64, "BN": 64, "BK": 64, "MS": 2}, num_warps=4, num_stages=3),
triton.Config({"BM": 64, "BN": 128, "BK": 64, "MS": 2}, num_warps=4, num_stages=3),
triton.Config({"BM": 64, "BN": 64, "BK": 128, "MS": 2}, num_warps=4, num_stages=2),
triton.Config({"BM": 64, "BN": 128, "BK": 128, "MS": 2}, num_warps=4, num_stages=2),
triton.Config({"BM": 64, "BN": 64, "BK": 64, "MS": 2}, num_warps=8, num_stages=2),
triton.Config({"BM": 64, "BN": 128, "BK": 64, "MS": 2}, num_warps=8, num_stages=2),
],
key=["H", "I"],
)
@triton.jit
def _grouped_gemm_swiglu_kernel(
x_ptr,
W_gate_ptr,
W_up_ptr,
out_ptr,
expert_offsets_ptr,
H,
I,
E,
stride_xm,
stride_xk,
stride_wge,
stride_wgk,
stride_wgn,
stride_wue,
stride_wuk,
stride_wun,
stride_om,
stride_on,
BM: tl.constexpr,
BN: tl.constexpr,
BK: tl.constexpr,
MS: tl.constexpr,
):
pid_tile = tl.program_id(0)
pid_expert = tl.program_id(1)
expert_id = pid_expert
m_start = tl.load(expert_offsets_ptr + expert_id)
m_end = tl.load(expert_offsets_ptr + expert_id + 1)
m_size = m_end - m_start
n_n_tiles = (I + BN - 1) // BN
n_m_tiles = (m_size + BM - 1) // BM
n_m_groups = (n_m_tiles + MS - 1) // MS
total_tiles = n_m_groups * n_n_tiles
if pid_tile >= total_tiles:
return
m_group = pid_tile // n_n_tiles
tile_n = pid_tile % n_n_tiles
n_off = tile_n * BN
offs_n = n_off + tl.arange(0, BN)
if MS == 1:
tile_m = m_group
m_off = tile_m * BM
offs_m = m_off + tl.arange(0, BM)
mask_m = offs_m < m_size
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_start in range(0, H, BK):
offs_k = k_start + tl.arange(0, BK)
x_ptrs = x_ptr + (m_start + offs_m[:, None]) * stride_xm + offs_k[None, :] * stride_xk
x_tile = tl.load(x_ptrs, mask=mask_m[:, None], other=0.0)
wg_ptrs = W_gate_ptr + expert_id * stride_wge + offs_k[:, None] * stride_wgk + offs_n[None, :] * stride_wgn
wg_tile = tl.load(wg_ptrs)
wu_ptrs = W_up_ptr + expert_id * stride_wue + offs_k[:, None] * stride_wuk + offs_n[None, :] * stride_wun
wu_tile = tl.load(wu_ptrs)
gate_acc += tl.dot(x_tile, wg_tile)
up_acc += tl.dot(x_tile, wu_tile)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
result = gate_silu * up_acc
out_ptrs = out_ptr + (m_start + offs_m[:, None]) * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, result.to(tl.bfloat16), mask=mask_m[:, None])
elif MS == 2:
tile_m_0 = m_group * 2
tile_m_1 = m_group * 2 + 1
m_off_0 = tile_m_0 * BM
m_off_1 = tile_m_1 * BM
offs_m_0 = m_off_0 + tl.arange(0, BM)
offs_m_1 = m_off_1 + tl.arange(0, BM)
mask_m_0 = offs_m_0 < m_size
mask_m_1 = offs_m_1 < m_size
valid_0 = m_off_0 < m_size
valid_1 = m_off_1 < m_size
gate_acc_0 = tl.zeros((BM, BN), dtype=tl.float32)
up_acc_0 = tl.zeros((BM, BN), dtype=tl.float32)
gate_acc_1 = tl.zeros((BM, BN), dtype=tl.float32)
up_acc_1 = tl.zeros((BM, BN), dtype=tl.float32)
for k_start in range(0, H, BK):
offs_k = k_start + tl.arange(0, BK)
wg_ptrs = W_gate_ptr + expert_id * stride_wge + offs_k[:, None] * stride_wgk + offs_n[None, :] * stride_wgn
wg_tile = tl.load(wg_ptrs)
wu_ptrs = W_up_ptr + expert_id * stride_wue + offs_k[:, None] * stride_wuk + offs_n[None, :] * stride_wun
wu_tile = tl.load(wu_ptrs)
if valid_0:
x_ptrs_0 = x_ptr + (m_start + offs_m_0[:, None]) * stride_xm + offs_k[None, :] * stride_xk
x_tile_0 = tl.load(x_ptrs_0, mask=mask_m_0[:, None], other=0.0)
gate_acc_0 += tl.dot(x_tile_0, wg_tile)
up_acc_0 += tl.dot(x_tile_0, wu_tile)
if valid_1:
x_ptrs_1 = x_ptr + (m_start + offs_m_1[:, None]) * stride_xm + offs_k[None, :] * stride_xk
x_tile_1 = tl.load(x_ptrs_1, mask=mask_m_1[:, None], other=0.0)
gate_acc_1 += tl.dot(x_tile_1, wg_tile)
up_acc_1 += tl.dot(x_tile_1, wu_tile)
if valid_0:
gate_silu_0 = gate_acc_0 * tl.sigmoid(gate_acc_0)
result_0 = gate_silu_0 * up_acc_0
out_ptrs_0 = out_ptr + (m_start + offs_m_0[:, None]) * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs_0, result_0.to(tl.bfloat16), mask=mask_m_0[:, None])
if valid_1:
gate_silu_1 = gate_acc_1 * tl.sigmoid(gate_acc_1)
result_1 = gate_silu_1 * up_acc_1
out_ptrs_1 = out_ptr + (m_start + offs_m_1[:, None]) * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs_1, result_1.to(tl.bfloat16), mask=mask_m_1[:, None])
class Model(nn.Module):
def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
def forward(
self,
hidden_states: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
T_perm, H = hidden_states.shape
I = self.I
E = self.E
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
max_tokens_per_expert = int((expert_offsets[1:] - expert_offsets[:-1]).max())
def grid(meta):
BM = meta["BM"]
BN = meta["BN"]
MS = meta["MS"]
n_m_tiles = triton.cdiv(max_tokens_per_expert, BM)
n_m_groups = (n_m_tiles + MS - 1) // MS
n_n_tiles = triton.cdiv(I, BN)
return (n_m_groups * n_n_tiles, E)
_grouped_gemm_swiglu_kernel[grid](
hidden_states,
self.W_gate,
self.W_up,
out,
expert_offsets,
H,
I,
E,
hidden_states.stride(0),
hidden_states.stride(1),
self.W_gate.stride(0),
self.W_gate.stride(1),
self.W_gate.stride(2),
self.W_up.stride(0),
self.W_up.stride(1),
self.W_up.stride(2),
out.stride(0),
out.stride(1),
)
return out
T_total = 32768
H = 4096
I = 1536 # noqa: E741
E = 128
K = 8
def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
T_perm = T_total * K
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
return offsets
def get_inputs():
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
expert_offsets = _build_routing(T_total, E, K)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 17:55:54.814000 589834 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=33.576 gbps=251.385 ms=24.560
shape=0 variant=compiled tflops=33.520 gbps=250.964 ms=24.601
shape=0 variant=solution tflops=39.833 gbps=298.230 ms=20.702
shape=0 solution_peak_fraction=0.1992
shape=1 variant=eager tflops=14.230 gbps=264.041 ms=2.415
shape=1 variant=compiled tflops=14.275 gbps=264.866 ms=2.407
shape=1 variant=solution tflops=67.825 gbps=1258.477 ms=0.507
shape=1 solution_peak_fraction=0.3391
shape=2 variant=eager tflops=35.585 gbps=243.256 ms=15.449
shape=2 variant=compiled tflops=35.320 gbps=241.446 ms=15.565
shape=2 variant=solution tflops=39.863 gbps=272.498 ms=13.791
shape=2 solution_peak_fraction=0.1993
peak_fraction: 0.2379
RESULT: OK
PASS
{
"problem": "06_sonic_moe_swiglu",
"harness": "opencode",
"model": "zai/glm-5.1",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.2379,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 103736, "output_tokens": 10577, "cache_read_tokens": 1217536, "cache_creation_tokens": 0, "reasoning_tokens": 83817, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py</path>
<type>file</type>
<content>
1: """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2:
3: This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4: expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5: according to routing metadata; we compute, per expert e:
6:
7: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
8:
9: where x_e is the slice of permuted hidden states routed to expert e, with
10: expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11:
12: The reference loops over experts in Python. Slow, but pedagogically clear and
13: correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14: are NOT used here, but the reference is exempt — only solution.py is checked.
15: """
16: from __future__ import annotations
17:
18: import torch
19: import torch.nn as nn
20: import torch.nn.functional as F
21:
22: OP_TYPE = "grouped_gemm_swiglu"
23: SUPPORTED_PRECISIONS = ["bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26:
27: class Model(nn.Module):
28: """Up-projection of a top-K MoE FFN with fused SwiGLU.
29:
30: Inputs at call time:
31: hidden_states: (T_perm, H) bf16, already permuted to expert order
32: expert_offsets: (E+1,) int32, prefix sums of token counts per expert
33:
... (2866 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Defaults match the sonic-moe paper's headline configuration. We add:
4: - a smaller shape for fast iteration during agent development
5: - a wider intermediate (different aspect ratio) to stress N-tile selection
6: """
7:
8: SHAPES = [
9: # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10: {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11:
12: # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13: # variable-length grouped layout meaningful, but smaller token / hidden dims.
14: {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15:
16: # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17: # Forces tiles to handle larger N relative to K.
18: {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.ge
... (2793 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml</path>
<type>file</type>
<content>
1: name: 06_sonic_moe_swiglu
2: display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3: precision: bf16
4: regime: compute
5: hardware: [RTX_PRO_6000]
6: peak_tflops_key: bf16
7: peak_bandwidth_key: dram
8:
9: # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10: # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11: # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12: # only if you count routing. The standard MoE FLOPs convention counts only the
13: # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14: flops_formula: "2 * T_total * H * (2 * I)"
15:
16: # Bytes moved (approximate, lower bound):
17: # read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18: # expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19: bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25: # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26: # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27: forbidden:
28: - "torch.matmul"
29: - "torch.bmm"
30: - "torch.nn.functional.linear"
31: - "F.linear"
32: - "from son
... (681 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: return float(eval(expr, {"__builtins__": {}}, vars))
27:
28:
29: def main():
30: import reference
31: import shapes
32: import solution
33:
34: meta = yaml.safe_load(Path("problem.yaml").read_text())
35: hw = get_hw(meta["hardware"][0])
36: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37: peak_gbps = hw.peak_bandwidth_gb_s
38: regime = meta.get("regime", "compute")
39: flops_formula = meta["
... (3701 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to write the up-projection of a top-K MoE FFN with grouped GEMM and fused SwiGLU for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The op is: per expert e, compute h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e]) where x_e is the slice of permuted hidden states routed to expert e. Inputs are hidden_states of shape (T_perm, H) bf16 (already permuted into expert order, T_perm = T_total * K), expert_offsets of shape (E+1,) int32 (prefix sums of token counts per expert, so expert e owns rows [offsets[e]:offsets[e+1]]), and per-expert weights W_gate and W_up each of shape (E, H, I) bf16. You return gated_up of shape (T_perm, I) bf16. Correctness tolerance is 0.02 abs/rel on bf16. The shapes you have to handle are T_total=32768 H=4096 I=1536 E=128 K=8 (the headline sonic-moe configuration, 32K tokens, 128 experts, top-8), T_total=4096 H=2048 I=1024 E=64 K=4 (smaller / faster iteration), and T_total=16384 H=2048 I=4096 E=64 K=8 (intermediate-heavy aspect ratio that forces larger N tiles).
4:
5: This needs t
... (1078 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
2:
3: Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
4: requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
5: SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
6: the package installs cleanly but kernels may fail at launch on SM120 (the
7: QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
8:
9: If the live call fails, `is_available()` returns False and the benchmark scores
10: the agent against PyTorch eager + the documented H100 paper ceiling (see
11: problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
12: importing sonic_moe in solution.py (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def _try_sonic_moe(
20: hidden_states: torch.Tensor,
21: W_gate: torch.Tensor,
22: W_up: torch.Tensor,
23: expert_offsets: torch.Tensor,
24: ) -> torch.Tensor | None:
25: try:
26: import sonic_moe # type: ignore # noqa: F401
27: except Exception:
28: return None
29: try:
30: # Public sonic-moe API surface is still stabilizing. The expected entry
31: # point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
32: # SwiGLU. Adapt to the a
... (1755 more chars){
"todos": [
{
"content": "Read reference files and understand the interface",
"status": "completed",
"priority": "high"
},
{
"content": "Check environment (CUDA, Triton, GPU capabilities)",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement Triton grouped GEMM + fused SwiGLU kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py and verify correctness across all shapes",
"status": "pending",
"priority": "high"
},
{
"content": "Run benc
... (239 more chars)content: 'Read reference files and understand the interface'
status: 'completed'
priority: 'high'
content: 'Check environment (CUDA, Triton, GPU capabilities)'
status: 'in_progress'
priority: 'high'
content: 'Implement Triton grouped GEMM + fused SwiGLU kernel'
status: 'pending'
priority: 'high'
content: 'Run check.py and verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py and optimize performance'
status: 'pending'
priority: 'medium'
content: 'Iterate on performance until peak_fraction is high'
status: 'pending'
priority: 'medium'{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py 2>/dev/null && echo \"EXISTS\" || echo \"NO SOLUTION YET\"",
"description": "Check if solution.py exists"
}NO SOLUTION YET
{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>/dev/null || echo \"nvidia-smi failed\"",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "uv run python -c \"import triton; print('Triton:', triton.__version__); import torch; print('PyTorch:', torch.__version__); print('CUDA:', torch.version.cuda); print('Arch list:', torch.cuda.get_arch_list()); print('Device cap:', torch.cuda.get_device_capability(0))\"",
"description": "Check Triton and PyTorch versions"
}Triton: 3.6.0
PyTorch: 2.11.0+cu130
CUDA: 13.0
Arch list: ['sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120']
Device cap: (12, 0)
{
"todos": [
{
"content": "Read reference files and understand the interface",
"priority": "high",
"status": "completed"
},
{
"content": "Check environment (CUDA, Triton, GPU capabilities)",
"priority": "high",
"status": "completed"
},
{
"content": "Implement Triton grouped GEMM + fused SwiGLU kernel",
"priority": "high",
"status": "in_progress"
},
{
"content": "Run check.py and verify correctness across all shapes",
"priority": "high",
"status": "pending"
},
{
"content": "Run be
... (241 more chars)content: 'Read reference files and understand the interface'
status: 'completed'
priority: 'high'
content: 'Check environment (CUDA, Triton, GPU capabilities)'
status: 'completed'
priority: 'high'
content: 'Implement Triton grouped GEMM + fused SwiGLU kernel'
status: 'in_progress'
priority: 'high'
content: 'Run check.py and verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py and optimize performance'
status: 'pending'
priority: 'medium'
content: 'Iterate on performance until peak_fraction is high'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,192 @@
+"""Grouped GEMM + fused SwiGLU kernel for MoE up-projection.
+
+Per expert e: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
+
+Uses Triton with a 2D grid: (flat_tile_within_expert, expert_id).
+Each CTA loads x once and computes both gate and up dot products,
+fusing SwiGLU in the epilogue.
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "grouped_gemm_swiglu"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BM": 64, "BN": 64, "BK": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BM": 64, "BN": 128, "BK": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BM": 128, "BN": 64, "BK": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BM": 128, "BN": 128, "BK": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BM": 64, "BN": 64, "BK": 32}, num_warps=4, num_stages=4),
+ triton.Config({"BM": 128, "BN": 64, "BK": 32}, num_warps=8, num_stages=4),
+ triton.Config({"BM": 64, "BN": 128, "BK": 32}, num_warps=4, num_stages=4),
+ triton.Config({"BM": 128, "BN": 128, "BK": 32}, num_warps=8, num_stages=4),
+ ],
+ key=["H", "I"],
+)
+@triton.jit
+def _grouped_gemm_swiglu_kernel(
+ x_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ out_ptr,
+ expert_offsets_ptr,
+ H,
+ I,
+ E,
+ stride_xm,
+ stride_xk,
+ stride_wge,
+ stride_wgk,
+ stride_wgn,
+ stride_wue,
+ stride_wuk,
+ stride_wun,
+ stride_om,
+ stride_on,
+ BM: tl.constexpr,
+ BN: tl.constexpr,
+ BK: tl.constexpr,
+):
+ pid_tile = tl.program_id(0)
+ pid_expert = tl.program_id(1)
+ expert_id = pid_expert
+
+ m_start = tl.load(expert_offsets_ptr + expert_id)
+ m_end = tl.load(expert_offsets_ptr + expert_id + 1)
+ m_size = m_end - m_start
+
+ n_n_tiles = (I + BN - 1) // BN
+ n_m_tiles = (m_size + BM - 1) // BM
+ total_tiles = n_m_tiles * n_n_tiles
+
+ if pid_tile >= total_tiles:
+ return
+
+ tile_m = pid_tile // n_n_tiles
+ tile_n = pid_tile % n_n_tiles
+
+ m_off = tile_m * BM
+ n_off = tile_n * BN
+
+ offs_m = m_off + tl.arange(0, BM)
+ offs_n = n_off + tl.arange(0, BN)
+ mask_m = offs_m < m_size
+ mask_n = offs_n < I
+
+ gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ up_acc = tl.zeros((BM, BN), dtype=tl.float32)
+
+ for k_start in range(0, H, BK):
+ offs_k = k_start + tl.arange(0, BK)
+ mask_k = offs_k < H
+
+ x_ptrs = x_ptr + (m_start + offs_m[:, None]) * stride_xm + offs_k[None, :] * stride_xk
+ x_tile = tl.load(x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0)
+
+ wg_ptrs = W_gate_ptr + expert_id * stride_wge + offs_k[:, None] * stride_wgk + offs_n[None, :] * stride_wgn
+ wg_tile = tl.load(wg_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0)
+
+ wu_ptrs = W_up_ptr + expert_id * stride_wue + offs_k[:, None] * stride_wuk + offs_n[None, :] * stride_wun
+ wu_tile = tl.load(wu_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0)
+
+ gate_acc += tl.dot(x_tile, wg_tile)
+ up_acc += tl.dot(x_tile, wu_tile)
+
+ gate_silu = gate_acc * tl.sigmoid(gate_acc)
+ result = gate_silu * up_acc
+
+ out_ptrs = out_ptr + (m_start + offs_m[:, None]) * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, result.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :])
+
+
+class Model(nn.Module):
+ def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
+ super().__init__()
+ self.T_total = T_total
+ self.H = H
+ self.I = I
+ self.E = E
+ self.K = K
+ self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ nn.init.normal_(self.W_gate, std=0.02)
+ nn.init.normal_(self.W_up, std=0.02)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ ) -> torch.Tensor:
+ T_perm, H = hidden_states.shape
+ I = self.I
+ E = self.E
+ out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
+
+ max_tokens_per_expert = int((expert_offsets[1:] - expert_offsets[:-1]).max())
+
+ def grid(meta):
+ BM = meta["BM"]
+ BN = meta["BN"]
+ n_m_tiles = triton.cdiv(max_tokens_per_expert, BM)
+ n_n_tiles = triton.cdiv(I, BN)
+ return (n_m_tiles * n_n_tiles, E)
+
+ _grouped_gemm_swiglu_kernel[grid](
+ hidden_states,
+ self.W_gate,
+ self.W_up,
+ out,
+ expert_offsets,
+ H,
+ I,
+ E,
+ hidden_states.stride(0),
+ hidden_states.stride(1),
+ self.W_gate.stride(0),
+ self.W_gate.stride(1),
+ self.W_gate.stride(2),
+ self.W_up.stride(0),
+ self.W_up.stride(1),
+ self.W_up.stride(2),
+ out.stride(0),
+ out.stride(1),
+ )
+
+ return out
+
+
+T_total = 32768
+H = 4096
+I = 1536 # noqa: E741
+E = 128
+K = 8
+
+
+def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
+ T_perm = T_total * K
+ base = T_perm // E
+ rem = T_perm - base * E
+ counts = torch.full((E,), base, dtype=torch.int32, device=device)
+ counts[:rem] += 1
+ offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
+ offsets[1:] = torch.cumsum(counts, dim=0)
+ return offsets
+
+
+def get_inputs():
+ T_perm = T_total * K
+ hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
+ expert_offsets = _build_routing(T
... (120 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"content": "<5811 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run performance benchmark",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0428 17:23:51.538000 571199 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 1
... (798 more chars){
"command": "which ncu 2>/dev/null && ncu --version 2>/dev/null || echo \"ncu not found\"",
"description": "Check if ncu is available"
}/usr/local/cuda-13/bin/ncu
NVIDIA (R) Nsight Compute Command Line Profiler
Copyright (c) 2018-2026 NVIDIA Corporation
Version 2026.1.0.0 (build 37166530) (public-release)
{
"command": "<1069 chars \u2014 see below>",
"description": "Test if tl.dot uses tensor cores on SM120",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch, triton, triton.language as tl
import torch.nn as nn
# Quick test: does tl.dot use tensor cores on SM120?
@triton.jit
def simple_dot_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
pid = tl.program_id(0)
offs_m = pid * BM + tl.arange(0, BM)
offs_n = tl.arange(0, BN)
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, K, BK):
offs_k = k + tl.arange(0, BK)
a = tl.load(a_ptr + offs_m[:, None] * K + offs_k[None, :])
b = tl.load(b_ptr + offs_k[:, None] * N + offs_n[None, :])
acc += tl.dot(a, b)
tl.store(c_ptr + offs_m[:, None] * N + offs_n[None, :], acc)
M, N, K = 2048, 1024, 4096
a = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
b = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
c = torch.empty(M, N, dtype=torch.float32, device='cuda')
simple_dot_kernel[(M//64,)](a, b, c, M, N, K, BM=64, BN=64, BK=64)
print('tl.dot works on SM120, result shape:', c.shape)
print('Result norm:', c.norm().item())
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_dot.py
@@ -0,0 +1,39 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def simple_dot_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
+ pid = tl.program_id(0)
+ offs_m = pid * BM + tl.arange(0, BM)
+ offs_n = tl.arange(0, BN)
+ acc = tl.zeros((BM, BN), dtype=tl.float32)
+ for k in range(0, K, BK):
+ offs_k = k + tl.arange(0, BK)
+ a = tl.load(a_ptr + offs_m[:, None] * K + offs_k[None, :])
+ b = tl.load(b_ptr + offs_k[:, None] * N + offs_n[None, :])
+ acc += tl.dot(a, b)
+ tl.store(c_ptr + offs_m[:, None] * N + offs_n[None, :], acc)
+
+M, N, K = 2048, 1024, 4096
+a = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
+b = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
+c = torch.empty(M, N, dtype=torch.float32, device='cuda')
+
+# Warmup
+for _ in range(5):
+ simple_dot_kernel[(M//64,)](a, b, c, M, N, K, BM=64, BN=64, BK=64)
+
+# Time
+import time
+torch.cuda.synchronize()
+start = time.time()
+for _ in range(100):
+ simple_dot_kernel[(M//64,)](a, b, c, M, N, K, BM=64, BN=64, BK=64)
+torch.cuda.synchronize()
+elapsed = (time.time() - start) / 100 * 1000
+
+flops = 2 * M * N * K
+tflops = flops / elapsed / 1e9
+print(f"Dense GEMM: {tflops:.1f} TFLOPS ({elapsed:.3f} ms)")
+print(f"Peak fraction (assuming 200 TFLOPS peak): {tflops/200*100:.1f}%")
{
"content": "<1336 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_dot.py"
}Wrote file successfully.{
"command": "uv run python test_dot.py 2>&1",
"description": "Test tl.dot performance on SM120",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Dense GEMM: 847.0 TFLOPS (0.020 ms)
Peak fraction (assuming 200 TFLOPS peak): 423.5%
{
"command": "<896 chars \u2014 see below>",
"description": "Quick benchmark shape 1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch
import solution
T_total, H, I, E, K = 4096, 2048, 1024, 64, 4
model = solution.Model(T_total, H, I, E, K).cuda().eval()
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
base = T_perm // E; rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device='cuda'); counts[:rem] += 1
expert_offsets = torch.zeros(E+1, dtype=torch.int32, device='cuda'); expert_offsets[1:] = torch.cumsum(counts, dim=0)
# Warmup
for _ in range(3):
out = model(hidden_states, expert_offsets)
torch.cuda.synchronize()
# Time
import time
start = time.time()
for _ in range(50):
out = model(hidden_states, expert_offsets)
torch.cuda.synchronize()
elapsed = (time.time() - start) / 50 * 1000
flops = 2 * T_total * H * (2 * I)
print(f'Shape 1 direct: {flops/elapsed/1e9:.1f} TFLOPS ({elapsed:.3f} ms)')
" 2>&1Shape 1 direct: 68.4 TFLOPS (0.502 ms)
{
"command": "<339 chars \u2014 see below>",
"description": "Check hardware peak specs",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard')
from src.hardware import get as get_hw
hw = get_hw('RTX_PRO_6000')
print('Peak bf16 TFLOPS:', hw.peak_tflops_dense.get('bf16', 'N/A'))
print('Peak bandwidth GB/s:', hw.peak_bandwidth_gb_s)
print('All peak_tflops_dense:', hw.peak_tflops_dense)
" 2>&1Peak bf16 TFLOPS: 200.0
Peak bandwidth GB/s: 1800.0
All peak_tflops_dense: {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}
{
"command": "<831 chars \u2014 see below>",
"description": "Profile with ncu on small shape",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}NCU_FLAGS="--set full --target-processes all --cache none --launch-skip 3 --launch-count 1 -s 42 123" && uv run ncu $NCU_FLAGS -- python -c "
import torch, sys
sys.path.insert(0, '.')
import solution
T_total, H, I, E, K = 4096, 2048, 1024, 64, 4
model = solution.Model(T_total, H, I, E, K).cuda().eval()
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
base = T_perm // E; rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device='cuda'); counts[:rem] += 1
expert_offsets = torch.zeros(E+1, dtype=torch.int32, device='cuda'); expert_offsets[1:] = torch.cumsum(counts, dim=0)
# Warmup
for _ in range(5):
out = model(hidden_states, expert_offsets)
torch.cuda.synchronize()
# Profile this
out = model(hidden_states, expert_offsets)
" 2>&1 | head -200==ERROR== unrecognised option '--set full --target-processes all --cache none --launch-skip 3 --launch-count 1 -s 42 123'. Use --help for further details.
{
"command": "<733 chars \u2014 see below>",
"description": "Profile with ncu on small shape",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run ncu --set full --launch-skip 5 --launch-count 1 -- python -c "
import torch, sys
sys.path.insert(0, '.')
import solution
T_total, H, I, E, K = 4096, 2048, 1024, 64, 4
model = solution.Model(T_total, H, I, E, K).cuda().eval()
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
base = T_perm // E; rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device='cuda'); counts[:rem] += 1
expert_offsets = torch.zeros(E+1, dtype=torch.int32, device='cuda'); expert_offsets[1:] = torch.cumsum(counts, dim=0)
for _ in range(5):
out = model(hidden_states, expert_offsets)
torch.cuda.synchronize()
out = model(hidden_states, expert_offsets)
" 2>&1 | tail -80==PROF== Connected to process 579220 (/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/bin/python3.11)
==ERROR== ERR_NVGPUCTRPERM - The user does not have permission to access NVIDIA GPU Performance Counters on the target device 0. For instructions on enabling permissions and to get more information see https://developer.nvidia.com/ERR_NVGPUCTRPERM
==PROF== Disconnected from process 579220
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: """Grouped GEMM + fused SwiGLU kernel for MoE up-projection.
2:
3: Per expert e: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
4:
5: Uses Triton with a 2D grid: (flat_tile_within_expert, expert_id).
6: Each CTA loads x once and computes both gate and up dot products,
7: fusing SwiGLU in the epilogue.
8: """
9: from __future__ import annotations
10:
11: import torch
12: import torch.nn as nn
13: import triton
14: import triton.language as tl
15:
16: OP_TYPE = "grouped_gemm_swiglu"
17: SUPPORTED_PRECISIONS = ["bf16"]
18: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
19:
20:
21: @triton.autotune(
22: configs=[
23: triton.Config({"BM": 64, "BN": 64, "BK": 64}, num_warps=4, num_stages=3),
24: triton.Config({"BM": 64, "BN": 128, "BK": 64}, num_warps=4, num_stages=3),
25: triton.Config({"BM": 128, "BN": 64, "BK": 64}, num_warps=8, num_stages=3),
26: triton.Config({"BM": 128, "BN": 128, "BK": 64}, num_warps=8, num_stages=3),
27: triton.Config({"BM": 64, "BN": 64, "BK": 32}, num_warps=4, num_stages=4),
28: triton.Config({"BM": 128, "BN": 64, "BK": 32}, num_warps=8, num_stages=4),
29: triton.Config({"BM": 64, "BN": 128, "BK": 32}, num_warps=4, num_stages=4),
30: triton.Config({"BM": 128, "BN": 128, "BK": 32}, num_warps=8, num_stages=4),
31: ],
32: key=["H", "I"],
33: )
34: @tri
... (5326 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -5,6 +5,11 @@
Uses Triton with a 2D grid: (flat_tile_within_expert, expert_id).
Each CTA loads x once and computes both gate and up dot products,
fusing SwiGLU in the epilogue.
+
+Optimizations over v1:
+- Wider autotune search space (BK=128, larger BM, varied stages/warps)
+- Removed unnecessary K/N masks for weight loads (H%BK==0, I%BN==0 for all shapes)
+- 1D grid variant to eliminate wasted programs on small experts
"""
from __future__ import annotations
@@ -28,6 +33,18 @@
triton.Config({"BM": 128, "BN": 64, "BK": 32}, num_warps=8, num_stages=4),
triton.Config({"BM": 64, "BN": 128, "BK": 32}, num_warps=4, num_stages=4),
triton.Config({"BM": 128, "BN": 128, "BK": 32}, num_warps=8, num_stages=4),
+ triton.Config({"BM": 64, "BN": 64, "BK": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BM": 64, "BN": 128, "BK": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BM": 128, "BN": 64, "BK": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BM": 128, "BN": 128, "BK": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BM": 256, "BN": 64, "BK": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BM": 64, "BN": 256, "BK": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BM": 256, "BN": 64, "BK": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BM": 128, "BN": 64, "BK": 64}, num_warps=4, num_stages=4),
+ triton.Config({"BM": 64, "BN": 128, "BK": 64}, num_warps=4, num_stages=4),
+ triton.Config({"BM": 128, "BN": 64, "BK": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BM": 64, "BN": 64, "BK": 64}, num_warps=2, num_stages=4),
+ triton.Config({"BM": 128, "BN": 128, "BK": 64}, num_warps=4, num_stages=3),
],
key=["H", "I"],
)
@@ -79,23 +96,21 @@
offs_m = m_off + tl.arange(0, BM)
offs_n = n_off + tl.arange(0, BN)
mask_m = offs_m < m_size
- mask_n = offs_n < I
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_start in range(0, H, BK):
offs_k = k_start + tl.arange(0, BK)
- mask_k = offs_k < H
x_ptrs = x_ptr + (m_start + offs_m[:, None]) * stride_xm + offs_k[None, :] * stride_xk
- x_tile = tl.load(x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0)
+ x_tile = tl.load(x_ptrs, mask=mask_m[:, None], other=0.0)
wg_ptrs = W_gate_ptr + expert_id * stride_wge + offs_k[:, None] * stride_wgk + offs_n[None, :] * stride_wgn
- wg_tile = tl.load(wg_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0)
+ wg_tile = tl.load(wg_ptrs)
wu_ptrs = W_up_ptr + expert_id * stride_wue + offs_k[:, None] * stride_wuk + offs_n[None, :] * stride_wun
- wu_tile = tl.load(wu_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0)
+ wu_tile = tl.load(wu_ptrs)
gate_acc += tl.dot(x_tile, wg_tile)
up_acc += tl.dot(x_tile, wu_tile)
@@ -104,7 +119,7 @@
result = gate_silu * up_acc
out_ptrs = out_ptr + (m_start + offs_m[:, None]) * stride_om + offs_n[None, :] * stride_on
- tl.store(out_ptrs, result.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :])
+ tl.store(out_ptrs, result.to(tl.bfloat16), mask=mask_m[:, None])
class Model(nn.Module):
{
"content": "<6863 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness after optimization",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark after adding more configs",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0428 17:48:08.277000 586755 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0428 1
... (798 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -2,14 +2,10 @@
Per expert e: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
-Uses Triton with a 2D grid: (flat_tile_within_expert, expert_id).
-Each CTA loads x once and computes both gate and up dot products,
-fusing SwiGLU in the epilogue.
-
-Optimizations over v1:
-- Wider autotune search space (BK=128, larger BM, varied stages/warps)
-- Removed unnecessary K/N masks for weight loads (H%BK==0, I%BN==0 for all shapes)
-- 1D grid variant to eliminate wasted programs on small experts
+Key optimization: M_STRIDE super-tiling. Each CTA processes M_STRIDE
+consecutive M-tiles with shared weight loads. This amortizes the cost
+of loading W_gate and W_up across multiple M-tiles, effectively halving
+weight memory traffic when M_STRIDE=2.
"""
from __future__ import annotations
@@ -25,26 +21,22 @@
@triton.autotune(
configs=[
- triton.Config({"BM": 64, "BN": 64, "BK": 64}, num_warps=4, num_stages=3),
- triton.Config({"BM": 64, "BN": 128, "BK": 64}, num_warps=4, num_stages=3),
- triton.Config({"BM": 128, "BN": 64, "BK": 64}, num_warps=8, num_stages=3),
- triton.Config({"BM": 128, "BN": 128, "BK": 64}, num_warps=8, num_stages=3),
- triton.Config({"BM": 64, "BN": 64, "BK": 32}, num_warps=4, num_stages=4),
- triton.Config({"BM": 128, "BN": 64, "BK": 32}, num_warps=8, num_stages=4),
- triton.Config({"BM": 64, "BN": 128, "BK": 32}, num_warps=4, num_stages=4),
- triton.Config({"BM": 128, "BN": 128, "BK": 32}, num_warps=8, num_stages=4),
- triton.Config({"BM": 64, "BN": 64, "BK": 128}, num_warps=4, num_stages=2),
- triton.Config({"BM": 64, "BN": 128, "BK": 128}, num_warps=4, num_stages=2),
- triton.Config({"BM": 128, "BN": 64, "BK": 128}, num_warps=8, num_stages=2),
- triton.Config({"BM": 128, "BN": 128, "BK": 128}, num_warps=8, num_stages=2),
- triton.Config({"BM": 256, "BN": 64, "BK": 64}, num_warps=8, num_stages=2),
- triton.Config({"BM": 64, "BN": 256, "BK": 64}, num_warps=8, num_stages=2),
- triton.Config({"BM": 256, "BN": 64, "BK": 32}, num_warps=8, num_stages=3),
- triton.Config({"BM": 128, "BN": 64, "BK": 64}, num_warps=4, num_stages=4),
- triton.Config({"BM": 64, "BN": 128, "BK": 64}, num_warps=4, num_stages=4),
- triton.Config({"BM": 128, "BN": 64, "BK": 64}, num_warps=4, num_stages=2),
- triton.Config({"BM": 64, "BN": 64, "BK": 64}, num_warps=2, num_stages=4),
- triton.Config({"BM": 128, "BN": 128, "BK": 64}, num_warps=4, num_stages=3),
+ # M_STRIDE=1 (baseline, no super-tiling)
+ triton.Config({"BM": 64, "BN": 128, "BK": 64, "MS": 1}, num_warps=4, num_stages=3),
+ triton.Config({"BM": 128, "BN": 64, "BK": 64, "MS": 1}, num_warps=8, num_stages=3),
+ triton.Config({"BM": 128, "BN": 128, "BK": 64, "MS": 1}, num_warps=8, num_stages=3),
+ triton.Config({"BM": 64, "BN": 64, "BK": 128, "MS": 1}, num_warps=4, num_stages=2),
+ triton.Config({"BM": 128, "BN": 64, "BK": 128, "MS": 1}, num_warps=8, num_stages=2),
+ triton.Config({"BM": 64, "BN": 128, "BK": 128, "MS": 1}, num_warps=4, num_stages=2),
+ triton.Config({"BM": 128, "BN": 128, "BK": 128, "MS": 1}, num_warps=8, num_stages=2),
+ triton.Config({"BM": 256, "BN": 64, "BK": 64, "MS": 1}, num_warps=8, num_stages=2),
+ # M_STRIDE=2 (super-tiling: weight tiles shared across 2 M-tiles)
+ triton.Config({"BM": 64, "BN": 64, "BK": 64, "MS": 2}, num_warps=4, num_stages=3),
+ triton.Config({"BM": 64, "BN": 128, "BK": 64, "MS": 2}, num_warps=4, num_stages=3),
+ triton.Config({"BM": 64, "BN": 64, "BK": 128, "MS": 2}, num_warps=4, num_stages=2),
+ triton.Config({"BM": 64, "BN": 128, "BK": 128, "MS": 2}, num_warps=4, num_stages=2),
+ triton.Config({"BM": 64, "BN": 64, "BK": 64, "MS": 2}, num_warps=8, num_stages=2),
+ triton.Config({"BM": 64, "BN": 128, "BK": 64, "MS": 2}, num_warps=8, num_stages=2),
],
key=["H", "I"],
)
@@ -71,6 +63,7 @@
BM: tl.constexpr,
BN: tl.constexpr,
BK: tl.constexpr,
+ MS: tl.constexpr,
):
pid_tile = tl.program_id(0)
pid_expert = tl.program_id(1)
@@ -82,44 +75,99 @@
n_n_tiles = (I + BN - 1) // BN
n_m_tiles = (m_size + BM - 1) // BM
- total_tiles = n_m_tiles * n_n_tiles
+ n_m_groups = (n_m_tiles + MS - 1) // MS
+ total_tiles = n_m_groups * n_n_tiles
if pid_tile >= total_tiles:
return
- tile_m = pid_tile // n_n_tiles
+ m_group = pid_tile // n_n_tiles
tile_n = pid_tile % n_n_tiles
- m_off = tile_m * BM
n_off = tile_n * BN
-
- offs_m = m_off + tl.arange(0, BM)
offs_n = n_off + tl.arange(0, BN)
- mask_m = offs_m < m_size
-
- gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
- up_acc = tl.zeros((BM, BN), dtype=tl.float32)
-
- for k_start in range(0, H, BK):
- offs_k = k_start + tl.arange(0, BK)
-
- x_ptrs = x_ptr + (m_start + offs_m[:, None]) * stride_xm + offs_k[None, :] * stride_xk
- x_tile = tl.load(x_ptrs, mask=mask_m[:, None], other=0.0)
-
- wg_ptrs = W_gate_ptr + expert_id * stride_wge + offs_k[:, None] * stride_wgk + offs_n[None, :] * stride_wgn
- wg_tile = tl.load(wg_ptrs)
-
- wu_ptrs = W_up_ptr + expert_id * stride_wue + offs_k[:, None] * stride_wuk + offs_n[None, :] * stride_wun
- wu_tile = tl.load(wu_ptrs)
-
- gate_acc += tl.dot(x_tile, wg_tile)
- up_acc += tl.dot(x_tile, wu_tile)
-
- gate_silu = gate_acc * tl.sigmoid(gate_acc)
- result = gate_silu * up_acc
-
- out_ptrs = out_ptr + (m_start + offs_m[:, None]) * stride_om + offs_n[None, :] * stride_on
- tl.store(out_ptrs, result.to(tl.bfloat16), mask=mask_m[:, None])
+
+ if MS == 1:
+ tile_m = m_group
+ m_off = tile_m * BM
+ offs_m = m_off
... (3972 more chars){
"content": "<9030 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.