./kernelbenchhardv3runsopencode · ?
opencode / ? session: ses_2266f37d3ffeFQ2DF0wluaZgTO cwd: ?
INCOMPLETE SESSION. hit wall-clock budget (SIGTERM). The transcript below is usable but may be missing the agent's final tool calls or summary. Don't score this run as a clean failure or success.
harness
opencode
model
?
turns
111
tools called
57
events
168
input toks
1,928,295
output toks
27,957
cache hit
0
solution.py
final answer
benchmark.log
check.log
result.json
"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.

2D grid: (sum_e cdiv(n_e, BM), cdiv(I, BN)) with per-block expert lookup
via binary search on cumulative tile counts.

Block sizes tuned for SM120 shared memory budget (~98KB usable).
"""
from __future__ import annotations

import torch
import torch.nn as nn
import triton
import triton.language as tl


@triton.jit
def _grouped_gemm_swiglu_kernel(
    X_ptr,             # (T_perm, H), row-major
    W_gate_ptr,        # (E, H, I), row-major
    W_up_ptr,          # (E, H, I), row-major
    Out_ptr,           # (T_perm, I), row-major
    Offsets_ptr,       # (E+1,) int32
    CumMTiles_ptr,     # (E+1,) int32
    H,
    I_val,
    BM: tl.constexpr,
    BN: tl.constexpr,
    BK: tl.constexpr,
    NUM_SMS: tl.constexpr,
):
    pid = tl.program_id(0)
    pid_n = tl.program_id(1)

    # 1D -> (pid_m, ...) mapping
    pid_m = pid

    # Find expert: largest e such that CumMTiles[e] <= pid_m
    # E is at most 128; binary search
    lo = 0
    hi = 127
    while lo < hi:
        mid = (lo + hi + 1) >> 1
        if tl.load(CumMTiles_ptr + mid).to(tl.int32) <= pid_m:
            lo = mid
        else:
            hi = mid - 1
    expert = lo

    # Intra-expert M-tile index
    m_tile_idx = pid_m - tl.load(CumMTiles_ptr + expert).to(tl.int32)

    # Expert row range
    row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
    row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
    n_e = row_end - row_start

    # Row offsets
    m_offs = m_tile_idx * BM + tl.arange(0, BM)
    m_valid = m_offs < n_e
    row_g = row_start + m_offs

    # Column offsets
    n_offs = pid_n * BN + tl.arange(0, BN)
    n_valid = n_offs < I_val

    # K offsets
    k_offs = tl.arange(0, BK)

    acc_g = tl.zeros((BM, BN), dtype=tl.float32)
    acc_u = tl.zeros((BM, BN), dtype=tl.float32)

    w_base = expert * H * I_val

    for k_start in range(0, H, BK):
        k = k_start + k_offs
        k_mask = k < H

        a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * H + k[None, :]
        a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)

        wg_ptrs = W_gate_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
        wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)

        wu_ptrs = W_up_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
        wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)

        acc_g += tl.dot(a, wg)
        acc_u += tl.dot(a, wu)

    out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
    out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * I_val + n_offs[None, :]
    tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])


def _launch(X, W_gate, W_up, out, expert_offsets):
    H = X.size(1)
    I_val = out.size(1)
    E = expert_offsets.size(0) - 1
    device = X.device

    # Shared memory: (BM*BK + 2*BK*BN)*2 bytes, num_stages=1
    # We want to maximize BM*BN*BK under ~98KB.
    # 256x128x64: 65536, flops/tile=2097152
    # 128x128x128: 98304, flops/tile=2097152
    # 128x256x64: 65536, flops/tile=2097152
    if H >= 4096 and I_val <= 4096:
        BM, BN, BK = 256, 128, 64
    elif I_val >= 4096:
        BM, BN, BK = 128, 128, 128
    else:
        BM, BN, BK = 128, 128, 128

    counts = expert_offsets[1:] - expert_offsets[:-1]
    m_tiles = (counts + BM - 1) // BM  # (E,)
    cum_m_tiles = torch.cat([
        torch.zeros(1, dtype=torch.int32, device=device),
        torch.cumsum(m_tiles, dim=0).to(torch.int32)
    ])
    total_m_tiles = int(cum_m_tiles[-1])

    n_tiles = triton.cdiv(I_val, BN)
    grid = (total_m_tiles, n_tiles)

    _grouped_gemm_swiglu_kernel[grid](
        X, W_gate, W_up, out, expert_offsets, cum_m_tiles,
        H, I_val,
        BM=BM, BN=BN, BK=BK,
        NUM_SMS=128,
        num_stages=1,
        num_warps=8,
    )


class Model(nn.Module):
    def __init__(self, T_total, H, I, E, K):  # noqa: E741
        super().__init__()
        self.I = I
        self.H = H
        self.E = E
        self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
        self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
        nn.init.normal_(self.W_gate, std=0.02)
        nn.init.normal_(self.W_up, std=0.02)

    def forward(self, hidden_states, expert_offsets):
        T_perm = hidden_states.size(0)
        out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
        _launch(hidden_states, self.W_gate, self.W_up, out, expert_offsets)
        return out


T_total = 32768
H = 4096
I = 1536  # noqa: E741
E = 128
K = 8


def get_inputs():
    from reference import _build_routing
    T_perm = T_total * K
    hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
    expert_offsets = _build_routing(T_total, E, K)
    return [hidden_states, expert_offsets]


def get_init_inputs():
    return [T_total, H, I, E, K]
Let me reconsider the approach. The binary search in Triton adds latency. Let me think about the actual compute bound vs memory bound aspects... Actually for H100/SM120 with 1.8TB/s and H=4096, I=1536: - FLOPs: 2 * T_total * K * 2 * H * I = 2 * 32768 * 8 * 2 * 4096 * 1536 = ~1.32 * 10^12 FMA ops - Bytes read X: T_perm * H * 2 = 262144 * 4096 * 2 = ~2.1 GB - Bytes read W: E * 2 * H * I * 2 = 128 * 2 * 4096 * 1536 * 2 = ~3.2 GB - Bytes write: T_perm * I * 2 = 262144 * 1536 * 2 = ~0.8 GB - Total bytes: ~6.1 GB So the peak is roughly 220 TFLOPs and 1800 GB/s. At compute bound, minimum time is 1.32e12 / 200e12 / 2 = ~3.3ms. At memory bound, 6.1e9 / 1800e9 = ~3.4ms. It's actually right at the bandwidth/compute balance. The issue with the current code is that loading X from global memory is expensive for large H. We need to load X row tiles into shared memory and compute everything at once, OR use Triton's automatic tiling more effectively. Let me step back and think about what's efficient here. Each token-tile (say 256 tokens) touches: - 256*4096 bf16 = 2MB of X input (streaming read, could use L2 cache) - 128 * 4096 * 1536 bf16 = ~1.56GB of weights The weights are the dominant bandwidth. Since weights are 128 experts * 2 matrix types * H * I, that's 3.2GB. If streamed through L2 and GEMM is compute bound, we want to maximize compute density by making BM large. Let me rewrite with a cleaner approach and then benchmark.
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]   File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     start = int(expert_offsets[e].item())
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
W0429 08:55:32.418000 1047727 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
shape=0 variant=eager tflops=32.879 gbps=246.163 ms=25.081
shape=0 variant=compiled tflops=32.402 gbps=242.591 ms=25.450
shape=0 variant=solution tflops=0.488 gbps=3.653 ms=1690.150
shape=0 solution_peak_fraction=0.0024
shape=1 variant=eager tflops=13.959 gbps=259.002 ms=2.462
shape=1 variant=compiled tflops=14.003 gbps=259.826 ms=2.454
shape=1 variant=solution tflops=43.275 gbps=802.956 ms=0.794
shape=1 solution_peak_fraction=0.2164
shape=2 variant=eager tflops=34.766 gbps=237.657 ms=15.813
shape=2 variant=compiled tflops=34.578 gbps=236.376 ms=15.899
shape=2 variant=solution tflops=24.664 gbps=168.604 ms=22.289
shape=2 solution_peak_fraction=0.1233
peak_fraction: 0.0402
RESULT: LOW
PASS
{
    "problem": "06_sonic_moe_swiglu",
    "harness": "opencode",
    "model": "openrouter-pinned/qwen/qwen3.6-plus",
    "reasoning_effort": "",
    "has_solution": true,
    "correct": true,
    "peak_fraction": 0.0402,
    "elapsed_seconds": 2700,
    "harness_exit_code": 124,
    "session_complete": false,
    "usage": {"input_tokens": 1928295, "output_tokens": 27957, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 21104, "total_cost_usd": null}
}

timeline (168 events)

assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu"
}
tool
stdout (213 chars)
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py

(7 entries)
</entries>
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py"
}
tool
stdout (4366 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py</path>
<type>file</type>
<content>
1: """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2: 
3: This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4: expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5: according to routing metadata; we compute, per expert e:
6: 
7:     h_e = silu(x_e @ W_gate[e])  *  (x_e @ W_up[e])
8: 
9: where x_e is the slice of permuted hidden states routed to expert e, with
10: expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11: 
12: The reference loops over experts in Python. Slow, but pedagogically clear and
13: correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14: are NOT used here, but the reference is exempt — only solution.py is checked.
15: """
16: from __future__ import annotations
17: 
18: import torch
19: import torch.nn as nn
20: import torch.nn.functional as F
21: 
22: OP_TYPE = "grouped_gemm_swiglu"
23: SUPPORTED_PRECISIONS = ["bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25: 
26: 
27: class Model(nn.Module):
28:     """Up-projection of a top-K MoE FFN with fused SwiGLU.
29: 
30:     Inputs at call time:
31:       hidden_states:    (T_perm, H)  bf16, already permuted to expert order
32:       expert_offsets:   (E+1,)       int32, prefix sums of token counts per expert
33:                

