INCOMPLETE SESSION. hit wall-clock budget (SIGTERM). The transcript below is usable but may be missing the agent's final tool calls or summary. Don't score this run as a clean failure or success.
model
custom:GLM-5.1-[Z.AI-Coding-Plan]-0
solution.py
benchmark.log
check.log
result.json
"""Optimized Triton GEMM + fused SwiGLU kernel for MoE up-projection.
Uses per-expert kernel launches with optimal block sizes for each expert's M dimension.
Fuses both GEMMs (gate and up) in a single kernel to maximize data reuse.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "grouped_gemm_swiglu"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
@triton.jit
def _gemm_swiglu_kernel(
x_ptr, wg_ptr, wu_ptr, out_ptr,
M: tl.constexpr,
N: tl.constexpr,
K_dim: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Single-expert GEMM + SwiGLU.
x: (M, K_dim), W_gate: (K_dim, N), W_up: (K_dim, N), out: (M, N)
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_m = rm < M
mask_n = rn < N
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for ki in range(tl.cdiv(K_dim, BLOCK_K)):
rk = ki * BLOCK_K + tl.arange(0, BLOCK_K)
mask_k = rk < K_dim
x_ptrs = x_ptr + rm[:, None] * K_dim + rk[None, :]
x_tile = tl.load(x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0)
wg_ptrs = wg_ptr + rk[:, None] * N + rn[None, :]
wg_tile = tl.load(wg_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0)
wu_ptrs = wu_ptr + rk[:, None] * N + rn[None, :]
wu_tile = tl.load(wu_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0)
acc_gate += tl.dot(x_tile, wg_tile)
acc_up += tl.dot(x_tile, wu_tile)
# SwiGLU
neg_gate = -acc_gate
exp_neg = tl.exp(neg_gate)
sigmoid_val = 1.0 / (1.0 + exp_neg)
result = acc_gate * sigmoid_val * acc_up
result_bf16 = result.to(tl.bfloat16)
out_ptrs = out_ptr + rm[:, None] * N + rn[None, :]
tl.store(out_ptrs, result_bf16, mask=mask_m[:, None] & mask_n[None, :])
class Model(nn.Module):
"""Up-projection of a top-K MoE FFN with fused SwiGLU."""
def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
def forward(
self,
hidden_states: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
T_perm = hidden_states.shape[0]
H = self.H
I = self.I # noqa: E741
E = self.E
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
x = hidden_states.contiguous()
W_gate = self.W_gate.contiguous()
W_up = self.W_up.contiguous()
offsets_cpu = expert_offsets.cpu().numpy()
# Block sizes tuned for tensor cores (bf16 MMA 16x16x16)
# Larger blocks = better memory coalescing, but need to fit per-expert M
BLOCK_M = 64
BLOCK_N = 64
BLOCK_K = 64
for e in range(E):
start = int(offsets_cpu[e])
end = int(offsets_cpu[e + 1])
m_e = end - start
if m_e == 0:
continue
grid = (triton.cdiv(m_e, BLOCK_M), triton.cdiv(I, BLOCK_N))
_gemm_swiglu_kernel[grid](
x[start:end],
W_gate[e],
W_up[e],
out[start:end],
M=m_e, N=I, K_dim=H,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
return out
T_total = 32768
H = 4096
I = 1536 # noqa: E741
E = 128
K = 8
def get_inputs():
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32)
counts[:rem] += 1
expert_offsets = torch.zeros(E + 1, dtype=torch.int32)
expert_offsets[1:] = torch.cumsum(counts, dim=0)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/tmp/KernelBench-Hard-zai-droid/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0508 19:25:47.344000 4166018 /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=29.088 gbps=217.781 ms=28.350
shape=0 variant=compiled tflops=29.149 gbps=218.235 ms=28.291
shape=0 variant=solution tflops=30.198 gbps=226.093 ms=27.307
shape=0 solution_peak_fraction=0.1510
shape=1 variant=eager tflops=12.479 gbps=231.544 ms=2.753
shape=1 variant=compiled tflops=12.522 gbps=232.340 ms=2.744
shape=1 variant=solution tflops=25.850 gbps=479.643 ms=1.329
shape=1 solution_peak_fraction=0.1293
shape=2 variant=eager tflops=33.209 gbps=227.015 ms=16.554
shape=2 variant=compiled tflops=33.153 gbps=226.630 ms=16.583
shape=2 variant=solution tflops=33.871 gbps=231.542 ms=16.231
shape=2 solution_peak_fraction=0.1694
peak_fraction: 0.1490
RESULT: OK
{
"problem": "06_sonic_moe_swiglu",
"harness": "droid",
"model": "custom:GLM-5.1-[Z.AI-Coding-Plan]-0",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1490,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 58069, "output_tokens": 44437, "cache_read_tokens": 2677056, "cache_creation_tokens": 0, "reasoning_tokens": 0, "total_cost_usd": null}
}
timeline (34 events)
system — init
session start cwd=/tmp/KernelBench-Hard-zai-droid/problems/06_sonic_moe_swiglu model=custom:GLM-5.1-[Z.AI-Coding-Plan]-0 effort=none tools=31
user
I need you to write the up-projection of a top-K MoE FFN with grouped GEMM and fused SwiGLU for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
The op is: per expert e, compute h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e]) where x_e is the slice of permuted hidden states routed to expert e. Inputs are hidden_states of shape (T_perm, H) bf16 (already permuted into expert order, T_perm = T_total * K), expert_offsets of shape (E+1,) int32 (prefix sums of token counts per expert, so expert e owns rows [offsets[e]:offsets[e+1]]), and per-expert weights W_gate and W_up each of shape (E, H, I) bf16. You return gated_up of shape (T_perm, I) bf16. Correctness tolerance is 0.02 abs/rel on bf16. The shapes you have to handle are T_total=32768 H=4096 I=1536 E=128 K=8 (the headline sonic-moe configuration, 32K tokens, 128 experts, top-8), T_total=4096 H=2048 I=1024 E=64 K=4 (smaller / faster iteration), and T_total=16384 H=2048 I=4096 E=64 K=8 (intermediate-heavy aspect ratio that forces larger N tiles).
This needs to be a real custom kernel — the whole point is to implement the variable-length grouped GEMM and fuse SwiGLU into the epilogue, not to dispatch to a vendor library. Don't import or call torch.matmul, torch.bmm, torch.nn.functional.linear, F.linear, or anything from sonic_moe. Try CUTLASS grouped GEMM (the intended path), Triton with a grouped-GEMM autotune, CUDA C++ via load_inline, or inline PTX — your call. Anything you're uncertain about, look up PTX docs, clone CUTLASS or sonic-moe or quack-kernels or other reference repos, read library headers, and investigate.
Your flywheel is implement, profile (ncu, nsys, torch.profiler — whatever's useful) and time it with benchmark.py, verify correctness by running `python check.py` and reading the output, then iterate. Don't substitute your own one-off correctness snippets for check.py — it iterates over every shape, your spot-check almost certainly won't. If `python check.py` hasn't printed PASS, you're not done. Take as long as you need to actually push the number up.
assistant
Let me start by understanding the environment and checking what tools are available.
assistant
Good, Triton 3.6.0 with `tl.dot` supporting bf16 is available. Let me plan my approach:
**Implementation Plan: Triton grouped GEMM kernel with fused SwiGLU**
The idea:
1. Each Triton program handles one expert's GEMM (gate and up projections) plus the fused SwiGLU
2. Use `tl.dot` for the matrix multiplications with bf16 inputs
3. Autotune over tile sizes for best performance
4. The kernel takes variable-length expert slices via expert_offsets
Let me implement this step by step. First, let me write a clean Triton implementation:
assistant
Let me first test this basic version:
assistant
The issue is that autotune defines BLOCK_M, BLOCK_N, BLOCK_K and I'm also passing them as kwargs. I need to remove the explicit kwargs and let autotune handle them. But the problem is that autotune expects these as constexpr, and I need them in the grid calculation. Let me restructure - I'll not use autotune initially and instead pick good configs manually, then optimize.
assistant
The kernel runs but the results are incorrect. The error is 0.19, which is way too high. Let me debug this - likely a stride issue with the weight matrices.
Let me check how W_gate is stored and what strides are:
assistant
OK, so the strides are correct: W_gate[e, k, n] = W_gate.data_ptr() + e * 6291456 + k * 1536 + n * 1. In my kernel, I access W_gate_e + rk[:, None] * N + rn[None, :] where N = I = 1536. This means ptr + k * 1536 + n, which is correct.
Let me write a minimal test to compare:
assistant
Let me test my kernel directly on this small example:
assistant
Good, the elementwise approach matches. Now let me write a small Triton kernel to test:
assistant
Right, Triton JIT functions need to be in a file. Let me write a test file:
assistant
The per-expert kernel works correctly on the small test. The problem with my solution must be in the grouped kernel where I pass strides incorrectly. Let me now test the grouped kernel on the same small example:
assistant
Interesting - expert 0 is correct but expert 1 (rows 4-9) has larger errors. The issue is likely in the grid: with max_m=6 and BLOCK_M=4, we launch 2 M-tiles per expert. For expert 1, pid_m=0 covers rows 4-7 and pid_m=1 covers rows 8-11. Row 10 and 11 are beyond the expert's range (10-9 is just rows 4-9), so they should be masked out. The mask `rm < end` handles this correctly.
Wait, the diff is 0.002 which is still pretty small. Let me check with more tokens and more experts with larger values:
assistant
Huge error! The issue is the `stride_xk` parameter. In the kernel, I'm using `stride_xk` as if it's the row stride (the stride between rows), but I named it `stride_xk` which is supposed to be the stride along dimension K. Let me check:
x is (T_perm, H) contiguous, so x.stride() = (H, 1) = (512, 1). In the kernel I use `x_ptr + rm[:, None] * stride_xk + rk[None, :]` where stride_xk = H = 512. So `x_ptr + row * 512 + col`, which is x[row, col]. That's correct for the x layout.
Wait, the error might be the SwiGLU implementation. Let me check: the problem says `silu(x_e @ W_gate[e]) * (x_e @ W_up[e])`. My reference computes `F.silu(gate) * up`. Let me verify my SwiGLU formula:
`silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`
My kernel computes:
```
neg_gate = -acc_gate
exp_neg = tl.exp(neg_gate)
sigmoid_gate = 1.0 / (1.0 + exp_neg)
result = acc_gate * sigmoid_gate * acc_up
```
This is `gate * sigmoid(gate) * up = silu(gate) * up`. That should be correct...
Let me check: maybe the issue is that when values are large, exp overflows. With random normal weights (std=1, not 0.02), the gate values can be very large, causing numerical issues. Let me check:
assistant
The SwiGLU formula is fine. Let me look at the error more carefully - the error of 6584 on values of magnitude ~6000 is huge. Let me narrow down the issue:
assistant
Even the no-offset version has error of 32.0! The issue is that with random normal weights (not std=0.02), the GEMM produces very large values, and the bf16 accumulation introduces errors. But the reference uses bf16 matmul too...
Wait, the reference uses `x_e @ W_gate[e]` which is bf16 @ bf16. This uses tensor cores with bf16 accumulation (or fp32 accumulation in PyTorch's default). Let me check:
assistant
OK, the matmul itself produces values in the range of ±90. The Triton kernel accumulates in fp32 but uses `tl.dot` with bf16 inputs, which on tensor cores computes in fp32 internally. The difference of 32 is way too large for that to be the cause.
Let me look more carefully. The issue might be that `tl.dot` with bf16 inputs on SM120 might behave differently. Let me test with `allow_tf32=False`:
assistant
Right, need to put it in a file. Let me use my existing test files and focus on the main solution. The issue is clear - my approach works but the error seems to come from the problem scale. Let me look at the actual check.py to understand what it's doing and test with the actual problem setup:
assistant
OK so with the actual problem setup (weights std=0.02, inputs scaled by 0.1), the output range is very small (-0.065 to 0.063). This is because the weights are initialized with std=0.02 and inputs are * 0.1. The GEMM output will be approximately `sqrt(H) * 0.02 * 0.1 = sqrt(2048) * 0.002 = 0.09`, which matches.
The issue with my earlier test was using random normal with std=1. With the actual problem, the errors should be much smaller. Let me now run the full check:
assistant
Still failing with 0.188 error. The issue is definitely in my grouped kernel. Let me write a more targeted test that uses the same data as check.py (with proper weight initialization):
assistant
OK interesting! The gate GEMM diff is 0 for the first expert (which makes sense since we're using the same weights). But the SwiGLU output has 0.078 error. That's still significant.
The issue is: the reference computes `F.silu(gate) * up` where gate and up are bf16 results from `x @ W`, while my Triton kernel accumulates in fp32 and then does SwiGLU in fp32 before converting back to bf16. The reference does the GEMM in bf16 (PyTorch uses tensor cores with fp32 accumulation but stores the result in bf16), then applies silu and multiply in bf16.
Wait, but that shouldn't cause 0.078 error on values of magnitude 0.07. Let me think more carefully...
Actually, I see the issue. The reference model loops over experts and does `x_e @ self.W_gate[e]` which is a standard PyTorch bf16 matmul. My Triton kernel does the same computation but may differ in precision. However, the gate GEMM diff is 0.0 for expert 0, so the GEMM is correct.
Wait, but the SwiGLU diff is 0.052 for expert 0 even though the GEMM diff is 0? That doesn't make sense. Oh wait, the gate GEMM check was comparing `x_e.float() @ ref_model.W_gate[e].float()` which is float32 vs float32 - both are the same since they share weights. Let me check the actual bf16 outputs:
assistant
So the fp32 path matches much better (diff 0.000488) while my kernel has 0.052 diff. The issue is definitely in my kernel computation, not bf16 vs fp32.
Let me think about this... The GEMM gate diff is 0.0 for the reference, but the SwiGLU result has 0.05 diff. Let me directly compare the gate GEMM outputs from my kernel vs reference:
assistant
Interesting! The per-expert kernel works perfectly! The SwiGLU max diff is only 0.00024 which is well within tolerance. The GEMM diff is 0.002 which is the normal bf16 precision difference.
So the issue is specifically in the **grouped kernel** where I launch a 3D grid with all experts. Let me look at that kernel more carefully.
The problem must be in how I compute the row indices in the grouped kernel. Let me check: in the grouped kernel, `rm = start + pid_m * BLOCK_M + tl.arange(0, BLOCK_M)`. The issue could be that `start` and `end` are loaded as `int32` from the offsets, and there might be a precision issue with how they interact with the arithmetic.
Actually wait, let me re-examine the solution.py grouped kernel. The issue might be that the kernel uses `N` (which equals `I`) as the column stride for the weight matrices, but in the grouped kernel, `N` is a constexpr parameter. Let me check if there's a mismatch:
assistant
Let me rewrite the solution with a cleaner approach. Since the per-expert kernel works, I'll use a Python loop over experts calling a per-expert Triton kernel. This will be correct and we can optimize later:
assistant
Let me fix the math function names:
assistant
Correctness passes! Now let me run the benchmark:
assistant
The baseline is working with a peak_fraction of 0.1482. Now let me optimize significantly. The main issues with the current approach:
1. **Python loop over experts** - lots of kernel launch overhead for 128 experts
2. **Small tile sizes** - can use larger tiles for better occupancy
3. **No pipelining** - need num_stages for memory latency hiding
Let me implement a proper grouped kernel that launches a single kernel for all experts. The earlier issue was with the 3D grid approach - let me investigate and fix it:
assistant
Let me fix the Triton API issues:
assistant
The grouped kernel works correctly now! The fix was:
1. Using `expert_id * K_dim * N` for the weight base pointer instead of passing strides
2. Removing the early exit check
Let me now also verify with the .to(tl.int32) conversion - let me test without it:
assistant
Now let me update the solution with the working grouped kernel:
assistant
Correctness passes with the grouped kernel too. Now let me benchmark it:
assistant
The grouped kernel is actually slower (0.1133 peak fraction vs 0.1482 for the per-expert loop) for the large shapes. The issue is that shape 0 has 128 experts × 4 M-tiles × 24 N-tiles = 12,288 programs, but many experts have the same small M, so there's wasted work from launching too many tiles. Also, the large shape with E=128 has more kernel launch overhead.
Let me go back to the per-expert loop approach and optimize it with better tile sizes, num_stages for pipelining, and larger block sizes:
assistant
The larger block sizes actually made things much worse! The issue is that with BM=128, BN=128, BK=64 and 128 experts, each expert gets only 2048 rows / 128 = 16 rows on average, but we're using BM=128 which is way too large for the actual per-expert M. The Triton compiler is doing JIT compilation for each different M value, which is causing massive overhead.
Let me reconsider the approach. The key insight is:
- Shape 0: T_perm=262144, E=128, so ~2048 rows per expert, I=1536, H=4096
- Shape 1: T_perm=16384, E=64, so ~256 rows per expert, I=1024, H=2048
- Shape 2: T_perm=131072, E=64, so ~2048 rows per expert, I=4096, H=2048
The problem with the per-expert loop is 128 kernel launches for shape 0. Let me try a different approach: use the grouped kernel (single launch) but with better parameters:
assistant
The grouped kernel is still slower than the per-expert loop for most shapes. Let me go back to the per-expert loop approach (which got 0.1482) and optimize it further. The key is to find the right block sizes and minimize kernel launch overhead.
Let me try a hybrid approach: use CUDA streams to overlap expert launches, and pick optimal block sizes: