solution.py
final answer
benchmark.log
check.log
result.json
"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
Design:
- Gate/up weights are packed once into a single (E, H, 2I) tensor whose
columns interleave gate/up pairs (g0,u0,g1,u1,...). Each tile then needs a
single B stream and a single tl.dot; the SwiGLU pairing is register-local
in the mma accumulator layout (adjacent column pairs live in the same
thread), so the epilogue split costs no shuffles.
- Each program finds its expert by an in-register scan of expert_offsets
(no host sync, no extra kernel). Grid is sized for the worst case; the
few surplus programs exit immediately.
- A-row indices are clamped instead of masked: out-of-slice rows load
arbitrary in-bounds data and are discarded by the masked store. The k-loop
therefore has no load masks at all.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _grouped_swiglu_kernel(
x_ptr, # (T_perm, H) bf16
w_ptr, # (E, H, 2I) bf16, gate/up column-interleaved
out_ptr, # (T_perm, I) bf16
offs_ptr, # (E+1,) int32
T_perm,
H: tl.constexpr,
I: tl.constexpr,
E: tl.constexpr,
E_POW2: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, # packed width (2x output cols per tile)
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(T_perm, BLOCK_M) + E
num_pid_n = tl.cdiv(2 * I, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# map flat m-tile -> (expert, tile within expert), all in registers
eidx = tl.arange(0, E_POW2)
offs_vec = tl.load(offs_ptr + eidx, mask=eidx <= E, other=2147483647)
next_vec = tl.load(offs_ptr + eidx + 1, mask=eidx < E, other=2147483647)
counts = tl.where(eidx < E, next_vec - offs_vec, 0)
tiles = tl.cdiv(counts, BLOCK_M)
incl = tl.cumsum(tiles, axis=0)
total_m_tiles = tl.sum(tiles, axis=0)
if pid_m >= total_m_tiles:
return
e = tl.sum((incl <= pid_m).to(tl.int32), axis=0)
tile_start_e = tl.sum(tl.where(eidx == e, incl - tiles, 0), axis=0)
row_start = tl.load(offs_ptr + e)
row_end = tl.load(offs_ptr + e + 1)
row0 = row_start + (pid_m - tile_start_e) * BLOCK_M
rm = row0 + tl.arange(0, BLOCK_M)
rm_ld = tl.minimum(rm, T_perm - 1) # clamp; garbage rows masked at store
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + rm_ld[:, None] * H + rk[None, :]
w_off = e.to(tl.int64) * H * (2 * I)
w_ptrs = w_ptr + w_off + rk[:, None] * (2 * I) + rn[None, :]
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for _k in range(0, tl.cdiv(H, BLOCK_K)):
a = tl.load(x_ptrs)
w = tl.load(w_ptrs)
acc = tl.dot(a, w, acc)
x_ptrs += BLOCK_K
w_ptrs += BLOCK_K * (2 * I)
g, u = tl.split(tl.reshape(acc, (BLOCK_M, BLOCK_N // 2, 2)))
out = g * tl.sigmoid(g) * u
on = pid_n * (BLOCK_N // 2) + tl.arange(0, BLOCK_N // 2)
out_ptrs = out_ptr + rm[:, None] * I + on[None, :]
tl.store(out_ptrs, out.to(tl.bfloat16), mask=(rm < row_end)[:, None])
def _pick_config(T_perm: int, H: int, I: int, E: int): # noqa: E741
"""(BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, num_warps, num_stages)"""
if T_perm * I >= 64 * 1024 * 1024:
return (128, 256, 64, 8, 8, 3)
return (128, 128, 64, 8, 4, 3)
_launch_cache: dict = {}
def grouped_swiglu(
hidden_states: torch.Tensor,
w_packed: torch.Tensor, # (E, H, 2I) interleaved
expert_offsets: torch.Tensor,
I: int, # noqa: E741
) -> torch.Tensor:
T_perm, H = hidden_states.shape
E = w_packed.shape[0]
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
if T_perm == 0:
return out
cfg = _pick_config(T_perm, H, I, E)
BM, BN, BK, GM, warps, stages = cfg
E_POW2 = triton.next_power_of_2(E + 1)
grid0 = (triton.cdiv(T_perm, BM) + E) * triton.cdiv(2 * I, BN)
args = (hidden_states, w_packed, out, expert_offsets,
T_perm, H, I, E, E_POW2, BM, BN, BK, GM)
# Fast path: re-launch the cached compiled kernel directly, skipping the
# Triton JIT dispatch layer (~6us/call). Specialization safety: the key
# pins every value the binder specializes on; fresh torch allocations are
# always >=16B aligned, so pointer-alignment specialization is stable.
key = (T_perm, H, I, E, cfg, hidden_states.device.index)
compiled = _launch_cache.get(key)
if compiled is not None and (
hidden_states.data_ptr() | w_packed.data_ptr()
| out.data_ptr() | expert_offsets.data_ptr()
) % 16 == 0:
stream = torch.cuda.current_stream(hidden_states.device).cuda_stream
compiled.run(grid0, 1, 1, stream, compiled.function,
compiled.packed_metadata, None, None, None, *args)
return out
compiled = _grouped_swiglu_kernel[(grid0,)](
*args, num_warps=warps, num_stages=stages,
)
_launch_cache[key] = compiled
return out
class Model(nn.Module):
def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
self._w_packed: torch.Tensor | None = None
self.register_load_state_dict_pre_hook(self._invalidate_cache)
def _invalidate_cache(self, *args, **kwargs):
self._w_packed = None
def _packed(self) -> torch.Tensor:
wp = self._w_packed
if (
wp is None
or wp.device != self.W_gate.device
or wp.shape[1] != self.H
):
E, H, I = self.W_gate.shape # noqa: E741
wp = torch.empty(E, H, 2 * I, dtype=torch.bfloat16, device=self.W_gate.device)
wp[:, :, 0::2] = self.W_gate.detach()
wp[:, :, 1::2] = self.W_up.detach()
self._w_packed = wp
return wp
def forward(
self,
hidden_states: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
return grouped_swiglu(hidden_states, self._packed(), expert_offsets, self.I)
T_total = 32768
H = 4096
I = 1536 # noqa: E741
E = 128
K = 8
def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
T_perm = T_total * K
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
return offsets
def get_inputs():
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
expert_offsets = _build_routing(T_total, E, K)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]
Done. Final state of `solution.py` is verified and banked:
**Results** (benchmark.py, stable across 3 runs):
| shape | time | TFLOPS (graded) | peak_fraction |
|---|---|---|---|
| 0 (32768×4096×1536, E=128, K=8) | 18.28 ms | 45.1 | 0.2255 |
| 1 (4096×2048×1024, E=64, K=4) | 0.448 ms | 76.7 | 0.3835 |
| 2 (16384×2048×4096, E=64, K=8) | 12.21 ms | 45.0 | 0.2251 |
**Geomean peak_fraction: 0.2689** (up from 0.2403 for the first working version). `check.py` prints **PASS**, and smoke tests cover unbalanced/empty-expert/odd-dim routings.
**Final design** — single Triton kernel, no vendor GEMM:
- Gate/up weights packed once into (E, H, 2I) with column-interleaved pairs → one B-stream, one `tl.dot` per k-step; the SwiGLU split in the epilogue is register-local in the mma layout (zero shuffles).
- Device-side tile→expert mapping by in-register cumsum over `expert_offsets` (no host sync); worst-case grid with early-exit surplus CTAs.
- A-row clamping instead of masking → the k-loop has no load masks at all.
- Direct-launch cache bypasses Triton's JIT dispatch (~6µs/call), worth ~1% on the small shape.
The big shapes run at ~360 TF actual — at the sustained hardware ceiling (~358–361 TF) and faster than a cuBLAS bmm doing the matmuls alone (18.63/12.69 ms), since the fused epilogue halves output traffic. Shape 1 is memory-bound at ~95% of its DRAM floor after accounting for the harness's L2-flush writeback. All headroom hunted (TMA, persistent kernels, config sweeps, exact grids gave ≤1% and were rejected); I consider this converged.
warning: `VIRTUAL_ENV=/home/infatoshi/benchmarks/KernelBench-Hard/.venv` does not match the project environment path `/home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260611_050205_claude_claude-fable-5_06_sonic_moe_swiglu/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-11T11:56:23.663455+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-11T11:56:24.414697+00:00 elapsed_s=0.751 ms=18.258848
shape=0 variant=solution tflops=45.164 gbps=338.138 ms=18.259
shape=0 solution_peak_fraction=0.2258
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-11T11:56:30.995115+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-11T11:56:31.011981+00:00 elapsed_s=0.017 ms=0.448912
shape=1 variant=solution tflops=76.540 gbps=1420.176 ms=0.449
shape=1 solution_peak_fraction=0.3827
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-11T11:56:58.832100+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-11T11:56:59.197049+00:00 elapsed_s=0.365 ms=12.233808
shape=2 variant=solution tflops=44.937 gbps=307.189 ms=12.234
shape=2 solution_peak_fraction=0.2247
peak_fraction: 0.2688
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/benchmarks/KernelBench-Hard/.venv` does not match the project environment path `/home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260611_050205_claude_claude-fable-5_06_sonic_moe_swiglu/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260611_050205_claude_claude-fable-5_06_sonic_moe_swiglu/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260611_050205_claude_claude-fable-5_06_sonic_moe_swiglu/repo/.venv
Creating virtual environment at: /home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260611_050205_claude_claude-fable-5_06_sonic_moe_swiglu/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260611_050205_claude_claude-fable-5_06_sonic_moe_swiglu/repo
Built kernelbench-hard @ file:///home/infatoshi/benchmarks/KernelBench-Hard/outputs/runs/20260611_050205_claude_claude-fable-5_06_sonic_moe_swiglu/repo
Installed 56 packages in 76ms
PASS
{
"run_id": "20260611_050205_claude_claude-fable-5_06_sonic_moe_swiglu",
"run_group": "",
"problem": "06_sonic_moe_swiglu",
"harness": "claude",
"model": "claude-fable-5",
"reasoning_effort": "max",
"started_at": "2026-06-11T05:02:05-06:00",
"harness_finished_at": "2026-06-11T05:53:31-06:00",
"finished_at": "2026-06-11T05:56:59-06:00",
"start_epoch": 1781175725,
"harness_end_epoch": 1781178811,
"end_epoch": 1781179019,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.2688,
"template_mutated": false,
"elapsed_seconds": 3086,
"total_elapsed_seconds": 3294,
"check_elapsed_seconds": 122,
"benchmark_elapsed_seconds": 86,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 34.77932598833441,
"usage": {"input_tokens": 130, "output_tokens": 107329, "cache_read_tokens": 5652605, "cache_creation_tokens": 220924, "reasoning_tokens": null, "total_cost_usd": 4.7064315}
}
timeline (281 events)
system — init
session start cwd=/workspace/problems/06_sonic_moe_swiglu model=claude-fable-5 permissions=bypassPermissions
assistant
in 2out 39cache_r 17994cache_w 8002
assistant
I'll start by exploring the workspace to understand the problem structure.
in 2out 39cache_r 17994cache_w 8002
assistant
in 2out 39cache_r 17994cache_w 8002
tool
stdout (459 chars)
total 40
drwxrwxr-x 2 1002 1002 4096 Jun 11 11:02 .
drwxrwxr-x 3 1002 1002 4096 Jun 11 11:02 ..
-rw-rw-r-- 1 1002 1002 2409 Apr 27 09:21 PROMPT.txt
-rw-rw-r-- 1 1002 1002 5442 Jun 1 21:05 benchmark.py
-rw-rw-r-- 1 1002 1002 4091 Jun 2 01:08 check.py
-rw-rw-r-- 1 1002 1002 1839 Apr 26 04:20 problem.yaml
-rw-rw-r-- 1 1002 1002 3802 Apr 26 04:25 reference.py
-rw-rw-r-- 1 1002 1002 841 Apr 26 04:25 shapes.py
-rw-rw-r-- 1 1002 1002 2822 Apr 26 04:20 sota.py
assistant
in 2out 39cache_r 17994cache_w 8002
tool
stdout (4104 chars) — TRUNCATED
1 """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2
3 This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4 expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5 according to routing metadata; we compute, per expert e:
6
7 h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
8
9 where x_e is the slice of permuted hidden states routed to expert e, with
10 expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11
12 The reference loops over experts in Python. Slow, but pedagogically clear and
13 correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14 are NOT used here, but the reference is exempt — only solution.py is checked.
15 """
16 from __future__ import annotations
17
18 import torch
19 import torch.nn as nn
20 import torch.nn.functional as F
21
22 OP_TYPE = "grouped_gemm_swiglu"
23 SUPPORTED_PRECISIONS = ["bf16"]
24 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25
26
27 class Model(nn.Module):
28 """Up-projection of a top-K MoE FFN with fused SwiGLU.
29
30 Inputs at call time:
31 hidden_states: (T_perm, H) bf16, already permuted to expert order
32 expert_offsets: (E+1,) int32, prefix sums of token counts per expert
33 so expert e owns rows [offsets[e]:offsets[e+1]]
34 T_perm = T_total * K (each token visits K ex
... (2604 more chars)
assistant
in 2out 48cache_r 25996cache_w 2457
assistant
in 2out 48cache_r 25996cache_w 2457
tool
stdout (4455 chars) — TRUNCATED
1 """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5 """
6 import re
7 import sys
8 from pathlib import Path
9
10 import torch
11 import yaml
12
13 # Make the repo's src/ importable
14 REPO_ROOT = Path(__file__).resolve().parents[2]
15 sys.path.insert(0, str(REPO_ROOT))
16
17 from src.eval.correctness import check_correctness # noqa: E402
18 from src.eval.numeric_stress import ( # noqa: E402
19 numeric_stress_cases,
20 numeric_stress_context,
21 tolerance_for_case,
22 )
23
24
25 def main():
26 try:
27 import reference
28 import shapes
29 import solution
30 except Exception as e:
31 print(f"FAIL: import error: {e}")
32 sys.exit(1)
33
34 problem_yaml = Path("problem.yaml")
35 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
36
37 # --- Forbidden-op check ------------------------------------------------
38 sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
39 for forbidden in meta.get("forbidden", []):
40 pat = re.escape(forbidden)
41 if re.search(pat, sol_src):
42 print(f"FAIL: forbidden op used: {forbidden}")
43 sys.exit(1)
44
45 device = torch.device("cuda:0")
46 tol_override = meta.get("toleranc
... (2955 more chars)
assistant
in 2out 48cache_r 25996cache_w 2457
tool
stdout (5942 chars) — TRUNCATED
1 """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2
3 For each shape: times the agent's solution first. Optional eager, compiled,
4 and SOTA diagnostics are enabled with KBH_BENCHMARK_BASELINES=1. Reports achieved TFLOPS, GB/s, and peak_fraction.
5
6 Output lines the harness picks up:
7 shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8 peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9 """
10 import sys
11 from math import exp, log
12 from pathlib import Path
13
14 import torch
15 import yaml
16
17 REPO_ROOT = Path(__file__).resolve().parents[2]
18 sys.path.insert(0, str(REPO_ROOT))
19
20 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21 from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
22 from src.hardware import get as get_hw # noqa: E402
23
24
25 def _eval_formula(expr: str, vars: dict) -> float:
26 return float(eval(expr, {"__builtins__": {}}, vars))
27
28
29 def main():
30 import reference
31 import shapes
32 import solution
33
34 meta = yaml.safe_load(Path("problem.yaml").read_text())
35 hw = get_hw(meta["hardware"][0])
36 peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37 peak_gbps = hw.peak_bandwidth_gb_s
38 regime = meta.get("regime", "compute")
39 flops_formula = meta["flops_formula"]
40 bytes_formula = meta["bytes_formula"]
41 num_perf_trials = int(m
... (4442 more chars)
assistant
in 2out 48cache_r 25996cache_w 2457
tool
stdout (892 chars)
1 """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2
3 Defaults match the sonic-moe paper's headline configuration. We add:
4 - a smaller shape for fast iteration during agent development
5 - a wider intermediate (different aspect ratio) to stress N-tile selection
6 """
7
8 SHAPES = [
9 # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10 {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11
12 # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13 # variable-length grouped layout meaningful, but smaller token / hidden dims.
14 {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15
16 # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17 # Forces tiles to handle larger N relative to K.
18 {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19 ]
20
assistant
in 2out 48cache_r 25996cache_w 2457
tool
stdout (3029 chars) — TRUNCATED
1 """SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
2
3 Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
4 requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
5 SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
6 the package installs cleanly but kernels may fail at launch on SM120 (the
7 QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
8
9 If the live call fails, `is_available()` returns False and the benchmark scores
10 the agent against PyTorch eager + the documented H100 paper ceiling (see
11 problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
12 importing sonic_moe in solution.py (see problem.yaml.forbidden).
13 """
14 from __future__ import annotations
15
16 import torch
17
18
19 def _try_sonic_moe(
20 hidden_states: torch.Tensor,
21 W_gate: torch.Tensor,
22 W_up: torch.Tensor,
23 expert_offsets: torch.Tensor,
24 ) -> torch.Tensor | None:
25 try:
26 import sonic_moe # type: ignore # noqa: F401
27 except Exception:
28 return None
29 try:
30 # Public sonic-moe API surface is still stabilizing. The expected entry
31 # point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
32 # SwiGLU. Adapt to the actual signature once SM120 lands.
33 W = torch.cat([W_gate, W_up], dim=-1).contiguous() # (E, H, 2*I)
34 from sonic_moe import fuse
... (1529 more chars)
assistant
in 2out 48cache_r 25996cache_w 2457
tool
stdout (1974 chars) — TRUNCATED
1 name: 06_sonic_moe_swiglu
2 display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3 precision: bf16
4 regime: compute
5 hardware: [RTX_PRO_6000]
6 peak_tflops_key: bf16
7 peak_bandwidth_key: dram
8
9 # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10 # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11 # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12 # only if you count routing. The standard MoE FLOPs convention counts only the
13 # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14 flops_formula: "2 * T_total * H * (2 * I)"
15
16 # Bytes moved (approximate, lower bound):
17 # read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18 # expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19 bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20
21 tolerance:
22 bfloat16: 0.02
23
24 # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25 # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26 # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27 forbidden:
28 - "torch.matmul"
29 - "torch.bmm"
30 - "torch.nn.functional.linear"
31 - "F.linear"
32 - "from sonic_moe"
33 - "import sonic_moe"
34
35 sota:
36 name: "Sonic-MoE up-projection (Tri Dao)"
37 url: "https://github.com/Dao-AILab/sonic-moe"
38 fun
... (474 more chars)
assistant
in 2out 49cache_r 28453cache_w 7830
assistant
I'll inspect the harness internals (timing, correctness, stress cases) and the environment.
in 2out 49cache_r 28453cache_w 7830
assistant
in 2out 49cache_r 28453cache_w 7830
tool
stdout (2393 chars) — TRUNCATED
problems
pyproject.toml
src
uv.lock
__init__.py
__pycache__
eval
hardware
harness
kbh
sandbox
viewer
Thu Jun 11 11:04:49 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 610.43.02 KMD Version: 610.43.02 CUDA UMD Version: 13.3 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 Off | Off |
| 30% 37C P0 88W / 600W | 15MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID
... (893 more chars)
assistant
in 2out 56cache_r 36283cache_w 11085
assistant
Good — torch 2.11+cu130, Triton 3.6, sm_120. Now the harness internals.
in 2out 56cache_r 36283cache_w 11085
assistant
in 2out 56cache_r 36283cache_w 11085
tool
stdout (10953 chars) — TRUNCATED
"""Centralized GPU timing for benchmark.py files.
All problems' benchmark.py should call `time_fn` from here so we have one
implementation of warmup, L2 flush, and CUDA event capture to maintain.
Methodology:
- 10 warmup calls absorb Triton autotune (typical ~7 configs) and
torch.compile reduce-overhead CUDA-graph capture.
- Between each timed call, `_l2_flush()` writes 128 MB to evict L2
(Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
we measure HBM-load bandwidth, not L2-cached bandwidth).
- GPU timing via cuda.Event with synchronize() AFTER record() but
BEFORE elapsed_time() — the canonical NVIDIA pattern.
- Reported value is the median of `iters` trials, robust to outliers.
Notes / known biases not addressed here:
- torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
On small shapes where launch overhead matters, this gives compile an
artificial advantage. We accept this as the cost of using the agreed
"compiled" baseline policy.
- cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
- Median over a fairly small number of trials (default 30) is fine for
headline numbers but won't catch bimodal distributions. Use --extra
instrumentation if you ever care.
"""
from __future__ import annotations
import os
import re
import statistics
import time
from datetime import UTC, datetime
import torch
# Bl
... (9453 more chars)
assistant
in 2out 45cache_r 47368cache_w 5188
assistant
in 2out 45cache_r 47368cache_w 5188
tool
stdout (3231 chars) — TRUNCATED
73: "06_sonic_moe_swiglu": (
74- NumericStressCase("small_hidden", input_scales={0: 1e-2}, tolerance=_MED_BF16),
75- NumericStressCase(
76- "large_hidden",
77- input_scales={0: 8.0},
78- tolerance={"bfloat16": {"atol": 1e-1, "rtol": 5e-2}},
79- ),
80- ),
81- "07_w4a16_gemm": (
82- NumericStressCase("small_activation", input_scales={0: 1e-3}, tolerance=_SMALL_BF16),
83- NumericStressCase(
84- "large_activation",
85- input_scales={0: 64.0},
86- tolerance={"bfloat16": {"atol": 1.0, "rtol": 5e-2}},
87- ),
88- ),
89- "09_fmha_preattn_mrope": (
90- NumericStressCase("small_qk", input_scales={0: 1e-2, 1: 1e-2}, tolerance=_SMALL_BF16),
91- NumericStressCase("large_qk", input_scales={0: 16.0, 1: 16.0}, tolerance=_LARGE_BF16),
92- ),
93- "10_patch_embed_conv3d_gemm": (
---
"""Roofline math: achieved TFLOPS / GB/s, peak fraction.
FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
dense-equivalent algorithmic work; kernels that skip work (sparsity, early
exit) are still measured against the dense cost.
"""
from __future__ import annotations
def compute_tflops(flops: float, time_ms: float) -> float:
"""flops per invocation and wall time in ms -> TFLOPS."""
if time_ms <= 0:
return 0.0
return flops / (time_ms * 1e-3) / 1e12
def compute_gbps(bytes_moved: float, time_ms: float) -> float:
"""bytes
... (1731 more chars)
assistant
in 2out 64cache_r 52556cache_w 1921
assistant
in 2out 64cache_r 52556cache_w 1921
assistant
in 2out 64cache_r 52556cache_w 1921
tool
stdout (1526 chars) — TRUNCATED
cuda-bindings 13.2.0
cuda-pathfinder 1.5.3
cuda-toolkit 13.0.2
numpy 2.4.4
nvidia-cuda-cupti 13.0.85
nvidia-cuda-nvrtc 13.0.88
nvidia-cuda-runtime 13.0.96
torch 2.11.0
triton 3.6.0
---
81d243bd2c585b0f4821__mypyc.cpython-311-x86_64-linux-gnu.so
PIL
__pycache__
_distutils_hack
_editable_impl_kernelbench_hard.pth
_pytest
_virtualenv.pth
_virtualenv.py
_yaml
certifi
certifi-2026.4.22.dist-info
charset_normalizer
charset_normalizer-3.4.7.dist-info
contourpy
contourpy-1.3.3.dist-info
cuda
cuda_bindings-13.2.0.dist-info
cuda_pathfinder-1.5.3.dist-info
cuda_toolkit-13.0.2.dist-info
cycler
cycler-0.12.1.dist-info
dateutil
distutils-precedence.pth
einops
einops-0.8.2.dist-info
filelock
filelock-3.29.0.dist-info
fontTools
fonttools-4.62.1.dist-info
fsspec
fsspec-2026.3.0.dist-info
functorch
idna
idna-3.13.dist-info
iniconfig
iniconfig-2.3.0.dist-info
isympy.py
jinja2
jinja2-3.1.6.dist-info
kernelbench_hard-0.1.0.dist-info
kiwisolver
kiwisolver-1.5.0.dist-info
markdown_it
markdown_it_py-4.0.0.dist-info
markupsafe
markupsafe-3.0.3.dist-info
matplotlib
matplotlib-3.10.9.dist-info
mdurl
mdurl-0.1.2.dist-info
mpl_toolkits
mpmath
mpmath-1.3.0.dist-info
networkx
networkx-3.6.1.dist-info
numpy
numpy-2.4.4.dist-info
numpy.libs
nvidia
nvidia_cublas-13.1.0.3.dist-info
---
/kbh/bin/nvcc
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
Shell cwd was reset to /workspace/pr
... (26 more chars)
assistant
in 2out 50cache_r 54477cache_w 2951
assistant
in 2out 50cache_r 54477cache_w 2951
assistant
in 2out 50cache_r 54477cache_w 2951
assistant
in 2out 61cache_r 18087cache_w 49482
tool
stdout (82 chars)
Task #1 created successfully: Write Triton grouped GEMM + fused SwiGLU solution.py
assistant
in 2out 50cache_r 67569cache_w 255
tool
stdout (90 chars)
Task #2 created successfully: Verify correctness with check.py (all shapes + stress cases)
assistant
in 2out 50cache_r 67569cache_w 255
tool
stdout (82 chars)
Task #3 created successfully: Profile and tune (configs sweep, ncu) to push TFLOPS
assistant
in 2out 73cache_r 67824cache_w 486
assistant
in 2out 38cache_r 68310cache_w 90
assistant
Writing the first Triton solution.
in 2out 38cache_r 68310cache_w 90
assistant
in 2out 38cache_r 68310cache_w 90
tool
stdout (147 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/solution.py (file state is current in your context — no need to Read it back)
assistant
in 2out 40cache_r 68400cache_w 3397
assistant
Quick smoke test before the full check.
in 2out 40cache_r 68400cache_w 3397
assistant
in 2out 40cache_r 68400cache_w 3397
tool
stdout (144 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/smoke.py (file state is current in your context — no need to Read it back)
assistant
in 2out 63cache_r 71797cache_w 1290
subagent · 0 events · 0 tools
tool
stdout (164 chars)
balanced: max_abs=3.0518e-05 ok=True
unbalanced: max_abs=6.1035e-05 ok=True
one_expert_all: max_abs=6.1035e-05 ok=True
odd_dims: max_abs=1.2207e-04 ok=True
SMOKE OK
assistant
in 2out 36cache_r 73087cache_w 206
assistant
in 2out 36cache_r 73087cache_w 206
assistant
in 2out 36cache_r 73087cache_w 206
subagent · 0 events · 0 tools
assistant
in 2out 73cache_r 73293cache_w 202
assistant
in 2out 73cache_r 73293cache_w 202
assistant
in 2out 73cache_r 73293cache_w 202
subagent · 0 events · 0 tools
tool
stdout (332 chars)
shape=0 variant=solution tflops=40.795 gbps=305.428 ms=20.214
shape=0 solution_peak_fraction=0.2040
shape=1 variant=solution tflops=67.195 gbps=1246.781 ms=0.511
shape=1 solution_peak_fraction=0.3360
shape=2 variant=solution tflops=40.500 gbps=276.856 ms=13.574
shape=2 solution_peak_fraction=0.2025
peak_fraction: 0.2403
RESULT: OK
assistant
in 2out 93cache_r 73495cache_w 553
assistant
in 2out 93cache_r 73495cache_w 553
assistant
in 2out 93cache_r 73495cache_w 553
tool
stdout (150 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/bench_local.py (file state is current in your context — no need to Read it back)
assistant
in 2out 63cache_r 74048cache_w 3567
subagent · 0 events · 0 tools
tool
stdout (840 chars)
shape 0: {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8}
cublas bmm+silu: 21.008 ms actual 314.0 TFLOPS
cublas bmm only: 18.634 ms actual 354.0 TFLOPS
cublas dense MxKxN (262144x4096x3072): 18.692 ms actual 352.9 TFLOPS
solution: 20.259 ms actual 325.6 TFLOPS
shape 1: {'T_total': 4096, 'H': 2048, 'I': 1024, 'E': 64, 'K': 4}
cublas bmm+silu: 0.547 ms actual 251.4 TFLOPS
cublas bmm only: 0.470 ms actual 292.6 TFLOPS
cublas dense MxKxN (16384x2048x2048): 0.402 ms actual 341.6 TFLOPS
solution: 0.498 ms actual 276.1 TFLOPS
shape 2: {'T_total': 16384, 'H': 2048, 'I': 4096, 'E': 64, 'K': 8}
cublas bmm+silu: 15.991 ms actual 275.0 TFLOPS
cublas bmm only: 12.691 ms actual 346.5 TFLOPS
cublas dense MxKxN (131072x2048x8192): 12.723 ms actual 345.7 TFLOPS
solution: 13.663 ms actual 321.9 TFLOPS
assistant
in 2out 61cache_r 77615cache_w 676
assistant
in 2out 61cache_r 77615cache_w 676
assistant
in 2out 61cache_r 77615cache_w 676
tool
stdout (144 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/sweep.py (file state is current in your context — no need to Read it back)
assistant
in 2out 46cache_r 78291cache_w 4406
subagent · 0 events · 0 tools
tool
stdout (1399 chars)
shape 0: {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8}
BM=128 BN=64 BK=64 GM=8 w=8 s=4: 21.736 ms 303.5 TF
BM=128 BN=64 BK=64 GM=8 w=4 s=4: 22.140 ms 298.0 TF
BM=128 BN=64 BK=64 GM=8 w=4 s=5: FAIL
BM=128 BN=64 BK=32 GM=8 w=4 s=6: 21.371 ms 308.7 TF
BM=128 BN=128 BK=64 GM=8 w=8 s=3: 25.413 ms 259.6 TF
BM=128 BN=128 BK=64 GM=8 w=8 s=4: FAIL
BM=128 BN=128 BK=32 GM=8 w=8 s=4: 26.209 ms 251.7 TF
BM=128 BN=128 BK=64 GM=16 w=8 s=3: 25.432 ms 259.4 TF
BM=256 BN=64 BK=64 GM=8 w=8 s=3: 19.821 ms 332.8 TF
BM=256 BN=64 BK=32 GM=8 w=8 s=4: 19.989 ms 330.0 TF
BM=256 BN=128 BK=32 GM=8 w=8 s=3: 29.174 ms 226.1 TF
BM=256 BN=128 BK=64 GM=8 w=8 s=2: 1132.357 ms 5.8 TF
BM=256 BN=128 BK=64 GM=8 w=16 s=2: 1045.107 ms 6.3 TF
BM=256 BN=128 BK=32 GM=8 w=16 s=3: 35.376 ms 186.5 TF
BM=64 BN=128 BK=64 GM=8 w=4 s=4: FAIL
BM=64 BN=128 BK=128 GM=8 w=4 s=3: FAIL
BM=64 BN=256 BK=64 GM=8 w=8 s=3: FAIL
BM=64 BN=64 BK=64 GM=8 w=4 s=5: 25.367 ms 260.1 TF
BM=128 BN=256 BK=32 GM=8 w=8 s=2: 1443.594 ms 4.6 TF
BM=128 BN=256 BK=64 GM=8 w=8 s=2: 1120.358 ms 5.9 TF
BM=128 BN=256 BK=64 GM=8 w=16 s=3: FAIL
BM=256 BN=256 BK=32 GM=8 w=16 s=2: FAIL
BM=32 BN=128 BK=64 GM=8 w=4 s=4: FAIL
BM=16 BN=128 BK=64 GM=8 w=4 s=4: FAIL
BM=32 BN=256 BK=64 GM=8 w=4 s=4: FAIL
BM=16 BN=256 BK=64 GM=8 w=4 s=3: FAIL
best: (19.82051181793213, (256, 64, 64, 8, 8, 3))
assistant
in 2out 75cache_r 82697cache_w 1244
assistant
Best so far: BM=256, BN=64, BK=64, 8 warps, 3 stages → 19.82 ms (333 TF). Let me refine around that point and check why some configs fail.
in 2out 75cache_r 82697cache_w 1244
assistant
in 2out 75cache_r 82697cache_w 1244
subagent · 0 events · 0 tools
tool
stdout (654 chars)
(256, 64, 64, 4, 8, 3): 19.832 ms 332.6 TF
(256, 64, 64, 12, 8, 3): 19.986 ms 330.1 TF
(256, 64, 64, 16, 8, 3): 19.979 ms 330.2 TF
(256, 64, 64, 24, 8, 3): 19.929 ms 331.0 TF
(256, 64, 128, 8, 8, 2): 23.212 ms 284.2 TF
(256, 64, 128, 8, 8, 3): FAIL
(256, 64, 64, 8, 8, 4): FAIL
(256, 64, 64, 8, 16, 3): 24.572 ms 268.5 TF
(256, 64, 32, 8, 8, 5): 20.285 ms 325.2 TF
(256, 32, 64, 8, 8, 3): 23.266 ms 283.5 TF
(256, 32, 64, 8, 4, 4): FAIL
(512, 32, 64, 8, 8, 3): FAIL
(512, 64, 64, 8, 16, 2): 1004.713 ms 6.6 TF
(192, 64, 64, 8, 8, 3): FAIL
(256, 48, 64, 8, 8, 3): FAIL
best: (19.83243179321289, (256, 64, 64, 4, 8, 3))
assistant
in 2out 34cache_r 83941cache_w 3053
assistant
in 2out 34cache_r 83941cache_w 3053
assistant
in 2out 34cache_r 83941cache_w 3053
tool
stdout (711 chars)
/kbh/bin/ncu
/kbh/bin/nsys
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 482, in launch_metadata
self._init_handles()
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 456, in _init_handles
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 448, in raise_
raise err
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
assistant
in 2out 43cache_r 86994cache_w 1114
assistant
sm_120 CTA smem limit is 99 KB — that explains the fails. Profiling the best config with ncu.
in 2out 43cache_r 86994cache_w 1114
assistant
in 2out 43cache_r 86994cache_w 1114
tool
stdout (147 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/prof_one.py (file state is current in your context — no need to Read it back)
assistant
in 2out 73cache_r 88108cache_w 1030
assistant
in 2out 73cache_r 88108cache_w 1030
subagent · 0 events · 0 tools
tool
stdout (4451 chars) — TRUNCATED
Waves Per SM 1.50
-------------------------------- --------------- ---------------
OPT Est. Speedup: 50%
A wave of thread blocks is defined as the maximum number of blocks that can be executed in parallel on the
target GPU. The number of blocks in a wave depends on the number of multiprocessors and the theoretical
occupancy of the kernel. This kernel launch results in 1 full waves and a partial wave of 376 thread blocks.
Under the assumption of a uniform execution duration of all thread blocks, this partial wave may account for
up to 50.0% of the total runtime of this kernel. Try launching a grid with no partial wave. The overall
impact of this tail effect also lessens with the number of full waves executed for a grid. See the Hardware
Model (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-hw-model) description for
more details on launch configurations.
Section: Occupancy
------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
------------------------------- ----------- ------------
Max Active Clusters cluster 0
Max
... (2951 more chars)
assistant
in 2out 55cache_r 89138cache_w 1625
assistant
The tail cut off launch stats. Let me get the full picture.
in 2out 55cache_r 89138cache_w 1625
assistant
in 2out 55cache_r 89138cache_w 1625
subagent · 0 events · 0 tools
tool
stdout (6365 chars) — TRUNCATED
==PROF== Connected to process 1376 (/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/bin/python3.11)
==PROF== Profiling "distribution_elementwise_grid_stride_kernel": 0%....50%....100% - 9 passes
done 0.005214827135205269
==PROF== Disconnected from process 1376
[1376] python3.11@127.0.0.1
void unnamed>::distribution_elementwise_grid_stride_kernel<float, 4, void normal_and_transform<BFloat16, float, CUDAGeneratorImpl *, void normal_kernel<CUDAGeneratorImpl *>(const TensorBase &, double, double, T1)::[lambda() (instance 1)]::operator ()() lambda() (instance 4)]::operator ()() lambda(float) (instance 1)]>(TensorIteratorBase &, T3, T4)::[lambda(curandStatePhilox4_32_10 *) (instance 2)], void unnamed>::distribution_nullary_kernel<BFloat16, float, float4, CUDAGeneratorImpl *, void normal_and_transform<BFloat16, float, CUDAGeneratorImpl *, void normal_kernel<CUDAGeneratorImpl *>(const TensorBase &, double, double, T1)::[lambda() (instance 1)]::operator ()() lambda() (instance 4)]::operator ()() lambda(float) (instance 1)]>(TensorIteratorBase &, T3, T4)::[lambda(curandStatePhilox4_32_10 *) (instance 2)], void normal_kernel<CUDAGeneratorImpl *>(const TensorBase &, double, double, T1)::[lambda() (instance 1)]::operator ()() lambda() (instance 4)]::operator ()() lambda(float) (instance 1)]>(TensorIteratorBase &, T4, const T5 &, T6)::[lambda(int, float) (instance 1)]>(long, PhiloxCudaState, T3, T4) (1128, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 12.0
Section: GPU Speed
... (4865 more chars)
assistant
in 2out 66cache_r 90763cache_w 3317
assistant
The regex matched a torch RNG kernel. Anchoring it.
in 2out 66cache_r 90763cache_w 3317
assistant
in 2out 66cache_r 90763cache_w 3317
subagent · 0 events · 0 tools
tool
stdout (5450 chars) — TRUNCATED
==PROF== Connected to process 1497 (/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/bin/python3.11)
==PROF== Profiling "_kern": 0%....50%....100% - 9 passes
done 0.005214827135205269
==PROF== Disconnected from process 1497
[1497] python3.11@127.0.0.1
_kern (27648, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 12.0
Section: GPU Speed Of Light Throughput
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 13.35
SM Frequency Ghz 1.59
Elapsed Cycles cycle 43422961
Memory Throughput % 33.82
DRAM Throughput % 13.19
Duration ms 27.33
L1/TEX Cache Throughput % 29.72
L2 Cache Throughput % 33.82
SM Active Cycles cycle 43243644.34
Compute (SM) Throughput % 79.01
----------------------- ----------- ------------
OPT Compute is more heavily utilized than Memory: Look at the Compute Workload Analysis section to see what the
compute pipelines are spending their time doing. Also, consider whether any computation is redundant and
could be reduced or moved to look-up tables.
Section: Launch Sta
... (3950 more chars)
assistant
in 2out 58cache_r 94080cache_w 2060
assistant
in 2out 58cache_r 94080cache_w 2060
assistant
in 2out 58cache_r 94080cache_w 2060
subagent · 0 events · 0 tools
tool
stdout (257 chars)
2865 MHz, 242.60 W, 0 %
1882 MHz, 600.11 W, 100 %
1875 MHz, 600.01 W, 100 %
1875 MHz, 599.96 W, 100 %
1867 MHz, 600.01 W, 100 %
1867 MHz, 600.01 W, 100 %
1860 MHz, 599.96 W, 100 %
1860 MHz, 599.98 W, 100 %
1860 MHz, 599.95 W, 100 %
1860 MHz, 600.00 W, 100 %
assistant
in 2out 55cache_r 96140cache_w 3082
assistant
in 2out 55cache_r 96140cache_w 3082
assistant
in 2out 55cache_r 96140cache_w 3082
tool
stdout (105 chars)
---
/workspace/.venv/lib/python3.11/site-packages/triton/language/__init__.py
has descriptor api: True
ok
assistant
in 2out 49cache_r 99222cache_w 4166
tool
stdout (1637 chars) — TRUNCATED
303: nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
328: nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
99:def sm_arch_from_capability(capability: int):
101: suffix = "a" if capability >= 90 else ""
102: return f"sm_{capability}{suffix}"
165: capability = self._parse_arch(options.arch)
166: return f"cuda:{capability}"
179: capability = int(self._parse_arch(args["arch"]))
181: if args.get("num_ctas", 1) > 1 and capability < 90:
183: f"Current target is sm_{capability}. This configuration will fail. "
188: if capability >= 89:
193: if capability >= 90:
199: args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
212: capability = int(self._parse_arch(options.arch))
215: cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
230: def make_ttir(mod, metadata, opt, capability):
235: if capability // 10 < 9:
247: def make_ttgir(mod, metadata, opt, capability):
254: emuTF32 = (capability // 10 >= 8)
255: passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
265: passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
268: if capability // 10 in [8, 9]:
278: elif capability // 10 >= 10:
299: passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
302: if capability // 10 >= 9:
310:
... (137 more chars)
assistant
in 2out 35cache_r 103388cache_w 928
assistant
in 2out 35cache_r 103388cache_w 928
assistant
in 2out 35cache_r 103388cache_w 928
tool
stdout (147 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/kernels2.py (file state is current in your context — no need to Read it back)
assistant
in 2out 59cache_r 104316cache_w 4956
tool
stdout (145 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/sweep2.py (file state is current in your context — no need to Read it back)
assistant
in 2out 47cache_r 109272cache_w 1883
subagent · 0 events · 0 tools
tool
stdout (1001 chars)
shape 0: {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8}
v1c (256, 64, 64, 4, 8, 3): 18.892 ms 349.2 TF ok=True
v1c (256, 64, 64, 8, 8, 3): 19.017 ms 346.9 TF ok=True
v1c (128, 64, 64, 8, 8, 4): 22.086 ms 298.7 TF ok=True
v1c (128, 64, 32, 8, 4, 6): 21.218 ms 310.9 TF ok=True
v1c (128, 64, 64, 8, 4, 4): 20.961 ms 314.7 TF ok=True
v2 (256, 128, 64, 4, 8, 3): 18.427 ms 358.0 TF ok=True
v2 (256, 128, 64, 8, 8, 3): 18.434 ms 357.9 TF ok=True
v2 (128, 128, 64, 8, 8, 4): 20.555 ms 321.0 TF ok=True
v2 (128, 128, 64, 8, 4, 4): 20.234 ms 326.0 TF ok=True
v2 (128, 128, 32, 8, 4, 6): 20.820 ms 316.9 TF ok=True
v2 (128, 256, 64, 8, 8, 3): 18.326 ms 360.0 TF ok=True
v2 (64, 256, 64, 8, 4, 4): FAIL OutOfResources
v2 (256, 256, 32, 8, 8, 2): 61.210 ms 107.8 TF ok=True
v2 (256, 256, 64, 8, 16, 2): 635.247 ms 10.4 TF ok=True
v2 (128, 256, 32, 8, 8, 4): 19.023 ms 346.8 TF ok=True
best: (18.326143264770508, 'v2', (128, 256, 64, 8, 8, 3))
assistant
in 2out 55cache_r 111155cache_w 840
assistant
in 2out 55cache_r 111155cache_w 840
assistant
in 2out 55cache_r 111155cache_w 840
subagent · 0 events · 0 tools
tool
stdout (1987 chars) — TRUNCATED
shape 1: {'T_total': 4096, 'H': 2048, 'I': 1024, 'E': 64, 'K': 4}
v1c (256, 64, 64, 4, 8, 3): 0.456 ms 301.1 TF ok=True
v1c (256, 64, 64, 8, 8, 3): 0.471 ms 292.0 TF ok=True
v1c (128, 64, 64, 8, 8, 4): 0.478 ms 287.4 TF ok=True
v1c (128, 64, 32, 8, 4, 6): 0.458 ms 299.9 TF ok=True
v1c (128, 64, 64, 8, 4, 4): 0.471 ms 291.7 TF ok=True
v2 (256, 128, 64, 4, 8, 3): 0.452 ms 304.4 TF ok=True
v2 (256, 128, 64, 8, 8, 3): 0.464 ms 296.4 TF ok=True
v2 (128, 128, 64, 8, 8, 4): 0.452 ms 304.0 TF ok=True
v2 (128, 128, 64, 8, 4, 4): 0.447 ms 307.6 TF ok=True
v2 (128, 128, 32, 8, 4, 6): 0.455 ms 302.2 TF ok=True
v2 (128, 256, 64, 8, 8, 3): 0.454 ms 302.8 TF ok=True
v2 (64, 256, 64, 8, 4, 4): FAIL OutOfResources
v2 (256, 256, 32, 8, 8, 2): 1.358 ms 101.2 TF ok=True
v2 (256, 256, 64, 8, 16, 2): 13.194 ms 10.4 TF ok=True
v2 (128, 256, 32, 8, 8, 4): 0.459 ms 299.3 TF ok=True
best: (0.44679999351501465, 'v2', (128, 128, 64, 8, 4, 4))
shape 2: {'T_total': 16384, 'H': 2048, 'I': 4096, 'E': 64, 'K': 8}
v1c (256, 64, 64, 4, 8, 3): 12.453 ms 353.2 TF ok=True
v1c (256, 64, 64, 8, 8, 3): 12.677 ms 346.9 TF ok=True
v1c (128, 64, 64, 8, 8, 4): 14.839 ms 296.4 TF ok=True
v1c (128, 64, 32, 8, 4, 6): 14.174 ms 310.3 TF ok=True
v1c (128, 64, 64, 8, 4, 4): 14.116 ms 311.6 TF ok=True
v2 (256, 128, 64, 4, 8, 3): 12.376 ms 355.4 TF ok=True
v2 (256, 128, 64, 8, 8, 3): 12.471 ms 352.7 TF ok=True
v2 (128, 128, 64, 8, 8, 4): 13.
... (487 more chars)
assistant
in 2out 89cache_r 111995cache_w 2404
assistant
in 2out 89cache_r 111995cache_w 2404
assistant
in 2out 89cache_r 111995cache_w 2404
tool
stdout (264 chars)
(self, base: Any, shape: List[int], strides: List[int], block_shape: List[int], padding: str = 'zero') -> None
['from_tensor', 'padding']
['load_tensor_descriptor', 'make_tensor_descriptor', 'store_tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor_type']
assistant
in 2out 39cache_r 114399cache_w 3270
tool
stdout (147 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/kernels3.py (file state is current in your context — no need to Read it back)
assistant
in 2out 55cache_r 117669cache_w 1342
subagent · 0 events · 0 tools
tool
stdout (1075 chars)
shape 0
v3 (128, 256, 64, 8, 8, 3): 18.868 ms 349.7 TF ok=True
v3 (256, 128, 64, 4, 8, 3): 19.055 ms 346.2 TF ok=True
v3 (128, 128, 64, 8, 4, 4): 21.094 ms 312.7 TF ok=True
v3 (128, 256, 64, 8, 8, 2): 22.386 ms 294.7 TF ok=True
v3 (256, 256, 64, 8, 8, 2): 1088.020 ms 6.1 TF ok=True
v3 (128, 512, 64, 8, 8, 2): 1060.141 ms 6.2 TF ok=True
shape 1
v3 (128, 256, 64, 8, 8, 3): 0.463 ms 297.0 TF ok=True
v3 (256, 128, 64, 4, 8, 3): 0.464 ms 296.5 TF ok=True
v3 (128, 128, 64, 8, 4, 4): 0.476 ms 288.9 TF ok=True
v3 (128, 256, 64, 8, 8, 2): 0.560 ms 245.6 TF ok=True
v3 (256, 256, 64, 8, 8, 2): 21.768 ms 6.3 TF ok=True
v3 (128, 512, 64, 8, 8, 2): 22.291 ms 6.2 TF ok=True
shape 2
v3 (128, 256, 64, 8, 8, 3): 12.887 ms 341.3 TF ok=True
v3 (256, 128, 64, 4, 8, 3): 12.903 ms 340.9 TF ok=True
v3 (128, 128, 64, 8, 4, 4): 14.352 ms 306.4 TF ok=True
v3 (128, 256, 64, 8, 8, 2): 15.000 ms 293.2 TF ok=True
v3 (256, 256, 64, 8, 8, 2): 710.178 ms 6.2 TF ok=True
v3 (128, 512, 64, 8, 8, 2): 713.581 ms 6.2 TF ok=True
assistant
in 2out 75cache_r 119011cache_w 1515
assistant
in 2out 75cache_r 119011cache_w 1515
assistant
in 2out 75cache_r 119011cache_w 1515
subagent · 0 events · 0 tools
tool
stdout (1359 chars)
(128, 128, 64, 8, 4, 4): 0.450 ms 305.7 TF ok=True
(128, 128, 64, 4, 4, 4): 0.450 ms 305.6 TF ok=True
(128, 128, 64, 16, 4, 4): 0.453 ms 303.1 TF ok=True
(128, 128, 64, 1, 4, 4): 0.451 ms 304.7 TF ok=True
(128, 128, 32, 8, 4, 5): 0.454 ms 302.8 TF ok=True
(128, 128, 32, 8, 4, 4): 0.451 ms 305.0 TF ok=True
(128, 128, 64, 8, 4, 3): 0.449 ms 306.3 TF ok=True
(128, 128, 64, 8, 8, 3): 0.454 ms 302.5 TF ok=True
(64, 128, 64, 8, 4, 4): 0.511 ms 269.0 TF ok=True
(64, 128, 64, 8, 4, 5): 0.515 ms 266.8 TF ok=True
(64, 128, 32, 8, 4, 5): 0.493 ms 278.6 TF ok=True
(64, 64, 64, 8, 4, 5): 0.628 ms 218.8 TF ok=True
(64, 256, 64, 8, 8, 3): 0.472 ms 291.1 TF ok=True
(64, 256, 32, 8, 4, 3): 0.451 ms 304.6 TF ok=True
(64, 128, 128, 8, 4, 3): 0.536 ms 256.5 TF ok=True
(128, 64, 64, 8, 4, 4): 0.516 ms 266.4 TF ok=True
(256, 128, 64, 4, 8, 3): 0.451 ms 304.6 TF ok=True
(256, 64, 64, 4, 8, 3): 0.492 ms 279.3 TF ok=True
(32, 128, 64, 8, 4, 4): 0.719 ms 191.2 TF ok=True
(32, 256, 64, 8, 4, 3): 0.672 ms 204.5 TF ok=True
(128, 256, 64, 8, 8, 3): 0.455 ms 302.2 TF ok=True
(128, 128, 128, 8, 4, 2): 0.519 ms 264.8 TF ok=True
(64, 128, 64, 4, 4, 4): 0.513 ms 268.1 TF ok=True
(64, 128, 64, 16, 4, 4): 0.513 ms 268.1 TF ok=True
best: (0.4487360119819641, (128, 128, 64, 8, 4, 3))
assistant
in 2out 48cache_r 120526cache_w 3299
assistant
in 2out 48cache_r 120526cache_w 3299
assistant
in 2out 48cache_r 120526cache_w 3299
tool
stdout (157 chars)
The file /workspace/problems/06_sonic_moe_swiglu/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)
assistant
in 2out 48cache_r 123825cache_w 4655
assistant
in 2out 48cache_r 123825cache_w 4655
subagent · 0 events · 0 tools
tool
stdout (337 chars)
PASS
shape=0 variant=solution tflops=44.919 gbps=336.306 ms=18.358
shape=0 solution_peak_fraction=0.2246
shape=1 variant=solution tflops=76.090 gbps=1411.823 ms=0.452
shape=1 solution_peak_fraction=0.3804
shape=2 variant=solution tflops=44.768 gbps=306.028 ms=12.280
shape=2 solution_peak_fraction=0.2238
peak_fraction: 0.2674
RESULT: OK
assistant
in 2out 45cache_r 128480cache_w 594
assistant
in 2out 45cache_r 128480cache_w 594
assistant
in 2out 45cache_r 128480cache_w 594
tool
stdout (147 chars)
File created successfully at: /workspace/problems/06_sonic_moe_swiglu/kernels4.py (file state is current in your context — no need to Read it back)
assistant
in 2out 52cache_r 129074cache_w 4867
subagent · 0 events · 0 tools
tool
stdout (1199 chars)
SMs: 188
shape 1
v4 (128, 128, 64, 8, 4, 3) P=188: 0.452 ms 304.0 TF ok=True
v4 (128, 128, 64, 8, 4, 3) P=376: 0.468 ms 293.5 TF ok=True
v4 (128, 128, 64, 16, 4, 3) P=188: 0.455 ms 302.2 TF ok=True
v4 (128, 256, 64, 8, 8, 3) P=188: 0.461 ms 297.8 TF ok=True
v4 (128, 256, 64, 8, 8, 3) P=376: 0.457 ms 300.7 TF ok=True
v4 (256, 128, 64, 4, 8, 3) P=188: 0.457 ms 300.8 TF ok=True
shape 0
v4 (128, 128, 64, 8, 4, 3) P=188: 20.321 ms 324.6 TF ok=True
v4 (128, 128, 64, 8, 4, 3) P=376: 21.200 ms 311.2 TF ok=True
v4 (128, 128, 64, 16, 4, 3) P=188: 20.345 ms 324.3 TF ok=True
v4 (128, 256, 64, 8, 8, 3) P=188: 18.556 ms 355.5 TF ok=True
v4 (128, 256, 64, 8, 8, 3) P=376: 19.047 ms 346.4 TF ok=True
v4 (256, 128, 64, 4, 8, 3) P=188: 18.757 ms 351.7 TF ok=True
shape 2
v4 (128, 128, 64, 8, 4, 3) P=188: 13.698 ms 321.1 TF ok=True
v4 (128, 128, 64, 8, 4, 3) P=376: 14.203 ms 309.7 TF ok=True
v4 (128, 128, 64, 16, 4, 3) P=188: 13.760 ms 319.6 TF ok=True
v4 (128, 256, 64, 8, 8, 3) P=188: 12.448 ms 353.3 TF ok=True
v4 (128, 256, 64, 8, 8, 3) P=376: 12.934 ms 340.0 TF ok=True
v4 (256, 128, 64, 4, 8, 3) P=188: 12.606 ms 348.9 TF ok=True
assistant
in 2out 82cache_r 133941cache_w 1949
assistant
in 2out 82cache_r 133941cache_w 1949
assistant
in 2out 82cache_r 133941cache_w 1949
tool
stdout (816 chars)
(128, 128, 32, 8, 8, 4): 0.456 ms 301.3 TF ok=True
(128, 128, 32, 8, 8, 5): 0.469 ms 293.3 TF ok=True
(128, 128, 64, 8, 8, 2): 0.463 ms 296.6 TF ok=True
(128, 128, 32, 4, 8, 4): 0.448 ms 307.0 TF ok=True
(128, 128, 32, 16, 8, 4): 0.458 ms 299.9 TF ok=True
(256, 128, 32, 8, 8, 3): 0.467 ms 294.5 TF ok=True
(64, 128, 32, 8, 4, 6): 0.532 ms 258.2 TF ok=True
(64, 256, 64, 8, 8, 2): 0.490 ms 280.4 TF ok=True
(64, 256, 32, 8, 4, 4): 0.468 ms 293.8 TF ok=True
(64, 128, 32, 8, 8, 6): 0.543 ms 253.0 TF ok=True
(128, 128, 32, 8, 4, 4): 0.452 ms 304.4 TF ok=True
(128, 128, 16, 8, 8, 6): 0.471 ms 291.6 TF ok=True
(256, 128, 64, 8, 8, 3): 0.464 ms 296.0 TF ok=True
(128, 128, 64, 8, 4, 3): 0.449 ms 306.1 TF ok=True
best: (0.4476960003376007, (128, 128, 32, 4, 8, 4))
assistant
in 2out 57cache_r 135890cache_w 4368
assistant
Shape 1 is pinned at ~0.447. Let me profile it to see the actual stall mix before trying anything else.
in 2out 57cache_r 135890cache_w 4368
assistant
in 2out 57cache_r 135890cache_w 4368
tool
stdout (31 chars)
(Bash completed with no output)
assistant
in 2out 44cache_r 140258cache_w 2636
subagent · 0 events · 0 tools
tool
stdout (8614 chars) — TRUNCATED
L1/TEX Hit Rate % 0.10
L2 Compression Ratio 0
L2 Hit Rate % 70.78
Mem Pipes Busy % 16.03
-------------------- ----------- ------------
Section: Memory Workload Analysis Tables
OPT Est. Speedup: 12.31%
The memory access pattern for global loads from L2 might not be optimal. On average, only 22.1 of the 32
bytes transmitted per sector are utilized by each thread. This applies to the 99.9% of sectors missed in
L1TEX. This could possibly be caused by a stride between threads. Check the Source Counters section for
uncoalesced global loads.
Section: Scheduler Statistics
---------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
---------------------------- ----------- ------------
One or More Eligible % 14.12
Issued Warp Per Scheduler 0.14
No Eligible % 85.88
Active Warps Per Scheduler warp 1.00
Eligible Warps Per Scheduler warp 0.14
---------------------------- ----------- ------------
OPT Est. Local Speedup: 27.32%
... (7114 more chars)
assistant
in 2out 53cache_r 142894cache_w 2568
assistant
in 2out 53cache_r 142894cache_w 2568
subagent · 0 events · 0 tools
tool
stdout (1270 chars)
Section: GPU Speed Of Light Throughput
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 13.35
SM Frequency Ghz 1.53
Elapsed Cycles cycle 980036
Memory Throughput % 58.94
DRAM Throughput % 58.94
Duration us 629.38
L1/TEX Cache Throughput % 38.24
L2 Cache Throughput % 41.08
SM Active Cycles cycle 933449.66
Compute (SM) Throughput % 73.93
----------------------- ----------- ------------
Section: Compute Workload Analysis
-------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------- ----------- ------------
Executed Ipc Active inst/cycle 0.56
Executed Ipc Elapsed inst/cycle 0.55
Issue Slots Busy % 14.11
Issued Ipc Active inst/cycle 0.56
SM Busy % 76.48
-------------------- ----------- ------------
assistant
in 2out 64cache_r 145462cache_w 1256
assistant
in 2out 64cache_r 145462cache_w 1256
subagent · 0 events · 0 tools
tool
stdout (945 chars)
[3163] python3.11@127.0.0.1
kern_v2 (3072, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 12.0
Section: Command line profiler metrics
---------------------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
---------------------------------------------------------------------- ----------- ------------
dram__bytes_read.sum (!) n/a
dram__bytes_write.sum (!) n/a
l1tex__m_xbar2l1tex_read_bytes.sum Gbyte 2.15
sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_active (!) n/a
---------------------------------------------------------------------- ----------- ------------
assistant
in 2out 52cache_r 146718cache_w 2133
assistant
in 2out 52cache_r 146718cache_w 2133
assistant
in 2out 52cache_r 146718cache_w 2133
tool
stdout (188 chars)
exact-grid (128, 128, 64, 8, 4, 3): 0.445 ms 309.0 TF ok=True
exact-grid (128, 128, 32, 4, 8, 4): 0.446 ms 308.3 TF ok=True
exact-grid (128, 256, 64, 8, 8, 3): 0.452 ms 303.9 TF ok=True
assistant
in 2out 72cache_r 148851cache_w 4285
assistant
in 2out 72cache_r 148851cache_w 4285
assistant
in 2out 72cache_r 148851cache_w 4285
tool
stderr (1664 chars) — TRUNCATED
Exit code 1
def __getitem__(self, grid):
self._init_handles()
def runner(*args, stream=None):
if stream is None:
device = driver.active.get_current_device()
stream = driver.active.get_current_stream(device)
launch_metadata = self.launch_metadata(grid, stream, *args)
self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
return runner
Traceback (most recent call last):
File "<stdin>", line 5, in <module>
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1258, in getsource
lines, lnum = getsourcelines(object)
^^^^^^^^^^^^^^^^^^^^^^
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1059, in findsource
file = getsourcefile(object)
^^^^^^^^^^^^^^^^^^^^^
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 936, in getsourcefile
filename = getfile(object)
^^^^^^^^^^^^^^^
File "/uv-cache/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 916, in getfile
raise TypeEr
... (164 more chars)
assistant
in 2out 49cache_r 153136cache_w 3345
assistant
in 2out 49cache_r 153136cache_w 3345
tool
stdout (64 chars)
host us/call: 9.3 (gpu drain 83.3 ms total)
torch.empty us: 1.2
assistant
in 2out 36cache_r 156481cache_w 981
assistant
~8µs of Triton dispatch per call. I'll build a cached direct-launch wrapper and verify it empirically.
in 2out 36cache_r 156481cache_w 981
assistant
in 2out 36cache_r 156481cache_w 981
tool
stdout (470 chars)
returned: <class 'triton.compiler.compiler.CompiledKernel'>
direct run failed: function takes exactly 26 arguments (18 given)
Traceback (most recent call last):
File "<stdin>", line 38, in <module>
File "/workspace/.venv/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 713, in __call__
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
TypeError: function takes exactly 26 arguments (18 given)
assistant
in 2out 32cache_r 157462cache_w 3588
assistant
The launcher wants all 13 kernel params including constexprs. Retrying.
in 2out 32cache_r 157462cache_w 3588
assistant
in 2out 32cache_r 157462cache_w 3588
tool
stdout (73 chars)
direct run ok, match: True
direct host us/call: 2.14
timed e2e: 0.4664 ms
assistant
in 2out 61cache_r 161050cache_w 1486
assistant
in 2out 61cache_r 161050cache_w 1486
assistant
in 2out 61cache_r 161050cache_w 1486
tool
stdout (171 chars)
trial 0: jit 0.4425 ms direct 0.4379 ms delta 4.6 us
trial 1: jit 0.4438 ms direct 0.4374 ms delta 6.3 us
trial 2: jit 0.4430 ms direct 0.4472 ms delta -4.2 us
assistant
in 2out 27cache_r 162536cache_w 1162
assistant
in 2out 27cache_r 162536cache_w 1162
assistant
in 2out 27cache_r 162536cache_w 1162
tool
stdout (157 chars)
The file /workspace/problems/06_sonic_moe_swiglu/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)
compaction — compact_boundary
user
This session is being continued from a previous conversation that ran out of context. The summary below covers the earlier portion of the conversation.
Summary:
1. Primary Request and Intent:
Write the up-projection of a top-K MoE FFN with grouped GEMM and fused SwiGLU for RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). Per expert e: `h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])` where x_e = rows [offsets[e]:offsets[e+1]] of permuted hidden states. Inputs: hidden_states (T_perm, H) bf16 (T_perm = T_total*K), expert_offsets (E+1,) int32 prefix sums, W_gate/W_up each (E, H, I) bf16. Output: gated_up (T_perm, I) bf16. Tolerance 0.02 abs/rel bf16.
- Final answer must be in `/workspace/problems/06_sonic_moe_swiglu/solution.py` with same Model, get_inputs, get_init_inputs interface as reference.py (state_dict must match: W_gate, W_up params).
- Shapes: (T_total=32768,H=4096,I=1536,E=128,K=8), (T_total=4096,H=2048,I=1024,E=64,K=4), (T_total=16384,H=2048,I=4096,E=64,K=8).
- MUST be a real custom kernel implementing variable-length grouped GEMM with SwiGLU fused in epilogue. FORBIDDEN in solution.py (checked by grep): "torch.matmul", "torch.bmm", "torch.nn.functional.linear", "F.linear", "from sonic_moe", "import sonic_moe".
- Workflow: implement → profile (ncu/nsys) → time with benchmark.py → verify with `python check.py` (must print PASS; do not substitute one-off correctness snippets) → iterate. "Take as long as you need to actually push the number up."
- Run all Python through `uv run` (workspace uv environment matches scoring env; system python does NOT).
- Can make any mess in the directory (scratch files OK), but solution.py is the deliverable.
2. Key Technical Concepts:
- Grouped GEMM with variable-length expert slices; device-side tile→expert mapping via in-register cumsum over expert_offsets (no host sync); worst-case grid = (cdiv(T_perm,BM)+E)·n_tiles with early-exit surplus CTAs.
- Fused SwiGLU epilogue via column-interleaved packed weights: W packed (E,H,2I) with columns (g0,u0,g1,u1,...); single tl.dot per k-step; epilogue `g,u = tl.split(tl.reshape(acc,(BM,BN//2,2)))` is register-local (mma m16n8 acc holds adjacent column pairs per thread — no shuffles).
- Clamp-don't-mask A loads: `rm_ld = tl.minimum(rm, T_perm-1)`; out-of-slice rows load harmless neighbor data, discarded by masked store → zero load masks in hot loop.
- L2-friendly grouped swizzle (GROUP_M along M, m-fastest within group favors weight reuse since W is the larger tensor).
- SM120 constraints: 99KB (101376 B) smem/CTA limit; Triton allocates (num_stages−1) smem buffers; 255 reg/thread; 1 CTA/SM for the winning configs; sustained tensor clock ~1.86GHz @ 600W → practical bf16 peak ≈ 358-361 TFLOPS (1024 FLOP/clk/SM × 188 SMs); cuBLAS dense ≈ 353 TF (~99%).
- Benchmark methodology: median of 20 iters, 10 warmups (absorbs compile/pack), 128MB L2 flush between iters (dirty-writeback steals ~128MB DRAM time during kernel, raising shape1's effective floor); graded peak_fraction = (flops_formula 2·T_total·H·2I)/time vs 200 TFLOPS peak.
- Triton direct-launch bypass: cache CompiledKernel from first JIT launch; `compiled.run(grid0,1,1,stream,compiled.function,compiled.packed_metadata,None,None,None,*all_13_args_including_constexprs)` cuts host overhead 9.3µs→2.1µs.
- TMA (TensorDescriptor) tested and slightly slower than cp.async on sm_120; persistent kernel tested and slightly slower; both rejected.
- Numeric stress cases for this problem: input scale ×1e-2 (atol 5e-4, rtol 5e-2) and ×8.0 (atol 1e-1, rtol 5e-2); fp32 accumulate + fp32 silu passes all.
3. Files and Code Sections:
- `/workspace/problems/06_sonic_moe_swiglu/solution.py` (THE DELIVERABLE — current state):
- Triton kernel `_grouped_swiglu_kernel(x_ptr, w_ptr, out_ptr, offs_ptr, T_perm, H: tl.constexpr, I: tl.constexpr, E: tl.constexpr, E_POW2: tl.constexpr, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M: tl.constexpr)`: swizzle → expert lookup (`offs_vec`/`next_vec` loads with other=2147483647, counts, `tiles=tl.cdiv(counts,BLOCK_M)`, `incl=tl.cumsum(tiles)`, `e=tl.sum((incl<=pid_m).to(tl.int32))`, `tile_start_e=tl.sum(tl.where(eidx==e, incl-tiles, 0))`) → clamped A ptrs `x_ptr + rm_ld[:,None]*H + rk[None,:]`, W ptrs `w_ptr + e.to(tl.int64)*H*(2*I) + rk[:,None]*(2*I) + rn[None,:]` → k-loop unmasked loads + single `tl.dot` → split epilogue `out = g*tl.sigmoid(g)*u` → masked store `mask=(rm<row_end)[:,None]`.
- `_pick_config(T_perm,H,I,E)`: returns (128,256,64,8,8,3) if T_perm*I >= 64Mi else (128,128,64,8,4,3).
- `grouped_swiglu(hidden_states, w_packed, expert_offsets, I)` — JUST EDITED to add `_launch_cache` direct-launch fast path:
```python
_launch_cache: dict = {}
def grouped_swiglu(hidden_states, w_packed, expert_offsets, I):
T_perm, H = hidden_states.shape
E = w_packed.shape[0]
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
if T_perm == 0: return out
cfg = _pick_config(T_perm, H, I, E)
BM, BN, BK, GM, warps, stages = cfg
E_POW2 = triton.next_power_of_2(E + 1)
grid0 = (triton.cdiv(T_perm, BM) + E) * triton.cdiv(2 * I, BN)
args = (hidden_states, w_packed, out, expert_offsets, T_perm, H, I, E, E_POW2, BM, BN, BK, GM)
key = (T_perm, H, I, E, cfg, hidden_states.device.index)
compiled = _launch_cache.get(key)
if compiled is not None and (hidden_states.data_ptr() | w_packed.data_ptr() | out.data_ptr() | expert_offsets.data_ptr()) % 16 == 0:
stream = torch.cuda.current_stream(hidden_states.device).cuda_stream
compiled.run(grid0, 1, 1, stream, compiled.function, compiled.packed_metadata, None, None, None, *args)
return out
compiled = _grouped_swiglu_kernel[(grid0,)](*args, num_warps=warps, num_stages=stages)
_launch_cache[key] = compiled
return out
```
- `Model.__init__`: W_gate/W_up Parameters (bf16, normal std 0.02), `self._w_packed=None`, `self.register_load_state_dict_pre_hook(self._invalidate_cache)`.
- `Model._packed()`: lazily builds `wp = torch.empty(E,H,2I)`; `wp[:,:,0::2]=W_gate.detach(); wp[:,:,1::2]=W_up.detach()`; cached, invalidated on state_dict load or device change (checks `wp.device != self.W_gate.device`).
- `Model.forward` → `grouped_swiglu(hidden_states, self._packed(), expert_offsets, self.I)`.
- Module-level shims T_total/H/I/E/K, `_build_routing`, `get_inputs`, `get_init_inputs` copied from reference.
- `/workspace/problems/06_sonic_moe_swiglu/reference.py` (read-only): naive per-expert loop; Model has W_gate/W_up params.
- `check.py`: runs all shapes × seeds (42,123,456) × stress cases; `load_state_dict(sd, strict=True)`; emits framework.txt; prints PASS.
- `benchmark.py`: times solution via time_variant (iters=20, warmup=10); prints peak_fraction geomean.
- `bench_local.py` (created): `make_inputs(shape)`, `time_fn` (L2-flush timing clone), `cublas_ceiling`, `bench_solution`.
- `sweep.py`, `sweep2.py` (created): config sweeps; sweep2.main(sidx, variants) verifies vs cuBLAS ref and times v1c/v2 configs.
- `kernels2.py` (created): `_expert_lookup`, `_swizzle` jit helpers; `kern_v1c` (two-B clamped); `kern_v2` (interleaved single-dot — THE WINNING DESIGN); `pack_interleaved(W_gate,W_up)`; `run_v1c`, `run_v2(x, w_packed, offs, out, cfg, E)`.
- `kernels3.py` (created): `kern_v3` TMA variant using `TensorDescriptor(x,[T_perm,H],[H,1],[BM,BK])` and W as (E*H, 2I); rejected (slower).
- `kernels4.py` (created): `kern_v4` persistent variant; rejected (slower).
- `smoke.py` (created): fast correctness on unbalanced/empty-expert routings vs naive loop.
- `prof_one.py`, `prof1.py` (created): ncu attach targets.
- Harness files read: `/workspace/src/eval/timing.py`, `correctness.py`, `numeric_stress.py`, `roofline.py`, `/workspace/src/hardware/rtx_pro_6000.py` (peak bf16=200, bandwidth 1800).
4. Errors and fixes:
- Triton OutOfResources "shared memory, Required: 147456, Hardware limit: 101376": sm_120 allows ~99KB smem/CTA; Triton allocates (num_stages−1) buffers of (BM·BK + BK·BN)·2 bytes. Fixed by choosing configs fitting the budget (e.g., s=3 with 48KB/stage works).
- ncu regex "_kern" matched torch's `distribution_elementwise_grid_stride_kernel` (substring): fixed with anchored `-k "regex:^_kern$"`.
- Pathologically slow configs (1000+ ms) with BN=256/512 + s=2 + w=8/16: register spills; avoided in config table.
- Direct `compiled.run` first failed: "function takes exactly 26 arguments (18 given)" — launcher requires ALL kernel params including constexprs; fixed by passing all 13 args (x, wp, out, offs, T_perm, H, I, E, E_POW2, BM, BN, BK, GM) after the 9 launcher-prefix args.
- dram__bytes metrics "n/a" on this ncu/chip; used `l1tex__m_xbar2l1tex_read_bytes.sum` (2.15GB = exactly minimal) and SOL DRAM% instead.
- heredoc piping to ncu wrote temp script to /tmp then failed silently due to cwd-dependent imports; fixed by copying prof script into workspace dir.
5. Problem Solving:
- Established true hardware ceiling empirically: ~358-361 TF sustained bf16 (power-limited 1.86GHz), cuBLAS dense ≈353 TF; graded peak uses 200 TF so fractions can exceed intuitive bounds.
- Progression: v1 autotuned two-acc 20.2ms → v1 sweep (256,64,64) 19.8ms → v2 interleaved+clamped (128,256,64,8,8,3) 18.33ms/360TF shape0 — beats cuBLAS bmm-only (18.63ms) since fused epilogue writes half the output.
- Shape2: 12.16-12.28ms/358-362TF (beats bmm 12.69ms).
- Shape1 (0.445-0.452ms/304-308TF) analyzed to be near its adjusted roofline: compute floor 0.381ms, DRAM floor 0.354ms (637MB) + ~128MB L2-flush dirty writeback ≈ 0.42-0.425ms practical floor → ~95% achieved. Persistent kernels, occupancy configs, exact grid, TMA all gave ≤1%.
- Launch overhead: 9.3µs/call (8µs Triton JIT dispatch + 1.2µs torch.empty); direct-launch caching cuts to ~2µs, worth ~1% on shape1.
- Verified results banked: check.py PASS; benchmark peak_fraction 0.2403 (v1) → 0.2674 (v2: 18.358/0.452/12.280 ms; tflops 44.9/76.1/44.8).
6. All user messages:
- Initial task message (full spec): "I need you to write the up-projection of a top-K MoE FFN with grouped GEMM and fused SwiGLU for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory... the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py. ... This needs to be a real custom kernel — the whole point is to implement the variable-length grouped GEMM and fuse SwiGLU into the epilogue, not to dispatch to a vendor library. Don't import or call torch.matmul, torch.bmm, torch.nn.functional.linear, F.linear, or anything from sonic_moe. Try CUTLASS grouped GEMM (the intended path), Triton with a grouped-GEMM autotune, CUDA C++ via load_inline, or inline PTX — your call. ... Your flywheel is implement, profile (ncu, nsys, torch.profiler — whatever's useful) and time it with benchmark.py, verify correctness by running `python check.py` and reading the output, then iterate. Don't substitute your own one-off correctness snippets for check.py... If `python check.py` hasn't printed PASS, you're not done. Take as long as you need to actually push the number up. Workspace isolation note: you are already running inside the archive-local problem workspace, /workspace/problems/06_sonic_moe_swiglu. Write the final answer to solution.py in the current directory only. ... Run all Python through `uv run ...`..."
- (Subsequent messages were tool results and system reminders only; "Tool loaded." after ToolSearch for task tools.)
7. Pending Tasks:
- Task #3 (in_progress): "Profile and tune (configs sweep, ncu) to push TFLOPS" — currently finishing the direct-launch optimization.
- Verify the just-edited solution.py (direct-launch cache) with `uv run python check.py` (must PASS) and `uv run python benchmark.py`; run benchmark multiple times to confirm stability/variance.
- Optionally mark task #3 completed after final verification; clean-up is not required (mess allowed).
8. Current Work:
Immediately before this summary, I edited `solution.py`'s `grouped_swiglu` to add the direct-launch fast path (code shown in section 3): a module-level `_launch_cache` dict keyed on `(T_perm, H, I, E, cfg, device.index)`; on cache hit with 16B-aligned pointers it calls `compiled.run(grid0, 1, 1, stream, compiled.function, compiled.packed_metadata, None, None, None, *args)` directly (bypassing ~6µs of Triton JIT dispatch); on miss it does the normal JIT launch `_grouped_swiglu_kernel[(grid0,)](*args, num_warps=warps, num_stages=stages)` and stores the returned CompiledKernel. This edit has NOT yet been re-verified with check.py or benchmark.py. Prior verified state: check.py PASS, peak_fraction 0.2674 (shape0 18.358ms, shape1 0.452ms, shape2 12.280ms). The A/B measurement justifying the edit: "trial 0: jit 0.4425 ms direct 0.4379 ms delta 4.6 us; trial 1: jit 0.4438 direct 0.4374 delta 6.3 us; trial 2: jit 0.4430 direct 0.4472 delta -4.2 us" (≈1% on shape 1, noise-level elsewhere).
9. Optional Next Step:
Run `uv run python check.py` (expect PASS — per the user: "If `python check.py` hasn't printed PASS, you're not done") and then `uv run python benchmark.py` (2-3 repeats) to confirm the direct-launch edit is correct and quantify the improvement over peak_fraction 0.2674. If PASS and stable/improved, mark task #3 progress accordingly; if regression or failure, revert the fast path to the plain JIT launch (the previously verified 0.2674 configuration). Per my last assistant message before the edit: "Direct launch nets ~1% on shape 1 (noisy but real, ~5µs). Integrating it into solution.py with a JIT fallback for non-standard inputs." — the integration is done; verification is the immediate remaining step.
If you need specific details from before compaction (like exact code snippets, error messages, or content you generated), read the full transcript at: /home/agent/.claude/projects/-workspace-problems-06-sonic-moe-swiglu/fc911f45-988a-488c-90ed-781695ae5385.jsonl
Continue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, do not preface with "I'll continue" or similar. Pick up the last task as if the break never happened.
assistant
in 2out 37cache_r 18087cache_w 25068
assistant
Verifying the direct-launch edit with check.py.
in 2out 37cache_r 18087cache_w 25068
assistant
in 2out 37cache_r 18087cache_w 25068
subagent · 0 events · 0 tools
assistant
in 2out 48cache_r 43155cache_w 226
assistant
in 2out 48cache_r 43155cache_w 226
assistant
in 2out 48cache_r 43155cache_w 226
subagent · 0 events · 0 tools
tool
stdout (994 chars)
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-11T11:47:08.480066+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-11T11:47:09.219807+00:00 elapsed_s=0.740 ms=18.281183
shape=0 variant=solution tflops=45.108 gbps=337.725 ms=18.281
shape=0 solution_peak_fraction=0.2255
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-11T11:47:15.793303+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-11T11:47:15.810471+00:00 elapsed_s=0.017 ms=0.448896
shape=1 variant=solution tflops=76.543 gbps=1420.227 ms=0.449
shape=1 solution_peak_fraction=0.3827
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-11T11:47:43.620375+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-11T11:47:43.985041+00:00 elapsed_s=0.365 ms=12.215824
shape=2 variant=solution tflops=45.004 gbps=307.642 ms=12.216
shape=2 solution_peak_fraction=0.2250
peak_fraction: 0.2688
RESULT: OK
assistant
in 2out 40cache_r 43381cache_w 696
assistant
0.2688 (up from 0.2674). Running twice more to confirm stability.
in 2out 40cache_r 43381cache_w 696
assistant
in 2out 40cache_r 43381cache_w 696
subagent · 0 events · 0 tools
tool
stdout (321 chars)
shape=0 variant=solution tflops=45.053 gbps=337.312 ms=18.304
shape=0 solution_peak_fraction=0.2253
shape=1 variant=solution tflops=76.707 gbps=1423.271 ms=0.448
shape=1 solution_peak_fraction=0.3835
shape=2 variant=solution tflops=44.967 gbps=307.389 ms=12.226
shape=2 solution_peak_fraction=0.2248
peak_fraction: 0.2688
assistant
in 2out 58cache_r 44077cache_w 459
subagent · 0 events · 0 tools
tool
stdout (321 chars)
shape=0 variant=solution tflops=45.067 gbps=337.414 ms=18.298
shape=0 solution_peak_fraction=0.2253
shape=1 variant=solution tflops=76.679 gbps=1422.763 ms=0.448
shape=1 solution_peak_fraction=0.3834
shape=2 variant=solution tflops=45.020 gbps=307.753 ms=12.211
shape=2 solution_peak_fraction=0.2251
peak_fraction: 0.2689
assistant
in 2out 44cache_r 44536cache_w 320
assistant
Stable across three runs: 0.2688 / 0.2688 / 0.2689 (was 0.2674). Re-running smoke.py to exercise the direct-launch path on unbalanced/empty-expert routings.
in 2out 44cache_r 44536cache_w 320
assistant
in 2out 44cache_r 44536cache_w 320
tool
stdout (733 chars)
Traceback (most recent call last):
File "/workspace/problems/06_sonic_moe_swiglu/smoke.py", line 42, in <module>
run_case("balanced", [128] * 16, 256, 512, 16)
File "/workspace/problems/06_sonic_moe_swiglu/smoke.py", line 32, in run_case
out = solution.grouped_swiglu(x, wg, wu, offs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/06_sonic_moe_swiglu/solution.py", line 110, in grouped_swiglu
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: empty(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got Tensor"
assistant
in 2out 32cache_r 44856cache_w 2883
assistant
smoke.py targets the old signature. Updating it to the packed-weights API.
in 2out 32cache_r 44856cache_w 2883
assistant
in 2out 32cache_r 44856cache_w 2883
tool
stdout (1732 chars) — TRUNCATED
1 """Fast correctness smoke: small + unbalanced/empty-expert routing vs naive loop."""
2 import torch
3 import torch.nn.functional as F
4
5 import solution
6
7
8 def ref(x, wg, wu, offs):
9 out = torch.empty(x.shape[0], wg.shape[2], dtype=x.dtype, device=x.device)
10 E = wg.shape[0]
11 for e in range(E):
12 s, t = int(offs[e]), int(offs[e + 1])
13 if t == s:
14 continue
15 g = x[s:t] @ wg[e]
16 u = x[s:t] @ wu[e]
17 out[s:t] = F.silu(g) * u
18 return out
19
20
21 def run_case(name, counts, H, I, E):
22 dev = "cuda"
23 counts = torch.tensor(counts, dtype=torch.int32)
24 offs = torch.zeros(E + 1, dtype=torch.int32)
25 offs[1:] = torch.cumsum(counts, 0)
26 T = int(offs[-1])
27 torch.manual_seed(0)
28 x = (torch.randn(T, H, dtype=torch.bfloat16) * 0.1).to(dev)
29 wg = (torch.randn(E, H, I, dtype=torch.bfloat16) * 0.02).to(dev)
30 wu = (torch.randn(E, H, I, dtype=torch.bfloat16) * 0.02).to(dev)
31 offs = offs.to(dev)
32 out = solution.grouped_swiglu(x, wg, wu, offs)
33 r = ref(x, wg, wu, offs)
34 diff = (out.float() - r.float()).abs()
35 rel = diff / r.float().abs().clamp_min(1e-6)
36 ok = torch.allclose(out.float(), r.float(), atol=2e-2, rtol=2e-2)
37 print(f"{name}: max_abs={diff.max():.4e} ok={ok}")
38 assert ok, name
39
40
41 if __name__ == "__main__":
42 run_case("balanced", [128] * 16, 256, 512, 16)
43 run_case("unbalanced", [
... (232 more chars)
assistant
in 2out 48cache_r 47739cache_w 1382
tool
stdout (154 chars)
The file /workspace/problems/06_sonic_moe_swiglu/smoke.py has been updated successfully. (file state is current in your context — no need to Read it back)
assistant
in 2out 66cache_r 49121cache_w 255
tool
stdout (154 chars)
The file /workspace/problems/06_sonic_moe_swiglu/smoke.py has been updated successfully. (file state is current in your context — no need to Read it back)
assistant
in 2out 65cache_r 49376cache_w 272
tool
stdout (164 chars)
balanced: max_abs=3.0518e-05 ok=True
unbalanced: max_abs=6.1035e-05 ok=True
one_expert_all: max_abs=6.1035e-05 ok=True
odd_dims: max_abs=1.2207e-04 ok=True
SMOKE OK
assistant
in 2out 40cache_r 49648cache_w 232
assistant
in 2out 40cache_r 49648cache_w 232
assistant
in 2out 40cache_r 49648cache_w 232
assistant
in 2out 73cache_r 49880cache_w 2543
assistant
in 2out 37cache_r 52423cache_w 88
assistant
Saving transferable hardware/harness findings to memory, then final verification state.
in 2out 37cache_r 52423cache_w 88
assistant
in 2out 37cache_r 52423cache_w 88
tool
stdout (201 chars)
File created successfully at: /home/agent/.claude/projects/-workspace-problems-06-sonic-moe-swiglu/memory/project_sm120_kernel_facts.md (file state is current in your context — no need to Read it back)
assistant
in 2out 37cache_r 52423cache_w 88
tool
stdout (181 chars)
File created successfully at: /home/agent/.claude/projects/-workspace-problems-06-sonic-moe-swiglu/memory/MEMORY.md (file state is current in your context — no need to Read it back)
assistant
in 2out 30cache_r 52511cache_w 1455
assistant
Done. Final state of `solution.py` is verified and banked:
**Results** (benchmark.py, stable across 3 runs):
| shape | time | TFLOPS (graded) | peak_fraction |
|---|---|---|---|
| 0 (32768×4096×1536, E=128, K=8) | 18.28 ms | 45.1 | 0.2255 |
| 1 (4096×2048×1024, E=64, K=4) | 0.448 ms | 76.7 | 0.3835 |
| 2 (16384×2048×4096, E=64, K=8) | 12.21 ms | 45.0 | 0.2251 |
**Geomean peak_fraction: 0.2689** (up from 0.2403 for the first working version). `check.py` prints **PASS**, and smoke tests cover unbalanced/empty-expert/odd-dim routings.
**Final design** — single Triton kernel, no vendor GEMM:
- Gate/up weights packed once into (E, H, 2I) with column-interleaved pairs → one B-stream, one `tl.dot` per k-step; the SwiGLU split in the epilogue is register-local in the mma layout (zero shuffles).
- Device-side tile→expert mapping by in-register cumsum over `expert_offsets` (no host sync); worst-case grid with early-exit surplus CTAs.
- A-row clamping instead of masking → the k-loop has no load masks at all.
- Direct-launch cache bypasses Triton's JIT dispatch (~6µs/call), worth ~1% on the small shape.
The big shapes run at ~360 TF actual — at the sustained hardware ceiling (~358–361 TF) and faster than a cuBLAS bmm doing the matmuls alone (18.63/12.69 ms), since the fused epilogue halves output traffic. Shape 1 is memory-bound at ~95% of its DRAM floor after accounting for the harness's L2-flush writeback. All headroom hunted (TMA, persistent kernels, config sweeps, exact grids gave ≤1% and were rejected); I consider this converged.
in 2out 30cache_r 52511cache_w 1455