from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def _w4a16_gemv_splitk_kernel(
x_ptr,
w_q_ptr,
scales_ptr,
zeros_ptr,
out_ptr,
N,
K,
stride_xk,
stride_wk,
stride_wn,
stride_sg,
stride_sn,
stride_zg,
stride_zn,
GROUP_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_k_blocks = tl.cdiv(K, BLOCK_K)
pid_n = pid // num_k_blocks
k_idx = pid % num_k_blocks
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
k_start = k_idx * BLOCK_K
HALF_K: tl.constexpr = BLOCK_K // 2
k_even = k_start + tl.arange(0, HALF_K) * 2
k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
x_even = tl.load(
x_ptr + k_even[None, :] * stride_xk,
mask=k_even[None, :] < K,
other=0.0,
)
x_odd = tl.load(
x_ptr + k_odd[None, :] * stride_xk,
mask=k_odd[None, :] < K,
other=0.0,
)
packed_k = k_start // 2 + tl.arange(0, HALF_K)
w_packed = tl.load(
w_q_ptr + packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
mask=(packed_k[:, None] < K // 2) & mask_n[None, :],
other=0,
eviction_policy="evict_last",
)
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = (w_packed >> 4).to(tl.float32)
g = k_start // GROUP_SIZE
s = tl.load(
scales_ptr + g * stride_sg + offs_n * stride_sn,
mask=mask_n,
other=1.0,
).to(tl.float32)
z = tl.load(
zeros_ptr + g * stride_zg + offs_n * stride_zn,
mask=mask_n,
other=0.0,
).to(tl.float32)
w_lo_dq = ((w_lo - z[None, :]) * s[None, :]).to(tl.bfloat16)
w_hi_dq = ((w_hi - z[None, :]) * s[None, :]).to(tl.bfloat16)
partial = tl.dot(x_even, w_lo_dq) + tl.dot(x_odd, w_hi_dq)
acc = partial.reshape(BLOCK_N)
tl.atomic_add(out_ptr + offs_n, acc, mask=mask_n)
def w4a16_gemv_splitk(x, w_q, scales, zeros):
M, K = x.shape
N = w_q.shape[1]
assert M == 1
BLOCK_N = 128
BLOCK_K = 128
acc = torch.zeros((N,), dtype=torch.float32, device=x.device)
num_pid_n = triton.cdiv(N, BLOCK_N)
num_k_blocks = triton.cdiv(K, BLOCK_K)
grid = (num_pid_n * num_k_blocks,)
_w4a16_gemv_splitk_kernel[grid](
x, w_q, scales, zeros, acc,
N, K,
x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
GROUP_SIZE=GROUP_SIZE,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return acc.to(torch.bfloat16).unsqueeze(0)
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
],
key=["M", "N", "K"],
)
@triton.jit
def _w4a16_gemm_kernel(
x_ptr,
w_q_ptr,
scales_ptr,
zeros_ptr,
out_ptr,
M,
N,
K,
stride_xm,
stride_xk,
stride_wk,
stride_wn,
stride_sg,
stride_sn,
stride_zg,
stride_zn,
stride_om,
stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_m = offs_m < M
mask_n = offs_n < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
HALF_K: tl.constexpr = BLOCK_K // 2
for k_start in range(0, K, BLOCK_K):
k_even = k_start + tl.arange(0, HALF_K) * 2
k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
x_even = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_even[None, :] * stride_xk,
mask=mask_m[:, None] & (k_even[None, :] < K),
other=0.0,
)
x_odd = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_odd[None, :] * stride_xk,
mask=mask_m[:, None] & (k_odd[None, :] < K),
other=0.0,
)
packed_k = k_start // 2 + tl.arange(0, HALF_K)
w_packed = tl.load(
w_q_ptr + packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
mask=(packed_k[:, None] < K // 2) & mask_n[None, :],
other=0,
eviction_policy="evict_last",
)
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = (w_packed >> 4).to(tl.float32)
g = k_start // GROUP_SIZE
s = tl.load(
scales_ptr + g * stride_sg + offs_n * stride_sn,
mask=mask_n,
other=1.0,
).to(tl.float32)
z = tl.load(
zeros_ptr + g * stride_zg + offs_n * stride_zn,
mask=mask_n,
other=0.0,
).to(tl.float32)
w_lo_dq = ((w_lo - z[None, :]) * s[None, :]).to(tl.bfloat16)
w_hi_dq = ((w_hi - z[None, :]) * s[None, :]).to(tl.bfloat16)
acc += tl.dot(x_even, w_lo_dq) + tl.dot(x_odd, w_hi_dq)
tl.store(
out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on,
acc.to(tl.bfloat16),
mask=mask_m[:, None] & mask_n[None, :],
)
def w4a16_gemm_triton(x, w_q, scales, zeros):
M, K = x.shape
N = w_q.shape[1]
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
)
_w4a16_gemm_kernel[grid](
x, w_q, scales, zeros, out,
M, N, K,
x.stride(0), x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out.stride(0), out.stride(1),
GROUP_SIZE=GROUP_SIZE,
)
return out
def w4a16_gemm(x, w_q, scales, zeros):
M = x.shape[0]
if M == 1:
return w4a16_gemv_splitk(x, w_q, scales, zeros)
return w4a16_gemm_triton(x, w_q, scales, zeros)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
w_g = w_full.view(n_groups, group_size, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
lo = w_q[0::2].to(torch.uint8) & 0xF
hi = w_q[1::2].to(torch.uint8) & 0xF
w_packed = (lo | (hi << 4)).contiguous()
self.register_buffer("w_q", w_packed)
self.register_buffer("scales", scales_2d)
self.register_buffer("zeros", zeros_2d)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return w4a16_gemm(x, self.w_q, self.scales, self.zeros)
M = 1
N = 12288
K = 4096
def get_inputs():
x = torch.randn(M, K, dtype=torch.bfloat16)
return [x]
def get_init_inputs():
return [M, N, K]
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.966 ms=0.766
shape=0 variant=compiled tflops=0.663 gbps=176.202 ms=0.152
shape=0 variant=sota tflops=2.291 gbps=609.328 ms=0.044
shape=0 variant=solution tflops=2.874 gbps=764.375 ms=0.035
shape=0 solution_peak_fraction=0.4247
shape=1 variant=eager tflops=4.180 gbps=36.054 ms=0.771
shape=1 variant=compiled tflops=20.740 gbps=178.913 ms=0.155
shape=1 variant=sota tflops=23.200 gbps=200.127 ms=0.139
shape=1 variant=solution tflops=36.145 gbps=311.796 ms=0.089
shape=1 solution_peak_fraction=0.1732
shape=2 variant=eager tflops=31.354 gbps=42.740 ms=0.822
shape=2 variant=compiled tflops=121.017 gbps=164.960 ms=0.213
shape=2 variant=sota tflops=133.108 gbps=181.443 ms=0.194
shape=2 variant=solution tflops=80.826 gbps=110.175 ms=0.319
shape=2 solution_peak_fraction=0.0612
shape=3 variant=eager tflops=0.182 gbps=48.457 ms=0.184
shape=3 variant=compiled tflops=0.513 gbps=136.450 ms=0.065
shape=3 variant=sota tflops=1.295 gbps=344.707 ms=0.026
shape=3 variant=solution tflops=1.323 gbps=352.101 ms=0.025
shape=3 solution_peak_fraction=0.1956
shape=4 variant=eager tflops=2.060 gbps=34.853 ms=0.912
shape=4 variant=compiled tflops=10.382 gbps=175.615 ms=0.181
shape=4 variant=sota tflops=11.344 gbps=191.883 ms=0.166
shape=4 variant=solution tflops=22.800 gbps=385.665 ms=0.082
shape=4 solution_peak_fraction=0.2143
peak_fraction: 0.1800
RESULT: OK
PASS
{
"problem": "07_w4a16_gemm",
"harness": "opencode",
"model": "zai/glm-5.1",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1800,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 130478, "output_tokens": 26285, "cache_read_tokens": 3698496, "cache_creation_tokens": 0, "reasoning_tokens": 75691, "total_cost_usd": null}
}
{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive W4A16 weight-only quantized GEMM reference (correctness only).
2:
3: AWQ/GPTQ-style scheme:
4: x: (M, K) bf16
5: w_q: (K // 2, N) uint8 -- two int4 weights packed per byte (low nibble = even-K, high = odd-K)
6: scales: (K // group, N) bf16
7: zeros: (K // group, N) bf16 -- asymmetric (stored already as float zero-point)
8: out: (M, N) bf16
9:
10: Dequant (per group along K):
11: w_bf[k, n] = (w_q[k, n] - zeros[k // group, n]) * scales[k // group, n]
12: where w_q[k, n] is the unpacked 4-bit value (0..15).
13:
14: This reference unpacks to a full bf16 matrix and then runs torch.matmul. Slow and
15: memory-heavy on the dequant; the agent's solution must fuse unpack+GEMM.
16: """
17: from __future__ import annotations
18:
19: import torch
20: import torch.nn as nn
21:
22: OP_TYPE = "gemm_w4a16"
23: SUPPORTED_PRECISIONS = ["int4_bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26: GROUP_SIZE = 128
27:
28:
29: def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
30: """Pack (K, N) uint8 in [0,15] into (K//2, N) uint8.
31:
32: Even rows go in the low nibble, odd rows in the high nibble.
33: """
34: K, N = w_q.shape
35: assert K % 2 == 0
36: lo = w_q[0::2].to(torch.uint8) & 0xF
37: hi = w_q[1::2].to(torch.uint8) & 0xF
38: ret
... (3312 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for FP8 GEMM.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or None
42:
43: # --- Per-shape
... (2909 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for FP8 GEMM.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "compute")
40: flops_formul
... (3415 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for W4A16 GEMM.
2:
3: Llama-style up_proj / qkv_proj shapes. Decode (M=1) is the bandwidth-bound
4: case every inference engine optimizes -- it's the bar to beat.
5: """
6:
7: SHAPES = [
8: {"M": 1, "N": 12288, "K": 4096}, # decode: memory-bound on int4 weight read
9: {"M": 32, "N": 12288, "K": 4096}, # small prefill: mixed regime
10: {"M": 256, "N": 12288, "K": 4096}, # larger prefill: approaching compute
11: {"M": 1, "N": 4096, "K": 4096}, # decode: square shape
12: {"M": 16, "N": 14336, "K": 4096}, # speculative-decode-ish
13: ]
(End of file - total 13 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 07_w4a16_gemm
2: display_name: "W4A16 Weight-only Quantized GEMM"
3: precision: int4_bf16
4: regime: memory # decode-dominant; M=1 is bandwidth-bound on the int4 weight stream
5:
6: # Dense-equivalent FLOPs (matmul work, ignoring dequant arithmetic).
7: flops_formula: "2 * M * N * K"
8:
9: # Bytes moved per call (memory roofline):
10: # x: M*K*2 (bf16 activations, streamed in once)
11: # w_q: (K/2)*N (packed int4, 0.5 B/elem)
12: # scales: (K/128)*N*2 (bf16 scales)
13: # zeros: (K/128)*N*2 (bf16 zero-points)
14: # out: M*N*2 (bf16 store)
15: bytes_formula: "M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.10 # group-quant adds noise on top of bf16 accumulator slop
23:
24: # Forbidden ops -- agent must write the unpack + GEMM themselves, not call a
25: # vendor library that does both.
26: forbidden:
27: - "bitsandbytes.functional.dequantize_4bit"
28: - "bitsandbytes.functional.gemv_4bit"
29: - "marlin_kernel.gemm"
30: - "torch.nn.functional.linear"
31:
32: sota:
33: name: "bitsandbytes NF4 (gemv_4bit / dequantize_4bit + matmul)"
34: url: "https://github.com/TimDettmers/bitsandbytes"
35: function: "bitsandbytes.functional.gemv_4bit"
36: notes: |
37:
... (757 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for W4A16 GEMM.
2:
3: Library survey on RTX PRO 6000 Blackwell (SM120, CC 12.0):
4:
5: - Marlin (IST-DASLab): no SM120 kernels (Ampere/Hopper only). Skip.
6: - GPTQ-Triton (fpgaminer): unmaintained; pure Triton path works on SM120
7: but is not faster than Marlin on its target HW
8: and has no Blackwell tuning. Skip as primary.
9: - AWQ (mit-han-lab/llm-awq): CUDA kernels not built for SM120 in the wheel.
10: Skip.
11: - bitsandbytes >= 0.49.2: CUDA kernels compile and run on SM120 (verified
12: on this machine). Different quant scheme (NF4,
13: symmetric, blocksize 64) than our reference's
14: AWQ-style asymmetric INT4 with group_size 128,
15: but it occupies the same memory regime and is
16: the only tuned W4A16-class kernel that runs on
17: SM120 today. Used here as an *informational*
18: SOTA line, not as a numerical reference.
19:
20: The benchmark calls `sota_forward(x, ref_model)` and times it; correctness is
21: NOT checked against this path (the quant scheme differs).
22: """
23: from
... (2725 more chars){
"todos": [
{
"content": "Explore environment: check Triton version, CUDA toolkit, existing solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement Triton fused dequant+GEMM kernel in solution.py",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py to verify correctness across all shapes",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py and profile to identify bottlenecks",
"status": "pending",
"priority": "medium"
... (132 more chars)content: 'Explore environment: check Triton version, CUDA toolkit, existing solution.py'
status: 'in_progress'
priority: 'high'
content: 'Implement Triton fused dequant+GEMM kernel in solution.py'
status: 'pending'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py and profile to identify bottlenecks'
status: 'pending'
priority: 'medium'
content: 'Optimize kernel performance (iterate)'
status: 'pending'
priority: 'medium'{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py 2>/dev/null && echo \"exists\" || echo \"no solution.py\"",
"description": "Check existing solution.py"
}no solution.py
{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>/dev/null || echo \"nvidia-smi not available\"",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "uv run python -c \"import triton; print(triton.__version__)\"",
"description": "Check Triton version"
}3.6.0
{
"command": "uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_capability())\"",
"description": "Check torch version and compute cap"
}2.11.0+cu130
(12, 0)
{
"todos": [
{
"content": "Explore environment: check Triton version, CUDA toolkit, existing solution.py",
"priority": "high",
"status": "completed"
},
{
"content": "Implement Triton fused dequant+GEMM kernel in solution.py",
"priority": "high",
"status": "in_progress"
},
{
"content": "Run check.py to verify correctness across all shapes",
"priority": "high",
"status": "pending"
},
{
"content": "Run benchmark.py and profile to identify bottlenecks",
"priority": "medium",
"status": "pending"
... (134 more chars)content: 'Explore environment: check Triton version, CUDA toolkit, existing solution.py'
status: 'completed'
priority: 'high'
content: 'Implement Triton fused dequant+GEMM kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Run check.py to verify correctness across all shapes'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py and profile to identify bottlenecks'
status: 'pending'
priority: 'medium'
content: 'Optimize kernel performance (iterate)'
status: 'pending'
priority: 'medium'{
"command": "<452 chars \u2014 see below>",
"description": "Quick Triton+SM120 smoke test"
}uv run python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernel(X, Y, N: tl.constexpr):
pid = tl.program_id(0)
offs = pid * N + tl.arange(0, N)
x = tl.load(X + offs)
tl.store(Y + offs, x + 1)
x = torch.randn(256, device='cuda', dtype=torch.bfloat16)
y = torch.empty_like(x)
test_kernel[(1,)](x, y, N=256)
print('Triton basic test passed')
print(f'Result matches: {(x + 1 == y).all().item()}')
"
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars){
"command": "<503 chars \u2014 see below>",
"description": "Triton smoke test from file"
}cat > /tmp/triton_test.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernel(X, Y, N: tl.constexpr):
pid = tl.program_id(0)
offs = pid * N + tl.arange(0, N)
x = tl.load(X + offs)
tl.store(Y + offs, x + 1)
x = torch.randn(256, device='cuda', dtype=torch.bfloat16)
y = torch.empty_like(x)
test_kernel[(1,)](x, y, N=256)
print('Triton basic test passed')
print(f'Result matches: {(x + 1 == y).all().item()}')
EOF
uv run python /tmp/triton_test.pyTriton basic test passed
Result matches: True
{
"command": "<2122 chars \u2014 see below>",
"description": "Test tl.dot with BLOCK_M=1"
}cat > /tmp/triton_dot_test.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def dot_test_kernel(X, W, Y, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
pid_m = pid // tl.cdiv(N, BLOCK_N)
pid_n = pid % tl.cdiv(N, BLOCK_N)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
k_offs = k + tl.arange(0, BLOCK_K)
x_tile = tl.load(X + offs_m[:, None] * K + k_offs[None, :], mask=(offs_m[:, None] < M) & (k_offs[None, :] < K), other=0.0)
w_tile = tl.load(W + k_offs[:, None] * N + offs_n[None, :], mask=(k_offs[:, None] < K) & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(x_tile, w_tile)
tl.store(Y + offs_m[:, None] * N + offs_n[None, :], acc.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
M, N, K = 1, 128, 128
x = torch.randn(M, K, device='cuda', dtype=torch.bfloat16)
w = torch.randn(K, N, device='cuda', dtype=torch.bfloat16)
y = torch.empty(M, N, device='cuda', dtype=torch.bfloat16)
grid = (triton.cdiv(M, 16) * triton.cdiv(N, 64),)
dot_test_kernel[grid](x, w, y, M, N, K, BLOCK_M=16, BLOCK_N=64, BLOCK_K=32)
ref = (x @ w)
print(f"tl.dot with M=1 (BLOCK_M=16 padded): max_diff = {(y - ref).abs().max().item():.6f}")
print(f"Match: {torch.allclose(y, ref, atol=0.01)}")
# Test with BLOCK_M=1
M2, N2, K2 = 1, 128, 64
x2 = torch.randn(M2, K2, device='cuda', dtype=torch.bfloat16)
w2 = torch.randn(K2, N2, device='cuda', dtype=torch.bfloat16)
y2 = torch.empty(M2, N2, device='cuda', dtype=torch.bfloat16)
try:
grid2 = (triton.cdiv(M2, 1) * triton.cdiv(N2, 64),)
dot_test_kernel[grid2](x2, w2, y2, M2, N2, K2, BLOCK_M=1, BLOCK_N=64, BLOCK_K=32)
ref2 = (x2 @ w2)
print(f"tl.dot with BLOCK_M=1: max_diff = {(y2 - ref2).abs().max().item():.6f}")
print(f"Match: {torch.allclose(y2, ref2, atol=0.01)}")
except Exception as e:
print(f"BLOCK_M=1 failed: {e}")
EOF
uv run python /tmp/triton_dot_test.pytl.dot with M=1 (BLOCK_M=16 padded): max_diff = 0.000000
Match: True
tl.dot with BLOCK_M=1: max_diff = 0.000000
Match: True
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -0,0 +1,192 @@
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
+ ],
+ key=["M", "N", "K"],
+)
+@triton.jit
+def _w4a16_gemm_kernel(
+ x_ptr,
+ w_q_ptr,
+ scales_ptr,
+ zeros_ptr,
+ out_ptr,
+ M,
+ N,
+ K,
+ stride_xm,
+ stride_xk,
+ stride_wk,
+ stride_wn,
+ stride_sg,
+ stride_sn,
+ stride_zg,
+ stride_zn,
+ stride_om,
+ stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ pid_m = pid // num_pid_n
+ pid_n = pid % num_pid_n
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_m = offs_m < M
+ mask_n = offs_n < N
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ HALF_K: tl.constexpr = BLOCK_K // 2
+
+ for k_start in range(0, K, BLOCK_K):
+ k_even = k_start + tl.arange(0, HALF_K) * 2
+ k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
+
+ x_even = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + k_even[None, :] * stride_xk,
+ mask=mask_m[:, None] & (k_even[None, :] < K),
+ other=0.0,
+ )
+ x_odd = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + k_odd[None, :] * stride_xk,
+ mask=mask_m[:, None] & (k_odd[None, :] < K),
+ other=0.0,
+ )
+
+ packed_k = k_start // 2 + tl.arange(0, HALF_K)
+ w_packed = tl.load(
+ w_q_ptr + packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=(packed_k[:, None] < K // 2) & mask_n[None, :],
+ other=0,
+ )
+
+ w_lo = (w_packed & 0xF).to(tl.float32)
+ w_hi = (w_packed >> 4).to(tl.float32)
+
+ g = k_start // GROUP_SIZE
+ s = tl.load(
+ scales_ptr + g * stride_sg + offs_n * stride_sn,
+ mask=mask_n,
+ other=1.0,
+ ).to(tl.float32)
+ z = tl.load(
+ zeros_ptr + g * stride_zg + offs_n * stride_zn,
+ mask=mask_n,
+ other=0.0,
+ ).to(tl.float32)
+
+ w_lo_dq = ((w_lo - z[None, :]) * s[None, :]).to(tl.bfloat16)
+ w_hi_dq = ((w_hi - z[None, :]) * s[None, :]).to(tl.bfloat16)
+
+ acc += tl.dot(x_even, w_lo_dq) + tl.dot(x_odd, w_hi_dq)
+
+ tl.store(
+ out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on,
+ acc.to(tl.bfloat16),
+ mask=mask_m[:, None] & mask_n[None, :],
+ )
+
+
+def w4a16_gemm(x, w_q, scales, zeros):
+ M, K = x.shape
+ K2, N = w_q.shape
+ assert K2 * 2 == K
+
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+
+ grid = lambda META: (
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
+ )
+
+ _w4a16_gemm_kernel[grid](
+ x,
+ w_q,
+ scales,
+ zeros,
+ out,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ w_q.stride(0),
+ w_q.stride(1),
+ scales.stride(0),
+ scales.stride(1),
+ zeros.stride(0),
+ zeros.stride(1),
+ out.stride(0),
+ out.stride(1),
+ GROUP_SIZE=GROUP_SIZE,
+ )
+
+ return out
+
+
+class Model(nn.Module):
+ def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ self.group_size = group_size
+ n_groups = K // group_size
+
+ torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
+ w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
+ w_g = w_full.view(n_groups, group_size, N)
+ w_min = w_g.min(dim=1, keepdim=True).values
+ w_max = w_g.max(dim=1, keepdim=True).values
+ scales = (w_max - w_min).clamp_min(1e-8) / 15.0
+ zeros = (-w_min / scales).round().clamp(0, 15)
+ w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
+ w_q = w_q.view(K, N)
+ scales_2d = scales.squeeze(1).to(torch.bfloat16)
+ zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
+ lo = w_q[0::2].to(torch.uint8) & 0xF
+ hi = w_q[1::2].to(torch.uint8) & 0xF
+ w_packed = (lo | (hi << 4)).contiguous()
+
+ self.register_buffer("w_q", w_packed)
+ self.register_buffer("scales", scales_2d)
+ self.register_buffer("zeros", zeros_2d)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return w4a16_gemm(x, self.w_q, self.scales, self.zeros)
+
+
+M = 1
+N = 12288
+K = 4096
+
+
+def get_inputs():
+ x = torch.randn(M, K, dtype=torch.bfloat16)
+ return [x]
+
+
+def get_
... (37 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<5734 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "uv run python check.py",
"description": "Run correctness check",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm",
"timeout": 180000
}PASS
{
"command": "uv run python benchmark.py",
"description": "Run performance benchmark",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.132 gbps=35.002 ms=0.765
shape=0 variant=compiled tflops=0.660 gbps=175.537 ms=0.153
shape=0 variant=sota tflops=2.328 gbps=619.022 ms=0.043
shape=0 variant=solution tflops=1.466 gbps=389.754 ms=0.069
shape=0 solution_peak_fraction=0.2165
shape=1 variant=eager tflops=4.192 gbps=36.159 ms=0.768
shape=1 variant=compiled tflops=20.721 gbps=178.747 ms=0.155
shape=1 variant=sota tflops=23.221 gbps=200.312 ms=0.139
shape=1 variant=solution tflops=36.347 gbps=313.541 ms=0.089
shape=1 solution_peak_fraction=0.1742
shape=2 variant=eager tflops=31.394 gbps=42.793 ms=0.821
shape=2 variant=compiled tflops=121.409 gbps=165.495 ms=0.212
shape=2 variant=sota tflops=133.175 gbps=181.533 ms=0.194
shape=2 variant=solution tflops=81.037 gbps=110.463 ms=0.318
shape=2 solution_peak_fraction=0.0614
shape=3 variant=eager tflops=0.183 gbps=48.571 ms=0.184
shape=3 variant=compiled tflops=0.505 gbps=134.477 ms=0
... (465 more chars){
"command": "<257 chars \u2014 see below>",
"description": "Check hardware peak specs",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
from src.hardware import get as get_hw
hw = get_hw('RTX_PRO_6000')
print('Peak bf16 TFLOPS:', hw.peak_tflops_dense.get('bf16', 'N/A'))
print('Peak bandwidth GB/s:', hw.peak_bandwidth_gb_s)
print('All tflops keys:', hw.peak_tflops_dense)
"Peak bf16 TFLOPS: 200.0
Peak bandwidth GB/s: 1800.0
All tflops keys: {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -10,16 +10,28 @@
@triton.autotune(
configs=[
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=5, num_warps=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=5, num_warps=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=5, num_warps=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 512, "BLOCK_K": 64}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
],
key=["M", "N", "K"],
)
@@ -64,25 +76,21 @@
HALF_K: tl.constexpr = BLOCK_K // 2
for k_start in range(0, K, BLOCK_K):
- k_even = k_start + tl.arange(0, HALF_K) * 2
- k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
-
- x_even = tl.load(
- x_ptr + offs_m[:, None] * stride_xm + k_even[None, :] * stride_xk,
- mask=mask_m[:, None] & (k_even[None, :] < K),
+ k_offs = k_start + tl.arange(0, BLOCK_K)
+ x_block = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
+ mask=mask_m[:, None] & (k_offs[None, :] < K),
other=0.0,
)
- x_odd = tl.load(
- x_ptr + offs_m[:, None] * stride_xm + k_odd[None, :] * stride_xk,
- mask=mask_m[:, None] & (k_odd[None, :] < K),
- other=0.0,
- )
+ x_even = x_block[:, 0::2]
+ x_odd = x_block[:, 1::2]
packed_k = k_start // 2 + tl.arange(0, HALF_K)
w_packed = tl.load(
w_q_ptr + packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
mask=(packed_k[:, None] < K // 2) & mask_n[None, :],
other=0,
+ eviction_policy="evict_last",
)
w_lo = (w_packed & 0xF).to(tl.float32)
{
"content": "<6768 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py",
"description": "Verify correctness after optimization",
"timeout": 180000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1101, in __getitem__
raise ValueError(f"unsupported tensor index: {sl}")
ValueError: unsupported tensor index: <triton.language.core.slice object at 0x7bf185036ad0>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 186, in forward
return w4a16_gemm(x, self.w_q, self.scales, self.zeros)
... (3567 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -76,14 +76,19 @@
HALF_K: tl.constexpr = BLOCK_K // 2
for k_start in range(0, K, BLOCK_K):
- k_offs = k_start + tl.arange(0, BLOCK_K)
- x_block = tl.load(
- x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
- mask=mask_m[:, None] & (k_offs[None, :] < K),
+ k_even = k_start + tl.arange(0, HALF_K) * 2
+ k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
+
+ x_even = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + k_even[None, :] * stride_xk,
+ mask=mask_m[:, None] & (k_even[None, :] < K),
other=0.0,
)
- x_even = x_block[:, 0::2]
- x_odd = x_block[:, 1::2]
+ x_odd = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + k_odd[None, :] * stride_xk,
+ mask=mask_m[:, None] & (k_odd[None, :] < K),
+ other=0.0,
+ )
packed_k = k_start // 2 + tl.arange(0, HALF_K)
w_packed = tl.load(
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<353 chars \u2014 see below>",
"newString": "<537 chars \u2014 see below>"
} for k_start in range(0, K, BLOCK_K):
k_offs = k_start + tl.arange(0, BLOCK_K)
x_block = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=mask_m[:, None] & (k_offs[None, :] < K),
other=0.0,
)
x_even = x_block[:, 0::2]
x_odd = x_block[:, 1::2] for k_start in range(0, K, BLOCK_K):
k_even = k_start + tl.arange(0, HALF_K) * 2
k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
x_even = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_even[None, :] * stride_xk,
mask=mask_m[:, None] & (k_even[None, :] < K),
other=0.0,
)
x_odd = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_odd[None, :] * stride_xk,
mask=mask_m[:, None] & (k_odd[None, :] < K),
other=0.0,
)Edit applied successfully.{
"command": "uv run python check.py",
"description": "Verify correctness",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark with more autotune configs",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.132 gbps=34.978 ms=0.765
shape=0 variant=compiled tflops=0.663 gbps=176.314 ms=0.152
shape=0 variant=sota tflops=2.311 gbps=614.701 ms=0.044
shape=0 variant=solution tflops=1.505 gbps=400.291 ms=0.067
shape=0 solution_peak_fraction=0.2224
shape=1 variant=eager tflops=4.195 gbps=36.184 ms=0.768
shape=1 variant=compiled tflops=20.636 gbps=178.014 ms=0.156
shape=1 variant=sota tflops=23.221 gbps=200.312 ms=0.139
shape=1 variant=solution tflops=35.900 gbps=309.683 ms=0.090
shape=1 solution_peak_fraction=0.1720
shape=2 variant=eager tflops=31.304 gbps=42.672 ms=0.823
shape=2 variant=compiled tflops=120.016 gbps=163.596 ms=0.215
shape=2 variant=sota tflops=133.175 gbps=181.533 ms=0.194
shape=2 variant=solution tflops=80.672 gbps=109.965 ms=0.319
shape=2 solution_peak_fraction=0.0611
shape=3 variant=eager tflops=0.182 gbps=48.482 ms=0.184
shape=3 variant=compiled tflops=0.511 gbps=136.018 ms=0
... (465 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
1: from __future__ import annotations
2:
3: import torch
4: import torch.nn as nn
5: import triton
6: import triton.language as tl
7:
8: GROUP_SIZE = 128
9:
10:
11: @triton.autotune(
12: configs=[
13: triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=5, num_warps=2),
14: triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=5, num_warps=2),
15: triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
16: triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=5, num_warps=4),
17: triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=2),
18: triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_stages=3, num_warps=4),
19: triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=4, num_warps=4),
20: triton.Config({"BLOCK_M": 1, "BLOCK_N": 512, "BLOCK_K": 64}, num_stages=3, num_warps=8),
21: triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=3, num_warps=4),
22: triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
23: triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
24: triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK
... (6526 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -7,19 +7,123 @@
GROUP_SIZE = 128
+CUDA_SRC = r"""
+#include <cuda.h>
+#include <cuda_bf16.h>
+#include <cuda_runtime.h>
+
+template<int BLOCK_N, int GROUP_SIZE, int NTHREADS>
+__global__ void __launch_bounds__(NTHREADS)
+w4a16_gemv_kernel(
+ const __nv_bfloat16* __restrict__ x,
+ const uint8_t* __restrict__ w_q,
+ const __nv_bfloat16* __restrict__ scales,
+ const __nv_bfloat16* __restrict__ zeros,
+ __nv_bfloat16* __restrict__ out,
+ int M, int N, int K
+) {
+ const int m = blockIdx.y;
+ const int n_base = blockIdx.x * BLOCK_N;
+ const int tid = threadIdx.x;
+
+ constexpr int ELEMS = BLOCK_N / NTHREADS;
+
+ float acc[ELEMS];
+ #pragma unroll
+ for (int j = 0; j < ELEMS; j++) acc[j] = 0.0f;
+
+ const int n_groups = K / GROUP_SIZE;
+
+ for (int g = 0; g < n_groups; g++) {
+ float s[ELEMS], z[ELEMS];
+ #pragma unroll
+ for (int j = 0; j < ELEMS; j++) {
+ int n = n_base + tid + j * NTHREADS;
+ if (n < N) {
+ s[j] = __bfloat162float(scales[g * N + n]);
+ z[j] = __bfloat162float(zeros[g * N + n]);
+ } else {
+ s[j] = 1.0f;
+ z[j] = 0.0f;
+ }
+ }
+
+ for (int k = g * GROUP_SIZE; k < (g + 1) * GROUP_SIZE; k += 2) {
+ float x_even = __bfloat162float(x[m * K + k]);
+ float x_odd = __bfloat162float(x[m * K + k + 1]);
+ int packed_k = k / 2;
+
+ #pragma unroll
+ for (int j = 0; j < ELEMS; j++) {
+ int n = n_base + tid + j * NTHREADS;
+ if (n < N) {
+ uint8_t packed = w_q[packed_k * N + n];
+ float w_lo = (float)(packed & 0xF);
+ float w_hi = (float)(packed >> 4);
+ acc[j] += x_even * (w_lo - z[j]) * s[j]
+ + x_odd * (w_hi - z[j]) * s[j];
+ }
+ }
+ }
+ }
+
+ #pragma unroll
+ for (int j = 0; j < ELEMS; j++) {
+ int n = n_base + tid + j * NTHREADS;
+ if (n < N) {
+ out[m * N + n] = __float2bfloat16(acc[j]);
+ }
+ }
+}
+
+// Dispatch wrapper
+void w4a16_gemv_cuda(
+ const __nv_bfloat16* x, const uint8_t* w_q,
+ const __nv_bfloat16* scales, const __nv_bfloat16* zeros,
+ __nv_bfloat16* out,
+ int M, int N, int K
+) {
+ const int BLOCK_N = 256;
+ const int NTHREADS = 256;
+ dim3 grid((N + BLOCK_N - 1) / BLOCK_N, M);
+ w4a16_gemv_kernel<BLOCK_N, 128, NTHREADS><<<grid, NTHREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+ x, w_q, scales, zeros, out, M, N, K);
+}
+"""
+
+_cuda_module = None
+
+def _get_cuda_module():
+ global _cuda_module
+ if _cuda_module is not None:
+ return _cuda_module
+ from torch.utils.cpp_extension import load_inline
+ _cuda_module = load_inline(
+ name="w4a16_cuda",
+ cpp_sources="",
+ cuda_sources=CUDA_SRC,
+ functions=["w4a16_gemv_cuda"],
+ extra_cuda_cflags=["-O3", "-arch=sm_120"],
+ verbose=False,
+ )
+ return _cuda_module
+
+
+def w4a16_gemv_cuda(x, w_q, scales, zeros):
+ M, K = x.shape
+ N = w_q.shape[1]
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ mod = _get_cuda_module()
+ mod.w4a16_gemv_cuda(
+ x.data_ptr(), w_q.data_ptr(), scales.data_ptr(), zeros.data_ptr(),
+ out.data_ptr(), M, N, K,
+ )
+ return out
+
@triton.autotune(
configs=[
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=5, num_warps=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=5, num_warps=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=5, num_warps=4),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=4, num_warps=4),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 512, "BLOCK_K": 64}, num_stages=3, num_warps=8),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
@@ -125,10 +229,9 @@
)
-def w4a16_gemm(x, w_q, scales, zeros):
+def w4a16_gemm_triton(x, w_q, scales, zeros):
M, K = x.shape
- K2, N = w_q.shape
- assert K2 * 2 == K
+ N = w_q.shape[1]
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
@@ -137,28 +240,24 @@
)
_w4a16_gemm_kernel[grid](
- x,
- w_q,
- scales,
- zeros,
- out,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- w_q.stride(0),
- w_q.stride(1),
- scales.stride(0),
- scales.stride(1),
- zeros.stride(0),
- zeros.stride(1),
- out.stride(0),
- out.stride(1),
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0),
... (289 more chars){
"content": "<9481 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA+Triton hybrid solution",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 290, in forward
return w4a16_gemm(x, self.w_q, self.scales, self.zeros)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
... (4970 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -11,6 +11,8 @@
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
+#include <torch/extension.h>
+#include <c10/cuda/CUDAStream.h>
template<int BLOCK_N, int GROUP_SIZE, int NTHREADS>
__global__ void __launch_bounds__(NTHREADS)
@@ -76,18 +78,21 @@
}
}
-// Dispatch wrapper
-void w4a16_gemv_cuda(
- const __nv_bfloat16* x, const uint8_t* w_q,
- const __nv_bfloat16* scales, const __nv_bfloat16* zeros,
- __nv_bfloat16* out,
- int M, int N, int K
+void w4a16_gemv_cuda_launch(
+ at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros,
+ at::Tensor out, int64_t M, int64_t N, int64_t K
) {
const int BLOCK_N = 256;
const int NTHREADS = 256;
dim3 grid((N + BLOCK_N - 1) / BLOCK_N, M);
- w4a16_gemv_kernel<BLOCK_N, 128, NTHREADS><<<grid, NTHREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
- x, w_q, scales, zeros, out, M, N, K);
+ auto stream = c10::cuda::getCurrentCUDAStream();
+ w4a16_gemv_kernel<BLOCK_N, 128, NTHREADS><<<grid, NTHREADS, 0, stream>>>(
+ reinterpret_cast<const __nv_bfloat16*>(x.data_ptr<at::BFloat16>()),
+ w_q.data_ptr<uint8_t>(),
+ reinterpret_cast<const __nv_bfloat16*>(scales.data_ptr<at::BFloat16>()),
+ reinterpret_cast<const __nv_bfloat16*>(zeros.data_ptr<at::BFloat16>()),
+ reinterpret_cast<__nv_bfloat16*>(out.data_ptr<at::BFloat16>()),
+ M, N, K);
}
"""
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<2548 chars \u2014 see below>",
"newString": "<2904 chars \u2014 see below>"
}CUDA_SRC = r"""
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
template<int BLOCK_N, int GROUP_SIZE, int NTHREADS>
__global__ void __launch_bounds__(NTHREADS)
w4a16_gemv_kernel(
const __nv_bfloat16* __restrict__ x,
const uint8_t* __restrict__ w_q,
const __nv_bfloat16* __restrict__ scales,
const __nv_bfloat16* __restrict__ zeros,
__nv_bfloat16* __restrict__ out,
int M, int N, int K
) {
const int m = blockIdx.y;
const int n_base = blockIdx.x * BLOCK_N;
const int tid = threadIdx.x;
constexpr int ELEMS = BLOCK_N / NTHREADS;
float acc[ELEMS];
#pragma unroll
for (int j = 0; j < ELEMS; j++) acc[j] = 0.0f;
const int n_groups = K / GROUP_SIZE;
for (int g = 0; g < n_groups; g++) {
float s[ELEMS], z[ELEMS];
#pragma unroll
for (int j = 0; j < ELEMS; j++) {
int n = n_base + tid + j * NTHREADS;
if (n < N) {
s[j] = __bfloat162float(scales[g * N + n]);
z[j] = __bfloat162float(zeros[g * N + n]);
} else {
s[j] = 1.0f;
z[j] = 0.0f;
}
}
for (int k = g * GROUP_SIZE; k < (g + 1) * GROUP_SIZE; k += 2) {
float x_even = __bfloat162float(x[m * K + k]);
float x_odd = __bfloat162float(x[m * K + k + 1]);
int packed_k = k / 2;
#pragma unroll
for (int j = 0; j < ELEMS; j++) {
int n = n_base + tid + j * NTHREADS;
if (n < N) {
uint8_t packed = w_q[packed_k * N + n];
float w_lo = (float)(packed & 0xF);
float w_hi = (float)(packed >> 4);
acc[j] += x_even * (w_lo - z[j]) * s[j]
+ x_odd * (w_hi - z[j]) * s[j];
}
}
}
}
#pragma unroll
for (int j = 0; j < ELEMS; j++) {
int n = n_base + tid + j * NTHREADS;
if (n < N) {
out[m * N + n] = __float2bfloat16(acc[j]);
}
}
}
// Dispatch wrapper
void w4a16_gemv_cuda(
const __nv_bfloat16* x, const uint8_t* w_q,
const __nv_bfloat16* scales, const __nv_bfloat16* zeros,
__nv_bfloat16* out,
int M, int N, int K
) {
const int BLOCK_N = 256;
const int NTHREADS = 256;
dim3 grid((N + BLOCK_N - 1) / BLOCK_N, M);
w4a16_gemv_kernel<BLOCK_N, 128, NTHREADS><<<grid, NTHREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
x, w_q, scales, zeros, out, M, N, K);
}
"""CUDA_SRC = r"""
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
template<int BLOCK_N, int GROUP_SIZE, int NTHREADS>
__global__ void __launch_bounds__(NTHREADS)
w4a16_gemv_kernel(
const __nv_bfloat16* __restrict__ x,
const uint8_t* __restrict__ w_q,
const __nv_bfloat16* __restrict__ scales,
const __nv_bfloat16* __restrict__ zeros,
__nv_bfloat16* __restrict__ out,
int M, int N, int K
) {
const int m = blockIdx.y;
const int n_base = blockIdx.x * BLOCK_N;
const int tid = threadIdx.x;
constexpr int ELEMS = BLOCK_N / NTHREADS;
float acc[ELEMS];
#pragma unroll
for (int j = 0; j < ELEMS; j++) acc[j] = 0.0f;
const int n_groups = K / GROUP_SIZE;
for (int g = 0; g < n_groups; g++) {
float s[ELEMS], z[ELEMS];
#pragma unroll
for (int j = 0; j < ELEMS; j++) {
int n = n_base + tid + j * NTHREADS;
if (n < N) {
s[j] = __bfloat162float(scales[g * N + n]);
z[j] = __bfloat162float(zeros[g * N + n]);
} else {
s[j] = 1.0f;
z[j] = 0.0f;
}
}
for (int k = g * GROUP_SIZE; k < (g + 1) * GROUP_SIZE; k += 2) {
float x_even = __bfloat162float(x[m * K + k]);
float x_odd = __bfloat162float(x[m * K + k + 1]);
int packed_k = k / 2;
#pragma unroll
for (int j = 0; j < ELEMS; j++) {
int n = n_base + tid + j * NTHREADS;
if (n < N) {
uint8_t packed = w_q[packed_k * N + n];
float w_lo = (float)(packed & 0xF);
float w_hi = (float)(packed >> 4);
acc[j] += x_even * (w_lo - z[j]) * s[j]
+ x_odd * (w_hi - z[j]) * s[j];
}
}
}
}
#pragma unroll
for (int j = 0; j < ELEMS; j++) {
int n = n_base + tid + j * NTHREADS;
if (n < N) {
out[m * N + n] = __float2bfloat16(acc[j]);
}
}
}
void w4a16_gemv_cuda_launch(
at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros,
at::Tensor out, int64_t M, int64_t N, int64_t K
) {
const int BLOCK_N = 256;
const int NTHREADS = 256;
dim3 grid((N + BLOCK_N - 1) / BLOCK_N, M);
auto stream = c10::cuda::getCurrentCUDAStream();
w4a16_gemv_kernel<BLOCK_N, 128, NTHREADS><<<grid, NTHREADS, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(x.data_ptr<at::BFloat16>()),
w_q.data_ptr<uint8_t>(),
reinterpret_cast<const __nv_bfloat16*>(scales.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(zeros.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(out.data_ptr<at::BFloat16>()),
M, N, K);
}
"""Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -107,7 +107,7 @@
name="w4a16_cuda",
cpp_sources="",
cuda_sources=CUDA_SRC,
- functions=["w4a16_gemv_cuda"],
+ functions=["w4a16_gemv_cuda_launch"],
extra_cuda_cflags=["-O3", "-arch=sm_120"],
verbose=False,
)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<293 chars \u2014 see below>",
"oldString": "<286 chars \u2014 see below>"
} from torch.utils.cpp_extension import load_inline
_cuda_module = load_inline(
name="w4a16_cuda",
cpp_sources="",
cuda_sources=CUDA_SRC,
functions=["w4a16_gemv_cuda_launch"],
extra_cuda_cflags=["-O3", "-arch=sm_120"],
verbose=False,
) from torch.utils.cpp_extension import load_inline
_cuda_module = load_inline(
name="w4a16_cuda",
cpp_sources="",
cuda_sources=CUDA_SRC,
functions=["w4a16_gemv_cuda"],
extra_cuda_cflags=["-O3", "-arch=sm_120"],
verbose=False,
)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -119,10 +119,7 @@
N = w_q.shape[1]
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
mod = _get_cuda_module()
- mod.w4a16_gemv_cuda(
- x.data_ptr(), w_q.data_ptr(), scales.data_ptr(), zeros.data_ptr(),
- out.data_ptr(), M, N, K,
- )
+ mod.w4a16_gemv_cuda_launch(x, w_q, scales, zeros, out, M, N, K)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<264 chars \u2014 see below>",
"oldString": "<335 chars \u2014 see below>"
}def w4a16_gemv_cuda(x, w_q, scales, zeros):
M, K = x.shape
N = w_q.shape[1]
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
mod = _get_cuda_module()
mod.w4a16_gemv_cuda_launch(x, w_q, scales, zeros, out, M, N, K)
return outdef w4a16_gemv_cuda(x, w_q, scales, zeros):
M, K = x.shape
N = w_q.shape[1]
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
mod = _get_cuda_module()
mod.w4a16_gemv_cuda(
x.data_ptr(), w_q.data_ptr(), scales.data_ptr(), zeros.data_ptr(),
out.data_ptr(), M, N, K,
)
return outEdit applied successfully.{
"command": "rm -rf ~/.cache/torch_extensions/py311_cu130/w4a16_cuda 2>/dev/null; uv run python check.py 2>&1",
"description": "Test with fixed CUDA kernel",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}{
"command": "rm -rf ~/.cache/torch_extensions/py311_cu130/w4a16_cuda 2>/dev/null && uv run python check.py 2>&1",
"description": "Test with fixed CUDA kernel",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}{
"command": "rm -rf /home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda 2>/dev/null; uv run python check.py 2>&1",
"description": "Test with fixed CUDA kernel",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}{
"command": "<256 chars \u2014 see below>",
"description": "Clean CUDA extension cache"
}uv run python -c "
import subprocess, shutil, os
cache_dir = os.path.expanduser('~/.cache/torch_extensions/py311_cu130/w4a16_cuda')
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
print('Cleaned cache')
else:
print('No cache to clean')
"Cleaned cache
{
"command": "uv run python check.py 2>&1",
"description": "Test CUDA+Triton hybrid solution",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 292, in forward
return w4a16_gemm(x, self.w_q, self.scales, self.zeros)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
... (3778 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -105,7 +105,7 @@
from torch.utils.cpp_extension import load_inline
_cuda_module = load_inline(
name="w4a16_cuda",
- cpp_sources="",
+ cpp_sources='void w4a16_gemv_cuda_launch(at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros, at::Tensor out, int64_t M, int64_t N, int64_t K);',
cuda_sources=CUDA_SRC,
functions=["w4a16_gemv_cuda_launch"],
extra_cuda_cflags=["-O3", "-arch=sm_120"],
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<383 chars \u2014 see below>",
"oldString": "<239 chars \u2014 see below>"
} _cuda_module = load_inline(
name="w4a16_cuda",
cpp_sources='void w4a16_gemv_cuda_launch(at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros, at::Tensor out, int64_t M, int64_t N, int64_t K);',
cuda_sources=CUDA_SRC,
functions=["w4a16_gemv_cuda_launch"],
extra_cuda_cflags=["-O3", "-arch=sm_120"],
verbose=False,
) _cuda_module = load_inline(
name="w4a16_cuda",
cpp_sources="",
cuda_sources=CUDA_SRC,
functions=["w4a16_gemv_cuda_launch"],
extra_cuda_cflags=["-O3", "-arch=sm_120"],
verbose=False,
)Edit applied successfully.{
"command": "python3 -c \"import shutil; shutil.rmtree('/home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda', True)\"",
"description": "Clean CUDA cache"
}(no output){
"command": "uv run python check.py 2>&1",
"description": "Test CUDA+Triton hybrid",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark hybrid CUDA+Triton",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.132 gbps=34.994 ms=0.765
shape=0 variant=compiled tflops=0.663 gbps=176.239 ms=0.152
shape=0 variant=sota tflops=2.305 gbps=613.124 ms=0.044
shape=0 variant=solution tflops=0.247 gbps=65.686 ms=0.408
shape=0 solution_peak_fraction=0.0365
shape=1 variant=eager tflops=4.180 gbps=36.060 ms=0.771
shape=1 variant=compiled tflops=20.691 gbps=178.490 ms=0.156
shape=1 variant=sota tflops=23.261 gbps=200.659 ms=0.138
shape=1 variant=solution tflops=35.495 gbps=306.189 ms=0.091
shape=1 solution_peak_fraction=0.1701
shape=2 variant=eager tflops=31.362 gbps=42.750 ms=0.822
shape=2 variant=compiled tflops=121.418 gbps=165.507 ms=0.212
shape=2 variant=sota tflops=133.517 gbps=181.999 ms=0.193
shape=2 variant=solution tflops=80.246 gbps=109.384 ms=0.321
shape=2 solution_peak_fraction=0.0608
shape=3 variant=eager tflops=0.182 gbps=48.415 ms=0.184
shape=3 variant=compiled tflops=0.509 gbps=135.391 ms=0.
... (464 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -7,120 +7,125 @@
GROUP_SIZE = 128
-CUDA_SRC = r"""
-#include <cuda.h>
-#include <cuda_bf16.h>
-#include <cuda_runtime.h>
-#include <torch/extension.h>
-#include <c10/cuda/CUDAStream.h>
-
-template<int BLOCK_N, int GROUP_SIZE, int NTHREADS>
-__global__ void __launch_bounds__(NTHREADS)
-w4a16_gemv_kernel(
- const __nv_bfloat16* __restrict__ x,
- const uint8_t* __restrict__ w_q,
- const __nv_bfloat16* __restrict__ scales,
- const __nv_bfloat16* __restrict__ zeros,
- __nv_bfloat16* __restrict__ out,
- int M, int N, int K
-) {
- const int m = blockIdx.y;
- const int n_base = blockIdx.x * BLOCK_N;
- const int tid = threadIdx.x;
-
- constexpr int ELEMS = BLOCK_N / NTHREADS;
-
- float acc[ELEMS];
- #pragma unroll
- for (int j = 0; j < ELEMS; j++) acc[j] = 0.0f;
-
- const int n_groups = K / GROUP_SIZE;
-
- for (int g = 0; g < n_groups; g++) {
- float s[ELEMS], z[ELEMS];
- #pragma unroll
- for (int j = 0; j < ELEMS; j++) {
- int n = n_base + tid + j * NTHREADS;
- if (n < N) {
- s[j] = __bfloat162float(scales[g * N + n]);
- z[j] = __bfloat162float(zeros[g * N + n]);
- } else {
- s[j] = 1.0f;
- z[j] = 0.0f;
- }
- }
-
- for (int k = g * GROUP_SIZE; k < (g + 1) * GROUP_SIZE; k += 2) {
- float x_even = __bfloat162float(x[m * K + k]);
- float x_odd = __bfloat162float(x[m * K + k + 1]);
- int packed_k = k / 2;
-
- #pragma unroll
- for (int j = 0; j < ELEMS; j++) {
- int n = n_base + tid + j * NTHREADS;
- if (n < N) {
- uint8_t packed = w_q[packed_k * N + n];
- float w_lo = (float)(packed & 0xF);
- float w_hi = (float)(packed >> 4);
- acc[j] += x_even * (w_lo - z[j]) * s[j]
- + x_odd * (w_hi - z[j]) * s[j];
- }
- }
- }
- }
-
- #pragma unroll
- for (int j = 0; j < ELEMS; j++) {
- int n = n_base + tid + j * NTHREADS;
- if (n < N) {
- out[m * N + n] = __float2bfloat16(acc[j]);
- }
- }
-}
-
-void w4a16_gemv_cuda_launch(
- at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros,
- at::Tensor out, int64_t M, int64_t N, int64_t K
-) {
- const int BLOCK_N = 256;
- const int NTHREADS = 256;
- dim3 grid((N + BLOCK_N - 1) / BLOCK_N, M);
- auto stream = c10::cuda::getCurrentCUDAStream();
- w4a16_gemv_kernel<BLOCK_N, 128, NTHREADS><<<grid, NTHREADS, 0, stream>>>(
- reinterpret_cast<const __nv_bfloat16*>(x.data_ptr<at::BFloat16>()),
- w_q.data_ptr<uint8_t>(),
- reinterpret_cast<const __nv_bfloat16*>(scales.data_ptr<at::BFloat16>()),
- reinterpret_cast<const __nv_bfloat16*>(zeros.data_ptr<at::BFloat16>()),
- reinterpret_cast<__nv_bfloat16*>(out.data_ptr<at::BFloat16>()),
- M, N, K);
-}
-"""
-
-_cuda_module = None
-
-def _get_cuda_module():
- global _cuda_module
- if _cuda_module is not None:
- return _cuda_module
- from torch.utils.cpp_extension import load_inline
- _cuda_module = load_inline(
- name="w4a16_cuda",
- cpp_sources='void w4a16_gemv_cuda_launch(at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros, at::Tensor out, int64_t M, int64_t N, int64_t K);',
- cuda_sources=CUDA_SRC,
- functions=["w4a16_gemv_cuda_launch"],
- extra_cuda_cflags=["-O3", "-arch=sm_120"],
- verbose=False,
- )
- return _cuda_module
-
-
-def w4a16_gemv_cuda(x, w_q, scales, zeros):
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_stages=5, num_warps=4),
+ triton.Config({"BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_N": 128, "BLOCK_K": 64}, num_stages=5, num_warps=4),
+ triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=3, num_warps=4),
+ triton.Config({"BLOCK_N": 256, "BLOCK_K": 64}, num_stages=4, num_warps=4),
+ triton.Config({"BLOCK_N": 128, "BLOCK_K": 128}, num_stages=5, num_warps=2),
+ triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_stages=5, num_warps=2),
+ ],
+ key=["N", "K"],
+)
+@triton.jit
+def _w4a16_gemv_splitk_kernel(
+ x_ptr,
+ w_q_ptr,
+ scales_ptr,
+ zeros_ptr,
+ out_ptr,
+ M,
+ N,
+ K,
+ stride_xm,
+ stride_xk,
+ stride_wk,
+ stride_wn,
+ stride_sg,
+ stride_sn,
+ stride_zg,
+ stride_zn,
+ stride_om,
+ stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(axis=0)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ num_k_blocks = tl.cdiv(K, BLOCK_K)
+ pid_n = pid // num_k_blocks
+ k_idx = pid % num_k_blocks
+
+ row = 0
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < N
+
+ k_start = k_idx * BLOCK_K
+ HALF_K: tl.constexpr = BLOCK_K // 2
+
+ k_even = k_start + tl.arange(0, HALF_K) * 2
+ k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
+
+ x_even = tl.load(
+ x_ptr + row * stride_xm + k_even * stride_xk,
+ mask=k_even < K,
+ other=0.0,
+ )
+ x_odd = tl.load(
+ x_ptr + row * stride_xm + k_odd * stride_xk,
+ mask=k_odd < K,
+ other=0.0,
+ )
+
+ packed_k = k_start // 2 + tl.arange(0, HALF_K)
+ w_packed = tl.load(
+ w_q_ptr + packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=(packed_k[:, None] < K // 2) & mask_n[None, :],
+ other=0,
+ eviction_policy="evict_last",
+ )
+
+ w_lo = (w_pack
... (1778 more chars){
"content": "<9470 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test split-K Triton kernel",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}FAIL: shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=26619.2 (atol=0.1, rtol=0.1)
{
"command": "<780 chars \u2014 see below>",
"description": "Debug split-K kernel correctness",
"timeout": 120000
}cat > /tmp/test_splitk.py << 'PYEOF'
import torch
import sys
sys.path.insert(0, "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm")
import reference
import solution
M, N, K = 1, 128, 128
reference.M, reference.N, reference.K = M, N, K
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = torch.randn(1, K, dtype=torch.bfloat16, device='cuda')
with torch.no_grad():
ref_out = ref(x)
sol_out = sol(x)
print(f"ref_out shape: {ref_out.shape}, sol_out shape: {sol_out.shape}")
print(f"ref_out: {ref_out[0, :5]}")
print(f"sol_out: {sol_out[0, :5]}")
print(f"max_diff: {(ref_out - sol_out).abs().max().item()}")
PYEOF
uv run python /tmp/test_splitk.pyref_out shape: torch.Size([1, 128]), sol_out shape: torch.Size([1, 128])
ref_out: tensor([ 0.4199, -0.1226, -0.0050, 0.1572, 0.0713], device='cuda:0',
dtype=torch.bfloat16)
sol_out: tensor([2512.0000, -732.0000, -30.2500, 936.0000, 424.0000],
device='cuda:0', dtype=torch.bfloat16)
max_diff: 3248.0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -15,9 +15,7 @@
triton.Config({"BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_N": 128, "BLOCK_K": 64}, num_stages=5, num_warps=4),
triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_N": 256, "BLOCK_K": 64}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_N": 128, "BLOCK_K": 128}, num_stages=5, num_warps=2),
- triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_stages=5, num_warps=2),
],
key=["N", "K"],
)
@@ -28,10 +26,8 @@
scales_ptr,
zeros_ptr,
out_ptr,
- M,
N,
K,
- stride_xm,
stride_xk,
stride_wk,
stride_wn,
@@ -39,8 +35,6 @@
stride_sn,
stride_zg,
stride_zn,
- stride_om,
- stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
@@ -51,24 +45,25 @@
pid_n = pid // num_k_blocks
k_idx = pid % num_k_blocks
- row = 0
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
k_start = k_idx * BLOCK_K
HALF_K: tl.constexpr = BLOCK_K // 2
+ offs_m = tl.zeros((1, 1), dtype=tl.int32)
+
k_even = k_start + tl.arange(0, HALF_K) * 2
k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
x_even = tl.load(
- x_ptr + row * stride_xm + k_even * stride_xk,
- mask=k_even < K,
+ x_ptr + k_even[None, :] * stride_xk,
+ mask=k_even[None, :] < K,
other=0.0,
)
x_odd = tl.load(
- x_ptr + row * stride_xm + k_odd * stride_xk,
- mask=k_odd < K,
+ x_ptr + k_odd[None, :] * stride_xk,
+ mask=k_odd[None, :] < K,
other=0.0,
)
@@ -95,10 +90,12 @@
other=0.0,
).to(tl.float32)
- w_lo_dq = (w_lo - z[None, :]) * s[None, :]
- w_hi_dq = (w_hi - z[None, :]) * s[None, :]
-
- acc = tl.sum(x_even[:, None] * w_lo_dq, axis=0) + tl.sum(x_odd[:, None] * w_hi_dq, axis=0)
+ w_lo_dq = ((w_lo - z[None, :]) * s[None, :]).to(tl.bfloat16)
+ w_hi_dq = ((w_hi - z[None, :]) * s[None, :]).to(tl.bfloat16)
+
+ partial = tl.dot(x_even, w_lo_dq) + tl.dot(x_odd, w_hi_dq)
+
+ acc = partial.reshape(BLOCK_N)
tl.atomic_add(out_ptr + offs_n, acc, mask=mask_n)
@@ -116,12 +113,11 @@
_w4a16_gemv_splitk_kernel[grid](
x, w_q, scales, zeros, acc,
- M, N, K,
- x.stride(0), x.stride(1),
+ N, K,
+ x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
- 0, 1,
GROUP_SIZE=GROUP_SIZE,
)
@@ -141,8 +137,9 @@
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8),
],
key=["M", "N", "K"],
)
{
"content": "<9396 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "<714 chars \u2014 see below>",
"description": "Test fixed split-K kernel",
"timeout": 120000
}cat > /tmp/test_splitk2.py << 'PYEOF'
import torch
import sys
sys.path.insert(0, "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm")
import reference
import solution
M, N, K = 1, 128, 128
reference.M, reference.N, reference.K = M, N, K
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = torch.randn(1, K, dtype=torch.bfloat16, device='cuda')
with torch.no_grad():
ref_out = ref(x)
sol_out = sol(x)
print(f"ref_out: {ref_out[0, :5]}")
print(f"sol_out: {sol_out[0, :5]}")
print(f"max_diff: {(ref_out - sol_out).abs().max().item()}")
PYEOF
uv run python /tmp/test_splitk2.py 2>&1ref_out: tensor([ 0.4199, -0.1226, -0.0050, 0.1572, 0.0713], device='cuda:0',
dtype=torch.bfloat16)
sol_out: tensor([1896.0000, -552.0000, -22.6250, 708.0000, 320.0000],
device='cuda:0', dtype=torch.bfloat16)
max_diff: 2464.0
{
"command": "<653 chars \u2014 see below>",
"description": "Test tl.atomic_add correctness"
}cat > /tmp/test_atomic.py << 'PYEOF'
import torch
import triton
import triton.language as tl
@triton.jit
def atomic_add_test(out_ptr, N, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
vals = offs.to(tl.float32) + 1.0
tl.atomic_add(out_ptr + offs, vals, mask=offs < N)
N = 64
out = torch.zeros(N, dtype=torch.float32, device='cuda')
grid = (1,)
atomic_add_test[grid](out, N, BLOCK=64)
print("out:", out[:10].tolist())
print("expected:", [i+1 for i in range(10)])
print("match:", torch.allclose(out, torch.arange(1, N+1, dtype=torch.float32, device='cuda')))
PYEOF
uv run python /tmp/test_atomic.pyout: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
expected: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
match: True
{
"command": "<2474 chars \u2014 see below>",
"description": "Debug split-K kernel step by step",
"timeout": 120000
}cat > /tmp/test_splitk_debug.py << 'PYEOF'
import torch
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def splitk_debug_kernel(
x_ptr, w_q_ptr, scales_ptr, zeros_ptr, out_ptr,
N, K, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn, stride_zg, stride_zn,
GROUP_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
offs_n = tl.arange(0, BLOCK_N)
k_start = 0
HALF_K: tl.constexpr = BLOCK_K // 2
k_even = k_start + tl.arange(0, HALF_K) * 2
k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
x_even = tl.load(x_ptr + k_even[None, :] * stride_xk, mask=k_even[None, :] < K, other=0.0)
x_odd = tl.load(x_ptr + k_odd[None, :] * stride_xk, mask=k_odd[None, :] < K, other=0.0)
packed_k = k_start // 2 + tl.arange(0, HALF_K)
w_packed = tl.load(w_q_ptr + packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
mask=(packed_k[:, None] < K // 2) & (offs_n[None, :] < N), other=0)
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = (w_packed >> 4).to(tl.float32)
g = 0
s = tl.load(scales_ptr + g * stride_sg + offs_n * stride_sn, mask=offs_n < N, other=1.0).to(tl.float32)
z = tl.load(zeros_ptr + g * stride_zg + offs_n * stride_zn, mask=offs_n < N, other=0.0).to(tl.float32)
w_lo_dq = ((w_lo - z[None, :]) * s[None, :]).to(tl.bfloat16)
w_hi_dq = ((w_hi - z[None, :]) * s[None, :]).to(tl.bfloat16)
partial = tl.dot(x_even, w_lo_dq) + tl.dot(x_odd, w_hi_dq)
acc = partial.reshape(BLOCK_N)
tl.store(out_ptr + offs_n, acc, mask=offs_n < N)
M, N, K = 1, 128, 128
import sys
sys.path.insert(0, "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm")
import reference
ref = reference.Model(M, N, K).cuda().eval()
torch.manual_seed(42)
x = torch.randn(1, K, dtype=torch.bfloat16, device='cuda')
out = torch.empty(N, dtype=torch.float32, device='cuda')
splitk_debug_kernel[(1,)](x, ref.w_q, ref.scales, ref.zeros, out,
N, K, x.stride(1),
ref.w_q.stride(0), ref.w_q.stride(1),
ref.scales.stride(0), ref.scales.stride(1), ref.zeros.stride(0), ref.zeros.stride(1),
GROUP_SIZE=128, BLOCK_N=128, BLOCK_K=128)
sol_out = out.to(torch.bfloat16).unsqueeze(0)
with torch.no_grad():
ref_out = ref(x)
print(f"ref_out: {ref_out[0, :5]}")
print(f"sol_out: {sol_out[0, :5]}")
print(f"max_diff: {(ref_out - sol_out).abs().max().item()}")
PYEOF
uv run python /tmp/test_splitk_debug.pyref_out: tensor([ 0.4199, -0.1226, -0.0050, 0.1572, 0.0713], device='cuda:0',
dtype=torch.bfloat16)
sol_out: tensor([ 0.4199, -0.1226, -0.0050, 0.1572, 0.0713], device='cuda:0',
dtype=torch.bfloat16)
max_diff: 0.0
{
"command": "<2893 chars \u2014 see below>",
"description": "Test full split-K kernel",
"timeout": 120000
}cat > /tmp/test_splitk_full.py << 'PYEOF'
import torch
import triton
import triton.language as tl
import sys
sys.path.insert(0, "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm")
import reference
GROUP_SIZE = 128
@triton.jit
def splitk_full_kernel(
x_ptr, w_q_ptr, scales_ptr, zeros_ptr, out_ptr,
N, K, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn, stride_zg, stride_zn,
GROUP_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_k_blocks = tl.cdiv(K, BLOCK_K)
pid_n = pid // num_k_blocks
k_idx = pid % num_k_blocks
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
k_start = k_idx * BLOCK_K
HALF_K: tl.constexpr = BLOCK_K // 2
k_even = k_start + tl.arange(0, HALF_K) * 2
k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
x_even = tl.load(x_ptr + k_even[None, :] * stride_xk, mask=k_even[None, :] < K, other=0.0)
x_odd = tl.load(x_ptr + k_odd[None, :] * stride_xk, mask=k_odd[None, :] < K, other=0.0)
packed_k = k_start // 2 + tl.arange(0, HALF_K)
w_packed = tl.load(
w_q_ptr + packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
mask=(packed_k[:, None] < K // 2) & mask_n[None, :],
other=0, eviction_policy="evict_last",
)
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = (w_packed >> 4).to(tl.float32)
g = k_start // GROUP_SIZE
s = tl.load(scales_ptr + g * stride_sg + offs_n * stride_sn, mask=mask_n, other=1.0).to(tl.float32)
z = tl.load(zeros_ptr + g * stride_zg + offs_n * stride_zn, mask=mask_n, other=0.0).to(tl.float32)
w_lo_dq = ((w_lo - z[None, :]) * s[None, :]).to(tl.bfloat16)
w_hi_dq = ((w_hi - z[None, :]) * s[None, :]).to(tl.bfloat16)
partial = tl.dot(x_even, w_lo_dq) + tl.dot(x_odd, w_hi_dq)
acc = partial.reshape(BLOCK_N)
tl.atomic_add(out_ptr + offs_n, acc, mask=mask_n)
M, N, K = 1, 128, 128
ref = reference.Model(M, N, K).cuda().eval()
torch.manual_seed(42)
x = torch.randn(1, K, dtype=torch.bfloat16, device='cuda')
acc = torch.zeros((N,), dtype=torch.float32, device=x.device)
BLOCK_N, BLOCK_K = 128, 128
num_pid_n = (N + BLOCK_N - 1) // BLOCK_N
num_k_blocks = (K + BLOCK_K - 1) // BLOCK_K
grid = (num_pid_n * num_k_blocks,)
splitk_full_kernel[grid](
x, ref.w_q, ref.scales, ref.zeros, acc,
N, K, x.stride(1),
ref.w_q.stride(0), ref.w_q.stride(1),
ref.scales.stride(0), ref.scales.stride(1), ref.zeros.stride(0), ref.zeros.stride(1),
GROUP_SIZE=GROUP_SIZE, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
sol_out = acc.to(torch.bfloat16).unsqueeze(0)
with torch.no_grad():
ref_out = ref(x)
print(f"ref_out: {ref_out[0, :5]}")
print(f"sol_out: {sol_out[0, :5]}")
print(f"max_diff: {(ref_out - sol_out).abs().max().item()}")
PYEOF
uv run python /tmp/test_splitk_full.pyref_out: tensor([ 0.4199, -0.1226, -0.0050, 0.1572, 0.0713], device='cuda:0',
dtype=torch.bfloat16)
sol_out: tensor([ 0.4199, -0.1226, -0.0050, 0.1572, 0.0713], device='cuda:0',
dtype=torch.bfloat16)
max_diff: 0.0
{
"command": "<770 chars \u2014 see below>",
"description": "Test solution.py's splitk function directly",
"timeout": 120000
}cat > /tmp/test_sol_splitk.py << 'PYEOF'
import torch
import sys
sys.path.insert(0, "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm")
import reference
import solution
M, N, K = 1, 128, 128
reference.M, reference.N, reference.K = M, N, K
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = torch.randn(1, K, dtype=torch.bfloat16, device='cuda')
with torch.no_grad():
ref_out = ref(x)
sol_out = solution.w4a16_gemv_splitk(x, ref.w_q, ref.scales, ref.zeros)
print(f"ref_out: {ref_out[0, :5]}")
print(f"sol_out: {sol_out[0, :5]}")
print(f"max_diff: {(ref_out - sol_out).abs().max().item()}")
PYEOF
uv run python /tmp/test_sol_splitk.pyref_out: tensor([ 0.4199, -0.1226, -0.0050, 0.1572, 0.0713], device='cuda:0',
dtype=torch.bfloat16)
sol_out: tensor([1896.0000, -552.0000, -22.6250, 708.0000, 320.0000],
device='cuda:0', dtype=torch.bfloat16)
max_diff: 2464.0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -8,17 +8,6 @@
GROUP_SIZE = 128
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
- triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_stages=5, num_warps=4),
- triton.Config({"BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4),
- triton.Config({"BLOCK_N": 128, "BLOCK_K": 64}, num_stages=5, num_warps=4),
- triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=3, num_warps=4),
- triton.Config({"BLOCK_N": 128, "BLOCK_K": 128}, num_stages=5, num_warps=2),
- ],
- key=["N", "K"],
-)
@triton.jit
def _w4a16_gemv_splitk_kernel(
x_ptr,
@@ -51,8 +40,6 @@
k_start = k_idx * BLOCK_K
HALF_K: tl.constexpr = BLOCK_K // 2
- offs_m = tl.zeros((1, 1), dtype=tl.int32)
-
k_even = k_start + tl.arange(0, HALF_K) * 2
k_odd = k_start + tl.arange(0, HALF_K) * 2 + 1
@@ -94,7 +81,6 @@
w_hi_dq = ((w_hi - z[None, :]) * s[None, :]).to(tl.bfloat16)
partial = tl.dot(x_even, w_lo_dq) + tl.dot(x_odd, w_hi_dq)
-
acc = partial.reshape(BLOCK_N)
tl.atomic_add(out_ptr + offs_n, acc, mask=mask_n)
@@ -105,11 +91,14 @@
N = w_q.shape[1]
assert M == 1
+ BLOCK_N = 128
+ BLOCK_K = 128
+
acc = torch.zeros((N,), dtype=torch.float32, device=x.device)
- grid = lambda META: (
- triton.cdiv(N, META["BLOCK_N"]) * triton.cdiv(K, META["BLOCK_K"]),
- )
+ num_pid_n = triton.cdiv(N, BLOCK_N)
+ num_k_blocks = triton.cdiv(K, BLOCK_K)
+ grid = (num_pid_n * num_k_blocks,)
_w4a16_gemv_splitk_kernel[grid](
x, w_q, scales, zeros, acc,
@@ -119,6 +108,8 @@
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
GROUP_SIZE=GROUP_SIZE,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
)
return acc.to(torch.bfloat16).unsqueeze(0)
{
"content": "<8889 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test with fixed split-K",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark split-K + GEMM hybrid",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.132 gbps=34.974 ms=0.765
shape=0 variant=compiled tflops=0.662 gbps=176.017 ms=0.152
shape=0 variant=sota tflops=2.300 gbps=611.556 ms=0.044
shape=0 variant=solution tflops=2.887 gbps=767.882 ms=0.035
shape=0 solution_peak_fraction=0.4266
shape=1 variant=eager tflops=4.182 gbps=36.072 ms=0.770
shape=1 variant=compiled tflops=20.698 gbps=178.545 ms=0.156
shape=1 variant=sota tflops=23.210 gbps=200.219 ms=0.139
shape=1 variant=solution tflops=36.249 gbps=312.694 ms=0.089
shape=1 solution_peak_fraction=0.1737
shape=2 variant=eager tflops=31.337 gbps=42.716 ms=0.822
shape=2 variant=compiled tflops=120.727 gbps=164.565 ms=0.213
shape=2 variant=sota tflops=133.716 gbps=182.271 ms=0.193
shape=2 variant=solution tflops=80.809 gbps=110.153 ms=0.319
shape=2 solution_peak_fraction=0.0612
shape=3 variant=eager tflops=0.182 gbps=48.453 ms=0.184
shape=3 variant=compiled tflops=0.512 gbps=136.117 ms=0
... (465 more chars)