"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
Per expert e:
h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
def _num_sms() -> int:
return torch.cuda.get_device_properties(0).multi_processor_count
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 256}, num_stages=2, num_warps=4),
],
key=["H", "I", "E"],
)
@triton.jit
def _grouped_swiglu_kernel(
a_ptr,
b_gate_ptr,
b_up_ptr,
c_ptr,
offsets_ptr,
E,
H,
I,
stride_am,
stride_ak,
stride_bg,
stride_bh,
stride_bi,
stride_cm,
stride_cn,
NUM_SMS: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
tidx = tl.program_id(0)
iterated_tiles = 0
for g in tl.range(E):
m_start = tl.load(offsets_ptr + g)
m_end = tl.load(offsets_ptr + g + 1)
m_size = m_end - m_start
num_m_tiles = tl.cdiv(m_size, BLOCK_M)
num_n_tiles = tl.cdiv(I, BLOCK_N)
num_tiles = num_m_tiles * num_n_tiles
if m_size > 0:
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
for k_block in range(0, H, BLOCK_K):
offs_k = k_block + tl.arange(0, BLOCK_K)
a_ptrs = (
a_ptr
+ (m_start + offs_am[:, None]) * stride_am
+ offs_k[None, :] * stride_ak
)
bg_ptrs = (
b_gate_ptr
+ g * stride_bg
+ offs_k[:, None] * stride_bh
+ offs_bn[None, :] * stride_bi
)
bu_ptrs = (
b_up_ptr
+ g * stride_bg
+ offs_k[:, None] * stride_bh
+ offs_bn[None, :] * stride_bi
)
a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
acc_gate = tl.dot(a, bg, acc_gate)
acc_up = tl.dot(a, bu, acc_up)
gate = acc_gate
silu_gate = gate * tl.sigmoid(gate)
c = (silu_gate * acc_up).to(tl.bfloat16)
c_ptrs = (
c_ptr
+ (m_start + offs_am[:, None]) * stride_cm
+ offs_bn[None, :] * stride_cn
)
c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
tl.store(c_ptrs, c, mask=c_mask)
tidx += NUM_SMS
iterated_tiles += num_tiles
def grouped_swiglu(
hidden_states: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
T_perm, H = hidden_states.shape
E, H_w, I = W_gate.shape
assert H == H_w and W_up.shape == W_gate.shape
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
num_sms = _num_sms()
_grouped_swiglu_kernel[(num_sms,)](
hidden_states,
W_gate,
W_up,
out,
expert_offsets,
E,
H,
I,
hidden_states.stride(0),
hidden_states.stride(1),
W_gate.stride(0),
W_gate.stride(1),
W_gate.stride(2),
out.stride(0),
out.stride(1),
NUM_SMS=num_sms,
)
return out
class Model(nn.Module):
def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
def forward(
self,
hidden_states: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor:
return grouped_swiglu(
hidden_states.contiguous(),
self.W_gate.contiguous(),
self.W_up.contiguous(),
expert_offsets.contiguous(),
)
T_total = 32768
H = 4096
I = 1536 # noqa: E741
E = 128
K = 8
def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
T_perm = T_total * K
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
return offsets
def get_inputs():
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
expert_offsets = _build_routing(T_total, E, K)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_121705_cursor_composer-2.5-fast_06_sonic_moe_swiglu/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T18:55:11.950917+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T18:55:13.708995+00:00 elapsed_s=1.758 ms=19.345648
shape=0 variant=solution tflops=42.626 gbps=319.142 ms=19.346
shape=0 solution_peak_fraction=0.0853
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T18:55:20.398657+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T18:55:20.790946+00:00 elapsed_s=0.392 ms=0.476256
shape=1 variant=solution tflops=72.146 gbps=1338.638 ms=0.476
shape=1 solution_peak_fraction=0.1443
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T18:55:49.045279+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T18:55:50.155498+00:00 elapsed_s=1.110 ms=13.068016
shape=2 variant=solution tflops=42.069 gbps=287.580 ms=13.068
shape=2 solution_peak_fraction=0.0841
peak_fraction: 0.1012
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_121705_cursor_composer-2.5-fast_06_sonic_moe_swiglu/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_121705_cursor_composer-2.5-fast_06_sonic_moe_swiglu/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_121705_cursor_composer-2.5-fast_06_sonic_moe_swiglu/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_121705_cursor_composer-2.5-fast_06_sonic_moe_swiglu/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_121705_cursor_composer-2.5-fast_06_sonic_moe_swiglu/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_121705_cursor_composer-2.5-fast_06_sonic_moe_swiglu/repo
Installed 56 packages in 65ms
PASS
{
"run_id": "20260615_121705_cursor_composer-2.5-fast_06_sonic_moe_swiglu",
"run_group": "",
"problem": "06_sonic_moe_swiglu",
"harness": "cursor",
"model": "composer-2.5-fast",
"reasoning_effort": "",
"started_at": "2026-06-15T12:17:05-06:00",
"harness_finished_at": "2026-06-15T12:52:14-06:00",
"finished_at": "2026-06-15T12:55:50-06:00",
"start_epoch": 1781547425,
"harness_end_epoch": 1781549534,
"end_epoch": 1781549750,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.1012,
"template_mutated": false,
"elapsed_seconds": 2109,
"total_elapsed_seconds": 2325,
"check_elapsed_seconds": 126,
"benchmark_elapsed_seconds": 90,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 14.998577524893314,
"usage": {"input_tokens": 198436, "output_tokens": 31632, "cache_read_tokens": 2577834, "cache_creation_tokens": 0, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"targetDirectory": "/workspace/problems/06_sonic_moe_swiglu",
"globPattern": "**/*"
}null{
"path": "/workspace/problems/06_sonic_moe_swiglu/reference.py"
}"""Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
according to routing metadata; we compute, per expert e:
h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
where x_e is the slice of permuted hidden states routed to expert e, with
expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
The reference loops over experts in Python. Slow, but pedagogically clear and
correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
are NOT used here, but the reference is exempt — only solution.py is checked.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
OP_TYPE = "grouped_gemm_swiglu"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
class Model(nn.Module):
"""Up-projection of a top-K MoE FFN with fused SwiGLU.
Inputs at call time:
hidden_states: (T_perm, H) bf16, already permuted to expert order
expert_offsets: (E+1,) int32, prefix sums of token counts per expert
so expert e owns rows [offsets[e]:offsets[e+1]]
T_perm = T_total * K (each token visits K experts)
Output:
gated_up: (T_perm, I) bf16
"""
def __init__(self,
... (2300 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/check.py"
}"""Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
"""
import re
import sys
from pathlib import Path
import torch
import yaml
# Make the repo's src/ importable
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.correctness import check_correctness # noqa: E402
from src.eval.numeric_stress import ( # noqa: E402
numeric_stress_cases,
numeric_stress_context,
tolerance_for_case,
)
def main():
try:
import reference
import shapes
import solution
except Exception as e:
print(f"FAIL: import error: {e}")
sys.exit(1)
problem_yaml = Path("problem.yaml")
meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
# --- Forbidden-op check ------------------------------------------------
sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
for forbidden in meta.get("forbidden", []):
pat = re.escape(forbidden)
if re.search(pat, sol_src):
print(f"FAIL: forbidden op used: {forbidden}")
sys.exit(1)
device = torch.device("cuda:0")
tol_override = meta.get("tolerance") or None
# --- Per-shape correctness --------------------------------------------
all_shapes = shapes.SHAPES
for
... (2591 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/benchmark.py"
}"""Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
For each shape: times the agent's solution first. Optional eager, compiled,
and SOTA diagnostics are enabled with KBH_BENCHMARK_BASELINES=1. Reports achieved TFLOPS, GB/s, and peak_fraction.
Output lines the harness picks up:
shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
"""
import sys
from math import exp, log
from pathlib import Path
import torch
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
from src.hardware import get as get_hw # noqa: E402
def _eval_formula(expr: str, vars: dict) -> float:
return float(eval(expr, {"__builtins__": {}}, vars))
def main():
import reference
import shapes
import solution
meta = yaml.safe_load(Path("problem.yaml").read_text())
hw = get_hw(meta["hardware"][0])
peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
peak_gbps = hw.peak_bandwidth_gb_s
regime = meta.get("regime", "compute")
flops_formula = meta["flops_formula"]
bytes_formula = meta["bytes_formula"]
num_perf_trials = int(meta.get("num_perf_trials", 20))
device = torch.device("cuda:0")
include_baselines = benchmark_baselines_e
... (3942 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/shapes.py"
}{
"path": "/workspace/problems/06_sonic_moe_swiglu/problem.yaml"
}{
"path": "/workspace/problems/06_sonic_moe_swiglu/sota.py"
}"""Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
Defaults match the sonic-moe paper's headline configuration. We add:
- a smaller shape for fast iteration during agent development
- a wider intermediate (different aspect ratio) to stress N-tile selection
"""
SHAPES = [
# Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
{"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
# Fast-iteration shape (~16x cheaper). Same expert count to keep the
# variable-length grouped layout meaningful, but smaller token / hidden dims.
{"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
# Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
# Forces tiles to handle larger N relative to K.
{"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
]
name: 06_sonic_moe_swiglu
display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
precision: bf16
regime: compute
hardware: [RTX_PRO_6000]
peak_tflops_key: bf16
peak_bandwidth_key: dram
# Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
# Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
# (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
# only if you count routing. The standard MoE FLOPs convention counts only the
# active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
flops_formula: "2 * T_total * H * (2 * I)"
# Bytes moved (approximate, lower bound):
# read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
# expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
tolerance:
bfloat16: 0.02
# Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
# - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
# - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
forbidden:
- "torch.matmul"
- "torch.bmm"
- "torch.nn.functional.linear"
- "F.linear"
- "from sonic_moe"
- "import sonic_moe"
sota:
name: "Sonic-MoE up-projection (Tri Dao)"
url: "https://github.com/Dao-AILab/sonic-moe"
function: "sonic_moe.fused_moe_up"
deps:
- "sonic-moe>=0.1.2" # requires Python>=3.12, sm_120 suppor
... (339 more chars)"""SOTA reference for Sonic-MoE up-projection: Tri Dao's sonic-moe.
Status (2026-04): sonic-moe ships on PyPI as `sonic-moe` (>=0.1.2.post1) and
requires Python>=3.12. It dispatches to QuACK CuTeDSL grouped GEMM kernels.
SM120 (RTX PRO 6000 Blackwell Workstation) support is in-progress upstream --
the package installs cleanly but kernels may fail at launch on SM120 (the
QuACK grouped-GEMM path targets Sm90/Sm100 in the public release).
If the live call fails, `is_available()` returns False and the benchmark scores
the agent against PyTorch eager + the documented H100 paper ceiling (see
problem.yaml.sota.reference_throughput_tflops_h100). Agents are FORBIDDEN from
importing sonic_moe in solution.py (see problem.yaml.forbidden).
"""
from __future__ import annotations
import torch
def _try_sonic_moe(
hidden_states: torch.Tensor,
W_gate: torch.Tensor,
W_up: torch.Tensor,
expert_offsets: torch.Tensor,
) -> torch.Tensor | None:
try:
import sonic_moe # type: ignore # noqa: F401
except Exception:
return None
try:
# Public sonic-moe API surface is still stabilizing. The expected entry
# point bundles gate+up weights as a single (E, H, 2*I) tensor and fuses
# SwiGLU. Adapt to the actual signature once SM120 lands.
W = torch.cat([W_gate, W_up], dim=-1).contiguous() # (E, H, 2*I)
from sonic_moe import fused_moe_up # type: ignore
return fused_moe_up(hidden_states, W, expert_offsets)
ex
... (1322 more chars){
"targetDirectory": "/workspace/problems/06_sonic_moe_swiglu",
"globPattern": "solution.py"
}null{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>/dev/null; uv run python -c \"import torch; print('torch', torch.__version__); print('cuda', torch.cuda.is_available()); print('cap', torch.cuda.get_device_capability() if torch.cuda.is_available() else None)\"",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_8bb63c22-0977-4498-9c9d-996a0fbce65",
"simpleCommands": [
"nvidia-smi",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
... (1719 more chars)null{
"command": "uv run python -c \"import triton; print('triton', triton.__version__)\" 2>&1; ls /usr/local/cuda/include/cutlass 2>/dev/null | head -5; uv run python -c \"import cutlass; print('cutlass')\" 2>&1; find /workspace -name \"*.py\" -path \"*grouped*\" 2>/dev/null | head -20",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_5ee625ad-1aab-4513-a328-86809c8c987",
"simpleCommands": [
"uv",
"ls",
"head",
"uv",
"find",
"head"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"ex
... (3235 more chars){
"query": "grouped GEMM SwiGLU MoE kernel implementation",
"targetDirectories": [],
"explanation": ""
}errorMessage: 'Codebase search is unavailable for this conversation because the codebase index could not be found. Do not retry SemanticSearch. Use Grep, Glob, and Read instead.'{
"pattern": "grouped.*gemm|GroupGemm|grouped_gemm",
"path": "/workspace",
"caseInsensitive": true,
"headLimit": 40,
"multiline": false,
"toolCallId": "tool_7f8c25ce-326b-490f-b23e-8874bc42c0f",
"offset": 0
}nullnull{
"path": "/workspace/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
"limit": 200
}# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
... (7537 more chars){
"pattern": "numeric_stress",
"path": "/workspace/src",
"outputMode": "files_with_matches",
"caseInsensitive": false,
"multiline": false,
"toolCallId": "tool_ecdcb798-ce9c-4887-aef0-6347a5c6245",
"offset": 0
}null{
"command": "<216 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_0585226d-bff0-4349-bb05-734555ecfdb",
"simpleCommands": [
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "uv",
"args": [
{
"type": "word",
"value": "run"
},
{
"type": "word",
"value": "python"
},
{
"type": "word",
"value": "-c
... (904 more chars)uv run python -c "
import triton
import triton.language as tl
print('has dot', hasattr(tl, 'dot'))
# Check bf16 support
import torch
x = torch.randn(16,16, device='cuda', dtype=torch.bfloat16)
print('bf16 cuda ok')
"null{
"path": "/workspace/src/eval/numeric_stress.py"
}{
"targetDirectory": "/workspace/problems",
"globPattern": "**/solution.py"
}"""Numeric distribution stress cases for correctness checks.
These cases are not hidden shapes. They rerun the same shape/seed validation
under a few scale regimes that catch kernels that only work for the nominal
N(0, 1)-ish inputs.
"""
from __future__ import annotations
import os
from collections.abc import Iterator, Mapping, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
import torch
ToleranceOverride = Mapping[str, float | Mapping[str, float]]
@dataclass(frozen=True)
class NumericStressCase:
name: str
input_scales: Mapping[int, float] = field(default_factory=dict)
state_scales: Mapping[str, float] = field(default_factory=dict)
tolerance: ToleranceOverride | None = None
NOMINAL_CASE = NumericStressCase("nominal")
_SMALL_BF16 = {"bfloat16": {"atol": 1e-4, "rtol": 5e-2}}
_MED_BF16 = {"bfloat16": {"atol": 5e-4, "rtol": 5e-2}}
_LARGE_BF16 = {"bfloat16": {"atol": 2e-1, "rtol": 5e-2}}
_TINY_FP32 = {"float32": {"atol": 1e-7, "rtol": 1e-4}}
_LARGE_FP32 = {"float32": {"atol": 1e-1, "rtol": 1e-4}}
_CASES: dict[str, tuple[NumericStressCase, ...]] = {
"01_fp8_gemm": (
# fp8 x fp8 has accumulation-order noise vs the bf16-matmul reference that
# scales with input magnitude; atols are calibrated to absorb it (measured
# fp8-MMA residual x ~1.5 margin) while rtol=5e-2 still catches gross error.
NumericStressCase("small_input", input_scales={0: 1e-3},
tolerance={
... (4968 more chars)null{
"command": "uv pip show cutlass 2>&1; uv pip show nvidia-cutlass-dsl 2>&1; uv pip show cutlass-dsl 2>&1; ls /workspace/.venv/lib/python3.11/site-packages/ | rg -i cutlass",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_c1da35ac-3690-444c-b304-bd839aba35e",
"simpleCommands": [
"uv",
"uv",
"uv",
"ls",
"rg"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "uv",
"args": [
{
"type": "word",
"value
... (2361 more chars)failure:
command: 'uv pip show cutlass 2>&1; uv pip show nvidia-cutlass-dsl 2>&1; uv pip show cutlass-dsl 2>&1; ls /workspace/.venv/lib/python3.11/site-packages/ | rg -i cutlass'
workingDirectory: ''
exitCode: 127
signal: ''
stdout:
Using Python 3.11.15 environment at: /workspace/.venv
warning: Package(s) not found for: cutlass
Using Python 3.11.15 environment at: /workspace/.venv
warning: Package(s) not found for: nvidia-cutlass-dsl
Using Python 3.11.15 environment at: /workspace/.venv
warning: Package(s) not found for: cutlass-dsl
stderr:
--: line 1: rg: command not found
executionTime: 138
interleavedOutput:
Using Python 3.11.15 environment at: /workspace/.venv
warning: Package(s) not found for: cutlass
Using Python 3.11.15 environment at: /workspace/.venv
warning: Package(s) not found for: nvidia-cutlass-dsl
Using Python 3.11.15 environment at: /workspace/.venv
warning: Package(s) not found for: cutlass-dsl
--: line 1: rg: command not found
aborted: False
localExecutionTimeMs: 73
isBackground: False{
"command": "uv pip install nvidia-cutlass-dsl 2>&1 | tail -20",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_c33aeaf2-f5b3-4533-af79-c9308f0a06e",
"simpleCommands": [
"uv",
"tail"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "uv",
"args": [
{
"type": "word",
"value": "pip"
},
{
"type": "word",
"value": "install"
},
{
"type"
... (858 more chars){
"path": "/workspace/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GroupMMCommon.cuh",
"limit": 150
}{
"path": "/workspace/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_grouped.py",
"limit": 150
}# mypy: allow-untyped-defs
import logging
from dataclasses import asdict, dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
from ..ir import ChoiceCaller, Layout, TensorBox
from ..lowering import register_lowering
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
realize_inputs,
TritonTemplate,
)
from ..utils import (
get_gpu_shared_memory,
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_blackwell_cutedsl_grouped_mm,
use_nv_universal_gemm_template,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
load_kernel_template,
persistent_grouped_mm_grid,
)
log = logging.getLogger(__name__)
aten = torch.ops.aten
@dataclass
class Config:
kwargs: dict[str, int]
num_stages: int
num_warps: int
_NV_CONFIGS = [
Config(
{
"BLOCK_M": block_size_m,
"BLOCK_N": block_size_n,
"BLOCK_K": block_size_k,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=num_stages,
num_warps=num_warps,
)
for block_size_m in [16, 32, 64, 128]
... (2546 more chars)#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
#pragma once
#include <cutlass/util/packed_stride.hpp>
namespace at::cuda::detail {
using Strides = std::array<int64_t, 3>;
template <
typename DtypeA,
typename DtypeB,
typename DtypeOutput,
typename DtypeScale,
typename ProblemShape,
typename StrideA,
typename StrideB,
typename StrideOutput>
__global__ void prepare_grouped_gemm_data(
DtypeA* A,
DtypeB* B,
DtypeOutput* output,
DtypeScale* scale_A,
DtypeScale* scale_B,
DtypeA** A_ptrs,
DtypeB** B_ptrs,
DtypeOutput** output_ptrs,
DtypeScale** inputA_scale_ptrs,
DtypeScale** inputB_scale_ptrs,
ProblemShape* problem_sizes,
// Strides for cutlass, cute::Stride
StrideA* stride_A,
StrideB* stride_B,
StrideOutput* stride_output,
const int32_t* offs,
int32_t M,
int32_t N,
int32_t K,
// Original strides of the input tensors
Strides tensor_StrideA,
Strides tensor_StrideB,
Strides tensor_StrideOutput,
Strides tensor_ShapeA,
Strides tensor_ShapeB,
int64_t a_scale_stride,
int64_t b_scale_stride,
bool a_row_major = true,
bool b_row_major = false) {
int32_t tid = threadIdx.x;
int32_t delta = 0;
int32_t offset = 0;
if (offs != nullptr) {
int32_t start = tid == 0 ? 0 : offs[tid - 1];
offset = offs[tid];
delta = offset - start;
CUDA_KERNEL_ASSERT(delta >=0 && "expected gemm dimension to be greater or equal 0
... (4689 more chars)null{
"targetDirectory": "/workspace/.venv/lib/python3.11/site-packages/torch",
"globPattern": "**/triton_mm_grouped*"
}null{
"command": "<220 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_ed14f75a-a364-48c5-a052-261b1c9b94c",
"simpleCommands": [
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "uv",
"args": [
{
"type": "word",
"value": "run"
},
{
"type": "word",
"value": "python"
},
{
"type": "word",
"value": "-c
... (908 more chars)uv run python -c "
import cutlass
import cutlass.cute as cute
print('cutlass', cutlass.__version__ if hasattr(cutlass,'__version__') else 'ok')
import cutlass.utils.blackwell_helpers as bh
print('blackwell helpers ok')
"null{
"taskId": "248045",
"blockUntilMs": 60000,
"regex": "blackwell|Error|ok"
}{
"path": "/workspace/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/templates/triton_mm_grouped.py.jinja",
"limit": 200
}{% macro assign_maybe_constexpr(name, value_expr) -%}
{%- set value_str = value_expr | string -%}
{%- set sentinel = "__NOT_A_NUMBER__" -%}
{%- set as_int = value_str | int(default=sentinel) -%}
{%- set as_float = value_str | float(default=sentinel) -%}
{%- set is_constexpr = (as_int != sentinel) or (as_float != sentinel) -%}
{{ name }}{{ ": tl.constexpr" if is_constexpr else "" }} = {{ value_expr }}
{%- endmacro %}
import triton
import triton.language as tl
@triton.jit
def do_tma_loads(
g, a_desc, b_desc, m_offset, n_offset, k_offset,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
{%- if A_IS_2D %}
{%- if A_IS_K_MAJOR %}
a = a_desc.load([m_offset, k_offset])
{%- else %}
a = a_desc.load([k_offset, m_offset])
{%- endif %}
{%- else %}
{%- if A_IS_K_MAJOR %}
a = a_desc.load([g, m_offset, k_offset]).reshape(BLOCK_M, BLOCK_K)
{%- else %}
a = a_desc.load([g, k_offset, m_offset]).reshape(BLOCK_K, BLOCK_M)
{%- endif %}
{%- endif %}
{%- if B_IS_2D %}
{%- if B_IS_K_MAJOR %}
b = b_desc.load([n_offset, k_offset])
{%- else %}
b = b_desc.load([k_offset, n_offset])
{%- endif %}
{%- else %}
{%- if B_IS_K_MAJOR %}
b = b_desc.load([g, n_offset, k_offset]).reshape(BLOCK_N, BLOCK_K)
{%- else %}
b = b_desc.load([g, k_offset, n_offset]).reshape(BLOCK_K, BLOCK_N)
{%- endif %}
{%- endif %}
return (a, b)
@triton.jit
def do_mma(a, b, accumulator):
{%- if USE_FAST_ACCUM %}
{%- if A_IS_K_MAJOR and B_IS_K_MA
... (4333 more chars)null{
"path": "/workspace/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/templates/triton_mm_grouped.py.jinja",
"offset": 200
}{
"path": "/home/agent/.cursor/projects/workspace-problems-06-sonic-moe-swiglu/terminals/248045.txt"
}{%- endif %}
iterated_tiles = 0
for g in tl.range(G):
{%- if M_IS_VARYING %}
# Move across groups
m_start_offset = m_end_offset
m_end_offset = tl.load(offsets_ptr + g)
m_size = m_end_offset - m_start_offset
{%- if SCALED %}
m_scale_start_offset = m_start_offset
{%- endif %}
{%- else %}
m_start_offset = 0
m_size = M
{%- if SCALED %}
m_scale_start_offset = g * M
{%- endif %}
{%- endif %}
{%- if N_IS_VARYING %}
# Move across groups
n_start_offset = n_end_offset
n_end_offset = tl.load(offsets_ptr + g)
n_size = n_end_offset - n_start_offset
{%- if SCALED %}
n_scale_start_offset = n_start_offset
{%- endif %}
{%- else %}
n_start_offset = 0
n_size = N
{%- if SCALED %}
n_scale_start_offset = g * N
{%- endif %}
{%- endif %}
if m_size > 0 and n_size > 0:
{%- if K_IS_VARYING %}
# Move across groups
k_start_offset = k_end_offset
k_end_offset = tl.load(offsets_ptr + g)
k_size = k_end_offset - k_start_offset
{%- else %}
k_start_offset = 0
k_size = K
{%- endif %}
num_m_tiles = tl.cdiv(m_size, BLOCK_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_N)
num_tiles = num_m_tiles * num_n_tiles
# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
... (5151 more chars)---
pid: 976
cwd: "/workspace/problems/06_sonic_moe_swiglu"
command: "uv run python -c \"\nimport cutlass\nimport cutlass.cute as cute\nprint('cutlass', cutlass.__version__ if hasattr(cutlass,'__version__') else 'ok')\nimport cutlass.utils.blackwell_helpers as bh\nprint('blackwell helpers ok')\n\""
started_at: 2026-06-15T18:17:41.511Z
running_for_ms: 55005
---
Uninstalled 1 package in 1ms
warning: Failed to hardlink files; falling back to full copy. This may lead to degraded performance.
If the cache and target directories are on different filesystems, hardlinking may not be supported.
If this is intentional, set `export UV_LINK_MODE=copy` or use `--link-mode=copy` to suppress this warning.
Installed 1 package in 7ms
cutlass 4.5.2
blackwell helpers ok
---
exit_code: 0
elapsed_ms: 59396
ended_at: 2026-06-15T18:18:40.907Z
---
--- /dev/null
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,214 @@
+"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
+
+Per expert e:
+ h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+def _num_sms() -> int:
+ props = torch.cuda.get_device_properties(0)
+ return props.multi_processor_count
+
+
+@triton.jit
+def _grouped_swiglu_kernel(
+ a_ptr,
+ b_gate_ptr,
+ b_up_ptr,
+ c_ptr,
+ offsets_ptr,
+ E,
+ H,
+ I,
+ stride_am,
+ stride_ak,
+ stride_bg,
+ stride_bh,
+ stride_bi,
+ stride_cm,
+ stride_cn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ NUM_SMS: tl.constexpr,
+):
+ tidx = tl.program_id(0)
+ iterated_tiles = 0
+
+ for g in tl.range(E):
+ m_start = tl.load(offsets_ptr + g)
+ m_end = tl.load(offsets_ptr + g + 1)
+ m_size = m_end - m_start
+ if m_size <= 0:
+ continue
+
+ num_m_tiles = tl.cdiv(m_size, BLOCK_M)
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ num_tiles = num_m_tiles * num_n_tiles
+
+ while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
+ gidx = tidx - iterated_tiles
+ tile_m_idx = gidx % num_m_tiles
+ tile_n_idx = gidx // num_m_tiles
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ for k_block in range(0, H, BLOCK_K):
+ offs_k = k_block + tl.arange(0, BLOCK_K)
+
+ a_ptrs = (
+ a_ptr
+ + (m_start + offs_am[:, None]) * stride_am
+ + offs_k[None, :] * stride_ak
+ )
+ bg_ptrs = (
+ b_gate_ptr
+ + g * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+ bu_ptrs = (
+ b_up_ptr
+ + g * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+
+ a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
+ b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+ bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
+ bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
+
+ acc_gate = tl.dot(a, bg, acc_gate)
+ acc_up = tl.dot(a, bu, acc_up)
+
+ gate = acc_gate
+ silu_gate = gate * tl.sigmoid(gate)
+ c = (silu_gate * acc_up).to(tl.bfloat16)
+
+ c_ptrs = (
+ c_ptr
+ + (m_start + offs_am[:, None]) * stride_cm
+ + offs_bn[None, :] * stride_cn
+ )
+ c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+ tidx += NUM_SMS
+
+ iterated_tiles += num_tiles
+
+
+def _pick_config(H: int, I: int, E: int) -> dict:
+ """Heuristic tile sizes tuned for the three benchmark shapes."""
+ if I >= 4096:
+ return {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8}
+ if H >= 4096:
+ return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8}
+ return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8}
+
+
+def grouped_swiglu(
+ hidden_states: torch.Tensor,
+ W_gate: torch.Tensor,
+ W_up: torch.Tensor,
+ expert_offsets: torch.Tensor,
+) -> torch.Tensor:
+ T_perm, H = hidden_states.shape
+ E, H_w, I = W_gate.shape
+ assert H == H_w and W_up.shape == W_gate.shape
+
+ out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
+
+ cfg = _pick_config(H, I, E)
+ num_sms = _num_sms()
+ grid = (num_sms,)
+
+ _grouped_swiglu_kernel[grid](
+ hidden_states,
+ W_gate,
+ W_up,
+ out,
+ expert_offsets,
+ E,
+ H,
+ I,
+ hidden_states.stride(0),
+ hidden_states.stride(1),
+ W_gate.stride(0),
+ W_gate.stride(1),
+ W_gate.stride(2),
+ out.stride(0),
+ out.stride(1),
+ **{k: v for k, v in cfg.items() if k.startswith("BLOCK_")},
+ NUM_SMS=num_sms,
+ num_warps=cfg["num_warps"],
+ )
+ return out
+
+
+class Model(nn.Module):
+ def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
+ super().__init__()
+ self.T_total = T_total
+ self.H = H
+ self.I = I
+ self.E = E
+ self.K = K
+ self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ nn.init.normal_(self.W_gate, std=0.02)
+ nn.init.normal_(self.W_up, std=0.02)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ ) -> torch.Tensor:
+ return grouped_swiglu(
+ hidden_states.contiguous(),
+ self.W_gate.contiguous(),
+ self.W_up.contiguous(),
+ expert_offsets.contiguous(),
+ )
+
+
+T_total = 32768
+H = 4096
+I = 1536 # noqa: E741
+E = 128
+K = 8
+
+
+def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
+ T_perm = T_total * K
+ base = T_perm // E
+ rem = T_perm - base * E
+ counts = torch.full((E,), base, dtype=torch.int32, device=device)
+ counts[:rem] += 1
+ offsets = torch.zeros(E + 1, dtype=
... (374 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<6070 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_59518e2d-bcd3-47c4-a390-7e59abdc136",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problems/06
... (911 more chars)failure:
command: 'cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/problems/06_sonic_moe_swiglu/check.py", line 117, in <module>
main()
File "/workspace/problems/06_sonic_moe_swiglu/check.py", line 78, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/06_sonic_moe_swiglu/solution.py", line 180, in forward
return grouped_swiglu(
^^^^^^^^^^^^^^^
File "/workspace/problems/06_sonic_moe_swiglu/solution.py", line 139, in grouped_swiglu
_grouped_swiglu_kernel[grid](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/run
... (4529 more chars)--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,214 +1,11 @@
-"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
-
-Per expert e:
- h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
-"""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-
-def _num_sms() -> int:
- props = torch.cuda.get_device_properties(0)
- return props.multi_processor_count
-
-
-@triton.jit
-def _grouped_swiglu_kernel(
- a_ptr,
- b_gate_ptr,
- b_up_ptr,
- c_ptr,
- offsets_ptr,
- E,
- H,
- I,
- stride_am,
- stride_ak,
- stride_bg,
- stride_bh,
- stride_bi,
- stride_cm,
- stride_cn,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- NUM_SMS: tl.constexpr,
-):
- tidx = tl.program_id(0)
- iterated_tiles = 0
-
for g in tl.range(E):
m_start = tl.load(offsets_ptr + g)
m_end = tl.load(offsets_ptr + g + 1)
m_size = m_end - m_start
- if m_size <= 0:
- continue
num_m_tiles = tl.cdiv(m_size, BLOCK_M)
num_n_tiles = tl.cdiv(I, BLOCK_N)
num_tiles = num_m_tiles * num_n_tiles
- while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
- gidx = tidx - iterated_tiles
- tile_m_idx = gidx % num_m_tiles
- tile_n_idx = gidx // num_m_tiles
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
-
- for k_block in range(0, H, BLOCK_K):
- offs_k = k_block + tl.arange(0, BLOCK_K)
-
- a_ptrs = (
- a_ptr
- + (m_start + offs_am[:, None]) * stride_am
- + offs_k[None, :] * stride_ak
- )
- bg_ptrs = (
- b_gate_ptr
- + g * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
- bu_ptrs = (
- b_up_ptr
- + g * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
-
- a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
- b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
-
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
- bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
- bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
-
- acc_gate = tl.dot(a, bg, acc_gate)
- acc_up = tl.dot(a, bu, acc_up)
-
- gate = acc_gate
- silu_gate = gate * tl.sigmoid(gate)
- c = (silu_gate * acc_up).to(tl.bfloat16)
-
- c_ptrs = (
- c_ptr
- + (m_start + offs_am[:, None]) * stride_cm
- + offs_bn[None, :] * stride_cn
- )
- c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
- tl.store(c_ptrs, c, mask=c_mask)
-
- tidx += NUM_SMS
-
- iterated_tiles += num_tiles
-
-
-def _pick_config(H: int, I: int, E: int) -> dict:
- """Heuristic tile sizes tuned for the three benchmark shapes."""
- if I >= 4096:
- return {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8}
- if H >= 4096:
- return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8}
- return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8}
-
-
-def grouped_swiglu(
- hidden_states: torch.Tensor,
- W_gate: torch.Tensor,
- W_up: torch.Tensor,
- expert_offsets: torch.Tensor,
-) -> torch.Tensor:
- T_perm, H = hidden_states.shape
- E, H_w, I = W_gate.shape
- assert H == H_w and W_up.shape == W_gate.shape
-
- out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
-
- cfg = _pick_config(H, I, E)
- num_sms = _num_sms()
- grid = (num_sms,)
-
- _grouped_swiglu_kernel[grid](
- hidden_states,
- W_gate,
- W_up,
- out,
- expert_offsets,
- E,
- H,
- I,
- hidden_states.stride(0),
- hidden_states.stride(1),
- W_gate.stride(0),
- W_gate.stride(1),
- W_gate.stride(2),
- out.stride(0),
- out.stride(1),
- **{k: v for k, v in cfg.items() if k.startswith("BLOCK_")},
- NUM_SMS=num_sms,
- num_warps=cfg["num_warps"],
- )
- return out
-
-
-class Model(nn.Module):
- def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
- super().__init__()
- self.T_total = T_total
- self.H = H
- self.I = I
- self.E = E
- self.K = K
- self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
- self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
- nn.init.normal_(self.W_gate, std=0.02)
- nn.init.normal_(self.W_up, std=0.02)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- expert_offsets: torch.Tensor,
- ) -> torch.Tensor:
- return grouped_swiglu(
- hidden_states.contiguous(),
- self.W_gate.contiguous(),
- self.W_up.contiguous(),
- expert_offsets.contiguous(),
- )
-
-
-T_total = 32768
-H = 4096
-I = 1536 # noqa: E741
-E = 128
-K = 8
-
-
-def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
- T_perm = T_total * K
- base = T_perm // E
- rem = T_perm - base * E
- counts = torch.full((E,), base, dtype=torch.int32, device=device)
- counts[:rem]
... (521 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<384 chars \u2014 see diff>"
}null--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,11 +1,3 @@
- for g in tl.range(E):
- m_start = tl.load(offsets_ptr + g)
- m_end = tl.load(offsets_ptr + g + 1)
- m_size = m_end - m_start
+ tidx += NUM_SMS
- num_m_tiles = tl.cdiv(m_size, BLOCK_M)
- num_n_tiles = tl.cdiv(I, BLOCK_N)
- num_tiles = num_m_tiles * num_n_tiles
-
- if m_size > 0:
- while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:+ iterated_tiles += num_tiles{
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<64 chars \u2014 see diff>"
}path: ''
error: 'Incorrect tool arguments'
modelVisibleError: 'old_string and new_string are exactly the same'{
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"offset": 40,
"limit": 80
}):
tidx = tl.program_id(0)
iterated_tiles = 0
for g in tl.range(E):
m_start = tl.load(offsets_ptr + g)
m_end = tl.load(offsets_ptr + g + 1)
m_size = m_end - m_start
num_m_tiles = tl.cdiv(m_size, BLOCK_M)
num_n_tiles = tl.cdiv(I, BLOCK_N)
num_tiles = num_m_tiles * num_n_tiles
if m_size > 0:
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
for k_block in range(0, H, BLOCK_K):
offs_k = k_block + tl.arange(0, BLOCK_K)
a_ptrs = (
a_ptr
+ (m_start + offs_am[:, None]) * stride_am
+ offs_k[None, :] * stride_ak
)
bg_ptrs = (
b_gate_ptr
+ g * stride_bg
+ offs_k[:, None] * stride_bh
+ offs_bn[None, :] * stride_bi
)
bu_ptrs = (
b_up_ptr
+ g * stride_bg
+ offs_k[:, None] * stride_bh
... (1363 more chars){
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_4580d077-1a7e-4898-b8a3-e2661cdda24",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problems/06
... (917 more chars)null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_9bd77355-b906-49e0-89d8-b7107ddaef8",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problem
... (927 more chars){
"pattern": "RTX_PRO_6000|peak_tflops",
"path": "/workspace/src/hardware",
"caseInsensitive": false,
"multiline": false,
"toolCallId": "tool_106269ad-a705-4be1-b078-4e79cc96699",
"offset": 0
}nullnull{
"path": "/workspace/src/hardware/rtx_pro_6000.py"
}"""RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
will see 60-85% of peak on well-tuned code.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class HardwareTarget:
name: str
sm: str
vram_gb: int
peak_bandwidth_gb_s: float # DRAM
peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
RTX_PRO_6000 = HardwareTarget(
name="RTX PRO 6000 Blackwell Workstation",
sm="sm_120a",
vram_gb=96,
peak_bandwidth_gb_s=1800.0,
peak_tflops_dense={
# Blackwell GB202 dense tensor peaks. Derived from NVIDIA's headline
# 4000 fp4-sparse AI TOPS -> fp4 dense 2000 -> fp8 1000 -> bf16 500
# (each precision step = 2x; sparse = 2x dense). Verified empirically:
# cuBLAS hits fp8 773 / bf16 412 TFLOPS on 4096^3 (~77-82% of these,
# normal cuBLAS efficiency). The prior table was ~2.5x too low, which
# produced peak_fraction > 1.0 for real fp8 kernels. fp32 is the 125
# TFLOPS SIMT figure from the spec sheet.
"fp4": 2000.0,
"nvfp4": 2000.0,
"mxfp4": 2000.0,
"fp6": 1000.0,
"fp8": 1000.0,
"bf16": 500.0,
"fp16": 500.0,
"tf32": 250.0,
"fp32": 125.0, # non-tensor-core SIMT fp32 (spec sheet)
"int8": 1000.0,
"int4": 2000.0,
},
)
{
"command": "<379 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_ba38a10a-0015-482f-9b31-8c9f63c2300",
"simpleCommands": [
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "uv",
"args": [
{
"type": "word",
"value": "run"
},
{
"type": "word",
"value": "python"
},
{
"type": "word",
"value": "-c
... (1241 more chars)uv run python -c "
import triton
print('tma', hasattr(triton.language, 'make_tensor_descriptor'))
import triton.language as tl
print('experimental', hasattr(tl, '_experimental_make_tensor_descriptor'))
# num sms
import torch
print('sms', torch.cuda.get_device_properties(0).multi_processor_count)
print('smem', torch.cuda.get_device_properties(0).shared_memory_per_block_optin)
"null{
"taskId": "19302",
"blockUntilMs": 30000,
"regex": "sms|Error"
}{
"targetDirectory": "/workspace/.venv/lib/python3.11/site-packages",
"globPattern": "**/grouped_gemm*.py"
}nullnull{
"targetDirectory": "/workspace/.venv/lib/python3.11/site-packages/nvidia_cutlass_dsl",
"globPattern": "**/*.py"
}{
"path": "/home/agent/.cursor/projects/workspace-problems-06-sonic-moe-swiglu/terminals/19302.txt"
}null---
pid: 1516
cwd: "/workspace/problems/06_sonic_moe_swiglu"
command: "uv run python -c \"\nimport triton\nprint('tma', hasattr(triton.language, 'make_tensor_descriptor'))\nimport triton.language as tl\nprint('experimental', hasattr(tl, '_experimental_make_tensor_descriptor'))\n# num sms\nimport torch\nprint('sms', torch.cuda.get_device_properties(0).multi_processor_count)\nprint('smem', torch.cuda.get_device_properties(0).shared_memory_per_block_optin)\n\""
started_at: 2026-06-15T18:24:10.634Z
running_for_ms: 60002
---
tma True
experimental False
sms 188
smem 101376
---
exit_code: 0
elapsed_ms: 60561
ended_at: 2026-06-15T18:25:11.195Z
---
{
"path": "/workspace/.venv/lib/python3.11/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/torch.py",
"limit": 100
}# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import ctypes
from math import prod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional, Tuple, Type, Union
from cutlass.cute.typing import (
Numeric,
Boolean,
TFloat32,
Float8E4M3B11FNUZ,
Float8E4M3FN,
Float8E5M2,
Float8E8M0FNU,
Float4E2M1FN,
Float6E3M2FN,
Float6E2M3FN,
Int4,
Tensor,
)
from cutlass.cute.runtime import from_dlpack
import cutlass.cute as cute
import torch
import cuda.bindings.driver as cuda
def dtype(ty: Type[Numeric]) -> "torch.dtype":
"""
Return the corresponding torch.dtype per the given DSL type
"""
torch_dtype = getattr(torch, ty.__name__.lower(), None)
torch_type_map = {
Boolean: torch.bool,
# TFloat32 is just alias of float32
TFloat32: torch.float32,
Float8E5M2: torch.float8_e5m2,
Float8E4M3FN: torch.float8_e4m3fn,
Float8E4M3B11FNUZ: torch.float8_e4m3fnuz,
}
... (1347 more chars){
"pattern": "GroupedGemm|grouped_gemm",
"path": "/workspace/.venv/lib/python3.11/site-packages/nvidia_cutlass_dsl",
"caseInsensitive": false,
"headLimit": 30,
"multiline": false,
"toolCallId": "tool_3d09542d-e2ba-4a8e-8173-9d3e2cfc7f8",
"offset": 0
}null--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,3 +1,293 @@
- tidx += NUM_SMS
-
- iterated_tiles += num_tiles+"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
+
+Per expert e:
+ h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+def _num_sms() -> int:
+ return torch.cuda.get_device_properties(0).multi_processor_count
+
+
+# ---------------------------------------------------------------------------
+# Per-expert tiled kernel (best when each expert has enough tokens).
+# ---------------------------------------------------------------------------
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8),
+ ],
+ key=["M", "N", "K"],
+)
+@triton.jit
+def _expert_swiglu_kernel(
+ a_ptr,
+ b_gate_ptr,
+ b_up_ptr,
+ c_ptr,
+ M,
+ N,
+ K,
+ m_start,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k in range(0, K, BLOCK_K):
+ offs_k = k + tl.arange(0, BLOCK_K)
+ a_ptrs = (
+ a_ptr + (m_start + offs_m[:, None]) * stride_am + offs_k[None, :] * stride_ak
+ )
+ bg_ptrs = b_gate_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+ bu_ptrs = b_up_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+
+ a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
+ b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+ bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
+ bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
+
+ acc_gate = tl.dot(a, bg, acc_gate)
+ acc_up = tl.dot(a, bu, acc_up)
+
+ gate = acc_gate
+ silu_gate = gate * tl.sigmoid(gate)
+ c = (silu_gate * acc_up).to(tl.bfloat16)
+
+ c_ptrs = c_ptr + (m_start + offs_m[:, None]) * stride_cm + offs_n[None, :] * stride_cn
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+ tl.store(c_ptrs, c, mask=mask)
+
+
+# ---------------------------------------------------------------------------
+# Persistent grouped kernel (handles variable M per expert in one launch).
+# ---------------------------------------------------------------------------
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_stages=3, num_warps=8),
+ ],
+ key=["H", "I", "E"],
+)
+@triton.jit
+def _grouped_swiglu_kernel(
+ a_ptr,
+ b_gate_ptr,
+ b_up_ptr,
+ c_ptr,
+ offsets_ptr,
+ E,
+ H,
+ I,
+ stride_am,
+ stride_ak,
+ stride_bg,
+ stride_bh,
+ stride_bi,
+ stride_cm,
+ stride_cn,
+ NUM_SMS: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ tidx = tl.program_id(0)
+ iterated_tiles = 0
+
+ for g in tl.range(E):
+ m_start = tl.load(offsets_ptr + g)
+ m_end = tl.load(offsets_ptr + g + 1)
+ m_size = m_end - m_start
+
+ num_m_tiles = tl.cdiv(m_size, BLOCK_M)
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ num_tiles = num_m_tiles * num_n_tiles
+
+ if m_size > 0:
+ while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
+ gidx = tidx - iterated_tiles
+ tile_m_idx = gidx % num_m_tiles
+ tile_n_idx = gidx // num_m_tiles
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ for k_block in range(0, H, BLOCK_K):
+ offs_k = k_block + tl.arange(0, BLOCK_K)
+
+ a_ptrs = (
+ a_ptr
+ + (m_start + offs_am[:, None]) * stride_am
+ + offs_k[None, :] * stride_ak
+ )
+ bg_ptrs = (
+ b_gate_ptr
+ + g * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+ bu_ptrs = (
+ b_up_ptr
+ + g * str
... (3894 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<9400 chars \u2014 see diff>"
}null--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,293 +1 @@
-"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
-
-Per expert e:
- h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
-"""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-
-def _num_sms() -> int:
- return torch.cuda.get_device_properties(0).multi_processor_count
-
-
-# ---------------------------------------------------------------------------
-# Per-expert tiled kernel (best when each expert has enough tokens).
-# ---------------------------------------------------------------------------
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8),
- ],
- key=["M", "N", "K"],
-)
-@triton.jit
-def _expert_swiglu_kernel(
- a_ptr,
- b_gate_ptr,
- b_up_ptr,
- c_ptr,
- M,
- N,
- K,
- m_start,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
-):
- pid = tl.program_id(0)
- num_pid_m = tl.cdiv(M, BLOCK_M)
- num_pid_n = tl.cdiv(N, BLOCK_N)
- pid_m = pid % num_pid_m
- pid_n = pid // num_pid_m
-
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k in range(0, K, BLOCK_K):
- offs_k = k + tl.arange(0, BLOCK_K)
- a_ptrs = (
- a_ptr + (m_start + offs_m[:, None]) * stride_am + offs_k[None, :] * stride_ak
- )
- bg_ptrs = b_gate_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
- bu_ptrs = b_up_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
-
- a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
- b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
-
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
- bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
- bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
-
- acc_gate = tl.dot(a, bg, acc_gate)
- acc_up = tl.dot(a, bu, acc_up)
-
- gate = acc_gate
- silu_gate = gate * tl.sigmoid(gate)
- c = (silu_gate * acc_up).to(tl.bfloat16)
-
- c_ptrs = c_ptr + (m_start + offs_m[:, None]) * stride_cm + offs_n[None, :] * stride_cn
- mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
- tl.store(c_ptrs, c, mask=mask)
-
-
-# ---------------------------------------------------------------------------
-# Persistent grouped kernel (handles variable M per expert in one launch).
-# ---------------------------------------------------------------------------
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_stages=3, num_warps=8),
- ],
- key=["H", "I", "E"],
-)
-@triton.jit
-def _grouped_swiglu_kernel(
- a_ptr,
- b_gate_ptr,
- b_up_ptr,
- c_ptr,
- offsets_ptr,
- E,
- H,
- I,
- stride_am,
- stride_ak,
- stride_bg,
- stride_bh,
- stride_bi,
- stride_cm,
- stride_cn,
- NUM_SMS: tl.constexpr,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
-):
- tidx = tl.program_id(0)
- iterated_tiles = 0
-
- for g in tl.range(E):
- m_start = tl.load(offsets_ptr + g)
- m_end = tl.load(offsets_ptr + g + 1)
- m_size = m_end - m_start
-
- num_m_tiles = tl.cdiv(m_size, BLOCK_M)
- num_n_tiles = tl.cdiv(I, BLOCK_N)
- num_tiles = num_m_tiles * num_n_tiles
-
- if m_size > 0:
- while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
- gidx = tidx - iterated_tiles
- tile_m_idx = gidx % num_m_tiles
- tile_n_idx = gidx // num_m_tiles
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
-
- for k_block in range(0, H, BLOCK_K):
- offs_k = k_block + tl.arange(0, BLOCK_K)
-
- a_ptrs = (
- a_ptr
- + (m_start + offs_am[:, None]) * stride_am
- + offs_k[None, :] * stride_ak
- )
- bg_ptrs = (
- b_gate_ptr
- + g * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
- bu_ptrs = (
- b_up_ptr
- + g * stride_bg
- + offs_k[:, None] * stride_bh
-
... (3923 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": " triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64}, num_stages=4, num_warps=8),"
}null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_8a35cbd1-cfcf-42d2-b6c5-3f80cfa552d",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problems/06
... (927 more chars)failure:
command: 'cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42 case nominal: tolerance exceeded: max_abs_diff=0.194336 max_rel_diff=5.21541e+24 bad=9668829/402653184 worst_idx=(96606, 470) (atol=0.02, rtol=0.02)
stderr: ''
executionTime: 84304
interleavedOutput:
FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42 case nominal: tolerance exceeded: max_abs_diff=0.194336 max_rel_diff=5.21541e+24 bad=9668829/402653184 worst_idx=(96606, 470) (atol=0.02, rtol=0.02)
aborted: False
localExecutionTimeMs: 84234
isBackground: False--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1 +1,15 @@
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=8),+ # Per-expert path: one launch per expert with autotuned tile GEMM.
+ offsets = expert_offsets.tolist()
+ for e in range(E):
+ m_start = offsets[e]
+ m_end = offsets[e + 1]
+ m_size = m_end - m_start
+ if m_size == 0:
+ continue
+
+ def grid(meta):
+ return (
+ triton.cdiv(m_size, meta["BLOCK_M"]) * triton.cdiv(I, meta["BLOCK_N"]),
+ )
+
+ _expert_swiglu_kernel[grid]({
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<455 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_aad5aa7c-f2aa-4176-a856-6904523ccf8",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problems/06
... (914 more chars)null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_707beb2b-a2aa-4dce-8b17-610d880775f",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problem
... (930 more chars)null{
"command": "<1737 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_3bbbe7a9-d86a-4256-8a96-2ae30380793",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": true,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problems/06_sonic_moe_swiglu"
},
{
... (752 more chars)cd /workspace/problems/06_sonic_moe_swiglu && uv run python - <<'PY'
import torch, triton, time
import reference, solution, shapes
device = torch.device('cuda')
shape = shapes.SHAPES[0]
for k,v in shape.items():
setattr(reference, k, v)
setattr(solution, k, v)
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*solution.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict())
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# warmup
for _ in range(3):
solution.grouped_swiglu(inputs[0], sol_model.W_gate, sol_model.W_up, inputs[1])
torch.cuda.synchronize()
# time persistent vs per-expert by temporarily swapping
import importlib
mod = importlib.import_module('solution')
# benchmark current
st=time.perf_counter();
for _ in range(10):
mod.grouped_swiglu(inputs[0], sol_model.W_gate, sol_model.W_up, inputs[1])
torch.cuda.synchronize();
print('current ms', (time.perf_counter()-st)*100)
# benchmark persistent kernel only
E,H,I = shape['E'], shape['H'], shape['I']
T_perm = shape['T_total']*shape['K']
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=device)
num_sms = torch.cuda.get_device_properties(0).multi_processor_count
st=time.perf_counter();
for _ in range(10):
mod._grouped_swiglu_kernel[(num_sms,)](
inputs[0], sol_model.W_gate, sol_model.W_up, out, inputs[1],
E,H,I,
inputs[0].stride(0), inputs[0].stride(1),
sol_model.W_gate.stride(0), sol_model.W_gate.stride(1), sol_model.W_gate.stride(2),
out.stride(0), out.stride(1),
NUM_SMS=num_sms,
)
torch.cuda.synchronize();
print('persistent ms', (time.perf_counter()-st)*100)
PY{
"path": "/workspace/.venv/lib/python3.11/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/utils/blackwell_helpers.py",
"limit": 150
}# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Any, List, Optional, Type, Union, Tuple, overload
from typing_extensions import deprecated
from cutlass.cutlass_dsl import (
Float16,
BFloat16,
TFloat32,
Float32,
Uint8,
Int8,
Float8E4M3FN,
Float8E5M2,
Float6E3M2FN,
Float6E2M3FN,
Float4E2M1FN,
Numeric,
NumericMeta,
dsl_user_op,
)
import cutlass.cute as cute
from cutlass._mlir import ir
from cutlass.cute.nvgpu import OperandMajorMode
from cutlass.cute.nvgpu.common import CopyUniversalOp
from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp, StMatrix16x8x8bOp
from cutlass.cute.nvgpu.tcgen05 import (
MmaF16BF16Op,
MmaTF32Op,
MmaI8Op,
MmaF8F6F4Op,
MmaMXF8F6F4Op,
MmaMXF4Op,
MmaMXF4NVF4Op,
SM103MmaMXF4Op,
SM103MmaMXF4NVF4Op,
OperandSource as Tcgen05OperandSource,
CtaGroup,
Ld16x64bOp,
Ld16x128bOp,
Ld16x256bOp,
Ld16x32bx2Op,
Ld32x32bOp,
Repetition,
Pack,
SmemL
... (3447 more chars)null{
"command": "git clone --depth 1 https://github.com/Dao-AILab/quack.git /tmp/quack 2>&1 | tail -5",
"workingDirectory": "",
"timeout": 60000,
"toolCallId": "tool_50c7e552-bc52-4fb4-8bb8-c664363f1eb",
"simpleCommands": [
"git",
"tail"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "git",
"args": [
{
"type": "word",
"value": "clone"
},
{
"type": "word",
"value": "--depth"
... (1116 more chars)--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,15 +1,209 @@
- # Per-expert path: one launch per expert with autotuned tile GEMM.
- offsets = expert_offsets.tolist()
- for e in range(E):
- m_start = offsets[e]
- m_end = offsets[e + 1]
+"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
+
+Per expert e:
+ h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+def _num_sms() -> int:
+ return torch.cuda.get_device_properties(0).multi_processor_count
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 256}, num_stages=2, num_warps=4),
+ ],
+ key=["H", "I", "E"],
+)
+@triton.jit
+def _grouped_swiglu_kernel(
+ a_ptr,
+ b_gate_ptr,
+ b_up_ptr,
+ c_ptr,
+ offsets_ptr,
+ E,
+ H,
+ I,
+ stride_am,
+ stride_ak,
+ stride_bg,
+ stride_bh,
+ stride_bi,
+ stride_cm,
+ stride_cn,
+ NUM_SMS: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ tidx = tl.program_id(0)
+ iterated_tiles = 0
+
+ for g in tl.range(E):
+ m_start = tl.load(offsets_ptr + g)
+ m_end = tl.load(offsets_ptr + g + 1)
m_size = m_end - m_start
- if m_size == 0:
- continue
-
- def grid(meta):
- return (
- triton.cdiv(m_size, meta["BLOCK_M"]) * triton.cdiv(I, meta["BLOCK_N"]),
- )
-
- _expert_swiglu_kernel[grid](+
+ num_m_tiles = tl.cdiv(m_size, BLOCK_M)
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ num_tiles = num_m_tiles * num_n_tiles
+
+ if m_size > 0:
+ while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
+ gidx = tidx - iterated_tiles
+ tile_m_idx = gidx % num_m_tiles
+ tile_n_idx = gidx // num_m_tiles
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ for k_block in range(0, H, BLOCK_K):
+ offs_k = k_block + tl.arange(0, BLOCK_K)
+
+ a_ptrs = (
+ a_ptr
+ + (m_start + offs_am[:, None]) * stride_am
+ + offs_k[None, :] * stride_ak
+ )
+ bg_ptrs = (
+ b_gate_ptr
+ + g * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+ bu_ptrs = (
+ b_up_ptr
+ + g * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+
+ a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
+ b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+ bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
+ bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
+
+ acc_gate = tl.dot(a, bg, acc_gate)
+ acc_up = tl.dot(a, bu, acc_up)
+
+ gate = acc_gate
+ silu_gate = gate * tl.sigmoid(gate)
+ c = (silu_gate * acc_up).to(tl.bfloat16)
+
+ c_ptrs = (
+ c_ptr
+ + (m_start + offs_am[:, None]) * stride_cm
+ + offs_bn[None, :] * stride_cn
+ )
+ c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+ tidx += NUM_SMS
+
+ iterated_tiles += num_tiles
+
+
+def grouped_swiglu(
+ hidden_states: torch.Tensor,
+ W_gate: torch.Tensor,
+ W_up: torch.Tensor,
+ expert_offsets: torch.Tensor,
+) -> torch.Tensor:
+ T_perm, H = hidden_states.shape
+ E, H_w, I = W_gate.shape
+ assert H == H_w and W_up.shape == W_gate.shape
+
+ out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
+ num_sms = _num_sms()
+
+ _grouped_swiglu_kernel[(num_sms,)](
+ hidden_states,
+ W_gate,
+ W_up,
+ out,
+ expert_offsets,
+ E,
+ H,
+ I,
+ hidden_states.stride(0),
+ hidden_states.stride(1),
+ W_gate.stride(0),
+ W_gate.stride(1),
+ W_gate.stride(2),
+ out.stride(0),
+ out.stride(1),
+ NUM_SMS=num_sms,
+ )
+ return out
+
+
+class Model(nn.Module):
+ def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
+ super().__init__()
+ self.T_total = T_total
+ self.H = H
+ self.I = I
+ self.E = E
+ self.K = K
+ self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ nn.init.normal_(self.W_gate, std=0.02)
+ nn.init.normal_(s
... (1108 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<6328 chars \u2014 see diff>"
}nullnull{
"pattern": "swiglu|grouped.*gemm|moe_up",
"path": "/tmp/quack",
"caseInsensitive": true,
"headLimit": 40,
"multiline": false,
"toolCallId": "tool_d74dc5e7-32de-4240-9a63-044dd5dc0e0",
"offset": 0
}{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_bb3ec604-a4b6-4ddc-a1c8-d25c0116af6",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
... (1510 more chars)nullnull{
"pattern": "varlen|grouped|offset",
"path": "/tmp/quack/quack",
"caseInsensitive": false,
"headLimit": 30,
"multiline": false,
"toolCallId": "tool_75f82bb4-c14d-4c75-b451-0365e1a8c93",
"offset": 0
}{
"targetDirectory": "/tmp/quack",
"globPattern": "**/*varlen*"
}nullnull{
"pattern": "linear_gated|gemm_gated",
"path": "/tmp/quack/quack",
"outputMode": "files_with_matches",
"caseInsensitive": false,
"headLimit": 20,
"multiline": false,
"toolCallId": "tool_6a253e9a-9da1-4cec-b916-6ea3986d587",
"offset": 0
}null{
"path": "/tmp/quack/tests/test_linear_varlen_m.py",
"limit": 100
}# Copyright (C) 2025, Tri Dao.
import math
import pytest
import torch
from quack.testing.fake_compat import assert_aliased
from quack.cute_dsl_utils import get_device_capacity
from quack.gemm import gemm as quack_gemm
from quack.gemm_interface import (
gemm,
gemm_ref,
gemm_add,
gemm_add_ref,
gemm_add_inplace,
gemm_act,
gemm_dact,
gemm_act_ref,
gemm_dact_ref,
gemm_gated,
gemm_dgated,
gemm_gated_ref,
gemm_dgated_ref,
)
sm100_tma_gather_only = pytest.mark.skipif(
not torch.cuda.is_available() or get_device_capacity(torch.device("cuda"))[0] not in (10, 11),
reason="TMA gather tests require SM100/SM110",
)
def generate_A_with_gather(total_m, k, device, dtype, gather_A=False):
"""Generate A matrix and optionally A_idx for gather_A case.
Args:
total_m: Number of rows needed
k: Number of columns
device: Device to create tensors on
dtype: Data type of tensors
gather_A: Whether to create gather indices
Returns:
A: Matrix of shape (larger_m, k) if gather_A else (total_m, k)
A_idx: Index tensor of shape (total_m,) if gather_A else None
"""
if gather_A:
# Create random indices for gathering from a larger A matrix
larger_m = total_m * 2 # Make A larger than needed
A = torch.randn((larger_m, k), device=device, dtype=dtype)
# Create random indices to gather from A
A_idx = torch.randperm(larger_m, device=device,
... (1245 more chars){
"path": "/tmp/quack/quack/gemm_interface.py",
"limit": 150
}# Copyright (c) 2025, Tri Dao
import os
from typing import Optional, Tuple, Literal
from functools import partial
import torch
import torch.nn.functional as F
from torch import Tensor
from quack.gemm_config import GemmConfig, get_all_configs
from quack.autotuner import autotune, AutotuneConfig
from quack.cute_dsl_utils import get_device_capacity
from quack.gemm import gemm as gemm_dispatch
from quack.gemm_act import gemm_act as gemm_act_dispatch
from quack.gemm_dact import gemm_dact as gemm_dact_dispatch
from quack.gemm_symmetric import gemm_symmetric as gemm_symmetric_dispatch
from quack.gemm_sq_reduce import gemm_sq_reduce as gemm_sq_reduce_dispatch
from quack.gemm_norm_act import gemm_norm_act_fn as gemm_norm_act_dispatch
from quack.rms_final_reduce import rms_final_reduce
from quack.rounding import RoundingMode
def _empty_k_matmul_into(
out: Tensor,
*,
bias: Optional[Tensor] = None,
C: Optional[Tensor] = None,
beta: float | Tensor = 1.0,
) -> None:
"""K=0 fast path: write `beta * C + bias` (or zero if neither) into `out`.
Used by every gemm-flavored wrapper to skip a kernel launch when the
contraction dim is empty. The matmul A @ B contributes zero, so the only
remaining terms are the C term and the (broadcast) bias.
"""
if C is not None:
if isinstance(beta, float) and beta == 1.0:
out.copy_(C)
else:
torch.mul(C, beta, out=out)
else:
out.zero_()
if bias is not None:
... (2852 more chars){
"pattern": "cu_seqlens|varlen_m",
"path": "/tmp/quack/quack",
"caseInsensitive": false,
"headLimit": 30,
"multiline": false,
"toolCallId": "tool_8088ed3e-e018-4136-b649-04664b4016d",
"offset": 0
}null{
"command": "uv pip install quack-kernels 2>&1 | tail -15",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_2e299d36-692c-44d0-a883-3274da471d8",
"simpleCommands": [
"uv",
"tail"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "uv",
"args": [
{
"type": "word",
"value": "pip"
},
{
"type": "word",
"value": "install"
},
{
"type": "wo
... (857 more chars){
"path": "/tmp/quack/quack/gemm_interface.py",
"offset": 145,
"limit": 80
}elif cap == 12:
return GemmConfig(
tile_m=128,
tile_n=128,
cluster_m=1,
cluster_n=1,
pingpong=True,
is_dynamic_persistent=True,
device_capacity=12,
)
else:
return GemmConfig(
tile_m=128,
tile_n=192,
cluster_m=2,
cluster_n=1,
pingpong=True,
is_dynamic_persistent=False,
)
def nvmmh_config(A, B, device_capacity):
"""Use nvMatmulHeuristics to pick a config for pure GEMM (no varlen/gather/epilogue).
Returns None if unavailable, caller should fall back to default_config.
"""
try:
from quack.nvmmh_heuristic import nvmmh_default_config
return nvmmh_default_config(A, B, device_capacity)
except Exception:
return None
def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
kwargs = named_args | kwargs
device_capacity = get_device_capacity(kwargs["A"].device)[0]
configs = [conf for conf in configs if conf.kwargs["config"].device_capacity == device_capacity]
gather_A = kwargs.get("A_idx", None) is not None
varlen_m = kwargs.get("cu_seqlens_m", None) is not None
if varlen_m or gather_A: # Doesn't support swap_ab
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
if gather_A:
configs = [conf for conf in configs if conf.kwargs["config"].cluster_n == 1]
if device_capa
... (1878 more chars)null{
"pattern": "def gemm_gated",
"path": "/tmp/quack/quack/gemm_interface.py",
"caseInsensitive": false,
"headLimit": 20,
"multiline": false,
"toolCallId": "tool_e0e35520-dc92-48c0-a79d-72af6f315ce",
"offset": 0
}null{
"command": "<1331 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_6d4c10ef-0004-4273-b108-631bc21d58d",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": true,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problems/06_sonic_moe_swiglu"
},
{
... (750 more chars)cd /workspace/problems/06_sonic_moe_swiglu && uv run python - <<'PY'
import torch, time
import reference, shapes
from quack.gemm_interface import gemm_gated
device = torch.device('cuda')
shape = shapes.SHAPES[0]
for k,v in shape.items():
setattr(reference, k, v)
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
hidden, offsets = inputs
T_perm, H = hidden.shape
E, _, I = ref_model.W_gate.shape
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=device)
# Stack gate and up weights interleaved for gated gemm? Check quack API
# gemm_gated likely expects B with 2*N for gate+up
# Try: B shape (E, H, 2*I) with gate and up concatenated on last dim
W = torch.cat([ref_model.W_gate, ref_model.W_up], dim=-1).contiguous()
# warmup
for _ in range(3):
gemm_gated(hidden, W, out, cu_seqlens_m=offsets, activation='swiglu')
torch.cuda.synchronize()
st=time.perf_counter()
for _ in range(20):
gemm_gated(hidden, W, out, cu_seqlens_m=offsets, activation='swiglu')
torch.cuda.synchronize()
ms=(time.perf_counter()-st)/20*1000
print('quack ms', ms)
# correctness vs ref
with torch.no_grad():
ref_out = ref_model(hidden, offsets)
err = (ref_out.float()-out.float()).abs().max().item()
print('max err', err)
PYfailure:
command:
cd /workspace/problems/06_sonic_moe_swiglu && uv run python - <<'PY'
import torch, time
import reference, shapes
from quack.gemm_interface import gemm_gated
device = torch.device('cuda')
shape = shapes.SHAPES[0]
for k,v in shape.items():
setattr(reference, k, v)
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
hidden, offsets = inputs
T_perm, H = hidden.shape
E, _, I = ref_model.W_gate.shape
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=device)
# Stack gate and up weights interleaved for gated gemm? Check quack API
# gemm_gated likely expects B with 2*N for gate+up
# Try: B shape (E, H, 2*I) with gate and up concatenated on last dim
W = torch.cat([ref_model.W_gate, ref_model.W_up], dim=-1).contiguous()
# warmup
for _ in range(3):
gemm_gated(hidden, W, out, cu_seqlens_m=offsets, activation='swiglu')
torch.cuda.synchronize()
st=time.perf_counter()
for _ in range(20):
gemm_gated(hidden, W, out, cu_seqlens_m=offsets, activation='swiglu')
torch.cuda.synchronize()
ms=(time.perf_counter()-st)/20*1000
print('quack ms', ms)
# correctness vs ref
with torch.no_grad():
ref_out = ref_model(hidden, offsets)
err = (ref_out.float()-out.float()).abs().max().item()
print
... (11027 more chars){
"path": "/tmp/quack/quack/gemm_interface.py",
"offset": 1390,
"limit": 100
}configs=[AutotuneConfig(config=c) for c in get_all_configs("gated")],
key=["activation", "dynamic_scheduler"],
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
)
def gemm_gated_tuned(
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
A: Tensor,
B: Tensor, # (K, N) or (L, K, N)
# (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
preact_out: Optional[Tensor],
postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
bias: Optional[Tensor] = None, # (N,) or (L, N)
activation: GatedActivation = "swiglu",
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
dynamic_scheduler: bool = False,
config: Optional[GemmConfig] = None,
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
) -> None:
if config is None:
config = default_config(A.device)
varlen_m = cu_seqlens_m is not None
if varlen_m:
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
if A.ndim == 2 and not varlen_m:
A = A.unsqueeze(0) # (1, M, K)
B = B.mT # (N, K) or (L, N, K)
if B.ndim == 2:
B = B.unsqueeze(0) # (1, N, K)
if C is not None and C.ndim == 2 and not va
... (2914 more chars){
"path": "/tmp/quack/tests/test_linear_varlen_m.py",
"offset": 520,
"limit": 80
}A, B, PreAct, activation=activation, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
)
assert (dx - dx_ref).abs().max() < 2 * (dx_pt - dx_ref).abs().max() + 1e-5
assert (postact - postact_ref).abs().max() < 2 * (postact_pt - postact_ref).abs().max() + 1e-5
@pytest.mark.parametrize("pre_allocate_out", [False, True])
@pytest.mark.parametrize("gather_A", [False, True])
@pytest.mark.parametrize("activation", ["swiglu", "geglu"])
@pytest.mark.parametrize("dynamic_scheduler", [False, True])
@pytest.mark.parametrize("B_major", ["k", "n"])
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("n", [1024, 1504])
@pytest.mark.parametrize("k", [512, 768])
@pytest.mark.parametrize("num_groups", [2, 4])
def test_gemm_gated_varlen_m(
num_groups,
k,
n,
has_bias,
input_dtype,
B_major,
dynamic_scheduler,
activation,
gather_A,
pre_allocate_out,
):
"""Test GEMM with gated activation and variable length M dimension."""
device = "cuda"
torch.random.manual_seed(42)
seq_lens = torch.randint(50, 300, (num_groups,), device="cpu")
total_m = seq_lens.sum().item()
cu_seqlens_m = torch.cat(
[torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)]
)
cu_seqlens_m = cu_seqlens_m.to(device)
A, A_idx = generate_A_with_gather(total_m, k, device, input_dtype, gather_A)
B = torch.randn((num_groups, k, n), device=device, dtype=
... (1383 more chars){
"pattern": "^def gemm_gated\\(",
"path": "/tmp/quack/quack/gemm_interface.py",
"contextAfter": 40,
"caseInsensitive": false,
"multiline": false,
"toolCallId": "tool_8b13b8c7-6c72-4fb9-986f-c652cce81f3",
"offset": 0
}null{
"pattern": "gemm_gated =",
"path": "/tmp/quack/quack/gemm_interface.py",
"caseInsensitive": false,
"headLimit": 10,
"multiline": false,
"toolCallId": "tool_03877782-3208-4453-af17-f9f2d999b49",
"offset": 0
}{
"path": "/tmp/quack/quack/gemm_interface.py",
"offset": 1575,
"limit": 80
}null@torch.library.custom_op(
"quack::gemm_gated_out",
mutates_args=("preact_out", "postact_out"),
device_types="cuda",
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str activation='swiglu', Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True, str? concat_layout=None) -> ()",
)
def gemm_gated_out(
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
B: Tensor, # (K, N) or (L, K, N)
preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
bias: Optional[Tensor] = None, # (N,) or (L, N)
activation: GatedActivation = "swiglu",
cu_seqlens_m: Optional[Tensor] = None,
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
dynamic_scheduler: bool = False,
tuned: bool = True,
concat_layout: Optional[str] = None,
) -> None:
"""GEMM with gated activation and pre-allocated output tensors."""
fn = gemm_gated_tuned if tuned else partial(gemm_gated_tuned.fn, config=None)
fn(
A,
B,
preact_out,
postact_out,
C,
bias,
activation,
cu_seqlens_m,
A_idx,
dynamic_scheduler,
concat_lay
... (1867 more chars){
"path": "/tmp/quack/quack/gemm_interface.py",
"offset": 990,
"limit": 80
}dynamic_scheduler: bool = False,
tuned: bool = True,
concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
) -> Tuple[Optional[Tensor], Tensor]:
"""GEMM with activation (or gated activation) and optional output tensors."""
is_gated = activation in gated_to_pytorch_fn_map
out_dtype = A.dtype if out_dtype is None else out_dtype
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
varlen_m = cu_seqlens_m is not None
# Determine output shape based on gather_A
if varlen_m:
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
out_shape = (total_m, B.shape[-1])
elif A.ndim == 2:
out_shape = (A.shape[0], B.shape[-1])
else:
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_gated else out_shape
if preact_out is None and store_preact:
preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
if postact_out is None:
postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
# Empty-input fast path. For M=0 or N=0 the outputs are empty; for K=0
# (A@B == 0) the no-bias / no-C surface yields preact=0 and act(0)=0 for
# every supported activation, so both outputs are zero.
if postact_out.numel() == 0 or A.numel() == 0:
if preact_out is not None:
_empty_k_matmul_into(preact_out)
_empty_k_matmu
... (1650 more chars){
"command": "<1435 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 180000,
"toolCallId": "tool_541763bf-e7c1-4077-a975-0a0d6f0ac61",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": true,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
"fullText": "cd /workspace/problems/06_sonic_moe_swiglu"
},
{
... (750 more chars)cd /workspace/problems/06_sonic_moe_swiglu && uv run python - <<'PY'
import torch, time
import reference, shapes
from quack.gemm_interface import gemm_gated
device = torch.device('cuda')
for si, shape in enumerate(shapes.SHAPES):
for k,v in shape.items():
setattr(reference, k, v)
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
torch.manual_seed(42)
hidden, offsets = [t.to(device) for t in reference.get_inputs()]
T_perm, H = hidden.shape
E, _, I = ref_model.W_gate.shape
W = torch.cat([ref_model.W_gate, ref_model.W_up], dim=-1).contiguous()
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=device)
# warmup with tuned=False
for _ in range(5):
gemm_gated(hidden, W, None, out, cu_seqlens_m=offsets, activation='swiglu', tuned=False, dynamic_scheduler=True)
torch.cuda.synchronize()
st=time.perf_counter()
for _ in range(20):
gemm_gated(hidden, W, None, out, cu_seqlens_m=offsets, activation='swiglu', tuned=False, dynamic_scheduler=True)
torch.cuda.synchronize()
ms=(time.perf_counter()-st)/20*1000
with torch.no_grad():
ref_out = ref_model(hidden, offsets)
err = (ref_out.float()-out.float()).abs().max().item()
flops = 2 * shape['T_total'] * shape['H'] * (2 * shape['I'])
tflops = flops / ms / 1e9
print(f'shape {si} quack ms={ms:.3f} tflops={tflops:.1f} max_err={err:.4f}')
PYfailure:
command:
cd /workspace/problems/06_sonic_moe_swiglu && uv run python - <<'PY'
import torch, time
import reference, shapes
from quack.gemm_interface import gemm_gated
device = torch.device('cuda')
for si, shape in enumerate(shapes.SHAPES):
for k,v in shape.items():
setattr(reference, k, v)
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
torch.manual_seed(42)
hidden, offsets = [t.to(device) for t in reference.get_inputs()]
T_perm, H = hidden.shape
E, _, I = ref_model.W_gate.shape
W = torch.cat([ref_model.W_gate, ref_model.W_up], dim=-1).contiguous()
out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=device)
# warmup with tuned=False
for _ in range(5):
gemm_gated(hidden, W, None, out, cu_seqlens_m=offsets, activation='swiglu', tuned=False, dynamic_scheduler=True)
torch.cuda.synchronize()
st=time.perf_counter()
for _ in range(20):
gemm_gated(hidden, W, None, out, cu_seqlens_m=offsets, activation='swiglu', tuned=False, dynamic_scheduler=True)
torch.cuda.synchronize()
ms=(time.perf_counter()-st)/20*1000
with torch.no_grad():
ref_out = ref_model(hidden, offsets)
err = (ref_out.float()-out.float()).abs().max().item()
flops = 2 * shape['T_total'] * shape['H'] * (2 * shape['I'])
tflops = flops / ms /
... (7847 more chars){
"pattern": "concat_interleave|2 \\* n|gate.*up",
"path": "/tmp/quack",
"caseInsensitive": false,
"headLimit": 30,
"multiline": false,
"toolCallId": "tool_a998c147-c077-4b59-9a6b-50e30e1b83b",
"offset": 0
}null{
"path": "/tmp/quack/quack/layout_utils.py"
}{
"path": "/tmp/quack/quack/mlp.py",
"offset": 200,
"limit": 80
}if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
return matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
else:
matmul_inplace_fn(dout.T, x, weight_og.grad)
dweight = weight_og.grad
weight_og.grad = None
return dweight
def mlp_func(
x,
weight1,
weight2,
activation: str,
bias1=None,
bias2=None,
fuse_grad_accum=False,
tuned=True,
recompute=False,
concat_layout=False,
):
gated = activation in gate_fn_map
if concat_layout:
assert gated, "concat_layout is only supported for gated MLP"
if recompute:
if concat_layout:
ops = _MLPGatedConcatOps if tuned else _MLPGatedConcatUntunedOps
elif gated:
ops = _MLPGatedOps if tuned else _MLPGatedUntunedOps
else:
ops = _MLPOps if tuned else _MLPUntunedOps
return MLPRecomputeFunc.apply(x, weight1, weight2, activation, fuse_grad_accum, ops)
fc1_fn = linear_gated_func if gated else linear_act_func
fc2_fn = gated_linear_func if gated else act_linear_func
preact, postact = fc1_fn(
x,
weight1,
activation,
bias=bias1,
store_preact=torch.is_grad_enabled(),
fuse_grad_accum=fuse_grad_accum,
tuned=tuned,
**({"concat_layout": concat_layout} if concat_layout and gated else {}),
)
out = fc2_fn(
preact,
weight2,
postact,
activation=activa
... (929 more chars)null--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -18,22 +18,23 @@
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 256}, num_stages=2, num_warps=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 192, "BLOCK_K": 64}, num_stages=4, num_warps=8),
],
- key=["H", "I", "E"],
+ key=["H", "I"],
)
@triton.jit
-def _grouped_swiglu_kernel(
+def _swiglu_tile_kernel(
a_ptr,
b_gate_ptr,
b_up_ptr,
c_ptr,
- offsets_ptr,
- E,
+ tile_info_ptr,
+ num_tiles,
H,
I,
stride_am,
@@ -43,81 +44,102 @@
stride_bi,
stride_cm,
stride_cn,
- NUM_SMS: tl.constexpr,
+ NUM_WORKERS: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
tidx = tl.program_id(0)
- iterated_tiles = 0
-
- for g in tl.range(E):
- m_start = tl.load(offsets_ptr + g)
- m_end = tl.load(offsets_ptr + g + 1)
+
+ while tidx < num_tiles:
+ expert = tl.load(tile_info_ptr + tidx * 4 + 0)
+ m_start = tl.load(tile_info_ptr + tidx * 4 + 1)
+ m_size = tl.load(tile_info_ptr + tidx * 4 + 2)
+ tile_m_idx = tl.load(tile_info_ptr + tidx * 4 + 3)
+ tile_n_idx = tl.load(tile_info_ptr + tidx * 4 + 4)
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ for k_block in range(0, H, BLOCK_K):
+ offs_k = k_block + tl.arange(0, BLOCK_K)
+
+ a_ptrs = (
+ a_ptr
+ + (m_start + offs_am[:, None]) * stride_am
+ + offs_k[None, :] * stride_ak
+ )
+ bg_ptrs = (
+ b_gate_ptr
+ + expert * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+ bu_ptrs = (
+ b_up_ptr
+ + expert * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+
+ a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
+ b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+ bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
+ bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
+
+ acc_gate = tl.dot(a, bg, acc_gate)
+ acc_up = tl.dot(a, bu, acc_up)
+
+ gate = acc_gate
+ silu_gate = gate * tl.sigmoid(gate)
+ c = (silu_gate * acc_up).to(tl.bfloat16)
+
+ c_ptrs = (
+ c_ptr
+ + (m_start + offs_am[:, None]) * stride_cm
+ + offs_bn[None, :] * stride_cn
+ )
+ c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+ tidx += NUM_WORKERS
+
+
+def _build_tile_schedule(
+ expert_offsets: torch.Tensor,
+ I: int,
+ block_m: int,
+ block_n: int,
+ device: torch.device,
+) -> tuple[torch.Tensor, int]:
+ offsets = expert_offsets.tolist()
+ E = len(offsets) - 1
+ rows: list[list[int]] = []
+ for e in range(E):
+ m_start = offsets[e]
+ m_end = offsets[e + 1]
m_size = m_end - m_start
-
- num_m_tiles = tl.cdiv(m_size, BLOCK_M)
- num_n_tiles = tl.cdiv(I, BLOCK_N)
- num_tiles = num_m_tiles * num_n_tiles
-
- if m_size > 0:
- while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
- gidx = tidx - iterated_tiles
- tile_m_idx = gidx % num_m_tiles
- tile_n_idx = gidx // num_m_tiles
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
-
- for k_block in range(0, H, BLOCK_K):
- offs_k = k_block + tl.arange(0, BLOCK_K)
-
- a_ptrs = (
- a_ptr
- + (m_start + offs_am[:, None]) * stride_am
- + offs_k[None, :] * stride_ak
- )
- bg_ptrs = (
- b_gate_ptr
- + g * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
- bu_ptrs = (
- b_up_ptr
- + g * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
-
- a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
- b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
-
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
- bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
- bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
-
-
... (2525 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<7383 chars \u2014 see diff>"
}null--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,245 +1,5 @@
-"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
-
-Per expert e:
- h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
-"""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-
-def _num_sms() -> int:
- return torch.cuda.get_device_properties(0).multi_processor_count
-
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 256}, num_stages=2, num_warps=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 192, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- ],
- key=["H", "I"],
-)
-@triton.jit
-def _swiglu_tile_kernel(
- a_ptr,
- b_gate_ptr,
- b_up_ptr,
- c_ptr,
- tile_info_ptr,
- num_tiles,
- H,
- I,
- stride_am,
- stride_ak,
- stride_bg,
- stride_bh,
- stride_bi,
- stride_cm,
- stride_cn,
- NUM_WORKERS: tl.constexpr,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
-):
- tidx = tl.program_id(0)
-
- while tidx < num_tiles:
- expert = tl.load(tile_info_ptr + tidx * 4 + 0)
- m_start = tl.load(tile_info_ptr + tidx * 4 + 1)
- m_size = tl.load(tile_info_ptr + tidx * 4 + 2)
- tile_m_idx = tl.load(tile_info_ptr + tidx * 4 + 3)
- tile_n_idx = tl.load(tile_info_ptr + tidx * 4 + 4)
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
-
- for k_block in range(0, H, BLOCK_K):
- offs_k = k_block + tl.arange(0, BLOCK_K)
-
- a_ptrs = (
- a_ptr
- + (m_start + offs_am[:, None]) * stride_am
- + offs_k[None, :] * stride_ak
- )
- bg_ptrs = (
- b_gate_ptr
- + expert * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
- bu_ptrs = (
- b_up_ptr
- + expert * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
-
- a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
- b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
-
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
- bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
- bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
-
- acc_gate = tl.dot(a, bg, acc_gate)
- acc_up = tl.dot(a, bu, acc_up)
-
- gate = acc_gate
- silu_gate = gate * tl.sigmoid(gate)
- c = (silu_gate * acc_up).to(tl.bfloat16)
-
- c_ptrs = (
- c_ptr
- + (m_start + offs_am[:, None]) * stride_cm
- + offs_bn[None, :] * stride_cn
- )
- c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
- tl.store(c_ptrs, c, mask=c_mask)
-
- tidx += NUM_WORKERS
-
-
-def _build_tile_schedule(
- expert_offsets: torch.Tensor,
- I: int,
- block_m: int,
- block_n: int,
- device: torch.device,
-) -> tuple[torch.Tensor, int]:
- offsets = expert_offsets.tolist()
- E = len(offsets) - 1
- rows: list[list[int]] = []
- for e in range(E):
- m_start = offsets[e]
- m_end = offsets[e + 1]
- m_size = m_end - m_start
- if m_size == 0:
- continue
- num_m = triton.cdiv(m_size, block_m)
- num_n = triton.cdiv(I, block_n)
- for tm in range(num_m):
- for tn in range(num_n):
- rows.append([e, m_start, m_size, tm, tn])
- if not rows:
- tile_info = torch.empty(0, 5, dtype=torch.int32, device=device)
- return tile_info, 0
- tile_info = torch.tensor(rows, dtype=torch.int32, device=device)
- return tile_info, tile_info.shape[0]
-
-
-# Cache tile schedules keyed by (offsets ptr, I, block sizes).
-_schedule_cache: dict[tuple, tuple[torch.Tensor, int]] = {}
-
-
-def grouped_swiglu(
- hidden_states: torch.Tensor,
- W_gate: torch.Tensor,
- W_up: torch.Tensor,
- expert_offsets: torch.Tensor,
-) -> torch.Tensor:
- T_perm, H = hidden_states.shape
- E, H_w, I = W_gate.shape
- assert H == H_w and W_up.shape == W_gate.shape
-
- out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
-
- # Use a representative block size for scheduling; autotune may pick different
- # sizes but schedule with 128x128 is a good default for SM120.
- block_m, block_n = 128, 128
- cache_key = (expert_offsets.data_ptr(), I, block_m, block_n)
- if cache_key not in _schedule_cache:
- _schedule_cache[cache_key] = _build_tile_schedule(
- expert_offsets, I, block_m, block_n, hidden_states.device
- )
- tile_info, num_tiles = _schedule_cache[cache_key]
- if num_tiles == 0:
- return out
-
- num_workers = _num_sms()
- grid = (num_workers,)
-
- _swiglu_tile_kernel[grid](
- hidden_states,
- W_gate,
- W_up,
- out,
- tile_info,
- num_tiles,
- H,
- I,
- hidden_states.stride(0),
- hidden_states.stride(1),
- W_gate.st
... (2050 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<283 chars \u2014 see diff>"
}null--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,5 +1,245 @@
- expert = tl.load(tile_info_ptr + tidx * 5 + 0)
- m_start = tl.load(tile_info_ptr + tidx * 5 + 1)
- m_size = tl.load(tile_info_ptr + tidx * 5 + 2)
- tile_m_idx = tl.load(tile_info_ptr + tidx * 5 + 3)
- tile_n_idx = tl.load(tile_info_ptr + tidx * 5 + 4)+"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
+
+Per expert e:
+ h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+def _num_sms() -> int:
+ return torch.cuda.get_device_properties(0).multi_processor_count
+
+
+def _pick_config(H: int, I: int) -> dict:
+ if I >= 4096:
+ return {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8, "num_stages": 4}
+ if H >= 4096:
+ return {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8, "num_stages": 4}
+ return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128, "num_warps": 8, "num_stages": 3}
+
+
+@triton.jit
+def _swiglu_tile_kernel(
+ a_ptr,
+ b_gate_ptr,
+ b_up_ptr,
+ c_ptr,
+ tile_info_ptr,
+ num_tiles,
+ H,
+ I,
+ stride_am,
+ stride_ak,
+ stride_bg,
+ stride_bh,
+ stride_bi,
+ stride_cm,
+ stride_cn,
+ NUM_WORKERS: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ tidx = tl.program_id(0)
+
+ while tidx < num_tiles:
+ base = tidx * 5
+ expert = tl.load(tile_info_ptr + base + 0)
+ m_start = tl.load(tile_info_ptr + base + 1)
+ m_size = tl.load(tile_info_ptr + base + 2)
+ tile_m_idx = tl.load(tile_info_ptr + base + 3)
+ tile_n_idx = tl.load(tile_info_ptr + base + 4)
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ for k_block in range(0, H, BLOCK_K):
+ offs_k = k_block + tl.arange(0, BLOCK_K)
+
+ a_ptrs = (
+ a_ptr
+ + (m_start + offs_am[:, None]) * stride_am
+ + offs_k[None, :] * stride_ak
+ )
+ bg_ptrs = (
+ b_gate_ptr
+ + expert * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+ bu_ptrs = (
+ b_up_ptr
+ + expert * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+
+ a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
+ b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+ bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
+ bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
+
+ acc_gate = tl.dot(a, bg, acc_gate)
+ acc_up = tl.dot(a, bu, acc_up)
+
+ gate = acc_gate
+ silu_gate = gate * tl.sigmoid(gate)
+ c = (silu_gate * acc_up).to(tl.bfloat16)
+
+ c_ptrs = (
+ c_ptr
+ + (m_start + offs_am[:, None]) * stride_cm
+ + offs_bn[None, :] * stride_cn
+ )
+ c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+ tidx += NUM_WORKERS
+
+
+def _build_tile_schedule(
+ expert_offsets: torch.Tensor,
+ I: int,
+ block_m: int,
+ block_n: int,
+ device: torch.device,
+) -> tuple[torch.Tensor, int]:
+ offsets = expert_offsets.tolist()
+ E = len(offsets) - 1
+ rows: list[list[int]] = []
+ for e in range(E):
+ m_start = offsets[e]
+ m_end = offsets[e + 1]
+ m_size = m_end - m_start
+ if m_size == 0:
+ continue
+ num_m = triton.cdiv(m_size, block_m)
+ num_n = triton.cdiv(I, block_n)
+ for tm in range(num_m):
+ for tn in range(num_n):
+ rows.append([e, m_start, m_size, tm, tn])
+ if not rows:
+ return torch.empty(0, dtype=torch.int32, device=device), 0
+ tile_info = torch.tensor(rows, dtype=torch.int32, device=device).reshape(-1)
+ return tile_info, len(rows)
+
+
+_schedule_cache: dict[tuple, tuple[torch.Tensor, int, dict]] = {}
+
+
+def grouped_swiglu(
+ hidden_states: torch.Tensor,
+ W_gate: torch.Tensor,
+ W_up: torch.Tensor,
+ expert_offsets: torch.Tensor,
+) -> torch.Tensor:
+ T_perm, H = hidden_states.shape
+ E, H_w, I = W_gate.shape
+ assert H == H_w and W_up.shape == W_gate.shape
+
+ out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
+ cfg = _pick_config(H, I)
+ block_m, block_n = cfg["BLOCK_M"], cfg["BLOCK_N"]
+
+ cache_key = (expert_offsets.data_ptr(), I, block_m, block_n)
+ if cache_key not in _schedule_cache:
+ tile_info, num_tiles = _build_tile_schedule(
+ expert_offsets, I, block_m, block_n, hidden_states.device
+ )
+ _schedule_cache[cache_key] = (tile_info, num_tiles, cfg)
+ else:
+ tile_info, num_tiles, cfg = _schedule_cache[cache_key]
+
+ if num_tiles == 0:
+ return out
+
+ num_workers = _num_sms()
+ _swiglu_tile_kernel[(num_workers,)](
+ hidden_states,
+ W_gate,
+ W_up,
+ out,
+ tile_info,
+ num_tiles,
+ H,
+ I,
+ hidden_states.stride(0),
+ hidden_states.stride(1),
+ W_gate.stride(0),
+ W_gate.stride(1),
+ W_gate.stride(2),
+ out.stride(0),
+ out.stride(1),
+ NUM_WORKERS=num_workers,
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=cfg["BLOCK_
... (1698 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<7031 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_50f38f89-4bba-496f-85f2-7b3053845e2",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
... (1512 more chars)failure:
command: 'cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/problems/06_sonic_moe_swiglu/check.py", line 117, in <module>
main()
File "/workspace/problems/06_sonic_moe_swiglu/check.py", line 78, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/06_sonic_moe_swiglu/solution.py", line 211, in forward
return grouped_swiglu(
^^^^^^^^^^^^^^^
File "/workspace/problems/06_sonic_moe_swiglu/solution.py", line 167, in grouped_swiglu
_swiglu_tile_kernel[(num_workers,)](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.ve
... (3182 more chars)--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,245 +1,6 @@
-"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
-
-Per expert e:
- h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
-"""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-
-def _num_sms() -> int:
- return torch.cuda.get_device_properties(0).multi_processor_count
-
-
def _pick_config(H: int, I: int) -> dict:
if I >= 4096:
- return {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8, "num_stages": 4}
+ return {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8, "num_stages": 3}
if H >= 4096:
- return {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8, "num_stages": 4}
- return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128, "num_warps": 8, "num_stages": 3}
-
-
-@triton.jit
-def _swiglu_tile_kernel(
- a_ptr,
- b_gate_ptr,
- b_up_ptr,
- c_ptr,
- tile_info_ptr,
- num_tiles,
- H,
- I,
- stride_am,
- stride_ak,
- stride_bg,
- stride_bh,
- stride_bi,
- stride_cm,
- stride_cn,
- NUM_WORKERS: tl.constexpr,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
-):
- tidx = tl.program_id(0)
-
- while tidx < num_tiles:
- base = tidx * 5
- expert = tl.load(tile_info_ptr + base + 0)
- m_start = tl.load(tile_info_ptr + base + 1)
- m_size = tl.load(tile_info_ptr + base + 2)
- tile_m_idx = tl.load(tile_info_ptr + base + 3)
- tile_n_idx = tl.load(tile_info_ptr + base + 4)
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
-
- for k_block in range(0, H, BLOCK_K):
- offs_k = k_block + tl.arange(0, BLOCK_K)
-
- a_ptrs = (
- a_ptr
- + (m_start + offs_am[:, None]) * stride_am
- + offs_k[None, :] * stride_ak
- )
- bg_ptrs = (
- b_gate_ptr
- + expert * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
- bu_ptrs = (
- b_up_ptr
- + expert * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
-
- a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
- b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
-
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
- bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
- bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
-
- acc_gate = tl.dot(a, bg, acc_gate)
- acc_up = tl.dot(a, bu, acc_up)
-
- gate = acc_gate
- silu_gate = gate * tl.sigmoid(gate)
- c = (silu_gate * acc_up).to(tl.bfloat16)
-
- c_ptrs = (
- c_ptr
- + (m_start + offs_am[:, None]) * stride_cm
- + offs_bn[None, :] * stride_cn
- )
- c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
- tl.store(c_ptrs, c, mask=c_mask)
-
- tidx += NUM_WORKERS
-
-
-def _build_tile_schedule(
- expert_offsets: torch.Tensor,
- I: int,
- block_m: int,
- block_n: int,
- device: torch.device,
-) -> tuple[torch.Tensor, int]:
- offsets = expert_offsets.tolist()
- E = len(offsets) - 1
- rows: list[list[int]] = []
- for e in range(E):
- m_start = offsets[e]
- m_end = offsets[e + 1]
- m_size = m_end - m_start
- if m_size == 0:
- continue
- num_m = triton.cdiv(m_size, block_m)
- num_n = triton.cdiv(I, block_n)
- for tm in range(num_m):
- for tn in range(num_n):
- rows.append([e, m_start, m_size, tm, tn])
- if not rows:
- return torch.empty(0, dtype=torch.int32, device=device), 0
- tile_info = torch.tensor(rows, dtype=torch.int32, device=device).reshape(-1)
- return tile_info, len(rows)
-
-
-_schedule_cache: dict[tuple, tuple[torch.Tensor, int, dict]] = {}
-
-
-def grouped_swiglu(
- hidden_states: torch.Tensor,
- W_gate: torch.Tensor,
- W_up: torch.Tensor,
- expert_offsets: torch.Tensor,
-) -> torch.Tensor:
- T_perm, H = hidden_states.shape
- E, H_w, I = W_gate.shape
- assert H == H_w and W_up.shape == W_gate.shape
-
- out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
- cfg = _pick_config(H, I)
- block_m, block_n = cfg["BLOCK_M"], cfg["BLOCK_N"]
-
- cache_key = (expert_offsets.data_ptr(), I, block_m, block_n)
- if cache_key not in _schedule_cache:
- tile_info, num_tiles = _build_tile_schedule(
- expert_offsets, I, block_m, block_n, hidden_states.device
- )
- _schedule_cache[cache_key] = (tile_info, num_tiles, cfg)
- else:
- tile_info, num_tiles, cfg = _schedule_cache[cache_key]
-
- if num_tiles == 0:
- return out
-
- num_workers = _num_sms()
- _swiglu_tile_kernel[(num_workers,)](
- hidden_states,
- W_gate,
- W_up,
- out,
- tile_info,
- num_tiles,
- H,
- I,
- hidden_states.stride(0),
- hidden_states.stride(1),
- W_gate.stride(0),
- W_gate.stride(1),
- W_gate.stride(2),
- out.stride(0),
- out.stride(1),
- NUM_WORKERS=num_workers,
- BLOCK_M=block_m,
- BLOCK_N=block_n,
- BLOCK_K=cfg["BLOCK_K"],
- num_warps=cfg["num_warps"],
- num_stages=cfg["num_stages"],
- )
- return out
-
-
-class Model(nn.Module):
- def __init__(self, T_total: int, H: int, I: int, E: in
... (1693 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<358 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_b06ed3f3-fa7e-4907-a06a-bd790ddde80",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
... (1485 more chars)null--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -1,6 +1,236 @@
+"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120).
+
+Per expert e:
+ h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+def _num_sms() -> int:
+ return torch.cuda.get_device_properties(0).multi_processor_count
+
+
def _pick_config(H: int, I: int) -> dict:
if I >= 4096:
return {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8, "num_stages": 3}
if H >= 4096:
return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8, "num_stages": 3}
- return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8, "num_stages": 3}+ return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8, "num_stages": 3}
+
+
+@triton.jit
+def _swiglu_tile_kernel(
+ a_ptr,
+ b_gate_ptr,
+ b_up_ptr,
+ c_ptr,
+ tile_info_ptr,
+ H,
+ I,
+ stride_am,
+ stride_ak,
+ stride_bg,
+ stride_bh,
+ stride_bi,
+ stride_cm,
+ stride_cn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ tidx = tl.program_id(0)
+ base = tidx * 5
+ expert = tl.load(tile_info_ptr + base + 0)
+ m_start = tl.load(tile_info_ptr + base + 1)
+ m_size = tl.load(tile_info_ptr + base + 2)
+ tile_m_idx = tl.load(tile_info_ptr + base + 3)
+ tile_n_idx = tl.load(tile_info_ptr + base + 4)
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ for k_block in range(0, H, BLOCK_K):
+ offs_k = k_block + tl.arange(0, BLOCK_K)
+
+ a_ptrs = (
+ a_ptr
+ + (m_start + offs_am[:, None]) * stride_am
+ + offs_k[None, :] * stride_ak
+ )
+ bg_ptrs = (
+ b_gate_ptr
+ + expert * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+ bu_ptrs = (
+ b_up_ptr
+ + expert * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+
+ a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
+ b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+ bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
+ bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
+
+ acc_gate = tl.dot(a, bg, acc_gate)
+ acc_up = tl.dot(a, bu, acc_up)
+
+ gate = acc_gate
+ silu_gate = gate * tl.sigmoid(gate)
+ c = (silu_gate * acc_up).to(tl.bfloat16)
+
+ c_ptrs = (
+ c_ptr
+ + (m_start + offs_am[:, None]) * stride_cm
+ + offs_bn[None, :] * stride_cn
+ )
+ c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+
+def _build_tile_schedule(
+ expert_offsets: torch.Tensor,
+ I: int,
+ block_m: int,
+ block_n: int,
+ device: torch.device,
+) -> tuple[torch.Tensor, int]:
+ offsets = expert_offsets.tolist()
+ E = len(offsets) - 1
+ rows: list[list[int]] = []
+ for e in range(E):
+ m_start = offsets[e]
+ m_end = offsets[e + 1]
+ m_size = m_end - m_start
+ if m_size == 0:
+ continue
+ num_m = triton.cdiv(m_size, block_m)
+ num_n = triton.cdiv(I, block_n)
+ for tm in range(num_m):
+ for tn in range(num_n):
+ rows.append([e, m_start, m_size, tm, tn])
+ if not rows:
+ return torch.empty(0, dtype=torch.int32, device=device), 0
+ tile_info = torch.tensor(rows, dtype=torch.int32, device=device).reshape(-1)
+ return tile_info, len(rows)
+
+
+_schedule_cache: dict[tuple, tuple[torch.Tensor, int, dict]] = {}
+
+
+def grouped_swiglu(
+ hidden_states: torch.Tensor,
+ W_gate: torch.Tensor,
+ W_up: torch.Tensor,
+ expert_offsets: torch.Tensor,
+) -> torch.Tensor:
+ T_perm, H = hidden_states.shape
+ E, H_w, I = W_gate.shape
+ assert H == H_w and W_up.shape == W_gate.shape
+
+ out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device)
+ cfg = _pick_config(H, I)
+ block_m, block_n = cfg["BLOCK_M"], cfg["BLOCK_N"]
+
+ cache_key = (expert_offsets.data_ptr(), I, block_m, block_n)
+ if cache_key not in _schedule_cache:
+ tile_info, num_tiles = _build_tile_schedule(
+ expert_offsets, I, block_m, block_n, hidden_states.device
+ )
+ _schedule_cache[cache_key] = (tile_info, num_tiles, cfg)
+ else:
+ tile_info, num_tiles, cfg = _schedule_cache[cache_key]
+
+ if num_tiles == 0:
+ return out
+
+ _swiglu_tile_kernel[(num_tiles,)](
+ hidden_states,
+ W_gate,
+ W_up,
+ out,
+ tile_info,
+ H,
+ I,
+ hidden_states.stride(0),
+ hidden_states.stride(1),
+ W_gate.stride(0),
+ W_gate.stride(1),
+ W_gate.stride(2),
+ out.stride(0),
+ out.stride(1),
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=cfg["BLOCK_K"],
+ num_warps=cfg["num_warps"],
+ num_stages=cfg["num_stages"],
+ )
+ return out
+
+
+class Model(nn.Module):
+ def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741
+ super().__init__()
+ self.T_total = T_total
+ self.H = H
+ self.I = I
+ self.E = E
+ self.K = K
+ self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
+ nn.init.normal_(self.W_gate, std=0.02)
+ nn.ini
... (1119 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<6658 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_4e202c67-dcfe-4643-96ef-b9bac6a5167",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
... (1487 more chars)null--- a//workspace/problems/06_sonic_moe_swiglu/solution.py
+++ b//workspace/problems/06_sonic_moe_swiglu/solution.py
@@ -15,21 +15,25 @@
return torch.cuda.get_device_properties(0).multi_processor_count
-def _pick_config(H: int, I: int) -> dict:
- if I >= 4096:
- return {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "num_warps": 8, "num_stages": 3}
- if H >= 4096:
- return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8, "num_stages": 3}
- return {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "num_warps": 8, "num_stages": 3}
-
-
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 256}, num_stages=2, num_warps=4),
+ ],
+ key=["H", "I", "E"],
+)
@triton.jit
-def _swiglu_tile_kernel(
+def _grouped_swiglu_kernel(
a_ptr,
b_gate_ptr,
b_up_ptr,
c_ptr,
- tile_info_ptr,
+ offsets_ptr,
+ E,
H,
I,
stride_am,
@@ -39,96 +43,81 @@
stride_bi,
stride_cm,
stride_cn,
+ NUM_SMS: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
tidx = tl.program_id(0)
- base = tidx * 5
- expert = tl.load(tile_info_ptr + base + 0)
- m_start = tl.load(tile_info_ptr + base + 1)
- m_size = tl.load(tile_info_ptr + base + 2)
- tile_m_idx = tl.load(tile_info_ptr + base + 3)
- tile_n_idx = tl.load(tile_info_ptr + base + 4)
-
- acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
-
- for k_block in range(0, H, BLOCK_K):
- offs_k = k_block + tl.arange(0, BLOCK_K)
-
- a_ptrs = (
- a_ptr
- + (m_start + offs_am[:, None]) * stride_am
- + offs_k[None, :] * stride_ak
- )
- bg_ptrs = (
- b_gate_ptr
- + expert * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
- bu_ptrs = (
- b_up_ptr
- + expert * stride_bg
- + offs_k[:, None] * stride_bh
- + offs_bn[None, :] * stride_bi
- )
-
- a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H)
- b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I)
-
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
- bg = tl.load(bg_ptrs, mask=b_mask, other=0.0)
- bu = tl.load(bu_ptrs, mask=b_mask, other=0.0)
-
- acc_gate = tl.dot(a, bg, acc_gate)
- acc_up = tl.dot(a, bu, acc_up)
-
- gate = acc_gate
- silu_gate = gate * tl.sigmoid(gate)
- c = (silu_gate * acc_up).to(tl.bfloat16)
-
- c_ptrs = (
- c_ptr
- + (m_start + offs_am[:, None]) * stride_cm
- + offs_bn[None, :] * stride_cn
- )
- c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I)
- tl.store(c_ptrs, c, mask=c_mask)
-
-
-def _build_tile_schedule(
- expert_offsets: torch.Tensor,
- I: int,
- block_m: int,
- block_n: int,
- device: torch.device,
-) -> tuple[torch.Tensor, int]:
- offsets = expert_offsets.tolist()
- E = len(offsets) - 1
- rows: list[list[int]] = []
- for e in range(E):
- m_start = offsets[e]
- m_end = offsets[e + 1]
+ iterated_tiles = 0
+
+ for g in tl.range(E):
+ m_start = tl.load(offsets_ptr + g)
+ m_end = tl.load(offsets_ptr + g + 1)
m_size = m_end - m_start
- if m_size == 0:
- continue
- num_m = triton.cdiv(m_size, block_m)
- num_n = triton.cdiv(I, block_n)
- for tm in range(num_m):
- for tn in range(num_n):
- rows.append([e, m_start, m_size, tm, tn])
- if not rows:
- return torch.empty(0, dtype=torch.int32, device=device), 0
- tile_info = torch.tensor(rows, dtype=torch.int32, device=device).reshape(-1)
- return tile_info, len(rows)
-
-
-_schedule_cache: dict[tuple, tuple[torch.Tensor, int, dict]] = {}
+
+ num_m_tiles = tl.cdiv(m_size, BLOCK_M)
+ num_n_tiles = tl.cdiv(I, BLOCK_N)
+ num_tiles = num_m_tiles * num_n_tiles
+
+ if m_size > 0:
+ while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
+ gidx = tidx - iterated_tiles
+ tile_m_idx = gidx % num_m_tiles
+ tile_n_idx = gidx // num_m_tiles
+
+ acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ for k_block in range(0, H, BLOCK_K):
+ offs_k = k_block + tl.arange(0, BLOCK_K)
+
+ a_ptrs = (
+ a_ptr
+ + (m_start + offs_am[:, None]) * stride_am
+ + offs_k[None, :] * stride_ak
+ )
+ bg_ptrs = (
+ b_gate_ptr
+ + g * stride_bg
+ + offs_k[:, None] * stride_bh
+ + offs_bn[None, :] * stride_bi
+ )
+ bu_ptrs = (
+ b_up_ptr
+ + g * stride_bg
+ + offs_k[:, None] * st
... (2408 more chars){
"path": "/workspace/problems/06_sonic_moe_swiglu/solution.py",
"streamContent": "<6328 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_0db5e526-d7ac-4a93-9fc2-d2631999375",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/06_sonic_moe_swiglu"
}
],
... (1503 more chars)null