... (2866 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py"
}
tool
stdout (4293 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2: 
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9: 
10: import torch
11: import yaml
12: 
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16: 
17: from src.eval.correctness import check_correctness  # noqa: E402
18: 
19: 
20: def main():
21:     try:
22:         import reference
23:         import shapes
24:         import solution
25:     except Exception as e:
26:         print(f"FAIL: import error: {e}")
27:         sys.exit(1)
28: 
29:     problem_yaml = Path("problem.yaml")
30:     meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31: 
32:     # --- Forbidden-op check ------------------------------------------------
33:     sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34:     for forbidden in meta.get("forbidden", []):
35:         pat = re.escape(forbidden)
36:         if re.search(pat, sol_src):
37:             print(f"FAIL: forbidden op used: {forbidden}")
38:             sys.exit(1)
39: 
40:     device = torch.device("cuda:0")
41:     tol_override = meta.ge

... (2793 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py"
}
tool
stdout (5201 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2: 
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5: 
6: Output lines the harness picks up:
7:   shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8:   peak_fraction: <N>  (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13: 
14: import torch
15: import yaml
16: 
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19: 
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction  # noqa: E402
21: from src.eval.timing import time_fn  # noqa: E402
22: from src.hardware import get as get_hw  # noqa: E402
23: 
24: 
25: def _eval_formula(expr: str, vars: dict) -> float:
26:     return float(eval(expr, {"__builtins__": {}}, vars))
27: 
28: 
29: def main():
30:     import reference
31:     import shapes
32:     import solution
33: 
34:     meta = yaml.safe_load(Path("problem.yaml").read_text())
35:     hw = get_hw(meta["hardware"][0])
36:     peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37:     peak_gbps = hw.peak_bandwidth_gb_s
38:     regime = meta.get("regime", "compute")
39:     flops_formula = meta["

... (3701 more chars)
assistant
in 16664out 203cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py"
}
tool
stdout (1068 chars)
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2: 
3: Defaults match the sonic-moe paper's headline configuration. We add:
4:   - a smaller shape for fast iteration during agent development
5:   - a wider intermediate (different aspect ratio) to stress N-tile selection
6: """
7: 
8: SHAPES = [
9:     # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10:     {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11: 
12:     # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13:     # variable-length grouped layout meaningful, but smaller token / hidden dims.
14:     {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15: 
16:     # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17:     # Forces tiles to handle larger N relative to K.
18:     {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19: ]

(End of file - total 19 lines)
</content>
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml"
}
tool
stdout (2181 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml</path>
<type>file</type>
<content>
1: name: 06_sonic_moe_swiglu
2: display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3: precision: bf16
4: regime: compute
5: hardware: [RTX_PRO_6000]
6: peak_tflops_key: bf16
7: peak_bandwidth_key: dram
8: 
9: # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10: # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11: # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12: # only if you count routing. The standard MoE FLOPs convention counts only the
13: # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14: flops_formula: "2 * T_total * H * (2 * I)"
15: 
16: # Bytes moved (approximate, lower bound):
17: #   read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18: #   expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19: bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20: 
21: tolerance:
22:   bfloat16: 0.02
23: 
24: # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25: # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26: # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27: forbidden:
28:   - "torch.matmul"
29:   - "torch.bmm"
30:   - "torch.nn.functional.linear"
31:   - "F.linear"
32:   - "from son

... (681 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py"
}
tool
stdout (3255 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
2: 
3: Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
4: requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
5: SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
6: the package installs cleanly but kernels may fail at launch on SM120 (the
7: QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
8: 
9: If the live call fails, `is_available()` returns False and the benchmark scores
10: the agent against PyTorch eager + the documented H100 paper ceiling (see
11: problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
12: importing sonic_moe in solution.py (see problem.yaml.forbidden).
13: """
14: from __future__ import annotations
15: 
16: import torch
17: 
18: 
19: def _try_sonic_moe(
20:     hidden_states: torch.Tensor,
21:     W_gate: torch.Tensor,
22:     W_up: torch.Tensor,
23:     expert_offsets: torch.Tensor,
24: ) -> torch.Tensor | None:
25:     try:
26:         import sonic_moe  # type: ignore  # noqa: F401
27:     except Exception:
28:         return None
29:     try:
30:         # Public sonic-moe API surface is still stabilizing. The expected entry
31:         # point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
32:         # SwiGLU. Adapt to the a

... (1755 more chars)
assistant
glob /home/infatoshi/cuda/KernelBench-Hard
args
{
  "pattern": "**/solution.py",
  "path": "/home/infatoshi/cuda/KernelBench-Hard"
}
tool
stdout (14 chars)
No files found
assistant
in 21640out 208cache_r 0cache_w 0
assistant
task
args
{
  "prompt": "<545 chars \u2014 see below>",
  "description": "Explore existing solutions",
  "subagent_type": "general"
}
prompt (545 chars)
Search the KernelBench-Hard codebase for:
1. All existing solution.py files that use Triton - find 2-3 examples and read them
2. Check if Triton is installed and what version (look in pyproject.toml or constraints files)
3. Check if CUTTLE is available and what version
4. Look at any existing grouped GEMM implementations in the repo
5. Look at src/eval/correctness.py to understand the check_correctness function

Return: Triton version, CUTLASS availability, code patterns from 2-3 triton solutions, and the check_correctness tolerance logic.
tool
stdout (1100 chars)
task_id: ses_2266ee687ffe4ve73pZR7wQNzM (for resuming to continue this task if needed)

<task_result>
**Results:**

**1. No existing solution.py files** — none have been generated yet (agents haven't produced them). There are no import triton/`from triton` statements anywhere in the repo.

**2. Triton version:** 3.6.0 (installed and working with `triton.language` module)

**3. CUTLASS:** NOT available — not installed at `/usr/local/cuda/include/cutlass` or anywhere under the default CUDA path.

**4. Grouped GEMM reference** (`problems/06_sonic_moe_swiglu/reference.py`):
- `reference.py:61-70` — naive Python loop: `for e in range(E)`, slices `hidden_states[start:end]`, does `x_e @ W_gate[e]` and `x_e @ W_up[e]`, then `F.silu(gate) * up`.
- No existing Triton grouped GEMM implementations exist in the repo.

**5. check_correctness tolerances** (`src/eval/correctness.py`):
- fp32: atol=1e-4, rtol=1e-4
- fp16/bf16: atol=1e-2, rtol=1e-2
- fp8 (e4m3/e5m2): atol=0.1, rtol=0.1
- int8/32/64: exact match (atol=0, rtol=0)
- Casts both tensors to fp32 before comparison (line 80-81)
</task_result>
assistant
in 24220out 190cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/pyproject.toml
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/pyproject.toml"
}
tool
stdout (914 chars)
<path>/home/infatoshi/cuda/KernelBench-Hard/pyproject.toml</path>
<type>file</type>
<content>
1: [project]
2: name = "kernelbench-hard"
3: version = "0.1.0"
4: description = "Surgical GPU kernel benchmark. Small deck, frontier agents, roofline metric, link-don't-spoil problem briefs."
5: requires-python = ">=3.11"
6: dependencies = [
7:     "torch>=2.11",
8:     "numpy>=1.26",
9:     "pyyaml>=6.0",
10:     "rich>=13.7",
11:     "matplotlib>=3.9",
12:     "pytest>=8.0",
13:     "ruff>=0.6",
14:     "einops>=0.8.2",
15:     "requests<2.33",
16: ]
17: 
18: [build-system]
19: requires = ["hatchling"]
20: build-backend = "hatchling.build"
21: 
22: [tool.hatch.build.targets.wheel]
23: packages = ["src"]
24: 
25: [tool.ruff]
26: line-length = 100
27: target-version = "py311"
28: 
29: [tool.ruff.lint]
30: select = ["E", "F", "W", "I", "B", "UP"]
31: ignore = ["E501"]

(End of file - total 31 lines)
</content>
assistant
glob /home/infatoshi/cuda/KernelBench-Hard
args
{
  "pattern": "**/solution.py",
  "path": "/home/infatoshi/cuda/KernelBench-Hard"
}
tool
stdout (14 chars)
No files found
assistant
read /home/infatoshi/cuda/KernelBench-Hard
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard"
}
tool
stdout (314 chars)
<path>/home/infatoshi/cuda/KernelBench-Hard</path>
<type>directory</type>
<entries>
.cursorrules
.git/
.gitignore
.pytest_cache/
.python-version
.ruff_cache/
.venv/
AGENTS.md
CLAUDE.md
DEVLOG.md
docs/
LICENSE
outputs/
problems/
pyproject.toml
README.md
scripts/
SPEC.md
src/
tests/
uv.lock

(21 entries)
</entries>
assistant
invalid
args
{
  "tool": "",
  "error": "Model tried to call unavailable tool ''. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill."
}
tool
stdout (185 chars)
The arguments provided to the tool are invalid: Model tried to call unavailable tool ''. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill.
assistant
in 24813out 151cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}
tool
stdout (4080 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2: 
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7: 
8: import torch
9: 
10: DEFAULT_TOLERANCE = {
11:     torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12:     torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13:     torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14:     torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15:     torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16:     torch.int8: {"atol": 0, "rtol": 0},
17:     torch.int32: {"atol": 0, "rtol": 0},
18:     torch.int64: {"atol": 0, "rtol": 0},
19: }
20: 
21: 
22: def _coerce_float(x):
23:     """Accept int/float/str scientific-notation as float; pass through dicts."""
24:     if isinstance(x, (int, float)):
25:         return float(x)
26:     if isinstance(x, str):
27:         try:
28:             return float(x)
29:         except ValueError:
30:             return x
31:     return x
32: 
33: 
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35:     """Lookup atol/rtol for a given dtype, with optional per-problem override.
36: 
37:     The override dict is keyed by str(dtype). Values may be:
38:       - a single number (or a string parseable as float, e.g. "5e-2") -> used
39:         as both atol and rtol
4

... (2580 more chars)
assistant
bash
args
{
  "command": "nvidia-smi --query-gpu=name,count,driver_version --format=csv",
  "description": "Check GPU info"
}
tool
stdout (92 chars)
name, count, driver_version
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 1, 595.58.03
assistant
bash
args
{
  "command": "uv run python -c \"import triton; print(triton.__version__); print(triton.runtime.driver.active.get_current_target())\"",
  "description": "Check Triton version and target"
}
tool
stdout (56 chars)
3.6.0
GPUTarget(backend='cuda', arch=120, warp_size=32)
assistant
in 25605out 173cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python -c \"import triton; from triton.language._experimental import *; print('Triton works')\"",
  "description": "Verify Triton basic imports"
}
tool
stdout (143 chars)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'triton.language._experimental'
assistant
bash
args
{
  "command": "<191 chars \u2014 see below>",
  "description": "Test basic BF16 matmul on device"
}
command (191 chars)
uv run python -c "
import torch
x = torch.randn(4, 4, device='cuda', dtype=torch.bfloat16)
y = torch.randn(4, 4, device='cuda', dtype=torch.bfloat16)
print(x @ y)
print('BF16 matmul works')
"
tool
stdout (249 chars)
tensor([[-0.7734,  2.4375,  0.5195, -2.5156],
        [ 1.0625,  2.9688, -0.8477,  1.4453],
        [ 1.1797,  1.9219, -0.4219,  1.2891],
        [ 1.2500, -0.2461,  1.2656, -1.0000]], device='cuda:0',
       dtype=torch.bfloat16)
BF16 matmul works
assistant
in 27322out 187cache_r 0cache_w 0
assistant
todowrite
args
{
  "todos": [
    {
      "content": "Implement basic Triton grouped GEMM + SwiGLU kernel with Python loop over experts",
      "status": "in_progress",
      "priority": "high"
    },
    {
      "content": "Verify correctness with check.py",
      "status": "pending",
      "priority": "high"
    },
    {
      "content": "Profile and optimize kernel",
      "status": "pending",
      "priority": "high"
    },
    {
      "content": "Run benchmark to get timing numbers",
      "status": "pending",
      "priority": "medium"
    }
  ]
}
tool
stdout (368 chars)
content: 'Implement basic Triton grouped GEMM + SwiGLU kernel with Python loop over experts'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel'
status: 'pending'
priority: 'high'
content: 'Run benchmark to get timing numbers'
status: 'pending'
priority: 'medium'
assistant
in 27731out 136cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (193 lines — TRUNCATED)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,190 @@
+"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
+
+Per-expert dense GEMM with SwiGLU fused in the epilogue. Loops over experts
+in Python (each expert is a dense tile); the Triton kernel computes both
+x @ W_gate and x @ W_up simultaneously and applies silu(gate) * up.
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel: fused dual GEMM + SwiGLU
+# ---------------------------------------------------------------------------
+# C[m,n] = silu(A[m,k] @ W_gate[k,n]) * (A[m,k] @ W_up[k,n])
+#
+# A is (M, K), W_gate is (K, N), W_up is (K, N), out is (M, N).
+# ---------------------------------------------------------------------------
+@triton.jit
+def _grouped_gemm_swiglu_kernel(
+    A_ptr,          # base ptr to (n_e, H)
+    W_gate_ptr,     # base ptr to (H, I) for this expert
+    W_up_ptr,       # base ptr to (H, I) for this expert
+    Out_ptr,        # base ptr to (n_e, I) for this expert
+    M: tl.constexpr,
+    K: tl.constexpr,
+    N: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    BLOCK_K: tl.constexpr,
+):
+    pid_m = tl.program_id(0)
+    pid_n = tl.program_id(1)
+
+    m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+    m_mask = m_offsets < M
+    n_mask = n_offsets < N
+
+    acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+    acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+    for k in range(0, K, BLOCK_K):
+        k_offsets = k + tl.arange(0, BLOCK_K)
+        k_mask = k_offsets < K
+
+        # Load A block: (BLOCK_M, BLOCK_K), col-major read row-major store
+        # A is row-major with stride K
+        a_ptrs = A_ptr + m_offsets[:, None] * K + k_offsets[None, :]
+        a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+        # Load W_gate block: (BLOCK_K, BLOCK_N), row-major stride N
+        wg_ptrs = W_gate_ptr + k_offsets[:, None] * N + n_offsets[None, :]
+        wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+
+        # Load W_up block: (BLOCK_K, BLOCK_N), row-major stride N
+        wu_ptrs = W_up_ptr + k_offsets[:, None] * N + n_offsets[None, :]
+        wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+
+        acc_gate = tl.dot(a, wg, acc_gate, out_dtype=tl.float32)
+        acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
+
+    # SwiGLU: silu(gate) * up
+    # silu(x) = x * sigmoid(x)
+    acc_gate = acc_gate.astype(tl.bfloat16)
+    acc_gate_fp32 = acc_gate.to(tl.float32)
+    sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate_fp32))
+    silu = acc_gate.to(tl.float32) * sigmoid
+    out_fp32 = silu * acc_up
+
+    out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
+    out = out_fp32.to(tl.bfloat16)
+    tl.store(out_ptr, out, mask=m_mask[:, None] & n_mask[None, :])
+
+
+def grouped_gemm_swiglu_launch(x, w_gate, w_up, n_tokens, out):
+    """Launch fused dual GEMM + SwiGLU for one expert.
+
+    x:        (n_tokens, H) bf16
+    w_gate:   (H, I) bf16
+    w_up:     (H, I) bf16
+    n_tokens: int
+    out:      (n_tokens, I) bf16
+    """
+    H = x.shape[1]
+    I = w_gate.shape[1]
+
+    # Heuristic block sizes
+    # SM120 has many SMs, try larger tiles for larger matrices
+    if n_tokens >= 1024 and I >= 1024:
+        BLOCK_M = 256
+        BLOCK_N = 256
+        BLOCK_K = 128
+    elif n_tokens >= 512:
+        BLOCK_M = 128
+        BLOCK_N = 128
+        BLOCK_K = 128
+    elif n_tokens >= 128:
+        BLOCK_M = 128
+        BLOCK_N = 64
+        BLOCK_K = 64
+    else:
+        BLOCK_M = 64
+        BLOCK_N = 64
+        BLOCK_K = 64
+
+    grid = (
+        triton.cdiv(n_tokens, BLOCK_M),
+        triton.cdiv(I, BLOCK_N),
+    )
+
+    _grouped_gemm_swiglu_kernel[grid](
+        x, w_gate, w_up, out,
+        M=n_tokens,
+        K=H,
+        N=I,
+        BLOCK_M=BLOCK_M,
+        BLOCK_N=BLOCK_N,
+        BLOCK_K=BLOCK_K,
+    )
+
+
+class Model(nn.Module):
+    """Up-projection of a top-K MoE FFN with fused SwiGLU.
+
+    Inputs at call time:
+      hidden_states:    (T_perm, H)  bf16, already permuted to expert order
+      expert_offsets:   (E+1,)       int32, prefix sums of token counts per expert
+
+    Output:
+      gated_up:         (T_perm, I)  bf16
+    """
+
+    def __init__(self, T_total: int, H: int, I: int, E: int, K: int):  # noqa: E741
+        super().__init__()
+        self.T_total = T_total
+        self.H = H
+        self.I = I
+        self.E = E
+        self.K = K
+        self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+        self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+        nn.init.normal_(self.W_gate, std=0.02)
+        nn.init.normal_(self.W_up, std=0.02)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,   # (T_perm, H) bf16
+        expert_offsets: torch.Tensor,  # (E+1,) int32
+    ) -> torch.Tensor:
+        T_perm, H = hidden_states.shape
+        out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=hidden_states.device)
+
+        for e in range(self.E):
+            start = int(expert_offsets[e].item())
+            end = int(expert_offsets[e + 1].item())
+            if end == start:
+                continue
+            n_e = end - start
+            x_e = hidden_states[start:end].contiguous()
+            w_gate_e = self.W_gate[e].contiguous()
+            w_up_e = self.W_up[e].contiguous()
+            out_e = out[start:end]
+            grouped_gemm_swiglu_launch(x_e, w_gate_e, w_up_e, n_e, out_e)
+
+        return out
+
+
+# Module-level shape shims rewritten by check.py / benchmark.py per shape.
+T_total = 32768
+H 

... (368 more chars)
args
{
  "content": "<6061 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Now I'll implement the Triton kernel. The approach: Python loop over experts, each expert gets a Triton GEMM with fused SwiGLU. This beats the reference's Python GEMM loop and avoids forbidden ops.
assistant
in 28034out 2100cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check"
}
tool
stdout (3262 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 168, in forward
    grouped_gemm_swiglu_launch(x_e, w_gate_e, w_up_e, n_e, out_e)
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 116, in grouped_gemm_swiglu_launch
    _grouped_gemm_swiglu_kernel[grid](
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/trito

... (1762 more chars)
assistant
in 30152out 69cache_r 0cache_w 0
assistant
edit /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (17 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -65,11 +65,9 @@
         acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
 
     # SwiGLU: silu(gate) * up
-    # silu(x) = x * sigmoid(x)
-    acc_gate = acc_gate.astype(tl.bfloat16)
-    acc_gate_fp32 = acc_gate.to(tl.float32)
-    sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate_fp32))
-    silu = acc_gate.to(tl.float32) * sigmoid
+    # silu(x) = x * sigmoid(x); everything in fp32, cast to bf16 on store
+    sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate))
+    silu = acc_gate * sigmoid
     out_fp32 = silu * acc_up
 
     out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
  "oldString": "<444 chars \u2014 see below>",
  "newString": "<379 chars \u2014 see below>"
}
oldString (444 chars)
    # SwiGLU: silu(gate) * up
    # silu(x) = x * sigmoid(x)
    acc_gate = acc_gate.astype(tl.bfloat16)
    acc_gate_fp32 = acc_gate.to(tl.float32)
    sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate_fp32))
    silu = acc_gate.to(tl.float32) * sigmoid
    out_fp32 = silu * acc_up

    out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
    out = out_fp32.to(tl.bfloat16)
    tl.store(out_ptr, out, mask=m_mask[:, None] & n_mask[None, :])
newString (379 chars)
    # SwiGLU: silu(gate) * up
    # silu(x) = x * sigmoid(x); everything in fp32, cast to bf16 on store
    sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate))
    silu = acc_gate * sigmoid
    out_fp32 = silu * acc_up

    out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
    out = out_fp32.to(tl.bfloat16)
    tl.store(out_ptr, out, mask=m_mask[:, None] & n_mask[None, :])
tool
stdout (26 chars)
Edit applied successfully.
assistant
in 31265out 392cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check again"
}
tool
stdout (2383 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 166, in forward
    grouped_gemm_swiglu_launch(x_e, w_gate_e, w_up_e, n_e, out_e)
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 114, in grouped_gemm_swiglu_launch
    _grouped_gemm_swiglu_kernel[grid](
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/trito

... (883 more chars)
assistant
in 31674out 70cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (196 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -12,25 +12,27 @@
 import triton.language as tl
 
 
-# ---------------------------------------------------------------------------
-# Triton kernel: fused dual GEMM + SwiGLU
-# ---------------------------------------------------------------------------
-# C[m,n] = silu(A[m,k] @ W_gate[k,n]) * (A[m,k] @ W_up[k,n])
-#
-# A is (M, K), W_gate is (K, N), W_up is (K, N), out is (M, N).
-# ---------------------------------------------------------------------------
 @triton.jit
 def _grouped_gemm_swiglu_kernel(
     A_ptr,          # base ptr to (n_e, H)
     W_gate_ptr,     # base ptr to (H, I) for this expert
     W_up_ptr,       # base ptr to (H, I) for this expert
     Out_ptr,        # base ptr to (n_e, I) for this expert
-    M: tl.constexpr,
-    K: tl.constexpr,
-    N: tl.constexpr,
+    M,
+    K,
+    N,
+    stride_am,
+    stride_ak,
+    stride_wg_k,
+    stride_wg_n,
+    stride_wu_k,
+    stride_wu_n,
+    stride_om,
+    stride_on,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
     BLOCK_K: tl.constexpr,
+    NUM_STAGES: tl.constexpr,
 ):
     pid_m = tl.program_id(0)
     pid_n = tl.program_id(1)
@@ -44,67 +46,63 @@
     acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
     acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
 
-    for k in range(0, K, BLOCK_K):
-        k_offsets = k + tl.arange(0, BLOCK_K)
-        k_mask = k_offsets < K
+    for k_start in range(0, K, BLOCK_K * NUM_STAGES):
+        for stage in range(NUM_STAGES):
+            k = k_start + stage * BLOCK_K
+            if k >= K:
+                break
+            k_offsets = k + tl.arange(0, BLOCK_K)
+            k_mask = k_offsets < K
 
-        # Load A block: (BLOCK_M, BLOCK_K), col-major read row-major store
-        # A is row-major with stride K
-        a_ptrs = A_ptr + m_offsets[:, None] * K + k_offsets[None, :]
-        a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+            a_ptrs = A_ptr + m_offsets[:, None] * stride_am + k_offsets[None, :] * stride_ak
+            a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
 
-        # Load W_gate block: (BLOCK_K, BLOCK_N), row-major stride N
-        wg_ptrs = W_gate_ptr + k_offsets[:, None] * N + n_offsets[None, :]
-        wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+            wg_ptrs = W_gate_ptr + k_offsets[:, None] * stride_wg_k + n_offsets[None, :] * stride_wg_n
+            wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
 
-        # Load W_up block: (BLOCK_K, BLOCK_N), row-major stride N
-        wu_ptrs = W_up_ptr + k_offsets[:, None] * N + n_offsets[None, :]
-        wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+            wu_ptrs = W_up_ptr + k_offsets[:, None] * stride_wu_k + n_offsets[None, :] * stride_wu_n
+            wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
 
-        acc_gate = tl.dot(a, wg, acc_gate, out_dtype=tl.float32)
-        acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
+            acc_gate = tl.dot(a, wg, acc_gate, out_dtype=tl.float32)
+            acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
 
     # SwiGLU: silu(gate) * up
-    # silu(x) = x * sigmoid(x); everything in fp32, cast to bf16 on store
     sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate))
     silu = acc_gate * sigmoid
     out_fp32 = silu * acc_up
 
-    out_ptr = Out_ptr + m_offsets[:, None] * N + n_offsets[None, :]
+    out_ptrs = Out_ptr + m_offsets[:, None] * stride_om + n_offsets[None, :] * stride_on
     out = out_fp32.to(tl.bfloat16)
-    tl.store(out_ptr, out, mask=m_mask[:, None] & n_mask[None, :])
+    tl.store(out_ptrs, out, mask=m_mask[:, None] & n_mask[None, :])
 
 
 def grouped_gemm_swiglu_launch(x, w_gate, w_up, n_tokens, out):
-    """Launch fused dual GEMM + SwiGLU for one expert.
-
-    x:        (n_tokens, H) bf16
-    w_gate:   (H, I) bf16
-    w_up:     (H, I) bf16
-    n_tokens: int
-    out:      (n_tokens, I) bf16
-    """
+    """Launch fused dual GEMM + SwiGLU for one expert."""
     H = x.shape[1]
     I = w_gate.shape[1]
 
-    # Heuristic block sizes
-    # SM120 has many SMs, try larger tiles for larger matrices
-    if n_tokens >= 1024 and I >= 1024:
-        BLOCK_M = 256
-        BLOCK_N = 256
-        BLOCK_K = 128
-    elif n_tokens >= 512:
+    # Constrain block sizes to fit in shared memory on SM120 (~101KB limit)
+    # Each tile at bf16: 2 bytes per element
+    # Shared memory needed per stage: A(BLOCK_M, BLOCK_K) + W_gate(BLOCK_K, BLOCK_N) + W_up(BLOCK_K, BLOCK_N)
+    # For 128x128x64: 2*(128*64 + 128*128 + 128*128) = 2*(8192 + 16384 + 16384) = 81664 bytes
+    # Still large with num_stages>1. Let's use 64x64x64 or 128x64x64.
+    # Actually, Triton doesn't necessarily load all stages at once for non-pipelined kernels.
+    # Let's use moderate sizes.
+    if n_tokens >= 512 and I >= 512:
         BLOCK_M = 128
         BLOCK_N = 128
-        BLOCK_K = 128
-    elif n_tokens >= 128:
+        BLOCK_K = 64
+        NUM_STAGES = 2
+    elif n_tokens >= 256 and I >= 256:
         BLOCK_M = 128
         BLOCK_N = 64
         BLOCK_K = 64
+        NUM_STAGES = 2
     else:
         BLOCK_M = 64
         BLOCK_N = 64
         BLOCK_K = 64
+        NUM_STAGES = 2
 
     grid = (
         triton.cdiv(n_tokens, BLOCK_M),
@@ -116,22 +114,23 @@
         M=n_tokens,
         K=H,
         N=I,
+        stride_am=x.stride(0),
+        stride_ak=x.stride(1),
+        stride_wg_k=w_gate.stride(0),
+        stride_wg_n=w_gate.stride(1),
+        stride_wu_k=w_up.stride(0),
+        stride_wu_n=w_up.stride(1),
+        stride_om=out.stride(0),
+        stride_on=out.stride(1),
         BLOCK_M=BLOCK_M,
         BLOCK_N=BLOCK_N,
         BLOCK_K=BLOCK_K,
+        NUM_STAGES=NUM_

... (1432 more chars)
args
{
  "content": "<6019 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 32506out 2090cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check"
}
tool
stdout (3077 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 167, in forward
    grouped_gemm_swiglu_launch(x_e, w_gate_e, w_up_e, n_e, out_slice)
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 112, in grouped_gemm_swiglu_launch
    _grouped_gemm_swiglu_kernel[grid](
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/t

... (1577 more chars)
assistant
in 34614out 68cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (223 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,9 +1,4 @@
-"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
-
-Per-expert dense GEMM with SwiGLU fused in the epilogue. Loops over experts
-in Python (each expert is a dense tile); the Triton kernel computes both
-x @ W_gate and x @ W_up simultaneously and applies silu(gate) * up.
-"""
+"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection."""
 from __future__ import annotations
 
 import torch
@@ -14,158 +9,94 @@
 
 @triton.jit
 def _grouped_gemm_swiglu_kernel(
-    A_ptr,          # base ptr to (n_e, H)
-    W_gate_ptr,     # base ptr to (H, I) for this expert
-    W_up_ptr,       # base ptr to (H, I) for this expert
-    Out_ptr,        # base ptr to (n_e, I) for this expert
+    A_ptr,
+    W_gate_ptr,
+    W_up_ptr,
+    Out_ptr,
     M,
     K,
     N,
-    stride_am,
-    stride_ak,
-    stride_wg_k,
-    stride_wg_n,
-    stride_wu_k,
-    stride_wu_n,
-    stride_om,
-    stride_on,
+    stride_a_row,
+    stride_a_col,
+    stride_w_row,
+    stride_w_col,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
     BLOCK_K: tl.constexpr,
-    NUM_STAGES: tl.constexpr,
 ):
     pid_m = tl.program_id(0)
     pid_n = tl.program_id(1)
 
-    m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
-    n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+    m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+    k_offs = tl.arange(0, BLOCK_K)
 
-    m_mask = m_offsets < M
-    n_mask = n_offsets < N
+    a_base = A_ptr + m_offs[:, None] * stride_a_row
+    wg_base = W_gate_ptr + n_offs[None, :] * stride_w_col
+    wu_base = W_up_ptr + n_offs[None, :] * stride_w_col
 
     acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
     acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
 
-    for k_start in range(0, K, BLOCK_K * NUM_STAGES):
-        for stage in range(NUM_STAGES):
-            k = k_start + stage * BLOCK_K
-            if k >= K:
-                break
-            k_offsets = k + tl.arange(0, BLOCK_K)
-            k_mask = k_offsets < K
+    for _ in range(0, K, BLOCK_K):
+        a_ptrs = a_base + k_offs[None, :] * stride_a_col
+        wg_ptrs = wg_base + k_offs[:, None] * stride_w_row
+        wu_ptrs = wu_base + k_offs[:, None] * stride_w_row
 
-            a_ptrs = A_ptr + m_offsets[:, None] * stride_am + k_offsets[None, :] * stride_ak
-            a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+        a = tl.load(a_ptrs, mask=(m_offs[:, None] < M) & (k_offs[None, :] < K), other=0.0)
+        wg = tl.load(wg_ptrs, mask=(k_offs[:, None] < K) & (n_offs[None, :] < N), other=0.0)
+        wu = tl.load(wu_ptrs, mask=(k_offs[:, None] < K) & (n_offs[None, :] < N), other=0.0)
 
-            wg_ptrs = W_gate_ptr + k_offsets[:, None] * stride_wg_k + n_offsets[None, :] * stride_wg_n
-            wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+        acc_gate += tl.dot(a, wg)
+        acc_up += tl.dot(a, wu)
 
-            wu_ptrs = W_up_ptr + k_offsets[:, None] * stride_wu_k + n_offsets[None, :] * stride_wu_n
-            wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+        a_base += BLOCK_K * stride_a_col
+        wg_base += BLOCK_K * stride_w_row
+        wu_base += BLOCK_K * stride_w_row
 
-            acc_gate = tl.dot(a, wg, acc_gate, out_dtype=tl.float32)
-            acc_up = tl.dot(a, wu, acc_up, out_dtype=tl.float32)
+    # SwiGLU
+    silu = acc_gate * tl.sigmoid(acc_gate)
+    out_fp = silu * acc_up
+    out_bf = out_fp.to(tl.bfloat16)
 
-    # SwiGLU: silu(gate) * up
-    sigmoid = 1.0 / (1.0 + tl.exp(-acc_gate))
-    silu = acc_gate * sigmoid
-    out_fp32 = silu * acc_up
-
-    out_ptrs = Out_ptr + m_offsets[:, None] * stride_om + n_offsets[None, :] * stride_on
-    out = out_fp32.to(tl.bfloat16)
-    tl.store(out_ptrs, out, mask=m_mask[:, None] & n_mask[None, :])
+    out_ptrs = Out_ptr + m_offs[:, None] * N + n_offs[None, :]
+    tl.store(out_ptrs, out_bf, mask=(m_offs[:, None] < M) & (n_offs[None, :] < N))
 
 
-def grouped_gemm_swiglu_launch(x, w_gate, w_up, n_tokens, out):
-    """Launch fused dual GEMM + SwiGLU for one expert."""
-    H = x.shape[1]
-    I = w_gate.shape[1]
-
-    # Constrain block sizes to fit in shared memory on SM120 (~101KB limit)
-    # Each tile at bf16: 2 bytes per element
-    # Shared memory needed per stage: A(BLOCK_M, BLOCK_K) + W_gate(BLOCK_K, BLOCK_N) + W_up(BLOCK_K, BLOCK_N)
-    # For 128x128x64: 2*(128*64 + 128*128 + 128*128) = 2*(8192 + 16384 + 16384) = 81664 bytes
-    # Still large with num_stages>1. Let's use 64x64x64 or 128x64x64.
-    # Actually, Triton doesn't necessarily load all stages at once for non-pipelined kernels.
-    # Let's use moderate sizes.
-    if n_tokens >= 512 and I >= 512:
-        BLOCK_M = 128
-        BLOCK_N = 128
-        BLOCK_K = 64
-        NUM_STAGES = 2
-    elif n_tokens >= 256 and I >= 256:
-        BLOCK_M = 128
-        BLOCK_N = 64
-        BLOCK_K = 64
-        NUM_STAGES = 2
-    else:
-        BLOCK_M = 64
-        BLOCK_N = 64
-        BLOCK_K = 64
-        NUM_STAGES = 2
-
-    grid = (
-        triton.cdiv(n_tokens, BLOCK_M),
-        triton.cdiv(I, BLOCK_N),
-    )
-
+def _launch(x, w_gate_e, w_up_e, n_e, out_e, bm=128, bn=128, bk=128):
+    grid = (triton.cdiv(n_e, bm), triton.cdiv(out_e.shape[1], bn))
     _grouped_gemm_swiglu_kernel[grid](
-        x, w_gate, w_up, out,
-        M=n_tokens,
-        K=H,
-        N=I,
-        stride_am=x.stride(0),
-        stride_ak=x.stride(1),
-        stride_wg_k=w_gate.stride(0),
-        stride_wg_n=w_gate.stride(1),
-        stride_wu_k=w_up.stride(0),
-        stride_wu_n=w_up.stride(1),
-        stride_om=out.stride(0),
-        stride_on=out.stride(1),
-        BLOCK_M=BLOCK_M,
-        BLOC

... (2368 more chars)
args
{
  "content": "<3695 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 35640out 1349cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check"
}
tool
stdout (2355 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 99, in forward
    _launch(x_e, self.W_gate[e], self.W_up[e], n_e, out_e)
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 68, in _launch
    _grouped_gemm_swiglu_kernel[grid](
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743,

... (855 more chars)
assistant
in 37007out 68cache_r 0cache_w 0
assistant
bash
args
{
  "command": "<705 chars \u2014 see below>",
  "description": "Calculate shared memory usage"
}
command (705 chars)
uv run python -c "
# Check shared memory requirements
bm, bn, bk = 64, 64, 64
# num_stages=1
shm_bm = bm * bk * 2
shm_wg = bk * bn * 2
shm_wu = bk * bn * 2
total = shm_bm + shm_wg + shm_wu
print(f'64x64x64 num_stages=1: {total} bytes')

bm, bn, bk = 64, 64, 128
shm_bm = bm * bk * 2
shm_wg = bk * bn * 2
shm_wu = bk * bn * 2
total = shm_bm + shm_wg + shm_wu
print(f'64x64x128 num_stages=1: {total} bytes')

bm, bn, bk = 128, 64, 128
shm_bm = bm * bk * 2
shm_wg = bk * bn * 2
shm_wu = bk * bn * 2
total = shm_bm + shm_wg + shm_wu
print(f'128x64x128 num_stages=1: {total} bytes')

# num_stages=2
bm, bn, bk = 64, 64, 128
shm = (bm * bk + 2 * bk * bn) * 2 * 2
print(f'64x64x128 num_stages=2: {shm} bytes')
"
tool
stdout (144 chars)
64x64x64 num_stages=1: 24576 bytes
64x64x128 num_stages=1: 49152 bytes
128x64x128 num_stages=1: 65536 bytes
64x64x128 num_stages=2: 98304 bytes
assistant
in 37828out 400cache_r 0cache_w 0
assistant
edit /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (29 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -63,14 +63,23 @@
     tl.store(out_ptrs, out_bf, mask=(m_offs[:, None] < M) & (n_offs[None, :] < N))
 
 
-def _launch(x, w_gate_e, w_up_e, n_e, out_e, bm=128, bn=128, bk=128):
-    grid = (triton.cdiv(n_e, bm), triton.cdiv(out_e.shape[1], bn))
+def _launch(x, w_gate_e, w_up_e, n_e, out_e):
+    I = out_e.shape[1]
+    # 64x64x128 fits in shared memory (num_stages=2 -> 98KB, just under 101KB limit)
+    if n_e >= 512 and I >= 512:
+        bm, bn, bk = 64, 128, 128
+    elif n_e >= 128:
+        bm, bn, bk = 64, 64, 128
+    else:
+        bm, bn, bk = 32, 32, 64
+    grid = (triton.cdiv(n_e, bm), triton.cdiv(I, bn))
     _grouped_gemm_swiglu_kernel[grid](
         x, w_gate_e, w_up_e, out_e,
-        M=n_e, K=x.shape[1], N=out_e.shape[1],
+        M=n_e, K=x.shape[1], N=I,
         stride_a_row=x.stride(0), stride_a_col=x.stride(1),
         stride_w_row=w_gate_e.stride(0), stride_w_col=w_gate_e.stride(1),
         BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
+        num_stages=1,
     )
 
 
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
  "oldString": "<442 chars \u2014 see below>",
  "newString": "<684 chars \u2014 see below>"
}
oldString (442 chars)
def _launch(x, w_gate_e, w_up_e, n_e, out_e, bm=128, bn=128, bk=128):
    grid = (triton.cdiv(n_e, bm), triton.cdiv(out_e.shape[1], bn))
    _grouped_gemm_swiglu_kernel[grid](
        x, w_gate_e, w_up_e, out_e,
        M=n_e, K=x.shape[1], N=out_e.shape[1],
        stride_a_row=x.stride(0), stride_a_col=x.stride(1),
        stride_w_row=w_gate_e.stride(0), stride_w_col=w_gate_e.stride(1),
        BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
    )
newString (684 chars)
def _launch(x, w_gate_e, w_up_e, n_e, out_e):
    I = out_e.shape[1]
    # 64x64x128 fits in shared memory (num_stages=2 -> 98KB, just under 101KB limit)
    if n_e >= 512 and I >= 512:
        bm, bn, bk = 64, 128, 128
    elif n_e >= 128:
        bm, bn, bk = 64, 64, 128
    else:
        bm, bn, bk = 32, 32, 64
    grid = (triton.cdiv(n_e, bm), triton.cdiv(I, bn))
    _grouped_gemm_swiglu_kernel[grid](
        x, w_gate_e, w_up_e, out_e,
        M=n_e, K=x.shape[1], N=I,
        stride_a_row=x.stride(0), stride_a_col=x.stride(1),
        stride_w_row=w_gate_e.stride(0), stride_w_col=w_gate_e.stride(1),
        BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
        num_stages=1,
    )
tool
stdout (26 chars)
Edit applied successfully.
assistant
in 38332out 545cache_r 0cache_w 0
assistant
bash
args
{
  "command": "<313 chars \u2014 see below>",
  "description": "Check shared memory for different block sizes"
}
command (313 chars)
# Check shared memory for 64x128x128 num_stages=1
uv run python -c "
bm, bn, bk = 64, 128, 128
shm = (bm * bk + bk * bn + bk * bn) * 2
print(f'64x128x128 num_stages=1: {shm} bytes (limit ~101KB)')

bm, bn, bk = 128, 64, 128
shm = (bm * bk + bk * bn + bk * bn) * 2
print(f'128x64x128 num_stages=1: {shm} bytes')
"
tool
stdout (89 chars)
64x128x128 num_stages=1: 81920 bytes (limit ~101KB)
128x64x128 num_stages=1: 65536 bytes
assistant
in 38894out 213cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 39175out 68cache_r 0cache_w 0
assistant
todowrite
args
{
  "todos": [
    {
      "content": "Implement basic Triton grouped GEMM + fused SwiGLU kernel with Python loop over experts",
      "status": "completed",
      "priority": "high"
    },
    {
      "content": "Verify correctness with check.py",
      "status": "completed",
      "priority": "high"
    },
    {
      "content": "Profile and optimize kernel for performance",
      "status": "in_progress",
      "priority": "high"
    },
    {
      "content": "Run benchmark to get timing numbers",
      "status": "pending",
      "priority": "medium"
    }
  ]
}
tool
stdout (394 chars)
content: 'Implement basic Triton grouped GEMM + fused SwiGLU kernel with Python loop over experts'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Profile and optimize kernel for performance'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark to get timing numbers'
status: 'pending'
priority: 'medium'
assistant
Correctness passes. Now let me benchmark.
assistant
in 39258out 148cache_r 0cache_w 0
assistant
bash
args
{
  "command": "./scripts/patch_torch.sh 2>&1 | tail -3 && uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (2364 chars) — TRUNCATED
zsh:1: no such file or directory: ./scripts/patch_torch.sh
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]   File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:23:18.551000 1027626 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     start = int(expert_offsets[e].item())
W0429 08:23:18.551000 1027626 .venv/lib/python3.11

... (864 more chars)
assistant
in 39577out 83cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (231 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,4 +1,9 @@
-"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection."""
+"""Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
+
+Single kernel handles all experts in one launch to avoid Python loop overhead.
+Grid is (E, cdiv(M_expert, BLOCK_M), cdiv(N, BLOCK_N)) where M_expert varies
+per expert (computed from expert_offsets).
+"""
 from __future__ import annotations
 
 import torch
@@ -8,104 +13,146 @@
 
 
 @triton.jit
-def _grouped_gemm_swiglu_kernel(
-    A_ptr,
-    W_gate_ptr,
-    W_up_ptr,
-    Out_ptr,
-    M,
-    K,
-    N,
-    stride_a_row,
-    stride_a_col,
-    stride_w_row,
-    stride_w_col,
-    BLOCK_M: tl.constexpr,
+def _grouped_gemm_swiglu_single_kernel(
+    X_ptr,           # (T_perm, H), row-major
+    W_gate_ptr,      # (E, H, I), row-major
+    W_up_ptr,        # (E, H, I), row-major
+    Out_ptr,         # (T_perm, I), row-major
+    Offsets_ptr,     # (E+1,) int32
+    T_perm,
+    H,
+    I,
+    E,
+    stride_x_row,
+    stride_w_H,
+    stride_w_I,
+    stride_o_row,
+    BLOCK_H: tl.constexpr,
     BLOCK_N: tl.constexpr,
     BLOCK_K: tl.constexpr,
 ):
-    pid_m = tl.program_id(0)
-    pid_n = tl.program_id(1)
+    # Grid: (E, cdiv(max_tokens, BLOCK_M), cdiv(I, BLOCK_N))
+    pid_e = tl.program_id(0)
+    pid_bm = tl.program_id(1)
+    pid_n = tl.program_id(2)
 
-    m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    # Read expert offset bounds
+    off_start = tl.load(Offsets_ptr + pid_e)
+    off_end = tl.load(Offsets_ptr + pid_e + 1)
+    n_e = off_end - off_start
+
+    # M dimension for this tile
+    m_base = pid_bm * 256  # We fix BLOCK_M = 256 for simplicity
+    m_offs = m_base + tl.arange(0, 256)
+
+    # Global row indices within hidden_states
+    row_global = off_start + m_offs
+    m_valid = m_offs < n_e
+
+    # N (output) dimension
     n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-    k_offs = tl.arange(0, BLOCK_K)
+    n_valid = n_offs < I
 
-    a_base = A_ptr + m_offs[:, None] * stride_a_row
-    wg_base = W_gate_ptr + n_offs[None, :] * stride_w_col
-    wu_base = W_up_ptr + n_offs[None, :] * stride_w_col
+    # K (hidden) dimension
+    k_offs = tl.arange(0, BLOCK_H)
+    k_valid = k_offs < H
 
-    acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-    acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+    # Pointers
+    # X[row, k] = X_ptr + row * stride_x_row + k
+    # W_gate[e, k, n] = W_gate_ptr + e*stride_w_H*H + k*stride_w_I + n
 
-    for _ in range(0, K, BLOCK_K):
-        a_ptrs = a_base + k_offs[None, :] * stride_a_col
-        wg_ptrs = wg_base + k_offs[:, None] * stride_w_row
-        wu_ptrs = wu_base + k_offs[:, None] * stride_w_row
+    acc_gate = tl.zeros((256, BLOCK_N), dtype=tl.float32)
+    acc_up = tl.zeros((256, BLOCK_N), dtype=tl.float32)
 
-        a = tl.load(a_ptrs, mask=(m_offs[:, None] < M) & (k_offs[None, :] < K), other=0.0)
-        wg = tl.load(wg_ptrs, mask=(k_offs[:, None] < K) & (n_offs[None, :] < N), other=0.0)
-        wu = tl.load(wu_ptrs, mask=(k_offs[:, None] < K) & (n_offs[None, :] < N), other=0.0)
+    # Load base W pointers (add expert offset once)
+    w_e_start = pid_e * stride_w_H * H  # offset for this expert's weights
+    wg_base = W_gate_ptr + w_e_start
+    wu_base = W_up_ptr + w_e_start
 
-        acc_gate += tl.dot(a, wg)
-        acc_up += tl.dot(a, wu)
+    for h_start in range(0, H, BLOCK_H):
+        h_offset = h_start + k_offs
+        h_mask = h_offset < H
 
-        a_base += BLOCK_K * stride_a_col
-        wg_base += BLOCK_K * stride_w_row
-        wu_base += BLOCK_K * stride_w_row
+        # Load X block: (256, BLOCK_H)
+        x_ptrs = X_ptr + row_global[:, None] * stride_x_row + h_offset[None, :]
+        x = tl.load(x_ptrs, mask=m_valid[:, None] & h_mask[None, :], other=0.0)
+
+        # Load W_gate block: (BLOCK_H, BLOCK_N) for expert e
+        wg_ptrs = wg_base + h_offset[:, None] * stride_w_I + n_offs[None, :]
+        wg = tl.load(wg_ptrs, mask=h_mask[:, None] & n_valid[None, :], other=0.0)
+
+        # Load W_up block: (BLOCK_H, BLOCK_N) for expert e
+        wu_ptrs = wu_base + h_offset[:, None] * stride_w_I + n_offs[None, :]
+        wu = tl.load(wu_ptrs, mask=h_mask[:, None] & n_valid[None, :], other=0.0)
+
+        acc_gate += tl.dot(x, wg)
+        acc_up += tl.dot(x, wu)
 
     # SwiGLU
-    silu = acc_gate * tl.sigmoid(acc_gate)
-    out_fp = silu * acc_up
-    out_bf = out_fp.to(tl.bfloat16)
+    gate_fp32 = acc_gate
+    silu = gate_fp32 * tl.sigmoid(gate_fp32)
+    out_fp32 = silu * acc_up
 
-    out_ptrs = Out_ptr + m_offs[:, None] * N + n_offs[None, :]
-    tl.store(out_ptrs, out_bf, mask=(m_offs[:, None] < M) & (n_offs[None, :] < N))
-
-
-def _launch(x, w_gate_e, w_up_e, n_e, out_e):
-    I = out_e.shape[1]
-    # 64x64x128 fits in shared memory (num_stages=2 -> 98KB, just under 101KB limit)
-    if n_e >= 512 and I >= 512:
-        bm, bn, bk = 64, 128, 128
-    elif n_e >= 128:
-        bm, bn, bk = 64, 64, 128
-    else:
-        bm, bn, bk = 32, 32, 64
-    grid = (triton.cdiv(n_e, bm), triton.cdiv(I, bn))
-    _grouped_gemm_swiglu_kernel[grid](
-        x, w_gate_e, w_up_e, out_e,
-        M=n_e, K=x.shape[1], N=I,
-        stride_a_row=x.stride(0), stride_a_col=x.stride(1),
-        stride_w_row=w_gate_e.stride(0), stride_w_col=w_gate_e.stride(1),
-        BLOCK_M=bm, BLOCK_N=bn, BLOCK_K=bk,
-        num_stages=1,
-    )
+    # Store
+    out_ptrs = Out_ptr + row_global[:, None] * stride_o_row + n_offs[None, :]
+    out = out_fp32.to(tl.bfloat16)
+    tl.store(out_ptrs, out, mask=m_valid[:, None] & n_valid[None, :])
 
 
 class Model(nn.Module):
     def __init__(self, T_total, H, I, E, K):  # noqa: E741
         super().__init__()
+        self.T_total = T_total
+        self.H = H
         self.I = I
+        self.E = E
+   

... (2141 more chars)
args
{
  "content": "<5293 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The performance is terrible - only 4-6% peak. The Python loop over 128/64 experts is the bottleneck (128 kernel launches). I need a single kernel that handles all experts in one launch.
assistant
in 40833out 1856cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check"
}
tool
stdout (2195 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 144, in forward
    _grouped_gemm_swiglu_single_kernel[grid](
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
    launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

... (695 more chars)
assistant
in 42707out 68cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (269 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,8 +1,9 @@
 """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
 
-Single kernel handles all experts in one launch to avoid Python loop overhead.
-Grid is (E, cdiv(M_expert, BLOCK_M), cdiv(N, BLOCK_N)) where M_expert varies
-per expert (computed from expert_offsets).
+Single kernel handles all experts in one launch. The 3D grid is
+(E, cdiv(M_expert, BM), cdiv(I, BN)), but we use a 2D grid
+(E * cdiv(total_rows/avg, BM), cdiv(I, BN)) and compute M from offsets.
+We keep tiles small enough for shared memory.
 """
 from __future__ import annotations
 
@@ -13,100 +14,131 @@
 
 
 @triton.jit
-def _grouped_gemm_swiglu_single_kernel(
-    X_ptr,           # (T_perm, H), row-major
-    W_gate_ptr,      # (E, H, I), row-major
-    W_up_ptr,        # (E, H, I), row-major
-    Out_ptr,         # (T_perm, I), row-major
-    Offsets_ptr,     # (E+1,) int32
-    T_perm,
+def _grouped_gemm_swiglu_kernel(
+    X_ptr,
+    W_gate_ptr,
+    W_up_ptr,
+    Out_ptr,
+    Offsets_ptr,
     H,
     I,
-    E,
     stride_x_row,
-    stride_w_H,
-    stride_w_I,
+    stride_w_row,
+    stride_w_col,
     stride_o_row,
-    BLOCK_H: tl.constexpr,
-    BLOCK_N: tl.constexpr,
-    BLOCK_K: tl.constexpr,
+    BM: tl.constexpr,
+    BN: tl.constexpr,
+    BK: tl.constexpr,
 ):
-    # Grid: (E, cdiv(max_tokens, BLOCK_M), cdiv(I, BLOCK_N))
+    pid_mn = tl.program_id(0)
+    pid_n = tl.program_id(1)
+
+    # We encode expert id in the upper bits of pid_mn for small E,
+    # OR we restructure as 3D grid: (E, BN, BK). Let me think differently.
+    # Actually the standard approach: grid (total_tiles_M, cdiv(I, BN)).
+    # We need to know which expert each tile belongs to and the intra-expert row.
+    pass
+
+
+# Better approach: use 3D grid (E, cdiv(max_M, BM), cdiv(I, BN)).
+# The kernel computes row = off_start[pid_e] + pid_bm * BM + local_m.
+# Shared memory: BM * BK + 2 * BK * BN bf16 values.
+# For BM=128, BN=128, BK=64: 128*64 + 2*64*128 = 8192 + 16384 = 24576 * 2 = 49152 (fits!)
+# For BM=64, BN=64, BK=128: 64*128 + 2*128*64 = 8192 + 16384 = 24576 * 2 = 49152 (fits!)
+
+@triton.jit
+def _gemma_gemm_kernel(
+    X_ptr,
+    W_gate_ptr,
+    W_up_ptr,
+    Out_ptr,
+    Offsets_ptr,
+    H,
+    I,
+    stride_x_row,
+    stride_w_row,
+    stride_w_col,
+    stride_o_row,
+    BM: tl.constexpr,
+    BN: tl.constexpr,
+    BK: tl.constexpr,
+):
     pid_e = tl.program_id(0)
     pid_bm = tl.program_id(1)
     pid_n = tl.program_id(2)
 
-    # Read expert offset bounds
-    off_start = tl.load(Offsets_ptr + pid_e)
-    off_end = tl.load(Offsets_ptr + pid_e + 1)
-    n_e = off_end - off_start
+    # Expert bounds
+    os = tl.load(Offsets_ptr + pid_e).to(tl.int64)
+    oe = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
+    n_e = oe - os
 
-    # M dimension for this tile
-    m_base = pid_bm * 256  # We fix BLOCK_M = 256 for simplicity
-    m_offs = m_base + tl.arange(0, 256)
+    # Local M offset
+    m_local = pid_bm * BM + tl.arange(0, BM)
+    m_valid = m_local < n_e
 
-    # Global row indices within hidden_states
-    row_global = off_start + m_offs
-    m_valid = m_offs < n_e
+    # Global row
+    row_g = os + m_local
 
-    # N (output) dimension
-    n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-    n_valid = n_offs < I
+    # N offsets
+    n_local = pid_n * BN + tl.arange(0, BN)
+    n_valid = n_local < I
 
-    # K (hidden) dimension
-    k_offs = tl.arange(0, BLOCK_H)
-    k_valid = k_offs < H
+    # K offsets
+    k_local = tl.arange(0, BK)
 
-    # Pointers
-    # X[row, k] = X_ptr + row * stride_x_row + k
-    # W_gate[e, k, n] = W_gate_ptr + e*stride_w_H*H + k*stride_w_I + n
+    # Accumulators
+    acc_g = tl.zeros((BM, BN), dtype=tl.float32)
+    acc_u = tl.zeros((BM, BN), dtype=tl.float32)
 
-    acc_gate = tl.zeros((256, BLOCK_N), dtype=tl.float32)
-    acc_up = tl.zeros((256, BLOCK_N), dtype=tl.float32)
+    # Weight base for this expert
+    wg_base = W_gate_ptr + pid_e * stride_w_row * H
+    wu_base = W_up_ptr + pid_e * stride_w_row * H
 
-    # Load base W pointers (add expert offset once)
-    w_e_start = pid_e * stride_w_H * H  # offset for this expert's weights
-    wg_base = W_gate_ptr + w_e_start
-    wu_base = W_up_ptr + w_e_start
+    for k_start in range(0, H, BK):
+        k_offs = k_start + k_local
+        k_mask = k_offs < H
 
-    for h_start in range(0, H, BLOCK_H):
-        h_offset = h_start + k_offs
-        h_mask = h_offset < H
+        # A[m, k]
+        a_ptrs = X_ptr + row_g[:, None] * stride_x_row + k_offs[None, :]
+        a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
 
-        # Load X block: (256, BLOCK_H)
-        x_ptrs = X_ptr + row_global[:, None] * stride_x_row + h_offset[None, :]
-        x = tl.load(x_ptrs, mask=m_valid[:, None] & h_mask[None, :], other=0.0)
+        # W_gate[e][k, n]
+        wg_ptrs = wg_base + k_offs[:, None] * stride_w_row + n_local[None, :] * stride_w_col
+        wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
 
-        # Load W_gate block: (BLOCK_H, BLOCK_N) for expert e
-        wg_ptrs = wg_base + h_offset[:, None] * stride_w_I + n_offs[None, :]
-        wg = tl.load(wg_ptrs, mask=h_mask[:, None] & n_valid[None, :], other=0.0)
+        # W_up[e][k, n]
+        wu_ptrs = wu_base + k_offs[:, None] * stride_w_row + n_local[None, :] * stride_w_col
+        wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
 
-        # Load W_up block: (BLOCK_H, BLOCK_N) for expert e
-        wu_ptrs = wu_base + h_offset[:, None] * stride_w_I + n_offs[None, :]
-        wu = tl.load(wu_ptrs, mask=h_mask[:, None] & n_valid[None, :], other=0.0)
-
-        acc_gate += tl.dot(x, wg)
-        acc_up += tl.dot(x, wu)
+        acc_g += tl.dot(a, wg)
+        acc_u += tl.dot(a, wu)
 
     #

... (3445 more chars)
args
{
  "content": "<5405 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 43467out 2062cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check"
}
tool
stdout (245 chars)
(no output)

<bash_metadata>
bash tool terminated command after exceeding timeout 300000 ms. If this command is expected to take longer and is not waiting for interactive input, retry with a larger timeout value in milliseconds.
</bash_metadata>
assistant
in 45548out 68cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (6365 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
2: 
3: Single kernel handles all experts in one launch. The 3D grid is
4: (E, cdiv(M_expert, BM), cdiv(I, BN)), but we use a 2D grid
5: (E * cdiv(total_rows/avg, BM), cdiv(I, BN)) and compute M from offsets.
6: We keep tiles small enough for shared memory.
7: """
8: from __future__ import annotations
9: 
10: import torch
11: import torch.nn as nn
12: import triton
13: import triton.language as tl
14: 
15: 
16: @triton.jit
17: def _grouped_gemm_swiglu_kernel(
18:     X_ptr,
19:     W_gate_ptr,
20:     W_up_ptr,
21:     Out_ptr,
22:     Offsets_ptr,
23:     H,
24:     I,
25:     stride_x_row,
26:     stride_w_row,
27:     stride_w_col,
28:     stride_o_row,
29:     BM: tl.constexpr,
30:     BN: tl.constexpr,
31:     BK: tl.constexpr,
32: ):
33:     pid_mn = tl.program_id(0)
34:     pid_n = tl.program_id(1)
35: 
36:     # We encode expert id in the upper bits of pid_mn for small E,
37:     # OR we restructure as 3D grid: (E, BN, BK). Let me think differently.
38:     # Actually the standard approach: grid (total_tiles_M, cdiv(I, BN)).
39:     # We need to know which expert each tile belongs to and the intra-expert row.
40:     pass
41: 
42: 
43: # Better approach: use 3D grid (E, cdiv(max_M, BM), cdiv(I, BN)).
44: # The kernel computes row = off_start[pid_e] + pid_bm * BM + local_m.
45: # Sha

... (4865 more chars)
assistant
The kernel is hanging, likely due to watchdog timeout from a huge grid. Let me restructure it properly.
assistant
in 45685out 78cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (273 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,9 +1,7 @@
 """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
 
-Single kernel handles all experts in one launch. The 3D grid is
-(E, cdiv(M_expert, BM), cdiv(I, BN)), but we use a 2D grid
-(E * cdiv(total_rows/avg, BM), cdiv(I, BN)) and compute M from offsets.
-We keep tiles small enough for shared memory.
+Flattened grid approach: grid dimension 0 covers all row-blocks across all experts.
+A per-block expert-lookup uses binary search on offsets (or a cumulative tile-count array).
 """
 from __future__ import annotations
 
@@ -11,6 +9,23 @@
 import torch.nn as nn
 import triton
 import triton.language as tl
+
+
+# We use a simple approach: grid is (total_m_blocks, cdiv(I, BN), E)
+# Actually, let's keep it as (E, cdiv(n_tokens_for_expert_e, BM), cdiv(I, BN))
+# but compute the grid size from actual per-expert tile counts.
+
+# The issue with 3D grid: the grid size in dimension 1 is not uniform across experts.
+# Standard solution: flatten dimension 0+1 into a single dimension.
+#
+# For each expert e, num_tiles_e = cdiv(n_e, BM)
+# cumulative_tiles[e] = sum(num_tiles_e' for e' < e)
+# total_tiles = cumulative_tiles[E]
+#
+# Grid: (total_tiles, cdiv(I, BN))
+# Given pid_0 in range [cumulative_tiles[e], cumulative_tiles[e+1]),
+#   expert_id = e
+#   pid_bm = pid_0 - cumulative_tiles[e]
 
 
 @triton.jit
@@ -20,94 +35,91 @@
     W_up_ptr,
     Out_ptr,
     Offsets_ptr,
+    TileOffs_ptr,  # (E+1,) int32, cumulative start tile indices
     H,
     I,
+    BM,
+    BN,
     stride_x_row,
     stride_w_row,
     stride_w_col,
     stride_o_row,
-    BM: tl.constexpr,
-    BN: tl.constexpr,
-    BK: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    BLOCK_K: tl.constexpr,
 ):
-    pid_mn = tl.program_id(0)
+    pid_tile = tl.program_id(0)
     pid_n = tl.program_id(1)
 
-    # We encode expert id in the upper bits of pid_mn for small E,
-    # OR we restructure as 3D grid: (E, BN, BK). Let me think differently.
-    # Actually the standard approach: grid (total_tiles_M, cdiv(I, BN)).
-    # We need to know which expert each tile belongs to and the intra-expert row.
-    pass
-
-
-# Better approach: use 3D grid (E, cdiv(max_M, BM), cdiv(I, BN)).
-# The kernel computes row = off_start[pid_e] + pid_bm * BM + local_m.
-# Shared memory: BM * BK + 2 * BK * BN bf16 values.
-# For BM=128, BN=128, BK=64: 128*64 + 2*64*128 = 8192 + 16384 = 24576 * 2 = 49152 (fits!)
-# For BM=64, BN=64, BK=128: 64*128 + 2*128*64 = 8192 + 16384 = 24576 * 2 = 49152 (fits!)
-
-@triton.jit
-def _gemma_gemm_kernel(
-    X_ptr,
-    W_gate_ptr,
-    W_up_ptr,
-    Out_ptr,
-    Offsets_ptr,
-    H,
-    I,
-    stride_x_row,
-    stride_w_row,
-    stride_w_col,
-    stride_o_row,
-    BM: tl.constexpr,
-    BN: tl.constexpr,
-    BK: tl.constexpr,
-):
-    pid_e = tl.program_id(0)
-    pid_bm = tl.program_id(1)
-    pid_n = tl.program_id(2)
-
-    # Expert bounds
-    os = tl.load(Offsets_ptr + pid_e).to(tl.int64)
-    oe = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
-    n_e = oe - os
-
-    # Local M offset
-    m_local = pid_bm * BM + tl.arange(0, BM)
+    # Find which expert this tile belongs to via binary search on TileOffs.
+    # TileOffs has E+1 entries; we need the largest e such that TileOffs[e] <= pid_tile.
+    # Since E is small (64 or 128), linear search is fine.
+    lo = 0
+    hi = Offsets_ptr.numel() - 2  # E-1, but we don't know E at compile time
+    # We'll just do a simple binary search
+    expert = 0
+    
+    # Binary search: find largest lo such that TileOffs[lo] <= pid_tile
+    left = 0
+    right = 128  # upper bound on E
+    while left < right:
+        mid = (left + right + 1) // 2
+        val = tl.load(TileOffs_ptr + mid).to(tl.int32)
+        if val <= pid_tile:
+            left = mid
+        else:
+            right = mid - 1
+    expert = left
+
+    tile_base = tl.load(TileOffs_ptr + expert).to(tl.int32)
+    pid_bm = pid_tile - tile_base
+
+    # Expert row bounds
+    row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
+    row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
+    n_e = row_end - row_start
+
+    # Local row offsets within expert
+    m_local = pid_bm * BLOCK_M + tl.arange(0, BLOCK_M)
     m_valid = m_local < n_e
 
-    # Global row
-    row_g = os + m_local
-
-    # N offsets
-    n_local = pid_n * BN + tl.arange(0, BN)
+    # Global rows
+    row_g = row_start + m_local
+
+    # Column offsets
+    n_local = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
     n_valid = n_local < I
 
     # K offsets
-    k_local = tl.arange(0, BK)
+    k_local = tl.arange(0, BLOCK_K)
 
     # Accumulators
-    acc_g = tl.zeros((BM, BN), dtype=tl.float32)
-    acc_u = tl.zeros((BM, BN), dtype=tl.float32)
-
-    # Weight base for this expert
-    wg_base = W_gate_ptr + pid_e * stride_w_row * H
-    wu_base = W_up_ptr + pid_e * stride_w_row * H
-
-    for k_start in range(0, H, BK):
+    acc_g = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+    acc_u = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+    # Weight base for this expert: W is (E, H, I), row-major
+    # stride_w_row = W.stride(1) = I, stride_w_col = W.stride(2) = 1
+    # W[e, k, n] = W_ptr + e*H*I + k*I + n
+    wg_base = W_gate_ptr + expert * H * I.to(tl.int64)
+    wu_base = W_up_ptr + expert * H * I.to(tl.int64)
+
+    for k_start in range(0, H, BLOCK_K):
         k_offs = k_start + k_local
         k_mask = k_offs < H
 
-        # A[m, k]
-        a_ptrs = X_ptr + row_g[:, None] * stride_x_row + k_offs[None, :]
+        # A[m, k]: (BLOCK_M, BLOCK_K)
+        a_ptrs = X_ptr + row_g.to(tl.int64)[:, None] * stride_x_row + k_offs[None, :] * stride_x_row // BLOCK_M
+        # Actually, stride_x_row might not be H if the tensor isn't contiguous...
+        # Let me just use stride_x_row directly
+       

... (4233 more chars)
args
{
  "content": "<7051 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 48486out 2401cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (234 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,7 @@
 """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
 
-Flattened grid approach: grid dimension 0 covers all row-blocks across all experts.
-A per-block expert-lookup uses binary search on offsets (or a cumulative tile-count array).
+Flattened 2D grid: (cdiv(T_perm, BM), cdiv(I, BN)) with expert lookup
+via binary search on offsets.
 """
 from __future__ import annotations
 
@@ -11,23 +11,6 @@
 import triton.language as tl
 
 
-# We use a simple approach: grid is (total_m_blocks, cdiv(I, BN), E)
-# Actually, let's keep it as (E, cdiv(n_tokens_for_expert_e, BM), cdiv(I, BN))
-# but compute the grid size from actual per-expert tile counts.
-
-# The issue with 3D grid: the grid size in dimension 1 is not uniform across experts.
-# Standard solution: flatten dimension 0+1 into a single dimension.
-#
-# For each expert e, num_tiles_e = cdiv(n_e, BM)
-# cumulative_tiles[e] = sum(num_tiles_e' for e' < e)
-# total_tiles = cumulative_tiles[E]
-#
-# Grid: (total_tiles, cdiv(I, BN))
-# Given pid_0 in range [cumulative_tiles[e], cumulative_tiles[e+1]),
-#   expert_id = e
-#   pid_bm = pid_0 - cumulative_tiles[e]
-
-
 @triton.jit
 def _grouped_gemm_swiglu_kernel(
     X_ptr,
@@ -35,132 +18,97 @@
     W_up_ptr,
     Out_ptr,
     Offsets_ptr,
-    TileOffs_ptr,  # (E+1,) int32, cumulative start tile indices
+    T_perm,
     H,
-    I,
-    BM,
-    BN,
-    stride_x_row,
-    stride_w_row,
-    stride_w_col,
-    stride_o_row,
-    BLOCK_M: tl.constexpr,
-    BLOCK_N: tl.constexpr,
-    BLOCK_K: tl.constexpr,
+    I_val,
+    E,
+    BM: tl.constexpr,
+    BN: tl.constexpr,
+    BK: tl.constexpr,
 ):
-    pid_tile = tl.program_id(0)
+    pid_m = tl.program_id(0)
     pid_n = tl.program_id(1)
 
-    # Find which expert this tile belongs to via binary search on TileOffs.
-    # TileOffs has E+1 entries; we need the largest e such that TileOffs[e] <= pid_tile.
-    # Since E is small (64 or 128), linear search is fine.
+    # Find expert for the first row in this block
+    row_0 = pid_m * BM
+    # Binary search for expert: find smallest e such that Offsets[e+1] > row_0
     lo = 0
-    hi = Offsets_ptr.numel() - 2  # E-1, but we don't know E at compile time
-    # We'll just do a simple binary search
-    expert = 0
-    
-    # Binary search: find largest lo such that TileOffs[lo] <= pid_tile
-    left = 0
-    right = 128  # upper bound on E
-    while left < right:
-        mid = (left + right + 1) // 2
-        val = tl.load(TileOffs_ptr + mid).to(tl.int32)
-        if val <= pid_tile:
-            left = mid
+    hi = E - 1
+    while lo < hi:
+        mid = (lo + hi) // 2
+        off = tl.load(Offsets_ptr + mid + 1).to(tl.int64)
+        if off <= row_0:
+            lo = mid + 1
         else:
-            right = mid - 1
-    expert = left
+            hi = mid
+    expert = lo
 
-    tile_base = tl.load(TileOffs_ptr + expert).to(tl.int32)
-    pid_bm = pid_tile - tile_base
-
-    # Expert row bounds
     row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
     row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
-    n_e = row_end - row_start
 
-    # Local row offsets within expert
-    m_local = pid_bm * BLOCK_M + tl.arange(0, BLOCK_M)
-    m_valid = m_local < n_e
+    # M offset (intra-expert rows)
+    m_local = pid_m * BM + tl.arange(0, BM) - row_start
+    m_valid = (pid_m * BM + tl.arange(0, BM)) < row_end
 
     # Global rows
-    row_g = row_start + m_local
+    row_g = pid_m * BM + tl.arange(0, BM)
+    row_valid = row_g < T_perm & m_valid
 
-    # Column offsets
-    n_local = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-    n_valid = n_local < I
+    # N offset
+    n_offs = pid_n * BN + tl.arange(0, BN)
+    n_valid = n_offs < I_val
 
-    # K offsets
-    k_local = tl.arange(0, BLOCK_K)
+    # K offset
+    k_offs = tl.arange(0, BK)
 
-    # Accumulators
-    acc_g = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-    acc_u = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+    acc_g = tl.zeros((BM, BN), dtype=tl.float32)
+    acc_u = tl.zeros((BM, BN), dtype=tl.float32)
 
-    # Weight base for this expert: W is (E, H, I), row-major
-    # stride_w_row = W.stride(1) = I, stride_w_col = W.stride(2) = 1
-    # W[e, k, n] = W_ptr + e*H*I + k*I + n
-    wg_base = W_gate_ptr + expert * H * I.to(tl.int64)
-    wu_base = W_up_ptr + expert * H * I.to(tl.int64)
+    # Weight base: W is (E, H, I), row major
+    # W[e, k, n] offset = e * H * I + k * I + n
+    w_base = expert * H * I_val
 
-    for k_start in range(0, H, BLOCK_K):
-        k_offs = k_start + k_local
-        k_mask = k_offs < H
+    for k_start in range(0, H, BK):
+        k_mask = k_start + k_offs < H
+        k = k_start + k_offs
 
-        # A[m, k]: (BLOCK_M, BLOCK_K)
-        a_ptrs = X_ptr + row_g.to(tl.int64)[:, None] * stride_x_row + k_offs[None, :] * stride_x_row // BLOCK_M
-        # Actually, stride_x_row might not be H if the tensor isn't contiguous...
-        # Let me just use stride_x_row directly
-        a_ptrs = X_ptr + row_g.to(tl.int64)[:, None] * stride_x_row + k_offs[None, :]
-        a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
+        # A[row, k]
+        a_ptrs = X_ptr + row_g[:, None] * H + k[None, :]
+        a = tl.load(a_ptrs, mask=row_valid[:, None] & k_mask[None, :], other=0.0)
 
-        # W_gate: (BLOCK_K, BLOCK_N) -> W[e, k, n]
-        wg_ptrs = wg_base + k_offs[:, None].to(tl.int64) * stride_w_row + n_local[None, :].to(tl.int64) * stride_w_col
+        # W_gate[e, k, n]
+        wg_ptrs = W_gate_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
         wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
 
-        # W_up
-        wu_ptrs = wu_base + k_offs[:, None].to(tl.int64) * stride_w_row + n_local[None, :].to(tl.int64) * stride_

... (2914 more chars)
args
{
  "content": "<4100 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 50905out 1521cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (2370 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 127, in forward
    _launch(hidden_states, self.W_gate, self.W_up, out, expert_offsets)
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 108, in _launch
    _grouped_gemm_swiglu_kernel[grid](
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit

... (870 more chars)
assistant
in 52444out 68cache_r 0cache_w 0
assistant
bash
args
{
  "command": "<401 chars \u2014 see below>",
  "description": "Find valid block sizes within shared memory limit"
}
command (401 chars)
uv run python -c "
# Shared memory for BM x BK A + 2 x BK x BN W (num_stages=1)
for BM in [32, 64, 128]:
    for BN in [32, 64, 128]:
        for BK in [32, 64, 128]:
            shm = (BM * BK + BK * BN + BK * BN) * 2  # bf16 = 2 bytes
            if shm <= 101376:
                flops = BM * BN * BK  # fma per tile
                print(f'BM={BM} BN={BN} BK={BK} shm={shm} flops/tile={flops}')
"
tool
stdout (1267 chars)
BM=32 BN=32 BK=32 shm=6144 flops/tile=32768
BM=32 BN=32 BK=64 shm=12288 flops/tile=65536
BM=32 BN=32 BK=128 shm=24576 flops/tile=131072
BM=32 BN=64 BK=32 shm=10240 flops/tile=65536
BM=32 BN=64 BK=64 shm=20480 flops/tile=131072
BM=32 BN=64 BK=128 shm=40960 flops/tile=262144
BM=32 BN=128 BK=32 shm=18432 flops/tile=131072
BM=32 BN=128 BK=64 shm=36864 flops/tile=262144
BM=32 BN=128 BK=128 shm=73728 flops/tile=524288
BM=64 BN=32 BK=32 shm=8192 flops/tile=65536
BM=64 BN=32 BK=64 shm=16384 flops/tile=131072
BM=64 BN=32 BK=128 shm=32768 flops/tile=262144
BM=64 BN=64 BK=32 shm=12288 flops/tile=131072
BM=64 BN=64 BK=64 shm=24576 flops/tile=262144
BM=64 BN=64 BK=128 shm=49152 flops/tile=524288
BM=64 BN=128 BK=32 shm=20480 flops/tile=262144
BM=64 BN=128 BK=64 shm=40960 flops/tile=524288
BM=64 BN=128 BK=128 shm=81920 flops/tile=1048576
BM=128 BN=32 BK=32 shm=12288 flops/tile=131072
BM=128 BN=32 BK=64 shm=24576 flops/tile=262144
BM=128 BN=32 BK=128 shm=49152 flops/tile=524288
BM=128 BN=64 BK=32 shm=16384 flops/tile=262144
BM=128 BN=64 BK=64 shm=32768 flops/tile=524288
BM=128 BN=64 BK=128 shm=65536 flops/tile=1048576
BM=128 BN=128 BK=32 shm=24576 flops/tile=524288
BM=128 BN=128 BK=64 shm=49152 flops/tile=1048576
BM=128 BN=128 BK=128 shm=98304 flops/tile=2097152
assistant
in 53264out 219cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (134 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,7 @@
 """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
 
-Flattened 2D grid: (cdiv(T_perm, BM), cdiv(I, BN)) with expert lookup
-via binary search on offsets.
+Single kernel launch, flattened 2D grid (cdiv(T_perm, BM), cdiv(I, BN)).
+Each kernel block determines its owning expert by binary-searching offsets.
 """
 from __future__ import annotations
 
@@ -21,7 +21,6 @@
     T_perm,
     H,
     I_val,
-    E,
     BM: tl.constexpr,
     BN: tl.constexpr,
     BK: tl.constexpr,
@@ -29,67 +28,69 @@
     pid_m = tl.program_id(0)
     pid_n = tl.program_id(1)
 
-    # Find expert for the first row in this block
-    row_0 = pid_m * BM
-    # Binary search for expert: find smallest e such that Offsets[e+1] > row_0
+    # First row this block covers
+    row_start_block = pid_m * BM
+
+    # Binary search for expert: find smallest e where Offsets[e+1] > row_start_block
     lo = 0
-    hi = E - 1
+    hi = 127  # max E is 128
     while lo < hi:
-        mid = (lo + hi) // 2
+        mid = (lo + hi) >> 1
         off = tl.load(Offsets_ptr + mid + 1).to(tl.int64)
-        if off <= row_0:
+        if off <= row_start_block:
             lo = mid + 1
         else:
             hi = mid
     expert = lo
 
+    # Read expert boundaries
     row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
     row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
 
-    # M offset (intra-expert rows)
-    m_local = pid_m * BM + tl.arange(0, BM) - row_start
-    m_valid = (pid_m * BM + tl.arange(0, BM)) < row_end
+    # Local row offsets within this expert
+    m_local = pid_m * BM - row_start + tl.arange(0, BM)
+    m_valid = m_local >= 0 and m_local < (row_end - row_start)
+    row_valid = row_start_block + tl.arange(0, BM) < T_perm
 
-    # Global rows
     row_g = pid_m * BM + tl.arange(0, BM)
-    row_valid = row_g < T_perm & m_valid
 
     # N offset
     n_offs = pid_n * BN + tl.arange(0, BN)
     n_valid = n_offs < I_val
 
-    # K offset
+    # K
     k_offs = tl.arange(0, BK)
 
+    # Accumulators
     acc_g = tl.zeros((BM, BN), dtype=tl.float32)
     acc_u = tl.zeros((BM, BN), dtype=tl.float32)
 
-    # Weight base: W is (E, H, I), row major
-    # W[e, k, n] offset = e * H * I + k * I + n
-    w_base = expert * H * I_val
+    # Weight base: expert * H * I
+    w_base = expert.to(tl.int64) * H * I_val
 
+    # Main GEMM loop
     for k_start in range(0, H, BK):
-        k_mask = k_start + k_offs < H
         k = k_start + k_offs
+        k_mask = k < H
 
-        # A[row, k]
-        a_ptrs = X_ptr + row_g[:, None] * H + k[None, :]
+        # X[row, k]: stride is H (row-major contiguous)
+        a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * H + k[None, :]
         a = tl.load(a_ptrs, mask=row_valid[:, None] & k_mask[None, :], other=0.0)
 
-        # W_gate[e, k, n]
+        # W_gate[expert, k, n]: (H, I) slice
         wg_ptrs = W_gate_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
         wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
 
-        # W_up[e, k, n]
+        # W_up[expert, k, n]
         wu_ptrs = W_up_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
         wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
 
-        acc_g = tl.dot(a, wg, acc_g, out_dtype=tl.float32)
-        acc_u = tl.dot(a, wu, acc_u, out_dtype=tl.float32)
+        acc_g += tl.dot(a, wg)
+        acc_u += tl.dot(a, wu)
 
     # SwiGLU
     out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
-    out_ptrs = Out_ptr + row_g[:, None] * I_val + n_offs[None, :]
+    out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * I_val + n_offs[None, :]
     tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=row_valid[:, None] & n_valid[None, :])
 
 
@@ -97,17 +98,16 @@
     T_perm = X.size(0)
     H = X.size(1)
     I_val = out.size(1)
-    E_val = expert_offsets.size(0) - 1
 
-    BM = 256
-    BN = 128
+    BM = 128
+    BN = 64
     BK = 128
 
     grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
 
     _grouped_gemm_swiglu_kernel[grid](
         X, W_gate, W_up, out, expert_offsets,
-        T_perm, H, I_val, E_val,
+        T_perm, H, I_val,
         BM=BM, BN=BN, BK=BK,
     )
 
args
{
  "content": "<4167 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Now let me rewrite with proper block sizes and the correct approach - a single kernel with a flattened 2D grid using offset-to-expert mapping.
assistant
in 54358out 1535cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 180000,
  "description": "Run correctness check"
}
tool
stdout (3030 chars) — TRUNCATED
/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py:52: UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead
  m_valid = m_local >= 0 and m_local < (row_end - row_start)
UserWarning: Enable tracemalloc to get the object allocation traceback
/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py:52: UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead
  m_valid = m_local >= 0 and m_local < (row_end - row_start)
UserWarning: Enable tracemalloc to get the object allocation traceback
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_s

... (1530 more chars)
assistant
in 55912out 68cache_r 0cache_w 0
assistant
edit /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (10 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -105,6 +105,7 @@
 
     grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
 
+    # num_stages=1 to stay within shared memory limit on SM120
     _grouped_gemm_swiglu_kernel[grid](
         X, W_gate, W_up, out, expert_offsets,
         T_perm, H, I_val,
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
  "oldString": "<100 chars \u2014 see below>",
  "newString": "<163 chars \u2014 see below>"
}
oldString (100 chars)
    grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))

    _grouped_gemm_swiglu_kernel[grid](
newString (163 chars)
    grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))

    # num_stages=1 to stay within shared memory limit on SM120
    _grouped_gemm_swiglu_kernel[grid](
tool
stdout (26 chars)
Edit applied successfully.
assistant
in 56918out 170cache_r 0cache_w 0
assistant
edit /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (15 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -105,11 +105,11 @@
 
     grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
 
-    # num_stages=1 to stay within shared memory limit on SM120
     _grouped_gemm_swiglu_kernel[grid](
         X, W_gate, W_up, out, expert_offsets,
         T_perm, H, I_val,
         BM=BM, BN=BN, BK=BK,
+        num_stages=1, num_warps=8,
     )
 
 
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
  "newString": "<242 chars \u2014 see below>",
  "oldString": "<270 chars \u2014 see below>"
}
newString (242 chars)
    grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))

    _grouped_gemm_swiglu_kernel[grid](
        X, W_gate, W_up, out, expert_offsets,
        T_perm, H, I_val,
        BM=BM, BN=BN, BK=BK,
        num_stages=1, num_warps=8,
    )
oldString (270 chars)
    grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))

    # num_stages=1 to stay within shared memory limit on SM120
    _grouped_gemm_swiglu_kernel[grid](
        X, W_gate, W_up, out, expert_offsets,
        T_perm, H, I_val,
        BM=BM, BN=BN, BK=BK,
    )
tool
stdout (26 chars)
Edit applied successfully.
assistant
The issue is 128*128 + 128*64 + 128*64 = 65536 * 2 = 131072 with num_stages=2 (default). The table showed this fits with num_stages=1. Let me explicitly set `num_stages=1`:
assistant
in 57105out 345cache_r 0cache_w 0
assistant
edit /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (11 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -49,7 +49,7 @@
 
     # Local row offsets within this expert
     m_local = pid_m * BM - row_start + tl.arange(0, BM)
-    m_valid = m_local >= 0 and m_local < (row_end - row_start)
+    m_valid = (m_local >= 0) & (m_local < (row_end - row_start))
     row_valid = row_start_block + tl.arange(0, BM) < T_perm
 
     row_g = pid_m * BM + tl.arange(0, BM)
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
  "newString": "    m_valid = (m_local >= 0) & (m_local < (row_end - row_start))",
  "oldString": "    m_valid = m_local >= 0 and m_local < (row_end - row_start)"
}
tool
stdout (26 chars)
Edit applied successfully.
assistant
I also need to fix the `and` operator issue:
assistant
in 57468out 134cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 180000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 57620out 68cache_r 0cache_w 0
assistant
Correctness passes. Now let me benchmark:
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (2308 chars) — TRUNCATED
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]   File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     start = int(expert_offsets[e].item())
W0429 08:40:01.257000 1038271 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0

... (808 more chars)
assistant
in 57703out 77cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (186 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,7 @@
 """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
 
-Single kernel launch, flattened 2D grid (cdiv(T_perm, BM), cdiv(I, BN)).
-Each kernel block determines its owning expert by binary-searching offsets.
+3D grid (E, cdiv(MAX_M, BM), cdiv(I, BN)) - no binary search needed.
+Blocks whose range falls outside the expert's row range are no-ops.
 """
 from __future__ import annotations
 
@@ -18,71 +18,68 @@
     W_up_ptr,
     Out_ptr,
     Offsets_ptr,
-    T_perm,
+    MAX_M,
     H,
     I_val,
+    stride_x_row,
+    stride_o_row,
     BM: tl.constexpr,
     BN: tl.constexpr,
     BK: tl.constexpr,
 ):
-    pid_m = tl.program_id(0)
-    pid_n = tl.program_id(1)
+    pid_e = tl.program_id(0)
+    pid_bm = tl.program_id(1)
+    pid_n = tl.program_id(2)
 
-    # First row this block covers
-    row_start_block = pid_m * BM
+    # Expert row bounds
+    row_start = tl.load(Offsets_ptr + pid_e).to(tl.int64)
+    row_end = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
+    n_e = row_end - row_start
 
-    # Binary search for expert: find smallest e where Offsets[e+1] > row_start_block
-    lo = 0
-    hi = 127  # max E is 128
-    while lo < hi:
-        mid = (lo + hi) >> 1
-        off = tl.load(Offsets_ptr + mid + 1).to(tl.int64)
-        if off <= row_start_block:
-            lo = mid + 1
-        else:
-            hi = mid
-    expert = lo
+    # First row of this block
+    block_first = pid_bm * BM
 
-    # Read expert boundaries
-    row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
-    row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
+    # If this block is entirely outside the expert's range, skip
+    if block_first >= n_e:
+        return
 
-    # Local row offsets within this expert
-    m_local = pid_m * BM - row_start + tl.arange(0, BM)
-    m_valid = (m_local >= 0) & (m_local < (row_end - row_start))
-    row_valid = row_start_block + tl.arange(0, BM) < T_perm
+    # Local row offsets within expert
+    m_local = block_first + tl.arange(0, BM)
+    m_valid = m_local < n_e
 
-    row_g = pid_m * BM + tl.arange(0, BM)
+    # Global row indices
+    row_g = row_start + m_local
 
-    # N offset
+    # Column offsets
     n_offs = pid_n * BN + tl.arange(0, BN)
     n_valid = n_offs < I_val
 
-    # K
+    # Reduction dimension offsets
     k_offs = tl.arange(0, BK)
 
     # Accumulators
     acc_g = tl.zeros((BM, BN), dtype=tl.float32)
     acc_u = tl.zeros((BM, BN), dtype=tl.float32)
 
-    # Weight base: expert * H * I
-    w_base = expert.to(tl.int64) * H * I_val
+    # Weight base for this expert: W is (E, H, I), row-major
+    wg_base = W_gate_ptr + pid_e * H * I_val
+    wu_base = W_up_ptr + pid_e * H * I_val
 
-    # Main GEMM loop
+    # Main GEMM loop over K=H dimension
     for k_start in range(0, H, BK):
         k = k_start + k_offs
         k_mask = k < H
 
-        # X[row, k]: stride is H (row-major contiguous)
-        a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * H + k[None, :]
-        a = tl.load(a_ptrs, mask=row_valid[:, None] & k_mask[None, :], other=0.0)
+        # Load A: (BM, BK)
+        a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * stride_x_row + k[None, :]
+        a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
 
-        # W_gate[expert, k, n]: (H, I) slice
-        wg_ptrs = W_gate_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
+        # Load W_gate[e, k, n]: (BK, BN)
+        wg_ptrs = wg_base + k[:, None] * I_val + n_offs[None, :]
         wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
 
-        # W_up[expert, k, n]
-        wu_ptrs = W_up_ptr + w_base + k[:, None] * I_val + n_offs[None, :]
+        # Load W_up[e, k, n]: (BK, BN)
+        wu_ptrs = wu_base + k[:, None] * I_val + n_offs[None, :]
         wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
 
         acc_g += tl.dot(a, wg)
@@ -90,26 +87,52 @@
 
     # SwiGLU
     out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
-    out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * I_val + n_offs[None, :]
-    tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=row_valid[:, None] & n_valid[None, :])
+    out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * stride_o_row + n_offs[None, :]
+    tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
 
 
 def _launch(X, W_gate, W_up, out, expert_offsets):
     T_perm = X.size(0)
     H = X.size(1)
     I_val = out.size(1)
+    E = expert_offsets.size(0) - 1
 
-    BM = 128
-    BN = 64
-    BK = 128
+    # Heuristic: pick block sizes that maximize FLOPS/tile while fitting in shared memory
+    # Shared memory: (BM * BK + 2 * BK * BN) * 2 bytes (bf16, num_stages=1)
+    # For BM=256, BN=256, BK=64: (256*64 + 2*64*256)*2 = (16384+32768)*2 = 98304 < 101k OK
+    # For BM=256, BN=128, BK=128: (256*128 + 2*128*128)*2 = (32768+32768)*2 = 131072 - too big
+    # For BM=128, BN=128, BK=128: (128*128 + 2*128*128)*2 = (16384+32768)*2 = 98304 < 101k OK
+    if H >= 2048:
+        BM = 128
+        BN = 128
+        BK = 128
+    else:
+        BM = 64
+        BN = 128
+        BK = 64
 
-    grid = (triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
+    # For SM120, try larger blocks when possible
+    # 128x256x64: (128*64 + 2*64*256)*2 = (8192+32768)*2 = 81920 - OK
+    # 256x128x64: (256*64 + 2*64*128)*2 = (16384+16384)*2 = 65536 - OK
+    if I_val >= 1024 and H >= 4096:
+        BM = 256
+        BN = 128
+        BK = 64
+    elif I_val >= 4096:
+        BM = 128
+        BN = 128
+        BK = 128
+
+    grid = (E, triton.cdiv(T_perm, BM), triton.cdiv(I_val, BN))
 
     _grouped_gemm_swiglu_kernel[grid](
         X, W_gate, W_up, out, expert_offsets,
         T_perm, H, I_val,
+        stride_x_row=X.stride(0),
+        stride_o_row=out.stride(0),
 

... (495 more chars)
args
{
  "content": "<5072 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 58938out 2015cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 180000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 60971out 68cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (2308 chars) — TRUNCATED
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]   File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     start = int(expert_offsets[e].item())
W0429 08:45:05.200000 1041350 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0

... (808 more chars)
assistant
in 61054out 67cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (237 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,7 @@
 """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
 
-3D grid (E, cdiv(MAX_M, BM), cdiv(I, BN)) - no binary search needed.
-Blocks whose range falls outside the expert's row range are no-ops.
+3D grid (E, cdiv(n_tokens_for_e, BM), cdiv(I, BN)) - grid size is
+computed from actual expert token counts, no wasted blocks.
 """
 from __future__ import annotations
 
@@ -18,7 +18,6 @@
     W_up_ptr,
     Out_ptr,
     Offsets_ptr,
-    MAX_M,
     H,
     I_val,
     stride_x_row,
@@ -36,15 +35,8 @@
     row_end = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
     n_e = row_end - row_start
 
-    # First row of this block
-    block_first = pid_bm * BM
-
-    # If this block is entirely outside the expert's range, skip
-    if block_first >= n_e:
-        return
-
     # Local row offsets within expert
-    m_local = block_first + tl.arange(0, BM)
+    m_local = pid_bm * BM + tl.arange(0, BM)
     m_valid = m_local < n_e
 
     # Global row indices
@@ -88,6 +80,129 @@
     # SwiGLU
     out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
     out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * stride_o_row + n_offs[None, :]
+    tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
+
+
+def _expert_grid(expert_offsets, BM, BN, I):
+    """Build 3D grid from expert_offsets: (E_list, M_list, N_list) where
+    E_list is just [0,1,...,E-1] repeated, and M_list/N_list are per-expert tile counts."""
+    E = len(expert_offsets) - 1
+    counts = expert_offsets[1:] - expert_offsets[:-1]  # tokens per expert
+    m_tiles = (counts + BM - 1) // BM  # cdiv per expert
+    n_tiles = (I + BN - 1) // BN
+
+    # Create grid as tensor of shape (total_work_items, 3)
+    # Actually, Triton doesn't support tensor grids. We need to use a flat list approach.
+    # Instead, let's use a 1D grid and decode (e, m, n) from the flat index.
+    total_work = int(m_tiles.sum()) * n_tiles
+    return E, int(m_tiles.sum()), n_tiles, m_tiles, total_work
+
+
+@triton.jit
+def _grouped_gemm_1d_kernel(
+    X_ptr,
+    W_gate_ptr,
+    W_up_ptr,
+    Out_ptr,
+    Offsets_ptr,
+    MTilesPrefix_ptr,  # cumulative sum of m_tiles: (E+1,) int32, prefix sum of m_tiles
+    H,
+    I_val,
+    total_m_tiles,
+    BN: tl.constexpr,
+    stride_x_row,
+    stride_o_row,
+    BM: tl.constexpr,
+    BLOCK_BN: tl.constexpr,
+    BK: tl.constexpr,
+):
+    pid_work = tl.program_id(0)
+    pid_n = tl.program_id(1)
+    pid_bm_in_expert = pid_work  # each pid_work handles one M-tile for one expert
+    # But we need to know WHICH expert AND which M-tile.
+    # Actually, let's restructure: grid = (total_m_tiles, cdiv(I, BN))
+    # Then for each pid_m in total_m_tiles, find which expert it belongs to.
+    pass
+
+
+# Simplest approach: 2D flattened grid (total_tiles_M, cdiv(I, BN))
+# where total_tiles_M = sum_e cdiv(n_tokens_e, BM)
+# Use binary search on cumulative_M_tiles to find expert index.
+
+
+@triton.jit
+def _grouped_gemm_flat_kernel(
+    X_ptr,
+    W_gate_ptr,
+    W_up_ptr,
+    Out_ptr,
+    Offsets_ptr,
+    CumMTiles_ptr,  # (E+1,) int32, cumulative M-tiles: CumMTiles[e] = sum of cdiv(n_tokens_e', BM) for e' < e
+    H,
+    I_val,
+    n_experts,
+    BM: tl.constexpr,
+    BN: tl.constexpr,
+    BK: tl.constexpr,
+):
+    pid_m = tl.program_id(0)
+    pid_n = tl.program_id(1)
+
+    # Find expert for this M-tile: largest e such that CumMTiles[e] <= pid_m
+    lo = 0
+    hi = n_experts - 1
+    while lo < hi:
+        mid = (lo + hi + 1) >> 1
+        val = tl.load(CumMTiles_ptr + mid).to(tl.int32)
+        if val <= pid_m:
+            lo = mid
+        else:
+            hi = mid - 1
+    expert = lo
+
+    # M-tile index within this expert
+    tile_off = tl.load(CumMTiles_ptr + expert).to(tl.int32)
+    pid_bm = pid_m - tile_off
+
+    # Expert row bounds
+    row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
+    row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
+    n_e = row_end - row_start
+
+    m_local = pid_bm * BM + tl.arange(0, BM)
+    m_valid = m_local < n_e
+
+    row_g = row_start + m_local
+
+    n_offs = pid_n * BN + tl.arange(0, BN)
+    n_valid = n_offs < I_val
+
+    k_offs = tl.arange(0, BK)
+
+    acc_g = tl.zeros((BM, BN), dtype=tl.float32)
+    acc_u = tl.zeros((BM, BN), dtype=tl.float32)
+
+    wg_base = W_gate_ptr + expert * H * I_val
+    wu_base = W_up_ptr + expert * H * I_val
+
+    for k_start in range(0, H, BK):
+        k = k_start + k_offs
+        k_mask = k < H
+
+        a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * H + k[None, :]
+        a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
+
+        wg_ptrs = wg_base + k[:, None] * I_val + n_offs[None, :]
+        wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
+
+        wu_ptrs = wu_base + k[:, None] * I_val + n_offs[None, :]
+        wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
+
+        acc_g += tl.dot(a, wg)
+        acc_u += tl.dot(a, wu)
+
+    out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
+    out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * I_val + n_offs[None, :]
     tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
 
 
@@ -96,40 +211,42 @@
     H = X.size(1)
     I_val = out.size(1)
     E = expert_offsets.size(0) - 1
-
-    # Heuristic: pick block sizes that maximize FLOPS/tile while fitting in shared memory
-    # Shared memory: (BM * BK + 2 * BK * BN) * 2 bytes (bf16, num_stages=1)
-    # For BM=256, BN=256, BK=64: (256*64 + 2*64*256)*2 = (16384+32768)*2 = 98304 < 101k OK
-    # For BM=256, BN=128, BK=128: (256*128 + 2*128*128)*2 = (32768+32768)*2 = 131072 - too big
-    # For BM=128, BN=128, BK=128: (128*128 + 2*128*128)*2 = (16384+32768)*2 = 98304 < 101k OK
-    if H >= 

... (1858 more chars)
args
{
  "content": "<8717 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The grid is way too sparse - cdiv(T_perm, BM) = 1024 blocks per expert, but most are empty. Need to build per-expert tile grid from actual counts.
assistant
in 62278out 3197cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
diff (305 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,7 +1,9 @@
 """Triton grouped GEMM + fused SwiGLU for Sonic-MoE up-projection.
 
-3D grid (E, cdiv(n_tokens_for_e, BM), cdiv(I, BN)) - grid size is
-computed from actual expert token counts, no wasted blocks.
+2D grid: (sum_e cdiv(n_e, BM), cdiv(I, BN)) with per-block expert lookup
+via binary search on cumulative tile counts.
+
+Block sizes tuned for SM120 shared memory budget (~98KB usable).
 """
 from __future__ import annotations
 
@@ -13,177 +15,61 @@
 
 @triton.jit
 def _grouped_gemm_swiglu_kernel(
-    X_ptr,
-    W_gate_ptr,
-    W_up_ptr,
-    Out_ptr,
-    Offsets_ptr,
+    X_ptr,             # (T_perm, H), row-major
+    W_gate_ptr,        # (E, H, I), row-major
+    W_up_ptr,          # (E, H, I), row-major
+    Out_ptr,           # (T_perm, I), row-major
+    Offsets_ptr,       # (E+1,) int32
+    CumMTiles_ptr,     # (E+1,) int32
     H,
     I_val,
-    stride_x_row,
-    stride_o_row,
     BM: tl.constexpr,
     BN: tl.constexpr,
     BK: tl.constexpr,
+    NUM_SMS: tl.constexpr,
 ):
-    pid_e = tl.program_id(0)
-    pid_bm = tl.program_id(1)
-    pid_n = tl.program_id(2)
+    pid = tl.program_id(0)
+    pid_n = tl.program_id(1)
 
-    # Expert row bounds
-    row_start = tl.load(Offsets_ptr + pid_e).to(tl.int64)
-    row_end = tl.load(Offsets_ptr + pid_e + 1).to(tl.int64)
+    # 1D -> (pid_m, ...) mapping
+    pid_m = pid
+
+    # Find expert: largest e such that CumMTiles[e] <= pid_m
+    # E is at most 128; binary search
+    lo = 0
+    hi = 127
+    while lo < hi:
+        mid = (lo + hi + 1) >> 1
+        if tl.load(CumMTiles_ptr + mid).to(tl.int32) <= pid_m:
+            lo = mid
+        else:
+            hi = mid - 1
+    expert = lo
+
+    # Intra-expert M-tile index
+    m_tile_idx = pid_m - tl.load(CumMTiles_ptr + expert).to(tl.int32)
+
+    # Expert row range
+    row_start = tl.load(Offsets_ptr + expert).to(tl.int64)
+    row_end = tl.load(Offsets_ptr + expert + 1).to(tl.int64)
     n_e = row_end - row_start
 
-    # Local row offsets within expert
-    m_local = pid_bm * BM + tl.arange(0, BM)
-    m_valid = m_local < n_e
-
-    # Global row indices
-    row_g = row_start + m_local
+    # Row offsets
+    m_offs = m_tile_idx * BM + tl.arange(0, BM)
+    m_valid = m_offs < n_e
+    row_g = row_start + m_offs
 
     # Column offsets
     n_offs = pid_n * BN + tl.arange(0, BN)
     n_valid = n_offs < I_val
 
-    # Reduction dimension offsets
-    k_offs = tl.arange(0, BK)
-
-    # Accumulators
-    acc_g = tl.zeros((BM, BN), dtype=tl.float32)
-    acc_u = tl.zeros((BM, BN), dtype=tl.float32)
-
-    # Weight base for this expert: W is (E, H, I), row-major
-    wg_base = W_gate_ptr + pid_e * H * I_val
-    wu_base = W_up_ptr + pid_e * H * I_val
-
-    # Main GEMM loop over K=H dimension
-    for k_start in range(0, H, BK):
-        k = k_start + k_offs
-        k_mask = k < H
-
-        # Load A: (BM, BK)
-        a_ptrs = X_ptr + row_g[:, None].to(tl.int64) * stride_x_row + k[None, :]
-        a = tl.load(a_ptrs, mask=m_valid[:, None] & k_mask[None, :], other=0.0)
-
-        # Load W_gate[e, k, n]: (BK, BN)
-        wg_ptrs = wg_base + k[:, None] * I_val + n_offs[None, :]
-        wg = tl.load(wg_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
-
-        # Load W_up[e, k, n]: (BK, BN)
-        wu_ptrs = wu_base + k[:, None] * I_val + n_offs[None, :]
-        wu = tl.load(wu_ptrs, mask=k_mask[:, None] & n_valid[None, :], other=0.0)
-
-        acc_g += tl.dot(a, wg)
-        acc_u += tl.dot(a, wu)
-
-    # SwiGLU
-    out_fp = acc_g * tl.sigmoid(acc_g) * acc_u
-    out_ptrs = Out_ptr + row_g[:, None].to(tl.int64) * stride_o_row + n_offs[None, :]
-    tl.store(out_ptrs, out_fp.to(tl.bfloat16), mask=m_valid[:, None] & n_valid[None, :])
-
-
-def _expert_grid(expert_offsets, BM, BN, I):
-    """Build 3D grid from expert_offsets: (E_list, M_list, N_list) where
-    E_list is just [0,1,...,E-1] repeated, and M_list/N_list are per-expert tile counts."""
-    E = len(expert_offsets) - 1
-    counts = expert_offsets[1:] - expert_offsets[:-1]  # tokens per expert
-    m_tiles = (counts + BM - 1) // BM  # cdiv per expert
-    n_tiles = (I + BN - 1) // BN
-
-    # Create grid as tensor of shape (total_work_items, 3)
-    # Actually, Triton doesn't support tensor grids. We need to use a flat list approach.
-    # Instead, let's use a 1D grid and decode (e, m, n) from the flat index.
-    total_work = int(m_tiles.sum()) * n_tiles
-    return E, int(m_tiles.sum()), n_tiles, m_tiles, total_work
-
-
-@triton.jit
-def _grouped_gemm_1d_kernel(
-    X_ptr,
-    W_gate_ptr,
-    W_up_ptr,
-    Out_ptr,
-    Offsets_ptr,
-    MTilesPrefix_ptr,  # cumulative sum of m_tiles: (E+1,) int32, prefix sum of m_tiles
-    H,
-    I_val,
-    total_m_tiles,
-    BN: tl.constexpr,
-    stride_x_row,
-    stride_o_row,
-    BM: tl.constexpr,
-    BLOCK_BN: tl.constexpr,
-    BK: tl.constexpr,
-):
-    pid_work = tl.program_id(0)
-    pid_n = tl.program_id(1)
-    pid_bm_in_expert = pid_work  # each pid_work handles one M-tile for one expert
-    # But we need to know WHICH expert AND which M-tile.
-    # Actually, let's restructure: grid = (total_m_tiles, cdiv(I, BN))
-    # Then for each pid_m in total_m_tiles, find which expert it belongs to.
-    pass
-
-
-# Simplest approach: 2D flattened grid (total_tiles_M, cdiv(I, BN))
-# where total_tiles_M = sum_e cdiv(n_tokens_e, BM)
-# Use binary search on cumulative_M_tiles to find expert index.
-
-
-@triton.jit
-def _grouped_gemm_flat_kernel(
-    X_ptr,
-    W_gate_ptr,
-    W_up_ptr,
-    Out_ptr,
-    Offsets_ptr,
-    CumMTiles_ptr,  # (E+1,) int32, cumulative M-tiles: CumMTiles[e] = sum of cdiv(n_tokens_e', BM) for e' < e
-    H,
-    I_val,
-    n_experts,
-    BM: tl.constexpr,
-    BN: tl.constexpr,
-    BK: tl.constexpr,
-):
-    pid_m = tl.program_id(0)
-  

... (3921 more chars)
args
{
  "content": "<4949 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me reconsider the approach. The binary search in Triton adds latency. Let me think about the actual compute bound vs memory bound aspects... Actually for H100/SM120 with 1.8TB/s and H=4096, I=1536: - FLOPs: 2 * T_total * K * 2 * H * I = 2 * 32768 * 8 * 2 * 4096 * 1536 = ~1.32 * 10^12 FMA ops - Bytes read X: T_perm * H * 2 = 262144 * 4096 * 2 = ~2.1 GB - Bytes read W: E * 2 * H * I * 2 = 128 * 2 * 4096 * 1536 * 2 = ~3.2 GB - Bytes write: T_perm * I * 2 = 262144 * 1536 * 2 = ~0.8 GB - Total bytes: ~6.1 GB So the peak is roughly 220 TFLOPs and 1800 GB/s. At compute bound, minimum time is 1.32e12 / 200e12 / 2 = ~3.3ms. At memory bound, 6.1e9 / 1800e9 = ~3.4ms. It's actually right at the bandwidth/compute balance. The issue with the current code is that loading X from global memory is expensive for large H. We need to load X row tiles into shared memory and compute everything at once, OR use Triton's automatic tiling more effectively. Let me step back and think about what's efficient here. Each token-tile (say 256 tokens) touches: - 256*4096 bf16 = 2MB of X input (streaming read, could use L2 cache) - 128 * 4096 * 1536 bf16 = ~1.56GB of weights The weights are the dominant bandwidth. Since weights are 128 experts * 2 matrix types * H * I, that's 3.2GB. If streamed through L2 and GEMM is compute bound, we want to maximize compute density by making BM large. Let me rewrite with a cleaner approach and then benchmark.
assistant
in 65493out 2434cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 180000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (2308 chars) — TRUNCATED
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] 
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]   File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]     start = int(expert_offsets[e].item())
W0429 08:49:30.800000 1044425 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0

... (808 more chars)
assistant
in 67946out 132cache_r 0cache_w 0
assistant
bash
args
{
  "command": "<1027 chars \u2014 see below>",
  "timeout": 60000,
  "description": "Debug timing issue"
}
command (1027 chars)
uv run python -c "
import torch
import sys
sys.path.insert(0, '.')
import solution
import reference
from shapes import SHAPES
import time

shape = SHAPES[0]
reference.T_total = shape['T_total']
reference.H = shape['H']
reference.I = shape['I']
reference.E = shape['E']
reference.K = shape['K']

init = reference.get_init_inputs()
model = solution.Model(*init).cuda().eval()
model.W_gate.data = torch.randn_like(model.W_gate)
model.W_up.data = torch.randn_like(model.W_up)

torch.manual_seed(42)
inputs = [t.cuda() for t in reference.get_inputs()]

# Single forward
torch.cuda.synchronize()
start = time.time()
out = model(*inputs)
torch.cuda.synchronize()
print(f'Single forward: {(time.time()-start)*1000:.1f} ms')
print(f'Output shape: {out.shape}')
print(f'Input T_perm: {inputs[0].shape}')

# Debug grid
HM = shape['H']
IM = shape['I']
BM = 256
BN = 128
counts = inputs[1][1:] - inputs[1][:-1]
m_tiles = (counts + BM - 1) // BM
print(f'Total M tiles: {m_tiles.sum().item()}')
print(f'N tiles: {(IM + BN - 1) // BN}')
" 2>&1
tool
stdout (140 chars)
Single forward: 1920.5 ms
Output shape: torch.Size([262144, 1536])
Input T_perm: torch.Size([262144, 4096])
Total M tiles: 1024
N tiles: 12
assistant
in 69241out 395cache_r 0cache_w 0