from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "grouped_gemm_swiglu"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
@triton.jit
def expert_kernel(
hidden_states_ptr,
W_gate_tiled_ptr,
W_up_tiled_ptr,
expert_offsets_ptr,
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
stride_hidden_h: tl.constexpr,
tiled_stride_e: tl.constexpr,
tiled_stride_t: tl.constexpr,
tiled_stride_h: tl.constexpr,
tiled_stride_n: tl.constexpr,
tiled_up_stride_e: tl.constexpr,
tiled_up_stride_t: tl.constexpr,
tiled_up_stride_h: tl.constexpr,
tiled_up_stride_n: tl.constexpr,
stride_out_i: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
e = pid
start = tl.load(expert_offsets_ptr + e)
end = tl.load(expert_offsets_ptr + e + 1)
n_tokens = end - start
if n_tokens <= 0:
return
w_gate_base = W_gate_tiled_ptr + e * tiled_stride_e
w_up_base = W_up_tiled_ptr + e * tiled_up_stride_e
num_nt = tl.cdiv(I, BLOCK_N)
for nt in range(num_nt):
n_start = nt * BLOCK_N
n_off = tl.arange(0, BLOCK_N)
n_mask = n_off < I - n_start
w_gate_ptr = w_gate_base + nt * tiled_stride_t
w_up_ptr = w_up_base + nt * tiled_up_stride_t
for m_start in range(0, n_tokens, BLOCK_M):
m_off = tl.arange(0, BLOCK_M)
m_mask = m_start + m_off < n_tokens
x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, H, BLOCK_K):
k_off = tl.arange(0, BLOCK_K)
k_mask = k_start + k_off < H
x = tl.load(
x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
mask=m_mask[:, None] & k_mask[None, :],
)
w_gate = tl.load(
w_gate_ptr
+ (k_start + k_off)[:, None] * tiled_stride_h
+ n_off[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
w_up = tl.load(
w_up_ptr
+ (k_start + k_off)[:, None] * tiled_up_stride_h
+ n_off[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
acc_gate = tl.dot(x, w_gate, acc_gate)
acc_up = tl.dot(x, w_up, acc_up)
gate_silu = tl.sigmoid(acc_gate) * acc_gate
result = (gate_silu * acc_up).to(tl.bfloat16)
out_ptr = output_ptr + (start + m_start) * stride_out_i
tl.store(
out_ptr + m_off[:, None] * stride_out_i + (n_start + n_off)[None, :],
result,
mask=m_mask[:, None] & n_mask[None, :],
)
class Model(nn.Module):
def __init__(self, T_total: int, H: int, I: int, E: int, K: int):
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
self._gate_tiled: torch.Tensor | None = None
self._up_tiled: torch.Tensor | None = None
def _make_tiled(self, bn: int) -> tuple[torch.Tensor, torch.Tensor]:
if self._gate_tiled is not None:
return self._gate_tiled, self._up_tiled
E, H, I = self.E, self.H, self.I
num_nt = (I + bn - 1) // bn
pad = num_nt * bn - I
if pad > 0:
wg = torch.nn.functional.pad(self.W_gate, (0, pad))
wu = torch.nn.functional.pad(self.W_up, (0, pad))
else:
wg = self.W_gate
wu = self.W_up
gt = wg.view(E, H, num_nt, bn).permute(0, 2, 1, 3).contiguous()
ut = wu.view(E, H, num_nt, bn).permute(0, 2, 1, 3).contiguous()
self._gate_tiled = gt
self._up_tiled = ut
return gt, ut
def forward(
self,
hidden_states: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
T_perm, _ = hidden_states.shape
out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
bn = 128
bm = 64
bk = 32
nw = 4
ns = 3
gt, ut = self._make_tiled(bn)
expert_kernel[(self.E,)](
hidden_states,
gt,
ut,
expert_offsets,
out,
self.H,
self.I,
hidden_states.stride(0),
gt.stride(0),
gt.stride(1),
gt.stride(2),
gt.stride(3),
ut.stride(0),
ut.stride(1),
ut.stride(2),
ut.stride(3),
out.stride(0),
BLOCK_M=bm,
BLOCK_N=bn,
BLOCK_K=bk,
num_warps=nw,
num_stages=ns,
)
return out
T_total = 32768
H = 4096
I = 1536
E = 128
K = 8
def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
T_perm = T_total * K
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
return offsets
def get_inputs():
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
expert_offsets = _build_routing(T_total, E, K)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 12:23:13.934000 1161976 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=33.581 gbps=251.421 ms=24.556
shape=0 variant=compiled tflops=33.522 gbps=250.979 ms=24.600
shape=0 variant=solution tflops=10.346 gbps=77.457 ms=79.709
shape=0 solution_peak_fraction=0.0517
shape=1 variant=eager tflops=14.445 gbps=268.016 ms=2.379
shape=1 variant=compiled tflops=14.341 gbps=266.090 ms=2.396
shape=1 variant=solution tflops=26.778 gbps=496.856 ms=1.283
shape=1 solution_peak_fraction=0.1339
shape=2 variant=eager tflops=35.549 gbps=243.009 ms=15.465
shape=2 variant=compiled tflops=35.332 gbps=241.528 ms=15.560
shape=2 variant=solution tflops=16.606 gbps=113.519 ms=33.106
shape=2 solution_peak_fraction=0.0830
peak_fraction: 0.0832
RESULT: LOW
PASS
{
"problem": "06_sonic_moe_swiglu",
"harness": "opencode",
"model": "deepseek/deepseek-v4-flash",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0832,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 63543, "output_tokens": 67232, "cache_read_tokens": 11005568, "cache_creation_tokens": 0, "reasoning_tokens": 61716, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py</path>
<type>file</type>
<content>
1: """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2:
3: This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4: expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5: according to routing metadata; we compute, per expert e:
6:
7: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
8:
9: where x_e is the slice of permuted hidden states routed to expert e, with
10: expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11:
12: The reference loops over experts in Python. Slow, but pedagogically clear and
13: correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14: are NOT used here, but the reference is exempt — only solution.py is checked.
15: """
16: from __future__ import annotations
17:
18: import torch
19: import torch.nn as nn
20: import torch.nn.functional as F
21:
22: OP_TYPE = "grouped_gemm_swiglu"
23: SUPPORTED_PRECISIONS = ["bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26:
27: class Model(nn.Module):
28: """Up-projection of a top-K MoE FFN with fused SwiGLU.
29:
30: Inputs at call time:
31: hidden_states: (T_perm, H) bf16, already permuted to expert order
32: expert_offsets: (E+1,) int32, prefix sums of token counts per expert
33:
... (2866 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.ge
... (2793 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Defaults match the sonic-moe paper's headline configuration. We add:
4: - a smaller shape for fast iteration during agent development
5: - a wider intermediate (different aspect ratio) to stress N-tile selection
6: """
7:
8: SHAPES = [
9: # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10: {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11:
12: # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13: # variable-length grouped layout meaningful, but smaller token / hidden dims.
14: {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15:
16: # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17: # Forces tiles to handle larger N relative to K.
18: {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19: ]
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
2:
3: Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
4: requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
5: SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
6: the package installs cleanly but kernels may fail at launch on SM120 (the
7: QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
8:
9: If the live call fails, `is_available()` returns False and the benchmark scores
10: the agent against PyTorch eager + the documented H100 paper ceiling (see
11: problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
12: importing sonic_moe in solution.py (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15:
16: import torch
17:
18:
19: def _try_sonic_moe(
20: hidden_states: torch.Tensor,
21: W_gate: torch.Tensor,
22: W_up: torch.Tensor,
23: expert_offsets: torch.Tensor,
24: ) -> torch.Tensor | None:
25: try:
26: import sonic_moe # type: ignore # noqa: F401
27: except Exception:
28: return None
29: try:
30: # Public sonic-moe API surface is still stabilizing. The expected entry
31: # point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
32: # SwiGLU. Adapt to the a
... (1755 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: return float(eval(expr, {"__builtins__": {}}, vars))
27:
28:
29: def main():
30: import reference
31: import shapes
32: import solution
33:
34: meta = yaml.safe_load(Path("problem.yaml").read_text())
35: hw = get_hw(meta["hardware"][0])
36: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37: peak_gbps = hw.peak_bandwidth_gb_s
38: regime = meta.get("regime", "compute")
39: flops_formula = meta["
... (3701 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml</path>
<type>file</type>
<content>
1: name: 06_sonic_moe_swiglu
2: display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3: precision: bf16
4: regime: compute
5: hardware: [RTX_PRO_6000]
6: peak_tflops_key: bf16
7: peak_bandwidth_key: dram
8:
9: # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10: # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11: # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12: # only if you count routing. The standard MoE FLOPs convention counts only the
13: # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14: flops_formula: "2 * T_total * H * (2 * I)"
15:
16: # Bytes moved (approximate, lower bound):
17: # read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18: # expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19: bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25: # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26: # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27: forbidden:
28: - "torch.matmul"
29: - "torch.bmm"
30: - "torch.nn.functional.linear"
31: - "F.linear"
32: - "from son
... (681 more chars){
"pattern": "**/src/eval/**/*.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard"
}/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/report.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/shapes.py
/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py{
"pattern": "**/src/hardware/**/*.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard"
}/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/m4_max.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"command": "which nvcc && nvcc --version 2>/dev/null; echo \"---\"; python -c \"import torch; print(f'torch {torch.__version__}'); print(f'CUDA {torch.version.cuda}'); print(torch.cuda.get_device_capability())\" 2>/dev/null; echo \"---\"; python -c \"import triton; print(f'triton {triton.__version__}')\" 2>/dev/null || echo \"no triton\"; echo \"---\"; nvidia-smi --query-gpu=name,memory.total,compute_cap --format=csv,noheader 2>/dev/null || echo \"no nvidia-smi\"",
"description": "Check CUDA, torch, triton, GPU info"
}/usr/local/cuda-13/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
---
torch 2.11.0+cu130
CUDA 13.0
(12, 0)
---
triton 3.6.0
---
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 12.0
{
"todos": [
{
"content": "Read reference, problem files, and understand the op specification",
"status": "completed",
"priority": "high"
},
{
"content": "Write initial Triton-based solution.py with grouped GEMM + fused SwiGLU",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py to verify correctness across all shapes and seeds",
"status": "pending",
"priority": "high"
},
{
"content": "Profile with benchmark.py and iterate on performance",
"status": "pending",
"priority"
... (297 more chars)content: 'Read reference, problem files, and understand the op specification'
status: 'completed'
priority: 'high'
content: 'Write initial Triton-based solution.py with grouped GEMM + fused SwiGLU'
status: 'in_progress'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes and seeds'
status: 'pending'
priority: 'high'
content: 'Profile with benchmark.py and iterate on performance'
status: 'pending'
priority: 'high'
content: 'Optimize tiling, autotuning, and memory access patterns'
status: 'pending'
priority: 'medium'
content: 'Final verification: check.py PASS + benchmark.py'
status: 'pending'
priority: 'high'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,198 @@
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "grouped_gemm_swiglu"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+
+def _get_configs(H: int, I: int, M_avg: int) -> list[triton.Config]:
+ """Return heuristic configs based on problem dimensions."""
+ configs = []
+ # M tile: pick based on average tokens-per-expert
+ for bm in [32, 64, 128]:
+ if bm > M_avg and bm != 32:
+ continue
+ for bn in [64, 128, 256]:
+ if bn >= I:
+ bn = I
+ for bk in [32, 64]:
+ for nw in [4, 8]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": bm, "BLOCK_N": bn, "BLOCK_K": bk},
+ num_warps=nw,
+ num_stages=3,
+ )
+ )
+ return configs
+
+
+@triton.autotune(
+ configs=[],
+ key=["H", "I"],
+ use_cuda_graph=True,
+)
+@triton.jit
+def grouped_gemm_swiglu_kernel(
+ hidden_states_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ expert_offsets_ptr,
+ output_ptr,
+ H: tl.constexpr,
+ I: tl.constexpr,
+ stride_hidden_h: tl.constexpr,
+ stride_gate_e: tl.constexpr,
+ stride_gate_h: tl.constexpr,
+ stride_gate_i: tl.constexpr,
+ stride_up_e: tl.constexpr,
+ stride_up_h: tl.constexpr,
+ stride_up_i: tl.constexpr,
+ stride_out_i: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ e = pid // num_n_tiles
+ n_tile_idx = pid % num_n_tiles
+
+ start = tl.load(expert_offsets_ptr + e)
+ end = tl.load(expert_offsets_ptr + e + 1)
+ n_tokens = end - start
+ if n_tokens <= 0:
+ return
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ w_gate_base = W_gate_ptr + e * stride_gate_e
+ w_up_base = W_up_ptr + e * stride_up_e
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+
+ x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x = tl.load(
+ x_ptr + m_off[:, None] * stride_hidden_h + k_off[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_gate = tl.load(
+ w_gate_base
+ + k_off[:, None] * stride_gate_h
+ + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up = tl.load(
+ w_up_base
+ + k_off[:, None] * stride_up_h
+ + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+
+ acc_gate = tl.dot(x, w_gate, acc_gate)
+ acc_up = tl.dot(x, w_up, acc_up)
+
+ gate_silu = tl.sigmoid(acc_gate) * acc_gate
+ result = (gate_silu * acc_up).to(tl.bfloat16)
+
+ out_ptr = output_ptr + (start + m_start) * stride_out_i
+ tl.store(
+ out_ptr + m_off[:, None] * stride_out_i + n_off[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+
+class Model(nn.Module):
+ def __init__(self, T_total: int, H: int, I: int, E: int, K: int):
+ super().__init__()
+ self.T_total = T_total
+ self.H = H
+ self.I = I
+ self.E = E
+ self.K = K
+ self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ nn.init.normal_(self.W_gate, std=0.02)
+ nn.init.normal_(self.W_up, std=0.02)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ ) -> torch.Tensor:
+ T_perm, _ = hidden_states.shape
+ out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
+
+ M_avg = T_perm // self.E
+ grid = (self.E * triton.cdiv(self.I, 128),)
+
+ grouped_gemm_swiglu_kernel[grid](
+ hidden_states,
+ self.W_gate,
+ self.W_up,
+ expert_offsets,
+ out,
+ self.H,
+ self.I,
+ hidden_states.stride(0),
+ self.W_gate.stride(0),
+ self.W_gate.stride(1),
+ self.W_gate.stride(2),
+ self.W_up.stride(0),
+ self.W_up.stride(1),
+ self.W_up.stride(2),
+ out.stride(0),
+ BLOCK_M=64,
+ BLOCK_N=128,
+ BLOCK_K=32,
+ num_warps=4,
+ num_stages=3,
+ )
+ return out
+
+
+T_total = 32768
+H = 4096
+I = 1536
+E = 128
+K = 8
+
+
+def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
+ T_perm = T_total * K
+ base = T_perm // E
+ rem = T_perm - base * E
+ counts = torch.full((E,), base, dtype=torch.int32, device=device)
+ counts[:rem] += 1
+ offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
+ offsets[1:] = torch.cumsum(counts, dim=0)
+ return offsets
+
+
+def get_inputs():
+ T_perm = T_total * K
+ hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
+ expert_offsets = _build_routing(T_total, E, K)
+ return [hidden_states, expert_offsets]
+
+
+def get_init_in
... (42 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"content": "<5727 chars \u2014 see diff>"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -10,33 +10,6 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-def _get_configs(H: int, I: int, M_avg: int) -> list[triton.Config]:
- """Return heuristic configs based on problem dimensions."""
- configs = []
- # M tile: pick based on average tokens-per-expert
- for bm in [32, 64, 128]:
- if bm > M_avg and bm != 32:
- continue
- for bn in [64, 128, 256]:
- if bn >= I:
- bn = I
- for bk in [32, 64]:
- for nw in [4, 8]:
- configs.append(
- triton.Config(
- {"BLOCK_M": bm, "BLOCK_N": bn, "BLOCK_K": bk},
- num_warps=nw,
- num_stages=3,
- )
- )
- return configs
-
-
-@triton.autotune(
- configs=[],
- key=["H", "I"],
- use_cuda_graph=True,
-)
@triton.jit
def grouped_gemm_swiglu_kernel(
hidden_states_ptr,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<832 chars \u2014 see below>",
"newString": "@triton.jit"
}def _get_configs(H: int, I: int, M_avg: int) -> list[triton.Config]:
"""Return heuristic configs based on problem dimensions."""
configs = []
# M tile: pick based on average tokens-per-expert
for bm in [32, 64, 128]:
if bm > M_avg and bm != 32:
continue
for bn in [64, 128, 256]:
if bn >= I:
bn = I
for bk in [32, 64]:
for nw in [4, 8]:
configs.append(
triton.Config(
{"BLOCK_M": bm, "BLOCK_N": bn, "BLOCK_K": bk},
num_warps=nw,
num_stages=3,
)
)
return configs
@triton.autotune(
configs=[],
key=["H", "I"],
use_cuda_graph=True,
)
@triton.jitEdit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -115,7 +115,12 @@
out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
M_avg = T_perm // self.E
- grid = (self.E * triton.cdiv(self.I, 128),)
+ bn = 128 if self.I >= 128 else 32
+ bm = 64
+ bk = 32
+ nw = 4
+ ns = 3
+ grid = (self.E * triton.cdiv(self.I, bn),)
grouped_gemm_swiglu_kernel[grid](
hidden_states,
@@ -133,11 +138,11 @@
self.W_up.stride(1),
self.W_up.stride(2),
out.stride(0),
- BLOCK_M=64,
- BLOCK_N=128,
- BLOCK_K=32,
- num_warps=4,
- num_stages=3,
+ BLOCK_M=bm,
+ BLOCK_N=bn,
+ BLOCK_K=bk,
+ num_warps=nw,
+ num_stages=ns,
)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<971 chars \u2014 see below>",
"newString": "<1075 chars \u2014 see below>"
} def forward(
self,
hidden_states: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
T_perm, _ = hidden_states.shape
out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
M_avg = T_perm // self.E
grid = (self.E * triton.cdiv(self.I, 128),)
grouped_gemm_swiglu_kernel[grid](
hidden_states,
self.W_gate,
self.W_up,
expert_offsets,
out,
self.H,
self.I,
hidden_states.stride(0),
self.W_gate.stride(0),
self.W_gate.stride(1),
self.W_gate.stride(2),
self.W_up.stride(0),
self.W_up.stride(1),
self.W_up.stride(2),
out.stride(0),
BLOCK_M=64,
BLOCK_N=128,
BLOCK_K=32,
num_warps=4,
num_stages=3,
)
return out def forward(
self,
hidden_states: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
T_perm, _ = hidden_states.shape
out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
M_avg = T_perm // self.E
bn = 128 if self.I >= 128 else 32
bm = 64
bk = 32
nw = 4
ns = 3
grid = (self.E * triton.cdiv(self.I, bn),)
grouped_gemm_swiglu_kernel[grid](
hidden_states,
self.W_gate,
self.W_up,
expert_offsets,
out,
self.H,
self.I,
hidden_states.stride(0),
self.W_gate.stride(0),
self.W_gate.stride(1),
self.W_gate.stride(2),
self.W_up.stride(0),
self.W_up.stride(1),
self.W_up.stride(2),
out.stride(0),
BLOCK_M=bm,
BLOCK_N=bn,
BLOCK_K=bk,
num_warps=nw,
num_stages=ns,
)
return outEdit applied successfully.{
"description": "Run check.py to test correctness",
"command": "uv run python check.py 2>&1",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu",
"timeout": 300000
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=36.5057 (atol=0.02, rtol=0.02)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug.py
@@ -0,0 +1,158 @@
+"""Debug script: test the triton kernel on a small case and compare with reference."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+# Small case: 2 experts, 8 tokens each, small dims
+T_total = 4
+H = 32
+I = 16
+E = 2
+K = 2
+T_perm = T_total * K # 8
+
+# Round-robin routing
+base = T_perm // E
+rem = T_perm - base * E
+counts = torch.full((E,), base, dtype=torch.int32)
+counts[:rem] += 1
+expert_offsets = torch.zeros(E + 1, dtype=torch.int32)
+expert_offsets[1:] = torch.cumsum(counts, dim=0)
+print(f"expert_offsets: {expert_offsets}")
+
+hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
+W_gate = torch.randn(E, H, I, dtype=torch.bfloat16) * 0.02
+W_up = torch.randn(E, H, I, dtype=torch.bfloat16) * 0.02
+
+# Reference
+ref_out = torch.empty(T_perm, I, dtype=torch.bfloat16)
+for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ x_e = hidden_states[start:end]
+ gate = x_e @ W_gate[e]
+ up = x_e @ W_up[e]
+ ref_out[start:end] = torch.nn.functional.silu(gate) * up
+
+
+# Triton kernel (copy from solution)
+@triton.jit
+def grouped_gemm_swiglu_kernel(
+ hidden_states_ptr,
+ W_gate_ptr,
+ W_up_ptr,
+ expert_offsets_ptr,
+ output_ptr,
+ H: tl.constexpr,
+ I: tl.constexpr,
+ stride_hidden_h: tl.constexpr,
+ stride_gate_e: tl.constexpr,
+ stride_gate_h: tl.constexpr,
+ stride_gate_i: tl.constexpr,
+ stride_up_e: tl.constexpr,
+ stride_up_h: tl.constexpr,
+ stride_up_i: tl.constexpr,
+ stride_out_i: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ e = pid // num_n_tiles
+ n_tile_idx = pid % num_n_tiles
+
+ start = tl.load(expert_offsets_ptr + e)
+ end = tl.load(expert_offsets_ptr + e + 1)
+ n_tokens = end - start
+ if n_tokens <= 0:
+ return
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ w_gate_base = W_gate_ptr + e * stride_gate_e
+ w_up_base = W_up_ptr + e * stride_up_e
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+
+ x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x = tl.load(
+ x_ptr + m_off[:, None] * stride_hidden_h + k_off[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_gate = tl.load(
+ w_gate_base
+ + k_off[:, None] * stride_gate_h
+ + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up = tl.load(
+ w_up_base
+ + k_off[:, None] * stride_up_h
+ + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+
+ acc_gate = tl.dot(x, w_gate, acc_gate)
+ acc_up = tl.dot(x, w_up, acc_up)
+
+ gate_silu = tl.sigmoid(acc_gate) * acc_gate
+ result = (gate_silu * acc_up).to(tl.bfloat16)
+
+ out_ptr = output_ptr + (start + m_start) * stride_out_i
+ tl.store(
+ out_ptr + m_off[:, None] * stride_out_i + n_off[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+
+triton_out = torch.empty(T_perm, I, dtype=torch.bfloat16, device='cuda')
+hs_cuda = hidden_states.cuda()
+wg_cuda = W_gate.cuda()
+wu_cuda = W_up.cuda()
+eo_cuda = expert_offsets.cuda()
+
+grid = (E * triton.cdiv(I, 16),)
+grouped_gemm_swiglu_kernel[grid](
+ hs_cuda, wg_cuda, wu_cuda, eo_cuda, triton_out,
+ H, I,
+ hs_cuda.stride(0),
+ wg_cuda.stride(0), wg_cuda.stride(1), wg_cuda.stride(2),
+ wu_cuda.stride(0), wu_cuda.stride(1), wu_cuda.stride(2),
+ triton_out.stride(0),
+ BLOCK_M=4, BLOCK_N=16, BLOCK_K=8, num_warps=4, num_stages=2,
+)
+
+print(f"Reference:\n{ref_out}")
+print(f"Triton:\n{triton_out.cpu()}")
+diff = (ref_out.float() - triton_out.cpu().float()).abs()
+print(f"Max diff: {diff.max().item()}")
+print(f"Mean diff: {diff.mean().item()}")
+
+if torch.allclose(ref_out.float(), triton_out.cpu().float(), atol=0.02, rtol=0.02):
+ print("PASS: results match")
+else:
+ print("FAIL: results differ")
+ # Print per-expert comparison
+ for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ ref_slice = ref_out[start:end].float()
+ sol_slice = triton_out[start:end].cpu().float()
+ d = (ref_slice - sol_slice).abs().max().item()
+ print(f" Expert {e}: tokens {start}-{end}, max_diff={d:.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug.py",
"content": "<4921 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug script on small test case",
"command": "uv run python debug.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}expert_offsets: tensor([0, 4, 8], dtype=torch.int32)
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 2054, in dot
res = _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 1528, in dot
assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Input shapes should have M >= 1, N >= 1 and K >= 16
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug.py", line 131, in <module>
grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
... (1775 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug.py
@@ -135,7 +135,7 @@
wg_cuda.stride(0), wg_cuda.stride(1), wg_cuda.stride(2),
wu_cuda.stride(0), wu_cuda.stride(1), wu_cuda.stride(2),
triton_out.stride(0),
- BLOCK_M=4, BLOCK_N=16, BLOCK_K=8, num_warps=4, num_stages=2,
+ BLOCK_M=4, BLOCK_N=16, BLOCK_K=16, num_warps=4, num_stages=2,
)
print(f"Reference:\n{ref_out}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug.py",
"oldString": "<366 chars \u2014 see below>",
"newString": "<367 chars \u2014 see below>"
}grid = (E * triton.cdiv(I, 16),)
grouped_gemm_swiglu_kernel[grid](
hs_cuda, wg_cuda, wu_cuda, eo_cuda, triton_out,
H, I,
hs_cuda.stride(0),
wg_cuda.stride(0), wg_cuda.stride(1), wg_cuda.stride(2),
wu_cuda.stride(0), wu_cuda.stride(1), wu_cuda.stride(2),
triton_out.stride(0),
BLOCK_M=4, BLOCK_N=16, BLOCK_K=8, num_warps=4, num_stages=2,
)grid = (E * triton.cdiv(I, 16),)
grouped_gemm_swiglu_kernel[grid](
hs_cuda, wg_cuda, wu_cuda, eo_cuda, triton_out,
H, I,
hs_cuda.stride(0),
wg_cuda.stride(0), wg_cuda.stride(1), wg_cuda.stride(2),
wu_cuda.stride(0), wu_cuda.stride(1), wu_cuda.stride(2),
triton_out.stride(0),
BLOCK_M=4, BLOCK_N=16, BLOCK_K=16, num_warps=4, num_stages=2,
)Edit applied successfully.{
"description": "Run debug script again",
"command": "uv run python debug.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}expert_offsets: tensor([0, 4, 8], dtype=torch.int32)
Reference:
tensor([[ 8.0466e-06, -1.2636e-05, -2.9802e-06, 2.0489e-07, 2.6345e-05,
-5.6028e-05, 5.4359e-05, -2.2650e-06, 2.3842e-06, -2.5153e-05,
-2.3365e-05, 1.0133e-05, -3.6210e-06, 2.1309e-06, 1.0073e-05,
-4.9174e-06],
[ 1.0204e-04, 2.7776e-05, -2.2769e-05, -9.8944e-06, 9.6321e-05,
-1.6022e-04, -7.3433e-05, -3.5286e-05, 1.9789e-05, 6.0797e-05,
-1.6880e-04, -8.5831e-06, -8.4877e-05, 2.3723e-05, 6.9439e-06,
-1.4976e-06],
[ 3.0160e-05, 2.4915e-05, -3.8147e-05, -7.8678e-05, 1.2684e-04,
-1.9837e-07, -1.8597e-05, -7.6771e-05, 2.4736e-06, 4.5061e-05,
-3.9816e-05, 6.2883e-06, -1.7643e-05, 6.9737e-06, -6.3777e-06,
1.1086e-05],
[-9.4771e-06, 2.4773e-07, 1.6308e-04, 7.8678e-06, 1.6451e-05,
-4.3511e-06, 2.0504e-05, -4.4346e-05, -1.5974e-05, 2.0742e-05,
8.0109e-05, -1.3232e-05, 9.0599e-06, 4.6730e-05, -5.0306e-05,
3.0994e-05],
[-2.1338e-05, -1.0431e-05, -1.5460e-07, 5.1022e-05, 1.6022e-04,
-6.4373e-05, -2.1458e-05, -1.5497e-05, -7.0035e-06, -2.8908e-06,
-1.4591e-04, -6.3777e-06, -2.0504e-05, 1.9789e-05, -8.0109e-05,
-2.1577e-05],
[-5.2214e-05, 2.9564e-05, 8.9407e-06, 1.0681e-04, 3.0518e-04,
2.2769e-05, 4.8876e-05, -1.4603e-05, -1.9193e-05, 2.1011e-06,
1.4961e-05, -3.4809e-05, -2.5940e-04, 1.7524e-05, -2
... (2623 more chars){
"description": "Run check.py to test correctness",
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=37.4923 (atol=0.02, rtol=0.02)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug2.py
@@ -0,0 +1,141 @@
+"""Test with actual problem shapes."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+T_total = 2048 # small enough to debug easily
+H = 4096
+I = 1536
+E = 4
+K = 8
+T_perm = T_total * K
+
+# Round-robin routing
+base = T_perm // E
+rem = T_perm - base * E
+counts = torch.full((E,), base, dtype=torch.int32)
+counts[:rem] += 1
+expert_offsets = torch.zeros(E + 1, dtype=torch.int32)
+expert_offsets[1:] = torch.cumsum(counts, dim=0)
+print(f"expert_offsets: {expert_offsets}")
+print(f"Per-expert token counts: {[int(expert_offsets[i+1]-expert_offsets[i]) for i in range(E)]}")
+
+hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
+W_gate = torch.randn(E, H, I, dtype=torch.bfloat16) * 0.02
+W_up = torch.randn(E, H, I, dtype=torch.bfloat16) * 0.02
+
+# Reference
+ref_out = torch.empty(T_perm, I, dtype=torch.bfloat16)
+for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ x_e = hidden_states[start:end]
+ gate = x_e @ W_gate[e]
+ up = x_e @ W_up[e]
+ ref_out[start:end] = torch.nn.functional.silu(gate) * up
+
+# Per-expert reference for comparison
+for e in range(E):
+ s = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ n = end - s
+ x_e = hidden_states[s:end]
+ gate_ref = x_e @ W_gate[e]
+ up_ref = x_e @ W_up[e]
+ out_ref = torch.nn.functional.silu(gate_ref) * up_ref
+
+ # Now test Triton for this expert
+ BLOCK_M = 64
+ BLOCK_N = 128
+ BLOCK_K = 32
+
+ @triton.jit
+ def single_expert_kernel(
+ x_ptr, w_gate_ptr, w_up_ptr, out_ptr,
+ n_tokens: tl.constexpr,
+ H: tl.constexpr, I: tl.constexpr,
+ x_stride: tl.constexpr,
+ w_gate_stride_h: tl.constexpr, w_gate_stride_i: tl.constexpr,
+ w_up_stride_h: tl.constexpr, w_up_stride_i: tl.constexpr,
+ out_stride: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ ):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+
+ x_base = x_ptr + m_start * x_stride
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x = tl.load(
+ x_base + m_off[:, None] * x_stride + k_off[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_gate = tl.load(
+ w_gate_ptr + k_off[:, None] * w_gate_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up = tl.load(
+ w_up_ptr + k_off[:, None] * w_up_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ acc_gate = tl.dot(x, w_gate, acc_gate)
+ acc_up = tl.dot(x, w_up, acc_up)
+
+ gate_silu = tl.sigmoid(acc_gate) * acc_gate
+ result = (gate_silu * acc_up).to(tl.bfloat16)
+
+ out_base = out_ptr + m_start * out_stride
+ tl.store(
+ out_base + m_off[:, None] * out_stride + n_off[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+ num_n_tiles = triton.cdiv(I, BLOCK_N)
+ triton_out_e = torch.zeros(n, I, dtype=torch.bfloat16, device='cuda')
+
+ x_cuda = x_e.unsqueeze(0).cuda() # add batch dim... no, keep as (n, H)
+ x_cuda = x_e.cuda()
+ wg_cuda = W_gate[e:e+1].cuda().squeeze(0)
+ wu_cuda = W_up[e:e+1].cuda().squeeze(0)
+
+ # Use parameters without the E dimension
+ single_expert_kernel[(num_n_tiles,)](
+ x_cuda, wg_cuda, wu_cuda, triton_out_e,
+ n, H, I,
+ x_cuda.stride(0),
+ wg_cuda.stride(0), wg_cuda.stride(1),
+ wu_cuda.stride(0), wu_cuda.stride(1),
+ triton_out_e.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ )
+
+ diff = (out_ref.cuda().float() - triton_out_e.float()).abs()
+ print(f" Expert {e} (n={n}): max_diff={diff.max().item():.6f}")
+
+ if diff.max().item() > 1.0:
+ print(f" WARNING: large error!")
+ # Check individual tile
+ for ni in range(num_n_tiles):
+ ns = ni * BLOCK_N
+ ne = min(ns + BLOCK_N, I)
+ ref_tile = out_ref[:, ns:ne].float()
+ sol_tile = triton_out_e[:, ns:ne].cpu().float()
+ tile_diff = (ref_tile - sol_tile).abs()
+ print(f" N-tile {ni} ({ns}:{ne}): max_diff={tile_diff.max().item():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug2.py",
"content": "<5071 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug2 to isolate issue",
"command": "uv run python debug2.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}expert_offsets: tensor([ 0, 4096, 8192, 12288, 16384], dtype=torch.int32)
Per-expert token counts: [4096, 4096, 4096, 4096]
Expert 0 (n=4096): max_diff=36.759033
WARNING: large error!
N-tile 0 (0:128): max_diff=36.759033
N-tile 1 (128:256): max_diff=0.112305
N-tile 2 (256:384): max_diff=0.109375
N-tile 3 (384:512): max_diff=0.107910
N-tile 4 (512:640): max_diff=0.111816
N-tile 5 (640:768): max_diff=0.098145
N-tile 6 (768:896): max_diff=0.116699
N-tile 7 (896:1024): max_diff=0.112793
N-tile 8 (1024:1152): max_diff=0.104492
N-tile 9 (1152:1280): max_diff=0.125977
N-tile 10 (1280:1408): max_diff=0.136719
N-tile 11 (1408:1536): max_diff=0.104004
Expert 1 (n=4096): max_diff=26.745422
WARNING: large error!
N-tile 0 (0:128): max_diff=26.745422
N-tile 1 (128:256): max_diff=0.097168
N-tile 2 (256:384): max_diff=0.108887
N-tile 3 (384:512): max_diff=0.125977
N-tile 4 (512:640): max_diff=0.117188
N-tile 5 (640:768): max_diff=0.123047
N-tile 6 (768:896): max_diff=0.146484
N-tile 7 (896:1024): max_diff=0.097656
N-tile 8 (1024:1152): max_diff=0.109863
N-tile 9 (1152:1280): max_diff=0.118652
N-tile 10 (1280:1408): max_diff=0.111816
N-tile 11 (1408:1536): max_diff=0.109375
Expert 2 (n=4096): max_diff=30.765320
WARNING: large error!
N-tile 0 (0:128): max_diff=30.765320
N-tile 1 (128:256): max_diff=0.108887
N-tile 2 (256:384): max_diff=0.098633
N-tile 3 (384:51
... (949 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -63,18 +63,18 @@
k_mask = k_start + k_off < H
x = tl.load(
- x_ptr + m_off[:, None] * stride_hidden_h + k_off[None, :],
+ x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
mask=m_mask[:, None] & k_mask[None, :],
)
w_gate = tl.load(
w_gate_base
- + k_off[:, None] * stride_gate_h
+ + (k_start + k_off)[:, None] * stride_gate_h
+ (n_start + n_off)[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
w_up = tl.load(
w_up_base
- + k_off[:, None] * stride_up_h
+ + (k_start + k_off)[:, None] * stride_up_h
+ (n_start + n_off)[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<609 chars \u2014 see below>",
"newString": "<645 chars \u2014 see below>"
} x = tl.load(
x_ptr + m_off[:, None] * stride_hidden_h + k_off[None, :],
mask=m_mask[:, None] & k_mask[None, :],
)
w_gate = tl.load(
w_gate_base
+ k_off[:, None] * stride_gate_h
+ (n_start + n_off)[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
w_up = tl.load(
w_up_base
+ k_off[:, None] * stride_up_h
+ (n_start + n_off)[None, :],
mask=k_mask[:, None] & n_mask[None, :],
) x = tl.load(
x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
mask=m_mask[:, None] & k_mask[None, :],
)
w_gate = tl.load(
w_gate_base
+ (k_start + k_off)[:, None] * stride_gate_h
+ (n_start + n_off)[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
w_up = tl.load(
w_up_base
+ (k_start + k_off)[:, None] * stride_up_h
+ (n_start + n_off)[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)Edit applied successfully.{
"description": "Run check.py to verify corrected kernel",
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.181641 (atol=0.02, rtol=0.02)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: from __future__ import annotations
2:
3: import torch
4: import torch.nn as nn
5: import triton
6: import triton.language as tl
7:
8: OP_TYPE = "grouped_gemm_swiglu"
9: SUPPORTED_PRECISIONS = ["bf16"]
10: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
11:
12:
13: @triton.jit
14: def grouped_gemm_swiglu_kernel(
15: hidden_states_ptr,
16: W_gate_ptr,
17: W_up_ptr,
18: expert_offsets_ptr,
19: output_ptr,
20: H: tl.constexpr,
21: I: tl.constexpr,
22: stride_hidden_h: tl.constexpr,
23: stride_gate_e: tl.constexpr,
24: stride_gate_h: tl.constexpr,
25: stride_gate_i: tl.constexpr,
26: stride_up_e: tl.constexpr,
27: stride_up_h: tl.constexpr,
28: stride_up_i: tl.constexpr,
29: stride_out_i: tl.constexpr,
30: BLOCK_M: tl.constexpr,
31: BLOCK_N: tl.constexpr,
32: BLOCK_K: tl.constexpr,
33: ):
34: pid = tl.program_id(0)
35: num_n_tiles = tl.cdiv(I, BLOCK_N)
36: e = pid // num_n_tiles
37: n_tile_idx = pid % num_n_tiles
38:
39: start = tl.load(expert_offsets_ptr + e)
40: end = tl.load(expert_offsets_ptr + e + 1)
41: n_tokens = end - start
42: if n_tokens <= 0:
43: return
44:
45: n_start = n_tile_idx * BLOCK_N
46: n_off = tl.arange(0, BLOCK_N)
47: n_mask = n_off < I - n_start
48:
49: w_gate_base = W_gate_ptr + e * stride_gate_e
50
... (4481 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py
@@ -0,0 +1,149 @@
+"""Debug remaining error: test individual components."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+H = 4096
+I = 1536
+E = 128
+K = 8
+T_total = 4096
+T_perm = T_total * K
+
+# Single expert test
+n_tokens = 512 # tokens for this expert
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w_gate = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_up = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# Reference (full precision)
+gate_ref = x.float() @ w_gate.float()
+up_ref = x.float() @ w_up.float()
+out_ref = torch.nn.functional.silu(gate_ref) * up_ref
+
+# Reference (bf16 matmul)
+gate_bf16 = x @ w_gate # bf16 matmul
+up_bf16 = x @ w_up
+out_bf16 = torch.nn.functional.silu(gate_bf16) * up_bf16
+
+# Compare
+diff_fp32 = (out_ref - out_bf16.float()).abs()
+print(f"bf16 reference vs fp32 reference: max_diff={diff_fp32.max().item():.6f}")
+
+# Now test with a single M-block
+BLOCK_M = 64
+BLOCK_N = 128
+BLOCK_K = 32
+
+@triton.jit
+def single_expert_kernel(
+ x_ptr, w_gate_ptr, w_up_ptr, out_ptr,
+ n_tokens, H, I,
+ x_stride, w_gate_stride_h, w_gate_stride_i,
+ w_up_stride_h, w_up_stride_i,
+ out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ USE_MANUAL_SIGMOID: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+
+ x_base = x_ptr + m_start * x_stride
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_gate_tile = tl.load(
+ w_gate_ptr + (k_start + k_off)[:, None] * w_gate_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up_tile = tl.load(
+ w_up_ptr + (k_start + k_off)[:, None] * w_up_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate)
+ acc_up = tl.dot(x_tile, w_up_tile, acc_up)
+
+ if USE_MANUAL_SIGMOID:
+ gate_silu = (1.0 / (1.0 + tl.exp(-acc_gate))) * acc_gate
+ else:
+ gate_silu = tl.sigmoid(acc_gate) * acc_gate
+ result = (gate_silu * acc_up).to(tl.bfloat16)
+
+ out_base = out_ptr + m_start * out_stride
+ tl.store(
+ out_base + m_off[:, None] * out_stride + n_off[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+num_n_tiles = triton.cdiv(I, BLOCK_N)
+
+for use_manual in [False, True]:
+ triton_out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ single_expert_kernel[(num_n_tiles,)](
+ x.cuda(), w_gate.cuda(), w_up.cuda(), triton_out,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), w_gate.stride(1),
+ w_up.stride(0), w_up.stride(1),
+ triton_out.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ USE_MANUAL_SIGMOID=use_manual,
+ )
+
+ diff_vs_bf16 = (out_bf16.cuda().float() - triton_out.float()).abs()
+ diff_vs_fp32 = (out_ref.cuda().float() - triton_out.float()).abs()
+
+ method = "manual sigmoid" if use_manual else "tl.sigmoid"
+ print(f"\n{method}:")
+ print(f" vs bf16 ref: max_diff={diff_vs_bf16.max().item():.6f}")
+ print(f" vs fp32 ref: max_diff={diff_vs_fp32.max().item():.6f}")
+
+ # Per-tile breakdown
+ for ni in range(num_n_tiles):
+ ns = ni * BLOCK_N
+ ne = min(ns + BLOCK_N, I)
+ tile_diff = diff_vs_bf16[0, ns:ne].abs().max().item() if n_tokens > 0 else 0
+ print(f" Tile {ni} ({ns}:{ne}): diff_vs_bf16_max={tile_diff:.6f}")
+
+# Also test: does accumulating twice give the same result?
+print("\n--- Reproducibility test ---")
+out1 = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+single_expert_kernel[(num_n_tiles,)](
+ x.cuda(), w_gate.cuda(), w_up.cuda(), out1,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), w_gate.stride(1),
+ w_up.stride(0), w_up.stride(1),
+ out1.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ USE_MANUAL_SIGMOID=False,
+)
+out2 = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+single_expert_kernel[(num_n_tiles,)](
+ x.cuda(), w_gate.cuda(), w_up.cuda(), out2,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), w_gate.stride(1),
+ w_up.stride(0), w_up.stride(1),
+ out2.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ USE_MANUAL_SIGMOID=False,
+)
+print(f"Run1 vs Run2: max_diff={(out1.float()-out2.float()).abs().max().item():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py",
"content": "<5191 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug3 to find error source",
"command": "uv run python debug3.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}bf16 reference vs fp32 reference: max_diff=0.000704
tl.sigmoid:
vs bf16 ref: max_diff=0.104980
vs fp32 ref: max_diff=0.104720
Tile 0 (0:128): diff_vs_bf16_max=0.037537
Tile 1 (128:256): diff_vs_bf16_max=0.046143
Tile 2 (256:384): diff_vs_bf16_max=0.042969
Tile 3 (384:512): diff_vs_bf16_max=0.035889
Tile 4 (512:640): diff_vs_bf16_max=0.037109
Tile 5 (640:768): diff_vs_bf16_max=0.053467
Tile 6 (768:896): diff_vs_bf16_max=0.039795
Tile 7 (896:1024): diff_vs_bf16_max=0.025024
Tile 8 (1024:1152): diff_vs_bf16_max=0.077637
Tile 9 (1152:1280): diff_vs_bf16_max=0.031982
Tile 10 (1280:1408): diff_vs_bf16_max=0.028320
Tile 11 (1408:1536): diff_vs_bf16_max=0.041260
manual sigmoid:
vs bf16 ref: max_diff=0.104980
vs fp32 ref: max_diff=0.104720
Tile 0 (0:128): diff_vs_bf16_max=0.037537
Tile 1 (128:256): diff_vs_bf16_max=0.046143
Tile 2 (256:384): diff_vs_bf16_max=0.042969
Tile 3 (384:512): diff_vs_bf16_max=0.035889
Tile 4 (512:640): diff_vs_bf16_max=0.037109
Tile 5 (640:768): diff_vs_bf16_max=0.053467
Tile 6 (768:896): diff_vs_bf16_max=0.039795
Tile 7 (896:1024): diff_vs_bf16_max=0.025024
Tile 8 (1024:1152): diff_vs_bf16_max=0.077637
Tile 9 (1152:1280): diff_vs_bf16_max=0.031982
Tile 10 (1280:1408): diff_vs_bf16_max=0.028320
Tile 11 (1408:1536): diff_vs_bf16_max=0.041260
--- Reproducibility test ---
Run1 vs Run2: max_diff=0.091171
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -79,8 +79,8 @@
mask=k_mask[:, None] & n_mask[None, :],
)
- acc_gate = tl.dot(x, w_gate, acc_gate)
- acc_up = tl.dot(x, w_up, acc_up)
+ acc_gate = tl.dot(x, w_gate, acc_gate, input_precision='ieee')
+ acc_up = tl.dot(x, w_up, acc_up, input_precision='ieee')
gate_silu = tl.sigmoid(acc_gate) * acc_gate
result = (gate_silu * acc_up).to(tl.bfloat16)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<95 chars \u2014 see below>",
"newString": "<143 chars \u2014 see below>"
} acc_gate = tl.dot(x, w_gate, acc_gate)
acc_up = tl.dot(x, w_up, acc_up) acc_gate = tl.dot(x, w_gate, acc_gate, input_precision='ieee')
acc_up = tl.dot(x, w_up, acc_up, input_precision='ieee')Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py
@@ -79,8 +79,8 @@
w_up_ptr + (k_start + k_off)[:, None] * w_up_stride_h + (n_start + n_off)[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
- acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate)
- acc_up = tl.dot(x_tile, w_up_tile, acc_up)
+ acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate, input_precision='ieee')
+ acc_up = tl.dot(x_tile, w_up_tile, acc_up, input_precision='ieee')
if USE_MANUAL_SIGMOID:
gate_silu = (1.0 / (1.0 + tl.exp(-acc_gate))) * acc_gate
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py",
"oldString": "<115 chars \u2014 see below>",
"newString": "<163 chars \u2014 see below>"
} acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate)
acc_up = tl.dot(x_tile, w_up_tile, acc_up) acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate, input_precision='ieee')
acc_up = tl.dot(x_tile, w_up_tile, acc_up, input_precision='ieee')Edit applied successfully.{
"description": "Run debug3 with ieee precision",
"command": "uv run python debug3.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}bf16 reference vs fp32 reference: max_diff=0.000704
tl.sigmoid:
vs bf16 ref: max_diff=0.104980
vs fp32 ref: max_diff=0.104720
Tile 0 (0:128): diff_vs_bf16_max=0.037537
Tile 1 (128:256): diff_vs_bf16_max=0.046143
Tile 2 (256:384): diff_vs_bf16_max=0.042969
Tile 3 (384:512): diff_vs_bf16_max=0.035889
Tile 4 (512:640): diff_vs_bf16_max=0.037109
Tile 5 (640:768): diff_vs_bf16_max=0.053467
Tile 6 (768:896): diff_vs_bf16_max=0.039795
Tile 7 (896:1024): diff_vs_bf16_max=0.025024
Tile 8 (1024:1152): diff_vs_bf16_max=0.077637
Tile 9 (1152:1280): diff_vs_bf16_max=0.031982
Tile 10 (1280:1408): diff_vs_bf16_max=0.028320
Tile 11 (1408:1536): diff_vs_bf16_max=0.041260
manual sigmoid:
vs bf16 ref: max_diff=0.104980
vs fp32 ref: max_diff=0.104720
Tile 0 (0:128): diff_vs_bf16_max=0.037537
Tile 1 (128:256): diff_vs_bf16_max=0.046143
Tile 2 (256:384): diff_vs_bf16_max=0.042969
Tile 3 (384:512): diff_vs_bf16_max=0.035889
Tile 4 (512:640): diff_vs_bf16_max=0.037109
Tile 5 (640:768): diff_vs_bf16_max=0.053467
Tile 6 (768:896): diff_vs_bf16_max=0.039795
Tile 7 (896:1024): diff_vs_bf16_max=0.025024
Tile 8 (1024:1152): diff_vs_bf16_max=0.077637
Tile 9 (1152:1280): diff_vs_bf16_max=0.031982
Tile 10 (1280:1408): diff_vs_bf16_max=0.028320
Tile 11 (1408:1536): diff_vs_bf16_max=0.041260
--- Reproducibility test ---
Run1 vs Run2: max_diff=0.091171
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py
@@ -79,8 +79,8 @@
w_up_ptr + (k_start + k_off)[:, None] * w_up_stride_h + (n_start + n_off)[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
- acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate, input_precision='ieee')
- acc_up = tl.dot(x_tile, w_up_tile, acc_up, input_precision='ieee')
+ acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate, input_precision='ieee', max_num_imprecise_acc=0)
+ acc_up = tl.dot(x_tile, w_up_tile, acc_up, input_precision='ieee', max_num_imprecise_acc=0)
if USE_MANUAL_SIGMOID:
gate_silu = (1.0 / (1.0 + tl.exp(-acc_gate))) * acc_gate
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py",
"oldString": "<163 chars \u2014 see below>",
"newString": "<213 chars \u2014 see below>"
} acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate, input_precision='ieee')
acc_up = tl.dot(x_tile, w_up_tile, acc_up, input_precision='ieee') acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate, input_precision='ieee', max_num_imprecise_acc=0)
acc_up = tl.dot(x_tile, w_up_tile, acc_up, input_precision='ieee', max_num_imprecise_acc=0)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py
@@ -98,31 +98,23 @@
num_n_tiles = triton.cdiv(I, BLOCK_N)
for use_manual in [False, True]:
- triton_out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
- single_expert_kernel[(num_n_tiles,)](
- x.cuda(), w_gate.cuda(), w_up.cuda(), triton_out,
- n_tokens, H, I,
- x.stride(0), w_gate.stride(0), w_gate.stride(1),
- w_up.stride(0), w_up.stride(1),
- triton_out.stride(0),
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
- USE_MANUAL_SIGMOID=use_manual,
- )
-
- diff_vs_bf16 = (out_bf16.cuda().float() - triton_out.float()).abs()
- diff_vs_fp32 = (out_ref.cuda().float() - triton_out.float()).abs()
-
- method = "manual sigmoid" if use_manual else "tl.sigmoid"
- print(f"\n{method}:")
- print(f" vs bf16 ref: max_diff={diff_vs_bf16.max().item():.6f}")
- print(f" vs fp32 ref: max_diff={diff_vs_fp32.max().item():.6f}")
-
- # Per-tile breakdown
- for ni in range(num_n_tiles):
- ns = ni * BLOCK_N
- ne = min(ns + BLOCK_N, I)
- tile_diff = diff_vs_bf16[0, ns:ne].abs().max().item() if n_tokens > 0 else 0
- print(f" Tile {ni} ({ns}:{ne}): diff_vs_bf16_max={tile_diff:.6f}")
+ for bk in [32, 64, 128]:
+ triton_out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ num_nt = triton.cdiv(I, BLOCK_N)
+ single_expert_kernel[(num_nt,)](
+ x.cuda(), w_gate.cuda(), w_up.cuda(), triton_out,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), w_gate.stride(1),
+ w_up.stride(0), w_up.stride(1),
+ triton_out.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=bk,
+ USE_MANUAL_SIGMOID=use_manual,
+ )
+
+ diff_vs_bf16 = (out_bf16.cuda().float() - triton_out.float()).abs()
+ method = "manual sigmoid" if use_manual else "tl.sigmoid"
+ print(f"\n{method} BK={bk}:")
+ print(f" vs bf16 ref: max_diff={diff_vs_bf16.max().item():.6f}")
# Also test: does accumulating twice give the same result?
print("\n--- Reproducibility test ---")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py",
"oldString": "<1132 chars \u2014 see below>",
"newString": "<829 chars \u2014 see below>"
}for use_manual in [False, True]:
triton_out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
single_expert_kernel[(num_n_tiles,)](
x.cuda(), w_gate.cuda(), w_up.cuda(), triton_out,
n_tokens, H, I,
x.stride(0), w_gate.stride(0), w_gate.stride(1),
w_up.stride(0), w_up.stride(1),
triton_out.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
USE_MANUAL_SIGMOID=use_manual,
)
diff_vs_bf16 = (out_bf16.cuda().float() - triton_out.float()).abs()
diff_vs_fp32 = (out_ref.cuda().float() - triton_out.float()).abs()
method = "manual sigmoid" if use_manual else "tl.sigmoid"
print(f"\n{method}:")
print(f" vs bf16 ref: max_diff={diff_vs_bf16.max().item():.6f}")
print(f" vs fp32 ref: max_diff={diff_vs_fp32.max().item():.6f}")
# Per-tile breakdown
for ni in range(num_n_tiles):
ns = ni * BLOCK_N
ne = min(ns + BLOCK_N, I)
tile_diff = diff_vs_bf16[0, ns:ne].abs().max().item() if n_tokens > 0 else 0
print(f" Tile {ni} ({ns}:{ne}): diff_vs_bf16_max={tile_diff:.6f}")for use_manual in [False, True]:
for bk in [32, 64, 128]:
triton_out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
num_nt = triton.cdiv(I, BLOCK_N)
single_expert_kernel[(num_nt,)](
x.cuda(), w_gate.cuda(), w_up.cuda(), triton_out,
n_tokens, H, I,
x.stride(0), w_gate.stride(0), w_gate.stride(1),
w_up.stride(0), w_up.stride(1),
triton_out.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=bk,
USE_MANUAL_SIGMOID=use_manual,
)
diff_vs_bf16 = (out_bf16.cuda().float() - triton_out.float()).abs()
method = "manual sigmoid" if use_manual else "tl.sigmoid"
print(f"\n{method} BK={bk}:")
print(f" vs bf16 ref: max_diff={diff_vs_bf16.max().item():.6f}")Edit applied successfully.{
"description": "Run updated debug3 with BK sweep",
"command": "uv run python debug3.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}bf16 reference vs fp32 reference: max_diff=0.000704
tl.sigmoid BK=32:
vs bf16 ref: max_diff=0.104980
tl.sigmoid BK=64:
vs bf16 ref: max_diff=0.111084
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug3.py", line 104, in <module>
single_expert_kernel[(num_nt,)](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 482, in launch_metadata
self._init_handles()
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 456, in _init_handles
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 448, in raise_
raise err
triton.runtime.errors.OutOfResources: out of resource:
... (105 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug4.py
@@ -0,0 +1,123 @@
+"""Investigate non-determinism and compare against reference."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+H = 4096
+I = 1536
+n_tokens = 512
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w_gate = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_up = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# Reference (bf16 precision)
+gate_bf16 = x @ w_gate
+up_bf16 = x @ w_up
+out_bf16 = torch.nn.functional.silu(gate_bf16) * up_bf16
+
+# Reference (full fp32)
+gate_fp32 = x.float() @ w_gate.float()
+up_fp32 = x.float() @ w_up.float()
+out_fp32 = torch.nn.functional.silu(gate_fp32) * up_fp32
+
+print(f"bf16 ref vs fp32 ref: max={((out_bf16.float()-out_fp32).abs().max().item()):.6f}")
+
+# Triton kernel
+@triton.jit
+def test_kernel(
+ x_ptr, w_gate_ptr, w_up_ptr, out_ptr,
+ n_tokens, H, I,
+ x_stride, w_gate_stride_h, w_gate_stride_i,
+ w_up_stride_h, w_up_stride_i,
+ out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ PREC: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+ x_base = x_ptr + m_start * x_stride
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ if PREC == 0:
+ # Standard bf16 dot
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ else:
+ # Load as fp32
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ ).to(tl.float32)
+
+ w_gate_tile = tl.load(
+ w_gate_ptr + (k_start + k_off)[:, None] * w_gate_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up_tile = tl.load(
+ w_up_ptr + (k_start + k_off)[:, None] * w_up_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+
+ if PREC == 0:
+ acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate)
+ acc_up = tl.dot(x_tile, w_up_tile, acc_up)
+ elif PREC == 1:
+ acc_gate = tl.dot(x_tile, w_gate_tile.to(tl.float32), acc_gate)
+ acc_up = tl.dot(x_tile, w_up_tile.to(tl.float32), acc_up)
+ elif PREC == 2:
+ x_f32 = x_tile.to(tl.float32)
+ acc_gate = tl.dot(x_f32, w_gate_tile.to(tl.float32), acc_gate)
+ acc_up = tl.dot(x_f32, w_up_tile.to(tl.float32), acc_up)
+
+ gate_silu = tl.sigmoid(acc_gate) * acc_gate
+ result = (gate_silu * acc_up).to(tl.bfloat16)
+
+ out_base = out_ptr + m_start * out_stride
+ tl.store(
+ out_base + m_off[:, None] * out_stride + n_off[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+num_nt = triton.cdiv(I, 128)
+
+for prec, name in [(0, "bf16 dot"), (1, "bf16@fp32 dot"), (2, "fp32 dot")]:
+ print(f"\n=== {name} ===")
+ diffs_vs_bf16 = []
+ diffs_vs_fp32 = []
+ for run in range(5):
+ out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ test_kernel[(num_nt,)](
+ x.cuda(), w_gate.cuda(), w_up.cuda(), out,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), w_gate.stride(1),
+ w_up.stride(0), w_up.stride(1),
+ out.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ PREC=prec,
+ )
+ dvb = (out_bf16.cuda().float() - out.float()).abs().max().item()
+ dvf = (out_fp32.cuda().float() - out.float()).abs().max().item()
+ diffs_vs_bf16.append(dvb)
+ diffs_vs_fp32.append(dvf)
+ print(f" Run {run}: vs_bf16={dvb:.6f} vs_fp32={dvf:.6f}")
+
+ print(f" --- vs bf16: min={min(diffs_vs_bf16):.6f} max={max(diffs_vs_bf16):.6f} range={max(diffs_vs_bf16)-min(diffs_vs_bf16):.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug4.py",
"content": "<4504 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug4 - compare precisions",
"command": "uv run python debug4.py 2>&1",
"timeout": 180000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}bf16 ref vs fp32 ref: max=0.000704
=== bf16 dot ===
Run 0: vs_bf16=0.104980 vs_fp32=0.104720
Run 1: vs_bf16=0.104980 vs_fp32=0.104720
Run 2: vs_bf16=0.104980 vs_fp32=0.104720
Run 3: vs_bf16=0.104980 vs_fp32=0.104720
Run 4: vs_bf16=0.104980 vs_fp32=0.104720
--- vs bf16: min=0.104980 max=0.104980 range=0.000000
=== bf16@fp32 dot ===
Run 0: vs_bf16=0.104980 vs_fp32=0.104720
Run 1: vs_bf16=0.104980 vs_fp32=0.104720
Run 2: vs_bf16=0.104980 vs_fp32=0.104720
Run 3: vs_bf16=0.104980 vs_fp32=0.104720
Run 4: vs_bf16=0.104980 vs_fp32=0.104720
--- vs bf16: min=0.104980 max=0.104980 range=0.000000
=== fp32 dot ===
Run 0: vs_bf16=0.104980 vs_fp32=0.104720
Run 1: vs_bf16=0.104980 vs_fp32=0.104720
Run 2: vs_bf16=0.104980 vs_fp32=0.104720
Run 3: vs_bf16=0.104980 vs_fp32=0.104720
Run 4: vs_bf16=0.109497 vs_fp32=0.109344
--- vs bf16: min=0.104980 max=0.109497 range=0.004517
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug5.py
@@ -0,0 +1,136 @@
+"""Compare GEMM intermediate values between Triton and reference."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+H = 4096
+I = 1536
+n_tokens = 512
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w_gate = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_up = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# Reference GEMMs (bf16 matmul)
+gate_ref = x @ w_gate # (n_tokens, I) bf16
+up_ref = x @ w_up # (n_tokens, I) bf16
+out_ref = torch.nn.functional.silu(gate_ref) * up_ref
+
+# Reference GEMM (fp32, to see if bf16 truncation is the issue)
+gate_fp32 = x.float() @ w_gate.float()
+up_fp32 = x.float() @ w_up.float()
+
+print("GEMM differences:")
+print(f" gate bf16 vs fp32: max={((gate_ref.float() - gate_fp32).abs().max().item()):.6f}")
+print(f" up bf16 vs fp32: max={((up_ref.float() - up_fp32).abs().max().item()):.6f}")
+
+# Triton kernel that computes JUST the GEMM and outputs bf16 gate and up
+@triton.jit
+def triton_gemm_kernel(
+ x_ptr, w_gate_ptr, w_up_ptr, gate_out_ptr, up_out_ptr,
+ n_tokens, H, I,
+ x_stride, w_gate_stride_h, w_gate_stride_i,
+ w_up_stride_h, w_up_stride_i,
+ gate_stride, up_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+ x_base = x_ptr + m_start * x_stride
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_gate_tile = tl.load(
+ w_gate_ptr + (k_start + k_off)[:, None] * w_gate_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up_tile = tl.load(
+ w_up_ptr + (k_start + k_off)[:, None] * w_up_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ acc_gate = tl.dot(x_tile, w_gate_tile, acc_gate)
+ acc_up = tl.dot(x_tile, w_up_tile, acc_up)
+
+ # Store gate and up separately as bf16
+ gate_base = gate_out_ptr + m_start * gate_stride
+ up_base = up_out_ptr + m_start * up_stride
+ tl.store(
+ gate_base + m_off[:, None] * gate_stride + n_off[None, :],
+ acc_gate.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+ tl.store(
+ up_base + m_off[:, None] * up_stride + n_off[None, :],
+ acc_up.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+num_nt = triton.cdiv(I, 128)
+gate_triton = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+up_triton = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+
+triton_gemm_kernel[(num_nt,)](
+ x.cuda(), w_gate.cuda(), w_up.cuda(), gate_triton, up_triton,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), w_gate.stride(1),
+ w_up.stride(0), w_up.stride(1),
+ gate_triton.stride(0), up_triton.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+)
+
+print("\nTriton GEMM vs reference:")
+gate_diff = (gate_ref.cuda().float() - gate_triton.float()).abs()
+up_diff = (up_ref.cuda().float() - up_triton.float()).abs()
+print(f" gate: max_diff={gate_diff.max().item():.6f} mean={gate_diff.mean().item():.6f}")
+print(f" up: max_diff={up_diff.max().item():.6f} mean={up_diff.mean().item():.6f}")
+
+# Now check the silu * up stage
+# Reference SwiGLU
+out_ref_swiglu = torch.nn.functional.silu(gate_ref) * up_ref
+
+# Triton SwiGLU (using Triton-computed gate and up with torch silu)
+out_triton_gemm_torch_silu = torch.nn.functional.silu(gate_triton.cpu()) * up_triton.cpu()
+diff_gemm = (out_ref_swiglu.float() - out_triton_gemm_torch_silu.float()).abs()
+print(f"\nSwiGLU: triton GEMM + torch silu vs ref:")
+print(f" max_diff={diff_gemm.max().item():.6f}")
+
+# Compute silu and multiply in fp32 from triton's gate/up
+gate_triton_f32 = gate_triton.float()
+up_triton_f32 = up_triton.float()
+out_triton_fp32 = torch.nn.functional.silu(gate_triton_f32) * up_triton_f32
+diff_fp32_swiglu = (out_ref_swiglu.float().cuda() - out_triton_fp32.cuda()).abs()
+print(f"\nSwiGLU: triton GEMM fp32 silu vs ref:")
+print(f" max_diff={diff_fp32_swiglu.max().item():.6f}")
+
+# The full pipeline: triton GEMM -> bf16 cutoff -> silu -> multiply -> bf16
+# vs reference
+print(f"\nFull pipeline triton vs ref:")
+out_triton_full = torch.nn.functional.silu(gate_triton.cpu()) * up_triton.cpu()
+diff_full = (out_ref_swiglu.float() - out_triton_full.float()).abs()
+print(f" max_diff={diff_full.max().item():.6f}")
+
+# Also check: what if gate_triton is bf16 and up_triton is bf16, then silu and mul in fp32 on GPU?
+gate_triton_gpu = gate_triton.cuda()
+up_triton_gpu = up_triton.cuda()
+out_triton_gpu_silu = torch.nn.functional.silu(gate_triton_gpu.float()) * up_triton_gpu.float()
+diff_gpu_fp32 = (out_ref_swiglu.float().cuda() - out_triton_gpu_silu).abs()
+print(f" (GPU fp32 silu): max_diff={diff_gpu_fp32.max().item():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug5.py",
"content": "<5529 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug5 - find exact error source",
"command": "uv run python debug5.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}GEMM differences:
gate bf16 vs fp32: max=0.001949
up bf16 vs fp32: max=0.001926
Triton GEMM vs reference:
gate: max_diff=0.756836 mean=0.104116
up: max_diff=0.732422 mean=0.104095
SwiGLU: triton GEMM + torch silu vs ref:
max_diff=0.108688
SwiGLU: triton GEMM fp32 silu vs ref:
max_diff=0.108967
Full pipeline triton vs ref:
max_diff=0.108688
(GPU fp32 silu): max_diff=0.108967
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug6.py
@@ -0,0 +1,111 @@
+"""Tile size sweep for GEMM accuracy."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+H = 4096
+I = 1536
+n_tokens = 512
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w_gate = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# Reference
+gate_ref = x @ w_gate # cuBLAS bf16 GEMM
+
+@triton.jit
+def gemm_kernel(
+ x_ptr, w_ptr, out_ptr,
+ n_tokens, H, I,
+ x_stride, w_stride_h, out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+ x_base = x_ptr + m_start * x_stride
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_tile = tl.load(
+ w_ptr + (k_start + k_off)[:, None] * w_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ acc = tl.dot(x_tile, w_tile, acc)
+
+ out_base = out_ptr + m_start * out_stride
+ tl.store(
+ out_base + m_off[:, None] * out_stride + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+import itertools
+
+configs = []
+for bm, bn, bk, nw in itertools.product(
+ [32, 64, 128],
+ [64, 128, 256],
+ [16, 32, 64],
+ [4, 8],
+):
+ if bk < 16: continue
+ if bn > I: continue
+ configs.append((bm, bn, bk, nw))
+
+# Also test single-k-tile (load all of H at once if possible)
+for bm, bn, nw in itertools.product([16, 32], [64, 128], [4]):
+ bk = H # full K dimension
+ configs.append((bm, bn, bk, nw))
+
+results = []
+for bm, bn, bk, nw in configs:
+ if bk == H:
+ try:
+ out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ num_nt = triton.cdiv(I, bn)
+ gemm_kernel[(num_nt,)](
+ x.cuda(), w_gate.cuda(), out,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), out.stride(0),
+ BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
+ num_warps=nw, num_stages=2,
+ )
+ diff = (gate_ref.cuda().float() - out.float()).abs().max().item()
+ results.append((bm, bn, bk, nw, diff, "OK"))
+ print(f"BM={bm:3d} BN={bn:3d} BK={bk:4d} NW={nw}: diff={diff:.6f} OK")
+ except Exception as e:
+ print(f"BM={bm:3d} BN={bn:3d} BK={bk:4d} NW={nw}: FAILED - {e}")
+ else:
+ out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ num_nt = triton.cdiv(I, bn)
+ gemm_kernel[(num_nt,)](
+ x.cuda(), w_gate.cuda(), out,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), out.stride(0),
+ BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
+ num_warps=nw, num_stages=3,
+ )
+ diff = (gate_ref.cuda().float() - out.float()).abs().max().item()
+ results.append((bm, bn, bk, nw, diff, "OK"))
+ print(f"BM={bm:3d} BN={bn:3d} BK={bk:4d} NW={nw}: diff={diff:.6f} OK")
+
+print("\n=== Best configs ===")
+for r in sorted(results, key=lambda x: x[4])[:10]:
+ print(f"BM={r[0]:3d} BN={r[1]:3d} BK={r[2]:4d} NW={r[3]}: diff={r[4]:.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug6.py",
"content": "<3703 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Sweep tile sizes for accuracy",
"command": "uv run python debug6.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}BM= 32 BN= 64 BK= 16 NW=4: diff=0.773438 OK
BM= 32 BN= 64 BK= 16 NW=8: diff=0.781250 OK
BM= 32 BN= 64 BK= 32 NW=4: diff=0.804688 OK
BM= 32 BN= 64 BK= 32 NW=8: diff=0.736328 OK
BM= 32 BN= 64 BK= 64 NW=4: diff=0.755859 OK
BM= 32 BN= 64 BK= 64 NW=8: diff=0.773438 OK
BM= 32 BN=128 BK= 16 NW=4: diff=0.789062 OK
BM= 32 BN=128 BK= 16 NW=8: diff=0.789062 OK
BM= 32 BN=128 BK= 32 NW=4: diff=0.771484 OK
BM= 32 BN=128 BK= 32 NW=8: diff=0.723633 OK
BM= 32 BN=128 BK= 64 NW=4: diff=0.771484 OK
BM= 32 BN=128 BK= 64 NW=8: diff=0.783203 OK
BM= 32 BN=256 BK= 16 NW=4: diff=0.781250 OK
BM= 32 BN=256 BK= 16 NW=8: diff=0.818359 OK
BM= 32 BN=256 BK= 32 NW=4: diff=0.732422 OK
BM= 32 BN=256 BK= 32 NW=8: diff=0.769531 OK
BM= 32 BN=256 BK= 64 NW=4: diff=0.781250 OK
BM= 32 BN=256 BK= 64 NW=8: diff=0.781250 OK
BM= 64 BN= 64 BK= 16 NW=4: diff=0.730469 OK
BM= 64 BN= 64 BK= 16 NW=8: diff=0.736328 OK
BM= 64 BN= 64 BK= 32 NW=4: diff=0.736328 OK
BM= 64 BN= 64 BK= 32 NW=8: diff=0.744141 OK
BM= 64 BN= 64 BK= 64 NW=4: diff=0.800781 OK
BM= 64 BN= 64 BK= 64 NW=8: diff=0.744141 OK
BM= 64 BN=128 BK= 16 NW=4: diff=0.816406 OK
BM= 64 BN=128 BK= 16 NW=8: diff=0.751953 OK
BM= 64 BN=128 BK= 32 NW=4: diff=0.750000 OK
BM= 64 BN=128 BK= 32 NW=8: diff=0.769531 OK
BM= 64 BN=128 BK= 64 NW=4: diff=0.847656 OK
BM= 64 BN=128 BK= 64 NW=8: diff=0.708984 OK
BM= 64 BN=256 BK= 16 NW=4: diff=0.730469 OK
BM= 64 BN=256 BK= 16 NW=8: diff=0.582031 OK
BM= 64 BN=256 BK= 32 NW=4: diff=0.781250 OK
BM= 64 BN=256 B
... (2006 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug7.py
@@ -0,0 +1,158 @@
+"""Test if Triton bf16 GEMM differs fundamentally from cuBLAS."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+H = 4096
+I = 1536
+n_tokens = 512
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# cuBLAS bf16 matmul
+ref = x @ w
+
+# Test tl.dot vs manual accumulation in fp32
+x_f32 = x.float().cuda()
+w_f32 = w.float().cuda()
+ref_f32 = x_f32 @ w_f32
+
+print(f"cuBLAS bf16 vs fp32 ref: max={((ref.float()-ref_f32.cpu()).abs().max().item()):.6f}")
+
+# Now: use Triton to load tiles and compute dot product manually (no tl.dot)
+# We'll compute the dot product by multiplying and summing in fp32
+@triton.jit
+def manual_dot_kernel(
+ x_ptr, w_ptr, out_ptr,
+ n_tokens, H, I,
+ x_stride, w_stride_h, out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+ x_base = x_ptr + m_start * x_stride
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ # Convert to fp32 manually, multiply, accumulate
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ ).to(tl.float32)
+ w_tile = tl.load(
+ w_ptr + (k_start + k_off)[:, None] * w_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ ).to(tl.float32)
+
+ # Manual dot product: element-wise multiply, then sum along K
+ prod = x_tile[:, :, None] * w_tile[None, :, :]
+ # This creates (BLOCK_M, BLOCK_K, BLOCK_N) and sums along K
+ # But this is insanely expensive in registers
+
+ # Better: compute via reduction
+ # acc[m,n] += sum_k x[m,k] * w[k,n]
+ for k in range(BLOCK_K):
+ acc += x_tile[:, k:k+1] * w_tile[k:k+1, :]
+
+ out_base = out_ptr + m_start * out_stride
+ tl.store(
+ out_base + m_off[:, None] * out_stride + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+# Actually, that manual approach is terrible. Let me test something simpler.
+# Test: what if we just do ONE K-tile and verify the partial sum is correct?
+
+@triton.jit
+def single_tile_gemm(
+ x_ptr, w_ptr, out_ptr,
+ n_tokens, H, I,
+ x_stride, w_stride_h, out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ TILE_K_START: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+ x_base = x_ptr + m_start * x_stride
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ k_start = TILE_K_START
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_tile = tl.load(
+ w_ptr + (k_start + k_off)[:, None] * w_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ acc = tl.dot(x_tile, w_tile, acc)
+
+ out_base = out_ptr + m_start * out_stride
+ tl.store(
+ out_base + m_off[:, None] * out_stride + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+# Test: compute partial sums for each K-tile and compare with cuBLAS partials
+BLOCK_K = 32
+BLOCK_M = 64
+BLOCK_N = 128
+num_k_tiles = H // BLOCK_K
+
+for k_tile in range(num_k_tiles):
+ k_start = k_tile * BLOCK_K
+ k_end = k_start + BLOCK_K
+
+ # cuBLAS reference for this K-tile only
+ ref_partial = x[:, k_start:k_end] @ w[k_start:k_end, :]
+
+ # Triton partial
+ out_triton = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ num_nt = triton.cdiv(I, BLOCK_N)
+ single_tile_gemm[(num_nt,)](
+ x.cuda(), w.cuda(), out_triton,
+ n_tokens, H, I,
+ x.stride(0), w.stride(0), out_triton.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ TILE_K_START=k_start,
+ num_warps=4, num_stages=2,
+ )
+
+ diff = (ref_partial.cuda().float() - out_triton.float()).abs()
+ maxd = diff.max().item()
+
+ if maxd > 0.01:
+ print(f"K-tile {k_tile} ({k_start}:{k_end}): max_diff={maxd:.6f} (LARGE!)")
+ else:
+ pass # suppress OK tiles for brevity
+
+# Summarize: how many tiles have large errors?
+print(f"\n(n_tokens=512, checking all {num_k_tiles} K-tiles)")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug7.py",
"content": "<5283 chars \u2014 see diff>"
}Wrote file successfully.--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug8.py
@@ -0,0 +1,93 @@
+"""Magnitude of Triton bf16 GEMM error: is it relative or absolute?"""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+H = 4096
+I = 1536
+n_tokens = 512
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+ref = x @ w # cuBLAS
+ref_vals = ref.float()
+print(f"Reference stats: mean={ref_vals.mean():.6f} std={ref_vals.std():.6f} min={ref_vals.min():.6f} max={ref_vals.max():.6f}")
+
+# Test: does tl.dot with a single K-tile give different result from cuBLAS?
+@triton.jit
+def single_tile_gemm(
+ x_ptr, w_ptr, out_ptr,
+ n_tokens, H, I,
+ x_stride, w_stride_h, out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ TILE_K_START: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+ x_base = x_ptr + m_start * x_stride
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ k_start = TILE_K_START
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_tile = tl.load(
+ w_ptr + (k_start + k_off)[:, None] * w_stride_h + (n_start + n_off)[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ acc = tl.dot(x_tile, w_tile, acc)
+ out_base = out_ptr + m_start * out_stride
+ tl.store(
+ out_base + m_off[:, None] * out_stride + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+BLOCK_M = 64
+BLOCK_N = 128
+BLOCK_K = 32
+num_nt = triton.cdiv(I, BLOCK_N)
+
+# Compare individual K-tiles
+for k_start in list(range(0, H, BLOCK_K))[:5]: # first 5 tiles
+ ref_partial = x[:, k_start:k_start+BLOCK_K] @ w[k_start:k_start+BLOCK_K, :]
+ out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ single_tile_gemm[(num_nt,)](
+ x.cuda(), w.cuda(), out,
+ n_tokens, H, I,
+ x.stride(0), w.stride(0), out.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ TILE_K_START=k_start,
+ num_warps=4, num_stages=2,
+ )
+ diff = (ref_partial.cuda().float() - out.float()).abs()
+ ref_partial_stats = ref_partial.float()
+ print(f"K-tile {k_start:4d}: ref mean={ref_partial_stats.mean():.4f} std={ref_partial_stats.std():.4f} diff max={diff.max().item():.6f}")
+
+# Test: does single-tile accumulate correctly?
+# Sum over ALL K-tiles using Triton
+out_all = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+for k_start in range(0, H, BLOCK_K):
+ partial = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ single_tile_gemm[(num_nt,)](
+ x.cuda(), w.cuda(), partial,
+ n_tokens, H, I,
+ x.stride(0), w.stride(0), partial.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ TILE_K_START=k_start,
+ num_warps=4, num_stages=2,
+ )
+ out_all = out_all + partial
+diff = (ref.cuda().float() - out_all.float()).abs()
+print(f"\nAccumulated all K-tiles separately: max_diff={diff.max().item():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug8.py",
"content": "<3464 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug8 - check individual K-tiles",
"command": "uv run python debug8.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Reference stats: mean=-0.000027 std=0.126209 min=-0.570312 max=0.585938
K-tile 0: ref mean=-0.0000 std=0.0112 diff max=0.082031
K-tile 32: ref mean=-0.0000 std=0.0111 diff max=0.075684
K-tile 64: ref mean=0.0000 std=0.0112 diff max=0.079651
K-tile 96: ref mean=-0.0000 std=0.0112 diff max=0.089233
K-tile 128: ref mean=0.0000 std=0.0111 diff max=0.075684
Accumulated all K-tiles separately: max_diff=0.783203
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug9.py
@@ -0,0 +1,101 @@
+"""Find which elements have large errors in single K-tile GEMM."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+H = 4096
+I = 1536
+n_tokens = 512
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+BLOCK_K = 32
+k_start = 0
+k_end = k_start + BLOCK_K
+
+# cuBLAS reference for first K-tile
+ref_partial = x[:, k_start:k_end] @ w[k_start:k_end, :] # (512, 1536)
+
+# Manual fp32 computation
+x_f32 = x[:, k_start:k_end].float() # (512, 32)
+w_f32 = w[k_start:k_end, :].float() # (32, 1536)
+manual_fp32 = x_f32 @ w_f32 # (512, 1536)
+
+print("Ref (cuBLAS bf16) vs manual fp32:")
+diff_bf16_vs_fp32 = (ref_partial.float() - manual_fp32.cpu()).abs()
+print(f" max={diff_bf16_vs_fp32.max().item():.6f} mean={diff_bf16_vs_fp32.mean().item():.6f}")
+
+# Now test Triton's tl.dot for the same tile
+@triton.jit
+def single_tile_gemm(
+ x_ptr, w_ptr, out_ptr,
+ n_tokens, I,
+ x_stride, w_stride_h, out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+ x_base = x_ptr + m_start * x_stride
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ k_off = tl.arange(0, BLOCK_K)
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + k_off[None, :],
+ mask=m_mask[:, None],
+ )
+ w_tile = tl.load(
+ w_ptr + k_off[:, None] * w_stride_h + (n_start + n_off)[None, :],
+ mask=n_mask[None, :],
+ )
+ acc = tl.dot(x_tile, w_tile, acc)
+ out_base = out_ptr + m_start * out_stride
+ tl.store(
+ out_base + m_off[:, None] * out_stride + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+# Pass weights without E dimension
+single_tile_gemm[(triton.cdiv(I, 128),)](
+ x.cuda(), w[:BLOCK_K].contiguous().cuda(), out,
+ n_tokens, I,
+ x.stride(0), w[:BLOCK_K].stride(0), out.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4, num_stages=2,
+)
+
+diff = (ref_partial.cuda().float() - out.float()).abs()
+print(f"\nTriton single tile vs cuBLAS:")
+print(f" max={diff.max().item():.6f} mean={diff.mean().item():.6f}")
+
+# Find the worst elements
+flat_diff = diff.flatten()
+worst_idx = flat_diff.argsort(descending=True)[:5]
+for idx in worst_idx:
+ m = idx // I
+ n = idx % I
+ print(f" Element ({m},{n}): ref={ref_partial[m,n].item():.4f} triton={out[m,n].item():.4f} diff={flat_diff[idx].item():.4f}")
+
+# Check: are certain columns systematically wrong?
+print(f"\nColumn-wise max diff:")
+col_means = diff.max(dim=0).values
+best_cols = col_means.argsort()[:5]
+worst_cols = col_means.argsort(descending=True)[:5]
+print(f" Best columns: {[(c.item(), col_means[c].item()) for c in best_cols]}")
+print(f" Worst columns: {[(c.item(), col_means[c].item()) for c in worst_cols]}")
+
+# Check max diff per row
+row_means = diff.max(dim=1).values
+best_rows = row_means.argsort()[:5]
+worst_rows = row_means.argsort(descending=True)[:5]
+print(f" Best rows: {[(r.item(), row_means[r].item()) for r in best_rows]}")
+print(f" Worst rows: {[(r.item(), row_means[r].item()) for r in worst_rows]}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug9.py",
"content": "<3565 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug9 - analyze error patterns",
"command": "uv run python debug9.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Ref (cuBLAS bf16) vs manual fp32:
max=0.000122 mean=0.000013
Triton single tile vs cuBLAS:
max=0.071533 mean=0.009201
Element (110,70): ref=-0.0381 triton=0.0334 diff=0.0715
Element (185,63): ref=-0.0337 triton=0.0376 diff=0.0713
Element (374,30): ref=-0.0356 triton=0.0356 diff=0.0713
Element (188,40): ref=-0.0073 triton=0.0596 diff=0.0668
Element (448,8): ref=0.0522 triton=-0.0142 diff=0.0665
Column-wise max diff:
Best columns: [(411, 0.0203857421875), (836, 0.020751953125), (310, 0.021484375), (1089, 0.0218505859375), (604, 0.0223388671875)]
Worst columns: [(70, 0.071533203125), (30, 0.0712890625), (63, 0.0712890625), (40, 0.06683349609375), (8, 0.06646728515625)]
Best rows: [(391, 0.0269775390625), (277, 0.02984619140625), (186, 0.02996826171875), (402, 0.0306396484375), (147, 0.0311279296875)]
Worst rows: [(110, 0.071533203125), (185, 0.0712890625), (374, 0.0712890625), (188, 0.06683349609375), (448, 0.06646728515625)]
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug10.py
@@ -0,0 +1,89 @@
+"""Verify Triton dot product with exact element comparison."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+H, I, n_tokens = 4096, 1536, 512
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# Test specific problematic element
+m, n = 110, 70
+
+# Manual dot product for the first K-tile (k=0..31)
+ref_partial = x[:, :32] @ w[:32, :]
+manual = sum(x[m, k].float().item() * w[k, n].float().item() for k in range(32))
+print(f"Manual dot: sum_k x[110,k] * w[k,70] over k=0..31 = {manual:.6f}")
+print(f"cuBLAS: ref_partial[110,70] = {ref_partial[110,70].item():.6f}")
+print(f"Are they equal? {abs(manual - ref_partial[110,70].item()) < 1e-6}")
+
+# Now check what Triton loads
+BLOCK_M = 64
+BLOCK_N = 128
+BLOCK_K = 32
+
+@triton.jit
+def debug_load_kernel(
+ x_ptr, w_ptr, out_ptr,
+ n_tokens, I,
+ x_stride, w_stride_h, out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ pid = tl.program_id(0)
+ n_tile_idx = pid
+ n_start = n_tile_idx * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+ x_base = x_ptr + m_start * x_stride
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ k_off = tl.arange(0, BLOCK_K)
+
+ # Instead of tl.dot, manually compute one element
+ x_tile = tl.load(
+ x_base + m_off[:, None] * x_stride + k_off[None, :],
+ mask=m_mask[:, None],
+ )
+ w_tile = tl.load(
+ w_ptr + k_off[:, None] * w_stride_h + (n_start + n_off)[None, :],
+ mask=n_mask[None, :],
+ )
+
+ # Store raw loaded values for debugging
+ out0 = out_ptr + m_start * out_stride
+
+ # Store first row of x_tile and first column of w_tile
+ x_row0 = x_tile[0, :]
+ w_col0 = w_tile[:, 0]
+
+ # Store x_row0 and w_col0 to separate output locations
+ for kk in range(BLOCK_K):
+ tl.store(out0 + kk * 2, x_row0[kk], mask=m_mask[0:1]) # store x values
+ tl.store(out0 + kk * 2 + 128, w_col0[kk], mask=m_mask[0:1]) # store w values
+
+out_debug = torch.zeros(256, dtype=torch.bfloat16, device='cuda')
+debug_load_kernel[(1,)](
+ x.cuda(), w[:32].contiguous().cuda(), out_debug,
+ n_tokens, I,
+ x.stride(0), w[:32].stride(0), 1,
+ BLOCK_M=BLOCK_M, BLOCK_N=128, BLOCK_K=BLOCK_K,
+ num_warps=4, num_stages=2,
+)
+
+# Check loaded values
+print(f"\nFirst 32 x values for row 0: {x[0, :32]}")
+print(f"Triton loaded x row 0 vals: {out_debug[0:64:2]}")
+
+print(f"\nFirst 32 w values for col 0: {w[:32, 0]}")
+print(f"Triton loaded w col 0 vals: {out_debug[1:65:2]}")
+
+# Check: do Triton's loaded values match Python?
+x_match = torch.allclose(x[0, :32].cuda(), out_debug[0:64:2], atol=1e-6, rtol=1e-6)
+w_match = torch.allclose(w[:32, 0].cuda(), out_debug[1:65:2], atol=1e-6, rtol=1e-6)
+print(f"x values match: {x_match}")
+print(f"w values match: {w_match}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug10.py",
"content": "<3143 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug10 - verify loaded values",
"command": "uv run python debug10.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Manual dot: sum_k x[110,k] * w[k,70] over k=0..31 = -0.038020
cuBLAS: ref_partial[110,70] = -0.038086
Are they equal? False
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1101, in __getitem__
raise ValueError(f"unsupported tensor index: {sl}")
ValueError: unsupported tensor index: constexpr[0]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug10.py", line 70, in <module>
debug_load_kernel[(1,)](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/
... (1306 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug11.py
@@ -0,0 +1,109 @@
+"""Minimal reproducer: does tl.dot bf16 give wrong results on SM120?"""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+# Tiny test: 16x32 @ 32x16
+M, K, N = 16, 32, 16
+a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+b = torch.randn(K, N, dtype=torch.bfloat16) * 0.02
+
+ref = a @ b # cuBLAS
+ref_fp32 = a.float() @ b.float() # fp32 reference
+
+# Also check fp16
+a_fp16 = a.to(torch.float16)
+b_fp16 = b.to(torch.float16)
+ref_fp16 = a_fp16 @ b_fp16
+
+print(f"cuBLAS bf16 vs fp32: max={((ref.float()-ref_fp32).abs().max().item()):.6f}")
+print(f"cuBLAS fp16 vs fp32: max={((ref_fp16.float()-ref_fp32).abs().max().item()):.6f}")
+
+@triton.jit
+def tiny_gemm(a_ptr, b_ptr, out_ptr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ off_m = pid * BLOCK_M
+ off_n = 0
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+
+ a = tl.load(a_ptr + m_off[:, None] * K + k_off[None, :])
+ b = tl.load(b_ptr + k_off[:, None] * N + n_off[None, :])
+ c = tl.dot(a, b)
+ tl.store(out_ptr + m_off[:, None] * N + n_off[None, :], c.to(tl.bfloat16))
+
+# bf16 test
+out_bf16 = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
+tiny_gemm[(1,)](
+ a.cuda(), b.cuda(), out_bf16,
+ BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
+ num_warps=4
+)
+
+print(f"\nTriton bf16 vs cuBLAS bf16: max={((out_bf16.cpu().float() - ref.float()).abs().max().item()):.6f}")
+
+# Test with fp16 input to tl.dot
+@triton.jit
+def tiny_gemm_fp16(a_ptr, b_ptr, out_ptr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+
+ a = tl.load(a_ptr + m_off[:, None] * K + k_off[None, :])
+ b = tl.load(b_ptr + k_off[:, None] * N + n_off[None, :])
+ c = tl.dot(a.to(tl.float16), b.to(tl.float16))
+ tl.store(out_ptr + m_off[:, None] * N + n_off[None, :], c.to(tl.bfloat16))
+
+out_fp16 = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
+tiny_gemm_fp16[(1,)](
+ a.cuda().to(torch.float16), b.cuda().to(torch.float16), out_fp16,
+ BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
+ num_warps=4
+)
+
+print(f"Triton fp16 vs cuBLAS fp16: max={((out_fp16.cpu().float() - ref_fp16.float()).abs().max().item()):.6f}")
+
+# Also try split into smaller K tiles
+@triton.jit
+def tiny_gemm_splitk(a_ptr, b_ptr, out_ptr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_start in range(0, K, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ a = tl.load(a_ptr + m_off[:, None] * K + (k_start + k_off)[None, :])
+ b = tl.load(b_ptr + (k_start + k_off)[:, None] * N + n_off[None, :])
+ acc = tl.dot(a, b, acc)
+
+ tl.store(out_ptr + m_off[:, None] * N + n_off[None, :], acc.to(tl.bfloat16))
+
+out_split = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
+tiny_gemm_splitk[(1,)](
+ a.cuda(), b.cuda(), out_split,
+ BLOCK_M=M, BLOCK_N=N, BLOCK_K=16,
+ num_warps=4
+)
+print(f"Triton bf16 splitk vs cuBLAS bf16: max={((out_split.cpu().float() - ref.float()).abs().max().item()):.6f}")
+
+# Check if the issue is with larger K
+K2 = 256
+a2 = torch.randn(M, K2, dtype=torch.bfloat16) * 0.1
+b2 = torch.randn(K2, N, dtype=torch.bfloat16) * 0.02
+ref2 = a2 @ b2
+
+out2 = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
+tiny_gemm[(1,)](
+ a2.cuda(), b2.cuda(), out2,
+ BLOCK_M=M, BLOCK_N=N, BLOCK_K=K2,
+ num_warps=4
+)
+print(f"\nK=256 triton vs cuBLAS: max={((out2.cpu().float() - ref2.float()).abs().max().item()):.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug11.py",
"content": "<3804 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug11 - minimal tl.dot test",
"command": "uv run python debug11.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}cuBLAS bf16 vs fp32: max=0.000109
cuBLAS fp16 vs fp32: max=0.000014
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug11.py", line 41, in <module>
tiny_gemm[(1,)](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py
... (1108 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug12.py
@@ -0,0 +1,67 @@
+"""Try different approaches to get correct bf16 matmul on SM120."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+M, K, N = 64, 32, 64
+a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+b = torch.randn(K, N, dtype=torch.bfloat16) * 0.02
+
+ref = a @ b
+ref_fp32 = a.float() @ b.float()
+
+print(f"cuBLAS bf16 vs fp32: max={((ref.float()-ref_fp32).abs().max().item()):.6f}")
+
+@triton.jit
+def gemm_kernel(a_ptr, b_ptr, out_ptr,
+ M, K, N, a_stride, b_stride, out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ DTYPE: tl.constexpr):
+ pid = tl.program_id(0)
+ n_tile = pid
+ n_start = n_tile * BLOCK_N
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+
+ for m_start in range(0, M, BLOCK_M):
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_start in range(0, K, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ a_load = tl.load(a_ptr + (m_start + m_off)[:, None] * a_stride + (k_start + k_off)[None, :],
+ mask=(m_start + m_off)[:, None] < M)
+ b_load = tl.load(b_ptr + (k_start + k_off)[:, None] * b_stride + (n_start + n_off)[None, :],
+ mask=(n_start + n_off)[None, :] < N)
+
+ if DTYPE == 0:
+ # bf16 dot
+ acc = tl.dot(a_load, b_load, acc)
+ elif DTYPE == 1:
+ # fp16 dot
+ acc = tl.dot(a_load.to(tl.float16), b_load.to(tl.float16), acc)
+ elif DTYPE == 2:
+ # fp32 dot (tf32 on Blackwell)
+ acc = tl.dot(a_load.to(tl.float32), b_load.to(tl.float32), acc)
+
+ tl.store(out_ptr + (m_start + m_off)[:, None] * out_stride + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=(m_start + m_off)[:, None] < M)
+
+num_nt = triton.cdiv(N, 64)
+
+for dtype_name, dt in [("bf16", 0), ("fp16", 1), ("fp32", 2)]:
+ out = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
+ gemm_kernel[(num_nt,)](
+ a.cuda(), b.cuda(), out,
+ M, K, N, a.stride(0), b.stride(0), out.stride(0),
+ BLOCK_M=M, BLOCK_N=64, BLOCK_K=K,
+ DTYPE=dt,
+ num_warps=4, num_stages=2,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ maxd = diff.max().item()
+ meand = diff.mean().item()
+ # Check sign flips
+ sign_flips = ((ref.cuda() * out) < 0).sum().item()
+ print(f"{dtype_name}: max_diff={maxd:.6f} mean={meand:.6f} sign_flips={sign_flips}/{M*N}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug12.py",
"content": "<2565 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug12 - compare dtype approaches",
"command": "uv run python debug12.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}cuBLAS bf16 vs fp32: max=0.000120
bf16: max_diff=0.000000 mean=0.000000 sign_flips=0/4096
fp16: max_diff=0.000000 mean=0.000000 sign_flips=0/4096
fp32: max_diff=0.000000 mean=0.000000 sign_flips=0/4096
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug13.py
@@ -0,0 +1,34 @@
+"""Check if error is specific to N=128 tile size."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+M, K = 64, 32
+
+for N in [16, 32, 64, 128, 256]:
+ a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+ b = torch.randn(K, N, dtype=torch.bfloat16) * 0.02
+ ref = a @ b
+
+ @triton.jit
+ def gemm_kernel(a_ptr, b_ptr, out_ptr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a = tl.load(a_ptr + m_off[:, None] * K + k_off[None, :])
+ b = tl.load(b_ptr + k_off[:, None] * N + n_off[None, :])
+ c = tl.dot(a, b)
+ tl.store(out_ptr + m_off[:, None] * N + n_off[None, :], c.to(tl.bfloat16))
+
+ out = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
+ gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out,
+ BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
+ num_warps=4,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sign_flips = ((ref.cuda() * out) < 0).sum().item()
+ print(f"N={N:3d}: max_diff={diff.max().item():.6f} mean={diff.mean().item():.6f} sign_flips={sign_flips}/{M*N}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug13.py",
"content": "<1222 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug13 - test N dimension",
"command": "uv run python debug13.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug13.py", line 27, in <module>
gemm_kernel[(1,)](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 80, in make_ir
return ast_to_ttir(self.fn, self, conte
... (977 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug14.py
@@ -0,0 +1,35 @@
+"""Test single K-tile with different N sizes. Key: is N=128 buggy?"""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+M, K = 64, 32
+
+for N in [64, 128, 256]:
+ a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+ b = torch.randn(K, N, dtype=torch.bfloat16) * 0.02
+ ref = a @ b
+
+ @triton.jit
+ def gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, a_stride, b_stride, out_stride,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a = tl.load(a_ptr + m_off[:, None] * a_stride + k_off[None, :])
+ b = tl.load(b_ptr + k_off[:, None] * b_stride + n_off[None, :])
+ c = tl.dot(a, b)
+ tl.store(out_ptr + m_off[:, None] * out_stride + n_off[None, :], c.to(tl.bfloat16))
+
+ out = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
+ gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out, M, K, N,
+ a.stride(0), b.stride(0), out.stride(0),
+ BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
+ num_warps=4,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sign_flips = ((ref.cuda() * out) < 0).sum().item()
+ print(f"N={N:3d}: max_diff={diff.max().item():.6f} mean={diff.mean().item():.6f} sign_flips={sign_flips}/{M*N}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug14.py",
"content": "<1353 chars \u2014 see diff>"
}Wrote file successfully.--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug15.py
@@ -0,0 +1,68 @@
+"""Focused: does tl.dot with BLOCK_N=128 give wrong results on SM120?"""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+# Test 1: (64, 32) @ (32, 128) - the single K-tile case
+M, K, N1 = 64, 32, 64
+N2 = 128
+
+a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+
+b64 = torch.randn(K, N1, dtype=torch.bfloat16) * 0.02
+b128 = torch.randn(K, N2, dtype=torch.bfloat16) * 0.02
+
+ref64 = a @ b64
+ref128 = a @ b128
+
+@triton.jit
+def gemm_kernel(a_ptr, b_ptr, out_ptr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ b = tl.load(b_ptr + k_off[:, None] * BLOCK_N + n_off[None, :])
+ c = tl.dot(a, b)
+ tl.store(out_ptr + m_off[:, None] * BLOCK_N + n_off[None, :], c.to(tl.bfloat16))
+
+for name, N in [("N=64", N1), ("N=128", N2)]:
+ b = b64 if N == N1 else b128
+ ref = ref64 if N == N1 else ref128
+ out = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
+ gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out,
+ BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
+ num_warps=4,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sign_flips = ((ref.cuda() * out) < 0).sum().item()
+ print(f"{name}: max_diff={diff.max().item():.6f} mean={diff.mean().item():.6f} sign_flips={sign_flips}/{M*N}")
+
+# Now test: what about (4, 32) @ (32, 128)?
+M2 = 4
+a_small = a[:M2]
+
+@triton.jit
+def gemm_kernel_small(a_ptr, b_ptr, out_ptr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ c = tl.dot(
+ tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :]),
+ tl.load(b_ptr + k_off[:, None] * BLOCK_N + n_off[None, :]),
+ )
+ tl.store(out_ptr + m_off[:, None] * BLOCK_N + n_off[None, :], c.to(tl.bfloat16))
+
+ref_small = a_small @ b128
+out_small = torch.zeros(M2, N2, dtype=torch.bfloat16, device='cuda')
+gemm_kernel_small[(1,)](
+ a_small.cuda(), b128.cuda(), out_small,
+ BLOCK_M=M2, BLOCK_N=N2, BLOCK_K=K,
+ num_warps=4,
+)
+diff_small = (ref_small.cuda().float() - out_small.float()).abs()
+print(f"M=4,N=128: max_diff={diff_small.max().item():.6f} mean={diff_small.mean().item():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug15.py",
"content": "<2396 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug15 - test N=128 specifically",
"command": "uv run python debug15.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}N=64: max_diff=0.000000 mean=0.000000 sign_flips=0/4096
N=128: max_diff=0.000000 mean=0.000000 sign_flips=0/8192
M=4,N=128: max_diff=0.000000 mean=0.000000
{
"description": "Re-run check.py with fixed solution",
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.21875 (atol=0.02, rtol=0.02)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
@@ -0,0 +1,136 @@
+"""CLEAN test: one K-tile correctly, compare tl.dot vs cuBLAS."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+n_tokens, H, I = 128, 4096, 1536
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+BLOCK_M, BLOCK_N, BLOCK_K = 64, 128, 32
+
+# Test ONE K-tile: columns k_start..k_start+BLOCK_K
+k_start = 0
+x_slice = x[:, k_start:k_start+BLOCK_K].contiguous() # (n_tokens, 32)
+w_slice = w[k_start:k_start+BLOCK_K, :].contiguous() # (32, 1536)
+
+# cuBLAS reference
+ref_partial = x_slice @ w_slice
+
+# Clean Triton kernel, explicit strides
+@triton.jit
+def clean_gemm(
+ x_ptr, w_ptr, out_ptr,
+ x_stride_0, w_stride_0,
+ out_stride_0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ n_tile = pid
+ n_start = n_tile * BLOCK_N
+
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+
+ # Load x[m, k] where k is relative to the tile start
+ # x shape: (n_tokens, BLOCK_K) with stride (BLOCK_K, 1)
+ # So offset = m * BLOCK_K + k
+ for m_base in range(0, n_tokens, BLOCK_M):
+ a = tl.load(
+ x_ptr + (m_base + m_off)[:, None] * x_stride_0 + k_off[None, :],
+ mask=(m_base + m_off)[:, None] < n_tokens,
+ )
+
+ # Load w[k, n] where k is relative to the tile start
+ # w shape: (BLOCK_K, I) with stride (I, 1)
+ # So offset = k * I + n
+ b = tl.load(
+ w_ptr + k_off[:, None] * w_stride_0 + (n_start + n_off)[None, :],
+ mask=(n_start + n_off)[None, :] < I,
+ )
+
+ c = tl.dot(a, b)
+
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_stride_0 + n_off[None, :],
+ c.to(tl.bfloat16),
+ mask=(m_base + m_off)[:, None] < n_tokens,
+ )
+
+num_nt = triton.cdiv(I, BLOCK_N)
+out_triton = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+clean_gemm[(num_nt,)](
+ x_slice.cuda(), w_slice.cuda(), out_triton,
+ x_slice.stride(0), w_slice.stride(0),
+ out_triton.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ num_warps=4, num_stages=2,
+)
+
+diff = (ref_partial.cuda().float() - out_triton.float()).abs()
+print(f"Single K-tile (k_start=0):")
+print(f" max_diff={diff.max().item():.6f} mean={diff.mean().item():.6f}")
+sign_flips = ((ref_partial.cuda() * out_triton) < 0).sum().item()
+print(f" sign_flips={sign_flips}/{n_tokens*I}")
+
+# Also test: full GEMM with all K tiles, CORRECTLY advancing k_start
+@triton.jit
+def full_gemm(
+ x_ptr, w_ptr, out_ptr,
+ H, I,
+ x_stride_0, w_stride_0, w_stride_1,
+ out_stride_0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ n_tile = pid
+ n_start = n_tile * BLOCK_N
+
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+
+ for m_base in range(0, n_tokens, BLOCK_M):
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_base in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_base + k_off < H
+
+ a = tl.load(
+ x_ptr + (m_base + m_off)[:, None] * x_stride_0 + (k_base + k_off)[None, :],
+ mask=(m_base + m_off)[:, None] < n_tokens,
+ )
+ b = tl.load(
+ w_ptr + (k_base + k_off)[:, None] * w_stride_0 + (n_start + n_off)[None, :],
+ mask=(k_base + k_off)[:, None] < H,
+ )
+
+ acc = tl.dot(a, b, acc)
+
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_stride_0 + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=(m_base + m_off)[:, None] < n_tokens,
+ )
+
+# Full GEMM reference
+ref_full = x @ w
+
+out_full = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+full_gemm[(num_nt,)](
+ x.cuda(), w.cuda(), out_full,
+ H, I,
+ x.stride(0), w.stride(0), w.stride(1),
+ out_full.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ num_warps=4, num_stages=3,
+)
+
+diff_full = (ref_full.cuda().float() - out_full.float()).abs()
+print(f"\nFull GEMM (all K tiles, correct offsets):")
+print(f" max_diff={diff_full.max().item():.6f} mean={diff_full.mean().item():.6f}")
+sign_flips_full = ((ref_full.cuda() * out_full) < 0).sum().item()
+print(f" sign_flips={sign_flips_full}/{n_tokens*I}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py",
"content": "<4412 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run clean debug test",
"command": "uv run python debug_clean.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py", line 65, in <module>
clean_gemm[(num_nt,)](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 80, in make_ir
return ast_to_ttir(self.fn, sel
... (1068 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
@@ -23,6 +23,7 @@
@triton.jit
def clean_gemm(
x_ptr, w_ptr, out_ptr,
+ n_tokens, I,
x_stride_0, w_stride_0,
out_stride_0,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
@@ -35,25 +36,16 @@
n_off = tl.arange(0, BLOCK_N)
k_off = tl.arange(0, BLOCK_K)
- # Load x[m, k] where k is relative to the tile start
- # x shape: (n_tokens, BLOCK_K) with stride (BLOCK_K, 1)
- # So offset = m * BLOCK_K + k
for m_base in range(0, n_tokens, BLOCK_M):
a = tl.load(
x_ptr + (m_base + m_off)[:, None] * x_stride_0 + k_off[None, :],
mask=(m_base + m_off)[:, None] < n_tokens,
)
-
- # Load w[k, n] where k is relative to the tile start
- # w shape: (BLOCK_K, I) with stride (I, 1)
- # So offset = k * I + n
b = tl.load(
w_ptr + k_off[:, None] * w_stride_0 + (n_start + n_off)[None, :],
mask=(n_start + n_off)[None, :] < I,
)
-
c = tl.dot(a, b)
-
tl.store(
out_ptr + (m_base + m_off)[:, None] * out_stride_0 + n_off[None, :],
c.to(tl.bfloat16),
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py",
"oldString": "<1229 chars \u2014 see below>",
"newString": "<948 chars \u2014 see below>"
}def clean_gemm(
x_ptr, w_ptr, out_ptr,
x_stride_0, w_stride_0,
out_stride_0,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
n_tile = pid
n_start = n_tile * BLOCK_N
m_off = tl.arange(0, BLOCK_M)
n_off = tl.arange(0, BLOCK_N)
k_off = tl.arange(0, BLOCK_K)
# Load x[m, k] where k is relative to the tile start
# x shape: (n_tokens, BLOCK_K) with stride (BLOCK_K, 1)
# So offset = m * BLOCK_K + k
for m_base in range(0, n_tokens, BLOCK_M):
a = tl.load(
x_ptr + (m_base + m_off)[:, None] * x_stride_0 + k_off[None, :],
mask=(m_base + m_off)[:, None] < n_tokens,
)
# Load w[k, n] where k is relative to the tile start
# w shape: (BLOCK_K, I) with stride (I, 1)
# So offset = k * I + n
b = tl.load(
w_ptr + k_off[:, None] * w_stride_0 + (n_start + n_off)[None, :],
mask=(n_start + n_off)[None, :] < I,
)
c = tl.dot(a, b)
tl.store(
out_ptr + (m_base + m_off)[:, None] * out_stride_0 + n_off[None, :],
c.to(tl.bfloat16),
mask=(m_base + m_off)[:, None] < n_tokens,
)def clean_gemm(
x_ptr, w_ptr, out_ptr,
n_tokens, I,
x_stride_0, w_stride_0,
out_stride_0,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
n_tile = pid
n_start = n_tile * BLOCK_N
m_off = tl.arange(0, BLOCK_M)
n_off = tl.arange(0, BLOCK_N)
k_off = tl.arange(0, BLOCK_K)
for m_base in range(0, n_tokens, BLOCK_M):
a = tl.load(
x_ptr + (m_base + m_off)[:, None] * x_stride_0 + k_off[None, :],
mask=(m_base + m_off)[:, None] < n_tokens,
)
b = tl.load(
w_ptr + k_off[:, None] * w_stride_0 + (n_start + n_off)[None, :],
mask=(n_start + n_off)[None, :] < I,
)
c = tl.dot(a, b)
tl.store(
out_ptr + (m_base + m_off)[:, None] * out_stride_0 + n_off[None, :],
c.to(tl.bfloat16),
mask=(m_base + m_off)[:, None] < n_tokens,
)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
@@ -72,8 +72,8 @@
@triton.jit
def full_gemm(
x_ptr, w_ptr, out_ptr,
- H, I,
- x_stride_0, w_stride_0, w_stride_1,
+ n_tokens, H, I,
+ x_stride_0, w_stride_0,
out_stride_0,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
@@ -89,7 +89,6 @@
for k_base in range(0, H, BLOCK_K):
k_off = tl.arange(0, BLOCK_K)
- k_mask = k_base + k_off < H
a = tl.load(
x_ptr + (m_base + m_off)[:, None] * x_stride_0 + (k_base + k_off)[None, :],
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py",
"oldString": "<1175 chars \u2014 see below>",
"newString": "<1133 chars \u2014 see below>"
}def full_gemm(
x_ptr, w_ptr, out_ptr,
H, I,
x_stride_0, w_stride_0, w_stride_1,
out_stride_0,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
n_tile = pid
n_start = n_tile * BLOCK_N
m_off = tl.arange(0, BLOCK_M)
n_off = tl.arange(0, BLOCK_N)
for m_base in range(0, n_tokens, BLOCK_M):
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_base in range(0, H, BLOCK_K):
k_off = tl.arange(0, BLOCK_K)
k_mask = k_base + k_off < H
a = tl.load(
x_ptr + (m_base + m_off)[:, None] * x_stride_0 + (k_base + k_off)[None, :],
mask=(m_base + m_off)[:, None] < n_tokens,
)
b = tl.load(
w_ptr + (k_base + k_off)[:, None] * w_stride_0 + (n_start + n_off)[None, :],
mask=(k_base + k_off)[:, None] < H,
)
acc = tl.dot(a, b, acc)
tl.store(
out_ptr + (m_base + m_off)[:, None] * out_stride_0 + n_off[None, :],
acc.to(tl.bfloat16),
mask=(m_base + m_off)[:, None] < n_tokens,
)def full_gemm(
x_ptr, w_ptr, out_ptr,
n_tokens, H, I,
x_stride_0, w_stride_0,
out_stride_0,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
n_tile = pid
n_start = n_tile * BLOCK_N
m_off = tl.arange(0, BLOCK_M)
n_off = tl.arange(0, BLOCK_N)
for m_base in range(0, n_tokens, BLOCK_M):
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_base in range(0, H, BLOCK_K):
k_off = tl.arange(0, BLOCK_K)
a = tl.load(
x_ptr + (m_base + m_off)[:, None] * x_stride_0 + (k_base + k_off)[None, :],
mask=(m_base + m_off)[:, None] < n_tokens,
)
b = tl.load(
w_ptr + (k_base + k_off)[:, None] * w_stride_0 + (n_start + n_off)[None, :],
mask=(k_base + k_off)[:, None] < H,
)
acc = tl.dot(a, b, acc)
tl.store(
out_ptr + (m_base + m_off)[:, None] * out_stride_0 + n_off[None, :],
acc.to(tl.bfloat16),
mask=(m_base + m_off)[:, None] < n_tokens,
)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
@@ -56,6 +56,7 @@
out_triton = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
clean_gemm[(num_nt,)](
x_slice.cuda(), w_slice.cuda(), out_triton,
+ n_tokens, I,
x_slice.stride(0), w_slice.stride(0),
out_triton.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py",
"oldString": "<226 chars \u2014 see below>",
"newString": "<243 chars \u2014 see below>"
}clean_gemm[(num_nt,)](
x_slice.cuda(), w_slice.cuda(), out_triton,
x_slice.stride(0), w_slice.stride(0),
out_triton.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=4, num_stages=2,
)clean_gemm[(num_nt,)](
x_slice.cuda(), w_slice.cuda(), out_triton,
n_tokens, I,
x_slice.stride(0), w_slice.stride(0),
out_triton.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=4, num_stages=2,
)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py
@@ -114,8 +114,8 @@
out_full = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
full_gemm[(num_nt,)](
x.cuda(), w.cuda(), out_full,
- H, I,
- x.stride(0), w.stride(0), w.stride(1),
+ n_tokens, H, I,
+ x.stride(0), w.stride(0),
out_full.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=4, num_stages=3,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_clean.py",
"oldString": "<293 chars \u2014 see below>",
"newString": "<290 chars \u2014 see below>"
}out_full = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
full_gemm[(num_nt,)](
x.cuda(), w.cuda(), out_full,
H, I,
x.stride(0), w.stride(0), w.stride(1),
out_full.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=4, num_stages=3,
)out_full = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
full_gemm[(num_nt,)](
x.cuda(), w.cuda(), out_full,
n_tokens, H, I,
x.stride(0), w.stride(0),
out_full.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=4, num_stages=3,
)Edit applied successfully.{
"description": "Run clean debug",
"command": "uv run python debug_clean.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Single K-tile (k_start=0):
max_diff=0.067871 mean=0.009083
sign_flips=8207/196608
Full GEMM (all K tiles, correct offsets):
max_diff=0.684570 mean=0.102348
sign_flips=7096/196608
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_workaround.py
@@ -0,0 +1,85 @@
+"""Test workarounds for bf16 tl.dot bug on SM120."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+n_tokens, H, I = 128, 4096, 1536
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w_gate = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_up = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# Full reference
+gate_ref = x @ w_gate
+up_ref = x @ w_up
+out_ref = torch.nn.functional.silu(gate_ref) * up_ref
+
+BLOCK_M, BLOCK_N, BLOCK_K = 64, 128, 32
+num_nt = triton.cdiv(I, BLOCK_N)
+
+@triton.jit
+def full_kernel(
+ x_ptr, wg_ptr, wu_ptr, out_ptr,
+ n_tokens, H, I,
+ x_s0, wg_s0, wu_s0, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ USE_FP16: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ nt = pid
+ n_start = nt * BLOCK_N
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+
+ for m_base in range(0, n_tokens, BLOCK_M):
+ acc_g = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_u = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_base in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+
+ a_raw = tl.load(
+ x_ptr + (m_base + m_off)[:, None] * x_s0 + (k_base + k_off)[None, :],
+ mask=(m_base + m_off)[:, None] < n_tokens,
+ )
+ wg_raw = tl.load(
+ wg_ptr + (k_base + k_off)[:, None] * wg_s0 + (n_start + n_off)[None, :],
+ mask=(k_base + k_off)[:, None] < H,
+ )
+ wu_raw = tl.load(
+ wu_ptr + (k_base + k_off)[:, None] * wu_s0 + (n_start + n_off)[None, :],
+ mask=(k_base + k_off)[:, None] < H,
+ )
+
+ if USE_FP16:
+ acc_g = tl.dot(a_raw.to(tl.float16), wg_raw.to(tl.float16), acc_g)
+ acc_u = tl.dot(a_raw.to(tl.float16), wu_raw.to(tl.float16), acc_u)
+ else:
+ a_f32 = a_raw.to(tl.float32)
+ acc_g = tl.dot(a_f32, wg_raw.to(tl.float32), acc_g)
+ acc_u = tl.dot(a_f32, wu_raw.to(tl.float32), acc_u)
+
+ gate_silu = tl.sigmoid(acc_g) * acc_g
+ result = (gate_silu * acc_u).to(tl.bfloat16)
+
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_s0 + n_off[None, :],
+ result,
+ mask=(m_base + m_off)[:, None] < n_tokens,
+ )
+
+for name, use_fp16 in [("fp16_gemm", True), ("fp32_gemm", False)]:
+ out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ full_kernel[(num_nt,)](
+ x.cuda(), w_gate.cuda(), w_up.cuda(), out,
+ n_tokens, H, I,
+ x.stride(0), w_gate.stride(0), w_up.stride(0), out.stride(0),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ USE_FP16=use_fp16,
+ num_warps=4, num_stages=3,
+ )
+ diff = (out_ref.cuda().float() - out.float()).abs()
+ print(f"{name}: max_diff={diff.max().item():.6f} mean={diff.mean().item():.6f}")
+ sign_flips = ((out_ref.cuda() * out) < 0).sum().item()
+ print(f" sign_flips={sign_flips}/{n_tokens*I}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_workaround.py",
"content": "<3060 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run workaround comparison",
"command": "uv run python debug_workaround.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}fp16_gemm: max_diff=0.094238 mean=0.005291
sign_flips=8231/196608
fp32_gemm: max_diff=0.094238 mean=0.004965
sign_flips=4161/196608
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_load.py
@@ -0,0 +1,80 @@
+"""Test: does Triton load/store bf16 data correctly?"""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+@triton.jit
+def load_store_kernel(in_ptr, out_ptr, N):
+ pid = tl.program_id(0)
+ off = pid * 128 + tl.arange(0, 128)
+ mask = off < N
+ x = tl.load(in_ptr + off, mask=mask)
+ tl.store(out_ptr + off, x, mask=mask)
+
+x = torch.randn(1024, dtype=torch.bfloat16) * 0.1
+out = torch.zeros(1024, dtype=torch.bfloat16, device='cuda')
+load_store_kernel[(8,)](
+ x.cuda(), out, 1024,
+ num_warps=4,
+)
+print(f"Load/store test: max_diff={(x.cuda() - out).abs().max().item():.6f}")
+
+# Now test: does tl.dot with tiny amounts work?
+@triton.jit
+def tiny_dot(a_ptr, b_ptr, out_ptr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ b = tl.load(b_ptr + k_off[:, None] * BLOCK_N + n_off[None, :])
+ c = tl.dot(a, b)
+ tl.store(out_ptr + m_off[:, None] * BLOCK_N + n_off[None, :], c.to(tl.bfloat16))
+
+# Test with very small values where bf16 has limited precision
+for scale in [0.1, 1.0, 10.0]:
+ a = torch.randn(16, 16, dtype=torch.bfloat16) * scale
+ b = torch.randn(16, 16, dtype=torch.bfloat16) * scale * 0.1
+ ref = a @ b
+
+ out = torch.zeros(16, 16, dtype=torch.bfloat16, device='cuda')
+ tiny_dot[(1,)](
+ a.cuda(), b.cuda(), out,
+ BLOCK_M=16, BLOCK_N=16, BLOCK_K=16,
+ num_warps=4,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sign_flips = ((ref.cuda() * out) < 0).sum().item()
+ print(f"16x16 dot (scale={scale}): max_diff={diff.max().item():.6f} sign_flips={sign_flips}/256")
+
+# Now test: does the issue happen with any specific row count?
+for M in [1, 2, 4, 8, 16, 32, 64]:
+ a = torch.randn(M, 32, dtype=torch.bfloat16) * 0.1
+ b = torch.randn(32, 128, dtype=torch.bfloat16) * 0.02
+ ref = a @ b
+
+ @triton.jit
+ def gemm_kernel(a_ptr, b_ptr, out_ptr, M, N,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :], mask=m_off[:, None] < M)
+ b = tl.load(b_ptr + k_off[:, None] * N + n_off[None, :])
+ c = tl.dot(a, b)
+ tl.store(out_ptr + m_off[:, None] * N + n_off[None, :], c.to(tl.bfloat16), mask=m_off[:, None] < M)
+
+ out = torch.zeros(M, 128, dtype=torch.bfloat16, device='cuda')
+ gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out, M, 128,
+ BLOCK_M=M, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sign_flips = ((ref.cuda() * out) < 0).sum().item()
+ if sign_flips > 0 or diff.max() > 0.001:
+ print(f"M={M:2d} N=128: max_diff={diff.max():.6f} sign_flips={sign_flips}/{M*128}")
+ else:
+ print(f"M={M:2d} N=128: OK")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_load.py",
"content": "<3077 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run load test and tiny dot tests",
"command": "uv run python debug_load.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Load/store test: max_diff=0.000000
16x16 dot (scale=0.1): max_diff=0.000000 sign_flips=0/256
16x16 dot (scale=1.0): max_diff=0.000000 sign_flips=0/256
16x16 dot (scale=10.0): max_diff=0.000000 sign_flips=0/256
M= 1 N=128: OK
M= 2 N=128: OK
M= 4 N=128: OK
M= 8 N=128: OK
M=16 N=128: OK
M=32 N=128: OK
M=64 N=128: OK
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_scale.py
@@ -0,0 +1,66 @@
+"""Test GEMM at different scales to find where errors start."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+@triton.jit
+def gemm_kernel(a_ptr, b_ptr, out_ptr, N, a_s0, b_s0, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ nt = pid
+ n_start = nt * BLOCK_N
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+
+ a = tl.load(a_ptr + m_off[:, None] * a_s0 + k_off[None, :])
+ b = tl.load(b_ptr + k_off[:, None] * b_s0 + (n_start + n_off)[None, :])
+ c = tl.dot(a, b)
+ tl.store(out_ptr + m_off[:, None] * out_s0 + n_off[None, :], c.to(tl.bfloat16))
+
+# Test 1: fixed K=32, N=128, vary M up to 64
+print("Fixed (64,32) x (32,128):")
+a = torch.randn(64, 32, dtype=torch.bfloat16) * 0.1
+b = torch.randn(32, 128, dtype=torch.bfloat16) * 0.02
+ref = a @ b
+out = torch.zeros(64, 128, dtype=torch.bfloat16, device='cuda')
+gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out, 128, a.stride(0), b.stride(0), out.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, num_warps=4,
+)
+diff = (ref.cuda().float() - out.float()).abs()
+sf = ((ref.cuda() * out) < 0).sum().item()
+print(f" max_diff={diff.max():.6f} sign_flips={sf}")
+
+# Test 2: vary K, keep M,N fixed
+for K in [32, 64, 128, 256, 512, 1024]:
+ a = torch.randn(64, K, dtype=torch.bfloat16) * 0.1
+ b = torch.randn(K, 128, dtype=torch.bfloat16) * 0.02
+ ref = a @ b
+
+ @triton.jit
+ def gemm_split(a_ptr, b_ptr, out_ptr, a_s0, b_s0, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ nt = pid
+ n_start = nt * BLOCK_N
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_base in range(0, K, BLOCK_K):
+ k_off2 = tl.arange(0, BLOCK_K)
+ a2 = tl.load(a_ptr + m_off[:, None] * a_s0 + (k_base + k_off2)[None, :])
+ b2 = tl.load(b_ptr + (k_base + k_off2)[:, None] * b_s0 + (n_start + n_off)[None, :])
+ acc = tl.dot(a2, b2, acc)
+ tl.store(out_ptr + m_off[:, None] * out_s0 + n_off[None, :], acc.to(tl.bfloat16))
+
+ out = torch.zeros(64, 128, dtype=torch.bfloat16, device='cuda')
+ gemm_split[(1,)](
+ a.cuda(), b.cuda(), out, a.stride(0), b.stride(0), out.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, num_warps=4, num_stages=3,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sf = ((ref.cuda() * out) < 0).sum().item()
+ print(f"K={K:4d} ({(K+31)//32} tiles): max_diff={diff.max():.6f} sign_flips={sf}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_scale.py",
"content": "<2697 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run debug_scale to find error boundary",
"command": "uv run python debug_scale.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Fixed (64,32) x (32,128):
max_diff=0.000000 sign_flips=0
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_scale.py", line 60, in <module>
gemm_split[(1,)](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", l
... (1112 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_stride.py
@@ -0,0 +1,89 @@
+"""Test if stride width in the N dimension causes tl.dot errors."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+M, K = 64, 32
+N_tile = 128
+
+# Case 1: w has shape (32, 128) - contiguous, stride(0)=128
+a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+b1 = torch.randn(K, N_tile, dtype=torch.bfloat16) * 0.02
+ref1 = a @ b1
+
+# Case 2: w has shape (32, 1536) - stride(0)=1536, we access first 128 columns
+w_big = torch.randn(K, 1536, dtype=torch.bfloat16) * 0.02
+b2 = w_big[:, :128] # stride(0) = 1536
+ref2 = a @ b2 # cuBLAS computes the same dot product
+
+print(f"Ref1 vs Ref2 (should be same): max={((ref1-ref2).abs().max().item()):.6f}")
+
+@triton.jit
+def gemm_kernel(a_ptr, b_ptr, out_ptr, b_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a_t = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ b_t = tl.load(b_ptr + k_off[:, None] * b_s0 + n_off[None, :])
+ c = tl.dot(a_t, b_t)
+ tl.store(out_ptr + m_off[:, None] * BLOCK_N + n_off[None, :], c.to(tl.bfloat16))
+
+for name, b, b_s0 in [("stride=128", b1, b1.stride(0)), ("stride=1536", b2, b2.stride(0))]:
+ out = torch.zeros(M, N_tile, dtype=torch.bfloat16, device='cuda')
+ gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out, b_s0,
+ BLOCK_M=M, BLOCK_N=N_tile, BLOCK_K=K,
+ num_warps=4,
+ )
+ ref = a @ b # cuBLAS with whatever shape
+ diff = (ref.cuda().float() - out.float()).abs()
+ sf = ((ref.cuda() * out) < 0).sum().item()
+ print(f"{name}: max_diff={diff.max():.6f} sign_flips={sf}/{M*N_tile}")
+
+# Test: what if the B tensor has a large stride in the K dimension too?
+# Use a big tensor with K rows but skip rows
+w_wide = torch.randn(4096, 1536, dtype=torch.bfloat16) * 0.02
+b3 = w_wide[:K, :N_tile] # stride(0)=1536, stride(1)=1
+ref3 = a @ b3
+out3 = torch.zeros(M, N_tile, dtype=torch.bfloat16, device='cuda')
+gemm_kernel[(1,)](
+ a.cuda(), b3.cuda(), out3, b3.stride(0),
+ BLOCK_M=M, BLOCK_N=N_tile, BLOCK_K=K,
+ num_warps=4,
+)
+diff3 = (ref3.cuda().float() - out3.float()).abs()
+sf3 = ((ref3.cuda() * out3) < 0).sum().item()
+print(f"big_view (K=32,N=128 from (4096,1536)): max_diff={diff3.max():.6f} sign_flips={sf3}/{M*N_tile}")
+
+# Test: do sign flips appear with multiple N-tiles?
+out_multi = torch.zeros(M, 1536, dtype=torch.bfloat16, device='cuda')
+num_nt = triton.cdiv(1536, N_tile)
+
+@triton.jit
+def gemm_multi(a_ptr, b_ptr, out_ptr, b_s0, I,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ n_start = pid * BLOCK_N
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a_t = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ b_t = tl.load(b_ptr + k_off[:, None] * b_s0 + (n_start + n_off)[None, :])
+ c = tl.dot(a_t, b_t)
+ tl.store(out_ptr + m_off[:, None] * I + n_off[None, :], c.to(tl.bfloat16))
+
+a_big = torch.randn(64, 32, dtype=torch.bfloat16) * 0.1
+w_big2 = torch.randn(32, 1536, dtype=torch.bfloat16) * 0.02
+ref_big = a_big @ w_big2
+
+gemm_multi[(num_nt,)](
+ a_big.cuda(), w_big2.cuda(), out_multi, w_big2.stride(0), 1536,
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4,
+)
+diff_big = (ref_big.cuda().float() - out_multi.float()).abs()
+sf_big = ((ref_big.cuda() * out_multi) < 0).sum().item()
+print(f"\nMulti N-tile (64,32) x (32,1536): max_diff={diff_big.max():.6f} sign_flips={sf_big}/{64*1536}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_stride.py",
"content": "<3557 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run stride test",
"command": "uv run python debug_stride.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Ref1 vs Ref2 (should be same): max=0.062988
stride=128: max_diff=0.000000 sign_flips=0/8192
stride=1536: max_diff=0.050011 sign_flips=3976/8192
big_view (K=32,N=128 from (4096,1536)): max_diff=49846049841934345233892764760512921600.000000 sign_flips=4022/8192
Multi N-tile (64,32) x (32,1536): max_diff=0.079834 sign_flips=4105/98304
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_tmafix.py
@@ -0,0 +1,41 @@
+"""Test TMA workaround: try small num_stages or different approach."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+M, K, N_tile = 64, 32, 128
+a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+w_big = torch.randn(K, 1536, dtype=torch.bfloat16) * 0.02
+b = w_big[:, :128] # stride(0)=1536
+ref = a @ b
+
+print(f"b shape={b.shape}, stride={b.stride()}")
+print(f"cuBLAS ref: mean={ref.float().mean():.6f}")
+
+@triton.jit
+def gemm_kernel(a_ptr, b_ptr, out_ptr, b_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a_t = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ b_t = tl.load(b_ptr + k_off[:, None] * b_s0 + n_off[None, :])
+ c = tl.dot(a_t, b_t)
+ tl.store(out_ptr + m_off[:, None] * BLOCK_N + n_off[None, :], c.to(tl.bfloat16))
+
+# Test with different num_stages
+for ns in [0, 1, 2, 3, 4]:
+ out = torch.zeros(M, N_tile, dtype=torch.bfloat16, device='cuda')
+ try:
+ gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out, b.stride(0),
+ BLOCK_M=M, BLOCK_N=N_tile, BLOCK_K=K,
+ num_warps=4, num_stages=ns,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sf = ((ref.cuda() * out) < 0).sum().item()
+ print(f"num_stages={ns}: max_diff={diff.max():.6f} sign_flips={sf}/{M*N_tile}")
+ except Exception as e:
+ print(f"num_stages={ns}: FAILED - {e}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_tmafix.py",
"content": "<1515 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test num_stages fix",
"command": "uv run python debug_tmafix.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}b shape=torch.Size([32, 128]), stride=(1536, 1)
cuBLAS ref: mean=0.000100
num_stages=0: max_diff=0.049866 sign_flips=3922/8192
num_stages=1: max_diff=713053462628379038341895553024.000000 sign_flips=3985/8192
num_stages=2: max_diff=2920666982925840541048404185186304.000000 sign_flips=3995/8192
num_stages=3: max_diff=689601926524156794414206543724544.000000 sign_flips=4038/8192
num_stages=4: max_diff=nan sign_flips=3962/8192
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: from __future__ import annotations
2:
3: import torch
4: import torch.nn as nn
5: import triton
6: import triton.language as tl
7:
8: OP_TYPE = "grouped_gemm_swiglu"
9: SUPPORTED_PRECISIONS = ["bf16"]
10: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
11:
12:
13: @triton.jit
14: def grouped_gemm_swiglu_kernel(
15: hidden_states_ptr,
16: W_gate_ptr,
17: W_up_ptr,
18: expert_offsets_ptr,
19: output_ptr,
20: H: tl.constexpr,
21: I: tl.constexpr,
22: stride_hidden_h: tl.constexpr,
23: stride_gate_e: tl.constexpr,
24: stride_gate_h: tl.constexpr,
25: stride_gate_i: tl.constexpr,
26: stride_up_e: tl.constexpr,
27: stride_up_h: tl.constexpr,
28: stride_up_i: tl.constexpr,
29: stride_out_i: tl.constexpr,
30: BLOCK_M: tl.constexpr,
31: BLOCK_N: tl.constexpr,
32: BLOCK_K: tl.constexpr,
33: ):
34: pid = tl.program_id(0)
35: num_n_tiles = tl.cdiv(I, BLOCK_N)
36: e = pid // num_n_tiles
37: n_tile_idx = pid % num_n_tiles
38:
39: start = tl.load(expert_offsets_ptr + e)
40: end = tl.load(expert_offsets_ptr + e + 1)
41: n_tokens = end - start
42: if n_tokens <= 0:
43: return
44:
45: n_start = n_tile_idx * BLOCK_N
46: n_off = tl.arange(0, BLOCK_N)
47: n_mask = n_off < I - n_start
48:
49: w_gate_base = W_gate_ptr + e * stride_gate_e
50
... (4529 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_stride2.py
@@ -0,0 +1,37 @@
+"""Find which strides work and which don't for tl.dot."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+M, K, BN = 64, 32, 128
+a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+
+@triton.jit
+def gemm_kernel(a_ptr, b_ptr, out_ptr, b_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a_t = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ b_t = tl.load(b_ptr + k_off[:, None] * b_s0 + n_off[None, :])
+ c = tl.dot(a_t, b_t)
+ tl.store(out_ptr + m_off[:, None] * BLOCK_N + n_off[None, :], c.to(tl.bfloat16))
+
+for stride in [128, 129, 130, 132, 144, 160, 192, 256, 512, 1024, 1536, 2048]:
+ b = torch.randn(K, stride, dtype=torch.bfloat16) * 0.02
+ # We use first BN=128 columns
+ ref = a @ b[:, :BN]
+ out = torch.zeros(M, BN, dtype=torch.bfloat16, device='cuda')
+ try:
+ gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out, b.stride(0),
+ BLOCK_M=M, BLOCK_N=BN, BLOCK_K=K,
+ num_warps=4,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sf = ((ref.cuda() * out) < 0).sum().item()
+ flag = "SIGNFLIP" if sf > 0 else "OK"
+ print(f"stride={stride:4d}: max_diff={diff.max():.6f} {flag} sf={sf}")
+ except Exception as e:
+ print(f"stride={stride:4d}: FAILED - {e}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_stride2.py",
"content": "<1442 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Find working stride values",
"command": "uv run python debug_stride2.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}stride= 128: max_diff=0.000000 OK sf=0
stride= 129: max_diff=0.000000 OK sf=0
stride= 130: max_diff=0.000000 OK sf=0
stride= 132: max_diff=0.000001 OK sf=0
stride= 144: max_diff=0.000000 OK sf=0
stride= 160: max_diff=0.000000 OK sf=0
stride= 192: max_diff=0.000000 OK sf=0
stride= 256: max_diff=0.000000 OK sf=0
stride= 512: max_diff=0.000000 OK sf=0
stride=1024: max_diff=0.000000 OK sf=0
stride=1536: max_diff=0.000000 OK sf=0
stride=2048: max_diff=0.000002 OK sf=0
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_notma.py
@@ -0,0 +1,45 @@
+"""Test if disabling TMA fixes the stride issue."""
+import os
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+M, K, BN = 64, 32, 128
+a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+w_big = torch.randn(K, 1536, dtype=torch.bfloat16) * 0.02
+b = w_big[:, :128] # view with stride 1536
+ref = a @ b
+
+@triton.jit
+def gemm_kernel(a_ptr, b_ptr, out_ptr, b_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a_t = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ b_t = tl.load(b_ptr + k_off[:, None] * b_s0 + n_off[None, :])
+ c = tl.dot(a_t, b_t)
+ tl.store(out_ptr + m_off[:, None] * BLOCK_N + n_off[None, :], c.to(tl.bfloat16))
+
+# Test with TRITON_DISABLE_TMA
+for disable_tma in [None, "1"]:
+ if disable_tma:
+ os.environ["TRITON_DISABLE_TMA"] = "1"
+ else:
+ os.environ.pop("TRITON_DISABLE_TMA", None)
+
+ out = torch.zeros(M, BN, dtype=torch.bfloat16, device='cuda')
+ try:
+ gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out, b.stride(0),
+ BLOCK_M=M, BLOCK_N=BN, BLOCK_K=K,
+ num_warps=4, num_stages=2,
+ )
+ diff = (ref.cuda().float() - out.float()).abs()
+ sf = ((ref.cuda() * out) < 0).sum().item()
+ flag = "SIGNFLIP" if sf > 0 else "OK"
+ print(f"DISABLE_TMA={disable_tma}: max_diff={diff.max():.6f} {flag} sf={sf}")
+ except Exception as e:
+ print(f"DISABLE_TMA={disable_tma}: FAILED {e}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_notma.py",
"content": "<1586 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test TRITON_DISABLE_TMA in fresh process",
"command": "<1094 chars \u2014 see below>",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}TRITON_DISABLE_TMA=1 uv run python -c "
import torch, triton, triton.language as tl
torch.manual_seed(42)
M, K, BN = 64, 32, 128
a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
w_big = torch.randn(K, 1536, dtype=torch.bfloat16) * 0.02
b = w_big[:, :128]
ref = a @ b
@triton.jit
def gemm_kernel(a_ptr, b_ptr, out_ptr, b_s0,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
a_t = tl.load(a_ptr + tl.arange(0, BLOCK_M)[:, None] * BLOCK_K + tl.arange(0, BLOCK_K)[None, :])
b_t = tl.load(b_ptr + tl.arange(0, BLOCK_K)[:, None] * b_s0 + tl.arange(0, BLOCK_N)[None, :])
c = tl.dot(a_t, b_t)
tl.store(out_ptr + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :], c.to(tl.bfloat16))
out = torch.zeros(M, BN, dtype=torch.bfloat16, device='cuda')
gemm_kernel[(1,)](a.cuda(), b.cuda(), out, b.stride(0),
BLOCK_M=M, BLOCK_N=BN, BLOCK_K=K, num_warps=4)
diff = (ref.cuda().float() - out.float()).abs()
sf = ((ref.cuda() * out) < 0).sum().item()
print(f'DISABLE_TMA=1: max_diff={diff.max():.6f} sign_flips={sf}/{M*BN}')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 10, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __
... (150 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_notma2.py
@@ -0,0 +1,33 @@
+"""Test if disabling TMA fixes the stride issue. Run with env var."""
+import os
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+M, K, BN = 64, 32, 128
+a = torch.randn(M, K, dtype=torch.bfloat16) * 0.1
+w_big = torch.randn(K, 1536, dtype=torch.bfloat16) * 0.02
+b = w_big[:, :128] # view with stride 1536
+ref = a @ b
+
+@triton.jit
+def gemm_kernel(a_ptr, b_ptr, out_ptr, b_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+ a_t = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ b_t = tl.load(b_ptr + k_off[:, None] * b_s0 + n_off[None, :])
+ c = tl.dot(a_t, b_t)
+ tl.store(out_ptr + m_off[:, None] * BLOCK_N + n_off[None, :], c.to(tl.bfloat16))
+
+out = torch.zeros(M, BN, dtype=torch.bfloat16, device='cuda')
+gemm_kernel[(1,)](
+ a.cuda(), b.cuda(), out, b.stride(0),
+ BLOCK_M=M, BLOCK_N=BN, BLOCK_K=K,
+ num_warps=4, num_stages=2,
+)
+diff = (ref.cuda().float() - out.float()).abs()
+sf = ((ref.cuda() * out) < 0).sum().item()
+print(f"DISABLE_TMA={os.environ.get('TRITON_DISABLE_TMA', '0')}: max_diff={diff.max():.6f} sign_flips={sf}/{M*BN}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_notma2.py",
"content": "<1239 chars \u2014 see diff>"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -13,22 +13,24 @@
@triton.jit
def grouped_gemm_swiglu_kernel(
hidden_states_ptr,
- W_gate_ptr,
- W_up_ptr,
+ W_gate_tiled_ptr,
+ W_up_tiled_ptr,
expert_offsets_ptr,
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
+ BLOCK_N: tl.constexpr,
stride_hidden_h: tl.constexpr,
- stride_gate_e: tl.constexpr,
- stride_gate_h: tl.constexpr,
- stride_gate_i: tl.constexpr,
- stride_up_e: tl.constexpr,
- stride_up_h: tl.constexpr,
- stride_up_i: tl.constexpr,
+ stride_gate_tiled_e: tl.constexpr,
+ stride_gate_tiled_t: tl.constexpr,
+ stride_gate_tiled_h: tl.constexpr,
+ stride_gate_tiled_n: tl.constexpr,
+ stride_up_tiled_e: tl.constexpr,
+ stride_up_tiled_t: tl.constexpr,
+ stride_up_tiled_h: tl.constexpr,
+ stride_up_tiled_n: tl.constexpr,
stride_out_i: tl.constexpr,
BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
@@ -46,8 +48,8 @@
n_off = tl.arange(0, BLOCK_N)
n_mask = n_off < I - n_start
- w_gate_base = W_gate_ptr + e * stride_gate_e
- w_up_base = W_up_ptr + e * stride_up_e
+ w_gate_tile_ptr = W_gate_tiled_ptr + e * stride_gate_tiled_e + n_tile_idx * stride_gate_tiled_t
+ w_up_tile_ptr = W_up_tiled_ptr + e * stride_up_tiled_e + n_tile_idx * stride_up_tiled_t
for m_start in range(0, n_tokens, BLOCK_M):
m_off = tl.arange(0, BLOCK_M)
@@ -67,20 +69,20 @@
mask=m_mask[:, None] & k_mask[None, :],
)
w_gate = tl.load(
- w_gate_base
- + (k_start + k_off)[:, None] * stride_gate_h
- + (n_start + n_off)[None, :],
+ w_gate_tile_ptr
+ + (k_start + k_off)[:, None] * stride_gate_tiled_h
+ + n_off[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
w_up = tl.load(
- w_up_base
- + (k_start + k_off)[:, None] * stride_up_h
- + (n_start + n_off)[None, :],
+ w_up_tile_ptr
+ + (k_start + k_off)[:, None] * stride_up_tiled_h
+ + n_off[None, :],
mask=k_mask[:, None] & n_mask[None, :],
)
- acc_gate = tl.dot(x, w_gate, acc_gate, input_precision='ieee')
- acc_up = tl.dot(x, w_up, acc_up, input_precision='ieee')
+ acc_gate = tl.dot(x, w_gate, acc_gate)
+ acc_up = tl.dot(x, w_up, acc_up)
gate_silu = tl.sigmoid(acc_gate) * acc_gate
result = (gate_silu * acc_up).to(tl.bfloat16)
@@ -105,6 +107,28 @@
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
+ self._gate_tiled: torch.Tensor | None = None
+ self._up_tiled: torch.Tensor | None = None
+
+ def _make_tiled(self) -> tuple[torch.Tensor, torch.Tensor]:
+ """Reorganize (E, H, I) weights to (E, num_nt, H, BLOCK_N) for TMA-compatible access."""
+ if self._gate_tiled is not None:
+ return self._gate_tiled, self._up_tiled
+ E, H, I = self.E, self.H, self.I
+ bn = 128 if I >= 128 else 32
+ num_nt = I // bn if I % bn == 0 else triton.cdiv(I, bn)
+ pad = num_nt * bn - I
+ if pad > 0:
+ wg = torch.nn.functional.pad(self.W_gate.view(E, H, -1), (0, pad))
+ wu = torch.nn.functional.pad(self.W_up.view(E, H, -1), (0, pad))
+ else:
+ wg = self.W_gate
+ wu = self.W_up
+ gt = wg.view(E, H, num_nt, bn).permute(0, 2, 1, 3).contiguous()
+ ut = wu.view(E, H, num_nt, bn).permute(0, 2, 1, 3).contiguous()
+ self._gate_tiled = gt
+ self._up_tiled = ut
+ return gt, ut
def forward(
self,
@@ -114,29 +138,33 @@
T_perm, _ = hidden_states.shape
out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
- M_avg = T_perm // self.E
bn = 128 if self.I >= 128 else 32
bm = 64
bk = 32
nw = 4
ns = 3
- grid = (self.E * triton.cdiv(self.I, bn),)
-
+ num_nt = triton.cdiv(self.I, bn)
+ grid = (self.E * num_nt,)
+
+ gt, ut = self._make_tiled()
grouped_gemm_swiglu_kernel[grid](
hidden_states,
- self.W_gate,
- self.W_up,
+ gt,
+ ut,
expert_offsets,
out,
self.H,
self.I,
+ bn,
hidden_states.stride(0),
- self.W_gate.stride(0),
- self.W_gate.stride(1),
- self.W_gate.stride(2),
- self.W_up.stride(0),
- self.W_up.stride(1),
- self.W_up.stride(2),
+ gt.stride(0),
+ gt.stride(1),
+ gt.stride(2),
+ gt.stride(3),
+ ut.stride(0),
+ ut.stride(1),
+ ut.stride(2),
+ ut.stride(3),
out.stride(0),
BLOCK_M=bm,
BLOCK_N=bn,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"content": "<6298 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run check.py with tiled weights",
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 150, in forward
grouped_gemm_swiglu_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 710, in run
bound_args, specialization, options = binder(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: dynamic_func() go
... (41 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -19,7 +19,6 @@
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
- BLOCK_N: tl.constexpr,
stride_hidden_h: tl.constexpr,
stride_gate_tiled_e: tl.constexpr,
stride_gate_tiled_t: tl.constexpr,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<279 chars \u2014 see below>",
"newString": "<252 chars \u2014 see below>"
}def grouped_gemm_swiglu_kernel(
hidden_states_ptr,
W_gate_tiled_ptr,
W_up_tiled_ptr,
expert_offsets_ptr,
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
BLOCK_N: tl.constexpr,
stride_hidden_h: tl.constexpr,
stride_gate_tiled_e: tl.constexpr,def grouped_gemm_swiglu_kernel(
hidden_states_ptr,
W_gate_tiled_ptr,
W_up_tiled_ptr,
expert_offsets_ptr,
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
stride_hidden_h: tl.constexpr,
stride_gate_tiled_e: tl.constexpr,Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -19,6 +19,7 @@
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
+ BLOCK_N: tl.constexpr,
stride_hidden_h: tl.constexpr,
stride_gate_tiled_e: tl.constexpr,
stride_gate_tiled_t: tl.constexpr,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<213 chars \u2014 see below>",
"newString": "<240 chars \u2014 see below>"
}def grouped_gemm_swiglu_kernel(
hidden_states_ptr,
W_gate_tiled_ptr,
W_up_tiled_ptr,
expert_offsets_ptr,
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
stride_hidden_h: tl.constexpr,def grouped_gemm_swiglu_kernel(
hidden_states_ptr,
W_gate_tiled_ptr,
W_up_tiled_ptr,
expert_offsets_ptr,
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
BLOCK_N: tl.constexpr,
stride_hidden_h: tl.constexpr,Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -167,7 +167,6 @@
ut.stride(3),
out.stride(0),
BLOCK_M=bm,
- BLOCK_N=bn,
BLOCK_K=bk,
num_warps=nw,
num_stages=ns,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"oldString": "<608 chars \u2014 see below>",
"newString": "<584 chars \u2014 see below>"
} grouped_gemm_swiglu_kernel[grid](
hidden_states,
gt,
ut,
expert_offsets,
out,
self.H,
self.I,
bn,
hidden_states.stride(0),
gt.stride(0),
gt.stride(1),
gt.stride(2),
gt.stride(3),
ut.stride(0),
ut.stride(1),
ut.stride(2),
ut.stride(3),
out.stride(0),
BLOCK_M=bm,
BLOCK_N=bn,
BLOCK_K=bk,
num_warps=nw,
num_stages=ns,
) grouped_gemm_swiglu_kernel[grid](
hidden_states,
gt,
ut,
expert_offsets,
out,
self.H,
self.I,
bn,
hidden_states.stride(0),
gt.stride(0),
gt.stride(1),
gt.stride(2),
gt.stride(3),
ut.stride(0),
ut.stride(1),
ut.stride(2),
ut.stride(3),
out.stride(0),
BLOCK_M=bm,
BLOCK_K=bk,
num_warps=nw,
num_stages=ns,
)Edit applied successfully.{
"description": "Run check.py with fixed call",
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.178711 (atol=0.02, rtol=0.02)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_tiled.py
@@ -0,0 +1,90 @@
+"""Verify tiled kernel correctness."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+n_tokens, H, I = 128, 4096, 1536
+E, num_nt = 1, I // 128
+BN = 128
+
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w_gate = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_up = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# Tiled weights: (E, num_nt, H, BN)
+wg_tiled = w_gate.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+wu_tiled = w_up.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+
+print(f"Tiled shape: {wg_tiled.shape}")
+print(f"Tiled strides: {wg_tiled.stride()}")
+
+# Reference
+gate_ref = x @ w_gate
+up_ref = x @ w_up
+out_ref = torch.nn.functional.silu(gate_ref) * up_ref
+
+@triton.jit
+def tiled_kernel(x_ptr, wg_ptr, wu_ptr, out_ptr, H, I, BN,
+ x_s0, wg_s0, wg_s1, wg_s2,
+ wu_s0, wu_s1, wu_s2, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ nt_cnt = tl.cdiv(I, BLOCK_N)
+ e = pid // nt_cnt
+ nt = pid % nt_cnt
+ n_start = nt * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ wg_base = wg_ptr + e * wg_s0 + nt * wg_s1
+ wu_base = wu_ptr + e * wu_s0 + nt * wu_s1
+
+ for m_base in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_base + m_off < n_tokens
+ acc_g = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_u = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_base in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_base + k_off < H
+ x_t = tl.load(
+ x_ptr + (m_base + m_off)[:, None] * x_s0 + (k_base + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ wg_t = tl.load(
+ wg_base + (k_base + k_off)[:, None] * wg_s2 + n_off[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ wu_t = tl.load(
+ wu_base + (k_base + k_off)[:, None] * wu_s2 + n_off[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ acc_g = tl.dot(x_t, wg_t, acc_g)
+ acc_u = tl.dot(x_t, wu_t, acc_u)
+
+ result = (tl.sigmoid(acc_g) * acc_g * acc_u).to(tl.bfloat16)
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_s0 + n_off[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+tiled_kernel[(E * num_nt,)](
+ x.cuda(), wg_tiled.cuda(), wu_tiled.cuda(), out,
+ n_tokens, H, I, BN,
+ x.stride(0),
+ wg_tiled.stride(0), wg_tiled.stride(1), wg_tiled.stride(2),
+ wu_tiled.stride(0), wu_tiled.stride(1), wu_tiled.stride(2),
+ out.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4, num_stages=3,
+)
+
+diff = (out_ref.cuda().float() - out.float()).abs()
+print(f"Tiled kernel vs ref: max_diff={diff.max().item():.6f} mean={diff.mean().item():.6f}")
+sf = ((out_ref.cuda() * out) < 0).sum().item()
+print(f"Sign flips: {sf}/{n_tokens*I}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_tiled.py",
"content": "<3192 chars \u2014 see diff>"
}Wrote file successfully.--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_tiled2.py
@@ -0,0 +1,120 @@
+"""Verify tiled kernel correctness - all args passed properly."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+n_tokens, H, I, E = 128, 4096, 1536, 1
+BN, num_nt = 128, I // 128
+
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+wg = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+wu = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+wg_t = wg.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+wu_t = wu.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+
+gate_ref = x @ wg
+up_ref = x @ wu
+out_ref = torch.nn.functional.silu(gate_ref) * up_ref
+
+@triton.jit
+def tiled_kernel(
+ x_ptr, wg_ptr, wu_ptr, out_ptr,
+ n_tokens, H, I, BN,
+ x_s0, wg_s0, wg_s1, wg_s2,
+ wu_s0, wu_s1, wu_s2, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ nt_cnt = tl.cdiv(I, BLOCK_N)
+ e = pid // nt_cnt
+ nt = pid % nt_cnt
+
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - nt * BLOCK_N
+
+ wg_base = wg_ptr + e * wg_s0 + nt * wg_s1
+ wu_base = wu_ptr + e * wu_s0 + nt * wu_s1
+
+ for m_base in range(0, n_tokens, BLOCK_M):
+ m_mask = m_base + m_off < n_tokens
+ acc_g = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_u = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_base in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ x_t = tl.load(
+ x_ptr + (m_base + m_off)[:, None] * x_s0 + (k_base + k_off)[None, :],
+ mask=m_mask[:, None],
+ )
+ wg_t = tl.load(
+ wg_base + (k_base + k_off)[:, None] * wg_s2 + n_off[None, :],
+ mask=n_mask[None, :],
+ )
+ wu_t = tl.load(
+ wu_base + (k_base + k_off)[:, None] * wu_s2 + n_off[None, :],
+ mask=n_mask[None, :],
+ )
+ acc_g = tl.dot(x_t, wg_t, acc_g)
+ acc_u = tl.dot(x_t, wu_t, acc_u)
+
+ result = (tl.sigmoid(acc_g) * acc_g * acc_u).to(tl.bfloat16)
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_s0 + n_off[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+tiled_kernel[(E * num_nt,)](
+ x.cuda(), wg_t.cuda(), wu_t.cuda(), out,
+ n_tokens, H, I, BN,
+ x.stride(0),
+ wg_t.stride(0), wg_t.stride(1), wg_t.stride(2),
+ wu_t.stride(0), wu_t.stride(1), wu_t.stride(2),
+ out.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4, num_stages=3,
+)
+
+diff = (out_ref.cuda().float() - out.float()).abs()
+print(f"Tiled kernel: max_diff={diff.max():.6f} mean={diff.mean():.6f}")
+sf = ((out_ref.cuda() * out) < 0).sum().item()
+print(f"Sign flips: {sf}/{n_tokens*I}")
+
+# Also test just the GEMM (no silu)
+@triton.jit
+def gemm_only(x_ptr, w_ptr, out_ptr, n_tokens, H, I, BN,
+ x_s0, w_s0, w_s1, w_s2, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ nt = pid
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ w_base = w_ptr + nt * w_s1
+ for m_base in range(0, n_tokens, BLOCK_M):
+ m_mask = m_base + m_off < n_tokens
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_base in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ x_t = tl.load(
+ x_ptr + (m_base + m_off)[:, None] * x_s0 + (k_base + k_off)[None, :],
+ mask=m_mask[:, None],
+ )
+ w_t = tl.load(
+ w_base + (k_base + k_off)[:, None] * w_s2 + n_off[None, :],
+ )
+ acc = tl.dot(x_t, w_t, acc)
+ tl.store(out_ptr + (m_base + m_off)[:, None] * out_s0 + n_off[None, :],
+ acc.to(tl.bfloat16), mask=m_mask[:, None])
+
+out_g = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+gemm_only[(num_nt,)](x.cuda(), wg_t.cuda(), out_g, n_tokens, H, I, BN,
+ x.stride(0), wg_t.stride(0), wg_t.stride(1), wg_t.stride(2), out_g.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, num_warps=4, num_stages=3)
+dg = (gate_ref.cuda().float() - out_g.float()).abs()
+sfg = ((gate_ref.cuda() * out_g) < 0).sum().item()
+print(f"\nGEMM only (gate): max_diff={dg.max():.6f} sign_flips={sfg}/{n_tokens*I}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_tiled2.py",
"content": "<4398 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test tiled kernel",
"command": "uv run python debug_tiled2.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Tiled kernel: max_diff=0.094238 mean=0.005283
Sign flips: 8155/196608
GEMM only (gate): max_diff=0.779297 sign_flips=3954/196608
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_dataload.py
@@ -0,0 +1,123 @@
+"""Verify data loading from tiled weights is correct."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+H, I, BN = 4096, 1536, 128
+num_nt = I // BN
+
+# Create data
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_copy = w.clone()
+
+# Tiled version
+w_t = w.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+# Verify: w_t[0, t, :, :] should equal w[:, t*BN:(t+1)*BN]
+for t in range(min(3, num_nt)):
+ print(f"Tile {t}: match={torch.allclose(w_t[0,t,:,:], w[:, t*BN:(t+1)*BN], atol=1e-10)}")
+
+# Now test: does tl.load with tiled pointer give the right data?
+@triton.jit
+def load_test(w_ptr, out_ptr, H, BN,
+ w_s0, w_s1, w_s2, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ nt = pid
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+
+ w_base = w_ptr + nt * w_s1
+
+ for m_base in range(0, H, BLOCK_M):
+ # Load from tiled weights: shape (H, BN), stride (BN, 1)
+ tile = tl.load(
+ w_base + (m_base + m_off)[:, None] * w_s2 + n_off[None, :],
+ mask=(m_base + m_off)[:, None] < H,
+ )
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_s0 + n_off[None, :],
+ tile.to(tl.bfloat16),
+ mask=(m_base + m_off)[:, None] < H,
+ )
+
+# Load tile 0 using Triton
+out_t = torch.zeros(H, BN, dtype=torch.bfloat16, device='cuda')
+load_test[(1,)](
+ w_t.cuda(), out_t, H, BN,
+ w_t.stride(0), w_t.stride(1), w_t.stride(2), out_t.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4, num_stages=3,
+)
+
+ref_tile = w[:, 0:BN]
+diff = (ref_tile.cuda().float() - out_t.float()).abs()
+print(f"\nLoad test tile 0: max_diff={diff.max():.6f}")
+
+# Now test with a K-tile load (BLOCK_K=32, BLOCK_N=128)
+@triton.jit
+def ktile_load(w_ptr, out_ptr, H, BN,
+ w_s0, w_s1, w_s2, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ nt = pid
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+
+ w_base = w_ptr + nt * w_s1
+
+ for m_base in range(0, H, BLOCK_M):
+ tile = tl.load(
+ w_base + (m_base + m_off)[:, None] * w_s2 + n_off[None, :],
+ mask=(m_base + m_off)[:, None] < H,
+ )
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_s0 + n_off[None, :],
+ tile.to(tl.bfloat16),
+ mask=(m_base + m_off)[:, None] < H,
+ )
+
+out_k = torch.zeros(H, BN, dtype=torch.bfloat16, device='cuda')
+ktile_load[(1,)](
+ w_t.cuda(), out_k, H, BN,
+ w_t.stride(0), w_t.stride(1), w_t.stride(2), out_k.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4, num_stages=3,
+)
+
+diff_k = (ref_tile.cuda().float() - out_k.float()).abs()
+print(f"K-tile load test: max_diff={diff_k.max():.6f}")
+
+# Test: tl.load + tl.dot from tiled weights
+a = torch.randn(64, 32, dtype=torch.bfloat16) * 0.1
+# For K-tile at k=0: B = w[0:32, 0:128]
+ref_gemm = a @ w[:32, :BN]
+
+@triton.jit
+def gemm_tiled(a_ptr, w_ptr, out_ptr, w_s0, w_s1, w_s2, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ nt = pid
+ m_off = tl.arange(0, BLOCK_M)
+ n_off = tl.arange(0, BLOCK_N)
+ k_off = tl.arange(0, BLOCK_K)
+
+ w_base = w_ptr + nt * w_s1
+
+ a_t = tl.load(a_ptr + m_off[:, None] * BLOCK_K + k_off[None, :])
+ w_t = tl.load(w_base + k_off[:, None] * w_s2 + n_off[None, :])
+ c = tl.dot(a_t, w_t)
+ tl.store(out_ptr + m_off[:, None] * out_s0 + n_off[None, :], c.to(tl.bfloat16))
+
+out_gemm = torch.zeros(64, BN, dtype=torch.bfloat16, device='cuda')
+gemm_tiled[(1,)](
+ a.cuda(), w_t.cuda(), out_gemm,
+ w_t.stride(0), w_t.stride(1), w_t.stride(2), out_gemm.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4,
+)
+
+diff_g = (ref_gemm.cuda().float() - out_gemm.float()).abs()
+sf_g = ((ref_gemm.cuda() * out_gemm) < 0).sum().item()
+print(f"\nGEMM from tiled: max_diff={diff_g.max():.6f} sign_flips={sf_g}/{64*BN}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_dataload.py",
"content": "<4110 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run data load test",
"command": "uv run python debug_dataload.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Tile 0: match=True
Tile 1: match=True
Tile 2: match=True
Load test tile 0: max_diff=0.000000
K-tile load test: max_diff=0.000000
GEMM from tiled: max_diff=0.000000 sign_flips=0/8192
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_multitile.py
@@ -0,0 +1,75 @@
+"""Test full multi-tile GEMM with tiled weights."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+n_tokens, H, I, E = 128, 4096, 1536, 1
+BN, num_nt = 128, I // 128
+
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+wg = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+
+# Tiled
+wg_t = wg.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+
+# Reference
+gate_ref = x @ wg
+
+@triton.jit
+def gemm_multi(x_ptr, w_ptr, out_ptr, n_tokens, H, I, BN,
+ x_s0, w_s0, w_s1, w_s2, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ nt_cnt = tl.cdiv(I, BLOCK_N)
+ e = pid // nt_cnt
+ nt = pid % nt_cnt
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - nt * BLOCK_N
+
+ w_base = w_ptr + e * w_s0 + nt * w_s1
+
+ for m_base in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_base + m_off < n_tokens
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_base in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ x_t = tl.load(
+ x_ptr + (m_base + m_off)[:, None] * x_s0 + (k_base + k_off)[None, :],
+ mask=m_mask[:, None],
+ )
+ w_t = tl.load(
+ w_base + (k_base + k_off)[:, None] * w_s2 + n_off[None, :],
+ mask=n_mask[None, :],
+ )
+ acc = tl.dot(x_t, w_t, acc)
+
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_s0 + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+gemm_multi[(E * num_nt,)](
+ x.cuda(), wg_t.cuda(), out, n_tokens, H, I, BN,
+ x.stride(0),
+ wg_t.stride(0), wg_t.stride(1), wg_t.stride(2), out.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ num_warps=4, num_stages=3,
+)
+
+diff = (gate_ref.cuda().float() - out.float()).abs()
+print(f"Multi-N-tile GEMM: max_diff={diff.max():.6f} mean={diff.mean():.6f}")
+sf = ((gate_ref.cuda() * out) < 0).sum().item()
+print(f"Sign flips: {sf}/{n_tokens*I}")
+
+# Per-N-tile breakdown
+for nt in range(num_nt):
+ ns = nt * BN
+ ne = min(ns + BN, I)
+ tile_diff = diff[:, ns:ne]
+ print(f" NTile {nt} ({ns}:{ne}): max_diff={tile_diff.max():.6f} mean={tile_diff.mean():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_multitile.py",
"content": "<2433 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test multi-tile with tiled weights",
"command": "uv run python debug_multitile.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Multi-N-tile GEMM: max_diff=0.613281 mean=0.099396
Sign flips: 5028/196608
NTile 0 (0:128): max_diff=0.613281 mean=0.087600
NTile 1 (128:256): max_diff=0.519531 mean=0.100135
NTile 2 (256:384): max_diff=0.550781 mean=0.101003
NTile 3 (384:512): max_diff=0.484375 mean=0.099868
NTile 4 (512:640): max_diff=0.500000 mean=0.100640
NTile 5 (640:768): max_diff=0.484375 mean=0.101139
NTile 6 (768:896): max_diff=0.585938 mean=0.100555
NTile 7 (896:1024): max_diff=0.507812 mean=0.100012
NTile 8 (1024:1152): max_diff=0.511719 mean=0.099888
NTile 9 (1152:1280): max_diff=0.507812 mean=0.100254
NTile 10 (1280:1408): max_diff=0.515625 mean=0.100666
NTile 11 (1408:1536): max_diff=0.515625 mean=0.100992
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_multitile2.py
@@ -0,0 +1,70 @@
+"""Debug multi-tile GEMM - pass NUM_NT as constexpr."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+n_tokens, H, I, E, BN = 128, 4096, 1536, 1, 128
+num_nt = I // BN
+
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_t = w.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+
+ref = x @ w
+
+@triton.jit
+def gm(x_ptr, w_ptr, out_ptr, n_tokens, H, I,
+ x_s0, w_s0, w_s1, w_s2, out_s0,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ NUM_NT: tl.constexpr):
+ pid = tl.program_id(0)
+ e = pid // NUM_NT
+ nt = pid % NUM_NT
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - nt * BLOCK_N
+
+ w_base = w_ptr + e * w_s0 + nt * w_s1
+
+ for m_base in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_base + m_off < n_tokens
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_base in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ x_t = tl.load(
+ x_ptr + (m_base + m_off)[:, None] * x_s0 + (k_base + k_off)[None, :],
+ mask=m_mask[:, None],
+ )
+ w_t = tl.load(
+ w_base + (k_base + k_off)[:, None] * w_s2 + n_off[None, :],
+ )
+ acc = tl.dot(x_t, w_t, acc)
+
+ tl.store(
+ out_ptr + (m_base + m_off)[:, None] * out_s0 + n_off[None, :],
+ acc.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
+
+out = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+gm[(E * num_nt,)](
+ x.cuda(), w_t.cuda(), out, n_tokens, H, I,
+ x.stride(0), w_t.stride(0), w_t.stride(1), w_t.stride(2), out.stride(0),
+ BLOCK_M=64, BLOCK_N=128, BLOCK_K=32,
+ NUM_NT=num_nt,
+ num_warps=4, num_stages=3,
+)
+
+diff = (ref.cuda().float() - out.float()).abs()
+print(f"Multi-tile (NUM_NT constexpr): max_diff={diff.max():.6f} mean={diff.mean():.6f}")
+sf = ((ref.cuda() * out) < 0).sum().item()
+print(f"Sign flips: {sf}/{n_tokens*I}")
+
+for nt in range(num_nt):
+ ns = nt * BN
+ ne = min(ns + BN, I)
+ td = diff[:, ns:ne]
+ print(f" Tile {nt} ({ns}:{ne}): max={td.max():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_multitile2.py",
"content": "<2266 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test multi-tile with constexpr NUM_NT",
"command": "uv run python debug_multitile2.py 2>&1",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Multi-tile (NUM_NT constexpr): max_diff=0.684570 mean=0.097930
Sign flips: 4067/196608
Tile 0 (0:128): max=0.684570
Tile 1 (128:256): max=0.519531
Tile 2 (256:384): max=0.550781
Tile 3 (384:512): max=0.484375
Tile 4 (512:640): max=0.500000
Tile 5 (640:768): max=0.484375
Tile 6 (768:896): max=0.585938
Tile 7 (896:1024): max=0.507812
Tile 8 (1024:1152): max=0.511719
Tile 9 (1152:1280): max=0.507812
Tile 10 (1280:1408): max=0.515625
Tile 11 (1408:1536): max=0.515625
{
"description": "Fresh process test multi-tile",
"command": "<1509 chars \u2014 see below>",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}uv run python -c "
import torch, triton, triton.language as tl
torch.manual_seed(42)
N, H, I, BN = 128, 4096, 1536, 128
num_nt = I // BN
x = torch.randn(N, H, dtype=torch.bfloat16) * 0.1
w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
w_t = w.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
ref = x @ w
@triton.jit
def gm(xp, wp, op, N, H, I, xs, ws0, ws1, ws2, os0,
BM: tl.constexpr, BN2: tl.constexpr, BK: tl.constexpr, NNT: tl.constexpr):
pid = tl.program_id(0)
e = pid // NNT
nt = pid % NNT
no = tl.arange(0, BN2)
nm = no < I - nt * BN2
wb = wp + e * ws0 + nt * ws1
for mb in range(0, N, BM):
mo = tl.arange(0, BM)
mm = mb + mo < N
ag = tl.zeros((BM, BN2), dtype=tl.float32)
for kb in range(0, H, BK):
ko = tl.arange(0, BK)
ag = tl.dot(tl.load(xp + (mb+mo)[:,None]*xs + (kb+ko)[None,:], mask=mm[:,None]),
tl.load(wb + (kb+ko)[:,None]*ws2 + no[None,:]), ag)
tl.store(op + (mb+mo)[:,None]*os0 + no[None,:], ag.to(tl.bfloat16), mask=mm[:,None] & nm[None,:])
out = torch.zeros(N, I, dtype=torch.bfloat16, device='cuda')
gm[(1*num_nt,)](
x.cuda(), w_t.cuda(), out, N, H, I,
x.stride(0), w_t.stride(0), w_t.stride(1), w_t.stride(2), out.stride(0),
BM=64, BN2=128, BK=32, NNT=num_nt, num_warps=4, num_stages=3)
diff = (ref.cuda().float() - out.float()).abs()
sf = ((ref.cuda()*out) < 0).sum().item()
print(f'max_diff={diff.max():.6f} sign_flips={sf}/{N*I}')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 12, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __
... (150 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_compare.py
@@ -0,0 +1,79 @@
+"""Compare single-launch multi-tile vs individual single-tile launches."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+n_tokens, H, I, E, BN = 128, 4096, 1536, 1, 128
+num_nt = I // BN
+
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_t = w.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+ref = x @ w
+
+@triton.jit
+def gm_single(xp, wp, op, N, H, xs, ws2, os0,
+ BM: tl.constexpr, BN2: tl.constexpr, BK: tl.constexpr):
+ mo = tl.arange(0, BM)
+ no = tl.arange(0, BN2)
+ for mb in range(0, N, BM):
+ mm = mb + mo < N
+ ag = tl.zeros((BM, BN2), dtype=tl.float32)
+ for kb in range(0, H, BK):
+ ko = tl.arange(0, BK)
+ ag = tl.dot(tl.load(xp + (mb+mo)[:,None]*xs + (kb+ko)[None,:], mask=mm[:,None]),
+ tl.load(wp + (kb+ko)[:,None]*ws2 + no[None,:]), ag)
+ tl.store(op + (mb+mo)[:,None]*os0 + no[None,:], ag.to(tl.bfloat16), mask=mm[:,None])
+
+# Approach 1: single launch for all tiles
+@triton.jit
+def gm_multi(xp, wp, op, N, H, I, xs, ws0, ws1, ws2, os0,
+ BM: tl.constexpr, BN2: tl.constexpr, BK: tl.constexpr, NNT: tl.constexpr):
+ pid = tl.program_id(0)
+ nt = pid
+ no = tl.arange(0, BN2)
+ nm = no < I - nt * BN2
+ wb = wp + nt * ws1
+ for mb in range(0, N, BM):
+ mo = tl.arange(0, BM)
+ mm = mb + mo < N
+ ag = tl.zeros((BM, BN2), dtype=tl.float32)
+ for kb in range(0, H, BK):
+ ko = tl.arange(0, BK)
+ ag = tl.dot(tl.load(xp + (mb+mo)[:,None]*xs + (kb+ko)[None,:], mask=mm[:,None]),
+ tl.load(wb + (kb+ko)[:,None]*ws2 + no[None,:]), ag)
+ tl.store(op + (mb+mo)[:,None]*os0 + no[None,:], ag.to(tl.bfloat16), mask=mm[:,None] & nm[None,:])
+
+# Multi-launch: launch each tile separately
+out_separate = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+for nt in range(num_nt):
+ w_slice = w_t[:, nt, :, :].contiguous()
+ temp = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+ gm_single[(1,)](
+ x.cuda(), w_slice.cuda(), temp, n_tokens, H,
+ x.stride(0), w_slice.stride(0), temp.stride(0),
+ BM=64, BN2=128, BK=32, num_warps=4, num_stages=3,
+ )
+ out_separate[:, nt*BN:(nt+1)*BN] = temp[:, nt*BN:(nt+1)*BN]
+
+diff_sep = (ref.cuda().float() - out_separate.float()).abs()
+sf_sep = ((ref.cuda()*out_separate)<0).sum().item()
+print(f"Separate launches: max_diff={diff_sep.max():.6f} sign_flips={sf_sep}/{n_tokens*I}")
+
+# Single launch
+out_single = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+gm_multi[(num_nt,)](
+ x.cuda(), w_t.cuda(), out_single, n_tokens, H, I,
+ x.stride(0), w_t.stride(0), w_t.stride(1), w_t.stride(2), out_single.stride(0),
+ BM=64, BN2=128, BK=32, NNT=num_nt, num_warps=4, num_stages=3,
+)
+
+diff_single = (ref.cuda().float() - out_single.float()).abs()
+sf_single = ((ref.cuda()*out_single)<0).sum().item()
+print(f"Single launch: max_diff={diff_single.max():.6f} sign_flips={sf_single}/{n_tokens*I}")
+
+# Compare separate vs single
+diff_ss = (out_separate.cuda().float() - out_single.float()).abs()
+print(f"Separate vs Single: max_diff={diff_ss.max():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_compare.py",
"content": "<3235 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Compare separate vs single launch",
"command": "uv run python debug_compare.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_compare.py", line 59, in <module>
out_separate[:, nt*BN:(nt+1)*BN] = temp[:, nt*BN:(nt+1)*BN]
~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: operation not supported on global/shared address space
Search for `cudaErrorInvalidAddressSpace' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_compare2.py
@@ -0,0 +1,80 @@
+"""Compare single-launch multi-tile vs reference, check for consistent errors."""
+import torch
+import triton
+import triton.language as tl
+
+torch.manual_seed(42)
+
+n_tokens, H, I, E, BN = 128, 4096, 1536, 1, 128
+num_nt = I // BN
+
+x = torch.randn(n_tokens, H, dtype=torch.bfloat16) * 0.1
+w = torch.randn(H, I, dtype=torch.bfloat16) * 0.02
+w_t = w.view(1, H, num_nt, BN).permute(0, 2, 1, 3).contiguous()
+ref = x @ w
+
+# Approach: directly tile the output using a single-tile kernel per tile
+# Instead of passing temp tensor with I cols, pass one with BN cols per tile
+out_correct = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+for nt in range(num_nt):
+ w_slice = w_t[:, nt, :, :].contiguous().squeeze(0)
+
+ @triton.jit
+ def gm(xp, wp, op, N, H, xs, ws2, os0,
+ BM: tl.constexpr, BN2: tl.constexpr, BK: tl.constexpr):
+ mo = tl.arange(0, BM)
+ no = tl.arange(0, BN2)
+ for mb in range(0, N, BM):
+ mm = mb + mo < N
+ ag = tl.zeros((BM, BN2), dtype=tl.float32)
+ for kb in range(0, H, BK):
+ ko = tl.arange(0, BK)
+ ag = tl.dot(tl.load(xp + (mb+mo)[:,None]*xs + (kb+ko)[None,:], mask=mm[:,None]),
+ tl.load(wp + (kb+ko)[:,None]*ws2 + no[None,:]), ag)
+ tl.store(op + (mb+mo)[:,None]*os0 + no[None,:], ag.to(tl.bfloat16), mask=mm[:,None])
+
+ out_nt = torch.zeros(n_tokens, BN, dtype=torch.bfloat16, device='cuda')
+ gm[(1,)](
+ x.cuda(), w_slice.unsqueeze(0).cuda(), out_nt, n_tokens, H,
+ x.stride(0), w_slice.stride(0), out_nt.stride(0),
+ BM=64, BN2=128, BK=32, num_warps=4, num_stages=3,
+ )
+ out_correct[:, nt*BN:(nt+1)*BN] = out_nt
+
+diff_correct = (ref.cuda().float() - out_correct.float()).abs()
+sf_correct = ((ref.cuda()*out_correct)<0).sum().item()
+print(f"Per-tile contiguous: max_diff={diff_correct.max():.6f} sign_flips={sf_correct}/{n_tokens*I}")
+
+# Now try with the multi-tile kernel
+@triton.jit
+def gm_multi(xp, wp, op, N, H, I, xs, ws0, ws1, ws2, os0,
+ BM: tl.constexpr, BN2: tl.constexpr, BK: tl.constexpr, NNT: tl.constexpr):
+ pid = tl.program_id(0)
+ nt = pid
+ no = tl.arange(0, BN2)
+ nm = no < I - nt * BN2
+ wb = wp + nt * ws1
+ for mb in range(0, N, BM):
+ mo = tl.arange(0, BM)
+ mm = mb + mo < N
+ ag = tl.zeros((BM, BN2), dtype=tl.float32)
+ for kb in range(0, H, BK):
+ ko = tl.arange(0, BK)
+ ag = tl.dot(tl.load(xp + (mb+mo)[:,None]*xs + (kb+ko)[None,:], mask=mm[:,None]),
+ tl.load(wb + (kb+ko)[:,None]*ws2 + no[None,:]), ag)
+ tl.store(op + (mb+mo)[:,None]*os0 + no[None,:], ag.to(tl.bfloat16), mask=mm[:,None] & nm[None,:])
+
+out_multi = torch.zeros(n_tokens, I, dtype=torch.bfloat16, device='cuda')
+gm_multi[(num_nt,)](
+ x.cuda(), w_t.cuda(), out_multi, n_tokens, H, I,
+ x.stride(0), w_t.stride(0), w_t.stride(1), w_t.stride(2), out_multi.stride(0),
+ BM=64, BN2=128, BK=32, NNT=num_nt, num_warps=4, num_stages=3,
+)
+
+diff_multi = (ref.cuda().float() - out_multi.float()).abs()
+sf_multi = ((ref.cuda()*out_multi)<0).sum().item()
+print(f"Multi-tile: max_diff={diff_multi.max():.6f} sign_flips={sf_multi}/{n_tokens*I}")
+
+# Compare correct vs multi
+diff_cm = (out_correct.cuda().float() - out_multi.float()).abs()
+print(f"Correct vs Multi: max_diff={diff_cm.max():.6f}")
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/debug_compare2.py",
"content": "<3376 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Compare per-tile vs multi-tile",
"command": "uv run python debug_compare2.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}Per-tile contiguous: max_diff=0.001953 sign_flips=0/196608
Multi-tile: max_diff=0.585938 sign_flips=0/196608
Correct vs Multi: max_diff=0.585938
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: from __future__ import annotations
2:
3: import torch
4: import torch.nn as nn
5: import triton
6: import triton.language as tl
7:
8: OP_TYPE = "grouped_gemm_swiglu"
9: SUPPORTED_PRECISIONS = ["bf16"]
10: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
11:
12:
13: @triton.jit
14: def grouped_gemm_swiglu_kernel(
15: hidden_states_ptr,
16: W_gate_tiled_ptr,
17: W_up_tiled_ptr,
18: expert_offsets_ptr,
19: output_ptr,
20: H: tl.constexpr,
21: I: tl.constexpr,
22: BLOCK_N: tl.constexpr,
23: stride_hidden_h: tl.constexpr,
24: stride_gate_tiled_e: tl.constexpr,
25: stride_gate_tiled_t: tl.constexpr,
26: stride_gate_tiled_h: tl.constexpr,
27: stride_gate_tiled_n: tl.constexpr,
28: stride_up_tiled_e: tl.constexpr,
29: stride_up_tiled_t: tl.constexpr,
30: stride_up_tiled_h: tl.constexpr,
31: stride_up_tiled_n: tl.constexpr,
32: stride_out_i: tl.constexpr,
33: BLOCK_M: tl.constexpr,
34: BLOCK_K: tl.constexpr,
35: ):
36: pid = tl.program_id(0)
37: num_n_tiles = tl.cdiv(I, BLOCK_N)
38: e = pid // num_n_tiles
39: n_tile_idx = pid % num_n_tiles
40:
41: start = tl.load(expert_offsets_ptr + e)
42: end = tl.load(expert_offsets_ptr + e + 1)
43: n_tokens = end - start
44: if n_tokens <= 0:
45: return
46:
47: n_start = n_tile_idx * BLOCK_N
48:
... (5844 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -11,7 +11,7 @@
@triton.jit
-def grouped_gemm_swiglu_kernel(
+def expert_kernel(
hidden_states_ptr,
W_gate_tiled_ptr,
W_up_tiled_ptr,
@@ -19,24 +19,23 @@
output_ptr,
H: tl.constexpr,
I: tl.constexpr,
- BLOCK_N: tl.constexpr,
stride_hidden_h: tl.constexpr,
- stride_gate_tiled_e: tl.constexpr,
- stride_gate_tiled_t: tl.constexpr,
- stride_gate_tiled_h: tl.constexpr,
- stride_gate_tiled_n: tl.constexpr,
- stride_up_tiled_e: tl.constexpr,
- stride_up_tiled_t: tl.constexpr,
- stride_up_tiled_h: tl.constexpr,
- stride_up_tiled_n: tl.constexpr,
+ tiled_stride_e: tl.constexpr,
+ tiled_stride_t: tl.constexpr,
+ tiled_stride_h: tl.constexpr,
+ tiled_stride_n: tl.constexpr,
+ tiled_up_stride_e: tl.constexpr,
+ tiled_up_stride_t: tl.constexpr,
+ tiled_up_stride_h: tl.constexpr,
+ tiled_up_stride_n: tl.constexpr,
stride_out_i: tl.constexpr,
BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
+ NUM_NT: tl.constexpr,
):
pid = tl.program_id(0)
- num_n_tiles = tl.cdiv(I, BLOCK_N)
- e = pid // num_n_tiles
- n_tile_idx = pid % num_n_tiles
+ e = pid
start = tl.load(expert_offsets_ptr + e)
end = tl.load(expert_offsets_ptr + e + 1)
@@ -44,55 +43,59 @@
if n_tokens <= 0:
return
- n_start = n_tile_idx * BLOCK_N
- n_off = tl.arange(0, BLOCK_N)
- n_mask = n_off < I - n_start
-
- w_gate_tile_ptr = W_gate_tiled_ptr + e * stride_gate_tiled_e + n_tile_idx * stride_gate_tiled_t
- w_up_tile_ptr = W_up_tiled_ptr + e * stride_up_tiled_e + n_tile_idx * stride_up_tiled_t
-
- for m_start in range(0, n_tokens, BLOCK_M):
- m_off = tl.arange(0, BLOCK_M)
- m_mask = m_start + m_off < n_tokens
-
- x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k_start in range(0, H, BLOCK_K):
- k_off = tl.arange(0, BLOCK_K)
- k_mask = k_start + k_off < H
-
- x = tl.load(
- x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
- mask=m_mask[:, None] & k_mask[None, :],
+ w_gate_base = W_gate_tiled_ptr + e * tiled_stride_e
+ w_up_base = W_up_tiled_ptr + e * tiled_up_stride_e
+
+ for nt in range(NUM_NT):
+ n_start = nt * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ w_gate_nt_ptr = w_gate_base + nt * tiled_stride_t
+ w_up_nt_ptr = w_up_base + nt * tiled_up_stride_t
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+
+ x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x = tl.load(
+ x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_gate = tl.load(
+ w_gate_nt_ptr
+ + (k_start + k_off)[:, None] * tiled_stride_h
+ + n_off[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up = tl.load(
+ w_up_nt_ptr
+ + (k_start + k_off)[:, None] * tiled_up_stride_h
+ + n_off[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+
+ acc_gate = tl.dot(x, w_gate, acc_gate)
+ acc_up = tl.dot(x, w_up, acc_up)
+
+ gate_silu = tl.sigmoid(acc_gate) * acc_gate
+ result = (gate_silu * acc_up).to(tl.bfloat16)
+
+ out_ptr = output_ptr + (start + m_start) * stride_out_i
+ tl.store(
+ out_ptr + m_off[:, None] * stride_out_i + (n_start + n_off)[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
)
- w_gate = tl.load(
- w_gate_tile_ptr
- + (k_start + k_off)[:, None] * stride_gate_tiled_h
- + n_off[None, :],
- mask=k_mask[:, None] & n_mask[None, :],
- )
- w_up = tl.load(
- w_up_tile_ptr
- + (k_start + k_off)[:, None] * stride_up_tiled_h
- + n_off[None, :],
- mask=k_mask[:, None] & n_mask[None, :],
- )
-
- acc_gate = tl.dot(x, w_gate, acc_gate)
- acc_up = tl.dot(x, w_up, acc_up)
-
- gate_silu = tl.sigmoid(acc_gate) * acc_gate
- result = (gate_silu * acc_up).to(tl.bfloat16)
-
- out_ptr = output_ptr + (start + m_start) * stride_out_i
- tl.store(
- out_ptr + m_off[:, None] * stride_out_i + n_off[None, :],
- result,
- mask=m_mask[:, None] & n_mask[None, :],
- )
class Model(nn.Module):
@@ -110,13 +113,11 @@
self._gate_tiled: torch.Tensor | None = None
self._up_tiled: torch.Tensor | None = None
- def _make_tiled(self) -> tuple[torch.Tensor, torch.Tensor]:
- """Reorganize (E, H, I) weights to (E, num_nt, H, BLOCK_N) for TMA-compatible access."""
+ def _make_tiled(self, bn: int) -> tuple[torch.Tensor, torch.Tensor]:
if self._gate_tiled is not None:
return self._gate_tiled, self._up_tiled
... (1090 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"content": "<6232 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run check.py with per-expert approach",
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}PASS
{
"todos": [
{
"priority": "high",
"content": "Read reference, problem files, and understand the op specification",
"status": "completed"
},
{
"priority": "high",
"content": "Write initial Triton-based solution.py with grouped GEMM + fused SwiGLU",
"status": "completed"
},
{
"priority": "high",
"content": "Run check.py to verify correctness across all shapes and seeds",
"status": "completed"
},
{
"priority": "high",
"content": "Profile with benchmark.py and iterate on performance",
"status": "
... (301 more chars)content: 'Read reference, problem files, and understand the op specification'
status: 'completed'
priority: 'high'
content: 'Write initial Triton-based solution.py with grouped GEMM + fused SwiGLU'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes and seeds'
status: 'completed'
priority: 'high'
content: 'Profile with benchmark.py and iterate on performance'
status: 'in_progress'
priority: 'high'
content: 'Optimize tiling, autotuning, and memory access patterns'
status: 'pending'
priority: 'medium'
content: 'Final verification: check.py PASS + benchmark.py'
status: 'pending'
priority: 'high'{
"description": "Run benchmark.py",
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0427 12:10:38.694000 1159233 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0
... (808 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -10,6 +10,20 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+# Autotune configs
+_configs = [
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+]
+
+
+@triton.autotune(_configs, key=[])
@triton.jit
def expert_kernel(
hidden_states_ptr,
@@ -32,10 +46,16 @@
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
- NUM_NT: tl.constexpr,
):
pid = tl.program_id(0)
- e = pid
+
+ # Precompute N-tile count from I and BLOCK_N
+ # Triton can't use tl.cdiv in constexpr context, use Python
+ # Actually tl.cdiv is available as a regular function
+ num_nt = tl.cdiv(I, BLOCK_N)
+
+ e = pid // num_nt
+ nt = pid % num_nt
start = tl.load(expert_offsets_ptr + e)
end = tl.load(expert_offsets_ptr + e + 1)
@@ -43,59 +63,55 @@
if n_tokens <= 0:
return
- w_gate_base = W_gate_tiled_ptr + e * tiled_stride_e
- w_up_base = W_up_tiled_ptr + e * tiled_up_stride_e
-
- for nt in range(NUM_NT):
- n_start = nt * BLOCK_N
- n_off = tl.arange(0, BLOCK_N)
- n_mask = n_off < I - n_start
-
- w_gate_nt_ptr = w_gate_base + nt * tiled_stride_t
- w_up_nt_ptr = w_up_base + nt * tiled_up_stride_t
-
- for m_start in range(0, n_tokens, BLOCK_M):
- m_off = tl.arange(0, BLOCK_M)
- m_mask = m_start + m_off < n_tokens
-
- x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k_start in range(0, H, BLOCK_K):
- k_off = tl.arange(0, BLOCK_K)
- k_mask = k_start + k_off < H
-
- x = tl.load(
- x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
- mask=m_mask[:, None] & k_mask[None, :],
- )
- w_gate = tl.load(
- w_gate_nt_ptr
- + (k_start + k_off)[:, None] * tiled_stride_h
- + n_off[None, :],
- mask=k_mask[:, None] & n_mask[None, :],
- )
- w_up = tl.load(
- w_up_nt_ptr
- + (k_start + k_off)[:, None] * tiled_up_stride_h
- + n_off[None, :],
- mask=k_mask[:, None] & n_mask[None, :],
- )
-
- acc_gate = tl.dot(x, w_gate, acc_gate)
- acc_up = tl.dot(x, w_up, acc_up)
-
- gate_silu = tl.sigmoid(acc_gate) * acc_gate
- result = (gate_silu * acc_up).to(tl.bfloat16)
-
- out_ptr = output_ptr + (start + m_start) * stride_out_i
- tl.store(
- out_ptr + m_off[:, None] * stride_out_i + (n_start + n_off)[None, :],
- result,
- mask=m_mask[:, None] & n_mask[None, :],
+ n_start = nt * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ w_gate_ptr = W_gate_tiled_ptr + e * tiled_stride_e + nt * tiled_stride_t
+ w_up_ptr = W_up_tiled_ptr + e * tiled_up_stride_e + nt * tiled_up_stride_t
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+
+ x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x = tl.load(
+ x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
)
+ w_gate = tl.load(
+ w_gate_ptr
+ + (k_start + k_off)[:, None] * tiled_stride_h
+ + n_off[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up = tl.load(
+ w_up_ptr
+ + (k_start + k_off)[:, None] * tiled_up_stride_h
+ + n_off[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+
+ acc_gate = tl.dot(x, w_gate, acc_gate)
+ acc_up = tl.dot(x, w_up, acc_up)
+
+ gate_silu = tl.sigmoid(acc_gate) * acc_gate
+ result = (gate_silu * acc_up).to(tl.bfloat16)
+
+ out_ptr = output_ptr + (start + m_start) * stride_out_i
+ tl.store(
+ out_ptr + m_off[:, None] * stride_out_i + (n_start + n_off)[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
+ )
class Model(nn.Module):
@@ -117,11 +133,11 @@
if self._gate_tiled is not None:
return self._gate_tiled, self._up_tiled
E, H, I = self.E, self.H, self.I
- num_nt = triton.cdiv(I, bn)
+ num_nt = (I + bn - 1) // bn
pad = num_nt * bn - I
if pad > 0:
- wg = torch
... (1203 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"content": "<6769 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Check correctness with autotune",
"command": "uv run python check.py 2>&1",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.162201 (atol=0.02, rtol=0.02)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,3 +1,4 @@
+"""Per-expert kernel with N-tile iteration - correct + optimized."""
from __future__ import annotations
import torch
@@ -10,20 +11,20 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-# Autotune configs
-_configs = [
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
-]
-
-
-@triton.autotune(_configs, key=[])
+@triton.autotune(
+ [
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ ],
+ key=[],
+ warmup=5,
+ rep=10,
+)
@triton.jit
def expert_kernel(
hidden_states_ptr,
@@ -37,25 +38,16 @@
tiled_stride_e: tl.constexpr,
tiled_stride_t: tl.constexpr,
tiled_stride_h: tl.constexpr,
- tiled_stride_n: tl.constexpr,
tiled_up_stride_e: tl.constexpr,
tiled_up_stride_t: tl.constexpr,
tiled_up_stride_h: tl.constexpr,
- tiled_up_stride_n: tl.constexpr,
stride_out_i: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
-
- # Precompute N-tile count from I and BLOCK_N
- # Triton can't use tl.cdiv in constexpr context, use Python
- # Actually tl.cdiv is available as a regular function
- num_nt = tl.cdiv(I, BLOCK_N)
-
- e = pid // num_nt
- nt = pid % num_nt
+ e = pid
start = tl.load(expert_offsets_ptr + e)
end = tl.load(expert_offsets_ptr + e + 1)
@@ -63,55 +55,61 @@
if n_tokens <= 0:
return
- n_start = nt * BLOCK_N
- n_off = tl.arange(0, BLOCK_N)
- n_mask = n_off < I - n_start
-
- w_gate_ptr = W_gate_tiled_ptr + e * tiled_stride_e + nt * tiled_stride_t
- w_up_ptr = W_up_tiled_ptr + e * tiled_up_stride_e + nt * tiled_up_stride_t
-
- for m_start in range(0, n_tokens, BLOCK_M):
- m_off = tl.arange(0, BLOCK_M)
- m_mask = m_start + m_off < n_tokens
-
- x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k_start in range(0, H, BLOCK_K):
- k_off = tl.arange(0, BLOCK_K)
- k_mask = k_start + k_off < H
-
- x = tl.load(
- x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
- mask=m_mask[:, None] & k_mask[None, :],
+ w_gate_base = W_gate_tiled_ptr + e * tiled_stride_e
+ w_up_base = W_up_tiled_ptr + e * tiled_up_stride_e
+
+ num_nt = tl.cdiv(I, BLOCK_N)
+
+ for nt in range(num_nt):
+ n_start = nt * BLOCK_N
+ n_off = tl.arange(0, BLOCK_N)
+ n_mask = n_off < I - n_start
+
+ w_gate_ptr = w_gate_base + nt * tiled_stride_t
+ w_up_ptr = w_up_base + nt * tiled_up_stride_t
+
+ for m_start in range(0, n_tokens, BLOCK_M):
+ m_off = tl.arange(0, BLOCK_M)
+ m_mask = m_start + m_off < n_tokens
+
+ x_ptr = hidden_states_ptr + (start + m_start) * stride_hidden_h
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, H, BLOCK_K):
+ k_off = tl.arange(0, BLOCK_K)
+ k_mask = k_start + k_off < H
+
+ x = tl.load(
+ x_ptr + m_off[:, None] * stride_hidden_h + (k_start + k_off)[None, :],
+ mask=m_mask[:, None] & k_mask[None, :],
+ )
+ w_gate = tl.load(
+ w_gate_ptr
+ + (k_start + k_off)[:, None] * tiled_stride_h
+ + n_off[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+ w_up = tl.load(
+ w_up_ptr
+ + (k_start + k_off)[:, None] * tiled_up_stride_h
+ + n_off[None, :],
+ mask=k_mask[:, None] & n_mask[None, :],
+ )
+
+ acc_gate = tl.dot(x, w_gate, acc_gate)
+ acc_up = tl.dot(x, w_up, acc_up)
+
+ gate_silu = tl.sigmoid(acc_gate) * acc_gate
+ result = (gate_silu * acc_up).to(tl.bfloat16)
+
+ out_ptr = output_ptr + (start + m_start) * stride_out_i
+ tl.store(
+ out_ptr + m_off[:, None] * stride_out_i + (n_start + n_off)[None, :],
+ result,
+ mask=m_mask[:, None] & n_mask[None, :],
)
- w_gate = tl
... (1604 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"content": "<6623 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run check with optimized kernel",
"command": "uv run python check.py 2>&1",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.201923 (atol=0.02, rtol=0.02)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,4 +1,3 @@
-"""Per-expert kernel with N-tile iteration - correct + optimized."""
from __future__ import annotations
import torch
@@ -11,20 +10,6 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
-@triton.autotune(
- [
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
- ],
- key=[],
- warmup=5,
- rep=10,
-)
@triton.jit
def expert_kernel(
hidden_states_ptr,
@@ -38,9 +23,11 @@
tiled_stride_e: tl.constexpr,
tiled_stride_t: tl.constexpr,
tiled_stride_h: tl.constexpr,
+ tiled_stride_n: tl.constexpr,
tiled_up_stride_e: tl.constexpr,
tiled_up_stride_t: tl.constexpr,
tiled_up_stride_h: tl.constexpr,
+ tiled_up_stride_n: tl.constexpr,
stride_out_i: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
@@ -154,6 +141,11 @@
out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
bn = 128
+ bm = 64
+ bk = 32
+ nw = 4
+ ns = 3
+
gt, ut = self._make_tiled(bn)
expert_kernel[(self.E,)](
@@ -168,10 +160,17 @@
gt.stride(0),
gt.stride(1),
gt.stride(2),
+ gt.stride(3),
ut.stride(0),
ut.stride(1),
ut.stride(2),
+ ut.stride(3),
out.stride(0),
+ BLOCK_M=bm,
+ BLOCK_N=bn,
+ BLOCK_K=bk,
+ num_warps=nw,
+ num_stages=ns,
)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"content": "<6106 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Verify correctness",
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}PASS
{
"description": "Benchmark current solution",
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0427 12:18:29.483000 1161514 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0
... (808 more chars)