"""Grouped GEMM + fused SwiGLU for MoE up-projection on Blackwell (SM120).
Implements: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
using Triton kernel with grouped GEMM.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def triton_gemm_kernel(
a_ptr, b_ptr, c_ptr,
N_e, I, H,
stride_a_row, stride_a_col,
stride_b_row, stride_b_col,
stride_c_row, stride_c_col,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(N_e, BLOCK_M)
num_pid_n = tl.cdiv(I, BLOCK_N)
num_pid_in_group = num_pid_m * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * num_pid_m
group_size_m = min(num_pid_m, N_e - first_pid_m)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_a_row + offs_k[None, :] * stride_a_col
b_ptrs = b_ptr + offs_k[:, None] * stride_b_row + offs_n[None, :] * stride_b_col
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(H, BLOCK_K)):
mask_m = offs_m[:, None] < N_e
mask_k = offs_k[None, :] < H
a = tl.load(a_ptrs, mask=mask_m & mask_k, other=0.0)
mask_k = offs_k[:, None] < H
mask_n = offs_n[None, :] < I
b = tl.load(b_ptrs, mask=mask_k & mask_n, other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_a_col
b_ptrs += BLOCK_K * stride_b_row
offs_k += BLOCK_K
offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_m_out[:, None] * stride_c_row + offs_n_out[None, :] * stride_c_col
mask_m_out = offs_m_out[:, None] < N_e
mask_n_out = offs_n_out[None, :] < I
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
T_perm, H = hidden_states.shape
I = W.shape[2]
output = torch.empty(T_perm, I, device=hidden_states.device, dtype=hidden_states.dtype)
for e in range(E):
start = int(expert_offsets[e].item())
end = int(expert_offsets[e + 1].item())
if end == start:
continue
N_e = end - start
x_e = hidden_states[start:end]
w_e = W[e]
out_e = output[start:end]
grid_m = triton.cdiv(N_e, 64)
grid_n = triton.cdiv(I, 64)
grid = (grid_m * grid_n,)
triton_gemm_kernel[grid](
x_e, w_e, out_e,
N_e, I, H,
x_e.stride(0), 1,
w_e.stride(0), w_e.stride(1),
out_e.stride(0), 1,
64, 64, 64,
)
return output
class Model(nn.Module):
def __init__(self, T_total: int, H: int, I: int, E: int, K: int):
super().__init__()
self.T_total = T_total
self.H = H
self.I = I
self.E = E
self.K = K
self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16))
nn.init.normal_(self.W_gate, std=0.02)
nn.init.normal_(self.W_up, std=0.02)
def forward(self, hidden_states: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
gate_out = triton_grouped_gemm(hidden_states, self.W_gate, expert_offsets, self.E)
up_out = triton_grouped_gemm(hidden_states, self.W_up, expert_offsets, self.E)
silu_gate = torch.nn.functional.silu(gate_out)
return silu_gate * up_out
T_total = 32768
H = 4096
I = 1536
E = 128
K = 8
def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor:
T_perm = T_total * K
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device=device)
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device=device)
offsets[1:] = torch.cumsum(counts, dim=0)
return offsets
def get_inputs():
T_perm = T_total * K
hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1
expert_offsets = _build_routing(T_total, E, K)
return [hidden_states, expert_offsets]
def get_init_inputs():
return [T_total, H, I, E, K]W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break from `Tensor.item()`, consider setting:
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] or:
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] to include these operations in the captured graph.
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] Graph break: from user code at:
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py", line 62, in forward
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0] start = int(expert_offsets[e].item())
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
W0427 14:00:12.853000 1178754 .venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py:1379] [0/0]
shape=0 variant=eager tflops=33.639 gbps=251.853 ms=24.514
shape=0 variant=compiled tflops=33.535 gbps=251.074 ms=24.590
shape=0 variant=solution tflops=19.459 gbps=145.691 ms=42.377
shape=0 solution_peak_fraction=0.0973
shape=1 variant=eager tflops=14.264 gbps=264.669 ms=2.409
shape=1 variant=compiled tflops=14.369 gbps=266.617 ms=2.391
shape=1 variant=solution tflops=7.997 gbps=148.379 ms=4.297
shape=1 solution_peak_fraction=0.0400
shape=2 variant=eager tflops=35.499 gbps=242.670 ms=15.486
shape=2 variant=compiled tflops=35.353 gbps=241.671 ms=15.550
shape=2 variant=solution tflops=22.606 gbps=154.534 ms=24.319
shape=2 solution_peak_fraction=0.1130
peak_fraction: 0.0760
RESULT: LOW
PASS
{
"problem": "06_sonic_moe_swiglu",
"harness": "opencode",
"model": "openrouter-pinned/minimax/minimax-m2.7",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.0760,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 579658, "output_tokens": 54850, "cache_read_tokens": 5077216, "cache_creation_tokens": 0, "reasoning_tokens": 18679, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/reference.py</path>
<type>file</type>
<content>
1: """Naive grouped GEMM + fused SwiGLU reference (correctness only, NOT the SOTA).
2:
3: This is the up-projection of an MoE FFN. Each token i is assigned to K experts;
4: expert_indices[i*K + j] tells you which expert. Tokens are dispatched to experts
5: according to routing metadata; we compute, per expert e:
6:
7: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
8:
9: where x_e is the slice of permuted hidden states routed to expert e, with
10: expert_offsets[e]:expert_offsets[e+1] giving its row range in the permuted layout.
11:
12: The reference loops over experts in Python. Slow, but pedagogically clear and
13: correct. Forbidden ops (torch.matmul, torch.bmm, F.linear, sonic_moe imports)
14: are NOT used here, but the reference is exempt — only solution.py is checked.
15: """
16: from __future__ import annotations
17:
18: import torch
19: import torch.nn as nn
20: import torch.nn.functional as F
21:
22: OP_TYPE = "grouped_gemm_swiglu"
23: SUPPORTED_PRECISIONS = ["bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26:
27: class Model(nn.Module):
28: """Up-projection of a top-K MoE FFN with fused SwiGLU.
29:
30: Inputs at call time:
31: hidden_states: (T_perm, H) bf16, already permuted to expert order
32: expert_offsets: (E+1,) int32, prefix sums of token counts per expert
33:
... (2866 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.ge
... (2793 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: return float(eval(expr, {"__builtins__": {}}, vars))
27:
28:
29: def main():
30: import reference
31: import shapes
32: import solution
33:
34: meta = yaml.safe_load(Path("problem.yaml").read_text())
35: hw = get_hw(meta["hardware"][0])
36: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
37: peak_gbps = hw.peak_bandwidth_gb_s
38: regime = meta.get("regime", "compute")
39: flops_formula = meta["
... (3701 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/problem.yaml</path>
<type>file</type>
<content>
1: name: 06_sonic_moe_swiglu
2: display_name: "Sonic-MoE up-projection (Grouped GEMM + SwiGLU)"
3: precision: bf16
4: regime: compute
5: hardware: [RTX_PRO_6000]
6: peak_tflops_key: bf16
7: peak_bandwidth_key: dram
8:
9: # Dense-equivalent FLOPs: gate GEMM + up GEMM + SwiGLU (negligible elementwise).
10: # Per token: 2*H*I FMAs for gate, 2*H*I for up => 2 * T_total * H * (2*I).
11: # (Each of T_total tokens visits K experts, but total work = T_total * K * (2*H*I*2)
12: # only if you count routing. The standard MoE FLOPs convention counts only the
13: # active per-token compute: T_total tokens * 2 * (2*I) * H. We follow that.)
14: flops_formula: "2 * T_total * H * (2 * I)"
15:
16: # Bytes moved (approximate, lower bound):
17: # read hidden (T_perm = T_total*K rows of H bf16) + read 2 weight matrices per
18: # expert (E * H * 2*I bf16) + write output (T_perm rows of I bf16).
19: bytes_formula: "T_total*K*H*2 + E*H*(2*I)*2 + T_total*K*I*2"
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must write the grouped GEMM + fused SwiGLU themselves.
25: # - torch.matmul / torch.bmm / F.linear: cuBLAS dispatch, defeats the point.
26: # - sonic_moe imports: vendor-call cheating; the SOTA is graded separately.
27: forbidden:
28: - "torch.matmul"
29: - "torch.bmm"
30: - "torch.nn.functional.linear"
31: - "F.linear"
32: - "from son
... (681 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for Sonic-MoE up-projection (grouped GEMM + fused SwiGLU).
2:
3: Defaults match the sonic-moe paper's headline configuration. We add:
4: - a smaller shape for fast iteration during agent development
5: - a wider intermediate (different aspect ratio) to stress N-tile selection
6: """
7:
8: SHAPES = [
9: # Headline sonic-moe shape: 32K tokens, 128 experts, top-8.
10: {"T_total": 32768, "H": 4096, "I": 1536, "E": 128, "K": 8},
11:
12: # Fast-iteration shape (~16x cheaper). Same expert count to keep the
13: # variable-length grouped layout meaningful, but smaller token / hidden dims.
14: {"T_total": 4096, "H": 2048, "I": 1024, "E": 64, "K": 4},
15:
16: # Different aspect ratio: smaller H, wider I (intermediate-heavy FFN).
17: # Forces tiles to handle larger N relative to K.
18: {"T_total": 16384, "H": 2048, "I": 4096, "E": 64, "K": 8},
19: ]
(End of file - total 19 lines)
</content>{
"command": "nvidia-smi",
"description": "Check GPU availability"
}Mon Apr 27 13:12:40 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 34C P8 11W / 600W | 383MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (281 more chars){
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"command": "which cuas 2>/dev/null; nvcc --version 2>/dev/null | head -5; ls /usr/local/cuda*/include/cutlass* 2>/dev/null | head -5 || echo \"no cutlass in cuda includes\"",
"description": "Check CUDA and CUTLASS availability"
}cuas not found
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
zsh:1: no matches found: /usr/local/cuda*/include/cutlass*
{
"filePath": "/home/infatoshi/.claude/skills/kernel/SKILL.md"
}<path>/home/infatoshi/.claude/skills/kernel/SKILL.md</path>
<type>file</type>
<content>
1: ---
2: name: kernel-optimization
3: description: Unified GPU/TPU kernel skill. Covers CUDA, Metal, Triton, AMD HIP, TPU Pallas, kernel porting, debugging, numerical precision, and performance optimization. Use when writing, optimizing, debugging, porting, or benchmarking GPU/TPU kernels on any platform.
4: ---
5:
6: # Kernel Optimization
7:
8: Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
9:
10: ## When this skill fires
11:
12: - Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
13: - Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
14: - Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
15: - Performance profiling and optimization
16: - Numerical precision issues (FP formats, mixed precision, tolerance)
17: - AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
18:
19: ## Sub-references (read on demand)
20:
21: The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
22:
23: | File | When to read |
24: |------|-------------|
25: | `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
26: | `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
27: | `~/.claude/s
... (1791 more chars){
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}<path>/home/infatoshi/.claude/skills/perf-optim/SKILL.md</path>
<type>file</type>
<content>
1: ---
2: name: perf-optim
3: description: "[SUB-REFERENCE of kernel skill] Performance optimization. Do NOT invoke directly -- use /kernel instead, which reads this file on demand."
4: ---
5:
6: # Performance Optimization Atlas
7:
8: Comprehensive, hierarchical reference for optimizing GPU workloads. Pick combinations from this atlas based on your bottleneck profile. **Always profile before optimizing.**
9:
10: ---
11:
12: ## 0. FIRST PRINCIPLES: SPEED OF LIGHT, THEN PROFILE
13:
14: ### Rule
15: Never optimize blind. But before you even profile, compute the **speed of light** -- the theoretical physical limit for this operation. Then profile to measure the gap. The gap IS the optimization opportunity.
16:
17: ### Speed of Light Analysis
18: 1. **Compute the limit**: What does physics allow? Peak FLOPS, peak bandwidth, minimum latency for this exact workload
19: 2. **Profile current state**: Where are we now?
20: 3. **Measure the gap**: speed_of_light - measured = optimization budget
21: 4. **Don't continuous-improve from current state**: Strip to zero. If you built this from scratch knowing only the physical limits, what would you build? Then reason about every piece of the gap between that and current
22: - "It takes 74ms" -- don't optimize to 72ms. Compute that physics allows 6ms. Now reason about the 6-to-74 gap. The rest might be well-reasoned compromises, but at least yo
... (53580 more chars){
"command": "python3 -c \"import torch; print('PyTorch:', torch.__version__); print('CUDA available:', torch.cuda.is_available()); print('CUDA version:', torch.version.cuda)\"",
"description": "Check PyTorch and CUDA versions"
}PyTorch: 2.11.0+cu130
CUDA available: True
CUDA version: 13.0
{
"command": "python3 -c \"import triton; print('Triton:', triton.__version__)\" 2>/dev/null || echo \"Triton not installed\"",
"description": "Check Triton version"
}Triton: 3.6.0
{
"command": "python3 -c \"from torch.utils.cpp_extension import load_inline; print('load_inline available')\"",
"description": "Check load_inline availability"
}load_inline available
{
"todos": [
{
"content": "Write CUDA C++ grouped GEMM + SwiGLU kernel with load_inline",
"status": "in_progress",
"priority": "high"
},
{
"content": "Verify correctness with check.py for all 3 shapes",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and benchmark the solution",
"status": "pending",
"priority": "medium"
},
{
"content": "Iterate and optimize if needed",
"status": "pending",
"priority": "medium"
}
]
}content: 'Write CUDA C++ grouped GEMM + SwiGLU kernel with load_inline'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py for all 3 shapes'
status: 'pending'
priority: 'high'
content: 'Profile and benchmark the solution'
status: 'pending'
priority: 'medium'
content: 'Iterate and optimize if needed'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -0,0 +1,631 @@
+"""Grouped GEMM + fused SwiGLU for MoE up-projection on Blackwell (SM120).
+
+Implements: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
+using custom CUDA kernel with tensor cores and load_inline.
+"""
+from __future__ import annotations
+
+import math
+from typing import List
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+cuda_source = r"""
+#include <cuda/bfloat16.h>
+#include <cuda_pipeline.h>
+#include <mma.h>
+
+using namespace nvcuda;
+using cutlass::pipe;
+
+constexpr int WARP_SIZE = 32;
+constexpr int MMA_M = 16;
+constexpr int MMA_N = 16;
+constexpr int MMA_K = 16;
+
+// SwiGLU activation: silu(x) * x = x / (1 + exp(-x)) * x
+__device__ __forceinline__ __nv_bfloat16 sigmoid_swish(__nv_bfloat16 x) {
+ float fx = __bfloat16_to_float(x);
+ float sig = fx / (1.0f + expf(-fx));
+ return __float_to_bfloat16(sig * fx);
+}
+
+template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
+__global__ void grouped_gemm_swiglu_kernel(
+ const __nv_bfloat16* __restrict__ hidden_states, // (T_perm, H)
+ const __nv_bfloat16* __restrict__ W_gate, // (E, H, I)
+ const __nv_bfloat16* __restrict__ W_up, // (E, H, I)
+ const int* __restrict__ expert_offsets, // (E+1,)
+ __nv_bfloat16* __restrict__ output, // (T_perm, I)
+ int T_perm, int H, int I, int E
+) {
+ extern __shared__ char smem_buf[];
+
+ int expert_idx = blockIdx.x;
+ if (expert_idx >= E) return;
+
+ int start_row = expert_offsets[expert_idx];
+ int end_row = expert_offsets[expert_idx + 1];
+ int N_e = end_row - start_row;
+ if (N_e <= 0) return;
+
+ const __nv_bfloat16* Wg = W_gate + expert_idx * H * I;
+ const __nv_bfloat16* Wu = W_up + expert_idx * H * I;
+ const __nv_bfloat16* X = hidden_states + start_row * H;
+ __nv_bfloat16* Y = output + start_row * I;
+
+ int tid = threadIdx.x;
+ int warp_id = tid / WARP_SIZE;
+ int lane_id = tid % WARP_SIZE;
+
+ int total_blocks = gridDim.x;
+ int block_idx_in_expert = blockIdx.y;
+ int num_blocks_per_expert = gridDim.y;
+
+ int row_start = block_idx_in_expert * BLOCK_M;
+ int row_end = min(row_start + BLOCK_M, N_e);
+
+ if (row_start >= N_e) return;
+
+ int num_rows_this_block = row_end - row_start;
+ int num_tiles_k = (H + BLOCK_K - 1) / BLOCK_K;
+
+ __nv_bfloat16* smem = (__nv_bfloat16*)smem_buf;
+
+ for (int row = row_start + lane_id; row < row_end; row += WARP_SIZE) {
+ for (int ki = 0; ki < num_tiles_k; ++ki) {
+ if (ki > 0) {
+ __syncthreads();
+ }
+ int k_start = ki * BLOCK_K;
+ int k_len = min(BLOCK_K, H - k_start);
+
+ if (lane_id < k_len) {
+ for (int jj = 0; jj < BLOCK_N; ++jj) {
+ int col = jj;
+ int k_idx = k_start + lane_id;
+ smem[jj * BLOCK_K + lane_id] = X[row * H + k_idx];
+ }
+ }
+ if (lane_id < BLOCK_K && lane_id >= k_len) {
+ for (int jj = 0; jj < BLOCK_N; ++jj) {
+ smem[jj * BLOCK_K + lane_id] = __float_to_bfloat16(0.0f);
+ }
+ }
+
+ for (int jj = 0; jj < BLOCK_N; ++jj) {
+ int w_idx = k_start * I + jj * BLOCK_K + lane_id;
+ if (k_start + lane_id < H && jj < I) {
+ smem[BLOCK_N * BLOCK_K + jj * BLOCK_K + lane_id] = Wg[w_idx];
+ } else {
+ smem[BLOCK_N * BLOCK_K + jj * BLOCK_K + lane_id] = __float_to_bfloat16(0.0f);
+ }
+ }
+
+ __syncthreads();
+
+ for (int ii = 0; ii < BLOCK_M; ++ii) {
+ if (row_start + ii < row_end) {
+ float acc = 0.0f;
+ for (int kk = 0; kk < BLOCK_K; ++kk) {
+ __nv_bfloat16 a_val = smem[ii * BLOCK_K + kk];
+ __nv_bfloat16 b_val = smem[BLOCK_N * BLOCK_K + jj * BLOCK_K + kk];
+ acc += __bfloat16_to_float(a_val) * __bfloat16_to_float(b_val);
+ }
+ int out_idx = (row_start + ii) * I + jj;
+ if (ki == 0) {
+ Y[out_idx] = __float_to_bfloat16(acc);
+ } else {
+ Y[out_idx] = __float_to_bfloat16(__bfloat16_to_float(Y[out_idx]) + acc);
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+}
+
+template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
+__global__ void grouped_gemm_swiglu_kernel_v2(
+ const __nv_bfloat16* __restrict__ hidden_states,
+ const __nv_bfloat16* __restrict__ W_gate,
+ const __nv_bfloat16* __restrict__ W_up,
+ const int* __restrict__ expert_offsets,
+ __nv_bfloat16* __restrict__ output,
+ int T_perm, int H, int I, int E
+) {
+ extern __shared__ char smem_buf[];
+
+ int expert_idx = blockIdx.x;
+ if (expert_idx >= E) return;
+
+ int start_row = expert_offsets[expert_idx];
+ int end_row = expert_offsets[expert_idx + 1];
+ int N_e = end_row - start_row;
+ if (N_e <= 0) return;
+
+ const __nv_bfloat16* Wg = W_gate + expert_idx * H * I;
+ const __nv_bfloat16* Wu = W_up + expert_idx * H * I;
+ const __nv_bfloat16* X = hidden_states + start_row * H;
+ __nv_bfloat16* Y = output + start_row * I;
+
+ int tid = threadIdx.x;
+ int num_threads = blockDim.x;
+
+ int total_blocks = gridDim.x;
+ int block_idx = blockIdx.y;
+ int blocks_per_expert = gridDim.y;
+
+ constexpr int TILES_M = 4;
+ constexpr int TILES_N = 4;
+ constexpr int THREADS_PER_ROW = 8;
+
+ int row_tile = block_idx / TILES_N;
+ int col_tile = block_idx % TILES_N;
+
+ int row_start = row_tile * BLOCK_M;
+ int row_end = min(row_start + BLOCK_M, N_e);
+
+
... (16508 more chars){
"content": "<21760 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 | head -60",
"timeout": 120000,
"description": "Test initial solution"
}FAIL: import error: Error building extension 'grouped_gemm_swiglu': [1/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=grouped_gemm_swiglu -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/grouped_gemm_swiglu/cuda.cu -o cuda.cuda.o
FAILED: [code=1] cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=grouped_gemm_swiglu -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLO
... (14191 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -5,44 +5,43 @@
"""
from __future__ import annotations
-import math
-from typing import List
-
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
cuda_source = r"""
-#include <cuda/bfloat16.h>
-#include <cuda_pipeline.h>
+#include <torch/extension.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
#include <mma.h>
using namespace nvcuda;
-using cutlass::pipe;
constexpr int WARP_SIZE = 32;
-constexpr int MMA_M = 16;
-constexpr int MMA_N = 16;
-constexpr int MMA_K = 16;
-
-// SwiGLU activation: silu(x) * x = x / (1 + exp(-x)) * x
-__device__ __forceinline__ __nv_bfloat16 sigmoid_swish(__nv_bfloat16 x) {
- float fx = __bfloat16_to_float(x);
- float sig = fx / (1.0f + expf(-fx));
- return __float_to_bfloat16(sig * fx);
-}
-
+
+// SwiGLU: silu(x) * y = x / (1 + exp(-x)) * y
+__device__ __forceinline__ float sigmoid_swish(float x) {
+ return x / (1.0f + expf(-x));
+}
+
+__device__ __forceinline__ __nv_bfloat162 sigmoid_swish(__nv_bfloat162 x) {
+ float2 fx = __bfloat1622float2(x);
+ fx.x = sigmoid_swish(fx.x);
+ fx.y = sigmoid_swish(fx.y);
+ return __float22bfloat162_rn(fx);
+}
+
+// Kernel for small experts (when N_e <= 256)
template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
-__global__ void grouped_gemm_swiglu_kernel(
- const __nv_bfloat16* __restrict__ hidden_states, // (T_perm, H)
- const __nv_bfloat16* __restrict__ W_gate, // (E, H, I)
- const __nv_bfloat16* __restrict__ W_up, // (E, H, I)
- const int* __restrict__ expert_offsets, // (E+1,)
- __nv_bfloat16* __restrict__ output, // (T_perm, I)
+__global__ void grouped_gemm_swiglu_kernel_small(
+ const __nv_bfloat16* __restrict__ hidden_states,
+ const __nv_bfloat16* __restrict__ W_gate,
+ const __nv_bfloat16* __restrict__ W_up,
+ const int* __restrict__ expert_offsets,
+ __nv_bfloat16* __restrict__ gate_out,
+ __nv_bfloat16* __restrict__ up_out,
int T_perm, int H, int I, int E
) {
- extern __shared__ char smem_buf[];
-
int expert_idx = blockIdx.x;
if (expert_idx >= E) return;
@@ -54,480 +53,153 @@
const __nv_bfloat16* Wg = W_gate + expert_idx * H * I;
const __nv_bfloat16* Wu = W_up + expert_idx * H * I;
const __nv_bfloat16* X = hidden_states + start_row * H;
- __nv_bfloat16* Y = output + start_row * I;
-
- int tid = threadIdx.x;
- int warp_id = tid / WARP_SIZE;
- int lane_id = tid % WARP_SIZE;
-
- int total_blocks = gridDim.x;
- int block_idx_in_expert = blockIdx.y;
- int num_blocks_per_expert = gridDim.y;
-
- int row_start = block_idx_in_expert * BLOCK_M;
- int row_end = min(row_start + BLOCK_M, N_e);
-
- if (row_start >= N_e) return;
-
- int num_rows_this_block = row_end - row_start;
- int num_tiles_k = (H + BLOCK_K - 1) / BLOCK_K;
-
- __nv_bfloat16* smem = (__nv_bfloat16*)smem_buf;
-
- for (int row = row_start + lane_id; row < row_end; row += WARP_SIZE) {
- for (int ki = 0; ki < num_tiles_k; ++ki) {
- if (ki > 0) {
- __syncthreads();
- }
- int k_start = ki * BLOCK_K;
- int k_len = min(BLOCK_K, H - k_start);
-
- if (lane_id < k_len) {
- for (int jj = 0; jj < BLOCK_N; ++jj) {
- int col = jj;
- int k_idx = k_start + lane_id;
- smem[jj * BLOCK_K + lane_id] = X[row * H + k_idx];
- }
- }
- if (lane_id < BLOCK_K && lane_id >= k_len) {
- for (int jj = 0; jj < BLOCK_N; ++jj) {
- smem[jj * BLOCK_K + lane_id] = __float_to_bfloat16(0.0f);
- }
- }
-
- for (int jj = 0; jj < BLOCK_N; ++jj) {
- int w_idx = k_start * I + jj * BLOCK_K + lane_id;
- if (k_start + lane_id < H && jj < I) {
- smem[BLOCK_N * BLOCK_K + jj * BLOCK_K + lane_id] = Wg[w_idx];
- } else {
- smem[BLOCK_N * BLOCK_K + jj * BLOCK_K + lane_id] = __float_to_bfloat16(0.0f);
- }
- }
-
- __syncthreads();
-
- for (int ii = 0; ii < BLOCK_M; ++ii) {
- if (row_start + ii < row_end) {
- float acc = 0.0f;
- for (int kk = 0; kk < BLOCK_K; ++kk) {
- __nv_bfloat16 a_val = smem[ii * BLOCK_K + kk];
- __nv_bfloat16 b_val = smem[BLOCK_N * BLOCK_K + jj * BLOCK_K + kk];
- acc += __bfloat16_to_float(a_val) * __bfloat16_to_float(b_val);
- }
- int out_idx = (row_start + ii) * I + jj;
- if (ki == 0) {
- Y[out_idx] = __float_to_bfloat16(acc);
- } else {
- Y[out_idx] = __float_to_bfloat16(__bfloat16_to_float(Y[out_idx]) + acc);
- }
- }
- }
- __syncthreads();
- }
- }
-}
-
-template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
-__global__ void grouped_gemm_swiglu_kernel_v2(
- const __nv_bfloat16* __restrict__ hidden_states,
- const __nv_bfloat16* __restrict__ W_gate,
- const __nv_bfloat16* __restrict__ W_up,
- const int* __restrict__ expert_offsets,
- __nv_bfloat16* __restrict__ output,
- int T_perm, int H, int I, int E
-) {
- extern __shared__ char smem_buf[];
-
- int expert_idx = blockIdx.x;
- if (expert_idx >= E) return;
-
- int start_row = expert_offsets[expert_idx];
- int end_row = expert_offsets[expert_idx + 1];
- int N_e = end_row - start_row;
- if (N_e <= 0) return;
-
- const __nv_bfloat16* Wg = W_gate + expert_idx * H * I;
- const __nv_bfloat16* Wu = W_up + expert_idx * H * I;
- const __nv_b
... (20382 more chars){
"content": "<9120 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "<1071 chars \u2014 see below>",
"description": "Test basic load_inline"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
#include <cuda_runtime.h>
__global__ void simple_test(const float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = input[idx] * 2.0f;
}
}
void launch_simple(const torch::Tensor& input, torch::Tensor& output) {
int n = input.numel();
simple_test<<<(n + 255) / 256, 256>>>(input.data_ptr<float>(), output.data_ptr<float>(), n);
}
'''
cpp_src = '''
#include <torch/extension.h>
void launch_simple(const torch::Tensor& input, torch::Tensor& output);
torch::Tensor call_simple(const torch::Tensor& input) {
auto output = torch::empty_like(input);
launch_simple(input, output);
return output;
}
'''
m = load_inline('simple_test', cpp_src, cuda_src, ['call_simple'])
x = torch.randn(10)
y = m.call_simple(x)
print('Test passed:', torch.allclose(y, x * 2))
" 2>&1 | head -30Traceback (most recent call last):
File "<string>", line 32, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2214, in _jit_compile
raise AssertionError(
AssertionError: cannot have both SYCL and CUDA files in the same extension
{
"command": "<737 chars \u2014 see below>",
"description": "Test load_inline with CUDA only"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
#include <cuda_runtime.h>
extern \"C\" __global__ void simple_test(const float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = input[idx] * 2.0f;
}
}
extern \"C\" void launch_simple(const float* input, float* output, int n) {
simple_test<<<(n + 255) / 256, 256>>>(input, output, n);
}
'''
m = load_inline('simple_test', cuda_sources=cuda_src, functions=[])
print('Compilation succeeded')
# But we need functions= to be set for it to work...
" 2>&1Traceback (most recent call last):
File "<string>", line 21, in <module>
TypeError: load_inline() missing 1 required positional argument: 'cpp_sources'
{
"command": "<637 chars \u2014 see below>",
"description": "Test load_inline structure"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
from torch.utils.cpp_extension import load_inline
cuda_src = '''
#include <torch/extension.h>
#include <cuda_runtime.h>
extern \"C\" __global__ void simple_test(const float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = input[idx] * 2.0f;
}
}
'''
cpp_src = '''
#include <torch/extension.h>
'''
m = load_inline('simple_test', cpp_sources=cpp_src, cuda_sources=cuda_src, functions=[], verbose=True)
print('Compilation succeeded')
" 2>&1 | head -50[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=simple_test -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -fPIC -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/simple_test/main.cpp -o main.o
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=simple_test -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/simple_test/cuda.cu -o cuda.cuda.o
[3/3] c++ main.o cuda.cuda.o -shared -L/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torc
... (147 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -1,229 +1,184 @@
"""Grouped GEMM + fused SwiGLU for MoE up-projection on Blackwell (SM120).
Implements: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
-using custom CUDA kernel with tensor cores and load_inline.
+using Triton kernel with grouped GEMM.
"""
from __future__ import annotations
import torch
import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-
-cuda_source = r"""
-#include <torch/extension.h>
-#include <cuda.h>
-#include <cuda_runtime.h>
-#include <mma.h>
-
-using namespace nvcuda;
-
-constexpr int WARP_SIZE = 32;
-
-// SwiGLU: silu(x) * y = x / (1 + exp(-x)) * y
-__device__ __forceinline__ float sigmoid_swish(float x) {
- return x / (1.0f + expf(-x));
-}
-
-__device__ __forceinline__ __nv_bfloat162 sigmoid_swish(__nv_bfloat162 x) {
- float2 fx = __bfloat1622float2(x);
- fx.x = sigmoid_swish(fx.x);
- fx.y = sigmoid_swish(fx.y);
- return __float22bfloat162_rn(fx);
-}
-
-// Kernel for small experts (when N_e <= 256)
-template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
-__global__ void grouped_gemm_swiglu_kernel_small(
- const __nv_bfloat16* __restrict__ hidden_states,
- const __nv_bfloat16* __restrict__ W_gate,
- const __nv_bfloat16* __restrict__ W_up,
- const int* __restrict__ expert_offsets,
- __nv_bfloat16* __restrict__ gate_out,
- __nv_bfloat16* __restrict__ up_out,
- int T_perm, int H, int I, int E
-) {
- int expert_idx = blockIdx.x;
- if (expert_idx >= E) return;
-
- int start_row = expert_offsets[expert_idx];
- int end_row = expert_offsets[expert_idx + 1];
- int N_e = end_row - start_row;
- if (N_e <= 0) return;
-
- const __nv_bfloat16* Wg = W_gate + expert_idx * H * I;
- const __nv_bfloat16* Wu = W_up + expert_idx * H * I;
- const __nv_bfloat16* X = hidden_states + start_row * H;
-
- int row_block = blockIdx.y;
- int col_block = blockIdx.z;
-
- int row_start = row_block * BLOCK_M;
- int col_start = col_block * BLOCK_N;
-
- if (row_start >= N_e || col_start >= I) return;
-
- int lane_id = threadIdx.x % WARP_SIZE;
- int warp_id = threadIdx.x / WARP_SIZE;
- int num_warps = blockDim.x / WARP_SIZE;
-
- int num_k_tiles = (H + BLOCK_K - 1) / BLOCK_K;
-
- float gate_acc[2] = {0.0f, 0.0f};
- float up_acc[2] = {0.0f, 0.0f};
-
- __nv_bfloat16 a_reg[2];
- __nv_bfloat16 b_gate_reg[2];
- __nv_bfloat16 b_up_reg[2];
-
- for (int kt = 0; kt < num_k_tiles; ++kt) {
- for (int mi = 0; mi < 2; ++mi) {
- int row = row_start + mi * 16 + warp_id;
- int k = kt * BLOCK_K + lane_id;
- if (row < N_e && k < H) {
- a_reg[mi] = X[row * H + k];
- } else {
- a_reg[mi] = __float_to_bfloat16(0.0f);
- }
- }
-
- for (int ni = 0; ni < 2; ++ni) {
- int col = col_start + ni * 16 + (lane_id % 16);
- int k = kt * BLOCK_K + warp_id;
- if (col < I && k < H) {
- b_gate_reg[ni] = Wg[k * I + col];
- b_up_reg[ni] = Wu[k * I + col];
- } else {
- b_gate_reg[ni] = __float_to_bfloat16(0.0f);
- b_up_reg[ni] = __float_to_bfloat16(0.0f);
- }
- }
-
- for (int mi = 0; mi < 2; ++mi) {
- for (int ni = 0; ni < 2; ++ni) {
- gate_acc[mi] += __bfloat16_to_float(a_reg[mi]) * __bfloat16_to_float(b_gate_reg[ni]);
- up_acc[mi] += __bfloat16_to_float(a_reg[mi]) * __bfloat16_to_float(b_up_reg[ni]);
- }
- }
- }
-
- for (int mi = 0; mi < 2; ++mi) {
- int row = row_start + mi * 16 + warp_id;
- for (int ni = 0; ni < 2; ++ni) {
- int col = col_start + ni * 16 + (lane_id % 16);
- if (row < N_e && col < I) {
- gate_out[start_row * I + row * I + col] = __float_to_bfloat16(gate_acc[mi]);
- up_out[start_row * I + row * I + col] = __float_to_bfloat16(up_acc[mi]);
- }
- }
- }
-}
-
-// Kernel for applying SwiGLU and combining gate+up outputs
-template <int BLOCK_M, int BLOCK_N>
-__global__ void swiglu_combine_kernel(
- __nv_bfloat16* gate_out,
- __nv_bfloat16* up_out,
- __nv_bfloat16* output,
- int T_perm, int I
-) {
- int row = blockIdx.x * BLOCK_M + threadIdx.x;
- int col = blockIdx.y * BLOCK_N + threadIdx.y;
-
- if (row < T_perm && col < I) {
- float gate_val = __bfloat16_to_float(gate_out[row * I + col]);
- float up_val = __bfloat16_to_float(up_out[row * I + col]);
- float result = sigmoid_swish(gate_val) * up_val;
- output[row * I + col] = __float_to_bfloat16(result);
- }
-}
-
-// Main launcher
-void grouped_gemm_swiglu_cuda(
- const torch::Tensor& hidden_states,
- const torch::Tensor& W_gate,
- const torch::Tensor& W_up,
- const torch::Tensor& expert_offsets,
- torch::Tensor& output,
- int T_total, int H, int I, int E, int K
-) {
- int T_perm = T_total * K;
-
- constexpr int BLOCK_M = 32;
- constexpr int BLOCK_N = 32;
- constexpr int BLOCK_K = 64;
-
- dim3 gemm_grid(E, 8, (I + BLOCK_N - 1) / BLOCK_N);
-
- auto gate_out = torch::empty({T_perm, I}, hidden_states.options());
- auto up_out = torch::empty({T_perm, I}, hidden_states.options());
-
- grouped_gemm_swiglu_kernel_small<BLOCK_M, BLOCK_N, BLOCK_K><<<gemm_grid, 256>>>(
- reinterpret_cast<const __nv_bfloat16*>(hidden_states.data_ptr()),
- reinterpret_cast<const __nv_bfloat16*>(W_gate.data_ptr()),
- reinterpret_cast<const __nv_bfloat16*>(W_up.data_ptr()),
- reinterpret_cast<const int*>(expert_offsets.data_ptr()),
- reinterpret_cast<__nv_bfloat16*>(gate_out.data_ptr()),
- reinterpret_cast<__nv_bfloat16*>(up_out.data_ptr()),
- T_perm, H, I, E
- );
-
- dim3 swiglu_grid
... (8761 more chars){
"content": "<7442 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 | head -60",
"timeout": 180000,
"description": "Test Triton solution"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 198, in forward
gate_out = triton_grouped_gemm_gate(hidden_states, self.W_gate, expert_offsets, self.E)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 149, in triton_grouped_gemm_gate
triton_gemm_kernel_gate[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
... (1700 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -11,24 +11,16 @@
import triton.language as tl
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=3, num_warps=4),
- ],
- key=["N_e", "I", "H"],
-)
@triton.jit
-def triton_gemm_kernel_gate(
+def triton_gemm_kernel(
a_ptr, b_ptr, c_ptr,
a_stride, b_stride, c_stride,
N_e, I, H,
- block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
- num_pid_m = tl.cdiv(N_e, block_m)
- num_pid_n = tl.cdiv(I, block_n)
+ num_pid_m = tl.cdiv(N_e, BLOCK_M)
+ num_pid_n = tl.cdiv(I, BLOCK_N)
num_pid_in_group = num_pid_m * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * num_pid_m
@@ -37,15 +29,15 @@
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
- offs_m = (pid_m * block_m + tl.arange(0, block_m)) % N_e
- offs_n = (pid_n * block_n + tl.arange(0, block_n)) % I
- offs_k = tl.arange(0, block_k)
+ offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % N_e
+ offs_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % I
+ offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_m[:, None] * a_stride + offs_k[None, :] * H)
b_ptrs = b_ptr + (offs_k[:, None] * I + offs_n[None, :])
- accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)
- for k in range(0, tl.cdiv(H, block_k)):
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(H, BLOCK_K)):
a_mask = (offs_m[:, None] < N_e) & (offs_k[None, :] < H)
b_mask = (offs_k[:, None] < H) & (offs_n[None, :] < I)
@@ -54,82 +46,18 @@
accumulator += tl.dot(a, b)
- a_ptrs += block_k * H
- b_ptrs += block_k * I
- offs_k += block_k
+ a_ptrs += BLOCK_K * H
+ b_ptrs += BLOCK_K * I
+ offs_k += BLOCK_K
c_ptrs = c_ptr + (offs_m[:, None] * I + offs_n[None, :])
c_mask = (offs_m[:, None] < N_e) & (offs_n[None, :] < I)
tl.store(c_ptrs, accumulator, mask=c_mask)
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=3, num_warps=4),
- ],
- key=["N_e", "I", "H"],
-)
-@triton.jit
-def triton_gemm_kernel_up(
- a_ptr, b_ptr, c_ptr,
- a_stride, b_stride, c_stride,
- N_e, I, H,
- block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
-):
- pid = tl.program_id(axis=0)
- num_pid_m = tl.cdiv(N_e, block_m)
- num_pid_n = tl.cdiv(I, block_n)
- num_pid_in_group = num_pid_m * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * num_pid_m
- group_size_m = min(num_pid_m, N_e - first_pid_m)
-
- pid_m = first_pid_m + (pid % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
-
- offs_m = (pid_m * block_m + tl.arange(0, block_m)) % N_e
- offs_n = (pid_n * block_n + tl.arange(0, block_n)) % I
- offs_k = tl.arange(0, block_k)
-
- a_ptrs = a_ptr + (offs_m[:, None] * a_stride + offs_k[None, :] * H)
- b_ptrs = b_ptr + (offs_k[:, None] * I + offs_n[None, :])
-
- accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)
- for k in range(0, tl.cdiv(H, block_k)):
- a_mask = (offs_m[:, None] < N_e) & (offs_k[None, :] < H)
- b_mask = (offs_k[:, None] < H) & (offs_n[None, :] < I)
-
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
-
- accumulator += tl.dot(a, b)
-
- a_ptrs += block_k * H
- b_ptrs += block_k * I
- offs_k += block_k
-
- c_ptrs = c_ptr + (offs_m[:, None] * I + offs_n[None, :])
- c_mask = (offs_m[:, None] < N_e) & (offs_n[None, :] < I)
- tl.store(c_ptrs, accumulator, mask=c_mask)
-
-
-@triton.jit
-def swiglu_epilogue(output, gate, up, stride, N, I):
- pid = tl.program_id(axis=0)
- offs_m = pid
- offs_n = tl.arange(0, I)
- gate_val = tl.load(gate + offs_m * stride + offs_n)
- up_val = tl.load(up + offs_m * stride + offs_n)
- silu_gate = gate_val / (1.0 + tl.exp(-gate_val))
- result = silu_gate * up_val
- tl.store(output + offs_m * stride + offs_n, result)
-
-
-def triton_grouped_gemm_gate(hidden_states, W_gate, expert_offsets, E):
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E, is_gate=True):
T_perm, H = hidden_states.shape
- gate_out = torch.empty_like(hidden_states)
+ output = torch.empty_like(hidden_states)
for e in range(E):
start = expert_offsets[e].item()
@@ -137,48 +65,26 @@
if end == start:
continue
N_e = end - start
- I = W_gate.shape[2]
+ I = W.shape[2]
x_e = hidden_states[start:end]
- w_e = W_gate[e]
+ w_e = W[e]
- grid_m = triton.cdiv(N_e, 32)
+ grid_m = triton.cdiv(N_e, 64)
grid_n = triton.cdiv(I, 64)
grid = (grid_m * grid_n,)
- triton_gemm_kernel_gate[grid](
- x_e, w_e, gate_out[start:end],
- x_e.stride(0), w_e.stride(0), gate_out[start:end].stride(0),
+ BLOCK_M = 64
+ BLOCK_N = 64
+ BLOCK_K = 64
+
+ triton_gemm_kernel[grid](
+ x_e, w_e, output[start:end],
+ x_e.stride(0), w_e.stride(0), output[start:end].stride(0),
N_e, I
... (1538 more chars){
"content": "<4200 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "<459 chars \u2014 see below>",
"timeout": 120000,
"description": "Test small shape"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
# Test with small shape
solution.T_total = 4096
solution.H = 2048
solution.I = 1024
solution.E = 64
solution.K = 4
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
inputs = [t.cuda() for t in solution.get_inputs()]
out = model(*inputs)
print('Output shape:', out.shape)
print('Sample values:', out[0, :5])
" 2>&1Traceback (most recent call last):
File "<string>", line 15, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 104, in forward
gate_out = triton_grouped_gemm(hidden_states, self.W_gate, expert_offsets, self.E, is_gate=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 63, in triton_grouped_gemm
start = expert_offsets[e].item()
^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
{
"command": "<652 chars \u2014 see below>",
"description": "Debug expert offsets"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
# Test with small shape
solution.T_total = 4096
solution.H = 2048
solution.I = 1024
solution.E = 64
solution.K = 4
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
inputs = [t.cuda() for t in solution.get_inputs()]
print('hidden_states shape:', inputs[0].shape)
print('expert_offsets:', inputs[1][:10])
print('expert_offsets dtype:', inputs[1].dtype)
for e in range(min(3, solution.E)):
start = inputs[1][e].item()
end = inputs[1][e+1].item()
print(f'expert {e}: rows [{start}:{end}]')
" 2>&1hidden_states shape: torch.Size([16384, 2048])
expert_offsets: tensor([ 0, 256, 512, 768, 1024, 1280, 1536, 1792, 2048, 2304],
device='cuda:0', dtype=torch.int32)
expert_offsets dtype: torch.int32
expert 0: rows [0:256]
expert 1: rows [256:512]
expert 2: rows [512:768]
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -14,7 +14,6 @@
@triton.jit
def triton_gemm_kernel(
a_ptr, b_ptr, c_ptr,
- a_stride, b_stride, c_stride,
N_e, I, H,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
@@ -29,20 +28,22 @@
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
- offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % N_e
- offs_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % I
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
- a_ptrs = a_ptr + (offs_m[:, None] * a_stride + offs_k[None, :] * H)
- b_ptrs = b_ptr + (offs_k[:, None] * I + offs_n[None, :])
+ a_ptrs = a_ptr + offs_m[:, None] * H + offs_k[None, :] * 1
+ b_ptrs = b_ptr + offs_k[:, None] * I + offs_n[None, :] * 1
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(H, BLOCK_K)):
- a_mask = (offs_m[:, None] < N_e) & (offs_k[None, :] < H)
- b_mask = (offs_k[:, None] < H) & (offs_n[None, :] < I)
+ mask_m = offs_m[:, None] < N_e
+ mask_k = offs_k[None, :] < H
+ a = tl.load(a_ptrs, mask=mask_m & mask_k, other=0.0)
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
+ mask_k = offs_k[:, None] < H
+ mask_n = offs_n[None, :] < I
+ b = tl.load(b_ptrs, mask=mask_k & mask_n, other=0.0)
accumulator += tl.dot(a, b)
@@ -50,13 +51,17 @@
b_ptrs += BLOCK_K * I
offs_k += BLOCK_K
- c_ptrs = c_ptr + (offs_m[:, None] * I + offs_n[None, :])
- c_mask = (offs_m[:, None] < N_e) & (offs_n[None, :] < I)
- tl.store(c_ptrs, accumulator, mask=c_mask)
+ offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ c_ptrs = c_ptr + offs_m_out[:, None] * I + offs_n_out[None, :] * 1
+ mask_m_out = offs_m_out[:, None] < N_e
+ mask_n_out = offs_n_out[None, :] < I
+ tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-def triton_grouped_gemm(hidden_states, W, expert_offsets, E, is_gate=True):
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E, BLOCK_M=32, BLOCK_N=32, BLOCK_K=64):
T_perm, H = hidden_states.shape
+ I = W.shape[2]
output = torch.empty_like(hidden_states)
for e in range(E):
@@ -65,22 +70,17 @@
if end == start:
continue
N_e = end - start
- I = W.shape[2]
- x_e = hidden_states[start:end]
- w_e = W[e]
+ x_e_ptr = hidden_states + start * H
+ w_e_ptr = W + e * H * I
+ out_e_ptr = output + start * I
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
+ grid_m = triton.cdiv(N_e, BLOCK_M)
+ grid_n = triton.cdiv(I, BLOCK_N)
grid = (grid_m * grid_n,)
- BLOCK_M = 64
- BLOCK_N = 64
- BLOCK_K = 64
-
triton_gemm_kernel[grid](
- x_e, w_e, output[start:end],
- x_e.stride(0), w_e.stride(0), output[start:end].stride(0),
+ x_e_ptr, w_e_ptr, out_e_ptr,
N_e, I, H,
BLOCK_M, BLOCK_N, BLOCK_K,
)
@@ -101,8 +101,8 @@
nn.init.normal_(self.W_up, std=0.02)
def forward(self, hidden_states: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
- gate_out = triton_grouped_gemm(hidden_states, self.W_gate, expert_offsets, self.E, is_gate=True)
- up_out = triton_grouped_gemm(hidden_states, self.W_up, expert_offsets, self.E, is_gate=False)
+ gate_out = triton_grouped_gemm(hidden_states, self.W_gate, expert_offsets, self.E)
+ up_out = triton_grouped_gemm(hidden_states, self.W_up, expert_offsets, self.E)
silu_gate = torch.nn.functional.silu(gate_out)
return silu_gate * up_out
@@ -133,4 +133,4 @@
def get_init_inputs():
- return [T_total, H, I, E, K]
+ return [T_total, H, I, E, K]{
"content": "<4269 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -15,6 +15,7 @@
def triton_gemm_kernel(
a_ptr, b_ptr, c_ptr,
N_e, I, H,
+ stride_a, stride_b, stride_c,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
@@ -32,8 +33,8 @@
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
- a_ptrs = a_ptr + offs_m[:, None] * H + offs_k[None, :] * 1
- b_ptrs = b_ptr + offs_k[:, None] * I + offs_n[None, :] * 1
+ a_ptrs = a_ptr + offs_m[:, None] * stride_a + offs_k[None, :] * stride_b
+ b_ptrs = b_ptr + offs_k[:, None] * stride_b + offs_n[None, :] * stride_c
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(H, BLOCK_K)):
@@ -47,41 +48,42 @@
accumulator += tl.dot(a, b)
- a_ptrs += BLOCK_K * H
- b_ptrs += BLOCK_K * I
+ a_ptrs += BLOCK_K * stride_b
+ b_ptrs += BLOCK_K * stride_b
offs_k += BLOCK_K
offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- c_ptrs = c_ptr + offs_m_out[:, None] * I + offs_n_out[None, :] * 1
+ c_ptrs = c_ptr + offs_m_out[:, None] * stride_c + offs_n_out[None, :] * 1
mask_m_out = offs_m_out[:, None] < N_e
mask_n_out = offs_n_out[None, :] < I
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-def triton_grouped_gemm(hidden_states, W, expert_offsets, E, BLOCK_M=32, BLOCK_N=32, BLOCK_K=64):
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E, BLOCK_M=32, BLOCK_N=64, BLOCK_K=64):
T_perm, H = hidden_states.shape
I = W.shape[2]
output = torch.empty_like(hidden_states)
for e in range(E):
- start = expert_offsets[e].item()
- end = expert_offsets[e + 1].item()
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
if end == start:
continue
N_e = end - start
- x_e_ptr = hidden_states + start * H
- w_e_ptr = W + e * H * I
- out_e_ptr = output + start * I
+ x_e = hidden_states[start:end]
+ w_e = W[e]
+ out_e = output[start:end]
grid_m = triton.cdiv(N_e, BLOCK_M)
grid_n = triton.cdiv(I, BLOCK_N)
grid = (grid_m * grid_n,)
triton_gemm_kernel[grid](
- x_e_ptr, w_e_ptr, out_e_ptr,
+ x_e, w_e, out_e,
N_e, I, H,
+ x_e.stride(0), w_e.stride(0), out_e.stride(0),
BLOCK_M, BLOCK_N, BLOCK_K,
)
return output
{
"content": "<4386 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -33,8 +33,8 @@
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
- a_ptrs = a_ptr + offs_m[:, None] * stride_a + offs_k[None, :] * stride_b
- b_ptrs = b_ptr + offs_k[:, None] * stride_b + offs_n[None, :] * stride_c
+ a_ptrs = a_ptr + offs_m[:, None] * stride_a + offs_k[None, :] * 1
+ b_ptrs = b_ptr + offs_k[:, None] * stride_b + offs_n[None, :] * 1
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(H, BLOCK_K)):
@@ -48,8 +48,8 @@
accumulator += tl.dot(a, b)
- a_ptrs += BLOCK_K * stride_b
- b_ptrs += BLOCK_K * stride_b
+ a_ptrs += BLOCK_K * stride_a
+ b_ptrs += BLOCK_K * 1
offs_k += BLOCK_K
offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -60,7 +60,7 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-def triton_grouped_gemm(hidden_states, W, expert_offsets, E, BLOCK_M=32, BLOCK_N=64, BLOCK_K=64):
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E, BLOCK_M=64, BLOCK_N=64, BLOCK_K=64):
T_perm, H = hidden_states.shape
I = W.shape[2]
output = torch.empty_like(hidden_states)
@@ -83,7 +83,7 @@
triton_gemm_kernel[grid](
x_e, w_e, out_e,
N_e, I, H,
- x_e.stride(0), w_e.stride(0), out_e.stride(0),
+ x_e.stride(0), w_e.stride(1), out_e.stride(0),
BLOCK_M, BLOCK_N, BLOCK_K,
)
return output
{
"content": "<4365 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "<582 chars \u2014 see below>",
"description": "Test with correct strides",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
# Test with small shape
solution.T_total = 4096
solution.H = 2048
solution.I = 1024
solution.E = 64
solution.K = 4
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
inputs = [t.cuda() for t in solution.get_inputs()]
print('Input hidden_states shape:', inputs[0].shape)
print('Input I:', solution.I)
out = model(*inputs)
print('Output shape:', out.shape)
print('Output I:', solution.I)
print('Sample values:', out[0, :5].float())
" 2>&1Traceback (most recent call last):
File "<string>", line 18, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 106, in forward
gate_out = triton_grouped_gemm(hidden_states, self.W_gate, expert_offsets, self.E)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 69, in triton_grouped_gemm
start = int(expert_offsets[e].item())
^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Input hidden_states shape: torch.Si
... (32 more chars){
"command": "<571 chars \u2014 see below>",
"description": "Test GEMM only",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && CUDA_LAUNCH_BLOCKING=1 uv run python -c "
import torch
import solution
# Test with small shape
solution.T_total = 4096
solution.H = 2048
solution.I = 1024
solution.E = 64
solution.K = 4
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
inputs = [t.cuda() for t in solution.get_inputs()]
print('Input shapes ok')
# Try just the gate GEMM
gate_out = solution.triton_grouped_gemm(inputs[0], model.W_gate, inputs[1], solution.E)
print('Gate out shape:', gate_out.shape)
" 2>&1Traceback (most recent call last):
File "<string>", line 18, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 83, in triton_grouped_gemm
triton_gemm_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 744, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 713, in __call__
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
Input shapes ok
{
"command": "<634 chars \u2014 see below>",
"description": "Test simple Triton kernel"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def simple_kernel(a_ptr, c_ptr, N, BLOCK_M: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_M + tl.arange(0, BLOCK_M)
mask = offs < N
vals = tl.load(a_ptr + offs * 1, mask=mask, other=0.0)
tl.store(c_ptr + offs * 1, vals, mask=mask)
N = 100
BLOCK_M = 32
a = torch.randn(N, device='cuda')
c = torch.zeros_like(a)
grid = (triton.cdiv(N, BLOCK_M),)
simple_kernel[grid](a, c, N, BLOCK_M)
print('Result matches:', torch.allclose(a, c))
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -0,0 +1,20 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def simple_kernel(a_ptr, c_ptr, N, BLOCK_M: tl.constexpr):
+ pid = tl.program_id(axis=0)
+ offs = pid * BLOCK_M + tl.arange(0, BLOCK_M)
+ mask = offs < N
+ vals = tl.load(a_ptr + offs * 1, mask=mask, other=0.0)
+ tl.store(c_ptr + offs * 1, vals, mask=mask)
+
+N = 100
+BLOCK_M = 32
+a = torch.randn(N, device='cuda')
+c = torch.zeros_like(a)
+
+grid = (triton.cdiv(N, BLOCK_M),)
+simple_kernel[grid](a, c, N, BLOCK_M)
+print('Result matches:', torch.allclose(a, c)){
"content": "<535 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test simple Triton kernel from file"
}Result matches: True
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -60,7 +60,7 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-def triton_grouped_gemm(hidden_states, W, expert_offsets, E, BLOCK_M=64, BLOCK_N=64, BLOCK_K=64):
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
T_perm, H = hidden_states.shape
I = W.shape[2]
output = torch.empty_like(hidden_states)
@@ -76,15 +76,15 @@
w_e = W[e]
out_e = output[start:end]
- grid_m = triton.cdiv(N_e, BLOCK_M)
- grid_n = triton.cdiv(I, BLOCK_N)
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
grid = (grid_m * grid_n,)
triton_gemm_kernel[grid](
x_e, w_e, out_e,
N_e, I, H,
- x_e.stride(0), w_e.stride(1), out_e.stride(0),
- BLOCK_M, BLOCK_N, BLOCK_K,
+ x_e.stride(0), w_e.stride(0), out_e.stride(0),
+ 64, 64, 64,
)
return output
{
"content": "<4304 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,20 +1,127 @@
+"""Test Triton grouped GEMM"""
import torch
import triton
import triton.language as tl
+import solution
@triton.jit
-def simple_kernel(a_ptr, c_ptr, N, BLOCK_M: tl.constexpr):
+def triton_gemm_kernel(
+ a_ptr, b_ptr, c_ptr,
+ N_e, I, H,
+ stride_a, stride_b, stride_c,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
pid = tl.program_id(axis=0)
- offs = pid * BLOCK_M + tl.arange(0, BLOCK_M)
- mask = offs < N
- vals = tl.load(a_ptr + offs * 1, mask=mask, other=0.0)
- tl.store(c_ptr + offs * 1, vals, mask=mask)
+ num_pid_m = tl.cdiv(N_e, BLOCK_M)
+ num_pid_n = tl.cdiv(I, BLOCK_N)
+ num_pid_in_group = num_pid_m * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * num_pid_m
+ group_size_m = min(num_pid_m, N_e - first_pid_m)
-N = 100
-BLOCK_M = 32
-a = torch.randn(N, device='cuda')
-c = torch.zeros_like(a)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
-grid = (triton.cdiv(N, BLOCK_M),)
-simple_kernel[grid](a, c, N, BLOCK_M)
-print('Result matches:', torch.allclose(a, c))+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ a_ptrs = a_ptr + offs_m[:, None] * stride_a + offs_k[None, :] * 1
+ b_ptrs = b_ptr + offs_k[:, None] * stride_b + offs_n[None, :] * 1
+
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(H, BLOCK_K)):
+ mask_m = offs_m[:, None] < N_e
+ mask_k = offs_k[None, :] < H
+ a = tl.load(a_ptrs, mask=mask_m & mask_k, other=0.0)
+
+ mask_k = offs_k[:, None] < H
+ mask_n = offs_n[None, :] < I
+ b = tl.load(b_ptrs, mask=mask_k & mask_n, other=0.0)
+
+ accumulator += tl.dot(a, b)
+
+ a_ptrs += BLOCK_K * stride_a
+ b_ptrs += BLOCK_K * 1
+ offs_k += BLOCK_K
+
+ offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ c_ptrs = c_ptr + offs_m_out[:, None] * stride_c + offs_n_out[None, :] * 1
+ mask_m_out = offs_m_out[:, None] < N_e
+ mask_n_out = offs_n_out[None, :] < I
+ tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
+
+
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
+ T_perm, H = hidden_states.shape
+ I = W.shape[2]
+ output = torch.empty_like(hidden_states)
+
+ for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
+
+ x_e = hidden_states[start:end]
+ w_e = W[e]
+ out_e = output[start:end]
+
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
+ grid = (grid_m * grid_n,)
+
+ triton_gemm_kernel[grid](
+ x_e, w_e, out_e,
+ N_e, I, H,
+ x_e.stride(0), w_e.stride(0), out_e.stride(0),
+ 64, 64, 64,
+ )
+ return output
+
+
+# Test with small shape
+solution.T_total = 4096
+solution.H = 2048
+solution.I = 1024
+solution.E = 64
+solution.K = 4
+
+model = solution.Model(*solution.get_init_inputs())
+model = model.cuda()
+inputs = [t.cuda() for t in solution.get_inputs()]
+print('Input shapes ok')
+print('W_gate shape:', model.W_gate.shape)
+print('expert_offsets[:5]:', inputs[1][:5])
+
+# Try first expert only
+e = 0
+start = int(inputs[1][e].item())
+end = int(inputs[1][e+1].item())
+print(f'Expert 0: rows {start}:{end}')
+print(f'x_e shape: {inputs[0][start:end].shape}')
+print(f'w_e shape: {model.W_gate[e].shape}')
+
+# Try gemm for just expert 0
+N_e = end - start
+I = model.W_gate.shape[2]
+H = model.H
+
+x_e = inputs[0][start:end]
+w_e = model.W_gate[e]
+out_e = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
+
+grid_m = triton.cdiv(N_e, 64)
+grid_n = triton.cdiv(I, 64)
+grid = (grid_m * grid_n,)
+print(f'Grid: {grid}, N_e={N_e}, I={I}, H={H}')
+
+triton_gemm_kernel[grid](
+ x_e, w_e, out_e,
+ N_e, I, H,
+ x_e.stride(0), w_e.stride(0), out_e.stride(0),
+ 64, 64, 64,
+)
+print('Expert 0 GEMM succeeded, out_e:', out_e.shape){
"content": "<3669 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Debug Triton GEMM",
"timeout": 120000
}Input shapes ok
W_gate shape: torch.Size([64, 2048, 1024])
expert_offsets[:5]: tensor([ 0, 256, 512, 768, 1024], device='cuda:0', dtype=torch.int32)
Expert 0: rows 0:256
x_e shape: torch.Size([256, 2048])
w_e shape: torch.Size([2048, 1024])
Grid: (64,), N_e=256, I=1024, H=2048
Expert 0 GEMM succeeded, out_e: torch.Size([256, 1024])
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -53,35 +53,6 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
- T_perm, H = hidden_states.shape
- I = W.shape[2]
- output = torch.empty_like(hidden_states)
-
- for e in range(E):
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e + 1].item())
- if end == start:
- continue
- N_e = end - start
-
- x_e = hidden_states[start:end]
- w_e = W[e]
- out_e = output[start:end]
-
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
- grid = (grid_m * grid_n,)
-
- triton_gemm_kernel[grid](
- x_e, w_e, out_e,
- N_e, I, H,
- x_e.stride(0), w_e.stride(0), out_e.stride(0),
- 64, 64, 64,
- )
- return output
-
-
# Test with small shape
solution.T_total = 4096
solution.H = 2048
@@ -93,35 +64,37 @@
model = model.cuda()
inputs = [t.cuda() for t in solution.get_inputs()]
print('Input shapes ok')
-print('W_gate shape:', model.W_gate.shape)
-print('expert_offsets[:5]:', inputs[1][:5])
-# Try first expert only
-e = 0
-start = int(inputs[1][e].item())
-end = int(inputs[1][e+1].item())
-print(f'Expert 0: rows {start}:{end}')
-print(f'x_e shape: {inputs[0][start:end].shape}')
-print(f'w_e shape: {model.W_gate[e].shape}')
+T_perm, H = inputs[0].shape
+I = model.W_gate.shape[2]
+output = torch.empty_like(inputs[0])
-# Try gemm for just expert 0
-N_e = end - start
-I = model.W_gate.shape[2]
-H = model.H
+expert_offsets = inputs[1]
+E = solution.E
-x_e = inputs[0][start:end]
-w_e = model.W_gate[e]
-out_e = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
+for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
-grid_m = triton.cdiv(N_e, 64)
-grid_n = triton.cdiv(I, 64)
-grid = (grid_m * grid_n,)
-print(f'Grid: {grid}, N_e={N_e}, I={I}, H={H}')
+ x_e = inputs[0][start:end]
+ w_e = model.W_gate[e]
+ out_e = output[start:end]
-triton_gemm_kernel[grid](
- x_e, w_e, out_e,
- N_e, I, H,
- x_e.stride(0), w_e.stride(0), out_e.stride(0),
- 64, 64, 64,
-)
-print('Expert 0 GEMM succeeded, out_e:', out_e.shape)+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
+ grid = (grid_m * grid_n,)
+
+ triton_gemm_kernel[grid](
+ x_e, w_e, out_e,
+ N_e, I, H,
+ x_e.stride(0), w_e.stride(0), out_e.stride(0),
+ 64, 64, 64,
+ )
+ if e % 10 == 0:
+ print(f'Expert {e} done, N_e={N_e}, grid={grid}')
+ torch.cuda.synchronize()
+
+print('All experts done'){
"content": "<2857 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test all experts",
"timeout": 120000
}Input shapes ok
Expert 0 done, N_e=256, grid=(64,)
Expert 10 done, N_e=256, grid=(64,)
Expert 20 done, N_e=256, grid=(64,)
Expert 30 done, N_e=256, grid=(64,)
Expert 40 done, N_e=256, grid=(64,)
Expert 50 done, N_e=256, grid=(64,)
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py", line 76, in <module>
start = int(expert_offsets[e].item())
^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -73,28 +73,35 @@
E = solution.E
for e in range(E):
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e + 1].item())
- if end == start:
- continue
- N_e = end - start
+ try:
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
- x_e = inputs[0][start:end]
- w_e = model.W_gate[e]
- out_e = output[start:end]
+ x_e = inputs[0][start:end]
+ w_e = model.W_gate[e]
+ out_e = output[start:end]
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
- grid = (grid_m * grid_n,)
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
+ grid = (grid_m * grid_n,)
- triton_gemm_kernel[grid](
- x_e, w_e, out_e,
- N_e, I, H,
- x_e.stride(0), w_e.stride(0), out_e.stride(0),
- 64, 64, 64,
- )
- if e % 10 == 0:
- print(f'Expert {e} done, N_e={N_e}, grid={grid}')
+ triton_gemm_kernel[grid](
+ x_e, w_e, out_e,
+ N_e, I, H,
+ x_e.stride(0), w_e.stride(0), out_e.stride(0),
+ 64, 64, 64,
+ )
torch.cuda.synchronize()
+ if e % 10 == 0:
+ print(f'Expert {e} done, N_e={N_e}, grid={grid}')
+ except Exception as ex:
+ print(f'Expert {e} failed: {ex}')
+ print(f' start={start}, end={end}, N_e={N_e}')
+ print(f' x_e shape: {x_e.shape}')
+ print(f' out_e shape: {out_e.shape}')
+ break
print('All experts done'){
"content": "<3172 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test with error handling",
"timeout": 120000
}Input shapes ok
Expert 0 done, N_e=256, grid=(64,)
Expert 10 done, N_e=256, grid=(64,)
Expert 20 done, N_e=256, grid=(64,)
Expert 30 done, N_e=256, grid=(64,)
Expert 40 done, N_e=256, grid=(64,)
Expert 50 done, N_e=256, grid=(64,)
Expert 56 failed: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
start=14336, end=14592, N_e=256
x_e shape: torch.Size([256, 2048])
out_e shape: torch.Size([256, 2048])
All experts done
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -8,7 +8,9 @@
def triton_gemm_kernel(
a_ptr, b_ptr, c_ptr,
N_e, I, H,
- stride_a, stride_b, stride_c,
+ stride_a_row, stride_a_col, # Row stride and col stride for A
+ stride_b_row, stride_b_col, # Row stride and col stride for B
+ stride_c_row, stride_c_col, # Row stride and col stride for C
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
@@ -26,8 +28,8 @@
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
- a_ptrs = a_ptr + offs_m[:, None] * stride_a + offs_k[None, :] * 1
- b_ptrs = b_ptr + offs_k[:, None] * stride_b + offs_n[None, :] * 1
+ a_ptrs = a_ptr + offs_m[:, None] * stride_a_row + offs_k[None, :] * stride_a_col
+ b_ptrs = b_ptr + offs_k[:, None] * stride_b_row + offs_n[None, :] * stride_b_col
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(H, BLOCK_K)):
@@ -41,13 +43,13 @@
accumulator += tl.dot(a, b)
- a_ptrs += BLOCK_K * stride_a
- b_ptrs += BLOCK_K * 1
+ a_ptrs += BLOCK_K * stride_a_col
+ b_ptrs += BLOCK_K * stride_b_row
offs_k += BLOCK_K
offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- c_ptrs = c_ptr + offs_m_out[:, None] * stride_c + offs_n_out[None, :] * 1
+ c_ptrs = c_ptr + offs_m_out[:, None] * stride_c_row + offs_n_out[None, :] * stride_c_col
mask_m_out = offs_m_out[:, None] < N_e
mask_n_out = offs_n_out[None, :] < I
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
@@ -67,7 +69,9 @@
T_perm, H = inputs[0].shape
I = model.W_gate.shape[2]
-output = torch.empty_like(inputs[0])
+
+# Output is (T_perm, I), not (T_perm, H)
+output = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
expert_offsets = inputs[1]
E = solution.E
@@ -82,16 +86,18 @@
x_e = inputs[0][start:end]
w_e = model.W_gate[e]
- out_e = output[start:end]
+ out_e_base = output + start * I # Pointer to start position in output
grid_m = triton.cdiv(N_e, 64)
grid_n = triton.cdiv(I, 64)
grid = (grid_m * grid_n,)
triton_gemm_kernel[grid](
- x_e, w_e, out_e,
+ x_e, w_e, out_e_base,
N_e, I, H,
- x_e.stride(0), w_e.stride(0), out_e.stride(0),
+ x_e.stride(0), 1, # A: row stride, col stride
+ w_e.stride(0), w_e.stride(1), # B: row stride, col stride
+ I, 1, # C: row stride, col stride
64, 64, 64,
)
torch.cuda.synchronize()
@@ -100,8 +106,7 @@
except Exception as ex:
print(f'Expert {e} failed: {ex}')
print(f' start={start}, end={end}, N_e={N_e}')
- print(f' x_e shape: {x_e.shape}')
- print(f' out_e shape: {out_e.shape}')
break
-print('All experts done')+print('All experts done')
+print('Output shape:', output.shape){
"content": "<3609 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test with correct output shape",
"timeout": 120000
}Input shapes ok
Expert 0 done, N_e=256, grid=(64,)
Expert 10 done, N_e=256, grid=(64,)
Expert 20 done, N_e=256, grid=(64,)
Expert 30 done, N_e=256, grid=(64,)
Expert 40 done, N_e=256, grid=(64,)
Expert 50 done, N_e=256, grid=(64,)
Expert 60 done, N_e=256, grid=(64,)
All experts done
Output shape: torch.Size([16384, 1024])
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,4 +1,4 @@
-"""Test Triton grouped GEMM"""
+"""Verify correctness of Triton grouped GEMM + SwiGLU"""
import torch
import triton
import triton.language as tl
@@ -8,9 +8,9 @@
def triton_gemm_kernel(
a_ptr, b_ptr, c_ptr,
N_e, I, H,
- stride_a_row, stride_a_col, # Row stride and col stride for A
- stride_b_row, stride_b_col, # Row stride and col stride for B
- stride_c_row, stride_c_col, # Row stride and col stride for C
+ stride_a_row, stride_a_col,
+ stride_b_row, stride_b_col,
+ stride_c_row, stride_c_col,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
@@ -55,6 +55,37 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
+ T_perm, H = hidden_states.shape
+ I = W.shape[2]
+ output = torch.empty(T_perm, I, device=hidden_states.device, dtype=hidden_states.dtype)
+
+ for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
+
+ x_e = hidden_states[start:end]
+ w_e = W[e]
+ out_e_base = output + start * I
+
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
+ grid = (grid_m * grid_n,)
+
+ triton_gemm_kernel[grid](
+ x_e, w_e, out_e_base,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_e.stride(0), w_e.stride(1),
+ I, 1,
+ 64, 64, 64,
+ )
+ return output
+
+
# Test with small shape
solution.T_total = 4096
solution.H = 2048
@@ -64,49 +95,25 @@
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
+
+torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
-print('Input shapes ok')
-T_perm, H = inputs[0].shape
-I = model.W_gate.shape[2]
+# Reference
+ref_out = solution.reference.Model(*solution.get_init_inputs()).cuda()
+ref_out.load_state_dict(model.state_dict())
+ref_out.eval()
+with torch.no_grad():
+ ref_result = ref_out(*inputs)
-# Output is (T_perm, I), not (T_perm, H)
-output = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
+# Triton GEMM + SwiGLU
+gate_out = triton_grouped_gemm(inputs[0], model.W_gate, inputs[1], solution.E)
+up_out = triton_grouped_gemm(inputs[0], model.W_up, inputs[1], solution.E)
+silu_gate = torch.nn.functional.silu(gate_out)
+triton_result = silu_gate * up_out
-expert_offsets = inputs[1]
-E = solution.E
-
-for e in range(E):
- try:
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e + 1].item())
- if end == start:
- continue
- N_e = end - start
-
- x_e = inputs[0][start:end]
- w_e = model.W_gate[e]
- out_e_base = output + start * I # Pointer to start position in output
-
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
- grid = (grid_m * grid_n,)
-
- triton_gemm_kernel[grid](
- x_e, w_e, out_e_base,
- N_e, I, H,
- x_e.stride(0), 1, # A: row stride, col stride
- w_e.stride(0), w_e.stride(1), # B: row stride, col stride
- I, 1, # C: row stride, col stride
- 64, 64, 64,
- )
- torch.cuda.synchronize()
- if e % 10 == 0:
- print(f'Expert {e} done, N_e={N_e}, grid={grid}')
- except Exception as ex:
- print(f'Expert {e} failed: {ex}')
- print(f' start={start}, end={end}, N_e={N_e}')
- break
-
-print('All experts done')
-print('Output shape:', output.shape)+print('Reference shape:', ref_result.shape)
+print('Triton shape:', triton_result.shape)
+print('Max diff:', (ref_result - triton_result).abs().max().item())
+print('Mean diff:', (ref_result - triton_result).abs().mean().item())
+print('Close?', torch.allclose(ref_result, triton_result, atol=0.02, rtol=0.02)){
"content": "<3833 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Verify Triton correctness",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py", line 103, in <module>
ref_out = solution.reference.Model(*solution.get_init_inputs()).cuda()
^^^^^^^^^^^^^^^^^^
AttributeError: module 'solution' has no attribute 'reference'
{
"command": "<624 chars \u2014 see below>",
"description": "Test reference import"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
import reference
# Test with small shape
solution.T_total = 4096
solution.H = 2048
solution.I = 1024
solution.E = 64
solution.K = 4
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
# Reference
ref_out = reference.Model(*reference.get_init_inputs()).cuda()
ref_out.load_state_dict(model.state_dict())
ref_out.eval()
with torch.no_grad():
ref_result = ref_out(*inputs)
print('Reference done')
" 2>&1Traceback (most recent call last):
File "<string>", line 21, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2639, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for Model:
size mismatch for W_gate: copying a param with shape torch.Size([64, 2048, 1024]) from checkpoint, the shape in current model is torch.Size([128, 4096, 1536]).
size mismatch for W_up: copying a param with shape torch.Size([64, 2048, 1024]) from checkpoint, the shape in current model is torch.Size([128, 4096, 1536]).
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 | head -60",
"description": "Run check.py",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 110, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/check.py", line 71, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 106, in forward
gate_out = triton_grouped_gemm(hidden_states, self.W_gate, expert_offsets, self.E)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py", line 69, in triton_grouped_gemm
start = int(expert_offsets[e].item())
^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be as
... (213 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -15,7 +15,9 @@
def triton_gemm_kernel(
a_ptr, b_ptr, c_ptr,
N_e, I, H,
- stride_a, stride_b, stride_c,
+ stride_a_row, stride_a_col,
+ stride_b_row, stride_b_col,
+ stride_c_row, stride_c_col,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
@@ -33,8 +35,8 @@
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
- a_ptrs = a_ptr + offs_m[:, None] * stride_a + offs_k[None, :] * 1
- b_ptrs = b_ptr + offs_k[:, None] * stride_b + offs_n[None, :] * 1
+ a_ptrs = a_ptr + offs_m[:, None] * stride_a_row + offs_k[None, :] * stride_a_col
+ b_ptrs = b_ptr + offs_k[:, None] * stride_b_row + offs_n[None, :] * stride_b_col
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(H, BLOCK_K)):
@@ -48,13 +50,13 @@
accumulator += tl.dot(a, b)
- a_ptrs += BLOCK_K * stride_a
- b_ptrs += BLOCK_K * 1
+ a_ptrs += BLOCK_K * stride_a_col
+ b_ptrs += BLOCK_K * stride_b_row
offs_k += BLOCK_K
offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- c_ptrs = c_ptr + offs_m_out[:, None] * stride_c + offs_n_out[None, :] * 1
+ c_ptrs = c_ptr + offs_m_out[:, None] * stride_c_row + offs_n_out[None, :] * stride_c_col
mask_m_out = offs_m_out[:, None] < N_e
mask_n_out = offs_n_out[None, :] < I
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
@@ -63,7 +65,7 @@
def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
T_perm, H = hidden_states.shape
I = W.shape[2]
- output = torch.empty_like(hidden_states)
+ output = torch.empty(T_perm, I, device=hidden_states.device, dtype=hidden_states.dtype)
for e in range(E):
start = int(expert_offsets[e].item())
@@ -74,16 +76,18 @@
x_e = hidden_states[start:end]
w_e = W[e]
- out_e = output[start:end]
+ out_e_base = output + start * I
grid_m = triton.cdiv(N_e, 64)
grid_n = triton.cdiv(I, 64)
grid = (grid_m * grid_n,)
triton_gemm_kernel[grid](
- x_e, w_e, out_e,
+ x_e, w_e, out_e_base,
N_e, I, H,
- x_e.stride(0), w_e.stride(0), out_e.stride(0),
+ x_e.stride(0), 1,
+ w_e.stride(0), w_e.stride(1),
+ I, 1,
64, 64, 64,
)
return output
{
"content": "<4515 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "<191 chars \u2014 see below>",
"description": "Test .item() on CUDA tensor"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
t = torch.tensor([1,2,3], device='cuda')
print(t[0].item()) # This should work
" 2>&11
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,4 +1,4 @@
-"""Verify correctness of Triton grouped GEMM + SwiGLU"""
+"""Test Triton grouped GEMM with better error handling"""
import torch
import triton
import triton.language as tl
@@ -95,25 +95,9 @@
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
-
-torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
-# Reference
-ref_out = solution.reference.Model(*solution.get_init_inputs()).cuda()
-ref_out.load_state_dict(model.state_dict())
-ref_out.eval()
-with torch.no_grad():
- ref_result = ref_out(*inputs)
-
-# Triton GEMM + SwiGLU
gate_out = triton_grouped_gemm(inputs[0], model.W_gate, inputs[1], solution.E)
-up_out = triton_grouped_gemm(inputs[0], model.W_up, inputs[1], solution.E)
-silu_gate = torch.nn.functional.silu(gate_out)
-triton_result = silu_gate * up_out
-
-print('Reference shape:', ref_result.shape)
-print('Triton shape:', triton_result.shape)
-print('Max diff:', (ref_result - triton_result).abs().max().item())
-print('Mean diff:', (ref_result - triton_result).abs().mean().item())
-print('Close?', torch.allclose(ref_result, triton_result, atol=0.02, rtol=0.02))+print('Gate out shape:', gate_out.shape)
+torch.cuda.synchronize()
+print('Done'){
"content": "<3204 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test updated solution",
"timeout": 120000
}Gate out shape: torch.Size([16384, 1024])
Done
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,4 +1,4 @@
-"""Test Triton grouped GEMM with better error handling"""
+"""Test Triton grouped GEMM + SwiGLU correctness"""
import torch
import triton
import triton.language as tl
@@ -97,7 +97,21 @@
model = model.cuda()
inputs = [t.cuda() for t in solution.get_inputs()]
+# Reference
+import reference as ref_module
+ref_model = ref_module.Model(*ref_module.get_init_inputs()).cuda()
+ref_model.load_state_dict(model.state_dict())
+ref_model.eval()
+with torch.no_grad():
+ ref_out = ref_model(*inputs)
+
+# Triton
gate_out = triton_grouped_gemm(inputs[0], model.W_gate, inputs[1], solution.E)
-print('Gate out shape:', gate_out.shape)
-torch.cuda.synchronize()
-print('Done')+up_out = triton_grouped_gemm(inputs[0], model.W_up, inputs[1], solution.E)
+silu_gate = torch.nn.functional.silu(gate_out)
+triton_result = silu_gate * up_out
+
+print('Reference shape:', ref_out.shape)
+print('Triton shape:', triton_result.shape)
+print('Max diff:', (ref_out - triton_result).abs().max().item())
+print('Close?', torch.allclose(ref_out, triton_result, atol=0.02, rtol=0.02)){
"content": "<3742 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -3,6 +3,7 @@
import triton
import triton.language as tl
import solution
+import reference as ref_mod
@triton.jit
def triton_gemm_kernel(
@@ -98,8 +99,7 @@
inputs = [t.cuda() for t in solution.get_inputs()]
# Reference
-import reference as ref_module
-ref_model = ref_module.Model(*ref_module.get_init_inputs()).cuda()
+ref_model = ref_mod.Model(*ref_mod.get_init_inputs()).cuda()
ref_model.load_state_dict(model.state_dict())
ref_model.eval()
with torch.no_grad():
{
"content": "<3733 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test correctness",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py", line 103, in <module>
ref_model.load_state_dict(model.state_dict())
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2639, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for Model:
size mismatch for W_gate: copying a param with shape torch.Size([64, 2048, 1024]) from checkpoint, the shape in current model is torch.Size([128, 4096, 1536]).
size mismatch for W_up: copying a param with shape torch.Size([64, 2048, 1024]) from checkpoint, the shape in current model is torch.Size([128, 4096, 1536]).
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -4,6 +4,17 @@
import triton.language as tl
import solution
import reference as ref_mod
+
+# Set shapes for both reference and solution
+for mod, T_total, H, I, E, K in [
+ (solution, 4096, 2048, 1024, 64, 4),
+ (ref_mod, 4096, 2048, 1024, 64, 4),
+]:
+ mod.T_total = T_total
+ mod.H = H
+ mod.I = I
+ mod.E = E
+ mod.K = K
@triton.jit
def triton_gemm_kernel(
@@ -87,13 +98,6 @@
return output
-# Test with small shape
-solution.T_total = 4096
-solution.H = 2048
-solution.I = 1024
-solution.E = 64
-solution.K = 4
-
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
inputs = [t.cuda() for t in solution.get_inputs()]
{
"content": "<3863 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test correctness v2",
"timeout": 120000
}Reference shape: torch.Size([16384, 1024])
Triton shape: torch.Size([16384, 1024])
Max diff: 0.06640625
Close? False
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,20 +1,7 @@
-"""Test Triton grouped GEMM + SwiGLU correctness"""
+"""Debug Triton GEMM correctness"""
import torch
import triton
import triton.language as tl
-import solution
-import reference as ref_mod
-
-# Set shapes for both reference and solution
-for mod, T_total, H, I, E, K in [
- (solution, 4096, 2048, 1024, 64, 4),
- (ref_mod, 4096, 2048, 1024, 64, 4),
-]:
- mod.T_total = T_total
- mod.H = H
- mod.I = I
- mod.E = E
- mod.K = K
@triton.jit
def triton_gemm_kernel(
@@ -67,55 +54,99 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
- T_perm, H = hidden_states.shape
- I = W.shape[2]
- output = torch.empty(T_perm, I, device=hidden_states.device, dtype=hidden_states.dtype)
+# Small test case
+N_e = 32
+I = 64
+H = 128
- for e in range(E):
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e + 1].item())
- if end == start:
- continue
- N_e = end - start
+torch.manual_seed(42)
+A = torch.randn(N_e, H, dtype=torch.bfloat16, device='cuda')
+B = torch.randn(H, I, dtype=torch.bfloat16, device='cuda')
+C = torch.zeros(N_e, I, dtype=torch.float32, device='cuda')
- x_e = hidden_states[start:end]
- w_e = W[e]
- out_e_base = output + start * I
+# Reference: matmul in fp32
+C_ref = (A.float() @ B.float()).to(torch.bfloat16)
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
- grid = (grid_m * grid_n,)
+grid_m = triton.cdiv(N_e, 64)
+grid_n = triton.cdiv(I, 64)
+grid = (grid_m * grid_n,)
- triton_gemm_kernel[grid](
- x_e, w_e, out_e_base,
- N_e, I, H,
- x_e.stride(0), 1,
- w_e.stride(0), w_e.stride(1),
- I, 1,
- 64, 64, 64,
- )
- return output
+triton_gemm_kernel[grid](
+ A, B, C,
+ N_e, I, H,
+ A.stride(0), 1,
+ B.stride(0), B.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
+
+print('Triton result shape:', C.shape)
+print('Reference result shape:', C_ref.shape)
+print('Max diff (fp32 accum):', (C.float() - C_ref.float()).abs().max().item())
+print('Close (fp32)?', torch.allclose(C.float(), C_ref.float(), atol=0.01, rtol=0.01))
+
+# Try with BLOCK_M=32
+C2 = torch.zeros(N_e, I, dtype=torch.float32, device='cuda')
+grid2 = (triton.cdiv(N_e, 32) * triton.cdiv(I, 64),)
+
+@triton.jit
+def triton_gemm_kernel2(
+ a_ptr, b_ptr, c_ptr,
+ N_e, I, H,
+ stride_a_row, stride_a_col,
+ stride_b_row, stride_b_col,
+ stride_c_row, stride_c_col,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(N_e, BLOCK_M)
+ num_pid_n = tl.cdiv(I, BLOCK_N)
+ num_pid_in_group = num_pid_m * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * num_pid_m
+ group_size_m = min(num_pid_m, N_e - first_pid_m)
+
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ a_ptrs = a_ptr + offs_m[:, None] * stride_a_row + offs_k[None, :] * stride_a_col
+ b_ptrs = b_ptr + offs_k[:, None] * stride_b_row + offs_n[None, :] * stride_b_col
+
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(H, BLOCK_K)):
+ mask_m = offs_m[:, None] < N_e
+ mask_k = offs_k[None, :] < H
+ a = tl.load(a_ptrs, mask=mask_m & mask_k, other=0.0)
+
+ mask_k = offs_k[:, None] < H
+ mask_n = offs_n[None, :] < I
+ b = tl.load(b_ptrs, mask=mask_k & mask_n, other=0.0)
+
+ accumulator += tl.dot(a, b)
+
+ a_ptrs += BLOCK_K * stride_a_col
+ b_ptrs += BLOCK_K * stride_b_row
+ offs_k += BLOCK_K
+
+ offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ c_ptrs = c_ptr + offs_m_out[:, None] * stride_c_row + offs_n_out[None, :] * stride_c_col
+ mask_m_out = offs_m_out[:, None] < N_e
+ mask_n_out = offs_n_out[None, :] < I
+ tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-model = solution.Model(*solution.get_init_inputs())
-model = model.cuda()
-inputs = [t.cuda() for t in solution.get_inputs()]
+triton_gemm_kernel2[grid2](
+ A, B, C2,
+ N_e, I, H,
+ A.stride(0), 1,
+ B.stride(0), B.stride(1),
+ I, 1,
+ 32, 64, 64,
+)
-# Reference
-ref_model = ref_mod.Model(*ref_mod.get_init_inputs()).cuda()
-ref_model.load_state_dict(model.state_dict())
-ref_model.eval()
-with torch.no_grad():
- ref_out = ref_model(*inputs)
-
-# Triton
-gate_out = triton_grouped_gemm(inputs[0], model.W_gate, inputs[1], solution.E)
-up_out = triton_grouped_gemm(inputs[0], model.W_up, inputs[1], solution.E)
-silu_gate = torch.nn.functional.silu(gate_out)
-triton_result = silu_gate * up_out
-
-print('Reference shape:', ref_out.shape)
-print('Triton shape:', triton_result.shape)
-print('Max diff:', (ref_out - triton_result).abs().max().item())
-print('Close?', torch.allclose(ref_out, triton_result, atol=0.02, rtol=0.02))+print('Max diff (BLOCK_M=32):', (C2.float() - C_ref.float()).abs().max().item()){
"content": "<4906 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test GEMM correctness",
"timeout": 60000
}Triton result shape: torch.Size([32, 64])
Reference result shape: torch.Size([32, 64])
Max diff (fp32 accum): 0.12030029296875
Close (fp32)? True
Max diff (BLOCK_M=32): 0.12030029296875
{
"command": "<3764 chars \u2014 see below>",
"description": "Test grouped GEMM with full precision reference",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def triton_gemm_kernel(
a_ptr, b_ptr, c_ptr,
N_e, I, H,
stride_a_row, stride_a_col,
stride_b_row, stride_b_col,
stride_c_row, stride_c_col,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(N_e, BLOCK_M)
num_pid_n = tl.cdiv(I, BLOCK_N)
num_pid_in_group = num_pid_m * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * num_pid_m
group_size_m = min(num_pid_m, N_e - first_pid_m)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_a_row + offs_k[None, :] * stride_a_col
b_ptrs = b_ptr + offs_k[:, None] * stride_b_row + offs_n[None, :] * stride_b_col
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(H, BLOCK_K)):
mask_m = offs_m[:, None] < N_e
mask_k = offs_k[None, :] < H
a = tl.load(a_ptrs, mask=mask_m & mask_k, other=0.0)
mask_k = offs_k[:, None] < H
mask_n = offs_n[None, :] < I
b = tl.load(b_ptrs, mask=mask_k & mask_n, other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_a_col
b_ptrs += BLOCK_K * stride_b_row
offs_k += BLOCK_K
offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_m_out[:, None] * stride_c_row + offs_n_out[None, :] * stride_c_col
mask_m_out = offs_m_out[:, None] < N_e
mask_n_out = offs_n_out[None, :] < I
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
# Test grouped GEMM on the MoE shape
T_total, H, I, E, K = 4096, 2048, 1024, 64, 4
T_perm = T_total * K
torch.manual_seed(42)
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
W = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device='cuda')
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device='cuda')
offsets[1:] = torch.cumsum(counts, dim=0)
# Run grouped GEMM
output = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
for e in range(E):
start = int(offsets[e].item())
end = int(offsets[e + 1].item())
if end == start:
continue
N_e = end - start
x_e = hidden[start:end]
w_e = W[e]
out_e = output[start:end]
grid_m = triton.cdiv(N_e, 64)
grid_n = triton.cdiv(I, 64)
grid = (grid_m * grid_n,)
triton_gemm_kernel[grid](
x_e, w_e, out_e,
N_e, I, H,
x_e.stride(0), 1,
w_e.stride(0), w_e.stride(1),
I, 1,
64, 64, 64,
)
torch.cuda.synchronize()
# Compare with reference
ref_output = torch.zeros(T_perm, I, device='cuda', dtype=torch.float32)
for e in range(E):
start = int(offsets[e].item())
end = int(offsets[e + 1].item())
if end == start:
continue
x_e = hidden[start:end].float()
w_e = W[e].float()
ref_output[start:end] = x_e @ w_e
ref_output_bf16 = ref_output.to(torch.bfloat16)
max_diff = (output.float() - ref_output).abs().max().item()
print(f'Max diff (fp32 ref vs bf16 out): {max_diff}')
print(f'Max diff (bf16 vs bf16): {(output - ref_output_bf16).abs().max().item()}')
print(f'Close (bf16 atol=0.02)? {torch.allclose(output, ref_output_bf16, atol=0.02)}')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,4 +1,4 @@
-"""Debug Triton GEMM correctness"""
+"""Test Triton grouped GEMM correctness"""
import torch
import triton
import triton.language as tl
@@ -54,99 +54,63 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-# Small test case
-N_e = 32
-I = 64
-H = 128
+# Test grouped GEMM on the MoE shape
+T_total, H, I, E, K = 4096, 2048, 1024, 64, 4
+T_perm = T_total * K
torch.manual_seed(42)
-A = torch.randn(N_e, H, dtype=torch.bfloat16, device='cuda')
-B = torch.randn(H, I, dtype=torch.bfloat16, device='cuda')
-C = torch.zeros(N_e, I, dtype=torch.float32, device='cuda')
+hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
+W = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
-# Reference: matmul in fp32
-C_ref = (A.float() @ B.float()).to(torch.bfloat16)
+base = T_perm // E
+rem = T_perm - base * E
+counts = torch.full((E,), base, dtype=torch.int32, device='cuda')
+counts[:rem] += 1
+offsets = torch.zeros(E + 1, dtype=torch.int32, device='cuda')
+offsets[1:] = torch.cumsum(counts, dim=0)
-grid_m = triton.cdiv(N_e, 64)
-grid_n = triton.cdiv(I, 64)
-grid = (grid_m * grid_n,)
+# Run grouped GEMM
+output = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
+for e in range(E):
+ start = int(offsets[e].item())
+ end = int(offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
-triton_gemm_kernel[grid](
- A, B, C,
- N_e, I, H,
- A.stride(0), 1,
- B.stride(0), B.stride(1),
- I, 1,
- 64, 64, 64,
-)
+ x_e = hidden[start:end]
+ w_e = W[e]
+ out_e = output[start:end]
-print('Triton result shape:', C.shape)
-print('Reference result shape:', C_ref.shape)
-print('Max diff (fp32 accum):', (C.float() - C_ref.float()).abs().max().item())
-print('Close (fp32)?', torch.allclose(C.float(), C_ref.float(), atol=0.01, rtol=0.01))
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
+ grid = (grid_m * grid_n,)
-# Try with BLOCK_M=32
-C2 = torch.zeros(N_e, I, dtype=torch.float32, device='cuda')
-grid2 = (triton.cdiv(N_e, 32) * triton.cdiv(I, 64),)
+ triton_gemm_kernel[grid](
+ x_e, w_e, out_e,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_e.stride(0), w_e.stride(1),
+ I, 1,
+ 64, 64, 64,
+ )
-@triton.jit
-def triton_gemm_kernel2(
- a_ptr, b_ptr, c_ptr,
- N_e, I, H,
- stride_a_row, stride_a_col,
- stride_b_row, stride_b_col,
- stride_c_row, stride_c_col,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
-):
- pid = tl.program_id(axis=0)
- num_pid_m = tl.cdiv(N_e, BLOCK_M)
- num_pid_n = tl.cdiv(I, BLOCK_N)
- num_pid_in_group = num_pid_m * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * num_pid_m
- group_size_m = min(num_pid_m, N_e - first_pid_m)
+torch.cuda.synchronize()
- pid_m = first_pid_m + (pid % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
+# Compare with reference
+ref_output = torch.zeros(T_perm, I, device='cuda', dtype=torch.float32)
+for e in range(E):
+ start = int(offsets[e].item())
+ end = int(offsets[e + 1].item())
+ if end == start:
+ continue
+ x_e = hidden[start:end].float()
+ w_e = W[e].float()
+ ref_output[start:end] = x_e @ w_e
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_K)
+ref_output_bf16 = ref_output.to(torch.bfloat16)
- a_ptrs = a_ptr + offs_m[:, None] * stride_a_row + offs_k[None, :] * stride_a_col
- b_ptrs = b_ptr + offs_k[:, None] * stride_b_row + offs_n[None, :] * stride_b_col
-
- accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k in range(0, tl.cdiv(H, BLOCK_K)):
- mask_m = offs_m[:, None] < N_e
- mask_k = offs_k[None, :] < H
- a = tl.load(a_ptrs, mask=mask_m & mask_k, other=0.0)
-
- mask_k = offs_k[:, None] < H
- mask_n = offs_n[None, :] < I
- b = tl.load(b_ptrs, mask=mask_k & mask_n, other=0.0)
-
- accumulator += tl.dot(a, b)
-
- a_ptrs += BLOCK_K * stride_a_col
- b_ptrs += BLOCK_K * stride_b_row
- offs_k += BLOCK_K
-
- offs_m_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n_out = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- c_ptrs = c_ptr + offs_m_out[:, None] * stride_c_row + offs_n_out[None, :] * stride_c_col
- mask_m_out = offs_m_out[:, None] < N_e
- mask_n_out = offs_n_out[None, :] < I
- tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-
-
-triton_gemm_kernel2[grid2](
- A, B, C2,
- N_e, I, H,
- A.stride(0), 1,
- B.stride(0), B.stride(1),
- I, 1,
- 32, 64, 64,
-)
-
-print('Max diff (BLOCK_M=32):', (C2.float() - C_ref.float()).abs().max().item())+max_diff = (output.float() - ref_output).abs().max().item()
+print(f'Max diff (fp32 ref vs bf16 out): {max_diff}')
+print(f'Max diff (bf16 vs bf16): {(output - ref_output_bf16).abs().max().item()}')
+print(f'Close (bf16 atol=0.02)? {torch.allclose(output, ref_output_bf16, atol=0.02)}'){
"content": "<3708 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test GEMM correctness with full reference",
"timeout": 120000
}Max diff (fp32 ref vs bf16 out): 0.000977158546447754
Max diff (bf16 vs bf16): 0.001953125
Close (bf16 atol=0.02)? True
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && CUDA_LAUNCH_BLOCKING=1 uv run python check.py 2>&1 | head -80",
"description": "Run check.py with CUDA_LAUNCH_BLOCKING",
"timeout": 300000
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.180664 (atol=0.02, rtol=0.02)
{
"command": "<970 chars \u2014 see below>",
"description": "Debug shape 0"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import triton
import triton.language as tl
# Test shape 0 directly
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
T_perm = T_total * K
print(f'Total tokens: {T_perm}')
print(f'H={H}, I={I}, E={E}, K={K}')
print(f'Bytes per token: {H * 2 + 2 * H * I * 2}')
torch.manual_seed(42)
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
W = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
base = T_perm // E
rem = T_perm - base * E
counts = torch.full((E,), base, dtype=torch.int32, device='cuda')
counts[:rem] += 1
offsets = torch.zeros(E + 1, dtype=torch.int32, device='cuda')
offsets[1:] = torch.cumsum(counts, dim=0)
# Check first expert
e = 0
start = int(offsets[e].item())
end = int(offsets[e + 1].item())
N_e = end - start
print(f'Expert 0: N_e={N_e}, grid_m={triton.cdiv(N_e, 64)}, grid_n={triton.cdiv(I, 64)}')
" 2>&1Total tokens: 262144
H=4096, I=1536, E=128, K=8
Bytes per token: 25174016
Expert 0: N_e=2048, grid_m=32, grid_n=24
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,7 +1,23 @@
-"""Test Triton grouped GEMM correctness"""
+"""Debug correctness issue"""
import torch
import triton
import triton.language as tl
+import solution
+import reference as ref_mod
+
+# Test with shape 0
+T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
+T_perm = T_total * K
+
+for mod, t, h, ii, ee, kk in [
+ (solution, T_total, H, I, E, K),
+ (ref_mod, T_total, H, I, E, K),
+]:
+ mod.T_total = t
+ mod.H = h
+ mod.I = ii
+ mod.E = ee
+ mod.K = kk
@triton.jit
def triton_gemm_kernel(
@@ -54,13 +70,43 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-# Test grouped GEMM on the MoE shape
-T_total, H, I, E, K = 4096, 2048, 1024, 64, 4
-T_perm = T_total * K
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
+ T_perm, H = hidden_states.shape
+ I = W.shape[2]
+ output = torch.empty(T_perm, I, device=hidden_states.device, dtype=hidden_states.dtype)
+ for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
+
+ x_e = hidden_states[start:end]
+ w_e = W[e]
+ out_e_base = output + start * I
+
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
+ grid = (grid_m * grid_n,)
+
+ triton_gemm_kernel[grid](
+ x_e, w_e, out_e_base,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_e.stride(0), w_e.stride(1),
+ I, 1,
+ 64, 64, 64,
+ )
+ return output
+
+
+# Just test expert 0
torch.manual_seed(42)
hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
-W = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
+
+W_gate = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
+W_up = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
base = T_perm // E
rem = T_perm - base * E
@@ -69,48 +115,53 @@
offsets = torch.zeros(E + 1, dtype=torch.int32, device='cuda')
offsets[1:] = torch.cumsum(counts, dim=0)
-# Run grouped GEMM
-output = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
-for e in range(E):
- start = int(offsets[e].item())
- end = int(offsets[e + 1].item())
- if end == start:
- continue
- N_e = end - start
+# Expert 0 only
+e = 0
+start = int(offsets[e].item())
+end = int(offsets[e + 1].item())
+N_e = end - start
- x_e = hidden[start:end]
- w_e = W[e]
- out_e = output[start:end]
+x_e = hidden[start:end]
+w_g = W_gate[e]
+w_u = W_up[e]
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
- grid = (grid_m * grid_n,)
+print(f'Expert 0: rows {start}:{end} = {N_e} tokens')
- triton_gemm_kernel[grid](
- x_e, w_e, out_e,
- N_e, I, H,
- x_e.stride(0), 1,
- w_e.stride(0), w_e.stride(1),
- I, 1,
- 64, 64, 64,
- )
+# Reference
+gate_ref = x_e.float() @ w_g.float()
+up_ref = x_e.float() @ w_u.float()
+gate_ref_bf16 = gate_ref.to(torch.bfloat16)
+up_ref_bf16 = up_ref.to(torch.bfloat16)
+# Triton
+gate_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
+up_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
+
+grid_m = triton.cdiv(N_e, 64)
+grid_n = triton.cdiv(I, 64)
+grid = (grid_m * grid_n,)
+print(f'Grid: {grid}, N_e={N_e}, I={I}, H={H}')
+
+triton_gemm_kernel[grid](
+ x_e, w_g, gate_out,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_g.stride(0), w_g.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
torch.cuda.synchronize()
-# Compare with reference
-ref_output = torch.zeros(T_perm, I, device='cuda', dtype=torch.float32)
-for e in range(E):
- start = int(offsets[e].item())
- end = int(offsets[e + 1].item())
- if end == start:
- continue
- x_e = hidden[start:end].float()
- w_e = W[e].float()
- ref_output[start:end] = x_e @ w_e
+triton_gemm_kernel[grid](
+ x_e, w_u, up_out,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_u.stride(0), w_u.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
+torch.cuda.synchronize()
-ref_output_bf16 = ref_output.to(torch.bfloat16)
-
-max_diff = (output.float() - ref_output).abs().max().item()
-print(f'Max diff (fp32 ref vs bf16 out): {max_diff}')
-print(f'Max diff (bf16 vs bf16): {(output - ref_output_bf16).abs().max().item()}')
-print(f'Close (bf16 atol=0.02)? {torch.allclose(output, ref_output_bf16, atol=0.02)}')+print(f'Gate ref max: {gate_ref.abs().max().item()}, Triton max: {gate_out.abs().max().item()}')
+print(f'Max diff gate: {(gate_out.float() - gate_ref).abs().max().item()}')
+print(f'Max diff up: {(up_out.float() - up_ref).abs().max().item()}'){
"content": "<4832 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Debug shape 0 expert 0",
"timeout": 120000
}Expert 0: rows 0:2048 = 2048 tokens
Grid: (768,), N_e=2048, I=1536, H=4096
Gate ref max: 0.6814078688621521, Triton max: 0.6796875
Max diff gate: 0.0019540786743164062
Max diff up: 0.001942455768585205
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,9 +1,8 @@
-"""Debug correctness issue"""
+"""Debug correctness issue on shape 0"""
import torch
import triton
import triton.language as tl
import solution
-import reference as ref_mod
# Test with shape 0
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
@@ -11,7 +10,6 @@
for mod, t, h, ii, ee, kk in [
(solution, T_total, H, I, E, K),
- (ref_mod, T_total, H, I, E, K),
]:
mod.T_total = t
mod.H = h
@@ -70,98 +68,88 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
-def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
- T_perm, H = hidden_states.shape
- I = W.shape[2]
- output = torch.empty(T_perm, I, device=hidden_states.device, dtype=hidden_states.dtype)
-
- for e in range(E):
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e + 1].item())
- if end == start:
- continue
- N_e = end - start
-
- x_e = hidden_states[start:end]
- w_e = W[e]
- out_e_base = output + start * I
-
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
- grid = (grid_m * grid_n,)
-
- triton_gemm_kernel[grid](
- x_e, w_e, out_e_base,
- N_e, I, H,
- x_e.stride(0), 1,
- w_e.stride(0), w_e.stride(1),
- I, 1,
- 64, 64, 64,
- )
- return output
-
-
-# Just test expert 0
+# Build model
+model = solution.Model(*solution.get_init_inputs())
+model = model.cuda()
torch.manual_seed(42)
-hidden = torch.randn(T_perm, H, dtype=torch.bfloat16, device='cuda') * 0.1
-
-W_gate = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
-W_up = torch.randn(E, H, I, dtype=torch.bfloat16, device='cuda') * 0.02
-
-base = T_perm // E
-rem = T_perm - base * E
-counts = torch.full((E,), base, dtype=torch.int32, device='cuda')
-counts[:rem] += 1
-offsets = torch.zeros(E + 1, dtype=torch.int32, device='cuda')
-offsets[1:] = torch.cumsum(counts, dim=0)
-
-# Expert 0 only
-e = 0
-start = int(offsets[e].item())
-end = int(offsets[e + 1].item())
-N_e = end - start
-
-x_e = hidden[start:end]
-w_g = W_gate[e]
-w_u = W_up[e]
-
-print(f'Expert 0: rows {start}:{end} = {N_e} tokens')
+inputs = [t.cuda() for t in solution.get_inputs()]
# Reference
-gate_ref = x_e.float() @ w_g.float()
-up_ref = x_e.float() @ w_u.float()
-gate_ref_bf16 = gate_ref.to(torch.bfloat16)
-up_ref_bf16 = up_ref.to(torch.bfloat16)
+import reference as ref_mod
+ref_mod.T_total = T_total
+ref_mod.H = H
+ref_mod.I = I
+ref_mod.E = E
+ref_mod.K = K
+ref_model = ref_mod.Model(*ref_mod.get_init_inputs()).cuda()
+ref_model.load_state_dict(model.state_dict())
+ref_model.eval()
+with torch.no_grad():
+ ref_out = ref_model(*inputs)
-# Triton
-gate_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
-up_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
+# Triton GEMM
+gate_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
+up_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
+expert_offsets = inputs[1]
-grid_m = triton.cdiv(N_e, 64)
-grid_n = triton.cdiv(I, 64)
-grid = (grid_m * grid_n,)
-print(f'Grid: {grid}, N_e={N_e}, I={I}, H={H}')
+for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
+ I_e = model.I
-triton_gemm_kernel[grid](
- x_e, w_g, gate_out,
- N_e, I, H,
- x_e.stride(0), 1,
- w_g.stride(0), w_g.stride(1),
- I, 1,
- 64, 64, 64,
-)
+ x_e = inputs[0][start:end]
+ w_g = model.W_gate[e]
+ w_u = model.W_up[e]
+ out_g = gate_out[start:start+N_e]
+ out_u = up_out[start:start+N_e]
+
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I_e, 64)
+ grid = (grid_m * grid_n,)
+
+ triton_gemm_kernel[grid](
+ x_e, w_g, out_g,
+ N_e, I_e, H,
+ x_e.stride(0), 1,
+ w_g.stride(0), w_g.stride(1),
+ I_e, 1,
+ 64, 64, 64,
+ )
+ triton_gemm_kernel[grid](
+ x_e, w_u, out_u,
+ N_e, I_e, H,
+ x_e.stride(0), 1,
+ w_u.stride(0), w_u.stride(1),
+ I_e, 1,
+ 64, 64, 64,
+ )
+
torch.cuda.synchronize()
-triton_gemm_kernel[grid](
- x_e, w_u, up_out,
- N_e, I, H,
- x_e.stride(0), 1,
- w_u.stride(0), w_u.stride(1),
- I, 1,
- 64, 64, 64,
-)
-torch.cuda.synchronize()
+silu_gate = torch.nn.functional.silu(gate_out)
+triton_out = silu_gate * up_out
-print(f'Gate ref max: {gate_ref.abs().max().item()}, Triton max: {gate_out.abs().max().item()}')
-print(f'Max diff gate: {(gate_out.float() - gate_ref).abs().max().item()}')
-print(f'Max diff up: {(up_out.float() - up_ref).abs().max().item()}')+# Find worst expert
+print(f'ref_out shape: {ref_out.shape}, triton_out shape: {triton_out.shape}')
+
+# Find which rows have biggest diff
+diffs = (ref_out.float() - triton_out.float()).abs()
+max_diff, max_idx = diffs.max(), diffs.argmax()
+max_row = max_idx // I
+max_col = max_idx % I
+print(f'Max diff at row {max_row}, col {max_col}: {max_diff.item()}')
+
+# Which expert owns that row?
+for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e+1].item())
+ if max_row >= start and max_row < end:
+ print(f'Expert {e} owns row {max_row} (offset {start}:{end})')
+ break
+
+# Check expert offsets around that area
+max_expert = max_row // (T_perm // E) # rough guess
+print(f'Expert offsets around {max_expert}: {expert_offsets[max_expert:max_expert+5]}'){
"content": "<4536 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -87,7 +87,7 @@
with torch.no_grad():
ref_out = ref_model(*inputs)
-# Triton GEMM
+# Triton GEMM - use explicit pointer arithmetic
gate_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
up_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
expert_offsets = inputs[1]
@@ -98,32 +98,30 @@
if end == start:
continue
N_e = end - start
- I_e = model.I
x_e = inputs[0][start:end]
w_g = model.W_gate[e]
w_u = model.W_up[e]
- out_g = gate_out[start:start+N_e]
- out_u = up_out[start:start+N_e]
grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I_e, 64)
+ grid_n = triton.cdiv(I, 64)
grid = (grid_m * grid_n,)
+ # Use explicit pointer offset
triton_gemm_kernel[grid](
- x_e, w_g, out_g,
- N_e, I_e, H,
+ x_e, w_g, gate_out + start * I,
+ N_e, I, H,
x_e.stride(0), 1,
w_g.stride(0), w_g.stride(1),
- I_e, 1,
+ I, 1,
64, 64, 64,
)
triton_gemm_kernel[grid](
- x_e, w_u, out_u,
- N_e, I_e, H,
+ x_e, w_u, up_out + start * I,
+ N_e, I, H,
x_e.stride(0), 1,
w_u.stride(0), w_u.stride(1),
- I_e, 1,
+ I, 1,
64, 64, 64,
)
@@ -148,8 +146,4 @@
end = int(expert_offsets[e+1].item())
if max_row >= start and max_row < end:
print(f'Expert {e} owns row {max_row} (offset {start}:{end})')
- break
-
-# Check expert offsets around that area
-max_expert = max_row // (T_perm // E) # rough guess
-print(f'Expert offsets around {max_expert}: {expert_offsets[max_expert:max_expert+5]}')+ break{
"content": "<4348 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py",
"limit": 80
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py</path>
<type>file</type>
<content>
1: """Grouped GEMM + fused SwiGLU for MoE up-projection on Blackwell (SM120).
2:
3: Implements: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e])
4: using Triton kernel with grouped GEMM.
5: """
6: from __future__ import annotations
7:
8: import torch
9: import torch.nn as nn
10: import triton
11: import triton.language as tl
12:
13:
14: @triton.jit
15: def triton_gemm_kernel(
16: a_ptr, b_ptr, c_ptr,
17: N_e, I, H,
18: stride_a_row, stride_a_col,
19: stride_b_row, stride_b_col,
20: stride_c_row, stride_c_col,
21: BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
22: ):
23: pid = tl.program_id(axis=0)
24: num_pid_m = tl.cdiv(N_e, BLOCK_M)
25: num_pid_n = tl.cdiv(I, BLOCK_N)
26: num_pid_in_group = num_pid_m * num_pid_n
27: group_id = pid // num_pid_in_group
28: first_pid_m = group_id * num_pid_m
29: group_size_m = min(num_pid_m, N_e - first_pid_m)
30:
31: pid_m = first_pid_m + (pid % group_size_m)
32: pid_n = (pid % num_pid_in_group) // group_size_m
33:
34: offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
35: offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
36: offs_k = tl.arange(0, BLOCK_K)
37:
38: a_ptrs = a_ptr + offs_m[:, None] * stride_a_row + offs_k[None, :] * stride_a_col
39: b_ptrs = b_ptr + offs_k[:, None] * stride_b_row + offs_n[None, :] * stride_b_
... (1621 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Debug correctness issue",
"timeout": 300000
}ref_out shape: torch.Size([262144, 1536]), triton_out shape: torch.Size([262144, 1536])
Max diff at row 112671, col 1123: 0.1748046875
Expert 55 owns row 112671 (offset 112640:114688)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -87,63 +87,90 @@
with torch.no_grad():
ref_out = ref_model(*inputs)
-# Triton GEMM - use explicit pointer arithmetic
-gate_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
-up_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
+# Test expert 55 specifically
expert_offsets = inputs[1]
+e = 55
+start = int(expert_offsets[e].item())
+end = int(expert_offsets[e + 1].item())
+N_e = end - start
+print(f'Expert {e}: rows {start}:{end}, N_e={N_e}')
-for e in range(E):
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e + 1].item())
- if end == start:
- continue
- N_e = end - start
+x_e = inputs[0][start:end]
+w_g = model.W_gate[e]
+w_u = model.W_up[e]
- x_e = inputs[0][start:end]
- w_g = model.W_gate[e]
- w_u = model.W_up[e]
+# Reference computation for this expert
+gate_ref = x_e.float() @ w_g.float()
+up_ref = x_e.float() @ w_u.float()
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
- grid = (grid_m * grid_n,)
+# Triton for this expert
+gate_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
+up_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
- # Use explicit pointer offset
- triton_gemm_kernel[grid](
- x_e, w_g, gate_out + start * I,
- N_e, I, H,
- x_e.stride(0), 1,
- w_g.stride(0), w_g.stride(1),
- I, 1,
- 64, 64, 64,
- )
- triton_gemm_kernel[grid](
- x_e, w_u, up_out + start * I,
- N_e, I, H,
- x_e.stride(0), 1,
- w_u.stride(0), w_u.stride(1),
- I, 1,
- 64, 64, 64,
- )
+grid_m = triton.cdiv(N_e, 64)
+grid_n = triton.cdiv(I, 64)
+grid = (grid_m * grid_n,)
+print(f'Grid: {grid}')
+triton_gemm_kernel[grid](
+ x_e, w_g, gate_out,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_g.stride(0), w_g.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
torch.cuda.synchronize()
-silu_gate = torch.nn.functional.silu(gate_out)
-triton_out = silu_gate * up_out
+triton_gemm_kernel[grid](
+ x_e, w_u, up_out,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_u.stride(0), w_u.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
+torch.cuda.synchronize()
-# Find worst expert
-print(f'ref_out shape: {ref_out.shape}, triton_out shape: {triton_out.shape}')
+# Check diffs within this expert
+gate_diff = (gate_out.float() - gate_ref).abs()
+up_diff = (up_out.float() - up_ref).abs()
-# Find which rows have biggest diff
-diffs = (ref_out.float() - triton_out.float()).abs()
-max_diff, max_idx = diffs.max(), diffs.argmax()
-max_row = max_idx // I
-max_col = max_idx % I
-print(f'Max diff at row {max_row}, col {max_col}: {max_diff.item()}')
+print(f'Gate max diff: {gate_diff.max().item()}')
+print(f'Up max diff: {up_diff.max().item()}')
-# Which expert owns that row?
-for e in range(E):
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e+1].item())
- if max_row >= start and max_row < end:
- print(f'Expert {e} owns row {max_row} (offset {start}:{end})')
- break+# Find worst row
+worst_gate_row = gate_diff.max(dim=1).values.argmax()
+print(f'Worst gate row: {worst_gate_row}, diff: {gate_diff.max(dim=1).values[worst_gate_row].item()}')
+
+# Now test with output as a slice
+gate_out2 = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
+up_out2 = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
+
+# Use slice approach
+triton_gemm_kernel[grid](
+ x_e, w_g, gate_out2[start:end],
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_g.stride(0), w_g.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
+torch.cuda.synchronize()
+
+triton_gemm_kernel[grid](
+ x_e, w_u, up_out2[start:end],
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_u.stride(0), w_u.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
+torch.cuda.synchronize()
+
+# Compare
+gate_diff2 = (gate_out2[start:end].float() - gate_ref).abs()
+up_diff2 = (up_out2[start:end].float() - up_ref).abs()
+
+print(f'Gate max diff (slice): {gate_diff2.max().item()}')
+print(f'Up max diff (slice): {up_diff2.max().item()}'){
"content": "<4867 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Debug expert 55",
"timeout": 120000
}Expert 55: rows 112640:114688, N_e=2048
Grid: (768,)
Gate max diff: 0.001954495906829834
Up max diff: 0.001954495906829834
Worst gate row: 1390, diff: 0.001954495906829834
Gate max diff (slice): 0.001954495906829834
Up max diff (slice): 0.001954495906829834
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -87,90 +87,87 @@
with torch.no_grad():
ref_out = ref_model(*inputs)
-# Test expert 55 specifically
+# Triton GEMM with slice output
+gate_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
+up_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
expert_offsets = inputs[1]
+
+for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
+
+ x_e = inputs[0][start:end]
+ w_g = model.W_gate[e]
+ w_u = model.W_up[e]
+
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
+ grid = (grid_m * grid_n,)
+
+ # Use slice for output
+ triton_gemm_kernel[grid](
+ x_e, w_g, gate_out[start:end],
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_g.stride(0), w_g.stride(1),
+ I, 1,
+ 64, 64, 64,
+ )
+ triton_gemm_kernel[grid](
+ x_e, w_u, up_out[start:end],
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_u.stride(0), w_u.stride(1),
+ I, 1,
+ 64, 64, 64,
+ )
+
+torch.cuda.synchronize()
+
+silu_gate = torch.nn.functional.silu(gate_out)
+triton_out = silu_gate * up_out
+
+# Find worst expert
+diffs = (ref_out.float() - triton_out.float()).abs()
+max_diff, max_idx = diffs.max(), diffs.argmax()
+max_row = max_idx // I
+max_col = max_idx % I
+print(f'Max diff at row {max_row}, col {max_col}: {max_diff.item()}')
+
+# Which expert owns that row?
+for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e+1].item())
+ if max_row >= start and max_row < end:
+ print(f'Expert {e} owns row {max_row} (offset {start}:{end})')
+ break
+
+# Now check the GEMM for this expert specifically
e = 55
start = int(expert_offsets[e].item())
end = int(expert_offsets[e + 1].item())
N_e = end - start
-print(f'Expert {e}: rows {start}:{end}, N_e={N_e}')
x_e = inputs[0][start:end]
w_g = model.W_gate[e]
w_u = model.W_up[e]
-# Reference computation for this expert
gate_ref = x_e.float() @ w_g.float()
up_ref = x_e.float() @ w_u.float()
-# Triton for this expert
-gate_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
-up_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
+# Check if our output buffer for this expert matches reference
+print(f'\nExpert 55 GEMM check:')
+print(f'Gate out buffer rows 112640:114688:')
+print(f' Max diff: {(gate_out[start:end].float() - gate_ref).abs().max().item()}')
+print(f' Up out buffer rows 112640:114688:')
+print(f' Max diff: {(up_out[start:end].float() - up_ref).abs().max().item()}')
-grid_m = triton.cdiv(N_e, 64)
-grid_n = triton.cdiv(I, 64)
-grid = (grid_m * grid_n,)
-print(f'Grid: {grid}')
-
-triton_gemm_kernel[grid](
- x_e, w_g, gate_out,
- N_e, I, H,
- x_e.stride(0), 1,
- w_g.stride(0), w_g.stride(1),
- I, 1,
- 64, 64, 64,
-)
-torch.cuda.synchronize()
-
-triton_gemm_kernel[grid](
- x_e, w_u, up_out,
- N_e, I, H,
- x_e.stride(0), 1,
- w_u.stride(0), w_u.stride(1),
- I, 1,
- 64, 64, 64,
-)
-torch.cuda.synchronize()
-
-# Check diffs within this expert
-gate_diff = (gate_out.float() - gate_ref).abs()
-up_diff = (up_out.float() - up_ref).abs()
-
-print(f'Gate max diff: {gate_diff.max().item()}')
-print(f'Up max diff: {up_diff.max().item()}')
-
-# Find worst row
-worst_gate_row = gate_diff.max(dim=1).values.argmax()
-print(f'Worst gate row: {worst_gate_row}, diff: {gate_diff.max(dim=1).values[worst_gate_row].item()}')
-
-# Now test with output as a slice
-gate_out2 = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
-up_out2 = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
-
-# Use slice approach
-triton_gemm_kernel[grid](
- x_e, w_g, gate_out2[start:end],
- N_e, I, H,
- x_e.stride(0), 1,
- w_g.stride(0), w_g.stride(1),
- I, 1,
- 64, 64, 64,
-)
-torch.cuda.synchronize()
-
-triton_gemm_kernel[grid](
- x_e, w_u, up_out2[start:end],
- N_e, I, H,
- x_e.stride(0), 1,
- w_u.stride(0), w_u.stride(1),
- I, 1,
- 64, 64, 64,
-)
-torch.cuda.synchronize()
-
-# Compare
-gate_diff2 = (gate_out2[start:end].float() - gate_ref).abs()
-up_diff2 = (up_out2[start:end].float() - up_ref).abs()
-
-print(f'Gate max diff (slice): {gate_diff2.max().item()}')
-print(f'Up max diff (slice): {up_diff2.max().item()}')+# But what about the overall output?
+print(f'\nOverall output for expert 55 region:')
+print(f' Gate ref max: {gate_ref.abs().max().item()}')
+print(f' Gate out max: {gate_out[start:end].abs().max().item()}')
+print(f' Up ref max: {up_ref.abs().max().item()}')
+print(f' Up out max: {up_out[start:end].abs().max().item()}'){
"content": "<5183 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Debug expert 55 GEMM and SwiGLU",
"timeout": 300000
}Max diff at row 36272, col 1338: 0.001953125
Expert 17 owns row 36272 (offset 34816:36864)
Expert 55 GEMM check:
Gate out buffer rows 112640:114688:
Max diff: 0.0019475221633911133
Up out buffer rows 112640:114688:
Max diff: 0.0019519329071044922
Overall output for expert 55 region:
Gate ref max: 0.6343462467193604
Gate out max: 0.6328125
Up ref max: 0.6757517457008362
Up out max: 0.67578125
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python check.py 2>&1 | head -60",
"description": "Run check.py",
"timeout": 600000
}FAIL: shape 0 {'T_total': 32768, 'H': 4096, 'I': 1536, 'E': 128, 'K': 8} seed 42: tolerance exceeded: max_abs_diff=0.209961 (atol=0.02, rtol=0.02)
{
"command": "<1430 chars \u2014 see below>",
"description": "Debug model output",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
import reference as ref_mod
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
T_perm = T_total * K
for mod, t, h, ii, ee, kk in [
(solution, T_total, H, I, E, K),
(ref_mod, T_total, H, I, E, K),
]:
mod.T_total = t
mod.H = h
mod.I = ii
mod.E = ee
mod.K = kk
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
ref_model = ref_mod.Model(*ref_mod.get_init_inputs()).cuda()
ref_model.load_state_dict(model.state_dict())
ref_model.eval()
with torch.no_grad():
ref_out = ref_model(*inputs)
# Forward using solution
with torch.no_grad():
sol_out = model(*inputs)
diffs = (ref_out.float() - sol_out.float()).abs()
max_diff, max_idx = diffs.max().item(), diffs.argmax().item()
max_row = max_idx // I
max_col = max_idx % I
print(f'Max diff: {max_diff} at row {max_row}, col {max_col}')
print(f'ref value: {ref_out.flatten()[max_idx].item()}')
print(f'sol value: {sol_out.flatten()[max_idx].item()}')
# Check row stats
row_diffs = diffs.max(dim=1).values
print(f'Max row diff: {row_diffs.max().item()} at row {row_diffs.argmax().item()}')
# Check col stats
col_diffs = diffs.max(dim=0).values
print(f'Max col diff: {col_diffs.max().item()} at col {col_diffs.argmax().item()}')
" 2>&1Max diff: 0.189453125 at row 204884, col 30
ref value: -0.189453125
sol value: 0.0
Max row diff: 0.189453125 at row 204884
Max col diff: 0.189453125 at col 30
{
"command": "<829 chars \u2014 see below>",
"description": "Find which expert owns row 204884"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
T_perm = T_total * K
for mod, t, h, ii, ee, kk in [
(solution, T_total, H, I, E, K),
]:
mod.T_total = t
mod.H = h
mod.I = ii
mod.E = ee
mod.K = kk
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
# Find what row 204884 is
expert_offsets = inputs[1]
for e in range(E):
start = int(expert_offsets[e].item())
end = int(expert_offsets[e + 1].item())
if 204884 >= start and 204884 < end:
print(f'Expert {e} owns row 204884 (offset {start}:{end})')
print(f'Local row index: {204884 - start}')
break
" 2>&1Expert 100 owns row 204884 (offset 204800:206848)
Local row index: 84
{
"command": "<745 chars \u2014 see below>",
"description": "Check expert 100 grid"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
T_perm = T_total * K
for mod, t, h, ii, ee, kk in [
(solution, T_total, H, I, E, K),
]:
mod.T_total = t
mod.H = h
mod.I = ii
mod.E = ee
mod.K = kk
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
expert_offsets = inputs[1]
e = 100
start = int(expert_offsets[e].item())
end = int(expert_offsets[e + 1].item())
N_e = end - start
print(f'Expert {e}: rows {start}:{end}, N_e={N_e}')
print(f'Grid: {triton.cdiv(N_e, 64)} x {triton.cdiv(1536, 64)}')
" 2>&1Traceback (most recent call last):
File "<string>", line 29, in <module>
NameError: name 'triton' is not defined
Expert 100: rows 204800:206848, N_e=2048
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,4 +1,4 @@
-"""Debug correctness issue on shape 0"""
+"""Debug expert 100 in detail"""
import torch
import triton
import triton.language as tl
@@ -74,79 +74,8 @@
torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
-# Reference
-import reference as ref_mod
-ref_mod.T_total = T_total
-ref_mod.H = H
-ref_mod.I = I
-ref_mod.E = E
-ref_mod.K = K
-ref_model = ref_mod.Model(*ref_mod.get_init_inputs()).cuda()
-ref_model.load_state_dict(model.state_dict())
-ref_model.eval()
-with torch.no_grad():
- ref_out = ref_model(*inputs)
-
-# Triton GEMM with slice output
-gate_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
-up_out = torch.empty(T_perm, I, device='cuda', dtype=torch.bfloat16)
expert_offsets = inputs[1]
-
-for e in range(E):
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e + 1].item())
- if end == start:
- continue
- N_e = end - start
-
- x_e = inputs[0][start:end]
- w_g = model.W_gate[e]
- w_u = model.W_up[e]
-
- grid_m = triton.cdiv(N_e, 64)
- grid_n = triton.cdiv(I, 64)
- grid = (grid_m * grid_n,)
-
- # Use slice for output
- triton_gemm_kernel[grid](
- x_e, w_g, gate_out[start:end],
- N_e, I, H,
- x_e.stride(0), 1,
- w_g.stride(0), w_g.stride(1),
- I, 1,
- 64, 64, 64,
- )
- triton_gemm_kernel[grid](
- x_e, w_u, up_out[start:end],
- N_e, I, H,
- x_e.stride(0), 1,
- w_u.stride(0), w_u.stride(1),
- I, 1,
- 64, 64, 64,
- )
-
-torch.cuda.synchronize()
-
-silu_gate = torch.nn.functional.silu(gate_out)
-triton_out = silu_gate * up_out
-
-# Find worst expert
-diffs = (ref_out.float() - triton_out.float()).abs()
-max_diff, max_idx = diffs.max(), diffs.argmax()
-max_row = max_idx // I
-max_col = max_idx % I
-print(f'Max diff at row {max_row}, col {max_col}: {max_diff.item()}')
-
-# Which expert owns that row?
-for e in range(E):
- start = int(expert_offsets[e].item())
- end = int(expert_offsets[e+1].item())
- if max_row >= start and max_row < end:
- print(f'Expert {e} owns row {max_row} (offset {start}:{end})')
- break
-
-# Now check the GEMM for this expert specifically
-e = 55
+e = 100
start = int(expert_offsets[e].item())
end = int(expert_offsets[e + 1].item())
N_e = end - start
@@ -155,19 +84,58 @@
w_g = model.W_gate[e]
w_u = model.W_up[e]
+print(f'Expert {e}: rows {start}:{end}, N_e={N_e}')
+
+# Reference
gate_ref = x_e.float() @ w_g.float()
up_ref = x_e.float() @ w_u.float()
-# Check if our output buffer for this expert matches reference
-print(f'\nExpert 55 GEMM check:')
-print(f'Gate out buffer rows 112640:114688:')
-print(f' Max diff: {(gate_out[start:end].float() - gate_ref).abs().max().item()}')
-print(f' Up out buffer rows 112640:114688:')
-print(f' Max diff: {(up_out[start:end].float() - up_ref).abs().max().item()}')
+# Triton output for this expert
+gate_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
+up_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
-# But what about the overall output?
-print(f'\nOverall output for expert 55 region:')
-print(f' Gate ref max: {gate_ref.abs().max().item()}')
-print(f' Gate out max: {gate_out[start:end].abs().max().item()}')
-print(f' Up ref max: {up_ref.abs().max().item()}')
-print(f' Up out max: {up_out[start:end].abs().max().item()}')+grid_m = triton.cdiv(N_e, 64)
+grid_n = triton.cdiv(I, 64)
+grid = (grid_m * grid_n,)
+print(f'Grid: {grid}')
+
+triton_gemm_kernel[grid](
+ x_e, w_g, gate_out,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_g.stride(0), w_g.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
+torch.cuda.synchronize()
+
+triton_gemm_kernel[grid](
+ x_e, w_u, up_out,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_u.stride(0), w_u.stride(1),
+ I, 1,
+ 64, 64, 64,
+)
+torch.cuda.synchronize()
+
+# Find the worst row (row 84, col 30)
+row = 84
+col = 30
+print(f'\nAt local row {row}, col {col}:')
+print(f' Gate ref: {gate_ref[row, col].item()}')
+print(f' Gate out: {gate_out[row, col].item()}')
+print(f' Up ref: {up_ref[row, col].item()}')
+print(f' Up out: {up_out[row, col].item()}')
+
+# Check all columns for row 84
+row_diff = (gate_out[row].float() - gate_ref[row]).abs()
+print(f'\nRow 84 gate diff stats:')
+print(f' Max: {row_diff.max().item()}, Mean: {row_diff.mean().item()}')
+print(f' ArgMax col: {row_diff.argmax().item()}')
+
+# Check all rows for col 30
+col_diff = (gate_out[:, 30].float() - gate_ref[:, 30]).abs()
+print(f'\nCol 30 gate diff stats:')
+print(f' Max: {col_diff.max().item()}, Mean: {col_diff.mean().item()}')
+print(f' ArgMax row: {col_diff.argmax().item()}'){
"content": "<4091 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Debug expert 100",
"timeout": 300000
}Expert 100: rows 204800:206848, N_e=2048
Grid: (768,)
At local row 84, col 30:
Gate ref: 0.08788036555051804
Gate out: 0.087890625
Up ref: 0.004511021077632904
Up out: 0.0045166015625
Row 84 gate diff stats:
Max: 0.0009746551513671875, Mean: 0.00014581058348994702
ArgMax col: 1292
Col 30 gate diff stats:
Max: 0.000964730978012085, Mean: 0.00014076537627261132
ArgMax row: 660
{
"command": "<2162 chars \u2014 see below>",
"description": "Check zero values in output",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
import reference as ref_mod
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
T_perm = T_total * K
for mod, t, h, ii, ee, kk in [
(solution, T_total, H, I, E, K),
(ref_mod, T_total, H, I, E, K),
]:
mod.T_total = t
mod.H = h
mod.I = ii
mod.E = ee
mod.K = kk
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
ref_model = ref_mod.Model(*ref_mod.get_init_inputs()).cuda()
ref_model.load_state_dict(model.state_dict())
ref_model.eval()
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = model(*inputs)
# Check row 204884, col 30
row, col = 204884, 30
print(f'At row {row}, col {col}:')
print(f' ref_out: {ref_out[row, col].item()}')
print(f' sol_out: {sol_out[row, col].item()}')
# Check the gate and up before SwiGLU
# But we don't store those... let me check the final result differently
# Find max diff across all data
diffs = (ref_out.float() - sol_out.float()).abs()
max_diff, max_idx = diffs.max().item(), diffs.argmax().item()
max_row = max_idx // I
max_col = max_idx % I
print(f'\nMax diff: {max_diff} at ({max_row}, {max_col})')
print(f'ref: {ref_out.flatten()[max_idx].item()}, sol: {sol_out.flatten()[max_idx].item()}')
# Check how many values are exactly 0 in sol_out
zero_count = (sol_out == 0).sum().item()
print(f'\nZero count in sol_out: {zero_count}')
print(f'Zero count in ref_out: {(ref_out == 0).sum().item()}')
# Check if all zeros are in the same expert region
if zero_count > 0:
zero_mask = (sol_out == 0)
zero_rows = zero_mask.any(dim=1).nonzero().squeeze()
if zero_rows.numel() == 1:
zero_row = zero_rows.item()
for e in range(E):
start = int(inputs[1][e].item())
end = int(inputs[1][e+1].item())
if zero_row >= start and zero_row < end:
print(f'All zeros in expert {e} (offset {start}:{end})')
break
else:
print(f'Multiple zero rows: {zero_rows[:10]}')
" 2>&1At row 204884, col 30:
ref_out: 0.0017242431640625
sol_out: 0.0
Max diff: 0.173828125 at (205433, 1531)
ref: -0.173828125, sol: 0.0
Zero count in sol_out: 402653184
Zero count in ref_out: 7
Multiple zero rows: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py
@@ -1,10 +1,9 @@
-"""Debug expert 100 in detail"""
+"""Test triton_grouped_gemm function directly"""
import torch
import triton
import triton.language as tl
import solution
-# Test with shape 0
T_total, H, I, E, K = 32768, 4096, 1536, 128, 8
T_perm = T_total * K
@@ -68,74 +67,50 @@
tl.store(c_ptrs, accumulator, mask=mask_m_out & mask_n_out)
+def triton_grouped_gemm(hidden_states, W, expert_offsets, E):
+ T_perm, H = hidden_states.shape
+ I = W.shape[2]
+ output = torch.empty(T_perm, I, device=hidden_states.device, dtype=hidden_states.dtype)
+
+ for e in range(E):
+ start = int(expert_offsets[e].item())
+ end = int(expert_offsets[e + 1].item())
+ if end == start:
+ continue
+ N_e = end - start
+
+ x_e = hidden_states[start:end]
+ w_e = W[e]
+ out_e_base = output + start * I
+
+ grid_m = triton.cdiv(N_e, 64)
+ grid_n = triton.cdiv(I, 64)
+ grid = (grid_m * grid_n,)
+
+ triton_gemm_kernel[grid](
+ x_e, w_e, out_e_base,
+ N_e, I, H,
+ x_e.stride(0), 1,
+ w_e.stride(0), w_e.stride(1),
+ I, 1,
+ 64, 64, 64,
+ )
+ return output
+
+
# Build model
model = solution.Model(*solution.get_init_inputs())
model = model.cuda()
torch.manual_seed(42)
inputs = [t.cuda() for t in solution.get_inputs()]
-expert_offsets = inputs[1]
-e = 100
-start = int(expert_offsets[e].item())
-end = int(expert_offsets[e + 1].item())
-N_e = end - start
-
-x_e = inputs[0][start:end]
-w_g = model.W_gate[e]
-w_u = model.W_up[e]
-
-print(f'Expert {e}: rows {start}:{end}, N_e={N_e}')
-
-# Reference
-gate_ref = x_e.float() @ w_g.float()
-up_ref = x_e.float() @ w_u.float()
-
-# Triton output for this expert
-gate_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
-up_out = torch.empty(N_e, I, device='cuda', dtype=torch.bfloat16)
-
-grid_m = triton.cdiv(N_e, 64)
-grid_n = triton.cdiv(I, 64)
-grid = (grid_m * grid_n,)
-print(f'Grid: {grid}')
-
-triton_gemm_kernel[grid](
- x_e, w_g, gate_out,
- N_e, I, H,
- x_e.stride(0), 1,
- w_g.stride(0), w_g.stride(1),
- I, 1,
- 64, 64, 64,
-)
+# Run triton_grouped_gemm directly
+gate_out = triton_grouped_gemm(inputs[0], model.W_gate, inputs[1], E)
torch.cuda.synchronize()
-triton_gemm_kernel[grid](
- x_e, w_u, up_out,
- N_e, I, H,
- x_e.stride(0), 1,
- w_u.stride(0), w_u.stride(1),
- I, 1,
- 64, 64, 64,
-)
-torch.cuda.synchronize()
-
-# Find the worst row (row 84, col 30)
-row = 84
-col = 30
-print(f'\nAt local row {row}, col {col}:')
-print(f' Gate ref: {gate_ref[row, col].item()}')
-print(f' Gate out: {gate_out[row, col].item()}')
-print(f' Up ref: {up_ref[row, col].item()}')
-print(f' Up out: {up_out[row, col].item()}')
-
-# Check all columns for row 84
-row_diff = (gate_out[row].float() - gate_ref[row]).abs()
-print(f'\nRow 84 gate diff stats:')
-print(f' Max: {row_diff.max().item()}, Mean: {row_diff.mean().item()}')
-print(f' ArgMax col: {row_diff.argmax().item()}')
-
-# Check all rows for col 30
-col_diff = (gate_out[:, 30].float() - gate_ref[:, 30]).abs()
-print(f'\nCol 30 gate diff stats:')
-print(f' Max: {col_diff.max().item()}, Mean: {col_diff.mean().item()}')
-print(f' ArgMax row: {col_diff.argmax().item()}')+print(f'gate_out shape: {gate_out.shape}')
+print(f'gate_out non-zero count: {(gate_out != 0).sum().item()}')
+print(f'gate_out max: {gate_out.abs().max().item()}')
+print(f'gate_out sample at row 0, col 0: {gate_out[0, 0].item()}')
+print(f'gate_out sample at row 100, col 100: {gate_out[100, 100].item()}')
+print(f'gate_out sample at row 200000, col 100: {gate_out[200000, 100].item()}'){
"content": "<3694 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/test_triton.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python test_triton.py 2>&1",
"description": "Test triton_grouped_gemm",
"timeout": 300000
}gate_out shape: torch.Size([262144, 1536])
gate_out non-zero count: 0
gate_out max: 0.0
gate_out sample at row 0, col 0: 0.0
gate_out sample at row 100, col 100: 0.0
gate_out sample at row 200000, col 100: 0.0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py
@@ -76,18 +76,18 @@
x_e = hidden_states[start:end]
w_e = W[e]
- out_e_base = output + start * I
+ out_e = output[start:end]
grid_m = triton.cdiv(N_e, 64)
grid_n = triton.cdiv(I, 64)
grid = (grid_m * grid_n,)
triton_gemm_kernel[grid](
- x_e, w_e, out_e_base,
+ x_e, w_e, out_e,
N_e, I, H,
x_e.stride(0), 1,
w_e.stride(0), w_e.stride(1),
- I, 1,
+ out_e.stride(0), 1,
64, 64, 64,
)
return output
{
"content": "<4518 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu/solution.py"
}Wrote file successfully.{
"command": "<579 chars \u2014 see below>",
"description": "Test slice approach",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/06_sonic_moe_swiglu && uv run python -c "
import torch
import solution
solution.T_total = 4096
solution.H = 2048
solution.I = 1024
solution.E = 64
solution.K = 4
model = solution.Model(*solution.get_init_inputs()).cuda()
inputs = [t.cuda() for t in solution.get_inputs()]
gate_out = solution.triton_grouped_gemm(inputs[0], model.W_gate, inputs[1], solution.E)
torch.cuda.synchronize()
print(f'gate_out non-zero: {(gate_out != 0).sum().item()} / {gate_out.numel()}')
print(f'gate_out max: {gate_out.abs().max().item()}')
" 2>&1gate_out non-zero: 16777216 / 16777216
gate_out max: 0.4921875