"""Triton grouped GEMM + fused SwiGLU for MoE up-projection (SM120 Blackwell).
Each program instance handles ONE (expert, n_tile, m_tile) triple.
Loops over K tiles to accumulate the full GEMM.
SwiGLU is fused in the epilogue: silu(gate) * up.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
BM = 128
BK = 64
BN = 128
@triton.jit
def _grouped_gemm_swiglu_kernel(
x_ptr, w_gate_ptr, w_up_ptr, out_ptr, offsets_ptr,
H: tl.constexpr, I: tl.constexpr, E: tl.constexpr,
num_n_tiles: tl.constexpr, max_m_tiles: tl.constexpr,
BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr,
num_k_tiles: tl.constexpr,
):
pid = tl.program_id(0)
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
start = tl.load(offsets_ptr + expert)
end = tl.load(offsets_ptr + expert + 1)
num_rows = end - start
m_tiles_needed = tl.cdiv(num_rows, BM)
if m_tile_pid >= m_tiles_needed:
return
m_start = start + m_tile_pid * BM
n_start = n_tile * BN
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < end
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
expert_w_off = expert * H * I
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
x_mask = load_row_mask[:, None] & k_mask[None, :]
x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
w_ptrs = k_offs[:, None] * I + n_offs[None, :]
wg_ptrs = w_gate_ptr + expert_w_off + w_ptrs
wu_ptrs = w_up_ptr + expert_w_off + w_ptrs
wg_mask = k_mask[:, None] & n_mask[None, :]
wg = tl.load(wg_ptrs, mask=wg_mask, other=0.0).to(tl.float32)
wu = tl.load(wu_ptrs, mask=wg_mask, other=0.0).to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
result = (gate_silu * up_acc).to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
class FusedSwiGLUGroupedGEMM(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, W_gate, W_up, expert_offsets, E, I):
T_perm, H = hidden_states.shape
out = torch.empty((T_perm, I), dtype=torch.bfloat16, device=hidden_states.device)
num_k_tiles = triton.cdiv(H, BK)
num_n_tiles = triton.cdiv(I, BN)
max_tokens_per_expert = int(expert_offsets[1:].diff().max().item()) if T_perm > 0 else 0
max_m_tiles = triton.cdiv(max_tokens_per_expert, BM)
total_programs = E * num_n_tiles * max_m_tiles
if total_programs > 0:
_grouped_gemm_swiglu_kernel[(total_programs,)](
hidden_states, W_gate, W_up, out, expert_offsets,
H=H, I=I, E=E,
num_n_tiles=num_n_tiles, max_m_tiles=max_m_tiles,
BM=BM, BK=BK, BN=BN,
num_k_tiles=num_k_tiles,
)
return out
class Model(nn.Module):
def __init__(self, T_total: int, H: int, I: int, E: int, K: int):
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
def forward(self, hidden_states: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
return FusedSwiGLUGroupedGEMM.apply(
hidden_states, self.W_gate, self.W_up, expert_offsets, self.E, self.I
)
T_total = 32768
H = 4096
I = 1536
E = 128
K = 8
def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
T_perm = T_total * K
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
return offsets
def get_inputs():
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
expert_offsets = _build_routing(T_total, E, K)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 05:26:18.699000 934693 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=32.413 gbps=242.675 ms=25.442
shape=0 variant=compiled tflops=32.645 gbps=244.414 ms=25.260
shape=0 variant=solution tflops=0.594 gbps=4.446 ms=1388.671
shape=0 solution_peak_fraction=0.0030
shape=1 variant=eager tflops=14.029 gbps=260.295 ms=2.449
shape=1 variant=compiled tflops=13.962 gbps=259.052 ms=2.461
shape=1 variant=solution tflops=1.157 gbps=21.474 ms=29.688
shape=1 solution_peak_fraction=0.0058
shape=2 variant=eager tflops=34.794 gbps=237.852 ms=15.800
shape=2 variant=compiled tflops=34.321 gbps=234.619 ms=16.018
shape=2 variant=solution tflops=0.605 gbps=4.133 ms=909.317
shape=2 solution_peak_fraction=0.0030
peak_fraction: 0.0037
RESULT: LOW
PASS
{
"problem": "06_sonic_moe_swiglu",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-max-preview",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0037,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 4527194, "output_tokens": 40948, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 20855, "total_cost_usd": null}
}
{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"name": "moe-training"
}<skill_content name="moe-training">
# Skill: moe-training
# Mixture-of-Experts Training and Inference
MoE-specific concerns for training, serving, and optimizing sparse models.
## 1. Architecture Basics
### Expert routing
- **Top-k routing**: each token selects k experts (typically k=2) based on gating scores
- **Gating function**: typically softmax over linear projection of hidden state
- **Load balancing loss**: auxiliary loss to prevent expert collapse (all tokens routed to same experts)
- **Capacity factor**: maximum ratio of tokens an expert can process. Overflow tokens are dropped or rerouted
### Scale reference points
- DeepSeek-V3/R1: 671B total, 37B active, 256 experts, top-8
- Kimi K2.5: 1.04T total, 32B active
- Qwen3-Coder-30B-A3B: 30B total, 3B active
- Expert count and active params are independent knobs
## 2. Parallelism Decomposition
### Dimensions
- **FSDP (Data Parallel)**: shard model parameters across DP ranks. Standard.
- **EP (Expert Parallel)**: distribute experts across ranks. Each rank holds subset of experts.
- **CP (Context Parallel)**: split sequence across ranks for long-context. Better than TP for MLA.
- **TP (Tensor Parallel)**: split individual weight matrices. High communication cost.
### Why decouple EP from TP
Original designs (Switch Transformer era) coupled EP and TP to same rank group. Problems:
- Constrains EP degree to TP degree
- Forces activation sharding on dense layers even when memory pressure is modest
- Limits expert-grou
... (11624 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py</path>
<type>file</type>
<content>
1: """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2:
3: This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4: expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5: according to routing metadata; we compute, per expert e:
6:
7: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
8:
9: where x_e is the slice of permuted hidden states routed to expert e, with
10: expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11:
12: The reference loops over experts in Python. Slow, but pedagogically clear and
13: correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14: are NOT used here, but the reference is exempt — only solution.py is checked.
15: """
16: from __future__ import annotations
17:
18: import torch
19: import torch.nn as nn
20: import torch.nn.functional as F
21:
22: OP_TYPE = "grouped_gemm_swiglu"
23: SUPPORTED_PRECISIONS = ["bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26:
27: class Model(nn.Module):
28: """Up-projection of a top-K MoE FFN with fused SwiGLU.
29:
30: Inputs at call time:
31: hidden_states: (T_perm, H) bf16, already permuted to expert order
32: expert_offsets: (E+1,) int32, prefix sums of token counts per expert
33:
... (2866 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.ge
... (2793 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: return float(eval(expr, {"__builtins__": {}}, vars))
27:
28:
29: def main():
30: import reference
31: import shapes
32: import solution
33:
34: meta = yaml.safe_load(Path("problem.yaml").read_text())
35: hw = get_hw(meta["hardware"][0])
36: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37: peak_gbps = hw.peak_bandwidth_gb_s
38: regime = meta.get("regime", "compute")
39: flops_formula = meta["
... (3701 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Defaults match the sonic-moe paper's headline configuration. We add:
4: - a smaller shape for fast iteration during agent development
5: - a wider intermediate (different aspect ratio) to stress N-tile selection
6: """
7:
8: SHAPES = [
9: # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10: {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11:
12: # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13: # variable-length grouped layout meaningful, but smaller token / hidden dims.
14: {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15:
16: # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17: # Forces tiles to handle larger N relative to K.
18: {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml</path>
<type>file</type>
<content>
1: name: 06_sonic_moe_swiglu
2: display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3: precision: bf16
4: regime: compute
5: hardware: [RTX_PRO_6000]
6: peak_tflops_key: bf16
7: peak_bandwidth_key: dram
8:
9: # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10: # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11: # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12: # only if you count routing. The standard MoE FLOPs convention counts only the
13: # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14: flops_formula: "2 * T_total * H * (2 * I)"
15:
16: # Bytes moved (approximate, lower bound):
17: # read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18: # expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19: bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25: # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26: # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27: forbidden:
28: - "torch.matmul"
29: - "torch.bmm"
30: - "torch.nn.functional.linear"
31: - "F.linear"
32: - "from son
... (681 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
2:
3: Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
4: requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
5: SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
6: the package installs cleanly but kernels may fail at launch on SM120 (the
7: QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
8:
9: If the live call fails, `is_available()` returns False and the benchmark scores
10: the agent against PyTorch eager + the documented H100 paper ceiling (see
11: problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
12: importing sonic_moe in solution.py (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def _try_sonic_moe(
20: hidden_states: torch.Tensor,
21: W_gate: torch.Tensor,
22: W_up: torch.Tensor,
23: expert_offsets: torch.Tensor,
24: ) -> torch.Tensor | None:
25: try:
26: import sonic_moe # type: ignore # noqa: F401
27: except Exception:
28: return None
29: try:
30: # Public sonic-moe API surface is still stabilizing. The expected entry
31: # point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
32: # SwiGLU. Adapt to the a
... (1755 more chars){
"todos": [
{
"content": "Check GPU availability and environment (nvidia-smi, CUDA, nvcc)",
"status": "in_progress",
"priority": "high"
},
{
"content": "Decide approach: CUTLASS grouped GEMM vs Triton vs raw CUDA",
"status": "pending",
"priority": "high"
},
{
"content": "Implement grouped GEMM + fused SwiGLU kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py fo
... (205 more chars)content: 'Check GPU availability and environment (nvidia-smi, CUDA, nvcc)'
status: 'in_progress'
priority: 'high'
content: 'Decide approach: CUTLASS grouped GEMM vs Triton vs raw CUDA'
status: 'pending'
priority: 'high'
content: 'Implement grouped GEMM + fused SwiGLU kernel'
status: 'pending'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py for performance'
status: 'pending'
priority: 'high'
content: 'Profile with ncu and optimize hotspots'
status: 'pending'
priority: 'medium'{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total,driver_version --format=csv,noheader",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB, 595.58.03
{
"command": "nvcc --version 2>/dev/null || echo \"nvcc not found\"",
"description": "Check nvcc version"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
{
"command": "uv run python -c \"import torch; print(torch.__version__); print('CUDA:', torch.version.cuda); print('BF16:', torch.cuda.get_device_properties(0))\"",
"description": "Check PyTorch CUDA"
}2.11.0+cu130
CUDA: 13.0
BF16: _CudaDeviceProperties(name='NVIDIA RTX PRO 6000 Blackwell Workstation Edition', major=12, minor=0, total_memory=97227MB, multi_processor_count=188, uuid=d594fb7a-abcb-bad5-1fc4-675b478a47df, pci_bus_id=1, pci_device_id=0, pci_domain_id=0, L2_cache_size=128MB)
{
"command": "pip list 2>/dev/null | grep -i triton; uv run python -c \"import triton; print(triton.__version__)\" 2>/dev/null || echo \"triton not installed\"",
"description": "Check triton availability"
}triton 3.6.0
3.6.0
{
"command": "ls /usr/local/cuda*/include/cutlass 2>/dev/null || ls /usr/include/cutlass 2>/dev/null || find /home/infatoshi -maxdepth 4 -name \"cutlass\" -type d 2>/dev/null | head -5 || echo \"no cutlass found\"",
"description": "Check CUTLASS availability"
}zsh:1: no matches found: /usr/local/cuda*/include/cutlass
{
"command": "which ncu 2>/dev/null || echo \"ncu not found\"; which nsys 2>/dev/null || echo \"nsys not found\"",
"description": "Check profiling tools"
}/usr/local/cuda-13/bin/ncu
/usr/local/cuda-13/bin/nsys
{
"todos": [
{
"content": "Check GPU availability and environment (nvidia-smi, CUDA, nvcc)",
"status": "completed",
"priority": "high"
},
{
"content": "Decide approach: CUTLASS grouped GEMM vs Triton vs raw CUDA",
"status": "completed",
"priority": "high"
},
{
"content": "Implement grouped GEMM + fused SwiGLU kernel (v1)",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchm
... (214 more chars)content: 'Check GPU availability and environment (nvidia-smi, CUDA, nvcc)'
status: 'completed'
priority: 'high'
content: 'Decide approach: CUTLASS grouped GEMM vs Triton vs raw CUDA'
status: 'completed'
priority: 'high'
content: 'Implement grouped GEMM + fused SwiGLU kernel (v1)'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py for performance'
status: 'pending'
priority: 'high'
content: 'Profile with ncu and optimize hotspots'
status: 'pending'
priority: 'medium'{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && ./scripts/patch_torch.sh 2>&1 | tail -5",
"description": "Apply torch inductor CSE hotfix"
}Applying torch inductor CSE typing hotfix...
[skip] already patched: .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py
[skip] already patched: /home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py
Done.
{
"command": "uv run python -c \"import triton; import triton.language as tl; help(tl.dot)\" 2>&1 | head -60",
"description": "Check tl.dot signature"
}Help on function dot in module triton.language.core:
dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=triton.language.float32, _semantic=None)
Returns the matrix product of two blocks.
The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions.
For three-dimensional blocks, `tl.dot` performs the batched matrix product,
where the first dimension of each block represents the batch dimension.
:param input: The first tensor to be multiplied.
:type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second tensor to be multiplied.
:type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
:type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
:param input_precision: How to exercise the Tensor Cores for f32 x f32. If
the device does not have Tensor Cores or the inputs are not of dtype f32,
this option is ignored. For devices that do have tensor cores, the
default precision is tf32.
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Ava
... (280 more chars){
"command": "uv run python -c \"import triton.language as tl; help(tl.sigmoid)\" 2>&1 | head -15",
"description": "Check tl.sigmoid signature"
}Help on JITFunction in module triton.language.standard:
sigmoid = JITFunction(triton.language.standard:sigmoid)
Computes the element-wise sigmoid of :code:`x`.
:param x: the input values
:type x: Block
This function can also be called as a member function on :py:class:`tensor`,
as :code:`x.sigmoid(...)` instead of
:code:`sigmoid(x, ...)`.
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval</path>
<type>directory</type>
<entries>
__init__.py
__pycache__/
correctness.py
report.py
roofline.py
shapes.py
timing.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,205 @@
+"""Triton grouped GEMM + fused SwiGLU for MoE up-projection (SM120 Blackwell).
+
+Each program instance handles one (expert, n_tile) pair.
+Within each program, we loop over K tiles to accumulate the full GEMM.
+SwiGLU is fused in the epilogue: silu(gate) * up.
+
+Inputs:
+ hidden_states: (T_perm, H) bf16, permuted to expert order
+ expert_offsets: (E+1,) int32, prefix sums of token counts per expert
+ W_gate: (E, H, I) bf16
+ W_up: (E, H, I) bf16
+
+Output:
+ gated_up: (T_perm, I) bf16
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+# --- Tunable block sizes ---
+BM = 128
+BK = 64
+BN = 128
+
+
+@triton.jit
+def _grouped_gemm_swiglu_kernel(
+ x_ptr,
+ w_gate_ptr,
+ w_up_ptr,
+ out_ptr,
+ offsets_ptr,
+ T_perm: tl.constexpr,
+ H: tl.constexpr,
+ I: tl.constexpr,
+ E: tl.constexpr,
+ num_n_tiles: tl.constexpr,
+ BM: tl.constexpr,
+ BK: tl.constexpr,
+ BN: tl.constexpr,
+ num_k_tiles: tl.constexpr,
+ stride_xm,
+ stride_xk_unused,
+ stride_wg_h,
+ stride_wg_i,
+ stride_om,
+ stride_oi,
+):
+ """Grouped GEMM with fused SwiGLU epilogue.
+
+ Grid is 1D with total_programs = num_n_tiles * E.
+ Each program handles one (expert, n_tile) pair and loops over K tiles.
+ """
+ pid = tl.program_id(0)
+ expert = pid // num_n_tiles
+ n_tile = pid % num_n_tiles
+
+ m_start = 0
+ n_start = n_tile * BN
+ num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
+
+ n_offs = n_start + tl.arange(0, BN)
+ n_mask = n_offs < I
+
+ k_offs_base = tl.arange(0, BK)
+
+ gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ up_acc = tl.zeros((BM, BN), dtype=tl.float32)
+
+ for k_tile in range(num_k_tiles):
+ k_start = k_tile * BK
+ k_offs = k_start + k_offs_base
+ k_mask = k_offs < H
+
+ row_offs = m_start + tl.arange(0, BM)
+ load_row_mask = row_offs < num_rows
+ row_ptrs = row_offs * stride_xm
+
+ x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
+ load_mask = load_row_mask[:, None] & k_mask[None, :]
+ x = tl.load(x_ptrs, mask=load_mask, other=0.0)
+
+ wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ wg_mask = k_mask[:, None] & n_mask[None, :]
+ wg = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
+
+ wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ wu = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
+
+ gate_acc = tl.dot(x.to(tl.bfloat16), wg.to(tl.bfloat16), acc=gate_acc)
+ up_acc = tl.dot(x.to(tl.bfloat16), wu.to(tl.bfloat16), acc=up_acc)
+
+ gate = gate_acc
+ up = up_acc
+
+ gate_silu = gate * tl.sigmoid(gate)
+ out_fp32 = gate_silu * up
+ result = out_fp32.to(tl.bfloat16)
+
+ store_row_mask = row_offs < num_rows
+ out_ptrs = out_ptr + row_offs[:, None] * stride_om + n_offs[None, :] * stride_oi
+ store_mask = store_row_mask[:, None] & n_mask[None, :]
+ tl.store(out_ptrs, result, mask=store_mask)
+
+
+class FusedSwiGLUGroupedGEMM(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, hidden_states, W_gate, W_up, expert_offsets, E, I):
+ T_perm, H = hidden_states.shape
+ out = torch.empty((T_perm, I), dtype=torch.bfloat16, device=hidden_states.device)
+
+ num_k_tiles = triton.cdiv(H, BK)
+ num_n_tiles = triton.cdiv(I, BN)
+ total_programs = num_n_tiles * E
+
+ if total_programs > 0:
+ grid = (total_programs,)
+
+ stride_xm = H
+ stride_xk = 1
+ stride_wg_h = I
+ stride_wg_i = 1
+ stride_om = I
+ stride_oi = 1
+
+ _grouped_gemm_swiglu_kernel[grid](
+ hidden_states,
+ W_gate,
+ W_up,
+ out,
+ expert_offsets,
+ T_perm=T_perm,
+ H=H,
+ I=I,
+ E=E,
+ num_n_tiles=num_n_tiles,
+ BM=BM,
+ BK=BK,
+ BN=BN,
+ num_k_tiles=num_k_tiles,
+ stride_xm=stride_xm,
+ stride_xk_unused=stride_xk,
+ stride_wg_h=stride_wg_h,
+ stride_wg_i=stride_wg_i,
+ stride_om=stride_om,
+ stride_oi=stride_oi,
+ )
+
+ return out
+
+
+class Model(nn.Module):
+ def __init__(self, T_total: int, H: int, I: int, E: int, K: int):
+ super().__init__()
+ self.T_total = T_total
+ self.H = H
+ self.I = I
+ self.E = E
+ self.K = K
+ self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ nn.init.normal_(self.W_gate, std=0.02)
+ nn.init.normal_(self.W_up, std=0.02)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ ) -> torch.Tensor:
+ return FusedSwiGLUGroupedGEMM.apply(
+ hidden_states, self.W_gate, self.W_up, expert_offsets, self.E, self.I
+ )
+
+
+T_total = 32768
+H = 4096
+I = 1536
+E = 128
+K = 8
+
+
+def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
+ T_perm = T_total * K
+ base = T_perm // E
+ rem = T_perm - base * E
+ counts = torch.full((E,), base, dtype=torch.int32, device=device)
+ counts[:rem] += 1
+ offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
+ offsets[1:] = torch.cumsum(counts, dim=0)
+ return offsets
+
+
+def get_inputs():
+ T_perm = T_total * K
+ hidden_states = torch.randn(T_perm,
... (189 more chars){
"content": "<5867 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "timeout 120 uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 174, in forward
return FusedSwiGLUGroupedGEMM.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 596, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 130, in forward
_grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, *
... (1785 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -63,6 +63,10 @@
n_start = n_tile * BN
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
+ row_offs = m_start + tl.arange(0, BM)
+ load_row_mask = row_offs < num_rows
+ row_ptrs = row_offs * stride_xm
+
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
@@ -75,10 +79,6 @@
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
-
- row_offs = m_start + tl.arange(0, BM)
- load_row_mask = row_offs < num_rows
- row_ptrs = row_offs * stride_xm
x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
load_mask = load_row_mask[:, None] & k_mask[None, :]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<790 chars \u2014 see below>",
"newString": "<778 chars \u2014 see below>"
} m_start = 0
n_start = n_tile * BN
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
row_ptrs = row_offs * stride_xm
x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
load_mask = load_row_mask[:, None] & k_mask[None, :]
x = tl.load(x_ptrs, mask=load_mask, other=0.0) m_start = 0
n_start = n_tile * BN
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
row_ptrs = row_offs * stride_xm
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
load_mask = load_row_mask[:, None] & k_mask[None, :]
x = tl.load(x_ptrs, mask=load_mask, other=0.0)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -101,8 +101,7 @@
out_fp32 = gate_silu * up
result = out_fp32.to(tl.bfloat16)
- store_row_mask = row_offs < num_rows
- out_ptrs = out_ptr + row_offs[:, None] * stride_om + n_offs[None, :] * stride_oi
+ out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
store_mask = store_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<234 chars \u2014 see below>",
"newString": "<169 chars \u2014 see below>"
} gate_silu = gate * tl.sigmoid(gate)
out_fp32 = gate_silu * up
result = out_fp32.to(tl.bfloat16)
store_row_mask = row_offs < num_rows
out_ptrs = out_ptr + row_offs[:, None] * stride_om + n_offs[None, :] * stride_oi gate_silu = gate * tl.sigmoid(gate)
out_fp32 = gate_silu * up
result = out_fp32.to(tl.bfloat16)
out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]Edit applied successfully.{
"command": "timeout 120 uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check v2"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 173, in forward
return FusedSwiGLUGroupedGEMM.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 596, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 129, in forward
_grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, *
... (1865 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -102,7 +102,7 @@
result = out_fp32.to(tl.bfloat16)
out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
- store_mask = store_row_mask[:, None] & n_mask[None, :]
+ store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"newString": "<118 chars \u2014 see below>",
"oldString": "<119 chars \u2014 see below>"
} out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :] out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
store_mask = store_row_mask[:, None] & n_mask[None, :]Edit applied successfully.{
"command": "timeout 120 uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check v3"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.198242 (atol=0.02, rtol=0.02)
{
"command": "<594 chars \u2014 see below>",
"description": "Test tl.dot bf16 accumulation precision"
}uv run python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def test_dot_kernel(out_ptr, BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr):
a = tl.full((BM, BK), 1.0, dtype=tl.bfloat16)
b = tl.full((BK, BN), 1.0, dtype=tl.bfloat16)
acc = tl.zeros((BM, BN), dtype=tl.float32)
c = tl.dot(a, b, acc=acc)
# store first element
tl.store(out_ptr, c[0, 0])
out = torch.empty(1, dtype=torch.float32, device='cuda')
BM, BK, BN = 16, 16, 16
test_dot_kernel[(1,)](out, BM=BM, BK=BK, BN=BN)
print('Result:', out.item(), '(expected:', BK, ')')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars){
"command": "<1206 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug correctness on small shape"
}uv run python -c "
import sys, torch
sys.path.insert(0, '.')
import shapes, reference, solution
# Use smallest shape for speed
shape = shapes.SHAPES[1]
T_total, H, I, E, K = shape['T_total'], shape['H'], shape['I'], shape['E'], shape['K']
reference.T_total = T_total
reference.H = H
reference.I = I
reference.E = E
reference.K = K
device = torch.device('cuda:0')
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
ref_f = ref_out.float()
sol_f = sol_out.float()
diff = (ref_f - sol_f).abs()
print(f' max_abs_diff: {diff.max().item():.6f}')
print(f' mean_abs_diff: {diff.mean().item():.6f}')
print(f' allclose(atol=0.02, rtol=0.02): {torch.allclose(ref_f, sol_f, atol=0.02, rtol=0.02)}')
print(f' shapes: ref={ref_out.shape}, sol={sol_out.shape}')
print(f' ref range: [{ref_f.min():.3f}, {ref_f.max():.3f}]')
print(f' sol range: [{sol_f.min():.3f}, {sol_f.max():.3f}]')
" 2>&1max_abs_diff: 0.067871
mean_abs_diff: 0.002549
allclose(atol=0.02, rtol=0.02): False
shapes: ref=torch.Size([16384, 1024]), sol=torch.Size([16384, 1024])
ref range: [-0.068, 0.060]
sol range: [-0.039, 0.044]
{
"command": "<1059 chars \u2014 see below>",
"timeout": 60000,
"description": "Test tl.dot precision with bf16/fp32 acc"
}cat > /tmp/test_dot_prec.py << 'PYEOF'
import torch
import triton
import triton.language as tl
@triton.jit
def _test_bf16_dot(out_ptr, N: tl.constexpr):
pid = tl.program_id(0)
a_data = tl.full((N, N), 0.001, dtype=tl.bfloat16)
b_data = tl.full((N, N), 0.001, dtype=tl.bfloat16)
acc_bf16 = tl.zeros((N, N), dtype=tl.bfloat16)
acc_fp32 = tl.zeros((N, N), dtype=tl.float32)
c1 = tl.dot(a_data, b_data, acc=acc_bf16)
c1 = tl.dot(a_data, b_data, acc=c1)
c2 = tl.dot(a_data, b_data, acc=acc_fp32)
c2 = tl.dot(a_data, b_data, acc=c2)
if pid == 0:
tl.store(out_ptr, c1[0, 0])
tl.store(out_ptr + 1, c2[0, 0])
out = torch.empty(2, dtype=torch.float32, device='cuda')
_test_bf16_dot[(1,)](out, N=16)
print("bf16 acc result:", out[0].item())
print("fp32 acc result:", out[1].item())
# expected: 2 * 16 * 0.001 * 0.001 = 3.2e-5 (approx)
print("expected:", 2 * 16 * 0.001 * 0.001)
PYEOF
cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python /tmp/test_dot_prec.py 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 2054, in dot
res = _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1558, in dot
assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/test_dot_prec.py", line 24, in <module>
_test_bf16_dot[(1,)](out, N=16)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature,
... (1409 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"limit": 85
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: """Triton grouped GEMM + fused SwiGLU for MoE up-projection (SM120 Blackwell).
2:
3: Each program instance handles one (expert, n_tile) pair.
4: Within each program, we loop over K tiles to accumulate the full GEMM.
5: SwiGLU is fused in the epilogue: silu(gate) * up.
6:
7: Inputs:
8: hidden_states: (T_perm, H) bf16, permuted to expert order
9: expert_offsets: (E+1,) int32, prefix sums of token counts per expert
10: W_gate: (E, H, I) bf16
11: W_up: (E, H, I) bf16
12:
13: Output:
14: gated_up: (T_perm, I) bf16
15: """
16: from __future__ import annotations
17:
18: import torch
19: import torch.nn as nn
20: import triton
21: import triton.language as tl
22:
23:
24: # --- Tunable block sizes ---
25: BM = 128
26: BK = 64
27: BN = 128
28:
29:
30: @triton.jit
31: def _grouped_gemm_swiglu_kernel(
32: x_ptr,
33: w_gate_ptr,
34: w_up_ptr,
35: out_ptr,
36: offsets_ptr,
37: T_perm: tl.constexpr,
38: H: tl.constexpr,
39: I: tl.constexpr,
40: E: tl.constexpr,
41: num_n_tiles: tl.constexpr,
42: BM: tl.constexpr,
43: BK: tl.constexpr,
44: BN: tl.constexpr,
45: num_k_tiles: tl.constexpr,
46: stride_xm,
47: stride_xk_unused,
48: stride_wg_h,
49: stride_wg_i,
50: stride_om,
51: stride_oi,
52: ):
53: """Grouped GEMM with fused SwiGLU e
... (1212 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -4,14 +4,7 @@
Within each program, we loop over K tiles to accumulate the full GEMM.
SwiGLU is fused in the epilogue: silu(gate) * up.
-Inputs:
- hidden_states: (T_perm, H) bf16, permuted to expert order
- expert_offsets: (E+1,) int32, prefix sums of token counts per expert
- W_gate: (E, H, I) bf16
- W_up: (E, H, I) bf16
-
-Output:
- gated_up: (T_perm, I) bf16
+Uses fp32 internal accumulation for precision; casts to bf16 at the end.
"""
from __future__ import annotations
@@ -50,11 +43,6 @@
stride_om,
stride_oi,
):
- """Grouped GEMM with fused SwiGLU epilogue.
-
- Grid is 1D with total_programs = num_n_tiles * E.
- Each program handles one (expert, n_tile) pair and loops over K tiles.
- """
pid = tl.program_id(0)
expert = pid // num_n_tiles
n_tile = pid % num_n_tiles
@@ -82,17 +70,20 @@
x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
load_mask = load_row_mask[:, None] & k_mask[None, :]
- x = tl.load(x_ptrs, mask=load_mask, other=0.0)
+ x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
+ x = x_bf16.to(tl.float32)
wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
wg_mask = k_mask[:, None] & n_mask[None, :]
- wg = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
+ wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
+ wg = wg_bf16.to(tl.float32)
wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
- wu = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
+ wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
+ wu = wu_bf16.to(tl.float32)
- gate_acc = tl.dot(x.to(tl.bfloat16), wg.to(tl.bfloat16), acc=gate_acc)
- up_acc = tl.dot(x.to(tl.bfloat16), wu.to(tl.bfloat16), acc=up_acc)
+ gate_acc += tl.dot(x, wg)
+ up_acc += tl.dot(x, wu)
gate = gate_acc
up = up_acc
{
"content": "<5445 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "timeout 120 uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check with fp32 accumulation"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.199219 (atol=0.02, rtol=0.02)
{
"command": "<1505 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug per-expert correctness"
}uv run python -c "
import sys, torch
sys.path.insert(0, '.')
import shapes, reference, solution
# Use smallest shape for speed
shape = shapes.SHAPES[1]
T_total, H, I, E, K = shape['T_total'], shape['H'], shape['I'], shape['E'], shape['K']
reference.T_total = T_total
reference.H = H
reference.I = I
reference.E = E
reference.K = K
device = torch.device('cuda:0')
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
hidden_states, expert_offsets = inputs
# Check per-expert correctness
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
for e in range(min(5, E)):
start = int(expert_offsets[e].item())
end = int(expert_offsets[e+1].item())
if end == start:
continue
ref_e = ref_out[start:end]
sol_e = sol_out[start:end]
diff = (ref_e.float() - sol_e.float()).abs()
print(f'expert {e}: M_e={end-start}, max_diff={diff.max():.6f}, mean_diff={diff.mean():.6f}')
# Also check: what's inside each expert's region?
# Do sol outputs look like zeros or random values?
print()
print('ref_out stats:', ref_out.float().mean(), ref_out.float().std())
print('sol_out stats:', sol_out.float().mean(), sol_out.float().std())
print('sol_out zeros?', (sol_out == 0).sum().item(), 'out of', sol_out.numel())
" 2>&1expert 0: M_e=256, max_diff=0.057220, mean_diff=0.003242
expert 1: M_e=256, max_diff=0.053711, mean_diff=0.002557
expert 2: M_e=256, max_diff=0.045410, mean_diff=0.002538
expert 3: M_e=256, max_diff=0.047119, mean_diff=0.002533
expert 4: M_e=256, max_diff=0.045654, mean_diff=0.002529
ref_out stats: tensor(-5.8693e-07, device='cuda:0') tensor(0.0040, device='cuda:0')
sol_out stats: tensor(2.1199e-08, device='cuda:0') tensor(0.0004, device='cuda:0')
sol_out zeros? 16646144 out of 16777216
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,10 +1,9 @@
"""Triton grouped GEMM + fused SwiGLU for MoE up-projection (SM120 Blackwell).
Each program instance handles one (expert, n_tile) pair.
-Within each program, we loop over K tiles to accumulate the full GEMM.
-SwiGLU is fused in the epilogue: silu(gate) * up.
-
-Uses fp32 internal accumulation for precision; casts to bf16 at the end.
+Within each program, we loop over M tiles (for variable token counts per expert)
+and K tiles (to accumulate the full GEMM).
+SwiGLU is fused in the epilogue: silu(gate) * up per M tile.
"""
from __future__ import annotations
@@ -14,7 +13,6 @@
import triton.language as tl
-# --- Tunable block sizes ---
BM = 128
BK = 64
BN = 128
@@ -36,6 +34,7 @@
BK: tl.constexpr,
BN: tl.constexpr,
num_k_tiles: tl.constexpr,
+ max_m_tiles: tl.constexpr,
stride_xm,
stride_xk_unused,
stride_wg_h,
@@ -47,54 +46,60 @@
expert = pid // num_n_tiles
n_tile = pid % num_n_tiles
- m_start = 0
n_start = n_tile * BN
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
-
- row_offs = m_start + tl.arange(0, BM)
- load_row_mask = row_offs < num_rows
- row_ptrs = row_offs * stride_xm
+ m_tiles = tl.cdiv(num_rows, BM)
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
- gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
- up_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ for m_tile in range(max_m_tiles):
+ if m_tile >= m_tiles:
+ break
- for k_tile in range(num_k_tiles):
- k_start = k_tile * BK
- k_offs = k_start + k_offs_base
- k_mask = k_offs < H
+ m_start = m_tile * BM
- x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
- load_mask = load_row_mask[:, None] & k_mask[None, :]
- x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
- x = x_bf16.to(tl.float32)
+ row_offs = m_start + tl.arange(0, BM)
+ load_row_mask = row_offs < num_rows
+ row_ptrs = row_offs * stride_xm
- wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
- wg_mask = k_mask[:, None] & n_mask[None, :]
- wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
- wg = wg_bf16.to(tl.float32)
+ gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ up_acc = tl.zeros((BM, BN), dtype=tl.float32)
- wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
- wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
- wu = wu_bf16.to(tl.float32)
+ for k_tile in range(num_k_tiles):
+ k_start = k_tile * BK
+ k_offs = k_start + k_offs_base
+ k_mask = k_offs < H
- gate_acc += tl.dot(x, wg)
- up_acc += tl.dot(x, wu)
+ x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
+ load_mask = load_row_mask[:, None] & k_mask[None, :]
+ x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
+ x = x_bf16.to(tl.float32)
- gate = gate_acc
- up = up_acc
+ wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ wg_mask = k_mask[:, None] & n_mask[None, :]
+ wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
+ wg = wg_bf16.to(tl.float32)
- gate_silu = gate * tl.sigmoid(gate)
- out_fp32 = gate_silu * up
- result = out_fp32.to(tl.bfloat16)
+ wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
+ wu = wu_bf16.to(tl.float32)
- out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
- store_mask = load_row_mask[:, None] & n_mask[None, :]
- tl.store(out_ptrs, result, mask=store_mask)
+ gate_acc += tl.dot(x, wg)
+ up_acc += tl.dot(x, wu)
+
+ gate = gate_acc
+ up = up_acc
+
+ gate_silu = gate * tl.sigmoid(gate)
+ out_fp32 = gate_silu * up
+ result = out_fp32.to(tl.bfloat16)
+
+ out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
+ store_mask = load_row_mask[:, None] & n_mask[None, :]
+ tl.store(out_ptrs, result, mask=store_mask)
class FusedSwiGLUGroupedGEMM(torch.autograd.Function):
@@ -107,15 +112,11 @@
num_n_tiles = triton.cdiv(I, BN)
total_programs = num_n_tiles * E
- if total_programs > 0:
+ max_tokens_per_expert = int(expert_offsets[1:].diff().max().item()) if T_perm > 0 else 0
+ max_m_tiles = triton.cdiv(max_tokens_per_expert, BM)
+
+ if total_programs > 0 and max_m_tiles > 0:
grid = (total_programs,)
-
- stride_xm = H
- stride_xk = 1
- stride_wg_h = I
- stride_wg_i = 1
- stride_om = I
- stride_oi = 1
_grouped_gemm_swiglu_kernel[grid](
hidden_states,
@@ -132,12 +133,13 @@
BK=BK,
BN=BN,
num_k_tiles=num_k_tiles,
- stride_xm=stride_xm,
- stride_xk_unused=stride_xk,
- stride_wg_h=stride_wg_h,
- stride_wg_i=stride_wg_i,
- stride_om=stride_om,
- stride_oi=stride_oi,
+ max_m_tiles=max_m_tiles,
+ stride_xm=H,
+ stride_xk_unused=1,
+ stride_wg_h=I,
+ stride_wg_i=1,
+ stride_om=I,
+ stride_oi=1,
)
return out
{
"content": "<5701 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -56,8 +56,9 @@
k_offs_base = tl.arange(0, BK)
for m_tile in range(max_m_tiles):
- if m_tile >= m_tiles:
- break
+ skip = m_tile >= m_tiles
+ if skip:
+ continue
m_start = m_tile * BM
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"newString": "<381 chars \u2014 see below>",
"oldString": "<358 chars \u2014 see below>"
} for m_tile in range(max_m_tiles):
skip = m_tile >= m_tiles
if skip:
continue
m_start = m_tile * BM
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
row_ptrs = row_offs * stride_xm
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32) for m_tile in range(max_m_tiles):
if m_tile >= m_tiles:
break
m_start = m_tile * BM
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
row_ptrs = row_offs * stride_xm
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -30,6 +30,12 @@
I: tl.constexpr,
E: tl.constexpr,
num_n_tiles: tl.constexpr,
+ BM: tl.constexpr,
+ BK: tl.constexpr,
+ BN: tl.constexpr,
+ num_k_tiles: tl.constexpr,
+ max_m_tiles: tl.constexpr,
+ stride_xm,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"newString": "<194 chars \u2014 see below>",
"oldString": "<51 chars \u2014 see below>"
} E: tl.constexpr,
num_n_tiles: tl.constexpr,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
num_k_tiles: tl.constexpr,
max_m_tiles: tl.constexpr,
stride_xm, E: tl.constexpr,
num_n_tiles: tl.constexpr,Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"limit": 55
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: """Triton grouped GEMM + fused SwiGLU for MoE up-projection (SM120 Blackwell).
2:
3: Each program instance handles one (expert, n_tile) pair.
4: Within each program, we loop over M tiles (for variable token counts per expert)
5: and K tiles (to accumulate the full GEMM).
6: SwiGLU is fused in the epilogue: silu(gate) * up per M tile.
7: """
8: from __future__ import annotations
9:
10: import torch
11: import torch.nn as nn
12: import triton
13: import triton.language as tl
14:
15:
16: BM = 128
17: BK = 64
18: BN = 128
19:
20:
21: @triton.jit
22: def _grouped_gemm_swiglu_kernel(
23: x_ptr,
24: w_gate_ptr,
25: w_up_ptr,
26: out_ptr,
27: offsets_ptr,
28: T_perm: tl.constexpr,
29: H: tl.constexpr,
30: I: tl.constexpr,
31: E: tl.constexpr,
32: num_n_tiles: tl.constexpr,
33: BM: tl.constexpr,
34: BK: tl.constexpr,
35: BN: tl.constexpr,
36: num_k_tiles: tl.constexpr,
37: max_m_tiles: tl.constexpr,
38: stride_xm,
39: BM: tl.constexpr,
40: BK: tl.constexpr,
41: BN: tl.constexpr,
42: num_k_tiles: tl.constexpr,
43: max_m_tiles: tl.constexpr,
44: stride_xm,
45: stride_xk_unused,
46: stride_wg_h,
47: stride_wg_i,
48: stride_om,
49: stride_oi,
50: ):
51: pid = tl.program_id(0)
52: expert = pid // num_n_tiles
53: n_tile = pid % num_n_tiles
54:
5
... (96 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -30,12 +30,6 @@
I: tl.constexpr,
E: tl.constexpr,
num_n_tiles: tl.constexpr,
- BM: tl.constexpr,
- BK: tl.constexpr,
- BN: tl.constexpr,
- num_k_tiles: tl.constexpr,
- max_m_tiles: tl.constexpr,
- stride_xm,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"newString": "<195 chars \u2014 see below>",
"oldString": "<338 chars \u2014 see below>"
} num_n_tiles: tl.constexpr,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
num_k_tiles: tl.constexpr,
max_m_tiles: tl.constexpr,
stride_xm,
stride_xk_unused, num_n_tiles: tl.constexpr,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
num_k_tiles: tl.constexpr,
max_m_tiles: tl.constexpr,
stride_xm,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
num_k_tiles: tl.constexpr,
max_m_tiles: tl.constexpr,
stride_xm,
stride_xk_unused,Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"offset": 110,
"limit": 50
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
110: out = torch.empty((T_perm, I), dtype=torch.bfloat16, device=hidden_states.device)
111:
112: num_k_tiles = triton.cdiv(H, BK)
113: num_n_tiles = triton.cdiv(I, BN)
114: total_programs = num_n_tiles * E
115:
116: max_tokens_per_expert = int(expert_offsets[1:].diff().max().item()) if T_perm > 0 else 0
117: max_m_tiles = triton.cdiv(max_tokens_per_expert, BM)
118:
119: if total_programs > 0 and max_m_tiles > 0:
120: grid = (total_programs,)
121:
122: _grouped_gemm_swiglu_kernel[grid](
123: hidden_states,
124: W_gate,
125: W_up,
126: out,
127: expert_offsets,
128: T_perm=T_perm,
129: H=H,
130: I=I,
131: E=E,
132: num_n_tiles=num_n_tiles,
133: BM=BM,
134: BK=BK,
135: BN=BN,
136: num_k_tiles=num_k_tiles,
137: max_m_tiles=max_m_tiles,
138: stride_xm=H,
139: stride_xk_unused=1,
140: stride_wg_h=I,
141: stride_wg_i=1,
142: stride_om=I,
143: stride_oi=1,
144: )
145:
146: return out
147:
148:
149: class Model(nn.Module):
150:
... (519 more chars){
"command": "timeout 120 uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check with M tile loop"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 167, in forward
return FusedSwiGLUGroupedGEMM.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 596, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 122, in forward
_grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, *
... (1772 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"limit": 50,
"offset": 57
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
57:
58: for m_tile in range(max_m_tiles):
59: skip = m_tile >= m_tiles
60: if skip:
61: continue
62:
63: m_start = m_tile * BM
64:
65: row_offs = m_start + tl.arange(0, BM)
66: load_row_mask = row_offs < num_rows
67: row_ptrs = row_offs * stride_xm
68:
69: gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
70: up_acc = tl.zeros((BM, BN), dtype=tl.float32)
71:
72: for k_tile in range(num_k_tiles):
73: k_start = k_tile * BK
74: k_offs = k_start + k_offs_base
75: k_mask = k_offs < H
76:
77: x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
78: load_mask = load_row_mask[:, None] & k_mask[None, :]
79: x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
80: x = x_bf16.to(tl.float32)
81:
82: wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
83: wg_mask = k_mask[:, None] & n_mask[None, :]
84: wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
85: wg = wg_bf16.to(tl.float32)
86:
87: wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
88: wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
89: wu = wu_bf16.to(tl.float32)
90:
... (620 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -56,51 +56,45 @@
k_offs_base = tl.arange(0, BK)
for m_tile in range(max_m_tiles):
- skip = m_tile >= m_tiles
- if skip:
- continue
+ if m_tile < m_tiles:
+ m_start = m_tile * BM
- m_start = m_tile * BM
+ row_offs = m_start + tl.arange(0, BM)
+ load_row_mask = row_offs < num_rows
+ row_ptrs = row_offs * stride_xm
- row_offs = m_start + tl.arange(0, BM)
- load_row_mask = row_offs < num_rows
- row_ptrs = row_offs * stride_xm
+ gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ up_acc = tl.zeros((BM, BN), dtype=tl.float32)
- gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
- up_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ for k_tile in range(num_k_tiles):
+ k_start = k_tile * BK
+ k_offs = k_start + k_offs_base
+ k_mask = k_offs < H
- for k_tile in range(num_k_tiles):
- k_start = k_tile * BK
- k_offs = k_start + k_offs_base
- k_mask = k_offs < H
+ x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
+ load_mask = load_row_mask[:, None] & k_mask[None, :]
+ x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
+ x = x_bf16.to(tl.float32)
- x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
- load_mask = load_row_mask[:, None] & k_mask[None, :]
- x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
- x = x_bf16.to(tl.float32)
+ wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ wg_mask = k_mask[:, None] & n_mask[None, :]
+ wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
+ wg = wg_bf16.to(tl.float32)
- wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
- wg_mask = k_mask[:, None] & n_mask[None, :]
- wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
- wg = wg_bf16.to(tl.float32)
+ wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
+ wu = wu_bf16.to(tl.float32)
- wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
- wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
- wu = wu_bf16.to(tl.float32)
+ gate_acc += tl.dot(x, wg)
+ up_acc += tl.dot(x, wu)
- gate_acc += tl.dot(x, wg)
- up_acc += tl.dot(x, wu)
+ gate_silu = gate_acc * tl.sigmoid(gate_acc)
+ out_fp32 = gate_silu * up_acc
+ result = out_fp32.to(tl.bfloat16)
- gate = gate_acc
- up = up_acc
-
- gate_silu = gate * tl.sigmoid(gate)
- out_fp32 = gate_silu * up
- result = out_fp32.to(tl.bfloat16)
-
- out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
- store_mask = load_row_mask[:, None] & n_mask[None, :]
- tl.store(out_ptrs, result, mask=store_mask)
+ out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
+ store_mask = load_row_mask[:, None] & n_mask[None, :]
+ tl.store(out_ptrs, result, mask=store_mask)
class FusedSwiGLUGroupedGEMM(torch.autograd.Function):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"newString": "<1704 chars \u2014 see below>",
"oldString": "<1664 chars \u2014 see below>"
} for m_tile in range(max_m_tiles):
if m_tile < m_tiles:
m_start = m_tile * BM
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
row_ptrs = row_offs * stride_xm
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
load_mask = load_row_mask[:, None] & k_mask[None, :]
x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
x = x_bf16.to(tl.float32)
wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
wg_mask = k_mask[:, None] & n_mask[None, :]
wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
wg = wg_bf16.to(tl.float32)
wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
wu = wu_bf16.to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
out_fp32 = gate_silu * up_acc
result = out_fp32.to(tl.bfloat16)
out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask) for m_tile in range(max_m_tiles):
skip = m_tile >= m_tiles
if skip:
continue
m_start = m_tile * BM
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
row_ptrs = row_offs * stride_xm
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
load_mask = load_row_mask[:, None] & k_mask[None, :]
x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
x = x_bf16.to(tl.float32)
wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
wg_mask = k_mask[:, None] & n_mask[None, :]
wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
wg = wg_bf16.to(tl.float32)
wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
wu = wu_bf16.to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate = gate_acc
up = up_acc
gate_silu = gate * tl.sigmoid(gate)
out_fp32 = gate_silu * up
result = out_fp32.to(tl.bfloat16)
out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)Edit applied successfully.{
"command": "timeout 120 uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check v2"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.177734 (atol=0.02, rtol=0.02)
{
"command": "<1449 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug correctness details"
}uv run python -c "
import sys, torch
sys.path.insert(0, '.')
import shapes, reference, solution
shape = shapes.SHAPES[1]
T_total, H, I, E, K = shape['T_total'], shape['H'], shape['I'], shape['E'], shape['K']
reference.T_total = T_total
reference.H = H
reference.I = I
reference.E = E
reference.K = K
device = torch.device('cuda:0')
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
ref_f = ref_out.float()
sol_f = sol_out.float()
print('ref range:', ref_f.min(), ref_f.max(), 'std:', ref_f.std())
print('sol range:', sol_f.min(), sol_f.max(), 'std:', sol_f.std())
print('sol zeros:', (sol_out == 0).sum().item(), '/ total:', sol_out.numel())
print('allclose(0.02):', torch.allclose(ref_f, sol_f, atol=0.02, rtol=0.02))
# Check per-expert
for e in range(min(5, E)):
start = int(inputs[1][e].item())
end = int(inputs[1][e+1].item())
if end == start: continue
ref_e = ref_out[start:end]
sol_e = sol_out[start:end]
diff = (ref_e.float() - sol_e.float()).abs()
print(f' expert {e} M_e={end-start}: max={diff.max():.6f} mean={diff.mean():.6f} sol_zero={((sol_e==0).sum()/2):.0f}/{end-start}')
" 2>&1ref range: tensor(-0.0669, device='cuda:0') tensor(0.0713, device='cuda:0') std: tensor(0.0040, device='cuda:0')
sol range: tensor(-0.0515, device='cuda:0') tensor(0.0527, device='cuda:0') std: tensor(0.0005, device='cuda:0')
sol zeros: 16515072 / total: 16777216
allclose(0.02): False
expert 0 M_e=256: max=0.062134 mean=0.003256 sol_zero=65536/256
expert 1 M_e=256: max=0.047363 mean=0.003277 sol_zero=65536/256
expert 2 M_e=256: max=0.041504 mean=0.002514 sol_zero=131072/256
expert 3 M_e=256: max=0.052979 mean=0.002544 sol_zero=131072/256
expert 4 M_e=256: max=0.051025 mean=0.002551 sol_zero=131072/256
{
"command": "<2203 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug kernel parameters and per-row stats"
}uv run python -c "
import sys, torch
sys.path.insert(0, '.')
import shapes, reference, solution
shape = shapes.SHAPES[1]
T_total, H, I, E, K = shape['T_total'], shape['H'], shape['I'], shape['E'], shape['K']
reference.T_total = T_total
reference.H = H
reference.I = I
reference.E = E
reference.K = K
device = torch.device('cuda:0')
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
hidden_states, expert_offsets = inputs
# Check kernel parameters
T_perm = hidden_states.shape[0]
tokens_per_expert = expert_offsets[1:].diff()
max_m = int(tokens_per_expert.max().item())
BM, BK_, BN_ = 128, 64, 128
num_k_tiles = (H + BK_ - 1) // BK_
num_n_tiles = (I + BN_ - 1) // BN_
max_m_tiles = (max_m + BM - 1) // BM
print(f'T_perm={T_perm}, H={H}, I={I}, E={E}')
print(f'M per expert: min={tokens_per_expert.min()}, max={tokens_per_expert.max()}, mean={tokens_per_expert.float().mean():.1f}')
print(f'num_k_tiles={num_k_tiles}, num_n_tiles={num_n_tiles}, max_m_tiles={max_m_tiles}')
print(f'total_programs={num_n_tiles * E}')
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
# Check how many elements differ
ref_f = ref_out.float()
sol_f = sol_out.float()
total = sol_out.numel()
zeros = (sol_out == 0).sum().item()
ref_zeros = (ref_out == 0).sum().item()
print(f'sol zeros: {zeros}/{total} ({100*zeros/total:.1f}%)')
print(f'ref zeros: {ref_zeros}/{total} ({100*ref_zeros/total:.1f}%)')
# Check a small region manually
e = 0
start = int(expert_offsets[e].item())
end = int(expert_offsets[e+1].item())
print(f'\\nExpert 0: rows {start}:{end}, M_e={end-start}')
# Check per-row correctness
ref_e = ref_out[start:end].float() # (M_e, I)
sol_e = sol_out[start:end].float()
row_diff = (ref_e - sol_e).abs().max(dim=1).values # (M_e,)
for i in range(0, end-start, 16):
chunk_diff = row_diff[i:min(i+16, end-start)]
print(f' rows {i}-{min(i+15,end-start-1)}: max_diff={chunk_diff.max():.6f}, mean={chunk_diff.mean():.6f}')
" 2>&1T_perm=16384, H=2048, I=1024, E=64
M per expert: min=256, max=256, mean=256.0
num_k_tiles=32, num_n_tiles=8, max_m_tiles=2
total_programs=512
sol zeros: 16515072/16777216 (98.4%)
ref zeros: 6511/16777216 (0.0%)
Expert 0: rows 0:256, M_e=256
rows 0-15: max_diff=0.041992, mean=0.029120
rows 16-31: max_diff=0.057739, mean=0.032727
rows 32-47: max_diff=0.052872, mean=0.027780
rows 48-63: max_diff=0.045668, mean=0.029935
rows 64-79: max_diff=0.036926, mean=0.028873
rows 80-95: max_diff=0.044922, mean=0.027596
rows 96-111: max_diff=0.034058, mean=0.027592
rows 112-127: max_diff=0.034767, mean=0.027769
rows 128-143: max_diff=0.040649, mean=0.029005
rows 144-159: max_diff=0.034180, mean=0.026944
rows 160-175: max_diff=0.043945, mean=0.029494
rows 176-191: max_diff=0.045471, mean=0.029363
rows 192-207: max_diff=0.035645, mean=0.026232
rows 208-223: max_diff=0.038467, mean=0.028862
rows 224-239: max_diff=0.051361, mean=0.028908
rows 240-255: max_diff=0.036377, mean=0.027529
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,9 +1,11 @@
"""Triton grouped GEMM + fused SwiGLU for MoE up-projection (SM120 Blackwell).
-Each program instance handles one (expert, n_tile) pair.
-Within each program, we loop over M tiles (for variable token counts per expert)
-and K tiles (to accumulate the full GEMM).
-SwiGLU is fused in the epilogue: silu(gate) * up per M tile.
+Each program instance handles ONE (expert, n_tile, m_tile) triple.
+Within each program, we loop over K tiles to accumulate the full GEMM.
+SwiGLU is fused in the epilogue: silu(gate) * up.
+
+Grid: E * num_n_tiles * max_m_tiles programs total.
+Program pid encodes (expert, n_tile, m_tile) via stride decomposition.
"""
from __future__ import annotations
@@ -25,16 +27,15 @@
w_up_ptr,
out_ptr,
offsets_ptr,
- T_perm: tl.constexpr,
H: tl.constexpr,
I: tl.constexpr,
E: tl.constexpr,
num_n_tiles: tl.constexpr,
+ max_m_tiles: tl.constexpr,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
num_k_tiles: tl.constexpr,
- max_m_tiles: tl.constexpr,
stride_xm,
stride_xk_unused,
stride_wg_h,
@@ -43,58 +44,61 @@
stride_oi,
):
pid = tl.program_id(0)
- expert = pid // num_n_tiles
- n_tile = pid % num_n_tiles
+ expert = pid // (num_n_tiles * max_m_tiles)
+ rem = pid % (num_n_tiles * max_m_tiles)
+ n_tile = rem // max_m_tiles
+ m_tile = rem % max_m_tiles
+
+ num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
+ m_tiles_needed = tl.cdiv(num_rows, BM)
+ if m_tile >= m_tiles_needed:
+ return
+
+ m_start = m_tile * BM
n_start = n_tile * BN
- num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
- m_tiles = tl.cdiv(num_rows, BM)
+
+ row_offs = m_start + tl.arange(0, BM)
+ load_row_mask = row_offs < num_rows
+ row_ptrs = row_offs * stride_xm
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
- for m_tile in range(max_m_tiles):
- if m_tile < m_tiles:
- m_start = m_tile * BM
+ gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ up_acc = tl.zeros((BM, BN), dtype=tl.float32)
- row_offs = m_start + tl.arange(0, BM)
- load_row_mask = row_offs < num_rows
- row_ptrs = row_offs * stride_xm
+ for k_tile in range(num_k_tiles):
+ k_start = k_tile * BK
+ k_offs = k_start + k_offs_base
+ k_mask = k_offs < H
- gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
- up_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
+ load_mask = load_row_mask[:, None] & k_mask[None, :]
+ x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
+ x = x_bf16.to(tl.float32)
- for k_tile in range(num_k_tiles):
- k_start = k_tile * BK
- k_offs = k_start + k_offs_base
- k_mask = k_offs < H
+ wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ wg_mask = k_mask[:, None] & n_mask[None, :]
+ wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
+ wg = wg_bf16.to(tl.float32)
- x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
- load_mask = load_row_mask[:, None] & k_mask[None, :]
- x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
- x = x_bf16.to(tl.float32)
+ wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
+ wu = wu_bf16.to(tl.float32)
- wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
- wg_mask = k_mask[:, None] & n_mask[None, :]
- wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
- wg = wg_bf16.to(tl.float32)
+ gate_acc += tl.dot(x, wg)
+ up_acc += tl.dot(x, wu)
- wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
- wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
- wu = wu_bf16.to(tl.float32)
+ gate_silu = gate_acc * tl.sigmoid(gate_acc)
+ out_fp32 = gate_silu * up_acc
+ result = out_fp32.to(tl.bfloat16)
- gate_acc += tl.dot(x, wg)
- up_acc += tl.dot(x, wu)
-
- gate_silu = gate_acc * tl.sigmoid(gate_acc)
- out_fp32 = gate_silu * up_acc
- result = out_fp32.to(tl.bfloat16)
-
- out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
- store_mask = load_row_mask[:, None] & n_mask[None, :]
- tl.store(out_ptrs, result, mask=store_mask)
+ out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
+ store_mask = load_row_mask[:, None] & n_mask[None, :]
+ tl.store(out_ptrs, result, mask=store_mask)
class FusedSwiGLUGroupedGEMM(torch.autograd.Function):
@@ -105,12 +109,12 @@
num_k_tiles = triton.cdiv(H, BK)
num_n_tiles = triton.cdiv(I, BN)
- total_programs = num_n_tiles * E
max_tokens_per_expert = int(expert_offsets[1:].diff().max().item()) if T_perm > 0 else 0
max_m_tiles = triton.cdiv(max_tokens_per_expert, BM)
+ total_programs = E * num_n_tiles * max_m_tiles
- if total_programs > 0 and max_m_tiles > 0:
+ if total_programs > 0:
grid = (total_programs,)
_grouped_gemm_swiglu_kernel[grid](
@@ -119,16 +123,15 @@
W_up,
out,
expert_offsets,
- T_perm=T_perm,
H=H,
I=I,
E=E,
... (330 more chars){
"content": "<5620 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "timeout 120 uv run python check.py 2>&1",
"timeout": 180000,
"description": "Run correctness check v3"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.197266 (atol=0.02, rtol=0.02)
{
"command": "<1451 chars \u2014 see below>",
"timeout": 120000,
"description": "Check zeros and diff distribution"
}uv run python -c "
import sys, torch
sys.path.insert(0, '.')
import shapes, reference, solution
shape = shapes.SHAPES[1]
T_total, H, I, E, K = shape['T_total'], shape['H'], shape['I'], shape['E'], shape['K']
reference.T_total = T_total
reference.H = H
reference.I = I
reference.E = E
reference.K = K
device = torch.device('cuda:0')
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
ref_f = ref_out.float()
sol_f = sol_out.float()
total = sol_out.numel()
zeros = (sol_out == 0).sum().item()
print(f'sol zeros: {zeros}/{total} ({100*zeros/total:.1f}%)')
print(f'ref zeros: {(ref_out==0).sum().item()}/{total} ({100*(ref_out==0).sum().item()/total:.1f}%)')
diff = (ref_f - sol_f).abs()
print(f'max_diff: {diff.max():.6f}')
print(f'mean_diff: {diff.mean():.6f}')
print(f'shape: ref={ref_out.shape} sol={sol_out.shape}')
print(f'ref range: [{ref_f.min():.6f}, {ref_f.max():.6f}]')
print(f'sol range: [{sol_f.min():.6f}, {sol_f.max():.6f}]')
print(f'N elements where diff > 0.02: {(diff > 0.02).sum().item()} ({100*(diff>0.02).sum().item()/total:.1f}%)')
print(f'N elements where diff > 0.05: {(diff > 0.05).sum().item()}')
" 2>&1sol zeros: 16515072/16777216 (98.4%)
ref zeros: 6509/16777216 (0.0%)
max_diff: 0.067871
mean_diff: 0.002558
shape: ref=torch.Size([16384, 1024]) sol=torch.Size([16384, 1024])
ref range: [-0.067871, 0.062500]
sol range: [-0.041016, 0.046631]
N elements where diff > 0.02: 40251 (0.2%)
N elements where diff > 0.05: 38
{
"command": "<2134 chars \u2014 see below>",
"timeout": 120000,
"description": "Manually verify grid and kernel"
}uv run python -c "
import sys, torch, triton
sys.path.insert(0, '.')
import shapes, reference
shape = shapes.SHAPES[1]
T_total, H, I, E, K = shape.values()
reference.T_total = T_total; reference.H = H; reference.I = I; reference.E = E; reference.K = K
device = torch.device('cuda:0')
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
hidden_states, expert_offsets = inputs
# Manually check grid calculation
T_perm = hidden_states.shape[0]
BM, BK, BN = 128, 64, 128
num_k_tiles_val = triton.cdiv(H, BK)
num_n_tiles_val = triton.cdiv(I, BN)
max_tokens = int(expert_offsets[1:].diff().max().item())
max_m_tiles_val = triton.cdiv(max_tokens, BM)
total_programs = E * num_n_tiles_val * max_m_tiles_val
print(f'H={H}, I={I}, E={E}, K={K}')
print(f'num_k_tiles={num_k_tiles_val}, num_n_tiles={num_n_tiles_val}, max_m_tiles={max_m_tiles_val}')
print(f'total_programs = {E}*{num_n_tiles_val}*{max_m_tiles_val} = {total_programs}')
# Manually launch the kernel
import solution
sol_model = solution.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
out = torch.empty((T_perm, I), dtype=torch.bfloat16, device=device)
from solution import _grouped_gemm_swiglu_kernel
_grouped_gemm_swiglu_kernel[(total_programs,)](
hidden_states, ref_model.W_gate, ref_model.W_up, out, expert_offsets,
H=H, I=I, E=E,
num_n_tiles=num_n_tiles_val,
max_m_tiles=max_m_tiles_val,
BM=BM, BK=BK, BN=BN,
num_k_tiles=num_k_tiles_val,
stride_xm=H, stride_xk_unused=1, stride_wg_h=I, stride_wg_i=1,
stride_om=I, stride_oi=1,
)
torch.cuda.synchronize()
with torch.no_grad():
ref_out = ref_model(*inputs)
ref_f = ref_out.float()
out_f = out.float()
print(f'sol zeros: {(out==0).sum().item()}/{out.numel()}')
print(f'ref range: [{ref_f.min():.4f}, {ref_f.max():.4f}]')
print(f'sol range: [{out_f.min():.4f}, {out_f.max():.4f}]')
diff = (ref_f - out_f).abs()
print(f'max_diff: {diff.max():.6f}')
print(f'allclose(0.02): {torch.allclose(ref_f, out_f, 0.02, 0.02)}')
" 2>&1H=2048, I=1024, E=64, K=4
num_k_tiles=32, num_n_tiles=8, max_m_tiles=2
total_programs = 64*8*2 = 1024
sol zeros: 16515072/16777216
ref range: [-0.0698, 0.0569]
sol range: [-0.0422, 0.0417]
max_diff: 0.069824
allclose(0.02): False
{
"command": "<2321 chars \u2014 see below>",
"timeout": 60000,
"description": "Debug pid decomposition"
}cat > /tmp/debug_sol.py << 'PYEOF'
import torch
import triton
import triton.language as tl
BM = 128
BK = 64
BN = 128
@triton.jit
def _debug_kernel(out_ptr, offsets_ptr, pid_dbg_ptr,
H: tl.constexpr, I: tl.constexpr, E: tl.constexpr,
num_n_tiles: tl.constexpr, max_m_tiles: tl.constexpr,
BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr,
num_k_tiles: tl.constexpr):
pid = tl.program_id(0)
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
# Log pid info
idx = pid * 3
tl.store(pid_dbg_ptr + idx, tl.cast(tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert), tl.int32))
tl.store(pid_dbg_ptr + idx + 1, n_tile)
tl.store(pid_dbg_ptr + idx + 2, m_tile_pid)
# Just write a marker value to verify which pid writes where
if pid == 0:
# Write a special marker
tl.store(out_ptr, tl.cast(99.0, tl.bfloat16))
if pid == 250:
# Write another marker
tl.store(out_ptr + 1, tl.cast(88.0, tl.bfloat16))
def test():
E, H, I, num_n_tiles, max_m_tiles = 64, 2048, 1024, 8, 2
total_programs = E * num_n_tiles * max_m_tiles
offsets = torch.zeros(E + 1, dtype=torch.int32, device='cuda')
for i in range(E + 1):
offsets[i] = i * 256
out = torch.zeros((16384, I), dtype=torch.bfloat16, device='cuda')
pid_dbg = torch.zeros((total_programs * 3,), dtype=torch.int32, device='cuda')
_debug_kernel[(total_programs,)](
out, offsets, pid_dbg,
H=H, I=I, E=E,
num_n_tiles=num_n_tiles, max_m_tiles=max_m_tiles,
BM=BM, BK=BK, BN=BN,
num_k_tiles=triton.cdiv(H, BK),
)
torch.cuda.synchronize()
# Check pid decomposition
print(f"Grid size: {total_programs}")
dbg = pid_dbg.cpu()
for p in [0, 15, 16, 250, 1023]:
m_e = dbg[p*3].item()
nt = dbg[p*3+1].item()
mt = dbg[p*3+2].item()
print(f" pid {p}: M_e={m_e}, n_tile={nt}, m_tile={mt}")
print(f"\nOutput markers: out[0,0]={out[0,0].item()}, out[0,1]={out[0,1].item()}")
print(f"Non-zero output count: {(out != 0).sum().item()}")
test()
PYEOF
uv run python /tmp/debug_sol.py 2>&1Grid size: 1024
pid 0: M_e=256, n_tile=0, m_tile=0
pid 15: M_e=256, n_tile=7, m_tile=1
pid 16: M_e=256, n_tile=0, m_tile=0
pid 250: M_e=256, n_tile=5, m_tile=0
pid 1023: M_e=256, n_tile=7, m_tile=1
Output markers: out[0,0]=99.0, out[0,1]=88.0
Non-zero output count: 2
{
"command": "<4911 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug kernel with markers"
}cat > /tmp/debug_sol2.py << 'PYEOF'
import torch
import triton
import triton.language as tl
BM = 128
BK = 64
BN = 128
@triton.jit
def _debug_gemm_kernel(out_ptr, offsets_ptr, marker_ptr,
x_ptr, w_gate_ptr, w_up_ptr,
H: tl.constexpr, I: tl.constexpr, E: tl.constexpr,
num_n_tiles: tl.constexpr, max_m_tiles: tl.constexpr,
BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr,
num_k_tiles: tl.constexpr,
stride_xm, stride_xk_unused, stride_wg_h, stride_wg_i,
stride_om, stride_oi):
pid = tl.program_id(0)
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
m_tiles_needed = tl.cdiv(num_rows, BM)
# Mark whether this pid enters or skips
early_return = m_tile_pid >= m_tiles_needed
if early_return:
tl.store(marker_ptr + pid, 0)
return
tl.store(marker_ptr + pid, 1)
m_start = m_tile_pid * BM
n_start = n_tile * BN
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + expert * stride_xm + row_offs[:, None] + k_offs[None, :] * stride_xk_unused
load_mask = load_row_mask[:, None] & k_mask[None, :]
x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
x = x_bf16.to(tl.float32)
wg_ptrs = w_gate_ptr + expert * stride_wg_h * H + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
wg_mask = k_mask[:, None] & n_mask[None, :]
wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
wg = wg_bf16.to(tl.float32)
wu_ptrs = w_up_ptr + expert * stride_wg_h * H + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
wu = wu_bf16.to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
out_fp32 = gate_silu * up_acc
result = out_fp32.to(tl.bfloat16)
row_ptrs = row_offs * stride_om
out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
def test():
E, H, I, K_routing = 64, 2048, 1024, 4
T_total, T_perm = 4096, 16384
num_n_tiles = 8
max_m_tiles = 2
num_k_tiles = triton.cdiv(H, BK)
total_programs = E * num_n_tiles * max_m_tiles
offsets = torch.zeros(E + 1, dtype=torch.int32, device='cuda')
for i in range(E + 1):
offsets[i] = i * 256 # 256 tokens per expert
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
W_gate = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
W_up = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros((T_perm, I), dtype=torch.bfloat16, device='cuda')
marker = torch.zeros(total_programs, dtype=torch.int32, device='cuda')
_debug_gemm_kernel[(total_programs,)](
out, offsets, marker, hidden, W_gate, W_up,
H=H, I=I, E=E,
num_n_tiles=num_n_tiles, max_m_tiles=max_m_tiles,
BM=BM, BK=BK, BN=BN,
num_k_tiles=num_k_tiles,
stride_xm=H, stride_xk_unused=1, stride_wg_h=I, stride_wg_i=1,
stride_om=I, stride_oi=1,
)
torch.cuda.synchronize()
m_cpu = marker.cpu()
entered = (m_cpu == 1).sum().item()
skipped = (m
... (911 more chars)Pid markers: entered=1024, skipped=0, other=0
Non-zero output: 262144/16777216
First expert output sample: out[0:4, 0:4] = tensor([[-8.4686e-04, 2.3603e-05, -1.9836e-03, 9.0122e-05],
[-8.6670e-03, 6.4392e-03, 3.1128e-03, -3.3112e-03],
[-2.8839e-03, 6.6757e-05, 6.1035e-03, 3.0975e-03],
[-5.0659e-03, -1.9836e-03, 5.9509e-03, 6.1035e-03]], device='cuda:0',
dtype=torch.bfloat16)
max_abs_diff: 0.065430
allclose(0.02): False
{
"command": "<4747 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug with small shapes"
}rm -rf /tmp/triton_* 2>/dev/null; rm -rf __pycache__ 2>/dev/null
cat > /tmp/debug_sol3.py << 'PYEOF'
import torch, triton, triton.language as tl
import torch.nn.functional as F
BM = 64
BK = 64
BN = 128
@triton.jit
def _test_kernel(out_ptr, offsets_ptr, x_ptr, wg_ptr, wu_ptr,
H: tl.constexpr, I: tl.constexpr, E: tl.constexpr,
num_n_tiles: tl.constexpr, max_m_tiles: tl.constexpr,
BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr,
num_k_tiles: tl.constexpr):
pid = tl.program_id(0)
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
m_tiles_needed = tl.cdiv(num_rows, BM)
if m_tile_pid >= m_tiles_needed:
return
m_start = m_tile_pid * BM
n_start = n_tile * BN
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
row_ptrs_x = row_offs * H
row_ptrs_out = row_offs * I
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
wg_base = expert.to(tl.int64) * H.to(tl.int64) * I.to(tl.int64)
x_base = tl.cast(0, tl.int64) # already included in row_ptrs
x_ptrs = x_ptr + row_ptrs_x[:, None] + k_offs[None, :]
x_mask = load_row_mask[:, None] & k_mask[None, :]
x_bf16 = tl.load(x_ptrs, mask=x_mask, other=0.0)
x = x_bf16.to(tl.float32)
wg_ptrs = wg_ptr + wg_base + k_offs[:, None] * I + n_offs[None, :]
wg_mask = k_mask[:, None] & n_mask[None, :]
wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
wg = wg_bf16.to(tl.float32)
wu_ptrs = wu_ptr + wg_base + k_offs[:, None] * I + n_offs[None, :]
wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
wu = wu_bf16.to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
combined = gate_silu * up_acc
result = combined.to(tl.bfloat16)
out_ptrs = out_ptr + row_ptrs_out[:, None] + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
def test():
E, H, I, K_routing = 4, 256, 128, 2
T_total = 128
T_perm = T_total * K_routing # 256
tokens_per_expert = T_perm // E # 64
# Even split
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * tokens_per_expert
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
W_gate = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
W_up = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros((T_perm, I), dtype=torch.bfloat16, device='cuda')
max_m_tiles_val = triton.cdiv(tokens_per_expert, BM) # ceil(64/64)=1
num_n_tiles_val = triton.cdiv(I, BN) # ceil(128/128)=1
num_k_tiles_val = triton.cdiv(H, BK) # ceil(256/64)=4
total = E * num_n_tiles_val * max_m_tiles_val # 4*1*1=4
_test_kernel[(total,)](out, offsets, hidden, W_gate, W_up,
H=H, I=I, E=E, num_n_tiles=num_n_tiles_val,
max_m_tiles=max_m_tiles_val, BM=BM, BK=BK, BN=BN,
num_k_tiles=num_k_tiles_val)
torch.cuda.synchronize()
# Reference
ref = torch.empty((T_perm, I), dtype=torch.bfloat16, device='cuda')
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
x_e = hidden[s:e_end].float()
g = x_e @ W_gate[e].float()
u = x_e @ W_up[e].float()
ref[s:e_end] = (F.silu(g) * u).bfloat16()
... (747 more chars){
"command": "<4467 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug with small shapes in local dir"
}cat > debug_sol3.py << 'PYEOF'
import torch, triton, triton.language as tl
import torch.nn.functional as F
BM = 64
BK = 64
BN = 128
@triton.jit
def _test_kernel(out_ptr, offsets_ptr, x_ptr, wg_ptr, wu_ptr,
H: tl.constexpr, I: tl.constexpr, E: tl.constexpr,
num_n_tiles: tl.constexpr, max_m_tiles: tl.constexpr,
BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr,
num_k_tiles: tl.constexpr):
pid = tl.program_id(0)
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
m_tiles_needed = tl.cdiv(num_rows, BM)
if m_tile_pid >= m_tiles_needed:
return
m_start = m_tile_pid * BM
n_start = n_tile * BN
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
x_mask = load_row_mask[:, None] & k_mask[None, :]
x_bf16 = tl.load(x_ptrs, mask=x_mask, other=0.0)
x = x_bf16.to(tl.float32)
wg_base = expert * H * I
wg_ptrs = wg_ptr + wg_base + k_offs[:, None] * I + n_offs[None, :]
wg_mask = k_mask[:, None] & n_mask[None, :]
wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
wg = wg_bf16.to(tl.float32)
wu_ptrs = wu_ptr + wg_base + k_offs[:, None] * I + n_offs[None, :]
wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
wu = wu_bf16.to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
combined = gate_silu * up_acc
result = combined.to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
def test():
E, H, I, K_routing = 4, 256, 128, 2
T_total = 128
T_perm = T_total * K_routing
tokens_per_expert = T_perm // E
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * tokens_per_expert
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
W_gate = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
W_up = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros((T_perm, I), dtype=torch.bfloat16, device='cuda')
max_m_tiles_val = triton.cdiv(tokens_per_expert, BM)
num_n_tiles_val = triton.cdiv(I, BN)
num_k_tiles_val = triton.cdiv(H, BK)
total = E * num_n_tiles_val * max_m_tiles_val
print(f"E={E}, H={H}, I={I}, T_perm={T_perm}, tokens_per_expert={tokens_per_expert}")
print(f"num_n_tiles={num_n_tiles_val}, max_m_tiles={max_m_tiles_val}, num_k_tiles={num_k_tiles_val}")
print(f"total_programs={total}")
_test_kernel[(total,)](out, offsets, hidden, W_gate, W_up,
H=H, I=I, E=E, num_n_tiles=num_n_tiles_val,
max_m_tiles=max_m_tiles_val, BM=BM, BK=BK, BN=BN,
num_k_tiles=num_k_tiles_val)
torch.cuda.synchronize()
ref = torch.empty((T_perm, I), dtype=torch.bfloat16, device='cuda')
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
x_e = hidden[s:e_end].float()
g = x_e @ W_gate[e].float()
u = x_e @ W_up[e].float()
ref[s:e_end] = (F.silu(g) * u).bfloat16()
zeros = (out == 0).sum().item()
print(f"Non-zero: {out.numel() - zeros}/{out.numel()}")
for e in range(E):
... (467 more chars)E=4, H=256, I=128, T_perm=256, tokens_per_expert=64
num_n_tiles=1, max_m_tiles=1, num_k_tiles=4
total_programs=4
Non-zero: 8192/32768
Expert 0 [rows 0:64]: max_diff=0.000004, non_zero=True
Expert 1 [rows 64:128]: max_diff=0.005554, non_zero=False
Expert 2 [rows 128:192]: max_diff=0.004089, non_zero=False
Expert 3 [rows 192:256]: max_diff=0.004059, non_zero=False
Overall: max_diff=0.005554, allclose(0.02): True
{
"command": "<4226 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug without constexpr types"
}cat > debug_sol4.py << 'PYEOF'
import torch, triton, triton.language as tl
import torch.nn.functional as F
BM = 64; BK = 64; BN = 128
@triton.jit
def _test_kernel(out_ptr, offsets_ptr, x_ptr, wg_ptr, wu_ptr,
H, I, E, num_n_tiles, max_m_tiles, BM_c, BK_c, BN_c,
num_k_tiles):
pid = tl.program_id(0)
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
num_rows = tl.load(offsets_ptr + expert * H + 1) - tl.load(offsets_ptr + expert * H)
m_tiles_needed = tl.cdiv(num_rows, BM_c)
if m_tile_pid >= m_tiles_needed:
return
m_start = m_tile_pid * BM_c
n_start = n_tile * BN_c
row_offs = m_start + tl.arange(0, BM_c)
load_row_mask = row_offs < num_rows
n_offs = n_start + tl.arange(0, BN_c)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK_c)
gate_acc = tl.zeros((BM_c, BN_c), dtype=tl.float32)
up_acc = tl.zeros((BM_c, BN_c), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK_c
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
x_mask = load_row_mask[:, None] & k_mask[None, :]
x_bf16 = tl.load(x_ptrs, mask=x_mask, other=0.0)
x = x_bf16.to(tl.float32)
wg_base = expert.to(tl.int64) * H.to(tl.int64) * I.to(tl.int64)
wg_ptrs = wg_ptr + wg_base + k_offs[:, None] * I + n_offs[None, :]
wg_mask = k_mask[:, None] & n_mask[None, :]
wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
wg = wg_bf16.to(tl.float32)
wu_ptrs = wu_ptr + wg_base + k_offs[:, None] * I + n_offs[None, :]
wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
wu = wu_bf16.to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
combined = gate_silu * up_acc
result = combined.to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
def test():
E, H, I, K_routing = 4, 256, 128, 2
T_total = 128; T_perm = T_total * K_routing
tokens_per_expert = T_perm // E
offsets = torch.zeros(E + 1, dtype=torch.int32, device='cuda')
for i in range(E + 1):
offsets[i] = i * tokens_per_expert
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
W_gate = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
W_up = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros((T_perm, I), dtype=torch.bfloat16, device='cuda')
max_m_tiles_val = (tokens_per_expert + BM - 1) // BM
num_n_tiles_val = (I + BN - 1) // BN
num_k_tiles_val = (H + BK - 1) // BK
total = E * num_n_tiles_val * max_m_tiles_val
print(f"Grid: {total}, E={E}, num_n={num_n_tiles_val}, max_m={max_m_tiles_val}, num_k={num_k_tiles_val}")
_test_kernel[(total,)](out, offsets, hidden, W_gate, W_up,
H, I, E, num_n_tiles_val, max_m_tiles_val,
BM, BK, BN, num_k_tiles_val)
torch.cuda.synchronize()
ref = torch.empty((T_perm, I), dtype=torch.bfloat16, device='cuda')
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
x_e = hidden[s:e_end].float()
g = x_e @ W_gate[e].float()
u = x_e @ W_up[e].float()
ref[s:e_end] = (F.silu(g) * u).bfloat16()
zeros = (out == 0).sum().item()
print(f"Non-zero: {out.numel() - zeros}/{out.numel()} ({100*(out.numel()-zeros)/out.numel():.0f}%)")
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
o = out[s:e_end]
r = ref[s:e_end]
z = (o == 0).sum().item()
d = (o.float() - r.float()).abs()
print(f"
... (226 more chars)Grid: 4, E=4, num_n=1, max_m=1, num_k=4
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol4.py", line 111, in <module>
test()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol4.py", line 86, in test
_test_kernel[(total,)](out, offsets, hidden, W_gate, W_up,
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.1
... (1606 more chars){
"command": "<4484 chars \u2014 see below>",
"timeout": 120000,
"description": "Debug with 8 experts"
}cat > debug_sol5.py << 'PYEOF'
import torch, triton, triton.language as tl
import torch.nn.functional as F
BM = 64; BK = 64; BN = 128
@triton.jit
def _test_kernel(out_ptr, offsets_ptr, x_ptr, wg_ptr, wu_ptr,
H: tl.constexpr, I: tl.constexpr, E: tl.constexpr,
num_n_tiles: tl.constexpr, max_m_tiles: tl.constexpr,
BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr,
num_k_tiles: tl.constexpr):
pid = tl.program_id(0)
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
m_tiles_needed = tl.cdiv(num_rows, BM)
if m_tile_pid >= m_tiles_needed:
return
m_start = m_tile_pid * BM
n_start = n_tile * BN
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
x_mask = load_row_mask[:, None] & k_mask[None, :]
x_bf16 = tl.load(x_ptrs, mask=x_mask, other=0.0)
x = x_bf16.to(tl.float32)
wg_base = expert.to(tl.int64) * H.to(tl.int64) * I.to(tl.int64)
wg_ptrs = wg_ptr + wg_base + k_offs[:, None] * I + n_offs[None, :]
wg_mask = k_mask[:, None] & n_mask[None, :]
wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
wg = wg_bf16.to(tl.float32)
wu_ptrs = wu_ptr + wg_base + k_offs[:, None] * I + n_offs[None, :]
wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
wu = wu_bf16.to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
combined = gate_silu * up_acc
result = combined.to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
def test():
E, H, I, K_routing = 8, 256, 128, 2
T_total = 256; T_perm = T_total * K_routing
tokens_per_expert = T_perm // E # 64
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * tokens_per_expert
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
W_gate = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
W_up = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros((T_perm, I), dtype=torch.bfloat16, device='cuda')
max_m_tiles_val = (tokens_per_expert + BM - 1) // BM
num_n_tiles_val = (I + BN - 1) // BN
num_k_tiles_val = (H + BK - 1) // BK
total = E * num_n_tiles_val * max_m_tiles_val
print(f"E={E}, H={H}, I={I}, T_perm={T_perm}, M_per_exp={tokens_per_expert}")
print(f"Grid: {total} = {E} * {num_n_tiles_val} * {max_m_tiles_val}")
_test_kernel[(total,)](out, offsets, hidden, W_gate, W_up,
H=H, I=I, E=E, num_n_tiles=num_n_tiles_val,
max_m_tiles=max_m_tiles_val, BM=BM, BK=BK, BN=BN,
num_k_tiles=num_k_tiles_val)
torch.cuda.synchronize()
ref = torch.empty((T_perm, I), dtype=torch.bfloat16, device='cuda')
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
x_e = hidden[s:e_end].float()
g = x_e @ W_gate[e].float()
u = x_e @ W_up[e].float()
ref[s:e_end] = (F.silu(g) * u).bfloat16()
zeros = (out == 0).sum().item()
print(f"Non-zero: {out.numel() - zeros}/{out.numel()} ({100*(out.numel()-zeros)/out.numel():.0f}%)")
for e
... (484 more chars)File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py", line 105
o = out[s:e_end]\n r = ref[s:e_end]
^
SyntaxError: unexpected character after line continuation character
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py
@@ -1 +1 @@
- o = out[s:e_end]\n r = ref[s:e_end]+ o = out[s:e_end]{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py",
"newString": " o = out[s:e_end]",
"oldString": " o = out[s:e_end]\\n r = ref[s:e_end]"
}Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py",
"limit": 20,
"offset": 95
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py</path>
<type>file</type>
<content>
95: s, e_end = int(offsets[e]), int(offsets[e+1])
96: x_e = hidden[s:e_end].float()
97: g = x_e @ W_gate[e].float()
98: u = x_e @ W_up[e].float()
99: ref[s:e_end] = (F.silu(g) * u).bfloat16()
100:
101: zeros = (out == 0).sum().item()
102: print(f"Non-zero: {out.numel() - zeros}/{out.numel()} ({100*(out.numel()-zeros)/out.numel():.0f}%)")
103: for e in range(E):
104: s, e_end = int(offsets[e]), int(offsets[e+1])
105: o = out[s:e_end]
106: d = (o.float() - r.float()).abs()
107: print(f" Expert {e}: rows[{s}:{e_end}], zero={(o==0).sum().item()}/{o.numel()}, diff_max={d.max():.6f}")
108:
109: diff = (out.float() - ref.float()).abs()
110: print(f"\nmax_diff={diff.max():.6f}, allclose(0.02)={torch.allclose(out.float(), ref.float(), 0.02, 0.02)}")
111:
112: test()
(End of file - total 112 lines)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py
@@ -1,2 +1,3 @@
o = out[s:e_end]
+ r = ref[s:e_end]
d = (o.float() - r.float()).abs(){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py",
"newString": "<91 chars \u2014 see below>",
"oldString": "<66 chars \u2014 see below>"
} o = out[s:e_end]
r = ref[s:e_end]
d = (o.float() - r.float()).abs() o = out[s:e_end]
d = (o.float() - r.float()).abs()Edit applied successfully.{
"command": "uv run python debug_sol5.py 2>&1",
"timeout": 120000,
"description": "Debug with 8 experts"
}E=8, H=256, I=128, T_perm=512, M_per_exp=64
Grid: 8 = 8 * 1 * 1
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py", line 113, in <module>
test()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py", line 87, in test
_test_kernel[(total,)](out, offsets, hidden, W_gate, W_up,
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^
... (948 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py
@@ -1 +1 @@
- wg_base = expert.to(tl.int64) * H.to(tl.int64) * I.to(tl.int64)+ wg_base = expert.to(tl.int64) * H * I{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol5.py",
"newString": " wg_base = expert.to(tl.int64) * H * I",
"oldString": " wg_base = expert.to(tl.int64) * H.to(tl.int64) * I.to(tl.int64)"
}Edit applied successfully.{
"command": "uv run python debug_sol5.py 2>&1",
"timeout": 120000,
"description": "Debug with 8 experts v2"
}E=8, H=256, I=128, T_perm=512, M_per_exp=64
Grid: 8 = 8 * 1 * 1
Non-zero: 8192/65536 (12%)
Expert 0: rows[0:64], zero=0/8192, diff_max=0.004250
Expert 1: rows[64:128], zero=8192/8192, diff_max=0.004547
Expert 2: rows[128:192], zero=8192/8192, diff_max=0.004211
Expert 3: rows[192:256], zero=8192/8192, diff_max=0.003906
Expert 4: rows[256:320], zero=8192/8192, diff_max=0.004608
Expert 5: rows[320:384], zero=8192/8192, diff_max=0.004303
Expert 6: rows[384:448], zero=8192/8192, diff_max=0.003937
Expert 7: rows[448:512], zero=8192/8192, diff_max=0.004517
max_diff=0.004608, allclose(0.02)=True
{
"command": "<4468 chars \u2014 see below>",
"timeout": 180000,
"description": "Test corrected kernel with expert offset"
}cat > debug_sol6.py << 'PYEOF'
import torch, triton, triton.language as tl
import torch.nn.functional as F
BM = 64; BK = 64; BN = 128
@triton.jit
def _correct_kernel(out_ptr, offsets_ptr, x_ptr, wg_ptr, wu_ptr,
H: tl.constexpr, I: tl.constexpr, E: tl.constexpr,
num_n_tiles: tl.constexpr, max_m_tiles: tl.constexpr,
BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr,
num_k_tiles: tl.constexpr):
pid = tl.program_id(0)
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
m_tiles_needed = tl.cdiv(num_rows, BM)
if m_tile_pid >= m_tiles_needed:
return
m_start = m_tile_pid * BM
n_start = n_tile * BN
row_offs = m_start + tl.arange(0, BM)
load_row_mask = row_offs < num_rows
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
expert_weight_off = expert.to(tl.int64) * H * I
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
x_mask = load_row_mask[:, None] & k_mask[None, :]
x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
wg_ptrs = wg_ptr + expert_weight_off + k_offs[:, None] * I + n_offs[None, :]
wg_mask = k_mask[:, None] & n_mask[None, :]
wg = tl.load(wg_ptrs, mask=wg_mask, other=0.0).to(tl.float32)
wu_ptrs = wu_ptr + expert_weight_off + k_offs[:, None] * I + n_offs[None, :]
wu = tl.load(wu_ptrs, mask=wg_mask, other=0.0).to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
result = (gate_silu * up_acc).to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
def test():
for E, H, I, T_total, K_r in [(64, 2048, 1024, 4096, 4),
(128, 4096, 1536, 32768, 8)]:
T_perm = T_total * K_r
base = T_perm // E
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * base
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
W_gate = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
W_up = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros((T_perm, I), dtype=torch.bfloat16, device='cuda')
max_m = (base + BM - 1) // BM
num_n = (I + BN - 1) // BN
num_k = (H + BK - 1) // BK
total = E * num_n * max_m
print(f"\n=== E={E}, H={H}, I={I}, T_perm={T_perm} ===")
print(f"Grid: {total}, num_n={num_n}, max_m={max_m}, num_k={num_k}")
_correct_kernel[(total,)](out, offsets, hidden, W_gate, W_up,
H=H, I=I, E=E, num_n_tiles=num_n,
max_m_tiles=max_m, BM=BM, BK=BK, BN=BN,
num_k_tiles=num_k)
torch.cuda.synchronize()
ref = torch.empty((T_perm, I), dtype=torch.bfloat16, device='cuda')
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
x_e = hidden[s:e_end].float()
g = x_e @ W_gate[e].float()
u = x_e @ W_up[e].float()
ref[s:e_end] = (F.silu(g) * u).bfloat16()
zeros = (out == 0).sum().item()
print(f"Non-zero: {out.numel() - zeros}/{out.numel()} ({100*(out.numel()-zeros)/out.numel():.1f}%)")
# Check first 3 experts
for e in ra
... (468 more chars)=== E=64, H=2048, I=1024, T_perm=16384 ===
Grid: 2048, num_n=8, max_m=4, num_k=32
Non-zero: 262144/16777216 (1.6%)
Expert 0: diff_max=0.059387, zeros=0/262144
Expert 1: diff_max=0.052246, zeros=262144/262144
Expert 2: diff_max=0.060791, zeros=262144/262144
max_diff=0.067871 allclose(0.02)=False
=== E=128, H=4096, I=1536, T_perm=262144 ===
Grid: 49152, num_n=12, max_m=32, num_k=64
Non-zero: 3145728/402653184 (0.8%)
Expert 0: diff_max=0.144714, zeros=0/3145728
Expert 1: diff_max=0.127930, zeros=3145728/3145728
Expert 2: diff_max=0.127930, zeros=3145728/3145728
max_diff=0.188477 allclose(0.02)=False
{
"command": "<1565 chars \u2014 see below>",
"timeout": 60000,
"description": "Test Triton grid size with markers"
}cat > debug_sol7.py << 'PYEOF'
"""Test if Triton skips programs with pid > some limit"""
import torch, triton, triton.language as tl
@triton.jit
def _marker_kernel(markers, num_n: tl.constexpr, max_m: tl.constexpr):
pid = tl.program_id(0)
expert = pid // (num_n * max_m)
rem = pid % (num_n * max_m)
n_tile = rem // max_m
m_tile = rem % max_m
# Write pid info
idx = pid * 4
markers[idx + 0] = pid
markers[idx + 1] = expert
markers[idx + 2] = n_tile
markers[idx + 3] = m_tile
# Test with same constexprs as E=64 case
num_n = 8; max_m = 4; E = 64
total = E * num_n * max_m
markers = torch.zeros(total * 4, dtype=torch.int32, device='cuda')
print(f"Testing pid decomposition: E={E}, num_n={num_n}, max_m={max_m}, total={total}")
_marker_kernel[(total,)](markers, num_n, max_m)
torch.cuda.synchronize()
m = markers.cpu()
# Check all pids were processed
pid_set = set()
for i in range(0, total * 4, 4):
p, e, n, mt = int(m[i]), int(m[i+1]), int(m[i+2]), int(m[i+3])
pid_set.add(p)
if p < 5 or p > total - 5 or e == 1:
pass # print(f"pid {p}: expert={e}, n_tile={n}, m_tile={mt}")
# Count per-expert
exp_count = {}
for i in range(0, total * 4, 4):
e = int(m[i+1])
exp_count[e] = exp_count.get(e, 0) + 1
print(f"Unique pids: {len(pid_set)}/{total}")
print(f"Experts covered: {len(exp_count)}/64")
print(f"Expert 0 programs: {exp_count.get(0, 0)}")
print(f"Expert 1 programs: {exp_count.get(1, 0)}")
print(f"Expert 63 programs: {exp_count.get(63, 0)}")
PYEOF
uv run python debug_sol7.py 2>&1Testing pid decomposition: E=64, num_n=8, max_m=4, total=2048
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol7.py", line 24, in <module>
_marker_kernel[(total,)](markers, num_n, max_m)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages
... (645 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol7.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol7.py
@@ -1,6 +1,6 @@
# Write pid info
idx = pid * 4
- markers[idx + 0] = pid
- markers[idx + 1] = expert
- markers[idx + 2] = n_tile
- markers[idx + 3] = m_tile+ tl.store(markers + idx, pid)
+ tl.store(markers + idx + 1, expert)
+ tl.store(markers + idx + 2, n_tile)
+ tl.store(markers + idx + 3, m_tile){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol7.py",
"newString": "<191 chars \u2014 see below>",
"oldString": "<155 chars \u2014 see below>"
} # Write pid info
idx = pid * 4
tl.store(markers + idx, pid)
tl.store(markers + idx + 1, expert)
tl.store(markers + idx + 2, n_tile)
tl.store(markers + idx + 3, m_tile) # Write pid info
idx = pid * 4
markers[idx + 0] = pid
markers[idx + 1] = expert
markers[idx + 2] = n_tile
markers[idx + 3] = m_tileEdit applied successfully.{
"command": "uv run python debug_sol7.py 2>&1",
"timeout": 60000,
"description": "Test Triton grid size"
}Testing pid decomposition: E=64, num_n=8, max_m=4, total=2048
Unique pids: 2048/2048
Experts covered: 64/64
Expert 0 programs: 32
Expert 1 programs: 32
Expert 63 programs: 32
{
"command": "<2684 chars \u2014 see below>",
"timeout": 60000,
"description": "Test early return behavior"
}cat > debug_sol8.py << 'PYEOF'
"""Test if Triton correctly evaluates early return based on program pid"""
import torch, triton, triton.language as tl
@triton.jit
def _early_return_kernel(out_ptr, offsets_ptr, marker_ptr,
BM: tl.constexpr, max_m_tiles: tl.constexpr):
pid = tl.program_id(0)
num_n_tiles: tl.constexpr = 8
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
m_tile_pid = rem % max_m_tiles
num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
m_tiles_needed = tl.cdiv(num_rows, BM)
# Log what happens for each pid
marker_ptr[pid * 3 + 0] = tl.cast(num_rows, tl.int32)
marker_ptr[pid * 3 + 1] = tl.cast(m_tiles_needed, tl.int32)
marker_ptr[pid * 3 + 2] = tl.cast(m_tile_pid, tl.int32)
if m_tile_pid >= m_tiles_needed:
marker_ptr[pid * 3 + 2] = 1000 + m_tile_pid # mark as "returned early"
out_ptr[pid] = -99 # sentinel for early return
return
out_ptr[pid] = pid
E = 64; num_n = 8; max_m = 4; BM_val = 64
total = E * num_n * max_m # 2048
base = 256
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * base
out = torch.zeros(total, dtype=torch.int32, device='cuda')
markers = torch.zeros(total * 3, dtype=torch.int32, device='cuda')
print(f"Grid: {total}, BM={BM_val}")
_early_return_kernel[(total,)](out, offsets, markers, BM=BM_val, max_m_tiles=max_m)
torch.cuda.synchronize()
o = out.cpu()
m = markers.cpu()
# Check how many returned early
early_ret = (o == -99).sum().item()
no_ret = ((o != -99) & (o != 0)).sum().item()
zero_vals = (o == 0).sum().item()
print(f"Results: early_return={early_ret}, no_ret={no_ret}, zero_vals={zero_vals}")
# Check first few and a few expert boundaries
for p in [0, 1, 2, 3, 31, 32, 33, 63, 64, 127, 128]:
mrows = m[p*3].item()
mneeded = m[p*3+1].item()
mtile = m[p*3+2].item()
oval = o[p].item()
print(f" pid={p}: num_rows={mrows}, m_needed={mneeded}, m_tile={mtile}, out={oval}")
# Count per expert
exp_data = {}
for p in range(total):
e = p // (num_n * max_m)
if e not in exp_data:
exp_data[e] = {'early': 0, 'proc': 0}
if o[p].item() == -99:
exp_data[e]['early'] += 1
else:
exp_data[e]['proc'] += 1
print(f"\nPer-expert processing:")
for e in sorted(exp_data.keys())[:3]:
print(f" Expert {e}: early={exp_data[e]['early']}, proc={exp_data[e]['proc']}")
print(f" ...")
e_last = sorted(exp_data.keys())[-1]
print(f" Expert {e_last}: early={exp_data[e_last]['early']}, proc={exp_data[e_last]['proc']}")
PYEOF
uv run python debug_sol8.py 2>&1Grid: 2048, BM=64
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol8.py", line 39, in <module>
_early_return_kernel[(total,)](out, offsets, markers, BM=BM_val, max_m_tiles=max_m)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/
... (769 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol8.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol8.py
@@ -1,11 +1,11 @@
# Log what happens for each pid
- marker_ptr[pid * 3 + 0] = tl.cast(num_rows, tl.int32)
- marker_ptr[pid * 3 + 1] = tl.cast(m_tiles_needed, tl.int32)
- marker_ptr[pid * 3 + 2] = tl.cast(m_tile_pid, tl.int32)
+ tl.store(marker_ptr + pid * 3, tl.cast(num_rows, tl.int32))
+ tl.store(marker_ptr + pid * 3 + 1, tl.cast(m_tiles_needed, tl.int32))
+ tl.store(marker_ptr + pid * 3 + 2, tl.cast(m_tile_pid, tl.int32))
if m_tile_pid >= m_tiles_needed:
- marker_ptr[pid * 3 + 2] = 1000 + m_tile_pid # mark as "returned early"
- out_ptr[pid] = -99 # sentinel for early return
+ tl.store(marker_ptr + pid * 3 + 2, 1000 + m_tile_pid)
+ tl.store(out_ptr + pid, -99)
return
- out_ptr[pid] = pid+ tl.store(out_ptr + pid, pid){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_sol8.py",
"newString": "<429 chars \u2014 see below>",
"oldString": "<430 chars \u2014 see below>"
} # Log what happens for each pid
tl.store(marker_ptr + pid * 3, tl.cast(num_rows, tl.int32))
tl.store(marker_ptr + pid * 3 + 1, tl.cast(m_tiles_needed, tl.int32))
tl.store(marker_ptr + pid * 3 + 2, tl.cast(m_tile_pid, tl.int32))
if m_tile_pid >= m_tiles_needed:
tl.store(marker_ptr + pid * 3 + 2, 1000 + m_tile_pid)
tl.store(out_ptr + pid, -99)
return
tl.store(out_ptr + pid, pid) # Log what happens for each pid
marker_ptr[pid * 3 + 0] = tl.cast(num_rows, tl.int32)
marker_ptr[pid * 3 + 1] = tl.cast(m_tiles_needed, tl.int32)
marker_ptr[pid * 3 + 2] = tl.cast(m_tile_pid, tl.int32)
if m_tile_pid >= m_tiles_needed:
marker_ptr[pid * 3 + 2] = 1000 + m_tile_pid # mark as "returned early"
out_ptr[pid] = -99 # sentinel for early return
return
out_ptr[pid] = pidEdit applied successfully.{
"command": "uv run python debug_sol8.py 2>&1",
"timeout": 60000,
"description": "Test early return behavior"
}Grid: 2048, BM=64
Results: early_return=0, no_ret=2047, zero_vals=1
pid=0: num_rows=256, m_needed=4, m_tile=0, out=0
pid=1: num_rows=256, m_needed=4, m_tile=1, out=1
pid=2: num_rows=256, m_needed=4, m_tile=2, out=2
pid=3: num_rows=256, m_needed=4, m_tile=3, out=3
pid=31: num_rows=256, m_needed=4, m_tile=3, out=31
pid=32: num_rows=256, m_needed=4, m_tile=0, out=32
pid=33: num_rows=256, m_needed=4, m_tile=1, out=33
pid=63: num_rows=256, m_needed=4, m_tile=3, out=63
pid=64: num_rows=256, m_needed=4, m_tile=0, out=64
pid=127: num_rows=256, m_needed=4, m_tile=3, out=127
pid=128: num_rows=256, m_needed=4, m_tile=0, out=128
Per-expert processing:
Expert 0: early=0, proc=32
Expert 1: early=0, proc=32
Expert 2: early=0, proc=32
...
Expert 63: early=0, proc=32
{
"command": "<4104 chars \u2014 see below>",
"timeout": 120000,
"description": "Test single and expert-indexed GEMM"
}cat > debug_sol9.py << 'PYEOF'
"""Minimal test: single program GEMM with known data"""
import torch, triton, triton.language as tl
import torch.nn.functional as F
@triton.jit
def _single_gemm(out_ptr, x_ptr, w_ptr,
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr):
BM: tl.constexpr = M
BK: tl.constexpr = 64
BN: tl.constexpr = N
row_offs = tl.arange(0, BM)
n_offs = tl.arange(0, BN)
k_offs_base = tl.arange(0, BK)
num_k_tiles = K // BK
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
x_ptrs = x_ptr + row_offs[:, None] * K + k_offs[None, :]
w_ptrs = w_ptr + k_offs[:, None] * N + n_offs[None, :]
x = tl.load(x_ptrs).to(tl.float32)
w = tl.load(w_ptrs).to(tl.float32)
acc += tl.dot(x, w)
result = acc.to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * N + n_offs[None, :]
tl.store(out_ptrs, result)
def test_single_gemm():
M, K, N = 64, 256, 128
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') * 0.1
w = torch.randn(K, N, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
_single_gemm[(1,)](out, x, w, M=M, K=K, N=N)
torch.cuda.synchronize()
ref = (x.float() @ w.float()).bfloat16()
diff = (out.float() - ref.float()).abs()
print(f"M={M}, K={K}, N={N}")
print(f"max_diff={diff.max():.6f}, mean_diff={diff.mean():.6f}")
print(f"allclose(0.02)={torch.allclose(out.float(), ref.float(), 0.02, 0.02)}")
test_single_gemm()
# Now test with expert indexing
@triton.jit
def _expert_gemm(out_ptr, x_ptr, w_ptr, offsets_ptr,
H: tl.constexpr, I: tl.constexpr):
pid = tl.program_id(0)
BM: tl.constexpr = 64
BK: tl.constexpr = 64
BN: tl.constexpr = 128
num_k_tiles = H // BK
num_rows = tl.load(offsets_ptr + pid + 1) - tl.load(offsets_ptr + pid)
m_start = 0
row_offs = m_start + tl.arange(0, BM)
n_offs = tl.arange(0, BN)
k_offs_base = tl.arange(0, BK)
acc = tl.zeros((BM, BN), dtype=tl.float32)
expert_off = pid.to(tl.int64) * H * I
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
w_ptrs = w_ptr + expert_off + k_offs[:, None] * I + n_offs[None, :]
x = tl.load(x_ptrs).to(tl.float32)
w = tl.load(w_ptrs).to(tl.float32)
acc += tl.dot(x, w)
result = acc.to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
tl.store(out_ptrs, result)
def test_expert_gemm():
E, H, I = 4, 256, 128
tokens_per_expert = 64
total = E * tokens_per_expert
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * tokens_per_expert
x = torch.randn(total, H, dtype=torch.bfloat16, device='cuda') * 0.1
w = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros(total, I, dtype=torch.bfloat16, device='cuda')
_expert_gemm[(E,)](out, x, w, offsets, H=H, I=I)
torch.cuda.synchronize()
ref = torch.empty(total, I, dtype=torch.bfloat16, device='cuda')
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
ref[s:e_end] = (x[s:e_end].float() @ w[e].float()).bfloat16()
zeros = (out == 0).sum().item()
print(f"\nExpert GEMM: E={E}, H={H}, I={I}, M_per_exp={tokens_per_expert}")
print(f"Non-zero: {out.numel() - zeros}/{out.numel()}")
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
d = (out[s:e_end].float() - ref[s:e_end].float()).abs()
z = (out[s:e_end] == 0).sum().item()
print(f" Expert {e}: diff_max={d.max():.6f} zeros={z}/{(e_end-s)*I}")
diff = (out.float() - ref.float()).abs()
print(f"Overall: max_diff={diff.max():.6f} allclose(0.02)={torch.allc
... (104 more chars)M=64, K=256, N=128
max_diff=0.000122, mean_diff=0.000000
allclose(0.02)=True
Expert GEMM: E=4, H=256, I=128, M_per_exp=64
Non-zero: 8192/32768
Expert 0: diff_max=0.201172 zeros=0/8192
Expert 1: diff_max=0.120605 zeros=8192/8192
Expert 2: diff_max=0.120117 zeros=8192/8192
Expert 3: diff_max=0.125000 zeros=8192/8192
Overall: max_diff=0.201172 allclose(0.02)=False
{
"command": "<3778 chars \u2014 see below>",
"timeout": 120000,
"description": "Test expert indexing with pid"
}cat > debug_sol10.py << 'PYEOF'
import torch, triton, triton.language as tl
@triton.jit
def _expert_gemm(out_ptr, x_ptr, w_ptr, offsets_ptr,
H: tl.constexpr, I: tl.constexpr):
pid = tl.program_id(0)
BM: tl.constexpr = 64
BK: tl.constexpr = 64
BN: tl.constexpr = 128
num_k_tiles = H // BK
num_rows = tl.load(offsets_ptr + pid + 1) - tl.load(offsets_ptr + pid)
row_offs = tl.arange(0, BM)
n_offs = tl.arange(0, BN)
k_offs_base = tl.arange(0, BK)
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
# Expert weight offset using pid
expert_weight_offset = pid * H * I
w_ptrs = w_ptr + expert_weight_offset + k_offs[:, None] * I + n_offs[None, :]
x = tl.load(x_ptrs).to(tl.float32)
w = tl.load(w_ptrs).to(tl.float32)
acc += tl.dot(x, w)
result = acc.to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
tl.store(out_ptrs, result)
E, H, I = 4, 256, 128
tokens_per_expert = 64
total_rows = E * tokens_per_expert
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * tokens_per_expert
x = torch.randn(total_rows, H, dtype=torch.bfloat16, device='cuda') * 0.1
w = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros(total_rows, I, dtype=torch.bfloat16, device='cuda')
import torch.nn.functional as F
ref = torch.empty(total_rows, I, dtype=torch.bfloat16, device='cuda')
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
ref[s:e_end] = (x[s:e_end].float() @ w[e].float()).bfloat16()
_expert_gemm[(E,)](out, x, w, offsets, H=H, I=I)
torch.cuda.synchronize()
zeros = (out == 0).sum().item()
print(f"Non-zero: {out.numel() - zeros}/{out.numel()}")
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
d = (out[s:e_end].float() - ref[s:e_end].float()).abs()
z = (out[s:e_end] == 0).sum().item()
print(f" Expert {e}: diff_max={d.max():.6f} zeros={z}")
# Also test with pid as constexpr parameter
@triton.jit
def _expert_gemm_param(out_ptr, x_ptr, w_ptr, expert_id,
H: tl.constexpr, I: tl.constexpr):
BM: tl.constexpr = 64; BK: tl.constexpr = 64; BN: tl.constexpr = 128
row_offs = tl.arange(0, BM); n_offs = tl.arange(0, BN); k_offs_base = tl.arange(0, BK)
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(H // BK):
k_offs = k_tile * BK + k_offs_base
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
expert_w_off = expert_id * H * I
w_ptrs = w_ptr + expert_w_off + k_offs[:, None] * I + n_offs[None, :]
acc += tl.dot(tl.load(x_ptrs).to(tl.float32), tl.load(w_ptrs).to(tl.float32))
result = acc.to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
tl.store(out_ptrs, result)
out2 = torch.zeros(total_rows, I, dtype=torch.bfloat16, device='cuda')
# Launch each expert separately with expert_id as constexpr
for e in range(E):
_expert_gemm_param[(1,)](out2, x, w, e, H=H, I=I)
torch.cuda.synchronize()
zeros2 = (out2 == 0).sum().item()
print(f"\nWith param: Non-zero: {out2.numel() - zeros2}/{out2.numel()}")
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
d = (out2[s:e_end].float() - ref[s:e_end].float()).abs()
print(f" Expert {e}: diff_max={d.max():.6f} zeros={(out2[s:e_end]==0).sum().item()}")
diff = (out2.float() - ref.float()).abs()
print(f"Overall: max_diff={diff.max():.6f} allclose(0.02)={torch.allclose(out2.float(), ref.float(), 0.02, 0.02)}")
PYEOF
uv run python debug_sol10.py 2>&1Non-zero: 8192/32768
Expert 0: diff_max=0.182129 zeros=0
Expert 1: diff_max=0.136719 zeros=8192
Expert 2: diff_max=0.123047 zeros=8192
Expert 3: diff_max=0.136719 zeros=8192
With param: Non-zero: 8192/32768
Expert 0: diff_max=0.199707 zeros=0
Expert 1: diff_max=0.136719 zeros=8192
Expert 2: diff_max=0.123047 zeros=8192
Expert 3: diff_max=0.136719 zeros=8192
Overall: max_diff=0.199707 allclose(0.02)=False
{
"command": "<2264 chars \u2014 see below>",
"timeout": 120000,
"description": "Test double-bug fix"
}cat > debug_sol11.py << 'PYEOF'
"""Test GEMM with CORRECT x pointer offset using offsets array"""
import torch, triton, triton.language as tl
import torch.nn.functional as F
@triton.jit
def _fixed_kernel(out_ptr, x_ptr, w_ptr, offsets_ptr,
H: tl.constexpr, I: tl.constexpr):
pid = tl.program_id(0)
BM: tl.constexpr = 64; BK: tl.constexpr = 64; BN: tl.constexpr = 128
start = tl.load(offsets_ptr + pid)
end = tl.load(offsets_ptr + pid + 1)
row_offs = start + tl.arange(0, BM)
n_offs = tl.arange(0, BN)
k_offs_base = tl.arange(0, BK)
acc = tl.zeros((BM, BN), dtype=tl.float32)
expert_w_off = pid * H * I
for k_tile in range(H // BK):
k_offs = k_tile * BK + k_offs_base
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
w_ptrs = w_ptr + expert_w_off + k_offs[:, None] * I + n_offs[None, :]
acc += tl.dot(tl.load(x_ptrs).to(tl.float32), tl.load(w_ptrs).to(tl.float32))
result = acc.to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
tl.store(out_ptrs, result)
E, H, I = 4, 256, 128
tokens_per_expert = 64
total_rows = E * tokens_per_expert
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * tokens_per_expert
x = torch.randn(total_rows, H, dtype=torch.bfloat16, device='cuda') * 0.1
w = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros(total_rows, I, dtype=torch.bfloat16, device='cuda')
ref = torch.empty(total_rows, I, dtype=torch.bfloat16, device='cuda')
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
ref[s:e_end] = (x[s:e_end].float() @ w[e].float()).bfloat16()
_fixed_kernel[(E,)](out, x, w, offsets, H=H, I=I)
torch.cuda.synchronize()
zeros = (out == 0).sum().item()
print(f"Non-zero: {out.numel() - zeros}/{out.numel()}")
for e in range(E):
s, e_end = int(offsets[e]), int(offsets[e+1])
d = (out[s:e_end].float() - ref[s:e_end].float()).abs()
print(f" Expert {e}: diff_max={d.max():.6f} zeros={(out[s:e_end]==0).sum().item()}")
diff = (out.float() - ref.float()).abs()
print(f"Overall: max_diff={diff.max():.6f} allclose(0.02)={torch.allclose(out.float(), ref.float(), 0.02, 0.02)}")
PYEOF
uv run python debug_sol11.py 2>&1Non-zero: 32768/32768
Expert 0: diff_max=0.000000 zeros=0
Expert 1: diff_max=0.000015 zeros=0
Expert 2: diff_max=0.000000 zeros=0
Expert 3: diff_max=0.000008 zeros=0
Overall: max_diff=0.000015 allclose(0.02)=True
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,11 +1,8 @@
"""Triton grouped GEMM + fused SwiGLU for MoE up-projection (SM120 Blackwell).
Each program instance handles ONE (expert, n_tile, m_tile) triple.
-Within each program, we loop over K tiles to accumulate the full GEMM.
+Loops over K tiles to accumulate the full GEMM.
SwiGLU is fused in the epilogue: silu(gate) * up.
-
-Grid: E * num_n_tiles * max_m_tiles programs total.
-Program pid encodes (expert, n_tile, m_tile) via stride decomposition.
"""
from __future__ import annotations
@@ -22,45 +19,30 @@
@triton.jit
def _grouped_gemm_swiglu_kernel(
- x_ptr,
- w_gate_ptr,
- w_up_ptr,
- out_ptr,
- offsets_ptr,
- H: tl.constexpr,
- I: tl.constexpr,
- E: tl.constexpr,
- num_n_tiles: tl.constexpr,
- max_m_tiles: tl.constexpr,
- BM: tl.constexpr,
- BK: tl.constexpr,
- BN: tl.constexpr,
+ x_ptr, w_gate_ptr, w_up_ptr, out_ptr, offsets_ptr,
+ H: tl.constexpr, I: tl.constexpr, E: tl.constexpr,
+ num_n_tiles: tl.constexpr, max_m_tiles: tl.constexpr,
+ BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr,
num_k_tiles: tl.constexpr,
- stride_xm,
- stride_xk_unused,
- stride_wg_h,
- stride_wg_i,
- stride_om,
- stride_oi,
):
pid = tl.program_id(0)
-
expert = pid // (num_n_tiles * max_m_tiles)
rem = pid % (num_n_tiles * max_m_tiles)
n_tile = rem // max_m_tiles
- m_tile = rem % max_m_tiles
+ m_tile_pid = rem % max_m_tiles
- num_rows = tl.load(offsets_ptr + expert + 1) - tl.load(offsets_ptr + expert)
+ start = tl.load(offsets_ptr + expert)
+ end = tl.load(offsets_ptr + expert + 1)
+ num_rows = end - start
m_tiles_needed = tl.cdiv(num_rows, BM)
- if m_tile >= m_tiles_needed:
+ if m_tile_pid >= m_tiles_needed:
return
- m_start = m_tile * BM
+ m_start = start + m_tile_pid * BM
n_start = n_tile * BN
row_offs = m_start + tl.arange(0, BM)
- load_row_mask = row_offs < num_rows
- row_ptrs = row_offs * stride_xm
+ load_row_mask = row_offs < end
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
@@ -70,33 +52,31 @@
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
+ expert_w_off = expert * H * I
+
for k_tile in range(num_k_tiles):
k_start = k_tile * BK
k_offs = k_start + k_offs_base
k_mask = k_offs < H
- x_ptrs = x_ptr + row_ptrs[:, None] + k_offs[None, :] * stride_xk_unused
- load_mask = load_row_mask[:, None] & k_mask[None, :]
- x_bf16 = tl.load(x_ptrs, mask=load_mask, other=0.0)
- x = x_bf16.to(tl.float32)
+ x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
+ x_mask = load_row_mask[:, None] & k_mask[None, :]
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
- wg_ptrs = w_gate_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
+ w_ptrs = k_offs[:, None] * I + n_offs[None, :]
+ wg_ptrs = w_gate_ptr + expert_w_off + w_ptrs
+ wu_ptrs = w_up_ptr + expert_w_off + w_ptrs
wg_mask = k_mask[:, None] & n_mask[None, :]
- wg_bf16 = tl.load(wg_ptrs, mask=wg_mask, other=0.0)
- wg = wg_bf16.to(tl.float32)
-
- wu_ptrs = w_up_ptr + k_offs[:, None] * stride_wg_h + n_offs[None, :] * stride_wg_i
- wu_bf16 = tl.load(wu_ptrs, mask=wg_mask, other=0.0)
- wu = wu_bf16.to(tl.float32)
+ wg = tl.load(wg_ptrs, mask=wg_mask, other=0.0).to(tl.float32)
+ wu = tl.load(wu_ptrs, mask=wg_mask, other=0.0).to(tl.float32)
gate_acc += tl.dot(x, wg)
up_acc += tl.dot(x, wu)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
- out_fp32 = gate_silu * up_acc
- result = out_fp32.to(tl.bfloat16)
+ result = (gate_silu * up_acc).to(tl.bfloat16)
- out_ptrs = out_ptr + row_ptrs[:, None] + n_offs[None, :]
+ out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
store_mask = load_row_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, result, mask=store_mask)
@@ -115,29 +95,12 @@
total_programs = E * num_n_tiles * max_m_tiles
if total_programs > 0:
- grid = (total_programs,)
-
- _grouped_gemm_swiglu_kernel[grid](
- hidden_states,
- W_gate,
- W_up,
- out,
- expert_offsets,
- H=H,
- I=I,
- E=E,
- num_n_tiles=num_n_tiles,
- max_m_tiles=max_m_tiles,
- BM=BM,
- BK=BK,
- BN=BN,
+ _grouped_gemm_swiglu_kernel[(total_programs,)](
+ hidden_states, W_gate, W_up, out, expert_offsets,
+ H=H, I=I, E=E,
+ num_n_tiles=num_n_tiles, max_m_tiles=max_m_tiles,
+ BM=BM, BK=BK, BN=BN,
num_k_tiles=num_k_tiles,
- stride_xm=H,
- stride_xk_unused=1,
- stride_wg_h=I,
- stride_wg_i=1,
- stride_om=I,
- stride_oi=1,
)
return out
@@ -156,11 +119,7 @@
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
- def forward(
- self,
- hidden_states: torch.Tensor,
- expert_offsets: torch.Tensor,
- ) -> torch.Tensor:
+ def forward(self, hidden_states: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
return FusedSwiGLUGroupedGEMM.apply(
hidden_states, self.W_gate, self.W_up, expert_offsets, self.E, self.I
)
{
"content": "<4855 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "timeout 300 uv run python check.py 2>&1",
"timeout": 360000,
"description": "Run correctness check with fixed kernel"
}PASS
{
"todos": [
{
"content": "Check GPU availability and environment (nvidia-smi, CUDA, nvcc)",
"status": "completed",
"priority": "high"
},
{
"content": "Decide approach: CUTLASS grouped GEMM vs Triton vs raw CUDA",
"status": "completed",
"priority": "high"
},
{
"content": "Implement grouped GEMM + fused SwiGLU kernel (v1)",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Run benchm
... (218 more chars)content: 'Check GPU availability and environment (nvidia-smi, CUDA, nvcc)'
status: 'completed'
priority: 'high'
content: 'Decide approach: CUTLASS grouped GEMM vs Triton vs raw CUDA'
status: 'completed'
priority: 'high'
content: 'Implement grouped GEMM + fused SwiGLU kernel (v1)'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py for performance'
status: 'in_progress'
priority: 'high'
content: 'Profile with ncu and optimize hotspots'
status: 'pending'
priority: 'medium'{
"command": "timeout 300 uv run python benchmark.py 2>&1",
"timeout": 360000,
"description": "Run benchmark"
}W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0429 05:18:29.324000 930557 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0429 0
... (794 more chars){
"command": "<6103 chars \u2014 see below>",
"timeout": 120000,
"description": "Compare bf16 tensor cores vs fp32"
}cat > debug_perf.py << 'PYEOF'
"""Compare kernel performance with different precision settings"""
import torch, triton, triton.language as tl
import torch.nn.functional as F
# Test bf16 tensor cores vs fp32
@triton.jit
def _kernel_bf16(out_ptr, x_ptr, wg_ptr, wu_ptr, offsets_ptr,
H: tl.constexpr, I: tl.constexpr, BM: tl.constexpr,
BK: tl.constexpr, BN: tl.constexpr, num_k_tiles: tl.constexpr):
pid = tl.program_id(0)
expert = pid // tl.cdiv(I, BN)
n_tile = pid % tl.cdiv(I, BN)
start = tl.load(offsets_ptr + expert)
end = tl.load(offsets_ptr + expert + 1)
num_rows = end - start
m_tiles = tl.cdiv(num_rows, BM)
n_start = n_tile * BN
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
for m_tile_id in range(m_tiles):
m_start = start + m_tile_id * BM
row_offs = m_start + tl.arange(0, BM)
row_mask = row_offs < end
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_offs = k_tile * BK + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
wg_ptrs = wg_ptr + expert * H * I + k_offs[:, None] * I + n_offs[None, :]
wu_ptrs = wu_ptr + expert * H * I + k_offs[:, None] * I + n_offs[None, :]
mask2d = row_mask[:, None] & k_mask[None, :]
maskw = k_mask[:, None] & n_mask[None, :]
# Keep bf16, accumulate in fp32 via out_dtype
x = tl.load(x_ptrs, mask=mask2d, other=0.0)
wg = tl.load(wg_ptrs, mask=maskw, other=0.0)
wu = tl.load(wu_ptrs, mask=maskw, other=0.0)
gate_acc += tl.dot(x, wg, out_dtype=tl.float32)
up_acc += tl.dot(x, wu, out_dtype=tl.float32)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
result = (gate_silu * up_acc).to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
tl.store(out_ptrs, result, mask=row_mask[:, None] & n_mask[None, :])
@triton.jit
def _kernel_mixed(out_ptr, x_ptr, wg_ptr, wu_ptr, offsets_ptr,
H: tl.constexpr, I: tl.constexpr, BM: tl.constexpr,
BK: tl.constexpr, BN: tl.constexpr, num_k_tiles: tl.constexpr):
pid = tl.program_id(0)
expert = pid // tl.cdiv(I, BN)
n_tile = pid % tl.cdiv(I, BN)
start = tl.load(offsets_ptr + expert)
end = tl.load(offsets_ptr + expert + 1)
num_rows = end - start
m_tiles = tl.cdiv(num_rows, BM)
n_start = n_tile * BN
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
for m_tile_id in range(m_tiles):
m_start = start + m_tile_id * BM
row_offs = m_start + tl.arange(0, BM)
row_mask = row_offs < end
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_offs = k_tile * BK + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
wg_ptrs = wg_ptr + expert * H * I + k_offs[:, None] * I + n_offs[None, :]
wu_ptrs = wu_ptr + expert * H * I + k_offs[:, None] * I + n_offs[None, :]
mask2d = row_mask[:, None] & k_mask[None, :]
maskw = k_mask[:, None] & n_mask[None, :]
x = tl.load(x_ptrs, mask=mask2d, other=0.0)
wg = tl.load(wg_ptrs, mask=maskw, other=0.0)
wu = tl.load(wu_ptrs, mask=maskw, other=0.0)
# Use large k blocks (BK=128) with fp32, but try tf32 for speed
gate_acc += tl.dot(x.to(tl.float32), wg.to(tl.float32), input_precision="tf32")
up_acc += tl.dot(x.to(tl.float32), wu.to(tl.float32), input_precision="tf32")
gate_silu = gate_acc * tl.sigmoid(gate_acc)
... (2103 more chars)=== Shape 1 (small) ===
bf16_tc: BM=128 BK=64 BN=128 1.03ms max_diff=0.000244 ok=True
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_perf.py", line 146, in <module>
test_kernel(_kernel_bf16, f"bf16_tc", BM=128, BK=BK, BN=128)
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_perf.py", line 118, in test_kernel
kernel_fn[(total,)](out, hidden, Wg, Wu, offsets, H=H, I=I, BM=BM, BK=BK, BN=BN, num_k_tiles=num_k)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 482, in launch_metadata
self._init_handles()
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 456, in _init_handles
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
File "/home/inf
... (287 more chars){
"command": "<4349 chars \u2014 see below>",
"timeout": 180000,
"description": "Test bf16 tensor cores on all shapes"
}cat > debug_perf2.py << 'PYEOF'
"""Test bf16 tensor cores on all 3 shapes"""
import torch, triton, triton.language as tl, torch.nn.functional as F
@triton.jit
def _bf16tc(out_ptr, x_ptr, wg_ptr, wu_ptr, offsets_ptr,
H: tl.constexpr, I: tl.constexpr, BM: tl.constexpr,
BK: tl.constexpr, BN: tl.constexpr, num_k_tiles: tl.constexpr):
pid = tl.program_id(0)
n_per_expert = I // BN if I % BN == 0 else (I + BN - 1) // BN
expert = pid // n_per_expert
n_tile = pid % n_per_expert
start = tl.load(offsets_ptr + expert)
end = tl.load(offsets_ptr + expert + 1)
num_rows = end - start
m_tiles = tl.cdiv(num_rows, BM)
n_start = n_tile * BN
n_offs = n_start + tl.arange(0, BN)
n_mask = n_offs < I
k_offs_base = tl.arange(0, BK)
for m_tile_id in range(m_tiles):
m_start = start + m_tile_id * BM
row_offs = m_start + tl.arange(0, BM)
row_mask = row_offs < end
gate_acc = tl.zeros((BM, BN), dtype=tl.float32)
up_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_offs = k_tile * BK + k_offs_base
k_mask = k_offs < H
x_ptrs = x_ptr + row_offs[:, None] * H + k_offs[None, :]
wg_ptrs = wg_ptr + expert * H * I + k_offs[:, None] * I + n_offs[None, :]
wu_ptrs = wu_ptr + expert * H * I + k_offs[:, None] * I + n_offs[None, :]
maskx = row_mask[:, None] & k_mask[None, :]
maskw = k_mask[:, None] & n_mask[None, :]
x = tl.load(x_ptrs, mask=maskx, other=0.0)
wg = tl.load(wg_ptrs, mask=maskw, other=0.0)
wu = tl.load(wu_ptrs, mask=maskw, other=0.0)
gate_acc += tl.dot(x, wg, out_dtype=tl.float32)
up_acc += tl.dot(x, wu, out_dtype=tl.float32)
gate_silu = gate_acc * tl.sigmoid(gate_acc)
result = (gate_silu * up_acc).to(tl.bfloat16)
out_ptrs = out_ptr + row_offs[:, None] * I + n_offs[None, :]
tl.store(out_ptrs, result, mask=row_mask[:, None] & n_mask[None, :])
def test(shape_idx, BM, BK, BN):
shapes = [
{"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
{"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
{"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
]
s = shapes[shape_idx]
T_total, H, I, E, K = s["T_total"], s["H"], s["I"], s["E"], s["K"]
T_perm = T_total * K
base = T_perm // E
offsets = torch.arange(0, E + 1, dtype=torch.int32, device='cuda') * base
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
Wg = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
Wu = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
out = torch.zeros((T_perm, I), dtype=torch.bfloat16, device='cuda')
num_n = triton.cdiv(I, BN)
num_k = triton.cdiv(H, BK)
total = E * num_n
ref = torch.empty((T_perm, I), dtype=torch.bfloat16, device='cuda')
for e in range(E):
s_e, e_end = int(offsets[e]), int(offsets[e+1])
g = hidden[s_e:e_end].float() @ Wg[e].float()
u = hidden[s_e:e_end].float() @ Wu[e].float()
ref[s_e:e_end] = (F.silu(g) * u).bfloat16()
# Warmup
for _ in range(3):
_bf16tc[(total,)](out, hidden, Wg, Wu, offsets, H=H, I=I, BM=BM, BK=BK, BN=BN, num_k_tiles=num_k)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(5):
_bf16tc[(total,)](out, hidden, Wg, Wu, offsets, H=H, I=I, BM=BM, BK=BK, BN=BN, num_k_tiles=num_k)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / 5
diff = (out.float() - ref.float()).abs()
ok = torch.allclose(out.float(), ref.float(), 0.02, 0.02)
# FLOPs = 2 * T_total * H * (2 * I)
flops = 2 * T_total * H * (2 * I)
tflops = flops / (ms * 1e-3) / 1
... (349 more chars)BF16 tensor cores with fp32 out_dtype:
Shape 0: BM=128 BK=64 BN=128 ms=46.67 tflops=17.67 max_diff=0.000488 ok=True grid=1536 num_k=64 num_n=12
Shape 1: BM=128 BK=64 BN=128 ms=1.05 tflops=32.87 max_diff=0.000244 ok=True grid=512 num_k=32 num_n=8
Shape 2: BM=128 BK=64 BN=128 ms=31.14 tflops=17.65 max_diff=0.000244 ok=True grid=2048 num_k=32 num_n=32