"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
2D grid: (sum_e cdiv(n_e, BM), cdiv(I, BN)) with per-block expert lookup
via binary search on cumulative tile counts.
Block sizes tuned for SM120 shared memory budget (~98KB usable).
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _grouped_gemm_swiglu_kernel(
X_ptr, # (T_perm, H), row-major
W_gate_ptr, # (E, H, I), row-major
W_up_ptr, # (E, H, I), row-major
Out_ptr, # (T_perm, I), row-major
Offsets_ptr, # (E+1,) int32
CumMTiles_ptr, # (E+1,) int32
H,
I_val,
BM: tl.constexpr,
BN: tl.constexpr,
BK: tl.constexpr,
NUM_SMS: tl.constexpr,
):
pid = tl.program_id(0)
pid_n = tl.program_id(1)
# 1D -> (pid_m, ...) mapping
pid_m = pid
# Find expert: largest e such that CumMTiles[e] <= pid_m
# E is at most 128; binary search
lo = 0
hi = 127
while lo < hi:
mid = (lo + hi + 1) >> 1
if tl.load(CumMTiles_ptr + mid).to(tl.int32) <= pid_m:
lo = mid
else:
hi = mid - 1
expert = lo
# Intra-expert M-tile index
m_tile_idx = pid_m - tl.load(CumMTiles_ptr + expert).to(tl.int32)
# Expert row range
row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
n_e = row_end - row_start
# Row offsets
m_offs = m_tile_idx * BM + tl.arange(0, BM)
m_valid = m_offs < n_e
row_g = row_start + m_offs
# Column offsets
n_offs = pid_n * BN + tl.arange(0, BN)
n_valid = n_offs < I_val
# K offsets
k_offs = tl.arange(0, BK)
acc_g = tl.zeros((BM, BN), dtype=tl.float32)
acc_u = tl.zeros((BM, BN), dtype=tl.float32)
w_base = expert * H * I_val
for k_start in range(0, H, BK):
k = k_start + k_offs
k_mask = k < H
a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * H + k[None, :]
a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
wg_ptrs = W_gate_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
wu_ptrs = W_up_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
acc_g += tl.dot(a, wg)
acc_u += tl.dot(a, wu)
out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * I_val + n_offs[None, :]
tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
def _launch(X, W_gate, W_up, out, expert_offsets):
H = X.size(1)
I_val = out.size(1)
E = expert_offsets.size(0) - 1
device = X.device
# Shared memory: (BM*BK + 2*BK*BN)*2 bytes, num_stages=1
# We want to maximize BM*BN*BK under ~98KB.
# 256x128x64: 65536, flops/tile=2097152
# 128x128x128: 98304, flops/tile=2097152
# 128x256x64: 65536, flops/tile=2097152
if H >= 4096 and I_val <= 4096:
BM, BN, BK = 256, 128, 64
elif I_val >= 4096:
BM, BN, BK = 128, 128, 128
else:
BM, BN, BK = 128, 128, 128
counts = expert_offsets[1:] - expert_offsets[:-1]
m_tiles = (counts + BM - 1) // BM # (E,)
cum_m_tiles = torch.cat([
torch.zeros(1, dtype=torch.int32, device=device),
torch.cumsum(m_tiles, dim=0).to(torch.int32)
])
total_m_tiles = int(cum_m_tiles[-1])
n_tiles = triton.cdiv(I_val, BN)
grid = (total_m_tiles, n_tiles)
_grouped_gemm_swiglu_kernel[grid](
X, W_gate, W_up, out, expert_offsets, cum_m_tiles,
H, I_val,
BM=BM, BN=BN, BK=BK,
NUM_SMS=128,
num_stages=1,
num_warps=8,
)
class Model(nn.Module):
def __init__(self, T_total, H, I, E, K): # noqa: E741
super().__init__()
self.I = I
self.H = H
self.E = E
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
def forward(self, hidden_states, expert_offsets):
T_perm = hidden_states.size(0)
out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
_launch(hidden_states, self.W_gate, self.W_up, out, expert_offsets)
return out
T_total = 32768
H = 4096
I = 1536 # noqa: E741
E = 128
K = 8
def get_inputs():
from reference import _build_routing
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
expert_offsets = _build_routing(T_total, E, K)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=32.879 gbps=246.163 ms=25.081
shape=0 variant=compiled tflops=32.402 gbps=242.591 ms=25.450
shape=0 variant=solution tflops=0.488 gbps=3.653 ms=1690.150
shape=0 solution_peak_fraction=0.0024
shape=1 variant=eager tflops=13.959 gbps=259.002 ms=2.462
shape=1 variant=compiled tflops=14.003 gbps=259.826 ms=2.454
shape=1 variant=solution tflops=43.275 gbps=802.956 ms=0.794
shape=1 solution_peak_fraction=0.2164
shape=2 variant=eager tflops=34.766 gbps=237.657 ms=15.813
shape=2 variant=compiled tflops=34.578 gbps=236.376 ms=15.899
shape=2 variant=solution tflops=24.664 gbps=168.604 ms=22.289
shape=2 solution_peak_fraction=0.1233
peak_fraction: 0.0402
RESULT: LOW
PASS
{
"problem": "06_sonic_moe_swiglu",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-plus",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0402,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 1928295, "output_tokens": 27957, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 21104, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py</path>
<type>file</type>
<content>
1: """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2:
3: This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4: expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5: according to routing metadata; we compute, per expert e:
6:
7: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
8:
9: where x_e is the slice of permuted hidden states routed to expert e, with
10: expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11:
12: The reference loops over experts in Python. Slow, but pedagogically clear and
13: correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14: are NOT used here, but the reference is exempt — only solution.py is checked.
15: """
16: from __future__ import annotations
17:
18: import torch
19: import torch.nn as nn
20: import torch.nn.functional as F
21:
22: OP_TYPE = "grouped_gemm_swiglu"
23: SUPPORTED_PRECISIONS = ["bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26:
27: class Model(nn.Module):
28: """Up-projection of a top-K MoE FFN with fused SwiGLU.
29:
30: Inputs at call time:
31: hidden_states: (T_perm, H) bf16, already permuted to expert order
32: expert_offsets: (E+1,) int32, prefix sums of token counts per expert
33:
... (2866 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.ge
... (2793 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: return float(eval(expr, {"__builtins__": {}}, vars))
27:
28:
29: def main():
30: import reference
31: import shapes
32: import solution
33:
34: meta = yaml.safe_load(Path("problem.yaml").read_text())
35: hw = get_hw(meta["hardware"][0])
36: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37: peak_gbps = hw.peak_bandwidth_gb_s
38: regime = meta.get("regime", "compute")
39: flops_formula = meta["
... (3701 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Defaults match the sonic-moe paper's headline configuration. We add:
4: - a smaller shape for fast iteration during agent development
5: - a wider intermediate (different aspect ratio) to stress N-tile selection
6: """
7:
8: SHAPES = [
9: # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10: {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11:
12: # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13: # variable-length grouped layout meaningful, but smaller token / hidden dims.
14: {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15:
16: # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17: # Forces tiles to handle larger N relative to K.
18: {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml</path>
<type>file</type>
<content>
1: name: 06_sonic_moe_swiglu
2: display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3: precision: bf16
4: regime: compute
5: hardware: [RTX_PRO_6000]
6: peak_tflops_key: bf16
7: peak_bandwidth_key: dram
8:
9: # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10: # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11: # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12: # only if you count routing. The standard MoE FLOPs convention counts only the
13: # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14: flops_formula: "2 * T_total * H * (2 * I)"
15:
16: # Bytes moved (approximate, lower bound):
17: # read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18: # expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19: bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25: # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26: # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27: forbidden:
28: - "torch.matmul"
29: - "torch.bmm"
30: - "torch.nn.functional.linear"
31: - "F.linear"
32: - "from son
... (681 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
2:
3: Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
4: requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
5: SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
6: the package installs cleanly but kernels may fail at launch on SM120 (the
7: QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
8:
9: If the live call fails, `is_available()` returns False and the benchmark scores
10: the agent against PyTorch eager + the documented H100 paper ceiling (see
11: problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
12: importing sonic_moe in solution.py (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def _try_sonic_moe(
20: hidden_states: torch.Tensor,
21: W_gate: torch.Tensor,
22: W_up: torch.Tensor,
23: expert_offsets: torch.Tensor,
24: ) -> torch.Tensor | None:
25: try:
26: import sonic_moe # type: ignore # noqa: F401
27: except Exception:
28: return None
29: try:
30: # Public sonic-moe API surface is still stabilizing. The expected entry
31: # point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
32: # SwiGLU. Adapt to the a
... (1755 more chars){
"pattern": "**/solution.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard"
}No files found{
"prompt": "<545 chars \u2014 see below>",
"description": "Explore existing solutions",
"subagent_type": "general"
}Search the KernelBench-Hard codebase for:
1. All existing solution.py files that use Triton - find 2-3 examples and read them
2. Check if Triton is installed and what version (look in pyproject.toml or constraints files)
3. Check if CUTTLE is available and what version
4. Look at any existing grouped GEMM implementations in the repo
5. Look at src/eval/correctness.py to understand the check_correctness function
Return: Triton version, CUTLASS availability, code patterns from 2-3 triton solutions, and the check_correctness tolerance logic.task_id: ses_2266ee687ffe4ve73pZR7wQNzM (for resuming to continue this task if needed)
<task_result>
**Results:**
**1. No existing solution.py files** — none have been generated yet (agents haven't produced them). There are no import triton/`from triton` statements anywhere in the repo.
**2. Triton version:** 3.6.0 (installed and working with `triton.language` module)
**3. CUTLASS:** NOT available — not installed at `/usr/local/cuda/include/cutlass` or anywhere under the default CUDA path.
**4. Grouped GEMM reference** (`problems/06_sonic_moe_swiglu/reference.py`):
- `reference.py:61-70` — naive Python loop: `for e in range(E)`, slices `hidden_states[start:end]`, does `x_e @ W_gate[e]` and `x_e @ W_up[e]`, then `F.silu(gate) * up`.
- No existing Triton grouped GEMM implementations exist in the repo.
**5. check_correctness tolerances** (`src/eval/correctness.py`):
- fp32: atol=1e-4, rtol=1e-4
- fp16/bf16: atol=1e-2, rtol=1e-2
- fp8 (e4m3/e5m2): atol=0.1, rtol=0.1
- int8/32/64: exact match (atol=0, rtol=0)
- Casts both tensors to fp32 before comparison (line 80-81)
</task_result>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/pyproject.toml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/pyproject.toml</path>
<type>file</type>
<content>
1: [project]
2: name = "kernelbench-hard"
3: version = "0.1.0"
4: description = "Surgical GPU kernel benchmark. Small deck, frontier agents, roofline metric, link-don't-spoil problem briefs."
5: requires-python = ">=3.11"
6: dependencies = [
7: "torch>=2.11",
8: "numpy>=1.26",
9: "pyyaml>=6.0",
10: "rich>=13.7",
11: "matplotlib>=3.9",
12: "pytest>=8.0",
13: "ruff>=0.6",
14: "einops>=0.8.2",
15: "requests<2.33",
16: ]
17:
18: [build-system]
19: requires = ["hatchling"]
20: build-backend = "hatchling.build"
21:
22: [tool.hatch.build.targets.wheel]
23: packages = ["src"]
24:
25: [tool.ruff]
26: line-length = 100
27: target-version = "py311"
28:
29: [tool.ruff.lint]
30: select = ["E", "F", "W", "I", "B", "UP"]
31: ignore = ["E501"]
(End of file - total 31 lines)
</content>{
"pattern": "**/solution.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard"
}No files found{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard"
}<path>/home/infatoshi/cuda/KernelBench-Hard</path>
<type>directory</type>
<entries>
.cursorrules
.git/
.gitignore
.pytest_cache/
.python-version
.ruff_cache/
.venv/
AGENTS.md
CLAUDE.md
DEVLOG.md
docs/
LICENSE
outputs/
problems/
pyproject.toml
README.md
scripts/
SPEC.md
src/
tests/
uv.lock
(21 entries)
</entries>{
"tool": "",
"error": "Model tried to call unavailable tool ''. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill."
}The arguments provided to the tool are invalid: Model tried to call unavailable tool ''. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"command": "nvidia-smi --query-gpu=name,count,driver_version --format=csv",
"description": "Check GPU info"
}name, count, driver_version
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 1, 595.58.03
{
"command": "uv run python -c \"import triton; print(triton.__version__); print(triton.runtime.driver.active.get_current_target())\"",
"description": "Check Triton version and target"
}3.6.0
GPUTarget(backend='cuda', arch=120, warp_size=32)
{
"command": "uv run python -c \"import triton; from triton.language._experimental import *; print('Triton works')\"",
"description": "Verify Triton basic imports"
}Traceback (most recent call last):
File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'triton.language._experimental'
{
"command": "<191 chars \u2014 see below>",
"description": "Test basic BF16 matmul on device"
}uv run python -c "
import torch
x = torch.randn(4, 4, device='cuda', dtype=torch.bfloat16)
y = torch.randn(4, 4, device='cuda', dtype=torch.bfloat16)
print(x @ y)
print('BF16 matmul works')
"tensor([[-0.7734, 2.4375, 0.5195, -2.5156],
[ 1.0625, 2.9688, -0.8477, 1.4453],
[ 1.1797, 1.9219, -0.4219, 1.2891],
[ 1.2500, -0.2461, 1.2656, -1.0000]], device='cuda:0',
dtype=torch.bfloat16)
BF16 matmul works
{
"todos": [
{
"content": "Implement basic Triton grouped GEMM + SwiGLU kernel with Python loop over experts",
"status": "in_progress",
"priority": "high"
},
{
"content": "Verify correctness with check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark to get timing numbers",
"status": "pending",
"priority": "medium"
}
]
}content: 'Implement basic Triton grouped GEMM + SwiGLU kernel with Python loop over experts'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel'
status: 'pending'
priority: 'high'
content: 'Run benchmark to get timing numbers'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,190 @@
+"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
+
+Per-expert dense GEMM with SwiGLU fused in the epilogue. Loops over experts
+in Python (each expert is a dense tile); the Triton kernel computes both
+x @ W_gate and x @ W_up simultaneously and applies silu(gate) * up.
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel: fused dual GEMM + SwiGLU
+# ---------------------------------------------------------------------------
+# C[m,n] = silu(A[m,k] @ W_gate[k,n]) * (A[m,k] @ W_up[k,n])
+#
+# A is (M, K), W_gate is (K, N), W_up is (K, N), out is (M, N).
+# ---------------------------------------------------------------------------
+@triton.jit
+def _grouped_gemm_swiglu_kernel(
+ A_ptr, # base ptr to (n_e, H)
+ W_gate_ptr, # base ptr to (H, I) for this expert
+ W_up_ptr, # base ptr to (H, I) for this expert
+ Out_ptr, # base ptr to (n_e, I) for this expert
+ M: tl.constexpr,
+ K: tl.constexpr,
+ N: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ m_mask = m_offsets < M
+ n_mask = n_offsets < N
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k in range(0, K, BLOCK_K):
+ k_offsets = k + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < K
+
+ # Load A block: (BLOCK_M, BLOCK_K), col-major read row-major store
+ # A is row-major with stride K
+ a_ptrs = A_ptr + m_offsets[:, None] * K + k_offsets[None, :]
+ a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+ # Load W_gate block: (BLOCK_K, BLOCK_N), row-major stride N
+ wg_ptrs = W_gate_ptr + k_offsets[:, None] * N + n_offsets[None, :]
+ wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+
+ # Load W_up block: (BLOCK_K, BLOCK_N), row-major stride N
+ wu_ptrs = W_up_ptr + k_offsets[:, None] * N + n_offsets[None, :]
+ wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+
+ acc_gate = tl.dot(a, wg, acc_gate, out_dtype=tl.float32)
+ acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
+
+ # SwiGLU: silu(gate) * up
+ # silu(x) = x * sigmoid(x)
+ acc_gate = acc_gate.astype(tl.bfloat16)
+ acc_gate_fp32 = acc_gate.to(tl.float32)
+ sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate_fp32))
+ silu = acc_gate.to(tl.float32) * sigmoid
+ out_fp32 = silu * acc_up
+
+ out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
+ out = out_fp32.to(tl.bfloat16)
+ tl.store(out_ptr, out, mask=m_mask[:, None] & n_mask[None, :])
+
+
+def grouped_gemm_swiglu_launch(x, w_gate, w_up, n_tokens, out):
+ """Launch fused dual GEMM + SwiGLU for one expert.
+
+ x: (n_tokens, H) bf16
+ w_gate: (H, I) bf16
+ w_up: (H, I) bf16
+ n_tokens: int
+ out: (n_tokens, I) bf16
+ """
+ H = x.shape[1]
+ I = w_gate.shape[1]
+
+ # Heuristic block sizes
+ # SM120 has many SMs, try larger tiles for larger matrices
+ if n_tokens >= 1024 and I >= 1024:
+ BLOCK_M = 256
+ BLOCK_N = 256
+ BLOCK_K = 128
+ elif n_tokens >= 512:
+ BLOCK_M = 128
+ BLOCK_N = 128
+ BLOCK_K = 128
+ elif n_tokens >= 128:
+ BLOCK_M = 128
+ BLOCK_N = 64
+ BLOCK_K = 64
+ else:
+ BLOCK_M = 64
+ BLOCK_N = 64
+ BLOCK_K = 64
+
+ grid = (
+ triton.cdiv(n_tokens, BLOCK_M),
+ triton.cdiv(I, BLOCK_N),
+ )
+
+ _grouped_gemm_swiglu_kernel[grid](
+ x, w_gate, w_up, out,
+ M=n_tokens,
+ K=H,
+ N=I,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ )
+
+
+class Model(nn.Module):
+ """Up-projection of a top-K MoE FFN with fused SwiGLU.
+
+ Inputs at call time:
+ hidden_states: (T_perm, H) bf16, already permuted to expert order
+ expert_offsets: (E+1,) int32, prefix sums of token counts per expert
+
+ Output:
+ gated_up: (T_perm, I) bf16
+ """
+
+ def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
+ super().__init__()
+ self.T_total = T_total
+ self.H = H
+ self.I = I
+ self.E = E
+ self.K = K
+ self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ nn.init.normal_(self.W_gate, std=0.02)
+ nn.init.normal_(self.W_up, std=0.02)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor, # (T_perm, H) bf16
+ expert_offsets: torch.Tensor, # (E+1,) int32
+ ) -> torch.Tensor:
+ T_perm, H = hidden_states.shape
+ out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
+
+ for e in range(self.E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ n_e = end - start
+ x_e = hidden_states[start:end].contiguous()
+ w_gate_e = self.W_gate[e].contiguous()
+ w_up_e = self.W_up[e].contiguous()
+ out_e = out[start:end]
+ grouped_gemm_swiglu_launch(x_e, w_gate_e, w_up_e, n_e, out_e)
+
+ return out
+
+
+# Module-level shape shims rewritten by check.py / benchmark.py per shape.
+T_total = 32768
+H
... (368 more chars){
"content": "<6061 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 168, in forward
grouped_gemm_swiglu_launch(x_e, w_gate_e, w_up_e, n_e, out_e)
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 116, in grouped_gemm_swiglu_launch
_grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/trito
... (1762 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -65,11 +65,9 @@
acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
# SwiGLU: silu(gate) * up
- # silu(x) = x * sigmoid(x)
- acc_gate = acc_gate.astype(tl.bfloat16)
- acc_gate_fp32 = acc_gate.to(tl.float32)
- sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate_fp32))
- silu = acc_gate.to(tl.float32) * sigmoid
+ # silu(x) = x * sigmoid(x); everything in fp32, cast to bf16 on store
+ sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate))
+ silu = acc_gate * sigmoid
out_fp32 = silu * acc_up
out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<444 chars \u2014 see below>",
"newString": "<379 chars \u2014 see below>"
} # SwiGLU: silu(gate) * up
# silu(x) = x * sigmoid(x)
acc_gate = acc_gate.astype(tl.bfloat16)
acc_gate_fp32 = acc_gate.to(tl.float32)
sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate_fp32))
silu = acc_gate.to(tl.float32) * sigmoid
out_fp32 = silu * acc_up
out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
out = out_fp32.to(tl.bfloat16)
tl.store(out_ptr, out, mask=m_mask[:, None] & n_mask[None, :]) # SwiGLU: silu(gate) * up
# silu(x) = x * sigmoid(x); everything in fp32, cast to bf16 on store
sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate))
silu = acc_gate * sigmoid
out_fp32 = silu * acc_up
out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
out = out_fp32.to(tl.bfloat16)
tl.store(out_ptr, out, mask=m_mask[:, None] & n_mask[None, :])Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check again"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 166, in forward
grouped_gemm_swiglu_launch(x_e, w_gate_e, w_up_e, n_e, out_e)
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 114, in grouped_gemm_swiglu_launch
_grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/trito
... (883 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -12,25 +12,27 @@
import triton.language as tl
-# ---------------------------------------------------------------------------
-# Triton kernel: fused dual GEMM + SwiGLU
-# ---------------------------------------------------------------------------
-# C[m,n] = silu(A[m,k] @ W_gate[k,n]) * (A[m,k] @ W_up[k,n])
-#
-# A is (M, K), W_gate is (K, N), W_up is (K, N), out is (M, N).
-# ---------------------------------------------------------------------------
@triton.jit
def _grouped_gemm_swiglu_kernel(
A_ptr, # base ptr to (n_e, H)
W_gate_ptr, # base ptr to (H, I) for this expert
W_up_ptr, # base ptr to (H, I) for this expert
Out_ptr, # base ptr to (n_e, I) for this expert
- M: tl.constexpr,
- K: tl.constexpr,
- N: tl.constexpr,
+ M,
+ K,
+ N,
+ stride_am,
+ stride_ak,
+ stride_wg_k,
+ stride_wg_n,
+ stride_wu_k,
+ stride_wu_n,
+ stride_om,
+ stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
+ NUM_STAGES: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -44,67 +46,63 @@
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k in range(0, K, BLOCK_K):
- k_offsets = k + tl.arange(0, BLOCK_K)
- k_mask = k_offsets < K
+ for k_start in range(0, K, BLOCK_K * NUM_STAGES):
+ for stage in range(NUM_STAGES):
+ k = k_start + stage * BLOCK_K
+ if k >= K:
+ break
+ k_offsets = k + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < K
- # Load A block: (BLOCK_M, BLOCK_K), col-major read row-major store
- # A is row-major with stride K
- a_ptrs = A_ptr + m_offsets[:, None] * K + k_offsets[None, :]
- a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+ a_ptrs = A_ptr + m_offsets[:, None] * stride_am + k_offsets[None, :] * stride_ak
+ a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
- # Load W_gate block: (BLOCK_K, BLOCK_N), row-major stride N
- wg_ptrs = W_gate_ptr + k_offsets[:, None] * N + n_offsets[None, :]
- wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+ wg_ptrs = W_gate_ptr + k_offsets[:, None] * stride_wg_k + n_offsets[None, :] * stride_wg_n
+ wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
- # Load W_up block: (BLOCK_K, BLOCK_N), row-major stride N
- wu_ptrs = W_up_ptr + k_offsets[:, None] * N + n_offsets[None, :]
- wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+ wu_ptrs = W_up_ptr + k_offsets[:, None] * stride_wu_k + n_offsets[None, :] * stride_wu_n
+ wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
- acc_gate = tl.dot(a, wg, acc_gate, out_dtype=tl.float32)
- acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
+ acc_gate = tl.dot(a, wg, acc_gate, out_dtype=tl.float32)
+ acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
# SwiGLU: silu(gate) * up
- # silu(x) = x * sigmoid(x); everything in fp32, cast to bf16 on store
sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate))
silu = acc_gate * sigmoid
out_fp32 = silu * acc_up
- out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
+ out_ptrs = Out_ptr + m_offsets[:, None] * stride_om + n_offsets[None, :] * stride_on
out = out_fp32.to(tl.bfloat16)
- tl.store(out_ptr, out, mask=m_mask[:, None] & n_mask[None, :])
+ tl.store(out_ptrs, out, mask=m_mask[:, None] & n_mask[None, :])
def grouped_gemm_swiglu_launch(x, w_gate, w_up, n_tokens, out):
- """Launch fused dual GEMM + SwiGLU for one expert.
-
- x: (n_tokens, H) bf16
- w_gate: (H, I) bf16
- w_up: (H, I) bf16
- n_tokens: int
- out: (n_tokens, I) bf16
- """
+ """Launch fused dual GEMM + SwiGLU for one expert."""
H = x.shape[1]
I = w_gate.shape[1]
- # Heuristic block sizes
- # SM120 has many SMs, try larger tiles for larger matrices
- if n_tokens >= 1024 and I >= 1024:
- BLOCK_M = 256
- BLOCK_N = 256
- BLOCK_K = 128
- elif n_tokens >= 512:
+ # Constrain block sizes to fit in shared memory on SM120 (~101KB limit)
+ # Each tile at bf16: 2 bytes per element
+ # Shared memory needed per stage: A(BLOCK_M, BLOCK_K) + W_gate(BLOCK_K, BLOCK_N) + W_up(BLOCK_K, BLOCK_N)
+ # For 128x128x64: 2*(128*64 + 128*128 + 128*128) = 2*(8192 + 16384 + 16384) = 81664 bytes
+ # Still large with num_stages>1. Let's use 64x64x64 or 128x64x64.
+ # Actually, Triton doesn't necessarily load all stages at once for non-pipelined kernels.
+ # Let's use moderate sizes.
+ if n_tokens >= 512 and I >= 512:
BLOCK_M = 128
BLOCK_N = 128
- BLOCK_K = 128
- elif n_tokens >= 128:
+ BLOCK_K = 64
+ NUM_STAGES = 2
+ elif n_tokens >= 256 and I >= 256:
BLOCK_M = 128
BLOCK_N = 64
BLOCK_K = 64
+ NUM_STAGES = 2
else:
BLOCK_M = 64
BLOCK_N = 64
BLOCK_K = 64
+ NUM_STAGES = 2
grid = (
triton.cdiv(n_tokens, BLOCK_M),
@@ -116,22 +114,23 @@
M=n_tokens,
K=H,
N=I,
+ stride_am=x.stride(0),
+ stride_ak=x.stride(1),
+ stride_wg_k=w_gate.stride(0),
+ stride_wg_n=w_gate.stride(1),
+ stride_wu_k=w_up.stride(0),
+ stride_wu_n=w_up.stride(1),
+ stride_om=out.stride(0),
+ stride_on=out.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
+ NUM_STAGES=NUM_
... (1432 more chars){
"content": "<6019 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 167, in forward
grouped_gemm_swiglu_launch(x_e, w_gate_e, w_up_e, n_e, out_slice)
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 112, in grouped_gemm_swiglu_launch
_grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/t
... (1577 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,9 +1,4 @@
-"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-
-Per-expert dense GEMM with SwiGLU fused in the epilogue. Loops over experts
-in Python (each expert is a dense tile); the Triton kernel computes both
-x @ W_gate and x @ W_up simultaneously and applies silu(gate) * up.
-"""
+"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection."""
from __future__ import annotations
import torch
@@ -14,158 +9,94 @@
@triton.jit
def _grouped_gemm_swiglu_kernel(
- A_ptr, # base ptr to (n_e, H)
- W_gate_ptr, # base ptr to (H, I) for this expert
- W_up_ptr, # base ptr to (H, I) for this expert
- Out_ptr, # base ptr to (n_e, I) for this expert
+ A_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ Out_ptr,
M,
K,
N,
- stride_am,
- stride_ak,
- stride_wg_k,
- stride_wg_n,
- stride_wu_k,
- stride_wu_n,
- stride_om,
- stride_on,
+ stride_a_row,
+ stride_a_col,
+ stride_w_row,
+ stride_w_col,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
- NUM_STAGES: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
- m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ k_offs = tl.arange(0, BLOCK_K)
- m_mask = m_offsets < M
- n_mask = n_offsets < N
+ a_base = A_ptr + m_offs[:, None] * stride_a_row
+ wg_base = W_gate_ptr + n_offs[None, :] * stride_w_col
+ wu_base = W_up_ptr + n_offs[None, :] * stride_w_col
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_start in range(0, K, BLOCK_K * NUM_STAGES):
- for stage in range(NUM_STAGES):
- k = k_start + stage * BLOCK_K
- if k >= K:
- break
- k_offsets = k + tl.arange(0, BLOCK_K)
- k_mask = k_offsets < K
+ for _ in range(0, K, BLOCK_K):
+ a_ptrs = a_base + k_offs[None, :] * stride_a_col
+ wg_ptrs = wg_base + k_offs[:, None] * stride_w_row
+ wu_ptrs = wu_base + k_offs[:, None] * stride_w_row
- a_ptrs = A_ptr + m_offsets[:, None] * stride_am + k_offsets[None, :] * stride_ak
- a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+ a = tl.load(a_ptrs, mask=(m_offs[:, None] < M) & (k_offs[None, :] < K), other=0.0)
+ wg = tl.load(wg_ptrs, mask=(k_offs[:, None] < K) & (n_offs[None, :] < N), other=0.0)
+ wu = tl.load(wu_ptrs, mask=(k_offs[:, None] < K) & (n_offs[None, :] < N), other=0.0)
- wg_ptrs = W_gate_ptr + k_offsets[:, None] * stride_wg_k + n_offsets[None, :] * stride_wg_n
- wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+ acc_gate += tl.dot(a, wg)
+ acc_up += tl.dot(a, wu)
- wu_ptrs = W_up_ptr + k_offsets[:, None] * stride_wu_k + n_offsets[None, :] * stride_wu_n
- wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+ a_base += BLOCK_K * stride_a_col
+ wg_base += BLOCK_K * stride_w_row
+ wu_base += BLOCK_K * stride_w_row
- acc_gate = tl.dot(a, wg, acc_gate, out_dtype=tl.float32)
- acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
+ # SwiGLU
+ silu = acc_gate * tl.sigmoid(acc_gate)
+ out_fp = silu * acc_up
+ out_bf = out_fp.to(tl.bfloat16)
- # SwiGLU: silu(gate) * up
- sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate))
- silu = acc_gate * sigmoid
- out_fp32 = silu * acc_up
-
- out_ptrs = Out_ptr + m_offsets[:, None] * stride_om + n_offsets[None, :] * stride_on
- out = out_fp32.to(tl.bfloat16)
- tl.store(out_ptrs, out, mask=m_mask[:, None] & n_mask[None, :])
+ out_ptrs = Out_ptr + m_offs[:, None] * N + n_offs[None, :]
+ tl.store(out_ptrs, out_bf, mask=(m_offs[:, None] < M) & (n_offs[None, :] < N))
-def grouped_gemm_swiglu_launch(x, w_gate, w_up, n_tokens, out):
- """Launch fused dual GEMM + SwiGLU for one expert."""
- H = x.shape[1]
- I = w_gate.shape[1]
-
- # Constrain block sizes to fit in shared memory on SM120 (~101KB limit)
- # Each tile at bf16: 2 bytes per element
- # Shared memory needed per stage: A(BLOCK_M, BLOCK_K) + W_gate(BLOCK_K, BLOCK_N) + W_up(BLOCK_K, BLOCK_N)
- # For 128x128x64: 2*(128*64 + 128*128 + 128*128) = 2*(8192 + 16384 + 16384) = 81664 bytes
- # Still large with num_stages>1. Let's use 64x64x64 or 128x64x64.
- # Actually, Triton doesn't necessarily load all stages at once for non-pipelined kernels.
- # Let's use moderate sizes.
- if n_tokens >= 512 and I >= 512:
- BLOCK_M = 128
- BLOCK_N = 128
- BLOCK_K = 64
- NUM_STAGES = 2
- elif n_tokens >= 256 and I >= 256:
- BLOCK_M = 128
- BLOCK_N = 64
- BLOCK_K = 64
- NUM_STAGES = 2
- else:
- BLOCK_M = 64
- BLOCK_N = 64
- BLOCK_K = 64
- NUM_STAGES = 2
-
- grid = (
- triton.cdiv(n_tokens, BLOCK_M),
- triton.cdiv(I, BLOCK_N),
- )
-
+def _launch(x, w_gate_e, w_up_e, n_e, out_e, bm=128, bn=128, bk=128):
+ grid = (triton.cdiv(n_e, bm), triton.cdiv(out_e.shape[1], bn))
_grouped_gemm_swiglu_kernel[grid](
- x, w_gate, w_up, out,
- M=n_tokens,
- K=H,
- N=I,
- stride_am=x.stride(0),
- stride_ak=x.stride(1),
- stride_wg_k=w_gate.stride(0),
- stride_wg_n=w_gate.stride(1),
- stride_wu_k=w_up.stride(0),
- stride_wu_n=w_up.stride(1),
- stride_om=out.stride(0),
- stride_on=out.stride(1),
- BLOCK_M=BLOCK_M,
- BLOC
... (2368 more chars){
"content": "<3695 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 99, in forward
_launch(x_e, self.W_gate[e], self.W_up[e], n_e, out_e)
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 68, in _launch
_grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743,
... (855 more chars){
"command": "<705 chars \u2014 see below>",
"description": "Calculate shared memory usage"
}uv run python -c "
# Check shared memory requirements
bm, bn, bk = 64, 64, 64
# num_stages=1
shm_bm = bm * bk * 2
shm_wg = bk * bn * 2
shm_wu = bk * bn * 2
total = shm_bm + shm_wg + shm_wu
print(f'64x64x64 num_stages=1: {total} bytes')
bm, bn, bk = 64, 64, 128
shm_bm = bm * bk * 2
shm_wg = bk * bn * 2
shm_wu = bk * bn * 2
total = shm_bm + shm_wg + shm_wu
print(f'64x64x128 num_stages=1: {total} bytes')
bm, bn, bk = 128, 64, 128
shm_bm = bm * bk * 2
shm_wg = bk * bn * 2
shm_wu = bk * bn * 2
total = shm_bm + shm_wg + shm_wu
print(f'128x64x128 num_stages=1: {total} bytes')
# num_stages=2
bm, bn, bk = 64, 64, 128
shm = (bm * bk + 2 * bk * bn) * 2 * 2
print(f'64x64x128 num_stages=2: {shm} bytes')
"
64x64x64 num_stages=1: 24576 bytes
64x64x128 num_stages=1: 49152 bytes
128x64x128 num_stages=1: 65536 bytes
64x64x128 num_stages=2: 98304 bytes
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -63,14 +63,23 @@
tl.store(out_ptrs, out_bf, mask=(m_offs[:, None] < M) & (n_offs[None, :] < N))
-def _launch(x, w_gate_e, w_up_e, n_e, out_e, bm=128, bn=128, bk=128):
- grid = (triton.cdiv(n_e, bm), triton.cdiv(out_e.shape[1], bn))
+def _launch(x, w_gate_e, w_up_e, n_e, out_e):
+ I = out_e.shape[1]
+ # 64x64x128 fits in shared memory (num_stages=2 -> 98KB, just under 101KB limit)
+ if n_e >= 512 and I >= 512:
+ bm, bn, bk = 64, 128, 128
+ elif n_e >= 128:
+ bm, bn, bk = 64, 64, 128
+ else:
+ bm, bn, bk = 32, 32, 64
+ grid = (triton.cdiv(n_e, bm), triton.cdiv(I, bn))
_grouped_gemm_swiglu_kernel[grid](
x, w_gate_e, w_up_e, out_e,
- M=n_e, K=x.shape[1], N=out_e.shape[1],
+ M=n_e, K=x.shape[1], N=I,
stride_a_row=x.stride(0), stride_a_col=x.stride(1),
stride_w_row=w_gate_e.stride(0), stride_w_col=w_gate_e.stride(1),
BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
+ num_stages=1,
)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<442 chars \u2014 see below>",
"newString": "<684 chars \u2014 see below>"
}def _launch(x, w_gate_e, w_up_e, n_e, out_e, bm=128, bn=128, bk=128):
grid = (triton.cdiv(n_e, bm), triton.cdiv(out_e.shape[1], bn))
_grouped_gemm_swiglu_kernel[grid](
x, w_gate_e, w_up_e, out_e,
M=n_e, K=x.shape[1], N=out_e.shape[1],
stride_a_row=x.stride(0), stride_a_col=x.stride(1),
stride_w_row=w_gate_e.stride(0), stride_w_col=w_gate_e.stride(1),
BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
)def _launch(x, w_gate_e, w_up_e, n_e, out_e):
I = out_e.shape[1]
# 64x64x128 fits in shared memory (num_stages=2 -> 98KB, just under 101KB limit)
if n_e >= 512 and I >= 512:
bm, bn, bk = 64, 128, 128
elif n_e >= 128:
bm, bn, bk = 64, 64, 128
else:
bm, bn, bk = 32, 32, 64
grid = (triton.cdiv(n_e, bm), triton.cdiv(I, bn))
_grouped_gemm_swiglu_kernel[grid](
x, w_gate_e, w_up_e, out_e,
M=n_e, K=x.shape[1], N=I,
stride_a_row=x.stride(0), stride_a_col=x.stride(1),
stride_w_row=w_gate_e.stride(0), stride_w_col=w_gate_e.stride(1),
BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
num_stages=1,
)Edit applied successfully.{
"command": "<313 chars \u2014 see below>",
"description": "Check shared memory for different block sizes"
}# Check shared memory for 64x128x128 num_stages=1
uv run python -c "
bm, bn, bk = 64, 128, 128
shm = (bm * bk + bk * bn + bk * bn) * 2
print(f'64x128x128 num_stages=1: {shm} bytes (limit ~101KB)')
bm, bn, bk = 128, 64, 128
shm = (bm * bk + bk * bn + bk * bn) * 2
print(f'128x64x128 num_stages=1: {shm} bytes')
"
64x128x128 num_stages=1: 81920 bytes (limit ~101KB)
128x64x128 num_stages=1: 65536 bytes
{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}PASS
{
"todos": [
{
"content": "Implement basic Triton grouped GEMM + fused SwiGLU kernel with Python loop over experts",
"status": "completed",
"priority": "high"
},
{
"content": "Verify correctness with check.py",
"status": "completed",
"priority": "high"
},
{
"content": "Profile and optimize kernel for performance",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run benchmark to get timing numbers",
"status": "pending",
"priority": "medium"
}
]
}content: 'Implement basic Triton grouped GEMM + fused SwiGLU kernel with Python loop over experts'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Profile and optimize kernel for performance'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark to get timing numbers'
status: 'pending'
priority: 'medium'{
"command": "./scripts/patch_torch.sh 2>&1 | tail -3 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}zsh:1: no such file or directory: ./scripts/patch_torch.sh
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0429 08:23:18.551000 1027626 .venv/lib/python3.11
... (864 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,4 +1,9 @@
-"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection."""
+"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
+
+Single kernel handles all experts in one launch to avoid Python loop overhead.
+Grid is (E, cdiv(M_expert, BLOCK_M), cdiv(N, BLOCK_N)) where M_expert varies
+per expert (computed from expert_offsets).
+"""
from __future__ import annotations
import torch
@@ -8,104 +13,146 @@
@triton.jit
-def _grouped_gemm_swiglu_kernel(
- A_ptr,
- W_gate_ptr,
- W_up_ptr,
- Out_ptr,
- M,
- K,
- N,
- stride_a_row,
- stride_a_col,
- stride_w_row,
- stride_w_col,
- BLOCK_M: tl.constexpr,
+def _grouped_gemm_swiglu_single_kernel(
+ X_ptr, # (T_perm, H), row-major
+ W_gate_ptr, # (E, H, I), row-major
+ W_up_ptr, # (E, H, I), row-major
+ Out_ptr, # (T_perm, I), row-major
+ Offsets_ptr, # (E+1,) int32
+ T_perm,
+ H,
+ I,
+ E,
+ stride_x_row,
+ stride_w_H,
+ stride_w_I,
+ stride_o_row,
+ BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
+ # Grid: (E, cdiv(max_tokens, BLOCK_M), cdiv(I, BLOCK_N))
+ pid_e = tl.program_id(0)
+ pid_bm = tl.program_id(1)
+ pid_n = tl.program_id(2)
- m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ # Read expert offset bounds
+ off_start = tl.load(Offsets_ptr + pid_e)
+ off_end = tl.load(Offsets_ptr + pid_e + 1)
+ n_e = off_end - off_start
+
+ # M dimension for this tile
+ m_base = pid_bm * 256 # We fix BLOCK_M = 256 for simplicity
+ m_offs = m_base + tl.arange(0, 256)
+
+ # Global row indices within hidden_states
+ row_global = off_start + m_offs
+ m_valid = m_offs < n_e
+
+ # N (output) dimension
n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- k_offs = tl.arange(0, BLOCK_K)
+ n_valid = n_offs < I
- a_base = A_ptr + m_offs[:, None] * stride_a_row
- wg_base = W_gate_ptr + n_offs[None, :] * stride_w_col
- wu_base = W_up_ptr + n_offs[None, :] * stride_w_col
+ # K (hidden) dimension
+ k_offs = tl.arange(0, BLOCK_H)
+ k_valid = k_offs < H
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ # Pointers
+ # X[row, k] = X_ptr + row * stride_x_row + k
+ # W_gate[e, k, n] = W_gate_ptr + e*stride_w_H*H + k*stride_w_I + n
- for _ in range(0, K, BLOCK_K):
- a_ptrs = a_base + k_offs[None, :] * stride_a_col
- wg_ptrs = wg_base + k_offs[:, None] * stride_w_row
- wu_ptrs = wu_base + k_offs[:, None] * stride_w_row
+ acc_gate = tl.zeros((256, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((256, BLOCK_N), dtype=tl.float32)
- a = tl.load(a_ptrs, mask=(m_offs[:, None] < M) & (k_offs[None, :] < K), other=0.0)
- wg = tl.load(wg_ptrs, mask=(k_offs[:, None] < K) & (n_offs[None, :] < N), other=0.0)
- wu = tl.load(wu_ptrs, mask=(k_offs[:, None] < K) & (n_offs[None, :] < N), other=0.0)
+ # Load base W pointers (add expert offset once)
+ w_e_start = pid_e * stride_w_H * H # offset for this expert's weights
+ wg_base = W_gate_ptr + w_e_start
+ wu_base = W_up_ptr + w_e_start
- acc_gate += tl.dot(a, wg)
- acc_up += tl.dot(a, wu)
+ for h_start in range(0, H, BLOCK_H):
+ h_offset = h_start + k_offs
+ h_mask = h_offset < H
- a_base += BLOCK_K * stride_a_col
- wg_base += BLOCK_K * stride_w_row
- wu_base += BLOCK_K * stride_w_row
+ # Load X block: (256, BLOCK_H)
+ x_ptrs = X_ptr + row_global[:, None] * stride_x_row + h_offset[None, :]
+ x = tl.load(x_ptrs, mask=m_valid[:, None] & h_mask[None, :], other=0.0)
+
+ # Load W_gate block: (BLOCK_H, BLOCK_N) for expert e
+ wg_ptrs = wg_base + h_offset[:, None] * stride_w_I + n_offs[None, :]
+ wg = tl.load(wg_ptrs, mask=h_mask[:, None] & n_valid[None, :], other=0.0)
+
+ # Load W_up block: (BLOCK_H, BLOCK_N) for expert e
+ wu_ptrs = wu_base + h_offset[:, None] * stride_w_I + n_offs[None, :]
+ wu = tl.load(wu_ptrs, mask=h_mask[:, None] & n_valid[None, :], other=0.0)
+
+ acc_gate += tl.dot(x, wg)
+ acc_up += tl.dot(x, wu)
# SwiGLU
- silu = acc_gate * tl.sigmoid(acc_gate)
- out_fp = silu * acc_up
- out_bf = out_fp.to(tl.bfloat16)
+ gate_fp32 = acc_gate
+ silu = gate_fp32 * tl.sigmoid(gate_fp32)
+ out_fp32 = silu * acc_up
- out_ptrs = Out_ptr + m_offs[:, None] * N + n_offs[None, :]
- tl.store(out_ptrs, out_bf, mask=(m_offs[:, None] < M) & (n_offs[None, :] < N))
-
-
-def _launch(x, w_gate_e, w_up_e, n_e, out_e):
- I = out_e.shape[1]
- # 64x64x128 fits in shared memory (num_stages=2 -> 98KB, just under 101KB limit)
- if n_e >= 512 and I >= 512:
- bm, bn, bk = 64, 128, 128
- elif n_e >= 128:
- bm, bn, bk = 64, 64, 128
- else:
- bm, bn, bk = 32, 32, 64
- grid = (triton.cdiv(n_e, bm), triton.cdiv(I, bn))
- _grouped_gemm_swiglu_kernel[grid](
- x, w_gate_e, w_up_e, out_e,
- M=n_e, K=x.shape[1], N=I,
- stride_a_row=x.stride(0), stride_a_col=x.stride(1),
- stride_w_row=w_gate_e.stride(0), stride_w_col=w_gate_e.stride(1),
- BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
- num_stages=1,
- )
+ # Store
+ out_ptrs = Out_ptr + row_global[:, None] * stride_o_row + n_offs[None, :]
+ out = out_fp32.to(tl.bfloat16)
+ tl.store(out_ptrs, out, mask=m_valid[:, None] & n_valid[None, :])
class Model(nn.Module):
def __init__(self, T_total, H, I, E, K): # noqa: E741
super().__init__()
+ self.T_total = T_total
+ self.H = H
self.I = I
+ self.E = E
+
... (2141 more chars){
"content": "<5293 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 144, in forward
_grouped_gemm_swiglu_single_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
... (695 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,8 +1,9 @@
"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-Single kernel handles all experts in one launch to avoid Python loop overhead.
-Grid is (E, cdiv(M_expert, BLOCK_M), cdiv(N, BLOCK_N)) where M_expert varies
-per expert (computed from expert_offsets).
+Single kernel handles all experts in one launch. The 3D grid is
+(E, cdiv(M_expert, BM), cdiv(I, BN)), but we use a 2D grid
+(E * cdiv(total_rows/avg, BM), cdiv(I, BN)) and compute M from offsets.
+We keep tiles small enough for shared memory.
"""
from __future__ import annotations
@@ -13,100 +14,131 @@
@triton.jit
-def _grouped_gemm_swiglu_single_kernel(
- X_ptr, # (T_perm, H), row-major
- W_gate_ptr, # (E, H, I), row-major
- W_up_ptr, # (E, H, I), row-major
- Out_ptr, # (T_perm, I), row-major
- Offsets_ptr, # (E+1,) int32
- T_perm,
+def _grouped_gemm_swiglu_kernel(
+ X_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ Out_ptr,
+ Offsets_ptr,
H,
I,
- E,
stride_x_row,
- stride_w_H,
- stride_w_I,
+ stride_w_row,
+ stride_w_col,
stride_o_row,
- BLOCK_H: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
+ BM: tl.constexpr,
+ BN: tl.constexpr,
+ BK: tl.constexpr,
):
- # Grid: (E, cdiv(max_tokens, BLOCK_M), cdiv(I, BLOCK_N))
+ pid_mn = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ # We encode expert id in the upper bits of pid_mn for small E,
+ # OR we restructure as 3D grid: (E, BN, BK). Let me think differently.
+ # Actually the standard approach: grid (total_tiles_M, cdiv(I, BN)).
+ # We need to know which expert each tile belongs to and the intra-expert row.
+ pass
+
+
+# Better approach: use 3D grid (E, cdiv(max_M, BM), cdiv(I, BN)).
+# The kernel computes row = off_start[pid_e] + pid_bm * BM + local_m.
+# Shared memory: BM * BK + 2 * BK * BN bf16 values.
+# For BM=128, BN=128, BK=64: 128*64 + 2*64*128 = 8192 + 16384 = 24576 * 2 = 49152 (fits!)
+# For BM=64, BN=64, BK=128: 64*128 + 2*128*64 = 8192 + 16384 = 24576 * 2 = 49152 (fits!)
+
+@triton.jit
+def _gemma_gemm_kernel(
+ X_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ Out_ptr,
+ Offsets_ptr,
+ H,
+ I,
+ stride_x_row,
+ stride_w_row,
+ stride_w_col,
+ stride_o_row,
+ BM: tl.constexpr,
+ BN: tl.constexpr,
+ BK: tl.constexpr,
+):
pid_e = tl.program_id(0)
pid_bm = tl.program_id(1)
pid_n = tl.program_id(2)
- # Read expert offset bounds
- off_start = tl.load(Offsets_ptr + pid_e)
- off_end = tl.load(Offsets_ptr + pid_e + 1)
- n_e = off_end - off_start
+ # Expert bounds
+ os = tl.load(Offsets_ptr + pid_e).to(tl.int64)
+ oe = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
+ n_e = oe - os
- # M dimension for this tile
- m_base = pid_bm * 256 # We fix BLOCK_M = 256 for simplicity
- m_offs = m_base + tl.arange(0, 256)
+ # Local M offset
+ m_local = pid_bm * BM + tl.arange(0, BM)
+ m_valid = m_local < n_e
- # Global row indices within hidden_states
- row_global = off_start + m_offs
- m_valid = m_offs < n_e
+ # Global row
+ row_g = os + m_local
- # N (output) dimension
- n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- n_valid = n_offs < I
+ # N offsets
+ n_local = pid_n * BN + tl.arange(0, BN)
+ n_valid = n_local < I
- # K (hidden) dimension
- k_offs = tl.arange(0, BLOCK_H)
- k_valid = k_offs < H
+ # K offsets
+ k_local = tl.arange(0, BK)
- # Pointers
- # X[row, k] = X_ptr + row * stride_x_row + k
- # W_gate[e, k, n] = W_gate_ptr + e*stride_w_H*H + k*stride_w_I + n
+ # Accumulators
+ acc_g = tl.zeros((BM, BN), dtype=tl.float32)
+ acc_u = tl.zeros((BM, BN), dtype=tl.float32)
- acc_gate = tl.zeros((256, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((256, BLOCK_N), dtype=tl.float32)
+ # Weight base for this expert
+ wg_base = W_gate_ptr + pid_e * stride_w_row * H
+ wu_base = W_up_ptr + pid_e * stride_w_row * H
- # Load base W pointers (add expert offset once)
- w_e_start = pid_e * stride_w_H * H # offset for this expert's weights
- wg_base = W_gate_ptr + w_e_start
- wu_base = W_up_ptr + w_e_start
+ for k_start in range(0, H, BK):
+ k_offs = k_start + k_local
+ k_mask = k_offs < H
- for h_start in range(0, H, BLOCK_H):
- h_offset = h_start + k_offs
- h_mask = h_offset < H
+ # A[m, k]
+ a_ptrs = X_ptr + row_g[:, None] * stride_x_row + k_offs[None, :]
+ a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
- # Load X block: (256, BLOCK_H)
- x_ptrs = X_ptr + row_global[:, None] * stride_x_row + h_offset[None, :]
- x = tl.load(x_ptrs, mask=m_valid[:, None] & h_mask[None, :], other=0.0)
+ # W_gate[e][k, n]
+ wg_ptrs = wg_base + k_offs[:, None] * stride_w_row + n_local[None, :] * stride_w_col
+ wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
- # Load W_gate block: (BLOCK_H, BLOCK_N) for expert e
- wg_ptrs = wg_base + h_offset[:, None] * stride_w_I + n_offs[None, :]
- wg = tl.load(wg_ptrs, mask=h_mask[:, None] & n_valid[None, :], other=0.0)
+ # W_up[e][k, n]
+ wu_ptrs = wu_base + k_offs[:, None] * stride_w_row + n_local[None, :] * stride_w_col
+ wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
- # Load W_up block: (BLOCK_H, BLOCK_N) for expert e
- wu_ptrs = wu_base + h_offset[:, None] * stride_w_I + n_offs[None, :]
- wu = tl.load(wu_ptrs, mask=h_mask[:, None] & n_valid[None, :], other=0.0)
-
- acc_gate += tl.dot(x, wg)
- acc_up += tl.dot(x, wu)
+ acc_g += tl.dot(a, wg)
+ acc_u += tl.dot(a, wu)
#
... (3445 more chars){
"content": "<5405 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}(no output)
<bash_metadata>
bash tool terminated command after exceeding timeout 300000 ms. If this command is expected to take longer and is not waiting for interactive input, retry with a larger timeout value in milliseconds.
</bash_metadata>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
2:
3: Single kernel handles all experts in one launch. The 3D grid is
4: (E, cdiv(M_expert, BM), cdiv(I, BN)), but we use a 2D grid
5: (E * cdiv(total_rows/avg, BM), cdiv(I, BN)) and compute M from offsets.
6: We keep tiles small enough for shared memory.
7: """
8: from __future__ import annotations
9:
10: import torch
11: import torch.nn as nn
12: import triton
13: import triton.language as tl
14:
15:
16: @triton.jit
17: def _grouped_gemm_swiglu_kernel(
18: X_ptr,
19: W_gate_ptr,
20: W_up_ptr,
21: Out_ptr,
22: Offsets_ptr,
23: H,
24: I,
25: stride_x_row,
26: stride_w_row,
27: stride_w_col,
28: stride_o_row,
29: BM: tl.constexpr,
30: BN: tl.constexpr,
31: BK: tl.constexpr,
32: ):
33: pid_mn = tl.program_id(0)
34: pid_n = tl.program_id(1)
35:
36: # We encode expert id in the upper bits of pid_mn for small E,
37: # OR we restructure as 3D grid: (E, BN, BK). Let me think differently.
38: # Actually the standard approach: grid (total_tiles_M, cdiv(I, BN)).
39: # We need to know which expert each tile belongs to and the intra-expert row.
40: pass
41:
42:
43: # Better approach: use 3D grid (E, cdiv(max_M, BM), cdiv(I, BN)).
44: # The kernel computes row = off_start[pid_e] + pid_bm * BM + local_m.
45: # Sha
... (4865 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,9 +1,7 @@
"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-Single kernel handles all experts in one launch. The 3D grid is
-(E, cdiv(M_expert, BM), cdiv(I, BN)), but we use a 2D grid
-(E * cdiv(total_rows/avg, BM), cdiv(I, BN)) and compute M from offsets.
-We keep tiles small enough for shared memory.
+Flattened grid approach: grid dimension 0 covers all row-blocks across all experts.
+A per-block expert-lookup uses binary search on offsets (or a cumulative tile-count array).
"""
from __future__ import annotations
@@ -11,6 +9,23 @@
import torch.nn as nn
import triton
import triton.language as tl
+
+
+# We use a simple approach: grid is (total_m_blocks, cdiv(I, BN), E)
+# Actually, let's keep it as (E, cdiv(n_tokens_for_expert_e, BM), cdiv(I, BN))
+# but compute the grid size from actual per-expert tile counts.
+
+# The issue with 3D grid: the grid size in dimension 1 is not uniform across experts.
+# Standard solution: flatten dimension 0+1 into a single dimension.
+#
+# For each expert e, num_tiles_e = cdiv(n_e, BM)
+# cumulative_tiles[e] = sum(num_tiles_e' for e' < e)
+# total_tiles = cumulative_tiles[E]
+#
+# Grid: (total_tiles, cdiv(I, BN))
+# Given pid_0 in range [cumulative_tiles[e], cumulative_tiles[e+1]),
+# expert_id = e
+# pid_bm = pid_0 - cumulative_tiles[e]
@triton.jit
@@ -20,94 +35,91 @@
W_up_ptr,
Out_ptr,
Offsets_ptr,
+ TileOffs_ptr, # (E+1,) int32, cumulative start tile indices
H,
I,
+ BM,
+ BN,
stride_x_row,
stride_w_row,
stride_w_col,
stride_o_row,
- BM: tl.constexpr,
- BN: tl.constexpr,
- BK: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
):
- pid_mn = tl.program_id(0)
+ pid_tile = tl.program_id(0)
pid_n = tl.program_id(1)
- # We encode expert id in the upper bits of pid_mn for small E,
- # OR we restructure as 3D grid: (E, BN, BK). Let me think differently.
- # Actually the standard approach: grid (total_tiles_M, cdiv(I, BN)).
- # We need to know which expert each tile belongs to and the intra-expert row.
- pass
-
-
-# Better approach: use 3D grid (E, cdiv(max_M, BM), cdiv(I, BN)).
-# The kernel computes row = off_start[pid_e] + pid_bm * BM + local_m.
-# Shared memory: BM * BK + 2 * BK * BN bf16 values.
-# For BM=128, BN=128, BK=64: 128*64 + 2*64*128 = 8192 + 16384 = 24576 * 2 = 49152 (fits!)
-# For BM=64, BN=64, BK=128: 64*128 + 2*128*64 = 8192 + 16384 = 24576 * 2 = 49152 (fits!)
-
-@triton.jit
-def _gemma_gemm_kernel(
- X_ptr,
- W_gate_ptr,
- W_up_ptr,
- Out_ptr,
- Offsets_ptr,
- H,
- I,
- stride_x_row,
- stride_w_row,
- stride_w_col,
- stride_o_row,
- BM: tl.constexpr,
- BN: tl.constexpr,
- BK: tl.constexpr,
-):
- pid_e = tl.program_id(0)
- pid_bm = tl.program_id(1)
- pid_n = tl.program_id(2)
-
- # Expert bounds
- os = tl.load(Offsets_ptr + pid_e).to(tl.int64)
- oe = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
- n_e = oe - os
-
- # Local M offset
- m_local = pid_bm * BM + tl.arange(0, BM)
+ # Find which expert this tile belongs to via binary search on TileOffs.
+ # TileOffs has E+1 entries; we need the largest e such that TileOffs[e] <= pid_tile.
+ # Since E is small (64 or 128), linear search is fine.
+ lo = 0
+ hi = Offsets_ptr.numel() - 2 # E-1, but we don't know E at compile time
+ # We'll just do a simple binary search
+ expert = 0
+
+ # Binary search: find largest lo such that TileOffs[lo] <= pid_tile
+ left = 0
+ right = 128 # upper bound on E
+ while left < right:
+ mid = (left + right + 1) // 2
+ val = tl.load(TileOffs_ptr + mid).to(tl.int32)
+ if val <= pid_tile:
+ left = mid
+ else:
+ right = mid - 1
+ expert = left
+
+ tile_base = tl.load(TileOffs_ptr + expert).to(tl.int32)
+ pid_bm = pid_tile - tile_base
+
+ # Expert row bounds
+ row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
+ row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
+ n_e = row_end - row_start
+
+ # Local row offsets within expert
+ m_local = pid_bm * BLOCK_M + tl.arange(0, BLOCK_M)
m_valid = m_local < n_e
- # Global row
- row_g = os + m_local
-
- # N offsets
- n_local = pid_n * BN + tl.arange(0, BN)
+ # Global rows
+ row_g = row_start + m_local
+
+ # Column offsets
+ n_local = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_valid = n_local < I
# K offsets
- k_local = tl.arange(0, BK)
+ k_local = tl.arange(0, BLOCK_K)
# Accumulators
- acc_g = tl.zeros((BM, BN), dtype=tl.float32)
- acc_u = tl.zeros((BM, BN), dtype=tl.float32)
-
- # Weight base for this expert
- wg_base = W_gate_ptr + pid_e * stride_w_row * H
- wu_base = W_up_ptr + pid_e * stride_w_row * H
-
- for k_start in range(0, H, BK):
+ acc_g = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_u = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ # Weight base for this expert: W is (E, H, I), row-major
+ # stride_w_row = W.stride(1) = I, stride_w_col = W.stride(2) = 1
+ # W[e, k, n] = W_ptr + e*H*I + k*I + n
+ wg_base = W_gate_ptr + expert * H * I.to(tl.int64)
+ wu_base = W_up_ptr + expert * H * I.to(tl.int64)
+
+ for k_start in range(0, H, BLOCK_K):
k_offs = k_start + k_local
k_mask = k_offs < H
- # A[m, k]
- a_ptrs = X_ptr + row_g[:, None] * stride_x_row + k_offs[None, :]
+ # A[m, k]: (BLOCK_M, BLOCK_K)
+ a_ptrs = X_ptr + row_g.to(tl.int64)[:, None] * stride_x_row + k_offs[None, :] * stride_x_row // BLOCK_M
+ # Actually, stride_x_row might not be H if the tensor isn't contiguous...
+ # Let me just use stride_x_row directly
+
... (4233 more chars){
"content": "<7051 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,7 @@
"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-Flattened grid approach: grid dimension 0 covers all row-blocks across all experts.
-A per-block expert-lookup uses binary search on offsets (or a cumulative tile-count array).
+Flattened 2D grid: (cdiv(T_perm, BM), cdiv(I, BN)) with expert lookup
+via binary search on offsets.
"""
from __future__ import annotations
@@ -11,23 +11,6 @@
import triton.language as tl
-# We use a simple approach: grid is (total_m_blocks, cdiv(I, BN), E)
-# Actually, let's keep it as (E, cdiv(n_tokens_for_expert_e, BM), cdiv(I, BN))
-# but compute the grid size from actual per-expert tile counts.
-
-# The issue with 3D grid: the grid size in dimension 1 is not uniform across experts.
-# Standard solution: flatten dimension 0+1 into a single dimension.
-#
-# For each expert e, num_tiles_e = cdiv(n_e, BM)
-# cumulative_tiles[e] = sum(num_tiles_e' for e' < e)
-# total_tiles = cumulative_tiles[E]
-#
-# Grid: (total_tiles, cdiv(I, BN))
-# Given pid_0 in range [cumulative_tiles[e], cumulative_tiles[e+1]),
-# expert_id = e
-# pid_bm = pid_0 - cumulative_tiles[e]
-
-
@triton.jit
def _grouped_gemm_swiglu_kernel(
X_ptr,
@@ -35,132 +18,97 @@
W_up_ptr,
Out_ptr,
Offsets_ptr,
- TileOffs_ptr, # (E+1,) int32, cumulative start tile indices
+ T_perm,
H,
- I,
- BM,
- BN,
- stride_x_row,
- stride_w_row,
- stride_w_col,
- stride_o_row,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
+ I_val,
+ E,
+ BM: tl.constexpr,
+ BN: tl.constexpr,
+ BK: tl.constexpr,
):
- pid_tile = tl.program_id(0)
+ pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
- # Find which expert this tile belongs to via binary search on TileOffs.
- # TileOffs has E+1 entries; we need the largest e such that TileOffs[e] <= pid_tile.
- # Since E is small (64 or 128), linear search is fine.
+ # Find expert for the first row in this block
+ row_0 = pid_m * BM
+ # Binary search for expert: find smallest e such that Offsets[e+1] > row_0
lo = 0
- hi = Offsets_ptr.numel() - 2 # E-1, but we don't know E at compile time
- # We'll just do a simple binary search
- expert = 0
-
- # Binary search: find largest lo such that TileOffs[lo] <= pid_tile
- left = 0
- right = 128 # upper bound on E
- while left < right:
- mid = (left + right + 1) // 2
- val = tl.load(TileOffs_ptr + mid).to(tl.int32)
- if val <= pid_tile:
- left = mid
+ hi = E - 1
+ while lo < hi:
+ mid = (lo + hi) // 2
+ off = tl.load(Offsets_ptr + mid + 1).to(tl.int64)
+ if off <= row_0:
+ lo = mid + 1
else:
- right = mid - 1
- expert = left
+ hi = mid
+ expert = lo
- tile_base = tl.load(TileOffs_ptr + expert).to(tl.int32)
- pid_bm = pid_tile - tile_base
-
- # Expert row bounds
row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
- n_e = row_end - row_start
- # Local row offsets within expert
- m_local = pid_bm * BLOCK_M + tl.arange(0, BLOCK_M)
- m_valid = m_local < n_e
+ # M offset (intra-expert rows)
+ m_local = pid_m * BM + tl.arange(0, BM) - row_start
+ m_valid = (pid_m * BM + tl.arange(0, BM)) < row_end
# Global rows
- row_g = row_start + m_local
+ row_g = pid_m * BM + tl.arange(0, BM)
+ row_valid = row_g < T_perm & m_valid
- # Column offsets
- n_local = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- n_valid = n_local < I
+ # N offset
+ n_offs = pid_n * BN + tl.arange(0, BN)
+ n_valid = n_offs < I_val
- # K offsets
- k_local = tl.arange(0, BLOCK_K)
+ # K offset
+ k_offs = tl.arange(0, BK)
- # Accumulators
- acc_g = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_u = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_g = tl.zeros((BM, BN), dtype=tl.float32)
+ acc_u = tl.zeros((BM, BN), dtype=tl.float32)
- # Weight base for this expert: W is (E, H, I), row-major
- # stride_w_row = W.stride(1) = I, stride_w_col = W.stride(2) = 1
- # W[e, k, n] = W_ptr + e*H*I + k*I + n
- wg_base = W_gate_ptr + expert * H * I.to(tl.int64)
- wu_base = W_up_ptr + expert * H * I.to(tl.int64)
+ # Weight base: W is (E, H, I), row major
+ # W[e, k, n] offset = e * H * I + k * I + n
+ w_base = expert * H * I_val
- for k_start in range(0, H, BLOCK_K):
- k_offs = k_start + k_local
- k_mask = k_offs < H
+ for k_start in range(0, H, BK):
+ k_mask = k_start + k_offs < H
+ k = k_start + k_offs
- # A[m, k]: (BLOCK_M, BLOCK_K)
- a_ptrs = X_ptr + row_g.to(tl.int64)[:, None] * stride_x_row + k_offs[None, :] * stride_x_row // BLOCK_M
- # Actually, stride_x_row might not be H if the tensor isn't contiguous...
- # Let me just use stride_x_row directly
- a_ptrs = X_ptr + row_g.to(tl.int64)[:, None] * stride_x_row + k_offs[None, :]
- a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
+ # A[row, k]
+ a_ptrs = X_ptr + row_g[:, None] * H + k[None, :]
+ a = tl.load(a_ptrs, mask=row_valid[:, None] & k_mask[None, :], other=0.0)
- # W_gate: (BLOCK_K, BLOCK_N) -> W[e, k, n]
- wg_ptrs = wg_base + k_offs[:, None].to(tl.int64) * stride_w_row + n_local[None, :].to(tl.int64) * stride_w_col
+ # W_gate[e, k, n]
+ wg_ptrs = W_gate_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
- # W_up
- wu_ptrs = wu_base + k_offs[:, None].to(tl.int64) * stride_w_row + n_local[None, :].to(tl.int64) * stride_
... (2914 more chars){
"content": "<4100 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 127, in forward
_launch(hidden_states, self.W_gate, self.W_up, out, expert_offsets)
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 108, in _launch
_grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit
... (870 more chars){
"command": "<401 chars \u2014 see below>",
"description": "Find valid block sizes within shared memory limit"
}uv run python -c "
# Shared memory for BM x BK A + 2 x BK x BN W (num_stages=1)
for BM in [32, 64, 128]:
for BN in [32, 64, 128]:
for BK in [32, 64, 128]:
shm = (BM * BK + BK * BN + BK * BN) * 2 # bf16 = 2 bytes
if shm <= 101376:
flops = BM * BN * BK # fma per tile
print(f'BM={BM} BN={BN} BK={BK} shm={shm} flops/tile={flops}')
"
BM=32 BN=32 BK=32 shm=6144 flops/tile=32768
BM=32 BN=32 BK=64 shm=12288 flops/tile=65536
BM=32 BN=32 BK=128 shm=24576 flops/tile=131072
BM=32 BN=64 BK=32 shm=10240 flops/tile=65536
BM=32 BN=64 BK=64 shm=20480 flops/tile=131072
BM=32 BN=64 BK=128 shm=40960 flops/tile=262144
BM=32 BN=128 BK=32 shm=18432 flops/tile=131072
BM=32 BN=128 BK=64 shm=36864 flops/tile=262144
BM=32 BN=128 BK=128 shm=73728 flops/tile=524288
BM=64 BN=32 BK=32 shm=8192 flops/tile=65536
BM=64 BN=32 BK=64 shm=16384 flops/tile=131072
BM=64 BN=32 BK=128 shm=32768 flops/tile=262144
BM=64 BN=64 BK=32 shm=12288 flops/tile=131072
BM=64 BN=64 BK=64 shm=24576 flops/tile=262144
BM=64 BN=64 BK=128 shm=49152 flops/tile=524288
BM=64 BN=128 BK=32 shm=20480 flops/tile=262144
BM=64 BN=128 BK=64 shm=40960 flops/tile=524288
BM=64 BN=128 BK=128 shm=81920 flops/tile=1048576
BM=128 BN=32 BK=32 shm=12288 flops/tile=131072
BM=128 BN=32 BK=64 shm=24576 flops/tile=262144
BM=128 BN=32 BK=128 shm=49152 flops/tile=524288
BM=128 BN=64 BK=32 shm=16384 flops/tile=262144
BM=128 BN=64 BK=64 shm=32768 flops/tile=524288
BM=128 BN=64 BK=128 shm=65536 flops/tile=1048576
BM=128 BN=128 BK=32 shm=24576 flops/tile=524288
BM=128 BN=128 BK=64 shm=49152 flops/tile=1048576
BM=128 BN=128 BK=128 shm=98304 flops/tile=2097152
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,7 @@
"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-Flattened 2D grid: (cdiv(T_perm, BM), cdiv(I, BN)) with expert lookup
-via binary search on offsets.
+Single kernel launch, flattened 2D grid (cdiv(T_perm, BM), cdiv(I, BN)).
+Each kernel block determines its owning expert by binary-searching offsets.
"""
from __future__ import annotations
@@ -21,7 +21,6 @@
T_perm,
H,
I_val,
- E,
BM: tl.constexpr,
BN: tl.constexpr,
BK: tl.constexpr,
@@ -29,67 +28,69 @@
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
- # Find expert for the first row in this block
- row_0 = pid_m * BM
- # Binary search for expert: find smallest e such that Offsets[e+1] > row_0
+ # First row this block covers
+ row_start_block = pid_m * BM
+
+ # Binary search for expert: find smallest e where Offsets[e+1] > row_start_block
lo = 0
- hi = E - 1
+ hi = 127 # max E is 128
while lo < hi:
- mid = (lo + hi) // 2
+ mid = (lo + hi) >> 1
off = tl.load(Offsets_ptr + mid + 1).to(tl.int64)
- if off <= row_0:
+ if off <= row_start_block:
lo = mid + 1
else:
hi = mid
expert = lo
+ # Read expert boundaries
row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
- # M offset (intra-expert rows)
- m_local = pid_m * BM + tl.arange(0, BM) - row_start
- m_valid = (pid_m * BM + tl.arange(0, BM)) < row_end
+ # Local row offsets within this expert
+ m_local = pid_m * BM - row_start + tl.arange(0, BM)
+ m_valid = m_local >= 0 and m_local < (row_end - row_start)
+ row_valid = row_start_block + tl.arange(0, BM) < T_perm
- # Global rows
row_g = pid_m * BM + tl.arange(0, BM)
- row_valid = row_g < T_perm & m_valid
# N offset
n_offs = pid_n * BN + tl.arange(0, BN)
n_valid = n_offs < I_val
- # K offset
+ # K
k_offs = tl.arange(0, BK)
+ # Accumulators
acc_g = tl.zeros((BM, BN), dtype=tl.float32)
acc_u = tl.zeros((BM, BN), dtype=tl.float32)
- # Weight base: W is (E, H, I), row major
- # W[e, k, n] offset = e * H * I + k * I + n
- w_base = expert * H * I_val
+ # Weight base: expert * H * I
+ w_base = expert.to(tl.int64) * H * I_val
+ # Main GEMM loop
for k_start in range(0, H, BK):
- k_mask = k_start + k_offs < H
k = k_start + k_offs
+ k_mask = k < H
- # A[row, k]
- a_ptrs = X_ptr + row_g[:, None] * H + k[None, :]
+ # X[row, k]: stride is H (row-major contiguous)
+ a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * H + k[None, :]
a = tl.load(a_ptrs, mask=row_valid[:, None] & k_mask[None, :], other=0.0)
- # W_gate[e, k, n]
+ # W_gate[expert, k, n]: (H, I) slice
wg_ptrs = W_gate_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
- # W_up[e, k, n]
+ # W_up[expert, k, n]
wu_ptrs = W_up_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
- acc_g = tl.dot(a, wg, acc_g, out_dtype=tl.float32)
- acc_u = tl.dot(a, wu, acc_u, out_dtype=tl.float32)
+ acc_g += tl.dot(a, wg)
+ acc_u += tl.dot(a, wu)
# SwiGLU
out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
- out_ptrs = Out_ptr + row_g[:, None] * I_val + n_offs[None, :]
+ out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * I_val + n_offs[None, :]
tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=row_valid[:, None] & n_valid[None, :])
@@ -97,17 +98,16 @@
T_perm = X.size(0)
H = X.size(1)
I_val = out.size(1)
- E_val = expert_offsets.size(0) - 1
- BM = 256
- BN = 128
+ BM = 128
+ BN = 64
BK = 128
grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
_grouped_gemm_swiglu_kernel[grid](
X, W_gate, W_up, out, expert_offsets,
- T_perm, H, I_val, E_val,
+ T_perm, H, I_val,
BM=BM, BN=BN, BK=BK,
)
{
"content": "<4167 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check"
}/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py:52: UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead
m_valid = m_local >= 0 and m_local < (row_end - row_start)
UserWarning: Enable tracemalloc to get the object allocation traceback
/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py:52: UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead
m_valid = m_local >= 0 and m_local < (row_end - row_start)
UserWarning: Enable tracemalloc to get the object allocation traceback
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_s
... (1530 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -105,6 +105,7 @@
grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
+ # num_stages=1 to stay within shared memory limit on SM120
_grouped_gemm_swiglu_kernel[grid](
X, W_gate, W_up, out, expert_offsets,
T_perm, H, I_val,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<100 chars \u2014 see below>",
"newString": "<163 chars \u2014 see below>"
} grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
_grouped_gemm_swiglu_kernel[grid]( grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
# num_stages=1 to stay within shared memory limit on SM120
_grouped_gemm_swiglu_kernel[grid](Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -105,11 +105,11 @@
grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
- # num_stages=1 to stay within shared memory limit on SM120
_grouped_gemm_swiglu_kernel[grid](
X, W_gate, W_up, out, expert_offsets,
T_perm, H, I_val,
BM=BM, BN=BN, BK=BK,
+ num_stages=1, num_warps=8,
)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"newString": "<242 chars \u2014 see below>",
"oldString": "<270 chars \u2014 see below>"
} grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
_grouped_gemm_swiglu_kernel[grid](
X, W_gate, W_up, out, expert_offsets,
T_perm, H, I_val,
BM=BM, BN=BN, BK=BK,
num_stages=1, num_warps=8,
) grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
# num_stages=1 to stay within shared memory limit on SM120
_grouped_gemm_swiglu_kernel[grid](
X, W_gate, W_up, out, expert_offsets,
T_perm, H, I_val,
BM=BM, BN=BN, BK=BK,
)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -49,7 +49,7 @@
# Local row offsets within this expert
m_local = pid_m * BM - row_start + tl.arange(0, BM)
- m_valid = m_local >= 0 and m_local < (row_end - row_start)
+ m_valid = (m_local >= 0) & (m_local < (row_end - row_start))
row_valid = row_start_block + tl.arange(0, BM) < T_perm
row_g = pid_m * BM + tl.arange(0, BM)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"newString": " m_valid = (m_local >= 0) & (m_local < (row_end - row_start))",
"oldString": " m_valid = m_local >= 0 and m_local < (row_end - row_start)"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0
... (808 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,7 @@
"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-Single kernel launch, flattened 2D grid (cdiv(T_perm, BM), cdiv(I, BN)).
-Each kernel block determines its owning expert by binary-searching offsets.
+3D grid (E, cdiv(MAX_M, BM), cdiv(I, BN)) - no binary search needed.
+Blocks whose range falls outside the expert's row range are no-ops.
"""
from __future__ import annotations
@@ -18,71 +18,68 @@
W_up_ptr,
Out_ptr,
Offsets_ptr,
- T_perm,
+ MAX_M,
H,
I_val,
+ stride_x_row,
+ stride_o_row,
BM: tl.constexpr,
BN: tl.constexpr,
BK: tl.constexpr,
):
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
+ pid_e = tl.program_id(0)
+ pid_bm = tl.program_id(1)
+ pid_n = tl.program_id(2)
- # First row this block covers
- row_start_block = pid_m * BM
+ # Expert row bounds
+ row_start = tl.load(Offsets_ptr + pid_e).to(tl.int64)
+ row_end = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
+ n_e = row_end - row_start
- # Binary search for expert: find smallest e where Offsets[e+1] > row_start_block
- lo = 0
- hi = 127 # max E is 128
- while lo < hi:
- mid = (lo + hi) >> 1
- off = tl.load(Offsets_ptr + mid + 1).to(tl.int64)
- if off <= row_start_block:
- lo = mid + 1
- else:
- hi = mid
- expert = lo
+ # First row of this block
+ block_first = pid_bm * BM
- # Read expert boundaries
- row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
- row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
+ # If this block is entirely outside the expert's range, skip
+ if block_first >= n_e:
+ return
- # Local row offsets within this expert
- m_local = pid_m * BM - row_start + tl.arange(0, BM)
- m_valid = (m_local >= 0) & (m_local < (row_end - row_start))
- row_valid = row_start_block + tl.arange(0, BM) < T_perm
+ # Local row offsets within expert
+ m_local = block_first + tl.arange(0, BM)
+ m_valid = m_local < n_e
- row_g = pid_m * BM + tl.arange(0, BM)
+ # Global row indices
+ row_g = row_start + m_local
- # N offset
+ # Column offsets
n_offs = pid_n * BN + tl.arange(0, BN)
n_valid = n_offs < I_val
- # K
+ # Reduction dimension offsets
k_offs = tl.arange(0, BK)
# Accumulators
acc_g = tl.zeros((BM, BN), dtype=tl.float32)
acc_u = tl.zeros((BM, BN), dtype=tl.float32)
- # Weight base: expert * H * I
- w_base = expert.to(tl.int64) * H * I_val
+ # Weight base for this expert: W is (E, H, I), row-major
+ wg_base = W_gate_ptr + pid_e * H * I_val
+ wu_base = W_up_ptr + pid_e * H * I_val
- # Main GEMM loop
+ # Main GEMM loop over K=H dimension
for k_start in range(0, H, BK):
k = k_start + k_offs
k_mask = k < H
- # X[row, k]: stride is H (row-major contiguous)
- a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * H + k[None, :]
- a = tl.load(a_ptrs, mask=row_valid[:, None] & k_mask[None, :], other=0.0)
+ # Load A: (BM, BK)
+ a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * stride_x_row + k[None, :]
+ a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
- # W_gate[expert, k, n]: (H, I) slice
- wg_ptrs = W_gate_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
+ # Load W_gate[e, k, n]: (BK, BN)
+ wg_ptrs = wg_base + k[:, None] * I_val + n_offs[None, :]
wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
- # W_up[expert, k, n]
- wu_ptrs = W_up_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
+ # Load W_up[e, k, n]: (BK, BN)
+ wu_ptrs = wu_base + k[:, None] * I_val + n_offs[None, :]
wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
acc_g += tl.dot(a, wg)
@@ -90,26 +87,52 @@
# SwiGLU
out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
- out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * I_val + n_offs[None, :]
- tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=row_valid[:, None] & n_valid[None, :])
+ out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * stride_o_row + n_offs[None, :]
+ tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
def _launch(X, W_gate, W_up, out, expert_offsets):
T_perm = X.size(0)
H = X.size(1)
I_val = out.size(1)
+ E = expert_offsets.size(0) - 1
- BM = 128
- BN = 64
- BK = 128
+ # Heuristic: pick block sizes that maximize FLOPS/tile while fitting in shared memory
+ # Shared memory: (BM * BK + 2 * BK * BN) * 2 bytes (bf16, num_stages=1)
+ # For BM=256, BN=256, BK=64: (256*64 + 2*64*256)*2 = (16384+32768)*2 = 98304 < 101k OK
+ # For BM=256, BN=128, BK=128: (256*128 + 2*128*128)*2 = (32768+32768)*2 = 131072 - too big
+ # For BM=128, BN=128, BK=128: (128*128 + 2*128*128)*2 = (16384+32768)*2 = 98304 < 101k OK
+ if H >= 2048:
+ BM = 128
+ BN = 128
+ BK = 128
+ else:
+ BM = 64
+ BN = 128
+ BK = 64
- grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
+ # For SM120, try larger blocks when possible
+ # 128x256x64: (128*64 + 2*64*256)*2 = (8192+32768)*2 = 81920 - OK
+ # 256x128x64: (256*64 + 2*64*128)*2 = (16384+16384)*2 = 65536 - OK
+ if I_val >= 1024 and H >= 4096:
+ BM = 256
+ BN = 128
+ BK = 64
+ elif I_val >= 4096:
+ BM = 128
+ BN = 128
+ BK = 128
+
+ grid = (E, triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
_grouped_gemm_swiglu_kernel[grid](
X, W_gate, W_up, out, expert_offsets,
T_perm, H, I_val,
+ stride_x_row=X.stride(0),
+ stride_o_row=out.stride(0),
... (495 more chars){
"content": "<5072 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0
... (808 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,7 @@
"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-3D grid (E, cdiv(MAX_M, BM), cdiv(I, BN)) - no binary search needed.
-Blocks whose range falls outside the expert's row range are no-ops.
+3D grid (E, cdiv(n_tokens_for_e, BM), cdiv(I, BN)) - grid size is
+computed from actual expert token counts, no wasted blocks.
"""
from __future__ import annotations
@@ -18,7 +18,6 @@
W_up_ptr,
Out_ptr,
Offsets_ptr,
- MAX_M,
H,
I_val,
stride_x_row,
@@ -36,15 +35,8 @@
row_end = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
n_e = row_end - row_start
- # First row of this block
- block_first = pid_bm * BM
-
- # If this block is entirely outside the expert's range, skip
- if block_first >= n_e:
- return
-
# Local row offsets within expert
- m_local = block_first + tl.arange(0, BM)
+ m_local = pid_bm * BM + tl.arange(0, BM)
m_valid = m_local < n_e
# Global row indices
@@ -88,6 +80,129 @@
# SwiGLU
out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * stride_o_row + n_offs[None, :]
+ tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
+
+
+def _expert_grid(expert_offsets, BM, BN, I):
+ """Build 3D grid from expert_offsets: (E_list, M_list, N_list) where
+ E_list is just [0,1,...,E-1] repeated, and M_list/N_list are per-expert tile counts."""
+ E = len(expert_offsets) - 1
+ counts = expert_offsets[1:] - expert_offsets[:-1] # tokens per expert
+ m_tiles = (counts + BM - 1) // BM # cdiv per expert
+ n_tiles = (I + BN - 1) // BN
+
+ # Create grid as tensor of shape (total_work_items, 3)
+ # Actually, Triton doesn't support tensor grids. We need to use a flat list approach.
+ # Instead, let's use a 1D grid and decode (e, m, n) from the flat index.
+ total_work = int(m_tiles.sum()) * n_tiles
+ return E, int(m_tiles.sum()), n_tiles, m_tiles, total_work
+
+
+@triton.jit
+def _grouped_gemm_1d_kernel(
+ X_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ Out_ptr,
+ Offsets_ptr,
+ MTilesPrefix_ptr, # cumulative sum of m_tiles: (E+1,) int32, prefix sum of m_tiles
+ H,
+ I_val,
+ total_m_tiles,
+ BN: tl.constexpr,
+ stride_x_row,
+ stride_o_row,
+ BM: tl.constexpr,
+ BLOCK_BN: tl.constexpr,
+ BK: tl.constexpr,
+):
+ pid_work = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ pid_bm_in_expert = pid_work # each pid_work handles one M-tile for one expert
+ # But we need to know WHICH expert AND which M-tile.
+ # Actually, let's restructure: grid = (total_m_tiles, cdiv(I, BN))
+ # Then for each pid_m in total_m_tiles, find which expert it belongs to.
+ pass
+
+
+# Simplest approach: 2D flattened grid (total_tiles_M, cdiv(I, BN))
+# where total_tiles_M = sum_e cdiv(n_tokens_e, BM)
+# Use binary search on cumulative_M_tiles to find expert index.
+
+
+@triton.jit
+def _grouped_gemm_flat_kernel(
+ X_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ Out_ptr,
+ Offsets_ptr,
+ CumMTiles_ptr, # (E+1,) int32, cumulative M-tiles: CumMTiles[e] = sum of cdiv(n_tokens_e', BM) for e' < e
+ H,
+ I_val,
+ n_experts,
+ BM: tl.constexpr,
+ BN: tl.constexpr,
+ BK: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ # Find expert for this M-tile: largest e such that CumMTiles[e] <= pid_m
+ lo = 0
+ hi = n_experts - 1
+ while lo < hi:
+ mid = (lo + hi + 1) >> 1
+ val = tl.load(CumMTiles_ptr + mid).to(tl.int32)
+ if val <= pid_m:
+ lo = mid
+ else:
+ hi = mid - 1
+ expert = lo
+
+ # M-tile index within this expert
+ tile_off = tl.load(CumMTiles_ptr + expert).to(tl.int32)
+ pid_bm = pid_m - tile_off
+
+ # Expert row bounds
+ row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
+ row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
+ n_e = row_end - row_start
+
+ m_local = pid_bm * BM + tl.arange(0, BM)
+ m_valid = m_local < n_e
+
+ row_g = row_start + m_local
+
+ n_offs = pid_n * BN + tl.arange(0, BN)
+ n_valid = n_offs < I_val
+
+ k_offs = tl.arange(0, BK)
+
+ acc_g = tl.zeros((BM, BN), dtype=tl.float32)
+ acc_u = tl.zeros((BM, BN), dtype=tl.float32)
+
+ wg_base = W_gate_ptr + expert * H * I_val
+ wu_base = W_up_ptr + expert * H * I_val
+
+ for k_start in range(0, H, BK):
+ k = k_start + k_offs
+ k_mask = k < H
+
+ a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * H + k[None, :]
+ a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
+
+ wg_ptrs = wg_base + k[:, None] * I_val + n_offs[None, :]
+ wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
+
+ wu_ptrs = wu_base + k[:, None] * I_val + n_offs[None, :]
+ wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
+
+ acc_g += tl.dot(a, wg)
+ acc_u += tl.dot(a, wu)
+
+ out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
+ out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * I_val + n_offs[None, :]
tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
@@ -96,40 +211,42 @@
H = X.size(1)
I_val = out.size(1)
E = expert_offsets.size(0) - 1
-
- # Heuristic: pick block sizes that maximize FLOPS/tile while fitting in shared memory
- # Shared memory: (BM * BK + 2 * BK * BN) * 2 bytes (bf16, num_stages=1)
- # For BM=256, BN=256, BK=64: (256*64 + 2*64*256)*2 = (16384+32768)*2 = 98304 < 101k OK
- # For BM=256, BN=128, BK=128: (256*128 + 2*128*128)*2 = (32768+32768)*2 = 131072 - too big
- # For BM=128, BN=128, BK=128: (128*128 + 2*128*128)*2 = (16384+32768)*2 = 98304 < 101k OK
- if H >=
... (1858 more chars){
"content": "<8717 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,9 @@
"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-3D grid (E, cdiv(n_tokens_for_e, BM), cdiv(I, BN)) - grid size is
-computed from actual expert token counts, no wasted blocks.
+2D grid: (sum_e cdiv(n_e, BM), cdiv(I, BN)) with per-block expert lookup
+via binary search on cumulative tile counts.
+
+Block sizes tuned for SM120 shared memory budget (~98KB usable).
"""
from __future__ import annotations
@@ -13,177 +15,61 @@
@triton.jit
def _grouped_gemm_swiglu_kernel(
- X_ptr,
- W_gate_ptr,
- W_up_ptr,
- Out_ptr,
- Offsets_ptr,
+ X_ptr, # (T_perm, H), row-major
+ W_gate_ptr, # (E, H, I), row-major
+ W_up_ptr, # (E, H, I), row-major
+ Out_ptr, # (T_perm, I), row-major
+ Offsets_ptr, # (E+1,) int32
+ CumMTiles_ptr, # (E+1,) int32
H,
I_val,
- stride_x_row,
- stride_o_row,
BM: tl.constexpr,
BN: tl.constexpr,
BK: tl.constexpr,
+ NUM_SMS: tl.constexpr,
):
- pid_e = tl.program_id(0)
- pid_bm = tl.program_id(1)
- pid_n = tl.program_id(2)
+ pid = tl.program_id(0)
+ pid_n = tl.program_id(1)
- # Expert row bounds
- row_start = tl.load(Offsets_ptr + pid_e).to(tl.int64)
- row_end = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
+ # 1D -> (pid_m, ...) mapping
+ pid_m = pid
+
+ # Find expert: largest e such that CumMTiles[e] <= pid_m
+ # E is at most 128; binary search
+ lo = 0
+ hi = 127
+ while lo < hi:
+ mid = (lo + hi + 1) >> 1
+ if tl.load(CumMTiles_ptr + mid).to(tl.int32) <= pid_m:
+ lo = mid
+ else:
+ hi = mid - 1
+ expert = lo
+
+ # Intra-expert M-tile index
+ m_tile_idx = pid_m - tl.load(CumMTiles_ptr + expert).to(tl.int32)
+
+ # Expert row range
+ row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
+ row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
n_e = row_end - row_start
- # Local row offsets within expert
- m_local = pid_bm * BM + tl.arange(0, BM)
- m_valid = m_local < n_e
-
- # Global row indices
- row_g = row_start + m_local
+ # Row offsets
+ m_offs = m_tile_idx * BM + tl.arange(0, BM)
+ m_valid = m_offs < n_e
+ row_g = row_start + m_offs
# Column offsets
n_offs = pid_n * BN + tl.arange(0, BN)
n_valid = n_offs < I_val
- # Reduction dimension offsets
- k_offs = tl.arange(0, BK)
-
- # Accumulators
- acc_g = tl.zeros((BM, BN), dtype=tl.float32)
- acc_u = tl.zeros((BM, BN), dtype=tl.float32)
-
- # Weight base for this expert: W is (E, H, I), row-major
- wg_base = W_gate_ptr + pid_e * H * I_val
- wu_base = W_up_ptr + pid_e * H * I_val
-
- # Main GEMM loop over K=H dimension
- for k_start in range(0, H, BK):
- k = k_start + k_offs
- k_mask = k < H
-
- # Load A: (BM, BK)
- a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * stride_x_row + k[None, :]
- a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
-
- # Load W_gate[e, k, n]: (BK, BN)
- wg_ptrs = wg_base + k[:, None] * I_val + n_offs[None, :]
- wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
-
- # Load W_up[e, k, n]: (BK, BN)
- wu_ptrs = wu_base + k[:, None] * I_val + n_offs[None, :]
- wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
-
- acc_g += tl.dot(a, wg)
- acc_u += tl.dot(a, wu)
-
- # SwiGLU
- out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
- out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * stride_o_row + n_offs[None, :]
- tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
-
-
-def _expert_grid(expert_offsets, BM, BN, I):
- """Build 3D grid from expert_offsets: (E_list, M_list, N_list) where
- E_list is just [0,1,...,E-1] repeated, and M_list/N_list are per-expert tile counts."""
- E = len(expert_offsets) - 1
- counts = expert_offsets[1:] - expert_offsets[:-1] # tokens per expert
- m_tiles = (counts + BM - 1) // BM # cdiv per expert
- n_tiles = (I + BN - 1) // BN
-
- # Create grid as tensor of shape (total_work_items, 3)
- # Actually, Triton doesn't support tensor grids. We need to use a flat list approach.
- # Instead, let's use a 1D grid and decode (e, m, n) from the flat index.
- total_work = int(m_tiles.sum()) * n_tiles
- return E, int(m_tiles.sum()), n_tiles, m_tiles, total_work
-
-
-@triton.jit
-def _grouped_gemm_1d_kernel(
- X_ptr,
- W_gate_ptr,
- W_up_ptr,
- Out_ptr,
- Offsets_ptr,
- MTilesPrefix_ptr, # cumulative sum of m_tiles: (E+1,) int32, prefix sum of m_tiles
- H,
- I_val,
- total_m_tiles,
- BN: tl.constexpr,
- stride_x_row,
- stride_o_row,
- BM: tl.constexpr,
- BLOCK_BN: tl.constexpr,
- BK: tl.constexpr,
-):
- pid_work = tl.program_id(0)
- pid_n = tl.program_id(1)
- pid_bm_in_expert = pid_work # each pid_work handles one M-tile for one expert
- # But we need to know WHICH expert AND which M-tile.
- # Actually, let's restructure: grid = (total_m_tiles, cdiv(I, BN))
- # Then for each pid_m in total_m_tiles, find which expert it belongs to.
- pass
-
-
-# Simplest approach: 2D flattened grid (total_tiles_M, cdiv(I, BN))
-# where total_tiles_M = sum_e cdiv(n_tokens_e, BM)
-# Use binary search on cumulative_M_tiles to find expert index.
-
-
-@triton.jit
-def _grouped_gemm_flat_kernel(
- X_ptr,
- W_gate_ptr,
- W_up_ptr,
- Out_ptr,
- Offsets_ptr,
- CumMTiles_ptr, # (E+1,) int32, cumulative M-tiles: CumMTiles[e] = sum of cdiv(n_tokens_e', BM) for e' < e
- H,
- I_val,
- n_experts,
- BM: tl.constexpr,
- BN: tl.constexpr,
- BK: tl.constexpr,
-):
- pid_m = tl.program_id(0)
-
... (3921 more chars){
"content": "<4949 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0
... (808 more chars){
"command": "<1027 chars \u2014 see below>",
"timeout": 60000,
"description": "Debug timing issue"
}uv run python -c "
import torch
import sys
sys.path.insert(0, '.')
import solution
import reference
from shapes import SHAPES
import time
shape = SHAPES[0]
reference.T_total = shape['T_total']
reference.H = shape['H']
reference.I = shape['I']
reference.E = shape['E']
reference.K = shape['K']
init = reference.get_init_inputs()
model = solution.Model(*init).cuda().eval()
model.W_gate.data = torch.randn_like(model.W_gate)
model.W_up.data = torch.randn_like(model.W_up)
torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]
# Single forward
torch.cuda.synchronize()
start = time.time()
out = model(*inputs)
torch.cuda.synchronize()
print(f'Single forward: {(time.time()-start)*1000:.1f} ms')
print(f'Output shape: {out.shape}')
print(f'Input T_perm: {inputs[0].shape}')
# Debug grid
HM = shape['H']
IM = shape['I']
BM = 256
BN = 128
counts = inputs[1][1:] - inputs[1][:-1]
m_tiles = (counts + BM - 1) // BM
print(f'Total M tiles: {m_tiles.sum().item()}')
print(f'N tiles: {(IM + BN - 1) // BN}')
" 2>&1Single forward: 1920.5 ms
Output shape: torch.Size([262144, 1536])
Input T_perm: torch.Size([262144, 4096])
Total M tiles: 1024
N tiles: 12