"""Triton-based grouped GEMM + fused SwiGLU kernel for MoE up-projection."""
import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _swiglu_forward_kernel(
hidden_states, # (T_perm, H)
expert_offsets, # (E+1,)
W_gate, # (E, H, I)
W_up, # (E, H, I)
output, # (T_perm, I)
H: tl.constexpr,
I: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_e = tl.program_id(0)
pid_n = tl.program_id(1)
start = tl.load(expert_offsets + pid_e)
end = tl.load(expert_offsets + pid_e + 1)
M = end - start
if M <= 0:
return
hs_base = hidden_states + start * H
wg_base = W_gate + pid_e * H * I
wu_base = W_up + pid_e * H * I
out_base = output + start * I
col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
col_mask = col_offsets < I
for off_m in range(0, M, BLOCK_M):
m_offsets = off_m + tl.arange(0, BLOCK_M)
m_mask = m_offsets < M
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for off_k in range(0, H, BLOCK_K):
k_offsets = off_k + tl.arange(0, BLOCK_K)
k_mask = k_offsets < H
a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
wg_ptrs = wg_base + k_offsets[:, None] * I + col_offsets[None, :]
wg_tile = tl.load(wg_ptrs, mask=k_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
wu_ptrs = wu_base + k_offsets[:, None] * I + col_offsets[None, :]
wu_tile = tl.load(wu_ptrs, mask=k_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
acc_up = tl.dot(a, wu_tile, acc=acc_up)
silu_val = acc_gate * tl.sigmoid(acc_gate)
final_val = silu_val * acc_up
o_ptrs = out_base + m_offsets[:, None] * I + col_offsets[None, :]
tl.store(o_ptrs, final_val, mask=m_mask[:, None] & col_mask[None, :])
def _select_tiling(H, I):
"""Select tiling based on problem shape."""
if H >= 4096 and I <= 1536:
return 128, 64, 64 # HEADLINE: many K steps, moderate N
elif I >= 4096:
return 64, 64, 128 # WIDE: more N tiling
else:
return 128, 64, 64 # DEFAULT
def _forward(hidden_states, expert_offsets, W_gate, W_up, E, H, I):
T_perm, _ = hidden_states.shape
output = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
BLOCK_M, BLOCK_K, BLOCK_N = _select_tiling(H, I)
num_n_tiles = triton.cdiv(I, BLOCK_N)
# Grid: (E, num_n_tiles) -> E * num_n_tiles blocks
grid = (E, max(1, num_n_tiles), 1)
_swiglu_forward_kernel[grid](
hidden_states, expert_offsets, W_gate, W_up,
output, H, I, BLOCK_M, BLOCK_K, BLOCK_N,
)
return output
class Model(nn.Module):
def __init__(self, T_total: int, H: int, I: int, E: int, K: int):
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
def forward(self, hidden_states, expert_offsets):
return _forward(hidden_states, expert_offsets, self.W_gate, self.W_up, self.E, self.H, self.I)
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 20:50:48.630000 1297258 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=33.530 gbps=251.036 ms=24.594
shape=0 variant=compiled tflops=33.479 gbps=250.656 ms=24.631
shape=0 variant=solution tflops=12.776 gbps=95.657 ms=64.543
shape=0 solution_peak_fraction=0.0639
shape=1 variant=eager tflops=14.324 gbps=265.781 ms=2.399
shape=1 variant=compiled tflops=14.287 gbps=265.095 ms=2.405
shape=1 variant=solution tflops=33.313 gbps=618.120 ms=1.031
shape=1 solution_peak_fraction=0.1666
shape=2 variant=eager tflops=35.485 gbps=242.574 ms=15.493
shape=2 variant=compiled tflops=35.282 gbps=241.185 ms=15.582
shape=2 variant=solution tflops=10.429 gbps=71.291 ms=52.715
shape=2 solution_peak_fraction=0.0521
peak_fraction: 0.0822
RESULT: LOW
PASS
{
"problem": "06_sonic_moe_swiglu",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-27b",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0822,
"elapsed_seconds": 1106,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 1204663, "output_tokens": 16459, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 11923, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml",
"status": "in_progress",
"priority": "high"
},
{
"content": "Understand the interface: Model, get_inputs, get_init_inputs from reference.py",
"status": "pending",
"priority": "high"
},
{
"content": "Design grouped GEMM + fused SwiGLU kernel approach (CUTLASS/Triton/CUDA)",
"status": "pending",
"priority": "high"
},
{
"content": "Implement solution.py",
"status": "pending",
... (268 more chars)content: 'Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml'
status: 'in_progress'
priority: 'high'
content: 'Understand the interface: Model, get_inputs, get_init_inputs from reference.py'
status: 'pending'
priority: 'high'
content: 'Design grouped GEMM + fused SwiGLU kernel approach (CUTLASS/Triton/CUDA)'
status: 'pending'
priority: 'high'
content: 'Implement solution.py'
status: 'pending'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Profile and optimize performance'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py</path>
<type>file</type>
<content>
1: """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2:
3: This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4: expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5: according to routing metadata; we compute, per expert e:
6:
7: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
8:
9: where x_e is the slice of permuted hidden states routed to expert e, with
10: expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11:
12: The reference loops over experts in Python. Slow, but pedagogically clear and
13: correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14: are NOT used here, but the reference is exempt — only solution.py is checked.
15: """
16: from __future__ import annotations
17:
18: import torch
19: import torch.nn as nn
20: import torch.nn.functional as F
21:
22: OP_TYPE = "grouped_gemm_swiglu"
23: SUPPORTED_PRECISIONS = ["bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26:
27: class Model(nn.Module):
28: """Up-projection of a top-K MoE FFN with fused SwiGLU.
29:
30: Inputs at call time:
31: hidden_states: (T_perm, H) bf16, already permuted to expert order
32: expert_offsets: (E+1,) int32, prefix sums of token counts per expert
33:
... (2866 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.ge
... (2793 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: return float(eval(expr, {"__builtins__": {}}, vars))
27:
28:
29: def main():
30: import reference
31: import shapes
32: import solution
33:
34: meta = yaml.safe_load(Path("problem.yaml").read_text())
35: hw = get_hw(meta["hardware"][0])
36: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37: peak_gbps = hw.peak_bandwidth_gb_s
38: regime = meta.get("regime", "compute")
39: flops_formula = meta["
... (3701 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Defaults match the sonic-moe paper's headline configuration. We add:
4: - a smaller shape for fast iteration during agent development
5: - a wider intermediate (different aspect ratio) to stress N-tile selection
6: """
7:
8: SHAPES = [
9: # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10: {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11:
12: # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13: # variable-length grouped layout meaningful, but smaller token / hidden dims.
14: {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15:
16: # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17: # Forces tiles to handle larger N relative to K.
18: {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml</path>
<type>file</type>
<content>
1: name: 06_sonic_moe_swiglu
2: display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3: precision: bf16
4: regime: compute
5: hardware: [RTX_PRO_6000]
6: peak_tflops_key: bf16
7: peak_bandwidth_key: dram
8:
9: # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10: # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11: # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12: # only if you count routing. The standard MoE FLOPs convention counts only the
13: # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14: flops_formula: "2 * T_total * H * (2 * I)"
15:
16: # Bytes moved (approximate, lower bound):
17: # read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18: # expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19: bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25: # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26: # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27: forbidden:
28: - "torch.matmul"
29: - "torch.bmm"
30: - "torch.nn.functional.linear"
31: - "F.linear"
32: - "from son
... (681 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
2:
3: Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
4: requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
5: SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
6: the package installs cleanly but kernels may fail at launch on SM120 (the
7: QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
8:
9: If the live call fails, `is_available()` returns False and the benchmark scores
10: the agent against PyTorch eager + the documented H100 paper ceiling (see
11: problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
12: importing sonic_moe in solution.py (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def _try_sonic_moe(
20: hidden_states: torch.Tensor,
21: W_gate: torch.Tensor,
22: W_up: torch.Tensor,
23: expert_offsets: torch.Tensor,
24: ) -> torch.Tensor | None:
25: try:
26: import sonic_moe # type: ignore # noqa: F401
27: except Exception:
28: return None
29: try:
30: # Public sonic-moe API surface is still stabilizing. The expected entry
31: # point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
32: # SwiGLU. Adapt to the a
... (1755 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"command": "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null | head -1",
"description": "Check GPU availability"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB
{
"command": "nvcc --version 2>/dev/null | tail -1",
"description": "Check nvcc version"
}Build cuda_13.2.r13.2/compiler.37434383_0
{
"command": "uv run python -c \"import triton; print(triton.__version__)\" 2>&1",
"description": "Check triton version"
}3.6.0
{
"command": "uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_capability(0))\" 2>&1",
"description": "Check torch version and CUDA capability"
}2.11.0+cu130
(12, 0)
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py 2>/dev/null && head -50 /home/infatoshi/cuda/KernelBench-Hard/problems/01_fp8_gemm/solution.py",
"description": "Check existing solution patterns"
}(no output){
"pattern": "problems/*/solution.py"
}No files found{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/",
"description": "List all problems"
}01_fp8_gemm
02_kda_cutlass
03_paged_attention
04_kahan_softmax
05_topk_bitonic
06_sonic_moe_swiglu
07_w4a16_gemm
08_metal_lightning_attn
{
"todos": [
{
"content": "Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml",
"status": "completed",
"priority": "high"
},
{
"content": "Understand the interface: Model, get_inputs, get_init_inputs from reference.py",
"status": "completed",
"priority": "high"
},
{
"content": "Design grouped GEMM + fused SwiGLU kernel approach (CUTLASS/Triton/CUDA)",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement solution.py",
"status": "pending"
... (272 more chars)content: 'Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml'
status: 'completed'
priority: 'high'
content: 'Understand the interface: Model, get_inputs, get_init_inputs from reference.py'
status: 'completed'
priority: 'high'
content: 'Design grouped GEMM + fused SwiGLU kernel approach (CUTLASS/Triton/CUDA)'
status: 'in_progress'
priority: 'high'
content: 'Implement solution.py'
status: 'pending'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Profile and optimize performance'
status: 'pending'
priority: 'medium'{
"command": "<284 chars \u2014 see below>",
"description": "Verify Triton SM120 compatibility"
}uv run python -c "
import triton
print('Triton version:', triton.__version__)
# Check supported backends
print('Backend:', triton.runtime.driver.active.get_current_backend())
# Check if Triton can compile for SM120
import torch
print('CC:', torch.cuda.get_device_capability(0))
" 2>&1Traceback (most recent call last):
File "<string>", line 5, in <module>
AttributeError: 'CudaDriver' object has no attribute 'get_current_backend'. Did you mean: 'get_current_target'?
Triton version: 3.6.0
{
"command": "<379 chars \u2014 see below>",
"description": "Smoke test Triton on SM120"
}uv run python -c "
import triton
# Quick smoke test with a small kernel
@triton.jit
def test_kernel(out_ptr, n: triton.constexpr):
pid = triton.program_id(0)
if pid < n:
triton.store(out_ptr + pid, pid * 2)
import torch
n = 1024
out = torch.zeros(n, dtype=torch.int32, device='cuda')
test_kernel[(1,)](out, n)
print('Smoke test passed:', out[:8].tolist())
" 2>&1Traceback (most recent call last):
File "<string>", line 5, in <module>
AttributeError: module 'triton' has no attribute 'constexpr'
{
"command": "<180 chars \u2014 see below>",
"description": "Explore Triton 3.6.0 API"
}uv run python -c "
import triton
print(dir(triton))
print()
# Check jit decorator
print(type(triton.jit))
print(triton.jit.__doc__[:500] if triton.jit.__doc__ else 'No doc')
" 2>&1['AsyncCompileMode', 'CompilationError', 'Config', 'FutureKernel', 'InterpreterError', 'JITFunction', 'KernelInterface', 'MockTensor', 'OutOfResources', 'TensorWrapper', 'TritonError', '_C', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', '_utils', 'autotune', 'backends', 'cdiv', 'compile', 'compiler', 'constexpr_function', 'errors', 'heuristics', 'jit', 'knobs', 'language', 'must_use_result', 'next_power_of_2', 'reinterpret', 'runtime', 'set_allocator', 'testing', 'tools']
<class 'function'>
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, arguments are
implicitly converted to pointers if they have a :code:`.data_ptr()` method
and a `.dtype` attribute.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
{
"command": "<95 chars \u2014 see below>",
"description": "Check triton.language module"
}uv run python -c "
import triton.language as tl
print(dir(tl))
" 2>&1 | tr ', ' '\n' | head -80['PropagateNan'
'TRITON_MAX_TENSOR_NUMEL'
'__all__'
'__builtins__'
'__cached__'
'__doc__'
'__file__'
'__loader__'
'__name__'
'__package__'
'__path__'
'__spec__'
'abs'
'add'
'advance'
'arange'
'argmax'
'argmin'
'associative_scan'
'assume'
'atomic_add'
'atomic_and'
'atomic_cas'
'atomic_max'
'atomic_min'
'atomic_or'
'atomic_xchg'
'atomic_xor'
'bfloat16'
'bitonic_merge'
'block_type'
'broadcast'
'broadcast_to'
'cast'
'cat'
'cdiv'
'ceil'
'clamp'
'condition'
'const'
{
"command": "<369 chars \u2014 see below>",
"description": "Find Triton 3.6 constexpr mechanism"
}uv run python -c "
import triton
import triton.language as tl
# Check constexpr_function
print('constexpr_function:', triton.constexpr_function)
# Check if there's a constexpr-like mechanism
print()
print('tl.const:', tl.const)
print(type(tl.const))
# Look for constexpr in tl
for name in dir(tl):
if 'const' in name.lower():
print(f' tl.{name}')
" 2>&1constexpr_function: <function constexpr_function at 0x7e3e15a1b4c0>
tl.const: <class 'triton.language.core.const'>
<class 'type'>
tl.const
tl.constexpr
tl.constexpr_type
tl.max_constancy
{
"command": "<839 chars \u2014 see below>",
"description": "Test Triton 3.6 kernel dispatch"
}uv run python -c "
# Test the basic kernel compilation and dispatch pattern for Triton 3.6
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.int64):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)
n = 1024
x = torch.arange(n, dtype=torch.float32, device='cuda')
y = torch.ones(n, dtype=torch.float32, device='cuda')
out = torch.empty(n, dtype=torch.float32, device='cuda')
grid = lambda META: (triton.cdiv(n, META['BLOCK_SIZE']),)
add_kernel[grid](x, y, out, n, BLOCK_SIZE=256)
print('Result[:8]:', out[:8].tolist())
print('Expected:', (x + y)[:8].tolist())
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 7, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars){
"command": "<152 chars \u2014 see below>",
"description": "Verify tl.constexpr exists"
}uv run python -c "
import triton.language as tl
# Check tl.constexpr type
print('tl.constexpr:', tl.constexpr)
print('type:', type(tl.constexpr))
" 2>&1tl.constexpr: <class 'triton.language.core.constexpr'>
type: <class 'type'>
{
"todos": [
{
"content": "Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml",
"status": "completed",
"priority": "high"
},
{
"content": "Understand the interface: Model, get_inputs, get_init_inputs from reference.py",
"status": "completed",
"priority": "high"
},
{
"content": "Design grouped GEMM + fused SwiGLU kernel approach (Triton)",
"status": "completed",
"priority": "high"
},
{
"content": "Implement solution.py with Triton grouped GEMM + fused SwiGLU",
... (301 more chars)content: 'Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml'
status: 'completed'
priority: 'high'
content: 'Understand the interface: Model, get_inputs, get_init_inputs from reference.py'
status: 'completed'
priority: 'high'
content: 'Design grouped GEMM + fused SwiGLU kernel approach (Triton)'
status: 'completed'
priority: 'high'
content: 'Implement solution.py with Triton grouped GEMM + fused SwiGLU'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Profile and optimize performance'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,101 @@
+"""Triton Grouped GEMM + fused SwiGLU forward kernel."""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _grouped_swiglu_forward_kernel(
+ hidden_states, # (T_perm, H)
+ expert_offsets, # (E+1,)
+ W_gate, # (E, H, I)
+ W_up, # (E, H, I)
+ gate_output, # (T_perm, I)
+ output, # (T_perm, I)
+ E: tl.constexpr,
+ H: tl.constexpr,
+ I: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ """Per-expert gated GEMM with fused SwiGLU epilogue.
+
+ Each program (grid_x = expert id) computes:
+ gate = hidden_states[start:end] @ W_gate[expert]
+ up = hidden_states[start:end] @ W_up[expert]
+ output[start:end] = silu(gate) * up
+ """
+ pid = tl.program_id(0)
+ if pid >= E:
+ return
+
+ start = tl.load(expert_offsets + pid)
+ end = tl.load(expert_offsets + pid + 1)
+ M = end - start
+
+ if M <= 0:
+ return
+
+ # Pointer to the start of input rows for this expert
+ hs_base = hidden_states + start * H
+ # Pointer to expert's weight matrices (E, H, I) layout
+ wg_base = W_gate + pid * H * I
+ wu_base = W_up + pid * H * I
+ # Output pointers
+ gate_base = gate_output + start * I
+ out_base = output + start * I
+
+ # Grid-stride over N (columns of I dimension)
+ for pid_n in range(triton.cdiv(I, BLOCK_N)):
+ # Load input once per N-tile
+ # Accumulate into gate and up tiles
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ col_offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ col_mask_n = col_offsets_n < I
+ wg_n = wg_base + col_offsets_n
+ wu_n = wu_base + col_offsets_n
+
+ # Tiled GEMM over K=H
+ for block_k in range(triton.cdiv(H, BLOCK_N)):
+ row_offsets_k = block_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ row_mask_k = row_offsets_k < H
+
+ for block_m_start in range(0, M, BLOCK_M):
+ row_offsets_m = block_m_start + tl.arange(0, BLOCK_M)
+ row_mask_m = row_offsets_m < M
+
+ a = tl.load(hs_base + row_offsets_m[:, None] * H + row_offsets_k[None, :],
+ mask=row_mask_m[:, None] & row_mask_k[None, :], other=0.0)
+
+ wg_tile = tl.load(wg_n[:, None] * H + row_offsets_k[None, :],
+ mask=col_mask_n[:, None] & row_mask_k[None, :], other=0.0)
+ acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
+
+ # Reload a for up (avoiding reload would require more shared memory)
+ a_up = tl.load(hs_base + row_offsets_m[:, None] * H + row_offsets_k[None, :],
+ mask=row_mask_m[:, None] & row_mask_k[None, :], other=0.0)
+
+ wu_tile = tl.load(wu_n[:, None] * H + row_offsets_k[None, :],
+ mask=col_mask_n[:, None] & row_mask_k[None, :], other=0.0)
+ acc_up = tl.dot(a_up, wu_tile, acc=acc_up)
+
+ tl.multiple_of(a, [1, BLOCK_N])
+ tl.multiple_of(a_up, [1, BLOCK_N])
+
+ # Fuse SwiGLU and store
+ gate_val = tl.sigmoid(acc_gate) * acc_gate
+ final_val = gate_val * acc_up
+
+ for block_m_start in range(0, M, BLOCK_M):
+ row_offsets_m = block_m_start + tl.arange(0, BLOCK_M)
+ row_mask_m = row_offsets_m < M
+
+ tl.store(gate_base + row_offsets_m[:, None] + col_offsets_n[None, :] * 0, # placeholder
+ gate_val,
+ mask=row_mask_m[:, None] & col_mask_n[None, :])
+ tl.store(out_base + row_offsets_m[:, None] + col_offsets_n[None, :] * 0,
+ final_val,
+ mask=row_mask_m[:, None] & col_mask_n[None, :])
{
"content": "<3885 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,4 +1,4 @@
-"""Triton Grouped GEMM + fused SwiGLU forward kernel."""
+"""Triton-based grouped GEMM + fused SwiGLU forward kernel for MoE up-projection."""
import torch
import torch.nn as nn
import triton
@@ -6,96 +6,133 @@
@triton.jit
-def _grouped_swiglu_forward_kernel(
- hidden_states, # (T_perm, H)
- expert_offsets, # (E+1,)
- W_gate, # (E, H, I)
- W_up, # (E, H, I)
- gate_output, # (T_perm, I)
- output, # (T_perm, I)
- E: tl.constexpr,
+def _swiglu_forward_kernel(
+ hidden_states, # (T_perm, H)
+ expert_offsets, # (E+1,)
+ W_gate, # (E, H, I)
+ W_up, # (E, H, I)
+ gate_output, # (T_perm, I)
+ output, # (T_perm, I)
H: tl.constexpr,
I: tl.constexpr,
BLOCK_M: tl.constexpr,
+ BLOCK_K: tl.constexpr,
BLOCK_N: tl.constexpr,
):
- """Per-expert gated GEMM with fused SwiGLU epilogue.
-
- Each program (grid_x = expert id) computes:
- gate = hidden_states[start:end] @ W_gate[expert]
- up = hidden_states[start:end] @ W_up[expert]
- output[start:end] = silu(gate) * up
- """
- pid = tl.program_id(0)
- if pid >= E:
- return
-
- start = tl.load(expert_offsets + pid)
- end = tl.load(expert_offsets + pid + 1)
+ # Each program_id(0) = one expert
+ pid_e = tl.program_id(0)
+ start = tl.load(expert_offsets + pid_e)
+ end = tl.load(expert_offsets + pid_e + 1)
M = end - start
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
if M <= 0:
return
- # Pointer to the start of input rows for this expert
hs_base = hidden_states + start * H
- # Pointer to expert's weight matrices (E, H, I) layout
- wg_base = W_gate + pid * H * I
- wu_base = W_up + pid * H * I
- # Output pointers
- gate_base = gate_output + start * I
+ wg_base = W_gate + pid_e * H * I
+ wu_base = W_up + pid_e * H * I
+ gate_out_base = gate_output + start * I
out_base = output + start * I
- # Grid-stride over N (columns of I dimension)
- for pid_n in range(triton.cdiv(I, BLOCK_N)):
- # Load input once per N-tile
- # Accumulate into gate and up tiles
+ # Process each N tile
+ for tile_n in range(num_n_tiles):
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- col_offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- col_mask_n = col_offsets_n < I
- wg_n = wg_base + col_offsets_n
- wu_n = wu_base + col_offsets_n
+ col_offsets = tile_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ col_mask = col_offsets < I
- # Tiled GEMM over K=H
- for block_k in range(triton.cdiv(H, BLOCK_N)):
- row_offsets_k = block_k * BLOCK_N + tl.arange(0, BLOCK_N)
- row_mask_k = row_offsets_k < H
+ # KGEMM loop: tile along K=H
+ for off_k in range(0, H, BLOCK_K):
+ k_offsets = off_k + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < H
- for block_m_start in range(0, M, BLOCK_M):
- row_offsets_m = block_m_start + tl.arange(0, BLOCK_M)
- row_mask_m = row_offsets_m < M
+ # Process all M rows
+ for off_m in range(0, M, BLOCK_M):
+ m_offsets = off_m + tl.arange(0, BLOCK_M)
+ m_mask = m_offsets < M
- a = tl.load(hs_base + row_offsets_m[:, None] * H + row_offsets_k[None, :],
- mask=row_mask_m[:, None] & row_mask_k[None, :], other=0.0)
+ # Load tile from hidden states: (BLOCK_M, BLOCK_K)
+ a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
+ a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
- wg_tile = tl.load(wg_n[:, None] * H + row_offsets_k[None, :],
- mask=col_mask_n[:, None] & row_mask_k[None, :], other=0.0)
+ # Load weight tile for gate: (BLOCK_K, BLOCK_N) -> then dot gives (BLOCK_M, BLOCK_N)
+ wg_ptrs = wg_base + k_offsets[:, None] + col_offsets[None, :] * H
+ wg_tile = tl.load(wg_ptrs, mask=k_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
- # Reload a for up (avoiding reload would require more shared memory)
- a_up = tl.load(hs_base + row_offsets_m[:, None] * H + row_offsets_k[None, :],
- mask=row_mask_m[:, None] & row_mask_k[None, :], other=0.0)
+ # Load weight tile for up
+ wu_ptrs = wu_base + k_offsets[:, None] + col_offsets[None, :] * H
+ wu_tile = tl.load(wu_ptrs, mask=k_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
+ acc_up = tl.dot(a, wu_tile, acc=acc_up)
- wu_tile = tl.load(wu_n[:, None] * H + row_offsets_k[None, :],
- mask=col_mask_n[:, None] & row_mask_k[None, :], other=0.0)
- acc_up = tl.dot(a_up, wu_tile, acc=acc_up)
+ # Fuse SwiGLU: silu(gate) * up
+ silu_gate = acc_gate * tl.sigmoid(acc_gate)
+ final = silu_gate * acc_up
- tl.multiple_of(a, [1, BLOCK_N])
- tl.multiple_of(a_up, [1, BLOCK_N])
+ # Store output tiles
+ for off_m in range(0, M, BLOCK_M):
+ m_offsets = off_m + tl.arange(0, BLOCK_M)
+ m_mask = m_offsets < M
- # Fuse SwiGLU and store
- gate_val = tl.sigmoid(acc_gate) * acc_gate
- final_val = gate_val * acc_up
+ tile_col_offsets = tile_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ tile_c
... (2659 more chars){
"content": "<4913 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "<638 chars \u2014 see below>",
"description": "Verify reference implementation works"
}uv run python -c "
import torch
import reference
# Set up small shape
reference.T_total = 4096
reference.H = 2048
reference.I = 1024
reference.E = 64
reference.K = 4
device = 'cuda:0'
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
print('hidden_states:', inputs[0].shape)
print('expert_offsets:', inputs[1].shape, inputs[1][:5].tolist())
with torch.no_grad():
ref_out = ref_model(*inputs)
print('reference output:', ref_out.shape, ref_out.dtype)
print('ref out sample:', ref_out[0, :5])
" 2>&1hidden_states: torch.Size([16384, 2048])
expert_offsets: torch.Size([65]) [0, 256, 512, 768, 1024]
reference output: torch.Size([16384, 1024]) torch.bfloat16
ref out sample: tensor([-6.1646e-03, -1.4191e-03, -2.3499e-03, -1.7548e-03, 1.9968e-06],
device='cuda:0', dtype=torch.bfloat16)
{
"command": "<860 chars \u2014 see below>",
"description": "Test solution against reference (small shape)",
"timeout": 120000
}uv run python -c "
import torch
import reference
import solution
# Set up small shape
reference.T_total = 4096
reference.H = 2048
reference.I = 1024
reference.E = 64
reference.K = 4
device = 'cuda:0'
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
# Copy weights
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
with torch.no_grad():
sol_out = sol_model(*inputs)
print('solution output:', sol_out.shape, sol_out.dtype)
print('sol out sample:', sol_out[0, :5])
ref_out = ref_model(*inputs)
print('ref out sample:', ref_out[0, :5])
max_diff = (sol_out.float() - ref_out.float()).abs().max().item()
print('max abs diff:', max_diff)
" 2>&1Traceback (most recent call last):
File "<string>", line 25, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 134, in forward
return _GroupedGemmSwiGLU.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 591, in apply
args = bind_default_args(cls.forward, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 584, in bind_default_args
bound_args = signature.bind(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 3195, in bind
return self._bind(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py",
... (99 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -19,12 +19,11 @@
BLOCK_K: tl.constexpr,
BLOCK_N: tl.constexpr,
):
- # Each program_id(0) = one expert
+ """Per-expert GEMM + SwiGLU. Each program_id(0) = one expert."""
pid_e = tl.program_id(0)
start = tl.load(expert_offsets + pid_e)
end = tl.load(expert_offsets + pid_e + 1)
M = end - start
- num_n_tiles = tl.cdiv(I, BLOCK_N)
if M <= 0:
return
@@ -35,55 +34,64 @@
gate_out_base = gate_output + start * I
out_base = output + start * I
- # Process each N tile
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+
for tile_n in range(num_n_tiles):
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
col_offsets = tile_n * BLOCK_N + tl.arange(0, BLOCK_N)
col_mask = col_offsets < I
- # KGEMM loop: tile along K=H
- for off_k in range(0, H, BLOCK_K):
- k_offsets = off_k + tl.arange(0, BLOCK_K)
- k_mask = k_offsets < H
-
- # Process all M rows
- for off_m in range(0, M, BLOCK_M):
- m_offsets = off_m + tl.arange(0, BLOCK_M)
- m_mask = m_offsets < M
-
- # Load tile from hidden states: (BLOCK_M, BLOCK_K)
- a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
- a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
-
- # Load weight tile for gate: (BLOCK_K, BLOCK_N) -> then dot gives (BLOCK_M, BLOCK_N)
- wg_ptrs = wg_base + k_offsets[:, None] + col_offsets[None, :] * H
- wg_tile = tl.load(wg_ptrs, mask=k_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
- acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
-
- # Load weight tile for up
- wu_ptrs = wu_base + k_offsets[:, None] + col_offsets[None, :] * H
- wu_tile = tl.load(wu_ptrs, mask=k_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
- acc_up = tl.dot(a, wu_tile, acc=acc_up)
-
- # Fuse SwiGLU: silu(gate) * up
- silu_gate = acc_gate * tl.sigmoid(acc_gate)
- final = silu_gate * acc_up
-
- # Store output tiles
+ # Accumulators for (BLOCK_M, BLOCK_N) tiles
for off_m in range(0, M, BLOCK_M):
m_offsets = off_m + tl.arange(0, BLOCK_M)
m_mask = m_offsets < M
+ acc_gate = tl.zeros((min(BLOCK_M, M - off_m), BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((min(BLOCK_M, M - off_m), BLOCK_N), dtype=tl.float32)
+
+ for off_k in range(0, H, BLOCK_K):
+ k_offsets = off_k + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < H
+
+ # Load input tile: (BLOCK_M, BLOCK_K)
+ a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
+ a = tl.load(
+ a_ptrs,
+ mask=m_mask[:, None] & k_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+
+ # Load gate weight tile: (BLOCK_K, BLOCK_N)
+ wg_ptrs = wg_base + k_offsets[:, None] + col_offsets[None, :] * H
+ wg_tile = tl.load(
+ wg_ptrs,
+ mask=k_mask[:, None] & col_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+ acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
+
+ # Load up weight tile: (BLOCK_K, BLOCK_N)
+ wu_ptrs = wu_base + k_offsets[:, None] + col_offsets[None, :] * H
+ wu_tile = tl.load(
+ wu_ptrs,
+ mask=k_mask[:, None] & col_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+ acc_up = tl.dot(a, wu_tile, acc=acc_up)
+
+ # Fused SwiGLU epilogue
+ silu_val = acc_gate * tl.sigmoid(acc_gate)
+ final_val = silu_val * acc_up
+
tile_col_offsets = tile_n * BLOCK_N + tl.arange(0, BLOCK_N)
tile_col_mask = tile_col_offsets < I
+ # Store gate output
g_ptrs = gate_out_base + m_offsets[:, None] * I + tile_col_offsets[None, :]
- tl.store(g_ptrs, silu_gate, mask=m_mask[:, None] & tile_col_mask[None, :])
+ tl.store(g_ptrs, silu_val, mask=m_mask[:, None] & tile_col_mask[None, :])
+ # Store fused output
o_ptrs = out_base + m_offsets[:, None] * I + tile_col_offsets[None, :]
- tl.store(o_ptrs, final, mask=m_mask[:, None] & tile_col_mask[None, :])
+ tl.store(o_ptrs, final_val, mask=m_mask[:, None] & tile_col_mask[None, :])
def _forward(hidden_states, expert_offsets, W_gate, W_up, E, H, I):
@@ -91,28 +99,16 @@
gate_output = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
output = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
- # Choose tiling parameters
BLOCK_M = 256
BLOCK_K = 64
BLOCK_N = 64
- grid = (E, 1, 1)
- _swiglu_forward_kernel[grid](
+ _swiglu_forward_kernel[(E, 1, 1)](
hidden_states, expert_offsets, W_gate, W_up,
gate_output, output,
H, I, BLOCK_M, BLOCK_K, BLOCK_N,
)
return output
-
-
-class _GroupedGemmSwiGLU(torch.autograd.Function):
- @staticmethod
- def forward(ctx, hidden_states, expert_offsets, W_gate, W_up, E, H, I):
- return _forward(hidden_states, expert_offsets, W_gate, W_up, E, H, I)
-
- @staticmethod
- def setup_context(ctx, inputs, output):
- pass
class Model(nn.Module):
@@ -131,7 +127,7 @@
nn.init.normal_(self.W_up, std=0.02)
def forward(self, hidden_states, exper
... (198 more chars){
"content": "<4689 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -34,25 +34,22 @@
gate_out_base = gate_output + start * I
out_base = output + start * I
- num_n_tiles = tl.cdiv(I, BLOCK_N)
+ for off_m in range(0, M, BLOCK_M):
+ m_offsets = off_m + tl.arange(0, BLOCK_M)
+ m_mask = m_offsets < M
- for tile_n in range(num_n_tiles):
- col_offsets = tile_n * BLOCK_N + tl.arange(0, BLOCK_N)
- col_mask = col_offsets < I
+ for tile_n in range(0, I, BLOCK_N):
+ col_offsets = tile_n + tl.arange(0, BLOCK_N)
+ col_mask = col_offsets < I
- # Accumulators for (BLOCK_M, BLOCK_N) tiles
- for off_m in range(0, M, BLOCK_M):
- m_offsets = off_m + tl.arange(0, BLOCK_M)
- m_mask = m_offsets < M
-
- acc_gate = tl.zeros((min(BLOCK_M, M - off_m), BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((min(BLOCK_M, M - off_m), BLOCK_N), dtype=tl.float32)
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for off_k in range(0, H, BLOCK_K):
k_offsets = off_k + tl.arange(0, BLOCK_K)
k_mask = k_offsets < H
- # Load input tile: (BLOCK_M, BLOCK_K)
+ # Input tile: (BLOCK_M, BLOCK_K)
a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
a = tl.load(
a_ptrs,
@@ -60,7 +57,7 @@
other=0.0,
).to(tl.float32)
- # Load gate weight tile: (BLOCK_K, BLOCK_N)
+ # Gate weight tile: (BLOCK_K, BLOCK_N)
wg_ptrs = wg_base + k_offsets[:, None] + col_offsets[None, :] * H
wg_tile = tl.load(
wg_ptrs,
@@ -69,7 +66,7 @@
).to(tl.float32)
acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
- # Load up weight tile: (BLOCK_K, BLOCK_N)
+ # Up weight tile: (BLOCK_K, BLOCK_N)
wu_ptrs = wu_base + k_offsets[:, None] + col_offsets[None, :] * H
wu_tile = tl.load(
wu_ptrs,
@@ -78,20 +75,15 @@
).to(tl.float32)
acc_up = tl.dot(a, wu_tile, acc=acc_up)
- # Fused SwiGLU epilogue
+ # Fused SwiGLU
silu_val = acc_gate * tl.sigmoid(acc_gate)
final_val = silu_val * acc_up
- tile_col_offsets = tile_n * BLOCK_N + tl.arange(0, BLOCK_N)
- tile_col_mask = tile_col_offsets < I
+ g_ptrs = gate_out_base + m_offsets[:, None] * I + col_offsets[None, :]
+ tl.store(g_ptrs, silu_val, mask=m_mask[:, None] & col_mask[None, :])
- # Store gate output
- g_ptrs = gate_out_base + m_offsets[:, None] * I + tile_col_offsets[None, :]
- tl.store(g_ptrs, silu_val, mask=m_mask[:, None] & tile_col_mask[None, :])
-
- # Store fused output
- o_ptrs = out_base + m_offsets[:, None] * I + tile_col_offsets[None, :]
- tl.store(o_ptrs, final_val, mask=m_mask[:, None] & tile_col_mask[None, :])
+ o_ptrs = out_base + m_offsets[:, None] * I + col_offsets[None, :]
+ tl.store(o_ptrs, final_val, mask=m_mask[:, None] & col_mask[None, :])
def _forward(hidden_states, expert_offsets, W_gate, W_up, E, H, I):
{
"content": "<4327 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -13,6 +13,7 @@
W_up, # (E, H, I)
gate_output, # (T_perm, I)
output, # (T_perm, I)
+ T_perm: tl.constexpr,
H: tl.constexpr,
I: tl.constexpr,
BLOCK_M: tl.constexpr,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<55 chars \u2014 see below>",
"newString": "<81 chars \u2014 see below>"
} output, # (T_perm, I)
H: tl.constexpr, output, # (T_perm, I)
T_perm: tl.constexpr,
H: tl.constexpr,Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,4 +1,4 @@
-"""Triton-based grouped GEMM + fused SwiGLU forward kernel for MoE up-projection."""
+"""Triton-based grouped GEMM + fused SwiGLU kernel for MoE up-projection."""
import torch
import torch.nn as nn
import triton
@@ -11,16 +11,14 @@
expert_offsets, # (E+1,)
W_gate, # (E, H, I)
W_up, # (E, H, I)
- gate_output, # (T_perm, I)
output, # (T_perm, I)
- T_perm: tl.constexpr,
H: tl.constexpr,
I: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_N: tl.constexpr,
):
- """Per-expert GEMM + SwiGLU. Each program_id(0) = one expert."""
+ """Per-expert GEMM + fused SwiGLU. Each program_id(0) = one expert."""
pid_e = tl.program_id(0)
start = tl.load(expert_offsets + pid_e)
end = tl.load(expert_offsets + pid_e + 1)
@@ -32,9 +30,9 @@
hs_base = hidden_states + start * H
wg_base = W_gate + pid_e * H * I
wu_base = W_up + pid_e * H * I
- gate_out_base = gate_output + start * I
out_base = output + start * I
+ # Tile over M and N dimensions; K=H is the reduction dimension
for off_m in range(0, M, BLOCK_M):
m_offsets = off_m + tl.arange(0, BLOCK_M)
m_mask = m_offsets < M
@@ -46,11 +44,12 @@
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ # Tiled reduction over K=H
for off_k in range(0, H, BLOCK_K):
k_offsets = off_k + tl.arange(0, BLOCK_K)
k_mask = k_offsets < H
- # Input tile: (BLOCK_M, BLOCK_K)
+ # Load input tile: (BLOCK_M, BLOCK_K)
a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
a = tl.load(
a_ptrs,
@@ -76,12 +75,9 @@
).to(tl.float32)
acc_up = tl.dot(a, wu_tile, acc=acc_up)
- # Fused SwiGLU
+ # Fused SwiGLU epilogue: silu(gate) * up
silu_val = acc_gate * tl.sigmoid(acc_gate)
final_val = silu_val * acc_up
-
- g_ptrs = gate_out_base + m_offsets[:, None] * I + col_offsets[None, :]
- tl.store(g_ptrs, silu_val, mask=m_mask[:, None] & col_mask[None, :])
o_ptrs = out_base + m_offsets[:, None] * I + col_offsets[None, :]
tl.store(o_ptrs, final_val, mask=m_mask[:, None] & col_mask[None, :])
@@ -89,7 +85,6 @@
def _forward(hidden_states, expert_offsets, W_gate, W_up, E, H, I):
T_perm, _ = hidden_states.shape
- gate_output = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
output = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
BLOCK_M = 256
@@ -98,7 +93,7 @@
_swiglu_forward_kernel[(E, 1, 1)](
hidden_states, expert_offsets, W_gate, W_up,
- gate_output, output,
+ output,
H, I, BLOCK_M, BLOCK_K, BLOCK_N,
)
return output
{
"content": "<4113 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "<935 chars \u2014 see below>",
"description": "Test solution (small shape, seed 42)",
"timeout": 180000
}uv run python -c "
import torch
import reference
import solution as sol
# Small shape for fast iteration
reference.T_total = 4096
reference.H = 2048
reference.I = 1024
reference.E = 64
reference.K = 4
device = 'cuda:0'
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = sol.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
with torch.no_grad():
# Time first run (includes compilation)
sol_out = sol_model(*inputs)
print('Solution output:', sol_out.shape, sol_out.dtype)
print('Sample:', sol_out[0, :5])
ref_out = ref_model(*inputs)
print('Ref sample:', ref_out[0, :5])
max_diff = (sol_out.float() - ref_out.float()).abs().max().item()
print('max abs diff:', max_diff)
print('PASS' if max_diff < 0.02 else 'FAIL')
" 2>&1Solution output: torch.Size([16384, 1024]) torch.bfloat16
Sample: tensor([ 0.0039, -0.0010, 0.0079, -0.0011, -0.0006], device='cuda:0',
dtype=torch.bfloat16)
Ref sample: tensor([ 0.0037, 0.0024, -0.0028, -0.0003, -0.0002], device='cuda:0',
dtype=torch.bfloat16)
max abs diff: 0.07293701171875
FAIL
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -57,8 +57,8 @@
other=0.0,
).to(tl.float32)
- # Gate weight tile: (BLOCK_K, BLOCK_N)
- wg_ptrs = wg_base + k_offsets[:, None] + col_offsets[None, :] * H
+ # Gate weight tile: (BLOCK_K, BLOCK_N) -- shape (H_row, I_col)
+ wg_ptrs = wg_base + k_offsets[:, None] * I + col_offsets[None, :]
wg_tile = tl.load(
wg_ptrs,
mask=k_mask[:, None] & col_mask[None, :],
@@ -67,7 +67,7 @@
acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
# Up weight tile: (BLOCK_K, BLOCK_N)
- wu_ptrs = wu_base + k_offsets[:, None] + col_offsets[None, :] * H
+ wu_ptrs = wu_base + k_offsets[:, None] * I + col_offsets[None, :]
wu_tile = tl.load(
wu_ptrs,
mask=k_mask[:, None] & col_mask[None, :],
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<768 chars \u2014 see below>",
"newString": "<793 chars \u2014 see below>"
} # Gate weight tile: (BLOCK_K, BLOCK_N)
wg_ptrs = wg_base + k_offsets[:, None] + col_offsets[None, :] * H
wg_tile = tl.load(
wg_ptrs,
mask=k_mask[:, None] & col_mask[None, :],
other=0.0,
).to(tl.float32)
acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
# Up weight tile: (BLOCK_K, BLOCK_N)
wu_ptrs = wu_base + k_offsets[:, None] + col_offsets[None, :] * H
wu_tile = tl.load(
wu_ptrs,
mask=k_mask[:, None] & col_mask[None, :],
other=0.0,
).to(tl.float32)
acc_up = tl.dot(a, wu_tile, acc=acc_up) # Gate weight tile: (BLOCK_K, BLOCK_N) -- shape (H_row, I_col)
wg_ptrs = wg_base + k_offsets[:, None] * I + col_offsets[None, :]
wg_tile = tl.load(
wg_ptrs,
mask=k_mask[:, None] & col_mask[None, :],
other=0.0,
).to(tl.float32)
acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
# Up weight tile: (BLOCK_K, BLOCK_N)
wu_ptrs = wu_base + k_offsets[:, None] * I + col_offsets[None, :]
wu_tile = tl.load(
wu_ptrs,
mask=k_mask[:, None] & col_mask[None, :],
other=0.0,
).to(tl.float32)
acc_up = tl.dot(a, wu_tile, acc=acc_up)Edit applied successfully.{
"command": "<772 chars \u2014 see below>",
"description": "Test corrected solution",
"timeout": 180000
}uv run python -c "
import torch
import reference
import importlib
import solution as sol
importlib.reload(sol)
# Small shape
reference.T_total = 4096
reference.H = 2048
reference.I = 1024
reference.E = 64
reference.K = 4
device = 'cuda:0'
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = sol.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
with torch.no_grad():
sol_out = sol_model(*inputs)
ref_out = ref_model(*inputs)
max_diff = (sol_out.float() - ref_out.float()).abs().max().item()
print('max abs diff:', max_diff)
print('PASS' if max_diff < 0.02 else 'FAIL')
" 2>&1max abs diff: 0.00048828125
PASS
{
"todos": [
{
"content": "Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml",
"status": "completed",
"priority": "high"
},
{
"content": "Understand the interface: Model, get_inputs, get_init_inputs from reference.py",
"status": "completed",
"priority": "high"
},
{
"content": "Design grouped GEMM + fused SwiGLU kernel approach (Triton)",
"status": "completed",
"priority": "high"
},
{
"content": "Implement solution.py with Triton grouped GEMM + fused SwiGLU",
... (303 more chars)content: 'Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml'
status: 'completed'
priority: 'high'
content: 'Understand the interface: Model, get_inputs, get_init_inputs from reference.py'
status: 'completed'
priority: 'high'
content: 'Design grouped GEMM + fused SwiGLU kernel approach (Triton)'
status: 'completed'
priority: 'high'
content: 'Implement solution.py with Triton grouped GEMM + fused SwiGLU'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'in_progress'
priority: 'high'
content: 'Profile and optimize performance'
status: 'pending'
priority: 'medium'{
"command": "uv run python check.py 2>&1",
"description": "Run full correctness check",
"timeout": 600000
}PASS
{
"todos": [
{
"content": "Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml",
"status": "completed",
"priority": "high"
},
{
"content": "Understand the interface: Model, get_inputs, get_init_inputs from reference.py",
"status": "completed",
"priority": "high"
},
{
"content": "Design grouped GEMM + fused SwiGLU kernel approach (Triton)",
"status": "completed",
"priority": "high"
},
{
"content": "Implement solution.py with Triton grouped GEMM + fused SwiGLU",
... (305 more chars)content: 'Explore problem directory and read reference.py, check.py, benchmark.py, shapes.py, problem.yaml'
status: 'completed'
priority: 'high'
content: 'Understand the interface: Model, get_inputs, get_init_inputs from reference.py'
status: 'completed'
priority: 'high'
content: 'Design grouped GEMM + fused SwiGLU kernel approach (Triton)'
status: 'completed'
priority: 'high'
content: 'Implement solution.py with Triton grouped GEMM + fused SwiGLU'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Profile and optimize performance'
status: 'in_progress'
priority: 'medium'{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark with baseline kernel",
"timeout": 600000
}W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0427 20:40:34.739000 1296025 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0
... (806 more chars){
"command": "<214 chars \u2014 see below>",
"description": "Check hardware peak specs"
}uv run python -c "
from src.hardware import get as get_hw
hw = get_hw('RTX_PRO_6000')
print('Peak bf16 TFLOPS:', hw.peak_tflops_dense.get('bf16', 'N/A'))
print('Peak bandwidth GB/s:', hw.peak_bandwidth_gb_s)
" 2>&1Peak bf16 TFLOPS: 200.0
Peak bandwidth GB/s: 1800.0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -18,8 +18,13 @@
BLOCK_K: tl.constexpr,
BLOCK_N: tl.constexpr,
):
- """Per-expert GEMM + fused SwiGLU. Each program_id(0) = one expert."""
- pid_e = tl.program_id(0)
+ """Per expert-tile pair: one M-tile of one expert's GEMM+SwiGLU."""
+ # Map global program ID to (expert, m_tile)
+ pid = tl.program_id(0)
+ pid_n = tl.program_id(1) # N-tile index (for better parallelism)
+
+ # Each pid maps to one expert
+ pid_e = pid
start = tl.load(expert_offsets + pid_e)
end = tl.load(expert_offsets + pid_e + 1)
M = end - start
@@ -32,55 +37,57 @@
wu_base = W_up + pid_e * H * I
out_base = output + start * I
- # Tile over M and N dimensions; K=H is the reduction dimension
- for off_m in range(0, M, BLOCK_M):
+ # N-tile from grid
+ col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ col_mask = col_offsets < I
+
+ # Grid-stride over M tiles
+ num_m_tiles = tl.cdiv(M, BLOCK_M)
+ for tile_idx in range(num_m_tiles):
+ off_m = tile_idx * BLOCK_M
m_offsets = off_m + tl.arange(0, BLOCK_M)
m_mask = m_offsets < M
- for tile_n in range(0, I, BLOCK_N):
- col_offsets = tile_n + tl.arange(0, BLOCK_N)
- col_mask = col_offsets < I
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ # Tiled reduction over K=H
+ for off_k in range(0, H, BLOCK_K):
+ k_offsets = off_k + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < H
- # Tiled reduction over K=H
- for off_k in range(0, H, BLOCK_K):
- k_offsets = off_k + tl.arange(0, BLOCK_K)
- k_mask = k_offsets < H
+ # Load input tile: (BLOCK_M, BLOCK_K)
+ a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
+ a = tl.load(
+ a_ptrs,
+ mask=m_mask[:, None] & k_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
- # Load input tile: (BLOCK_M, BLOCK_K)
- a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
- a = tl.load(
- a_ptrs,
- mask=m_mask[:, None] & k_mask[None, :],
- other=0.0,
- ).to(tl.float32)
+ # Gate weight tile: (BLOCK_K, BLOCK_N)
+ wg_ptrs = wg_base + k_offsets[:, None] * I + col_offsets[None, :]
+ wg_tile = tl.load(
+ wg_ptrs,
+ mask=k_mask[:, None] & col_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+ acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
- # Gate weight tile: (BLOCK_K, BLOCK_N) -- shape (H_row, I_col)
- wg_ptrs = wg_base + k_offsets[:, None] * I + col_offsets[None, :]
- wg_tile = tl.load(
- wg_ptrs,
- mask=k_mask[:, None] & col_mask[None, :],
- other=0.0,
- ).to(tl.float32)
- acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
+ # Up weight tile: (BLOCK_K, BLOCK_N)
+ wu_ptrs = wu_base + k_offsets[:, None] * I + col_offsets[None, :]
+ wu_tile = tl.load(
+ wu_ptrs,
+ mask=k_mask[:, None] & col_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+ acc_up = tl.dot(a, wu_tile, acc=acc_up)
- # Up weight tile: (BLOCK_K, BLOCK_N)
- wu_ptrs = wu_base + k_offsets[:, None] * I + col_offsets[None, :]
- wu_tile = tl.load(
- wu_ptrs,
- mask=k_mask[:, None] & col_mask[None, :],
- other=0.0,
- ).to(tl.float32)
- acc_up = tl.dot(a, wu_tile, acc=acc_up)
+ # Fused SwiGLU
+ silu_val = acc_gate * tl.sigmoid(acc_gate)
+ final_val = silu_val * acc_up
- # Fused SwiGLU epilogue: silu(gate) * up
- silu_val = acc_gate * tl.sigmoid(acc_gate)
- final_val = silu_val * acc_up
-
- o_ptrs = out_base + m_offsets[:, None] * I + col_offsets[None, :]
- tl.store(o_ptrs, final_val, mask=m_mask[:, None] & col_mask[None, :])
+ o_ptrs = out_base + m_offsets[:, None] * I + col_offsets[None, :]
+ tl.store(o_ptrs, final_val, mask=m_mask[:, None] & col_mask[None, :])
def _forward(hidden_states, expert_offsets, W_gate, W_up, E, H, I):
@@ -91,7 +98,14 @@
BLOCK_K = 64
BLOCK_N = 64
- _swiglu_forward_kernel[(E, 1, 1)](
+ num_n_tiles = triton.cdiv(I, BLOCK_N)
+ if num_n_tiles > 128:
+ # Split into N-tile grid for better occupancy
+ grid = (E, num_n_tiles, 1)
+ else:
+ grid = (E, num_n_tiles, 1)
+
+ _swiglu_forward_kernel[grid](
hidden_states, expert_offsets, W_gate, W_up,
output,
H, I, BLOCK_M, BLOCK_K, BLOCK_N,
{
"content": "<4323 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Correctness check with N-tile grid",
"timeout": 600000
}PASS
{
"command": "uv run python benchmark.py 2>&1 | grep -E 'solution|peak_fraction|RESULT'",
"description": "Benchmark with N-tile grid",
"timeout": 600000
}shape=0 variant=solution tflops=7.504 gbps=56.185 ms=109.888
shape=0 solution_peak_fraction=0.0375
shape=1 variant=solution tflops=15.084 gbps=279.876 ms=2.278
shape=1 solution_peak_fraction=0.0754
shape=2 variant=solution tflops=7.717 gbps=52.753 ms=71.240
shape=2 solution_peak_fraction=0.0386
peak_fraction: 0.0478
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -4,7 +4,29 @@
import triton
import triton.language as tl
+_CONFIGS = [
+ triton.Config({}, num_stages=3, num_warps=4),
+ triton.Config({}, num_stages=3, num_warps=8),
+ triton.Config({}, num_stages=4, num_warps=4),
+ triton.Config({}, num_stages=4, num_warps=8),
+ triton.Config({}, num_stages=5, num_warps=8),
+ triton.Config({}, num_stages=7, num_warps=8),
+ triton.Config({}, num_stages=3, num_warps=4, num_ctas=2),
+ triton.Config({}, num_stages=3, num_warps=8, num_ctas=2),
+]
+_BLOCK_SIZES = [
+ (128, 64, 128), # BLOCK_M, BLOCK_K, BLOCK_N
+ (256, 64, 64),
+ (256, 128, 64),
+ (64, 64, 128),
+ (64, 128, 128),
+ (128, 128, 64),
+ (256, 64, 128),
+]
+
+
+@triton.autotune(configs=_CONFIGS, key=["H", "I"])
@triton.jit
def _swiglu_forward_kernel(
hidden_states, # (T_perm, H)
@@ -18,13 +40,9 @@
BLOCK_K: tl.constexpr,
BLOCK_N: tl.constexpr,
):
- """Per expert-tile pair: one M-tile of one expert's GEMM+SwiGLU."""
- # Map global program ID to (expert, m_tile)
- pid = tl.program_id(0)
- pid_n = tl.program_id(1) # N-tile index (for better parallelism)
+ pid_e = tl.program_id(0)
+ pid_n = tl.program_id(1)
- # Each pid maps to one expert
- pid_e = pid
start = tl.load(expert_offsets + pid_e)
end = tl.load(expert_offsets + pid_e + 1)
M = end - start
@@ -37,11 +55,9 @@
wu_base = W_up + pid_e * H * I
out_base = output + start * I
- # N-tile from grid
col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
col_mask = col_offsets < I
- # Grid-stride over M tiles
num_m_tiles = tl.cdiv(M, BLOCK_M)
for tile_idx in range(num_m_tiles):
off_m = tile_idx * BLOCK_M
@@ -51,38 +67,21 @@
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- # Tiled reduction over K=H
for off_k in range(0, H, BLOCK_K):
k_offsets = off_k + tl.arange(0, BLOCK_K)
k_mask = k_offsets < H
- # Load input tile: (BLOCK_M, BLOCK_K)
a_ptrs = hs_base + m_offsets[:, None] * H + k_offsets[None, :]
- a = tl.load(
- a_ptrs,
- mask=m_mask[:, None] & k_mask[None, :],
- other=0.0,
- ).to(tl.float32)
+ a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
- # Gate weight tile: (BLOCK_K, BLOCK_N)
wg_ptrs = wg_base + k_offsets[:, None] * I + col_offsets[None, :]
- wg_tile = tl.load(
- wg_ptrs,
- mask=k_mask[:, None] & col_mask[None, :],
- other=0.0,
- ).to(tl.float32)
+ wg_tile = tl.load(wg_ptrs, mask=k_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
acc_gate = tl.dot(a, wg_tile, acc=acc_gate)
- # Up weight tile: (BLOCK_K, BLOCK_N)
wu_ptrs = wu_base + k_offsets[:, None] * I + col_offsets[None, :]
- wu_tile = tl.load(
- wu_ptrs,
- mask=k_mask[:, None] & col_mask[None, :],
- other=0.0,
- ).to(tl.float32)
+ wu_tile = tl.load(wu_ptrs, mask=k_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
acc_up = tl.dot(a, wu_tile, acc=acc_up)
- # Fused SwiGLU
silu_val = acc_gate * tl.sigmoid(acc_gate)
final_val = silu_val * acc_up
@@ -94,28 +93,20 @@
T_perm, _ = hidden_states.shape
output = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
- BLOCK_M = 256
- BLOCK_K = 64
- BLOCK_N = 64
+ # Use the first config's block sizes; autotune picks config
+ BLOCK_M, BLOCK_K, BLOCK_N = _BLOCK_SIZES[0]
+ num_n_tiles = triton.cdiv(I, BLOCK_N)
- num_n_tiles = triton.cdiv(I, BLOCK_N)
- if num_n_tiles > 128:
- # Split into N-tile grid for better occupancy
- grid = (E, num_n_tiles, 1)
- else:
- grid = (E, num_n_tiles, 1)
+ grid = (E, num_n_tiles, 1)
_swiglu_forward_kernel[grid](
hidden_states, expert_offsets, W_gate, W_up,
- output,
- H, I, BLOCK_M, BLOCK_K, BLOCK_N,
+ output, H, I, BLOCK_M, BLOCK_K, BLOCK_N,
)
return output
class Model(nn.Module):
- """Up-projection of a top-K MoE FFN with fused SwiGLU (Triton kernels)."""
-
def __init__(self, T_total: int, H: int, I: int, E: int, K: int):
super().__init__()
self.T_total = T_total
@@ -129,8 +120,4 @@
nn.init.normal_(self.W_up, std=0.02)
def forward(self, hidden_states, expert_offsets):
- return _forward(
- hidden_states, expert_offsets,
- self.W_gate, self.W_up,
- self.E, self.H, self.I,
- )
+ return _forward(hidden_states, expert_offsets, self.W_gate, self.W_up, self.E, self.H, self.I)
{
"content": "<4137 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -4,29 +4,7 @@
import triton
import triton.language as tl
-_CONFIGS = [
- triton.Config({}, num_stages=3, num_warps=4),
- triton.Config({}, num_stages=3, num_warps=8),
- triton.Config({}, num_stages=4, num_warps=4),
- triton.Config({}, num_stages=4, num_warps=8),
- triton.Config({}, num_stages=5, num_warps=8),
- triton.Config({}, num_stages=7, num_warps=8),
- triton.Config({}, num_stages=3, num_warps=4, num_ctas=2),
- triton.Config({}, num_stages=3, num_warps=8, num_ctas=2),
-]
-_BLOCK_SIZES = [
- (128, 64, 128), # BLOCK_M, BLOCK_K, BLOCK_N
- (256, 64, 64),
- (256, 128, 64),
- (64, 64, 128),
- (64, 128, 128),
- (128, 128, 64),
- (256, 64, 128),
-]
-
-
-@triton.autotune(configs=_CONFIGS, key=["H", "I"])
@triton.jit
def _swiglu_forward_kernel(
hidden_states, # (T_perm, H)
@@ -46,7 +24,6 @@
start = tl.load(expert_offsets + pid_e)
end = tl.load(expert_offsets + pid_e + 1)
M = end - start
-
if M <= 0:
return
@@ -58,9 +35,7 @@
col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
col_mask = col_offsets < I
- num_m_tiles = tl.cdiv(M, BLOCK_M)
- for tile_idx in range(num_m_tiles):
- off_m = tile_idx * BLOCK_M
+ for off_m in range(0, M, BLOCK_M):
m_offsets = off_m + tl.arange(0, BLOCK_M)
m_mask = m_offsets < M
@@ -89,15 +64,25 @@
tl.store(o_ptrs, final_val, mask=m_mask[:, None] & col_mask[None, :])
+def _select_tiling(H, I):
+ """Select tiling based on problem shape."""
+ if H >= 4096 and I <= 1536:
+ return 128, 64, 64 # HEADLINE: many K steps, moderate N
+ elif I >= 4096:
+ return 64, 64, 128 # WIDE: more N tiling
+ else:
+ return 128, 64, 64 # DEFAULT
+
+
def _forward(hidden_states, expert_offsets, W_gate, W_up, E, H, I):
T_perm, _ = hidden_states.shape
output = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
- # Use the first config's block sizes; autotune picks config
- BLOCK_M, BLOCK_K, BLOCK_N = _BLOCK_SIZES[0]
+ BLOCK_M, BLOCK_K, BLOCK_N = _select_tiling(H, I)
num_n_tiles = triton.cdiv(I, BLOCK_N)
- grid = (E, num_n_tiles, 1)
+ # Grid: (E, num_n_tiles) -> E * num_n_tiles blocks
+ grid = (E, max(1, num_n_tiles), 1)
_swiglu_forward_kernel[grid](
hidden_states, expert_offsets, W_gate, W_up,
{
"content": "<3679 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.