"""W4A16 weight-only quantized GEMM via Triton.
Fused unpack + GEMM in one pass: int4 weights are streamed as packed bytes
(2 nibbles per byte), unpacked, dequantized per group, and matmul'd against
bf16 activations without ever materialising a full bf16 weight matrix.
Layout matches reference.py:
x: (M, K) bf16
w_q: (K // 2, N) uint8 (low nibble = even-K row, high = odd-K row)
scales: (K // g, N) bf16
zeros: (K // g, N) bf16
out: (M, N) bf16
Two kernels:
- gemv: M=1 SIMT path with split-K to saturate SMs on memory-bound decodes.
- gemm: M>=16 tensor-core path with hand-tuned tile sizes per M regime.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
SM_COUNT = 188 # RTX PRO 6000 Blackwell
# ---------------------------------------------------------------------------
# GEMM kernel (M >= 16): tensor-core path. BLOCK_K may be a multiple of GROUP_SIZE.
# ---------------------------------------------------------------------------
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
BLOCK_KH: tl.constexpr = BLOCK_K // 2
offs_kh = tl.arange(0, BLOCK_KH)
offs_k = tl.arange(0, BLOCK_K)
n_iters = K // BLOCK_K
n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
for k_iter in range(n_iters):
kbase_packed = k_iter * BLOCK_KH
kbase_full = k_iter * BLOCK_K
w_ptrs = (
w_ptr
+ (kbase_packed + offs_kh)[:, None] * stride_wk
+ offs_n[None, :] * stride_wn
)
w_packed = tl.load(w_ptrs)
w_lo = (w_packed & 0xF).to(tl.bfloat16)
w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
if n_groups_per_block == 1:
group_idx = kbase_full // GROUP_SIZE
s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
w_lo_bf = (w_lo - z[None, :]) * s[None, :]
w_hi_bf = (w_hi - z[None, :]) * s[None, :]
else:
group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
s = tl.load(
s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
)
z = tl.load(
z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn
)
w_lo_bf = (w_lo - z) * s
w_hi_bf = (w_hi - z) * s
w_join = tl.join(w_lo_bf, w_hi_bf) # (BLOCK_KH, BLOCK_N, 2)
w_perm = tl.permute(w_join, (0, 2, 1)) # (BLOCK_KH, 2, BLOCK_N)
w = tl.reshape(w_perm, (BLOCK_K, BLOCK_N)) # (BLOCK_K, BLOCK_N)
x_ptrs = (
x_ptr
+ offs_m[:, None] * stride_xm
+ (kbase_full + offs_k)[None, :] * stride_xk
)
x_mask = offs_m[:, None] < M
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
accumulator = tl.dot(x, w, accumulator)
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_mask = offs_m[:, None] < M
tl.store(out_ptrs, accumulator.to(tl.bfloat16), mask=out_mask)
# ---------------------------------------------------------------------------
# GEMV kernel (M == 1) with split-K. Writes (SPLIT_K, N) fp32 partials.
# ---------------------------------------------------------------------------
@triton.jit
def cast_fp32_to_bf16_kernel(
src_ptr, dst_ptr, N,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK_N + tl.arange(0, BLOCK_N)
mask = offs < N
val = tl.load(src_ptr + offs, mask=mask, other=0.0)
tl.store(dst_ptr + offs, val.to(tl.bfloat16), mask=mask)
@triton.jit
def w4a16_gemv_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
N, K,
stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
K_PER_SPLIT: tl.constexpr,
SPLIT_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_k = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
accumulator = tl.zeros((BLOCK_N,), dtype=tl.float32)
BLOCK_KH: tl.constexpr = BLOCK_K // 2
offs_kh = tl.arange(0, BLOCK_KH)
k_start_full = pid_k * K_PER_SPLIT
n_iters: tl.constexpr = K_PER_SPLIT // BLOCK_K
n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
for k_iter in range(n_iters):
kbase_full = k_start_full + k_iter * BLOCK_K
kbase_packed = kbase_full // 2
w_ptrs = (
w_ptr
+ (kbase_packed + offs_kh)[:, None] * stride_wk
+ offs_n[None, :] * stride_wn
)
w_packed = tl.load(w_ptrs)
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
x_offs_even = kbase_full + 2 * offs_kh
x_even = tl.load(x_ptr + x_offs_even * stride_xk).to(tl.float32)
x_odd = tl.load(x_ptr + (x_offs_even + 1) * stride_xk).to(tl.float32)
if n_groups_per_block == 1:
group_idx = kbase_full // GROUP_SIZE
s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn).to(tl.float32)
z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn).to(tl.float32)
w_lo_f = (w_lo - z[None, :]) * s[None, :]
w_hi_f = (w_hi - z[None, :]) * s[None, :]
else:
group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
s = tl.load(
s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
).to(tl.float32)
z = tl.load(
z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn
).to(tl.float32)
w_lo_f = (w_lo - z) * s
w_hi_f = (w_hi - z) * s
accumulator += tl.sum(x_even[:, None] * w_lo_f, axis=0)
accumulator += tl.sum(x_odd[:, None] * w_hi_f, axis=0)
out_ptrs = out_ptr + offs_n * stride_on
if SPLIT_K == 1:
tl.store(out_ptrs, accumulator.to(tl.bfloat16))
else:
# fp32 atomic_add into pre-zeroed fp32 buffer
tl.atomic_add(out_ptrs, accumulator)
# ---------------------------------------------------------------------------
# Dispatcher
# ---------------------------------------------------------------------------
def _gemv_config(N: int, K: int) -> tuple[int, int, int, int, int]:
"""Return (BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages) for M=1 GEMV."""
BLOCK_N = 64
n_tiles = (N + BLOCK_N - 1) // BLOCK_N
if n_tiles <= 64: # small N (e.g., shape 3: N=4096)
BLOCK_K = 512 if K % 512 == 0 else 256
SPLIT_K = 8
num_warps = 8
num_stages = 2
elif n_tiles <= 128:
BLOCK_K = 256
SPLIT_K = 4
num_warps = 4
num_stages = 3
else: # large N (e.g., shape 0: N=12288)
BLOCK_K = 256
SPLIT_K = 2
num_warps = 4
num_stages = 2
return BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
def _gemm_config(M: int, N: int, K: int) -> tuple[int, int, int, int, int]:
"""Return (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) for GEMM."""
if M <= 16:
return 16, 64, 256, 4, 2
if M <= 32:
return 32, 64, 256, 4, 2
if M <= 64:
return 64, 64, 256, 4, 2
return 128, 64, 256, 8, 2
def w4a16_gemm(
x: torch.Tensor,
w_q: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
group_size: int = GROUP_SIZE,
) -> torch.Tensor:
M, K = x.shape
K_packed, N = w_q.shape
assert K_packed * 2 == K
assert K % group_size == 0
if M == 1:
BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = _gemv_config(N, K)
K_PER_SPLIT = K // SPLIT_K
n_tiles = triton.cdiv(N, BLOCK_N)
grid = (n_tiles, SPLIT_K)
if SPLIT_K == 1:
out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
w4a16_gemv_kernel[grid](
x, w_q, scales, zeros, out,
N, K,
x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out.stride(1),
GROUP_SIZE=group_size,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
K_PER_SPLIT=K_PER_SPLIT,
SPLIT_K=SPLIT_K,
num_warps=num_warps,
num_stages=num_stages,
)
return out
# SPLIT_K > 1: fp32 atomic_add into pre-zeroed fp32 buffer, then cast.
out_fp32 = torch.zeros((1, N), dtype=torch.float32, device=x.device)
out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
w4a16_gemv_kernel[grid](
x, w_q, scales, zeros, out_fp32,
N, K,
x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out_fp32.stride(1),
GROUP_SIZE=group_size,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
K_PER_SPLIT=K_PER_SPLIT,
SPLIT_K=SPLIT_K,
num_warps=num_warps,
num_stages=num_stages,
)
cast_fp32_to_bf16_kernel[(triton.cdiv(N, 1024),)](
out_fp32, out, N, BLOCK_N=1024, num_warps=4, num_stages=2,
)
return out
BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = _gemm_config(M, N, K)
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
w4a16_gemm_kernel[grid](
x, w_q, scales, zeros, out,
M, N, K,
x.stride(0), x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out.stride(0), out.stride(1),
GROUP_SIZE=group_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
num_warps=num_warps,
num_stages=num_stages,
)
return out
class Model(nn.Module):
def __init__(self, M, N, K, group_size=GROUP_SIZE):
super().__init__()
assert K % group_size == 0
assert K % 2 == 0
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
self.register_buffer("w_q", torch.zeros((K // 2, N), dtype=torch.uint8))
self.register_buffer("scales", torch.zeros((n_groups, N), dtype=torch.bfloat16))
self.register_buffer("zeros", torch.zeros((n_groups, N), dtype=torch.bfloat16))
def forward(self, x):
return w4a16_gemm(
x.to(torch.bfloat16), self.w_q, self.scales, self.zeros, self.group_size
)
M = 1
N = 12288
K = 4096
def get_inputs():
x = torch.randn(M, K, dtype=torch.bfloat16)
return [x]
def get_init_inputs():
return [M, N, K]
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.777 ms=0.770
shape=0 variant=compiled tflops=0.645 gbps=171.524 ms=0.156
shape=0 variant=sota tflops=2.278 gbps=605.799 ms=0.044
shape=0 variant=solution tflops=1.620 gbps=430.797 ms=0.062
shape=0 solution_peak_fraction=0.2393
shape=1 variant=eager tflops=4.140 gbps=35.709 ms=0.778
shape=1 variant=compiled tflops=19.674 gbps=169.716 ms=0.164
shape=1 variant=sota tflops=21.758 gbps=187.691 ms=0.148
shape=1 variant=solution tflops=45.060 gbps=388.698 ms=0.071
shape=1 solution_peak_fraction=0.2159
shape=2 variant=eager tflops=31.030 gbps=42.298 ms=0.830
shape=2 variant=compiled tflops=117.135 gbps=159.670 ms=0.220
shape=2 variant=sota tflops=131.747 gbps=179.587 ms=0.196
shape=2 variant=solution tflops=137.214 gbps=187.038 ms=0.188
shape=2 solution_peak_fraction=0.1039
shape=3 variant=eager tflops=0.176 gbps=46.752 ms=0.191
shape=3 variant=compiled tflops=0.498 gbps=132.624 ms=0.067
shape=3 variant=sota tflops=1.195 gbps=317.994 ms=0.028
shape=3 variant=solution tflops=0.989 gbps=263.245 ms=0.034
shape=3 solution_peak_fraction=0.1462
shape=4 variant=eager tflops=2.047 gbps=34.619 ms=0.918
shape=4 variant=compiled tflops=10.113 gbps=171.063 ms=0.186
shape=4 variant=sota tflops=11.028 gbps=186.549 ms=0.170
shape=4 variant=solution tflops=28.360 gbps=479.730 ms=0.066
shape=4 solution_peak_fraction=0.2665
peak_fraction: 0.1837
RESULT: OK
PASS
{
"problem": "07_w4a16_gemm",
"harness": "claude",
"model": "claude-opus-4-7",
"reasoning_effort": "max",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1837,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 210, "output_tokens": 11216, "cache_read_tokens": 28906182, "cache_creation_tokens": 439853, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"command": "ls problems/",
"description": "List problems directory"
}01_fp8_gemm
02_kda_cutlass
03_paged_attention
04_kahan_softmax
05_topk_bitonic
06_sonic_moe_swiglu
07_w4a16_gemm
08_metal_lightning_attn{
"command": "ls problems/07_w4a16_gemm/",
"description": "List w4a16 problem"
}benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py{
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py"
}1 """Naive W4A16 weight-only quantized GEMM reference (correctness only).
2
3 AWQ/GPTQ-style scheme:
4 x: (M, K) bf16
5 w_q: (K // 2, N) uint8 -- two int4 weights packed per byte (low nibble = even-K, high = odd-K)
6 scales: (K // group, N) bf16
7 zeros: (K // group, N) bf16 -- asymmetric (stored already as float zero-point)
8 out: (M, N) bf16
9
10 Dequant (per group along K):
11 w_bf[k, n] = (w_q[k, n] - zeros[k // group, n]) * scales[k // group, n]
12 where w_q[k, n] is the unpacked 4-bit value (0..15).
13
14 This reference unpacks to a full bf16 matrix and then runs torch.matmul. Slow and
15 memory-heavy on the dequant; the agent's solution must fuse unpack+GEMM.
16 """
17 from __future__ import annotations
18
19 import torch
20 import torch.nn as nn
21
22 OP_TYPE = "gemm_w4a16"
23 SUPPORTED_PRECISIONS = ["int4_bf16"]
24 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25
26 GROUP_SIZE = 128
27
28
29 def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
30 """Pack (K, N) uint8 in [0,15] into (K//2, N) uint8.
31
32 Even rows go in the low nibble, odd rows in the high nibble.
33 """
34 K, N = w_q.shape
35 assert K % 2 == 0
36 lo = w_q[0::2].to(torch.uint8) & 0xF
37 hi = w_q[1::2].to(torch.uint8) & 0xF
38 return (lo | (hi << 4)).contiguous()
39
40
41 def _unpack_int4(w_packed: torch.Tensor, K: int) -> torch.Tensor:
42 """Unpack (K//2, N) uint8 -> (K, N)
... (3046 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py"
}1 """Correctness runner for FP8 GEMM.
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5 """
6 import re
7 import sys
8 from pathlib import Path
9
10 import torch
11 import yaml
12
13 # Make the repo's src/ importable
14 REPO_ROOT = Path(__file__).resolve().parents[2]
15 sys.path.insert(0, str(REPO_ROOT))
16
17 from src.eval.correctness import check_correctness # noqa: E402
18
19
20 def main():
21 try:
22 import reference
23 import shapes
24 import solution
25 except Exception as e:
26 print(f"FAIL: import error: {e}")
27 sys.exit(1)
28
29 problem_yaml = Path("problem.yaml")
30 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31
32 # --- Forbidden-op check ------------------------------------------------
33 sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34 for forbidden in meta.get("forbidden", []):
35 pat = re.escape(forbidden)
36 if re.search(pat, sol_src):
37 print(f"FAIL: forbidden op used: {forbidden}")
38 sys.exit(1)
39
40 device = torch.device("cuda:0")
41 tol_override = meta.get("tolerance") or None
42
43 # --- Per-shape correctness --------------------------------------------
44 all_shapes = shapes.SHAPES
45 for shape_idx, shape in enumerate(all_shapes):
46
... (2647 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py"
}1 """Roofline benchmark for FP8 GEMM.
2
3 For each shape: times eager reference, compiled reference, SOTA (if available),
4 and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5
6 Output lines the harness picks up:
7 shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8 peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9 """
10 import sys
11 from math import exp, log
12 from pathlib import Path
13
14 import torch
15 import yaml
16
17 REPO_ROOT = Path(__file__).resolve().parents[2]
18 sys.path.insert(0, str(REPO_ROOT))
19
20 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21 from src.eval.timing import time_fn # noqa: E402
22 from src.hardware import get as get_hw # noqa: E402
23
24
25 def _eval_formula(expr: str, vars: dict) -> float:
26 # Very small eval: only names from `vars` are valid.
27 return float(eval(expr, {"__builtins__": {}}, vars))
28
29
30 def main():
31 import reference
32 import shapes
33 import solution
34
35 meta = yaml.safe_load(Path("problem.yaml").read_text())
36 hw = get_hw(meta["hardware"][0])
37 peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38 peak_gbps = hw.peak_bandwidth_gb_s
39 regime = meta.get("regime", "compute")
40 flops_formula = meta["flops_formula"]
41 bytes_formula = meta["bytes_formula"]
42 num_perf_trials = int(meta.get("num_perf_trials", 30))
43
44 device = to
... (3133 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml"
}1 name: 07_w4a16_gemm
2 display_name: "W4A16 Weight-only Quantized GEMM"
3 precision: int4_bf16
4 regime: memory # decode-dominant; M=1 is bandwidth-bound on the int4 weight stream
5
6 # Dense-equivalent FLOPs (matmul work, ignoring dequant arithmetic).
7 flops_formula: "2 * M * N * K"
8
9 # Bytes moved per call (memory roofline):
10 # x: M*K*2 (bf16 activations, streamed in once)
11 # w_q: (K/2)*N (packed int4, 0.5 B/elem)
12 # scales: (K/128)*N*2 (bf16 scales)
13 # zeros: (K/128)*N*2 (bf16 zero-points)
14 # out: M*N*2 (bf16 store)
15 bytes_formula: "M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2"
16
17 hardware: [RTX_PRO_6000]
18 peak_tflops_key: bf16
19 peak_bandwidth_key: dram
20
21 tolerance:
22 bfloat16: 0.10 # group-quant adds noise on top of bf16 accumulator slop
23
24 # Forbidden ops -- agent must write the unpack + GEMM themselves, not call a
25 # vendor library that does both.
26 forbidden:
27 - "bitsandbytes.functional.dequantize_4bit"
28 - "bitsandbytes.functional.gemv_4bit"
29 - "marlin_kernel.gemm"
30 - "torch.nn.functional.linear"
31
32 sota:
33 name: "bitsandbytes NF4 (gemv_4bit / dequantize_4bit + matmul)"
34 url: "https://github.com/TimDettmers/bitsandbytes"
35 function: "bitsandbytes.functional.gemv_4bit"
36 notes: |
37 Marlin (IST-DASLab) is the W4A16 SOTA on Ampere/Hopper but does not have
38 SM120 (Blackwell consumer) kernels yet. GPTQ-Triton is unmaintained and
... (554 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py"
}1 """Shape sweep for W4A16 GEMM.
2
3 Llama-style up_proj / qkv_proj shapes. Decode (M=1) is the bandwidth-bound
4 case every inference engine optimizes -- it's the bar to beat.
5 """
6
7 SHAPES = [
8 {"M": 1, "N": 12288, "K": 4096}, # decode: memory-bound on int4 weight read
9 {"M": 32, "N": 12288, "K": 4096}, # small prefill: mixed regime
10 {"M": 256, "N": 12288, "K": 4096}, # larger prefill: approaching compute
11 {"M": 1, "N": 4096, "K": 4096}, # decode: square shape
12 {"M": 16, "N": 14336, "K": 4096}, # speculative-decode-ish
13 ]
14 {
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py"
}1 """SOTA reference for W4A16 GEMM.
2
3 Library survey on RTX PRO 6000 Blackwell (SM120, CC 12.0):
4
5 - Marlin (IST-DASLab): no SM120 kernels (Ampere/Hopper only). Skip.
6 - GPTQ-Triton (fpgaminer): unmaintained; pure Triton path works on SM120
7 but is not faster than Marlin on its target HW
8 and has no Blackwell tuning. Skip as primary.
9 - AWQ (mit-han-lab/llm-awq): CUDA kernels not built for SM120 in the wheel.
10 Skip.
11 - bitsandbytes >= 0.49.2: CUDA kernels compile and run on SM120 (verified
12 on this machine). Different quant scheme (NF4,
13 symmetric, blocksize 64) than our reference's
14 AWQ-style asymmetric INT4 with group_size 128,
15 but it occupies the same memory regime and is
16 the only tuned W4A16-class kernel that runs on
17 SM120 today. Used here as an *informational*
18 SOTA line, not as a numerical reference.
19
20 The benchmark calls `sota_forward(x, ref_model)` and times it; correctness is
21 NOT checked against this path (the quant scheme differs).
22 """
23 from __future__ import annotations
24
25 import torch
26
27 _BNB_OK: bool | None = None
28
29
30 def is_available() -> bool:
31 gl
... (2489 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/PROMPT.txt"
}1 I need you to write a weight-only int4 quantized GEMM (W4A16) for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2
3 The scheme is AWQ/GPTQ-style asymmetric int4 with explicit zero-points and per-group bf16 scales. Inputs are x of shape (M, K) bf16, w_q of shape (K // 2, N) uint8 (two int4 weights packed per byte, low nibble = even-K row, high nibble = odd-K row), scales of shape (K // 128, N) bf16, and zeros of shape (K // 128, N) bf16. Group size is 128 along K. Dequant per group is w_bf[k, n] = (unpack(w_q)[k, n] - zeros[k // 128, n]) * scales[k // 128, n], and the output is (M, N) bf16. Correctness tolerance is 0.10 abs/rel — group-quant adds noise on top of bf16 accumulator slop. The shapes you have to handle are M=1 N=12288 K=4096 (decode, memory-bound on the int4 weight read), M=32 N=12288 K=4096 (small prefill, mixed regime), M=256 N=12288 K=4096 (larger prefill, approaching compute-bound), M=1 N=4096 K=4096 (decode, square), and M=16 N=14336 K=4096 (speculative-decode-ish).
4
5 This needs to be a real custom kernel that fuses unpack and GEMM in the same pass — a separate dequant-then-matmul wastes the entire bandwidth advan
... (923 more chars){
"command": "nvidia-smi",
"description": "Check GPU availability and load"
}Tue Apr 28 08:39:53 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 39C P8 15W / 600W | 733MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (740 more chars){
"query": "select:TodoWrite",
"max_results": 1
}{
"todos": [
{
"content": "Explore existing scaffolding (src/eval, prior solutions in outputs/runs)",
"activeForm": "Exploring existing scaffolding",
"status": "in_progress"
},
{
"content": "Implement initial Triton W4A16 kernel (fused unpack+GEMM)",
"activeForm": "Implementing initial Triton W4A16 kernel",
"status": "pending"
},
{
"content": "Verify correctness with check.py",
"activeForm": "Verifying correctness",
"status": "pending"
},
{
"content": "Profile and benchmark with benchmark.py",
"acti
... (220 more chars)Todos have been modified successfully. Ensure that you continue to use the todo list to track your progress. Please proceed with the current tasks if applicable{
"command": "ls outputs/runs/ 2>/dev/null | head -30",
"description": "List previous runs"
}20260427_023632_opencode_deepseek_deepseek-v4-flash_05_topk_bitonic
20260427_032119_opencode_deepseek_deepseek-v4-flash_05_topk_bitonic
20260427_044055_opencode_deepseek_deepseek-v4-flash_01_fp8_gemm
20260427_044258_opencode_deepseek_deepseek-v4-flash_01_fp8_gemm
20260427_051426_opencode_deepseek_deepseek-v4-pro_01_fp8_gemm
20260427_055936_opencode_openrouter-pinned_minimax_minimax-m2.7_01_fp8_gemm
20260427_060216_opencode_openrouter-pinned_qwen_qwen3.6-27b_01_fp8_gemm
20260427_060329_opencode_openrouter-pinned_qwen_qwen3.6-35b-a3b_01_fp8_gemm
20260427_060330_opencode_deepseek_deepseek-v4-flash_02_kda_cutlass
20260427_065343_opencode_deepseek_deepseek-v4-pro_02_kda_cutlass
20260427_073847_opencode_openrouter-pinned_minimax_minimax-m2.7_02_kda_cutlass
20260427_082347_opencode_openrouter-pinned_qwen_qwen3.6-27b_02_kda_cutlass
20260427_082405_opencode_openrouter-pinned_qwen_qwen3.6-35b-a3b_02_kda_cutlass
20260427_082406_opencode_deepseek_deepseek-v4-flash_03_paged_attention
20260427_085225_opencode_deepseek_deepseek-v4-pro_03_paged_attention
20260427_092504_opencode_openrouter-pinned_minimax_minimax-m2.7_03_paged_attention
20260427_093353_opencode_openrouter-pinned_qwen_qwen3.6-27b_03_paged_attention
20260427_094003_opencode_openrouter-pinned_qwen_qwen3.6-35b-a3b_03_paged_attention
20260427_094004_opencode_deepseek_deepseek-v4-flash_04_kahan_softmax
20260427_095552_opencode_deepseek_deepseek-v4-pro_04_kahan_softmax
20260427_102645_opencode_openrouter-pinned_minimax_minimax-m2.7_
... (686 more chars){
"command": "ls outputs/runs/ | grep w4a16 | head -10",
"description": "Look at w4a16 prior runs"
}20260427_140540_opencode_deepseek_deepseek-v4-flash_07_w4a16_gemm
20260427_142403_opencode_deepseek_deepseek-v4-pro_07_w4a16_gemm
20260427_150902_opencode_openrouter-pinned_minimax_minimax-m2.7_07_w4a16_gemm
20260427_153033_opencode_openrouter-pinned_qwen_qwen3.6-27b_07_w4a16_gemm
20260427_153040_opencode_openrouter-pinned_qwen_qwen3.6-35b-a3b_07_w4a16_gemm
20260427_205128_opencode_openrouter-pinned_qwen_qwen3.6-27b_07_w4a16_gemm
20260428_010912_codex_gpt-5.5_07_w4a16_gemm
20260428_015550_codex_gpt-5.5_07_w4a16_gemm
20260428_083917_claude_claude-opus-4-7_07_w4a16_gemm{
"command": "ls -la problems/07_w4a16_gemm/",
"description": "Check for existing solution.py"
}drwxrwxr-x infatoshi infatoshi 4.0 KB Tue Apr 28 02:10:12 2026 .
drwxrwxr-x infatoshi infatoshi 4.0 KB Fri Apr 24 00:30:02 2026 ..
.rw-rw-r-- infatoshi infatoshi 4.1 KB Sun Apr 26 03:11:43 2026 benchmark.py
.rw-rw-r-- infatoshi infatoshi 3.7 KB Sat Apr 25 22:23:16 2026 check.py
.rw-rw-r-- infatoshi infatoshi 1.9 KB Sat Apr 25 22:22:46 2026 problem.yaml
.rw-rw-r-- infatoshi infatoshi 2.4 KB Mon Apr 27 03:21:13 2026 PROMPT.txt
.rw-rw-r-- infatoshi infatoshi 4.1 KB Sat Apr 25 22:25:27 2026 reference.py
.rw-rw-r-- infatoshi infatoshi 552 B Sat Apr 25 22:22:28 2026 shapes.py
.rw-rw-r-- infatoshi infatoshi 3.6 KB Sat Apr 25 22:24:51 2026 sota.py{
"command": "ls outputs/runs/20260428_083917_claude_claude-opus-4-7_07_w4a16_gemm/ 2>/dev/null",
"description": "Check if there's an in-progress run"
}stderr.log
transcript.jsonl{
"command": "cat outputs/runs/20260428_083917_claude_claude-opus-4-7_07_w4a16_gemm/stderr.log 2>/dev/null | head -20",
"description": "Check stderr log"
}(Bash completed with no output){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}1 """Per-dtype correctness tolerance.
2
3 Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4 We extend to fp8 (0.1, permissive) and int (exact).
5 """
6 from __future__ import annotations
7
8 import torch
9
10 DEFAULT_TOLERANCE = {
11 torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12 torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13 torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14 torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15 torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16 torch.int8: {"atol": 0, "rtol": 0},
17 torch.int32: {"atol": 0, "rtol": 0},
18 torch.int64: {"atol": 0, "rtol": 0},
19 }
20
21
22 def _coerce_float(x):
23 """Accept int/float/str scientific-notation as float; pass through dicts."""
24 if isinstance(x, (int, float)):
25 return float(x)
26 if isinstance(x, str):
27 try:
28 return float(x)
29 except ValueError:
30 return x
31 return x
32
33
34 def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35 """Lookup atol/rtol for a given dtype, with optional per-problem override.
36
37 The override dict is keyed by str(dtype). Values may be:
38 - a single number (or a string parseable as float, e.g. "5e-2") -> used
39 as both atol and rtol
40 - a dict {"atol": ..., "rtol": ...} -> used directly, with values coerced
41 """
42 # Accept several spellings of the same dty
... (2345 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py"
}1 """Roofline math: achieved TFLOPS / GB/s, peak fraction.
2
3 FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
4 dense-equivalent algorithmic work; kernels that skip work (sparsity, early
5 exit) are still measured against the dense cost.
6 """
7 from __future__ import annotations
8
9
10 def compute_tflops(flops: float, time_ms: float) -> float:
11 """flops per invocation and wall time in ms -> TFLOPS."""
12 if time_ms <= 0:
13 return 0.0
14 return flops / (time_ms * 1e-3) / 1e12
15
16
17 def compute_gbps(bytes_moved: float, time_ms: float) -> float:
18 """bytes per invocation and wall time in ms -> GB/s."""
19 if time_ms <= 0:
20 return 0.0
21 return bytes_moved / (time_ms * 1e-3) / 1e9
22
23
24 def peak_fraction(
25 achieved: float,
26 peak: float,
27 ) -> float:
28 """Return achieved / peak, clamped to [0, infinity)."""
29 if peak <= 0:
30 return 0.0
31 return max(0.0, achieved / peak)
32 {
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}1 """Centralized GPU timing for benchmark.py files.
2
3 All problems' benchmark.py should call `time_fn` from here so we have one
4 implementation of warmup, L2 flush, and CUDA event capture to maintain.
5
6 Methodology:
7 - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8 torch.compile reduce-overhead CUDA-graph capture.
9 - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10 (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11 we measure HBM-load bandwidth, not L2-cached bandwidth).
12 - GPU timing via cuda.Event with synchronize() AFTER record() but
13 BEFORE elapsed_time() — the canonical NVIDIA pattern.
14 - Reported value is the median of `iters` trials, robust to outliers.
15
16 Notes / known biases not addressed here:
17 - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18 launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19 On small shapes where launch overhead matters, this gives compile an
20 artificial advantage. We accept this as the cost of using the agreed
21 "compiled" baseline policy.
22 - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23 - Median over a fairly small number of trials (default 30) is fine for
24 headline numbers but won't catch bimodal distributions. Use --extra
25 instrumentation if you ever care.
26 """
27 from __future__ import annotations
28
29 import statistics
30
31 i
... (1271 more chars){
"command": "cat src/hardware/rtx_pro_6000.py",
"description": "Check hardware specs"
}"""RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
will see 60-85% of peak on well-tuned code.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class HardwareTarget:
name: str
sm: str
vram_gb: int
peak_bandwidth_gb_s: float # DRAM
peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
RTX_PRO_6000 = HardwareTarget(
name="RTX PRO 6000 Blackwell Workstation",
sm="sm_120a",
vram_gb=96,
peak_bandwidth_gb_s=1800.0,
peak_tflops_dense={
"fp4": 800.0,
"nvfp4": 800.0,
"mxfp4": 800.0,
"fp6": 800.0,
"fp8": 400.0,
"bf16": 200.0,
"fp16": 200.0,
"tf32": 100.0,
"fp32": 12.0, # non-tensor-core SIMT fp32
"int8": 400.0,
"int4": 800.0,
},
){
"command": "cd problems/07_w4a16_gemm && uv run python -c \"import torch, triton; print('torch', torch.__version__); print('triton', triton.__version__); print('cuda', torch.version.cuda); print('device', torch.cuda.get_device_name(0)); print('compute cap', torch.cuda.get_device_capability(0))\"",
"description": "Check torch and triton versions"
}torch 2.11.0+cu130
triton 3.6.0
cuda 13.0
device NVIDIA RTX PRO 6000 Blackwell Workstation Edition
compute cap (12, 0){
"todos": [
{
"content": "Explore existing scaffolding (src/eval, prior solutions in outputs/runs)",
"activeForm": "Exploring existing scaffolding",
"status": "completed"
},
{
"content": "Implement initial Triton W4A16 kernel (fused unpack+GEMM)",
"activeForm": "Implementing initial Triton W4A16 kernel",
"status": "in_progress"
},
{
"content": "Verify correctness with check.py",
"activeForm": "Verifying correctness",
"status": "pending"
},
{
"content": "Profile and benchmark with benchmark.py",
"ac
... (222 more chars)Todos have been modified successfully. Ensure that you continue to use the todo list to track your progress. Please proceed with the current tasks if applicable{
"command": "<448 chars \u2014 see below>",
"description": "Check Triton APIs"
}cd problems/07_w4a16_gemm && uv run python -c "
import triton
import triton.language as tl
print('triton', triton.__version__)
# Check available API
attrs = [a for a in dir(tl) if not a.startswith('_')]
print('has join:', 'join' in attrs)
print('has split:', 'split' in attrs)
print('has interleave:', 'interleave' in attrs)
print('has permute:', 'permute' in attrs)
print('has reshape:', 'reshape' in attrs)
print('has trans:', 'trans' in attrs)
"Exit code 1
(eval):cd:1: no such file or directory: problems/07_w4a16_gemm{
"command": "pwd && ls",
"description": "Check pwd"
}/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py{
"command": "<431 chars \u2014 see below>",
"description": "Check Triton APIs"
}uv run python -c "
import triton
import triton.language as tl
print('triton', triton.__version__)
attrs = [a for a in dir(tl) if not a.startswith('_')]
print('has join:', 'join' in attrs)
print('has split:', 'split' in attrs)
print('has interleave:', 'interleave' in attrs)
print('has permute:', 'permute' in attrs)
print('has reshape:', 'reshape' in attrs)
print('has trans:', 'trans' in attrs)
print('has dot:', 'dot' in attrs)
"triton 3.6.0
has join: True
has split: True
has interleave: True
has permute: True
has reshape: True
has trans: True
has dot: True--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -0,0 +1,172 @@
+"""W4A16 weight-only quantized GEMM via Triton.
+
+Fused unpack + GEMM in one pass: int4 weights are streamed as packed bytes
+(2 nibbles per byte), unpacked, dequantized per group, and matmul'd against
+bf16 activations without ever materialising a full bf16 weight matrix.
+
+Layout matches reference.py:
+ x: (M, K) bf16
+ w_q: (K // 2, N) uint8 (low nibble = even-K row, high = odd-K row)
+ scales: (K // g, N) bf16
+ zeros: (K // g, N) bf16
+ out: (M, N) bf16
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.jit
+def w4a16_gemm_kernel(
+ x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ BLOCK_KH: tl.constexpr = BLOCK_K // 2
+ offs_kh = tl.arange(0, BLOCK_KH)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ n_iters = K // BLOCK_K
+
+ for k_iter in range(n_iters):
+ kbase_packed = k_iter * BLOCK_KH
+ kbase_full = k_iter * BLOCK_K
+
+ # Load packed weights: (BLOCK_KH, BLOCK_N) uint8, contiguous along N.
+ w_ptrs = (
+ w_ptr
+ + (kbase_packed + offs_kh)[:, None] * stride_wk
+ + offs_n[None, :] * stride_wn
+ )
+ w_packed = tl.load(w_ptrs)
+
+ # Unpack: low nibble = even-K row, high nibble = odd-K row.
+ w_lo = (w_packed & 0xF).to(tl.bfloat16)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
+
+ # Group-wise scale/zero. BLOCK_K == GROUP_SIZE so one (s, z) per tile-N column.
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
+
+ w_lo_bf = (w_lo - z[None, :]) * s[None, :]
+ w_hi_bf = (w_hi - z[None, :]) * s[None, :]
+
+ # Interleave: row 2i = w_lo[i], row 2i+1 = w_hi[i].
+ w_join = tl.join(w_lo_bf, w_hi_bf) # (BLOCK_KH, BLOCK_N, 2)
+ w_perm = tl.permute(w_join, (0, 2, 1)) # (BLOCK_KH, 2, BLOCK_N)
+ w = tl.reshape(w_perm, (BLOCK_K, BLOCK_N)) # (BLOCK_K, BLOCK_N)
+
+ # Load activation tile.
+ x_ptrs = (
+ x_ptr
+ + offs_m[:, None] * stride_xm
+ + (kbase_full + offs_k)[None, :] * stride_xk
+ )
+ x_mask = offs_m[:, None] < M
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0)
+
+ accumulator = tl.dot(x, w, accumulator)
+
+ out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ out_mask = offs_m[:, None] < M
+ tl.store(out_ptrs, accumulator.to(tl.bfloat16), mask=out_mask)
+
+
+def w4a16_gemm(
+ x: torch.Tensor,
+ w_q: torch.Tensor,
+ scales: torch.Tensor,
+ zeros: torch.Tensor,
+ group_size: int = GROUP_SIZE,
+) -> torch.Tensor:
+ M, K = x.shape
+ K_packed, N = w_q.shape
+ assert K_packed * 2 == K
+ assert K % group_size == 0
+
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+
+ if M <= 16:
+ BLOCK_M = 16
+ elif M <= 32:
+ BLOCK_M = 32
+ elif M <= 64:
+ BLOCK_M = 64
+ else:
+ BLOCK_M = 128
+
+ BLOCK_N = 128
+ BLOCK_K = group_size
+
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+
+ w4a16_gemm_kernel[grid](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0), out.stride(1),
+ GROUP_SIZE=group_size,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ num_warps=4,
+ num_stages=4,
+ )
+ return out
+
+
+class Model(nn.Module):
+ def __init__(self, M, N, K, group_size=GROUP_SIZE):
+ super().__init__()
+ assert K % group_size == 0
+ assert K % 2 == 0
+ self.M, self.N, self.K = M, N, K
+ self.group_size = group_size
+ n_groups = K // group_size
+ self.register_buffer("w_q", torch.zeros((K // 2, N), dtype=torch.uint8))
+ self.register_buffer("scales", torch.zeros((n_groups, N), dtype=torch.bfloat16))
+ self.register_buffer("zeros", torch.zeros((n_groups, N), dtype=torch.bfloat16))
+
+ def forward(self, x):
+ return w4a16_gemm(
+ x.to(torch.bfloat16), self.w_q, self.scales, self.zeros, self.group_size
+ )
+
+
+M = 1
+N = 12288
+K = 4096
+
+
+def get_inputs():
+ x = torch.randn(M, K, dtype=torch.bfloat16)
+ return [x]
+
+
+def get_init_inputs():
+ return [M, N, K]
{
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<4977 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -30",
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 156, in forward
return w4a16_gemm(
^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 125, in w4a16_gemm
w4a16_gemm_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launc
... (812 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -135,7 +135,7 @@
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
num_warps=4,
- num_stages=4,
+ num_stages=2,
)
return out
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<42 chars \u2014 see diff>",
"new_string": "<42 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -20",
"description": "Run correctness check with stages=2"
}PASS{
"command": "uv run python benchmark.py 2>&1 | tail -50",
"description": "Run benchmark"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.906 ms=0.767
shape=0 variant=compiled tflops=0.637 gbps=169.405 ms=0.158
shape=0 variant=sota tflops=1.848 gbps=491.400 ms=0.054
shape=0 variant=solution tflops=0.996 gbps=264.875 ms=0.101
shape=0 solution_peak_fraction=0.1472
shape=1 variant=eager tflops=4.132 gbps=35.644 ms=0.780
shape=1 variant=compiled tflops=19.977 gbps=172.326 ms=0.161
shape=1 variant=sota tflops=22.962 gbps=198.073 ms=0.140
shape=1 variant=solution tflops=26.819 gbps=231.345 ms=0.120
shape=1 solution_peak_fraction=0.1285
shape=2 variant=eager tflops=31.246 gbps=42.591 ms=0.825
shape=2 variant=compiled tflops=118.900 gbps=162.074 ms=0.217
shape=2 variant=sota tflops=132.093 gbps=180.059 ms=0.195
shape=2 variant=solution tflops=75.844 gbps=103.384 ms=0.340
shape=2 solution_peak_fraction=0.0574
shape=3 variant=eager tflops=0.177 gbps=47.223 ms=0.189
shape=3 variant=compiled tflops=0.505 gbps=134.380 ms=0
... (464 more chars){
"command": "<338 chars \u2014 see below>",
"description": "Check SM properties"
}uv run python -c "
import torch
from torch.cuda import get_device_properties
p = get_device_properties(0)
print('SMs:', p.multi_processor_count)
print('SMEM per block:', p.shared_memory_per_block)
print('SMEM per SM:', p.shared_memory_per_multiprocessor)
print('regs per SM:', p.regs_per_multiprocessor)
print('warp size:', p.warp_size)
"SMs: 188
SMEM per block: 49152
SMEM per SM: 102400
regs per SM: 65536
warp size: 32--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -10,6 +10,10 @@
scales: (K // g, N) bf16
zeros: (K // g, N) bf16
out: (M, N) bf16
+
+Two kernels:
+ - gemv: M=1 SIMT path, no tensor cores, maximises bandwidth on the int4 stream.
+ - gemm: M>=16 tensor-core path with hand-tuned tile sizes per M regime.
"""
from __future__ import annotations
@@ -21,6 +25,9 @@
GROUP_SIZE = 128
+# ---------------------------------------------------------------------------
+# GEMM kernel (M >= 16): tensor-core path
+# ---------------------------------------------------------------------------
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
@@ -65,7 +72,7 @@
w_lo = (w_packed & 0xF).to(tl.bfloat16)
w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- # Group-wise scale/zero. BLOCK_K == GROUP_SIZE so one (s, z) per tile-N column.
+ # Group-wise scale/zero. BLOCK_K == GROUP_SIZE => one (s, z) per N column.
group_idx = kbase_full // GROUP_SIZE
s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
@@ -73,12 +80,11 @@
w_lo_bf = (w_lo - z[None, :]) * s[None, :]
w_hi_bf = (w_hi - z[None, :]) * s[None, :]
- # Interleave: row 2i = w_lo[i], row 2i+1 = w_hi[i].
+ # Interleave (row 2i = w_lo[i], row 2i+1 = w_hi[i]) -> (BLOCK_K, BLOCK_N).
w_join = tl.join(w_lo_bf, w_hi_bf) # (BLOCK_KH, BLOCK_N, 2)
w_perm = tl.permute(w_join, (0, 2, 1)) # (BLOCK_KH, 2, BLOCK_N)
w = tl.reshape(w_perm, (BLOCK_K, BLOCK_N)) # (BLOCK_K, BLOCK_N)
- # Load activation tile.
x_ptrs = (
x_ptr
+ offs_m[:, None] * stride_xm
@@ -94,6 +100,69 @@
tl.store(out_ptrs, accumulator.to(tl.bfloat16), mask=out_mask)
+# ---------------------------------------------------------------------------
+# GEMV kernel (M == 1): SIMT path, no tensor cores. Memory-bound on int4 stream.
+# ---------------------------------------------------------------------------
+@triton.jit
+def w4a16_gemv_kernel(
+ x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
+ N, K,
+ stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_n = tl.program_id(0)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ accumulator = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ BLOCK_KH: tl.constexpr = BLOCK_K // 2
+ offs_kh = tl.arange(0, BLOCK_KH)
+
+ n_iters = K // BLOCK_K
+
+ for k_iter in range(n_iters):
+ kbase_packed = k_iter * BLOCK_KH
+ kbase_full = k_iter * BLOCK_K
+
+ # Load packed weights: (BLOCK_KH, BLOCK_N) uint8, contiguous along N.
+ w_ptrs = (
+ w_ptr
+ + (kbase_packed + offs_kh)[:, None] * stride_wk
+ + offs_n[None, :] * stride_wn
+ )
+ w_packed = tl.load(w_ptrs)
+ w_lo = (w_packed & 0xF).to(tl.float32)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
+
+ # Load x slice as even/odd halves.
+ x_offs_even = kbase_full + 2 * offs_kh
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(x_ptr + x_offs_even * stride_xk).to(tl.float32)
+ x_odd = tl.load(x_ptr + x_offs_odd * stride_xk).to(tl.float32)
+
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn).to(tl.float32)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn).to(tl.float32)
+
+ w_lo_f = (w_lo - z[None, :]) * s[None, :]
+ w_hi_f = (w_hi - z[None, :]) * s[None, :]
+
+ accumulator += tl.sum(x_even[:, None] * w_lo_f, axis=0)
+ accumulator += tl.sum(x_odd[:, None] * w_hi_f, axis=0)
+
+ out_ptrs = out_ptr + offs_n * stride_on
+ tl.store(out_ptrs, accumulator.to(tl.bfloat16))
+
+
+# ---------------------------------------------------------------------------
+# Dispatcher
+# ---------------------------------------------------------------------------
def w4a16_gemm(
x: torch.Tensor,
w_q: torch.Tensor,
@@ -108,17 +177,52 @@
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ if M == 1:
+ # GEMV path
+ BLOCK_N = 128
+ BLOCK_K = group_size
+ grid = (triton.cdiv(N, BLOCK_N),)
+ w4a16_gemv_kernel[grid](
+ x, w_q, scales, zeros, out,
+ N, K,
+ x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(1),
+ GROUP_SIZE=group_size,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ num_warps=4,
+ num_stages=4,
+ )
+ return out
+
+ # GEMM path
if M <= 16:
BLOCK_M = 16
+ BLOCK_N = 128
+ BLOCK_K = group_size
+ num_warps = 4
+ num_stages = 2
elif M <= 32:
BLOCK_M = 32
+ BLOCK_N = 128
+ BLOCK_K = group_size
+ num_warps = 4
+ num_stages = 2
elif M <= 64:
BLOCK_M = 64
+ BLOCK_N = 128
+ BLOCK_K = group_size
+ num_warps = 4
+ num_stages = 2
else:
BLOCK_M = 128
-
- BLOCK_N = 128
- BLOCK_K = group_size
+ BLOCK_N = 128
+ BLOCK_K = group_size
+ num_warps = 8
+ num_stages = 2
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
@@ -134,8 +238,8 @@
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
- num_warps=4,
- num_stages=2,
+ num_warps=num_warps,
+ num_stages=num_stages,
)
return out
... (2 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<8522 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -10",
"description": "Verify correctness"
}PASS{
"command": "uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Run benchmark"
}shape=0 variant=eager tflops=0.131 gbps=34.892 ms=0.767
shape=0 variant=compiled tflops=0.640 gbps=170.181 ms=0.157
shape=0 variant=sota tflops=2.263 gbps=601.876 ms=0.044
shape=0 variant=solution tflops=1.053 gbps=280.083 ms=0.096
shape=0 solution_peak_fraction=0.1556
shape=1 variant=eager tflops=4.141 gbps=35.718 ms=0.778
shape=1 variant=compiled tflops=20.209 gbps=174.333 ms=0.159
shape=1 variant=sota tflops=22.831 gbps=196.950 ms=0.141
shape=1 variant=solution tflops=26.676 gbps=230.118 ms=0.121
shape=1 solution_peak_fraction=0.1278
shape=2 variant=eager tflops=31.221 gbps=42.558 ms=0.825
shape=2 variant=compiled tflops=118.471 gbps=161.490 ms=0.218
shape=2 variant=sota tflops=131.157 gbps=178.783 ms=0.196
shape=2 variant=solution tflops=95.196 gbps=129.763 ms=0.271
shape=2 solution_peak_fraction=0.0721
shape=3 variant=eager tflops=0.175 gbps=46.619 ms=0.192
shape=3 variant=compiled tflops=0.505 gbps=134.510 ms=0.066
shape=3 variant=sota tflops=1.035 gbps=275.459 ms=0.032
shape=3 variant=solution tflops=0.386 gbps=102.588 ms=0.087
shape=3 solution_peak_fraction=0.0570
shape=4 variant=eager tflops=2.042 gbps=34.546 ms=0.920
shape=4 variant=compiled tflops=10.280 gbps=173.894 ms=0.183
shape=4 variant=sota tflops=11.246 gbps=190.229 ms=0.167
shape=4 variant=solution tflops=17.466 gbps=295.443 ms=0.108
shape=4 solution_peak_fraction=0.1641
peak_fraction: 0.1061
RESULT: OK--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -12,7 +12,7 @@
out: (M, N) bf16
Two kernels:
- - gemv: M=1 SIMT path, no tensor cores, maximises bandwidth on the int4 stream.
+ - gemv: M=1 SIMT path with split-K to saturate SMs on memory-bound decodes.
- gemm: M>=16 tensor-core path with hand-tuned tile sizes per M regime.
"""
from __future__ import annotations
@@ -23,6 +23,7 @@
import triton.language as tl
GROUP_SIZE = 128
+SM_COUNT = 188 # RTX PRO 6000 Blackwell
# ---------------------------------------------------------------------------
@@ -60,7 +61,6 @@
kbase_packed = k_iter * BLOCK_KH
kbase_full = k_iter * BLOCK_K
- # Load packed weights: (BLOCK_KH, BLOCK_N) uint8, contiguous along N.
w_ptrs = (
w_ptr
+ (kbase_packed + offs_kh)[:, None] * stride_wk
@@ -68,11 +68,9 @@
)
w_packed = tl.load(w_ptrs)
- # Unpack: low nibble = even-K row, high nibble = odd-K row.
w_lo = (w_packed & 0xF).to(tl.bfloat16)
w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- # Group-wise scale/zero. BLOCK_K == GROUP_SIZE => one (s, z) per N column.
group_idx = kbase_full // GROUP_SIZE
s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
@@ -80,7 +78,6 @@
w_lo_bf = (w_lo - z[None, :]) * s[None, :]
w_hi_bf = (w_hi - z[None, :]) * s[None, :]
- # Interleave (row 2i = w_lo[i], row 2i+1 = w_hi[i]) -> (BLOCK_K, BLOCK_N).
w_join = tl.join(w_lo_bf, w_hi_bf) # (BLOCK_KH, BLOCK_N, 2)
w_perm = tl.permute(w_join, (0, 2, 1)) # (BLOCK_KH, 2, BLOCK_N)
w = tl.reshape(w_perm, (BLOCK_K, BLOCK_N)) # (BLOCK_K, BLOCK_N)
@@ -101,7 +98,7 @@
# ---------------------------------------------------------------------------
-# GEMV kernel (M == 1): SIMT path, no tensor cores. Memory-bound on int4 stream.
+# GEMV kernel (M == 1) with split-K. Writes fp32 partial sums of shape (SPLIT_K, N).
# ---------------------------------------------------------------------------
@triton.jit
def w4a16_gemv_kernel(
@@ -111,12 +108,14 @@
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
- stride_on,
+ stride_obs, stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
+ K_PER_SPLIT: tl.constexpr,
):
pid_n = tl.program_id(0)
+ pid_k = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
accumulator = tl.zeros((BLOCK_N,), dtype=tl.float32)
@@ -124,13 +123,13 @@
BLOCK_KH: tl.constexpr = BLOCK_K // 2
offs_kh = tl.arange(0, BLOCK_KH)
- n_iters = K // BLOCK_K
+ k_start_full = pid_k * K_PER_SPLIT
+ n_iters: tl.constexpr = K_PER_SPLIT // BLOCK_K
for k_iter in range(n_iters):
- kbase_packed = k_iter * BLOCK_KH
- kbase_full = k_iter * BLOCK_K
-
- # Load packed weights: (BLOCK_KH, BLOCK_N) uint8, contiguous along N.
+ kbase_full = k_start_full + k_iter * BLOCK_K
+ kbase_packed = kbase_full // 2
+
w_ptrs = (
w_ptr
+ (kbase_packed + offs_kh)[:, None] * stride_wk
@@ -140,11 +139,9 @@
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
- # Load x slice as even/odd halves.
x_offs_even = kbase_full + 2 * offs_kh
- x_offs_odd = x_offs_even + 1
x_even = tl.load(x_ptr + x_offs_even * stride_xk).to(tl.float32)
- x_odd = tl.load(x_ptr + x_offs_odd * stride_xk).to(tl.float32)
+ x_odd = tl.load(x_ptr + (x_offs_even + 1) * stride_xk).to(tl.float32)
group_idx = kbase_full // GROUP_SIZE
s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn).to(tl.float32)
@@ -156,13 +153,26 @@
accumulator += tl.sum(x_even[:, None] * w_lo_f, axis=0)
accumulator += tl.sum(x_odd[:, None] * w_hi_f, axis=0)
- out_ptrs = out_ptr + offs_n * stride_on
- tl.store(out_ptrs, accumulator.to(tl.bfloat16))
+ out_ptrs = out_ptr + pid_k * stride_obs + offs_n * stride_on
+ tl.store(out_ptrs, accumulator)
# ---------------------------------------------------------------------------
# Dispatcher
# ---------------------------------------------------------------------------
+def _choose_split_k(K: int, n_tiles: int, block_k: int) -> int:
+ n_k_iters = K // block_k
+ candidates = [s for s in [1, 2, 4, 8, 16, 32] if n_k_iters % s == 0]
+ target = SM_COUNT
+ chosen = 1
+ for s in candidates:
+ if n_tiles * s >= target:
+ chosen = s
+ break
+ chosen = s
+ return chosen
+
+
def w4a16_gemm(
x: torch.Tensor,
w_q: torch.Tensor,
@@ -175,30 +185,36 @@
assert K_packed * 2 == K
assert K % group_size == 0
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
-
if M == 1:
- # GEMV path
- BLOCK_N = 128
- BLOCK_K = group_size
- grid = (triton.cdiv(N, BLOCK_N),)
+ BLOCK_N = 128
+ BLOCK_K = group_size
+ n_tiles = triton.cdiv(N, BLOCK_N)
+ SPLIT_K = _choose_split_k(K, n_tiles, BLOCK_K)
+ K_PER_SPLIT = K // SPLIT_K
+
+ partials = torch.empty((SPLIT_K, N), dtype=torch.float32, device=x.device)
+ grid = (n_tiles, SPLIT_K)
w4a16_gemv_kernel[grid](
- x, w_q, scales, zeros, out,
+ x, w_q, scales, zeros, partials,
N, K,
x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
- out.stride(1),
+ partials.stride(0), partials.stride(1),
GROUP_SIZE=group_size,
... (477 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<8963 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -10",
"description": "Verify correctness"
}PASS{
"command": "uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Run benchmark with split-K GEMV"
}shape=0 variant=eager tflops=0.131 gbps=34.854 ms=0.768
shape=0 variant=compiled tflops=0.636 gbps=169.251 ms=0.158
shape=0 variant=sota tflops=2.191 gbps=582.596 ms=0.046
shape=0 variant=solution tflops=1.238 gbps=329.114 ms=0.081
shape=0 solution_peak_fraction=0.1828
shape=1 variant=eager tflops=4.130 gbps=35.630 ms=0.780
shape=1 variant=compiled tflops=19.767 gbps=170.516 ms=0.163
shape=1 variant=sota tflops=22.639 gbps=195.289 ms=0.142
shape=1 variant=solution tflops=26.273 gbps=226.635 ms=0.123
shape=1 solution_peak_fraction=0.1259
shape=2 variant=eager tflops=31.140 gbps=42.447 ms=0.828
shape=2 variant=compiled tflops=117.366 gbps=159.984 ms=0.220
shape=2 variant=sota tflops=132.115 gbps=180.088 ms=0.195
shape=2 variant=solution tflops=95.027 gbps=129.533 ms=0.271
shape=2 solution_peak_fraction=0.0720
shape=3 variant=eager tflops=0.178 gbps=47.492 ms=0.188
shape=3 variant=compiled tflops=0.464 gbps=123.414 ms=0.072
shape=3 variant=sota tflops=1.259 gbps=334.982 ms=0.027
shape=3 variant=solution tflops=0.847 gbps=225.487 ms=0.040
shape=3 solution_peak_fraction=0.1253
shape=4 variant=eager tflops=2.049 gbps=34.668 ms=0.917
shape=4 variant=compiled tflops=10.180 gbps=172.205 ms=0.185
shape=4 variant=sota tflops=11.008 gbps=186.199 ms=0.171
shape=4 variant=solution tflops=17.035 gbps=288.158 ms=0.110
shape=4 solution_peak_fraction=0.1601
peak_fraction: 0.1271
RESULT: OK{
"command": "<753 chars \u2014 see below>",
"description": "Quick timing of M=1 case"
}uv run python -c "
import torch
import solution
import reference
torch.manual_seed(42)
M, N, K = 1, 12288, 4096
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
# warmup
for _ in range(20): sol(x)
torch.cuda.synchronize()
import time
# Pure timing without L2 flush
t0 = time.perf_counter()
n = 1000
for _ in range(n):
sol(x)
torch.cuda.synchronize()
t1 = time.perf_counter()
ms = (t1-t0)/n*1000
print(f'avg ms: {ms:.4f}')
# bytes_formula
bytes_moved = M*K*2 + (K//2)*N + (K//128)*N*2 + (K//128)*N*2 + M*N*2
print(f'bytes: {bytes_moved/1e6:.1f}MB')
print(f'gbps: {bytes_moved/(ms*1e-3)/1e9:.1f}')
"avg ms: 0.0463
bytes: 26.8MB
gbps: 577.6{
"todos": [
{
"content": "Explore existing scaffolding (src/eval, prior solutions in outputs/runs)",
"activeForm": "Exploring existing scaffolding",
"status": "completed"
},
{
"content": "Implement initial Triton W4A16 kernel (fused unpack+GEMM)",
"activeForm": "Implementing initial Triton W4A16 kernel",
"status": "completed"
},
{
"content": "Verify correctness with check.py",
"activeForm": "Verifying correctness",
"status": "completed"
},
{
"content": "Profile and benchmark with benchmark.py",
"ac
... (365 more chars)Todos have been modified successfully. Ensure that you continue to use the todo list to track your progress. Please proceed with the current tasks if applicable--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune.py
@@ -0,0 +1,61 @@
+"""Quick tuner: sweep configs for each shape with L2 flush."""
+import statistics
+import time
+
+import torch
+
+import reference
+import shapes as shapes_mod
+import solution
+
+
+def time_fn(fn, inputs, iters=30, warmup=10, flush=True):
+ if flush:
+ scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
+ for _ in range(warmup):
+ fn(*inputs)
+ torch.cuda.synchronize()
+ times = []
+ for _ in range(iters):
+ if flush:
+ scratch.zero_()
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ fn(*inputs)
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
+
+
+def bytes_for(M, N, K):
+ return M * K * 2 + (K // 2) * N + (K // 128) * N * 2 + (K // 128) * N * 2 + M * N * 2
+
+
+def main():
+ device = torch.device("cuda:0")
+
+ for sh_idx, shape in enumerate(shapes_mod.SHAPES):
+ M, N, K = shape["M"], shape["N"], shape["K"]
+ reference.M = M
+ reference.N = N
+ reference.K = K
+
+ ref = reference.Model(*reference.get_init_inputs()).to(device).eval()
+ sol = solution.Model(*reference.get_init_inputs()).to(device).eval()
+ sol.load_state_dict(ref.state_dict(), strict=True)
+
+ torch.manual_seed(42)
+ x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
+
+ ms = time_fn(lambda x: sol(x), (x,))
+ bm = bytes_for(M, N, K)
+ gbps = bm / (ms * 1e-3) / 1e9
+ tflops = 2 * M * N * K / (ms * 1e-3) / 1e12
+ print(f"shape={sh_idx} M={M} N={N} K={K} ms={ms:.4f} gbps={gbps:.1f} tflops={tflops:.2f}")
+
+
+if __name__ == "__main__":
+ main()
{
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune.py",
"content": "<1763 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune.py (file state is current in your context — no need to Read it back){
"command": "uv run python _tune.py 2>&1 | tail -10",
"description": "Quick tuner baseline"
}shape=0 M=1 N=12288 K=4096 ms=0.0684 gbps=391.3 tflops=1.47
shape=1 M=32 N=12288 K=4096 ms=0.1194 gbps=232.7 tflops=26.98
shape=2 M=256 N=12288 K=4096 ms=0.2706 gbps=129.8 tflops=95.22
shape=3 M=1 N=4096 K=4096 ms=0.0437 gbps=204.1 tflops=0.77
shape=4 M=16 N=14336 K=4096 ms=0.1062 gbps=299.3 tflops=17.69--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemv.py
@@ -0,0 +1,189 @@
+"""Sweep GEMV configs for M=1 shapes."""
+import statistics
+
+import torch
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+SM_COUNT = 188
+
+
+@triton.jit
+def gemv_kernel(
+ x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
+ N, K,
+ stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_obs, stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ K_PER_SPLIT: tl.constexpr,
+):
+ pid_n = tl.program_id(0)
+ pid_k = tl.program_id(1)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ accumulator = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ BLOCK_KH: tl.constexpr = BLOCK_K // 2
+ offs_kh = tl.arange(0, BLOCK_KH)
+
+ k_start_full = pid_k * K_PER_SPLIT
+ n_iters: tl.constexpr = K_PER_SPLIT // BLOCK_K
+ n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
+
+ for k_iter in range(n_iters):
+ kbase_full = k_start_full + k_iter * BLOCK_K
+ kbase_packed = kbase_full // 2
+
+ w_ptrs = (
+ w_ptr
+ + (kbase_packed + offs_kh)[:, None] * stride_wk
+ + offs_n[None, :] * stride_wn
+ )
+ w_packed = tl.load(w_ptrs)
+ w_lo = (w_packed & 0xF).to(tl.float32)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
+
+ x_offs_even = kbase_full + 2 * offs_kh
+ x_even = tl.load(x_ptr + x_offs_even * stride_xk).to(tl.float32)
+ x_odd = tl.load(x_ptr + (x_offs_even + 1) * stride_xk).to(tl.float32)
+
+ # Per-group scale/zero. Each row of (BLOCK_KH,) maps to a group.
+ # Compute group index for even-K rows (k_full = kbase_full + 2*kh).
+ # group = (kbase_full + 2*kh) // GROUP_SIZE. For even and odd, group is the same
+ # (both 2*kh and 2*kh+1 fall in the same group when GROUP_SIZE >= 2).
+ if n_groups_per_block == 1:
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn).to(tl.float32)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn).to(tl.float32)
+ w_lo_f = (w_lo - z[None, :]) * s[None, :]
+ w_hi_f = (w_hi - z[None, :]) * s[None, :]
+ else:
+ # General: group per row
+ group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
+ s = tl.load(
+ s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ ).to(tl.float32)
+ z = tl.load(
+ z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn
+ ).to(tl.float32)
+ w_lo_f = (w_lo - z) * s
+ w_hi_f = (w_hi - z) * s
+
+ accumulator += tl.sum(x_even[:, None] * w_lo_f, axis=0)
+ accumulator += tl.sum(x_odd[:, None] * w_hi_f, axis=0)
+
+ out_ptrs = out_ptr + pid_k * stride_obs + offs_n * stride_on
+ tl.store(out_ptrs, accumulator)
+
+
+def run(M, N, K, w_q, scales, zeros, x, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages):
+ K_PER_SPLIT = K // SPLIT_K
+ n_tiles = (N + BLOCK_N - 1) // BLOCK_N
+ partials = torch.empty((SPLIT_K, N), dtype=torch.float32, device=x.device)
+ grid = (n_tiles, SPLIT_K)
+ gemv_kernel[grid](
+ x, w_q, scales, zeros, partials,
+ N, K,
+ x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ partials.stride(0), partials.stride(1),
+ GROUP_SIZE=GROUP_SIZE,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ K_PER_SPLIT=K_PER_SPLIT,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return partials.sum(dim=0).to(torch.bfloat16).view(M, N)
+
+
+def time_call(fn, args, iters=30, warmup=10, flush=True):
+ if flush:
+ scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
+ for _ in range(warmup):
+ fn(*args)
+ torch.cuda.synchronize()
+ times = []
+ for _ in range(iters):
+ if flush:
+ scratch.zero_()
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ fn(*args)
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
+
+
+def main():
+ import reference
+
+ for shape_name, (M, N, K) in [("0", (1, 12288, 4096)), ("3", (1, 4096, 4096))]:
+ reference.M = M
+ reference.N = N
+ reference.K = K
+
+ ref = reference.Model(M, N, K).cuda().eval()
+ x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
+ ref_out = ref(x).clone()
+
+ bytes_moved = M * K * 2 + (K // 2) * N + (K // 128) * N * 2 + (K // 128) * N * 2 + M * N * 2
+ print(f"\n=== shape {shape_name} M={M} N={N} K={K} bytes={bytes_moved/1e6:.1f}MB ===")
+
+ configs = []
+ for bn in [64, 128, 256]:
+ for bk in [128, 256, 512]:
+ if K % bk != 0:
+ continue
+ # Compute SPLIT_K candidates
+ n_tiles = (N + bn - 1) // bn
+ n_k_iters = K // bk
+ for sk in [1, 2, 4, 8, 16]:
+ if n_k_iters % sk != 0:
+ continue
+ if K % (sk * bk) != 0:
+ continue
+ if n_tiles * sk < 32 or n_tiles * sk > 4 * SM_COUNT:
+ continue
+ for nw in [2, 4, 8]:
+ for ns in [2, 3, 4]:
+ configs.append((bn, bk, sk, nw, ns))
+
+ results = []
+ for cfg in configs:
+ bn, bk, sk, nw, ns = cfg
+ try:
+
... (861 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemv.py",
"content": "<6559 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemv.py (file state is current in your context — no need to Read it back){
"command": "uv run python _tune_gemv.py 2>&1 | tail -30",
"description": "Tune GEMV configs"
}Command running in background with ID: bn21zx0ok. Output is being written to: /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output{
"command": "cat /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output 2>&1 | tail -40",
"description": "Check tune output"
}(Bash completed with no output){
"command": "ls /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/",
"description": "List task outputs"
}bbl0fw3un.output
bn21zx0ok.output{
"command": "tail -100 /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output",
"description": "Tail tuner output"
}(Bash completed with no output){
"command": "wc -l /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output ; ls -la /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output",
"description": "Check tuner output size"
}0 /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output
.rw-rw-r-- infatoshi infatoshi 0 B Tue Apr 28 08:54:32 2026 /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output{
"command": "sleep 60 && wc -l /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output 2>&1",
"description": "Wait and check"
}<tool_use_error>Blocked: sleep 60 followed by: wc -l /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output. To wait for a condition, use Monitor with an until-loop (e.g. `until <check>; do sleep 2; done`). To wait for a command you started, use run_in_background: true. Do not chain shorter sleeps to work around this block.</tool_use_error>{
"query": "select:Monitor",
"max_results": 1
}{
"command": "until [ -s /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output ] && grep -q 'Best 6' /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output 2>/dev/null; do sleep 5; done; cat /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output",
"description": "Wait for tuner to finish",
"run_in_background": true
}Command running in background with ID: bwe3l5rxi. Output is being written to: /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bwe3l5rxi.output{
"command": "cat /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bwe3l5rxi.output",
"description": "Check tuner result"
}(Bash completed with no output){
"command": "ps aux | grep -E \"tune_gemv|python\" | grep -v grep | head",
"description": "Check running python"
}root 1907 0.0 0.0 126720 24796 ? Ssl 01:49 0:00 /usr/bin/python3 /usr/share/unattended-upgrades/unattended-upgrade-shutdown --wait-for-signal
infatos+ 14352 0.0 0.0 344524 95372 ? Ssl 01:50 0:03 /home/infatoshi/hermes-agent/.venv/bin/python -m hermes_cli.main gateway run --replace
infatos+ 280959 0.0 0.0 16972 2252 ? S 08:39 0:00 timeout 2700 claude --dangerously-skip-permissions --print --verbose --output-format stream-json --model claude-opus-4-7 --effort max --add-dir /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm -p I need you to write a weight-only int4 quantized GEMM (W4A16) for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py. The scheme is AWQ/GPTQ-style asymmetric int4 with explicit zero-points and per-group bf16 scales. Inputs are x of shape (M, K) bf16, w_q of shape (K // 2, N) uint8 (two int4 weights packed per byte, low nibble = even-K row, high nibble = odd-K row), scales of shape (K // 128, N) bf16, and zeros of shape (K // 128, N) bf16. Group size is 128 along K. Dequant per group is w_bf[k, n] = (unpack(w_q)[k, n] - zeros[k // 128, n]) * scales[k // 128, n],
... (4993 more chars){
"command": "kill 289316 289320 290814 2>/dev/null; sleep 2; ps aux | grep _tune | grep -v grep",
"description": "Kill tuner"
}infatos+ 289314 0.0 0.0 19720 4600 ? Ss 08:54 0:00 /usr/bin/zsh -c source /home/infatoshi/.claude/shell-snapshots/snapshot-zsh-1777387164593-thbrlr.sh 2>/dev/null || true && setopt NO_EXTENDED_GLOB 2>/dev/null || true && eval 'uv run python _tune_gemv.py 2>&1 | tail -30' < /dev/null && pwd -P >| /tmp/claude-11e9-cwd{
"command": "pkill -f _tune_gemv; sleep 2; pkill -f bn21zx0ok; pkill -f bwe3l5rxi; sleep 1; ps aux | grep _tune | grep -v grep",
"description": "Force kill tuner"
}Exit code 144{
"command": "cat /tmp/claude-1002/-home-infatoshi-cuda-KernelBench-Hard/095e16f5-cebf-4044-bed1-c18d99a9d145/tasks/bn21zx0ok.output",
"description": "Check tuner output"
}(Bash completed with no output)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemv.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemv.py
@@ -1,5 +1,6 @@
-"""Sweep GEMV configs for M=1 shapes."""
+"""Sweep GEMV configs for M=1 shapes -- focused subset."""
import statistics
+import sys
import torch
import triton
@@ -53,10 +54,6 @@
x_even = tl.load(x_ptr + x_offs_even * stride_xk).to(tl.float32)
x_odd = tl.load(x_ptr + (x_offs_even + 1) * stride_xk).to(tl.float32)
- # Per-group scale/zero. Each row of (BLOCK_KH,) maps to a group.
- # Compute group index for even-K rows (k_full = kbase_full + 2*kh).
- # group = (kbase_full + 2*kh) // GROUP_SIZE. For even and odd, group is the same
- # (both 2*kh and 2*kh+1 fall in the same group when GROUP_SIZE >= 2).
if n_groups_per_block == 1:
group_idx = kbase_full // GROUP_SIZE
s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn).to(tl.float32)
@@ -64,7 +61,6 @@
w_lo_f = (w_lo - z[None, :]) * s[None, :]
w_hi_f = (w_hi - z[None, :]) * s[None, :]
else:
- # General: group per row
group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
s = tl.load(
s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
@@ -105,7 +101,7 @@
return partials.sum(dim=0).to(torch.bfloat16).view(M, N)
-def time_call(fn, args, iters=30, warmup=10, flush=True):
+def time_call(fn, args, iters=20, warmup=5, flush=True):
if flush:
scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
for _ in range(warmup):
@@ -129,7 +125,12 @@
def main():
import reference
- for shape_name, (M, N, K) in [("0", (1, 12288, 4096)), ("3", (1, 4096, 4096))]:
+ only_shape = int(sys.argv[1]) if len(sys.argv) > 1 else -1
+
+ cases = [(0, 1, 12288, 4096), (3, 1, 4096, 4096)]
+ for sh, M, N, K in cases:
+ if only_shape >= 0 and sh != only_shape:
+ continue
reference.M = M
reference.N = N
reference.K = K
@@ -139,24 +140,19 @@
ref_out = ref(x).clone()
bytes_moved = M * K * 2 + (K // 2) * N + (K // 128) * N * 2 + (K // 128) * N * 2 + M * N * 2
- print(f"\n=== shape {shape_name} M={M} N={N} K={K} bytes={bytes_moved/1e6:.1f}MB ===")
+ print(f"\n=== shape {sh} M={M} N={N} K={K} bytes={bytes_moved/1e6:.1f}MB ===", flush=True)
+ # Targeted small set
configs = []
for bn in [64, 128, 256]:
for bk in [128, 256, 512]:
- if K % bk != 0:
- continue
- # Compute SPLIT_K candidates
- n_tiles = (N + bn - 1) // bn
- n_k_iters = K // bk
- for sk in [1, 2, 4, 8, 16]:
- if n_k_iters % sk != 0:
- continue
+ for sk in [1, 2, 4, 8]:
if K % (sk * bk) != 0:
continue
+ n_tiles = (N + bn - 1) // bn
if n_tiles * sk < 32 or n_tiles * sk > 4 * SM_COUNT:
continue
- for nw in [2, 4, 8]:
+ for nw in [4, 8]:
for ns in [2, 3, 4]:
configs.append((bn, bk, sk, nw, ns))
@@ -164,25 +160,20 @@
for cfg in configs:
bn, bk, sk, nw, ns = cfg
try:
- # Verify correctness
sol_out = run(M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns)
if not torch.allclose(ref_out.float(), sol_out.float(), atol=0.1, rtol=0.1):
continue
- ms = time_call(
- run,
- (M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns),
- iters=20,
- warmup=5,
- )
+ ms = time_call(run, (M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns), iters=15, warmup=4)
gbps = bytes_moved / (ms * 1e-3) / 1e9
results.append((cfg, ms, gbps))
+ print(f" cfg={cfg} ms={ms:.4f} gbps={gbps:.1f}", flush=True)
except Exception as e:
- pass
+ print(f" cfg={cfg} ERR={type(e).__name__}: {str(e)[:80]}", flush=True)
results.sort(key=lambda r: r[1])
- print(f"Best 6 configs:")
- for cfg, ms, gbps in results[:6]:
- print(f" cfg={cfg} ms={ms:.4f} gbps={gbps:.1f}")
+ print(f" >> Best 5:")
+ for cfg, ms, gbps in results[:5]:
+ print(f" cfg={cfg} ms={ms:.4f} gbps={gbps:.1f}")
if __name__ == "__main__":
{
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemv.py",
"content": "<6186 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemv.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python _tune_gemv.py 0 2>&1",
"description": "Tune shape 0 GEMV"
}=== shape 0 M=1 N=12288 K=4096 bytes=26.8MB ===
cfg=(64, 128, 1, 4, 2) ms=0.0813 gbps=329.4
cfg=(64, 128, 1, 4, 3) ms=0.0805 gbps=332.5
cfg=(64, 128, 1, 4, 4) ms=0.0858 gbps=311.9
cfg=(64, 128, 1, 8, 2) ms=0.1210 gbps=221.3
cfg=(64, 128, 1, 8, 3) ms=0.1260 gbps=212.4
cfg=(64, 128, 1, 8, 4) ms=0.1206 gbps=221.9
cfg=(64, 128, 2, 4, 2) ms=0.0623 gbps=429.7
cfg=(64, 128, 2, 4, 3) ms=0.0715 gbps=374.3
cfg=(64, 128, 2, 4, 4) ms=0.0621 gbps=430.8
cfg=(64, 128, 2, 8, 2) ms=0.0977 gbps=274.0
cfg=(64, 128, 2, 8, 3) ms=0.0921 gbps=290.8
cfg=(64, 128, 2, 8, 4) ms=0.0912 gbps=293.4
cfg=(64, 256, 1, 4, 2) ms=0.0644 gbps=415.4
cfg=(64, 256, 1, 4, 3) ms=0.0634 gbps=422.3
cfg=(64, 256, 1, 4, 4) ms=0.0644 gbps=415.6
cfg=(64, 256, 1, 8, 2) ms=0.0778 gbps=344.0
cfg=(64, 256, 1, 8, 3) ms=0.0758 gbps=353.3
cfg=(64, 256, 1, 8, 4) ms=0.0749 gbps=357.2
cfg=(64, 256, 2, 4, 2) ms=0.0473 gbps=566.0
cfg=(64, 256, 2, 4, 3) ms=0.0596 gbps=449.5
cfg=(64, 256, 2, 4, 4) ms=0.0484 gbps=553.7
cfg=(64, 256, 2, 8, 2) ms=0.0644 gbps=415.4
cfg=(64, 256, 2, 8, 3) ms=0.0717 gbps=373.3
cfg=(64, 256, 2, 8, 4) ms=0.0640 gbps=418.1
cfg=(64, 512, 1, 4, 2) ms=0.0628 gbps=426.6
cfg=(64, 512, 1, 4, 3) ms=0.0628 gbps=426.4
cfg=(64, 512, 1, 4, 4) ms=0.0636 gbps=420.6
cfg=(64, 512, 1, 8, 2) ms=0.0643 gbps=416.6
cfg=(64, 512, 1, 8, 3) ms=0.0577 gbps=464.0
cfg=(64, 512, 1, 8, 4) ms=0.0589 gbps=454.4
cfg=(64, 512, 2, 4, 2) ms=0.0591 gbps=453.0
cfg=(64, 512, 2, 4, 3) m
... (6374 more chars){
"command": "timeout 600 uv run python _tune_gemv.py 3 2>&1 | tail -50",
"description": "Tune shape 3"
}cfg=(256, 128, 4, 8, 3) ms=0.0579 gbps=154.3
cfg=(256, 128, 4, 8, 4) ms=0.0583 gbps=153.2
cfg=(256, 128, 8, 4, 2) ms=0.0382 gbps=233.9
cfg=(256, 128, 8, 4, 3) ms=0.0446 gbps=200.3
cfg=(256, 128, 8, 4, 4) ms=0.0471 gbps=189.4
cfg=(256, 128, 8, 8, 2) ms=0.0390 gbps=228.9
cfg=(256, 128, 8, 8, 3) ms=0.0393 gbps=227.2
cfg=(256, 128, 8, 8, 4) ms=0.0388 gbps=230.0
cfg=(256, 256, 2, 4, 2) ms=0.0971 gbps=92.0
cfg=(256, 256, 2, 4, 3) ms=0.0909 gbps=98.2
cfg=(256, 256, 2, 4, 4) ms=0.0897 gbps=99.5
cfg=(256, 256, 2, 8, 2) ms=0.0688 gbps=129.8
cfg=(256, 256, 2, 8, 3) ms=0.0739 gbps=120.8
cfg=(256, 256, 2, 8, 4) ms=0.0691 gbps=129.2
cfg=(256, 256, 4, 4, 2) ms=0.0588 gbps=151.9
cfg=(256, 256, 4, 4, 3) ms=0.0585 gbps=152.7
cfg=(256, 256, 4, 4, 4) ms=0.0592 gbps=150.9
cfg=(256, 256, 4, 8, 2) ms=0.0521 gbps=171.5
cfg=(256, 256, 4, 8, 3) ms=0.0448 gbps=199.2
cfg=(256, 256, 4, 8, 4) ms=0.0444 gbps=201.3
cfg=(256, 256, 8, 4, 2) ms=0.0468 gbps=191.0
cfg=(256, 256, 8, 4, 3) ms=0.0463 gbps=192.8
cfg=(256, 256, 8, 4, 4) ms=0.0526 gbps=169.7
cfg=(256, 256, 8, 8, 2) ms=0.0330 gbps=270.4
cfg=(256, 256, 8, 8, 3) ms=0.0329 gbps=271.7
cfg=(256, 256, 8, 8, 4) ms=0.0322 gbps=277.7
cfg=(256, 512, 2, 4, 2) ms=0.3341 gbps=26.7
cfg=(256, 512, 2, 4, 3) ms=0.3217 gbps=27.8
cfg=(256, 512, 2, 4, 4) ms=0.3291 gbps=27.1
cfg=(256, 512, 2, 8, 2) ms=0.0779 gbps=114.6
cfg=(256, 512, 2, 8, 3) ms=0.0771 gbps=115.9
cfg=(256, 512, 2, 8, 4) ms=0.0810 gbps=110.2
cf
... (812 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemm.py
@@ -0,0 +1,192 @@
+"""Sweep GEMM configs for M>=16 shapes."""
+import statistics
+import sys
+
+import torch
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+SM_COUNT = 188
+
+
+@triton.jit
+def gemm_kernel(
+ x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ BLOCK_KH: tl.constexpr = BLOCK_K // 2
+ offs_kh = tl.arange(0, BLOCK_KH)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ n_iters = K // BLOCK_K
+ n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
+
+ for k_iter in range(n_iters):
+ kbase_packed = k_iter * BLOCK_KH
+ kbase_full = k_iter * BLOCK_K
+
+ w_ptrs = (
+ w_ptr
+ + (kbase_packed + offs_kh)[:, None] * stride_wk
+ + offs_n[None, :] * stride_wn
+ )
+ w_packed = tl.load(w_ptrs)
+ w_lo = (w_packed & 0xF).to(tl.bfloat16)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
+
+ if n_groups_per_block == 1:
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
+ w_lo_bf = (w_lo - z[None, :]) * s[None, :]
+ w_hi_bf = (w_hi - z[None, :]) * s[None, :]
+ else:
+ group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
+ s = tl.load(
+ s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ )
+ z = tl.load(
+ z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn
+ )
+ w_lo_bf = (w_lo - z) * s
+ w_hi_bf = (w_hi - z) * s
+
+ w_join = tl.join(w_lo_bf, w_hi_bf)
+ w_perm = tl.permute(w_join, (0, 2, 1))
+ w = tl.reshape(w_perm, (BLOCK_K, BLOCK_N))
+
+ x_ptrs = (
+ x_ptr
+ + offs_m[:, None] * stride_xm
+ + (kbase_full + offs_k)[None, :] * stride_xk
+ )
+ x_mask = offs_m[:, None] < M
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0)
+
+ accumulator = tl.dot(x, w, accumulator)
+
+ out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ out_mask = offs_m[:, None] < M
+ tl.store(out_ptrs, accumulator.to(tl.bfloat16), mask=out_mask)
+
+
+def run(M, N, K, w_q, scales, zeros, x, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ grid = ((M + BLOCK_M - 1) // BLOCK_M, (N + BLOCK_N - 1) // BLOCK_N)
+ gemm_kernel[grid](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0), out.stride(1),
+ GROUP_SIZE=GROUP_SIZE,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return out
+
+
+def time_call(fn, args, iters=15, warmup=4):
+ scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
+ for _ in range(warmup):
+ fn(*args)
+ torch.cuda.synchronize()
+ times = []
+ for _ in range(iters):
+ scratch.zero_()
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ fn(*args)
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
+
+
+def main():
+ import reference
+
+ only_shape = int(sys.argv[1]) if len(sys.argv) > 1 else -1
+
+ cases = [
+ (1, 32, 12288, 4096),
+ (2, 256, 12288, 4096),
+ (4, 16, 14336, 4096),
+ ]
+ for sh, M, N, K in cases:
+ if only_shape >= 0 and sh != only_shape:
+ continue
+ reference.M = M
+ reference.N = N
+ reference.K = K
+
+ ref = reference.Model(M, N, K).cuda().eval()
+ x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
+ ref_out = ref(x).clone()
+
+ bytes_moved = M * K * 2 + (K // 2) * N + (K // 128) * N * 2 + (K // 128) * N * 2 + M * N * 2
+ flops = 2 * M * N * K
+ print(f"\n=== shape {sh} M={M} N={N} K={K} bytes={bytes_moved/1e6:.1f}MB flops={flops/1e9:.2f}G ===", flush=True)
+
+ bm_options = [16, 32, 64, 128]
+ bm_options = [b for b in bm_options if b >= min(16, M) and b <= max(M * 2, 16)]
+ configs = []
+ for bm in bm_options:
+ for bn in [64, 128, 256]:
+ for bk in [128, 256]:
+ if K % bk != 0:
+ continue
+ for nw in [4, 8]:
+ for ns in [2, 3, 4]:
+ configs.append((bm, bn, bk, nw, ns))
+
+ results = []
+ for cfg in configs:
+ bm, bn, bk, nw, ns = cfg
+ try:
+ sol_out = run(M, N, K, ref.w_q, ref.scales, ref.zeros, x, bm, bn, bk, nw, ns)
+ if not torch.allclose(ref_out.float(), sol_out.float(), atol=0.1, rtol=0.1):
+ continue
+ ms = time_call(run, (M, N, K, ref.w_q, ref.scales, ref.zeros, x, bm, bn, bk, nw, ns))
+ gbps = bytes_move
... (693 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemm.py",
"content": "<6388 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_tune_gemm.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python _tune_gemm.py 1 2>&1 | tail -30",
"description": "Tune shape 1"
}cfg=(32, 256, 128, 8, 3) ms=0.1963 gbps=141.5 tflops=16.4
cfg=(32, 256, 128, 8, 4) ms=0.2038 gbps=136.3 tflops=15.8
cfg=(64, 64, 128, 4, 2) ms=0.0852 gbps=326.2 tflops=37.8
cfg=(64, 64, 128, 4, 3) ms=0.0832 gbps=334.0 tflops=38.7
cfg=(64, 64, 128, 4, 4) ms=0.1127 gbps=246.5 tflops=28.6
cfg=(64, 64, 128, 8, 2) ms=0.1287 gbps=215.9 tflops=25.0
cfg=(64, 64, 128, 8, 3) ms=0.1123 gbps=247.5 tflops=28.7
cfg=(64, 64, 128, 8, 4) ms=0.1044 gbps=266.1 tflops=30.8
cfg=(64, 64, 256, 4, 2) ms=0.0929 gbps=299.0 tflops=34.7
cfg=(64, 64, 256, 4, 3) ms=0.1838 gbps=151.1 tflops=17.5
cfg=(64, 64, 256, 8, 2) ms=0.0855 gbps=325.0 tflops=37.7
cfg=(64, 64, 256, 8, 3) ms=0.1014 gbps=274.0 tflops=31.8
cfg=(64, 128, 128, 4, 2) ms=0.1188 gbps=234.0 tflops=27.1
cfg=(64, 128, 128, 4, 3) ms=0.1266 gbps=219.6 tflops=25.5
cfg=(64, 128, 128, 4, 4) ms=0.1255 gbps=221.4 tflops=25.7
cfg=(64, 128, 128, 8, 2) ms=0.1153 gbps=241.1 tflops=27.9
cfg=(64, 128, 128, 8, 3) ms=0.1005 gbps=276.4 tflops=32.0
cfg=(64, 128, 128, 8, 4) ms=0.1013 gbps=274.4 tflops=31.8
cfg=(64, 128, 256, 4, 2) ms=0.3664 gbps=75.8 tflops=8.8
cfg=(64, 128, 256, 8, 2) ms=0.0812 gbps=342.0 tflops=39.6
cfg=(64, 256, 128, 4, 2) ms=0.3086 gbps=90.0 tflops=10.4
cfg=(64, 256, 128, 4, 3) ms=0.3816 gbps=72.8 tflops=8.4
cfg=(64, 256, 128, 8, 2) ms=0.2167 gbps=128.2 tflops=14.9
cfg=(64, 256, 128, 8, 3) ms=0.2321 gbps=119.7 tflops=13.9
>> Best 5:
cfg=(32, 64, 256, 4, 2) ms=0.0650 gbps=427.3 tflops=49.5
... (246 more chars){
"command": "timeout 600 uv run python _tune_gemm.py 2 2>&1 | tail -30",
"description": "Tune shape 2"
}cfg=(64, 128, 128, 4, 2) ms=0.2575 gbps=136.4 tflops=100.1
cfg=(64, 128, 128, 4, 3) ms=0.3292 gbps=106.7 tflops=78.3
cfg=(64, 128, 128, 4, 4) ms=0.3441 gbps=102.1 tflops=74.9
cfg=(64, 128, 128, 8, 2) ms=0.2843 gbps=123.6 tflops=90.7
cfg=(64, 128, 128, 8, 3) ms=0.2633 gbps=133.4 tflops=97.9
cfg=(64, 128, 128, 8, 4) ms=0.2620 gbps=134.1 tflops=98.4
cfg=(64, 128, 256, 4, 2) ms=0.9611 gbps=36.6 tflops=26.8
cfg=(64, 128, 256, 8, 2) ms=0.2140 gbps=164.2 tflops=120.4
cfg=(64, 256, 128, 4, 2) ms=0.5595 gbps=62.8 tflops=46.1
cfg=(64, 256, 128, 4, 3) ms=0.7321 gbps=48.0 tflops=35.2
cfg=(64, 256, 128, 8, 2) ms=0.4044 gbps=86.9 tflops=63.7
cfg=(64, 256, 128, 8, 3) ms=0.4379 gbps=80.2 tflops=58.9
cfg=(128, 64, 128, 4, 2) ms=0.2127 gbps=165.1 tflops=121.2
cfg=(128, 64, 128, 4, 3) ms=0.2884 gbps=121.8 tflops=89.4
cfg=(128, 64, 128, 8, 2) ms=0.2366 gbps=148.5 tflops=108.9
cfg=(128, 64, 128, 8, 3) ms=0.1977 gbps=177.7 tflops=130.3
cfg=(128, 64, 256, 4, 2) ms=0.2555 gbps=137.5 tflops=100.9
cfg=(128, 64, 256, 8, 2) ms=0.1802 gbps=194.9 tflops=143.0
cfg=(128, 128, 128, 4, 2) ms=0.3344 gbps=105.1 tflops=77.1
cfg=(128, 128, 128, 4, 3) ms=0.3856 gbps=91.1 tflops=66.8
cfg=(128, 128, 128, 8, 2) ms=0.2647 gbps=132.7 tflops=97.4
cfg=(128, 128, 128, 8, 3) ms=0.2897 gbps=121.3 tflops=89.0
cfg=(128, 256, 128, 4, 2) ms=4.5599 gbps=7.7 tflops=5.7
cfg=(128, 256, 128, 8, 2) ms=0.3103 gbps=113.2 tflops=83.0
>> Best 5:
cfg=(128, 64, 256, 8, 2) ms=0.1802 gbps=
... (271 more chars){
"command": "timeout 600 uv run python _tune_gemm.py 4 2>&1 | tail -20",
"description": "Tune shape 4"
}cfg=(32, 128, 128, 4, 4) ms=0.1132 gbps=280.8 tflops=16.6
cfg=(32, 128, 128, 8, 2) ms=0.1037 gbps=306.4 tflops=18.1
cfg=(32, 128, 128, 8, 3) ms=0.0951 gbps=334.3 tflops=19.8
cfg=(32, 128, 128, 8, 4) ms=0.0874 gbps=363.6 tflops=21.5
cfg=(32, 128, 256, 4, 2) ms=0.1490 gbps=213.3 tflops=12.6
cfg=(32, 128, 256, 4, 3) ms=0.4548 gbps=69.9 tflops=4.1
cfg=(32, 128, 256, 8, 2) ms=0.0692 gbps=459.6 tflops=27.2
cfg=(32, 128, 256, 8, 3) ms=0.1054 gbps=301.5 tflops=17.8
cfg=(32, 256, 128, 4, 2) ms=0.2023 gbps=157.1 tflops=9.3
cfg=(32, 256, 128, 4, 3) ms=0.2410 gbps=131.9 tflops=7.8
cfg=(32, 256, 128, 4, 4) ms=0.4268 gbps=74.5 tflops=4.4
cfg=(32, 256, 128, 8, 2) ms=0.1946 gbps=163.4 tflops=9.7
cfg=(32, 256, 128, 8, 3) ms=0.2015 gbps=157.7 tflops=9.3
cfg=(32, 256, 128, 8, 4) ms=0.2041 gbps=155.8 tflops=9.2
>> Best 5:
cfg=(16, 64, 256, 4, 2) ms=0.0600 gbps=529.5 tflops=31.3
cfg=(16, 64, 256, 8, 2) ms=0.0601 gbps=528.6 tflops=31.3
cfg=(32, 64, 256, 4, 2) ms=0.0624 gbps=509.4 tflops=30.1
cfg=(16, 128, 256, 8, 2) ms=0.0632 gbps=502.7 tflops=29.7
cfg=(16, 64, 128, 8, 3) ms=0.0681 gbps=466.5 tflops=27.6--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -27,7 +27,7 @@
# ---------------------------------------------------------------------------
-# GEMM kernel (M >= 16): tensor-core path
+# GEMM kernel (M >= 16): tensor-core path. BLOCK_K may be a multiple of GROUP_SIZE.
# ---------------------------------------------------------------------------
@triton.jit
def w4a16_gemm_kernel(
@@ -56,6 +56,7 @@
offs_k = tl.arange(0, BLOCK_K)
n_iters = K // BLOCK_K
+ n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
for k_iter in range(n_iters):
kbase_packed = k_iter * BLOCK_KH
@@ -71,12 +72,22 @@
w_lo = (w_packed & 0xF).to(tl.bfloat16)
w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- group_idx = kbase_full // GROUP_SIZE
- s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
- z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
-
- w_lo_bf = (w_lo - z[None, :]) * s[None, :]
- w_hi_bf = (w_hi - z[None, :]) * s[None, :]
+ if n_groups_per_block == 1:
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
+ w_lo_bf = (w_lo - z[None, :]) * s[None, :]
+ w_hi_bf = (w_hi - z[None, :]) * s[None, :]
+ else:
+ group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
+ s = tl.load(
+ s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ )
+ z = tl.load(
+ z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn
+ )
+ w_lo_bf = (w_lo - z) * s
+ w_hi_bf = (w_hi - z) * s
w_join = tl.join(w_lo_bf, w_hi_bf) # (BLOCK_KH, BLOCK_N, 2)
w_perm = tl.permute(w_join, (0, 2, 1)) # (BLOCK_KH, 2, BLOCK_N)
@@ -98,7 +109,7 @@
# ---------------------------------------------------------------------------
-# GEMV kernel (M == 1) with split-K. Writes fp32 partial sums of shape (SPLIT_K, N).
+# GEMV kernel (M == 1) with split-K. Writes (SPLIT_K, N) fp32 partials.
# ---------------------------------------------------------------------------
@triton.jit
def w4a16_gemv_kernel(
@@ -125,6 +136,7 @@
k_start_full = pid_k * K_PER_SPLIT
n_iters: tl.constexpr = K_PER_SPLIT // BLOCK_K
+ n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
for k_iter in range(n_iters):
kbase_full = k_start_full + k_iter * BLOCK_K
@@ -143,12 +155,22 @@
x_even = tl.load(x_ptr + x_offs_even * stride_xk).to(tl.float32)
x_odd = tl.load(x_ptr + (x_offs_even + 1) * stride_xk).to(tl.float32)
- group_idx = kbase_full // GROUP_SIZE
- s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn).to(tl.float32)
- z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn).to(tl.float32)
-
- w_lo_f = (w_lo - z[None, :]) * s[None, :]
- w_hi_f = (w_hi - z[None, :]) * s[None, :]
+ if n_groups_per_block == 1:
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn).to(tl.float32)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn).to(tl.float32)
+ w_lo_f = (w_lo - z[None, :]) * s[None, :]
+ w_hi_f = (w_hi - z[None, :]) * s[None, :]
+ else:
+ group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
+ s = tl.load(
+ s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ ).to(tl.float32)
+ z = tl.load(
+ z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn
+ ).to(tl.float32)
+ w_lo_f = (w_lo - z) * s
+ w_hi_f = (w_hi - z) * s
accumulator += tl.sum(x_even[:, None] * w_lo_f, axis=0)
accumulator += tl.sum(x_odd[:, None] * w_hi_f, axis=0)
@@ -160,17 +182,35 @@
# ---------------------------------------------------------------------------
# Dispatcher
# ---------------------------------------------------------------------------
-def _choose_split_k(K: int, n_tiles: int, block_k: int) -> int:
- n_k_iters = K // block_k
- candidates = [s for s in [1, 2, 4, 8, 16, 32] if n_k_iters % s == 0]
- target = SM_COUNT
- chosen = 1
- for s in candidates:
- if n_tiles * s >= target:
- chosen = s
- break
- chosen = s
- return chosen
+def _gemv_config(N: int, K: int) -> tuple[int, int, int, int, int]:
+ """Return (BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages) for M=1 GEMV."""
+ BLOCK_N = 64
+ n_tiles = (N + BLOCK_N - 1) // BLOCK_N
+
+ if n_tiles <= 64: # small N (e.g., shape 3: N=4096)
+ BLOCK_K = 512 if K % 512 == 0 else 256
+ SPLIT_K = 8
+ num_warps = 8
+ elif n_tiles <= 128:
+ BLOCK_K = 256
+ SPLIT_K = 4
+ num_warps = 4
+ else: # large N (e.g., shape 0: N=12288)
+ BLOCK_K = 256
+ SPLIT_K = 2
+ num_warps = 4
+ return BLOCK_N, BLOCK_K, SPLIT_K, num_warps, 2
+
+
+def _gemm_config(M: int, N: int, K: int) -> tuple[int, int, int, int, int]:
+ """Return (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) for GEMM."""
+ if M <= 16:
+ return 16, 64, 256, 4, 2
+ if M <= 32:
+ return 32, 64, 256, 4, 2
+ if M <= 64:
+ return 64, 64, 256, 4, 2
+ return 128, 64, 256, 8, 2
def w4a16_gemm(
@@ -186,12 +226,9 @@
assert K % group_size == 0
if M == 1:
- BLOCK_N = 128
- BLOCK_K = group_size
+ BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = _gemv_config(N, K)
+
... (1534 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<10266 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness"
}PASS{
"command": "uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Run benchmark with tuned configs"
}shape=0 variant=eager tflops=0.131 gbps=34.907 ms=0.767
shape=0 variant=compiled tflops=0.655 gbps=174.130 ms=0.154
shape=0 variant=sota tflops=2.186 gbps=581.382 ms=0.046
shape=0 variant=solution tflops=1.840 gbps=489.244 ms=0.055
shape=0 solution_peak_fraction=0.2718
shape=1 variant=eager tflops=4.140 gbps=35.715 ms=0.778
shape=1 variant=compiled tflops=19.699 gbps=169.932 ms=0.164
shape=1 variant=sota tflops=22.829 gbps=196.928 ms=0.141
shape=1 variant=solution tflops=46.798 gbps=403.697 ms=0.069
shape=1 solution_peak_fraction=0.2243
shape=2 variant=eager tflops=31.079 gbps=42.364 ms=0.829
shape=2 variant=compiled tflops=119.260 gbps=162.566 ms=0.216
shape=2 variant=sota tflops=130.923 gbps=178.463 ms=0.197
shape=2 variant=solution tflops=136.585 gbps=186.182 ms=0.189
shape=2 solution_peak_fraction=0.1034
shape=3 variant=eager tflops=0.177 gbps=47.032 ms=0.190
shape=3 variant=compiled tflops=0.510 gbps=135.786 ms=0.066
shape=3 variant=sota tflops=1.242 gbps=330.420 ms=0.027
shape=3 variant=solution tflops=1.014 gbps=269.734 ms=0.033
shape=3 solution_peak_fraction=0.1499
shape=4 variant=eager tflops=2.045 gbps=34.595 ms=0.919
shape=4 variant=compiled tflops=10.122 gbps=171.211 ms=0.186
shape=4 variant=sota tflops=11.277 gbps=190.759 ms=0.167
shape=4 variant=solution tflops=26.379 gbps=446.217 ms=0.071
shape=4 solution_peak_fraction=0.2479
peak_fraction: 0.1879
RESULT: OK{
"command": "<1360 chars \u2014 see below>",
"description": "Measure forward overhead"
}uv run python -c "
import statistics
import torch
import solution
import reference
torch.manual_seed(42)
M, N, K = 1, 12288, 4096
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
# Warmup
for _ in range(20): sol(x)
torch.cuda.synchronize()
scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
# Time forward
times_fwd = []
for _ in range(30):
scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
s.record(); sol(x); e.record()
torch.cuda.synchronize()
times_fwd.append(s.elapsed_time(e))
# Time direct call
def direct():
return solution.w4a16_gemm(x, sol.w_q, sol.scales, sol.zeros)
times_dir = []
for _ in range(30):
scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
s.record(); direct(); e.record()
torch.cuda.synchronize()
times_dir.append(s.elapsed_time(e))
print(f'forward median: {statistics.median(times_fwd):.4f} ms')
print(f'direct median: {statistics.median(times_dir):.4f} ms')
print(f'overhead: {statistics.median(times_fwd)-statistics.median(times_dir):.4f} ms')
"forward median: 0.0592 ms
direct median: 0.0535 ms
overhead: 0.0057 ms{
"command": "<1183 chars \u2014 see below>",
"description": "Test more shape 0 configs"
}uv run python -c "
import statistics
import torch
import _tune_gemv as tg
import reference
torch.manual_seed(42)
# shape 0
M, N, K = 1, 12288, 4096
reference.M = M; reference.N = N; reference.K = K
ref = reference.Model(M, N, K).cuda().eval()
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
ref_out = ref(x).clone()
# Test more configs SPLIT_K=4
for cfg in [(64, 256, 4, 4, 2), (64, 256, 4, 4, 3), (64, 256, 4, 4, 4), (64, 256, 4, 8, 2), (64, 128, 4, 4, 2), (64, 128, 4, 8, 2), (32, 256, 2, 4, 2), (32, 256, 4, 4, 2)]:
try:
bn, bk, sk, nw, ns = cfg
sol_out = tg.run(M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns)
if not torch.allclose(ref_out.float(), sol_out.float(), atol=0.1, rtol=0.1):
print(f'cfg={cfg} FAIL')
continue
ms = tg.time_call(tg.run, (M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns), iters=30, warmup=8)
bytes_moved = M*K*2 + (K//2)*N + (K//128)*N*2 + (K//128)*N*2 + M*N*2
gbps = bytes_moved/(ms*1e-3)/1e9
print(f'cfg={cfg} ms={ms:.4f} gbps={gbps:.1f}')
except Exception as e:
print(f'cfg={cfg} ERR: {e}')
" 2>&1 | tail -15cfg=(64, 256, 4, 4, 2) ms=0.0476 gbps=561.9
cfg=(64, 256, 4, 4, 3) ms=0.0466 gbps=574.4
cfg=(64, 256, 4, 4, 4) ms=0.0501 gbps=534.1
cfg=(64, 256, 4, 8, 2) ms=0.0564 gbps=475.1
cfg=(64, 128, 4, 4, 2) ms=0.0540 gbps=495.6
cfg=(64, 128, 4, 8, 2) ms=0.0838 gbps=319.6
cfg=(32, 256, 2, 4, 2) ms=0.0564 gbps=474.8
cfg=(32, 256, 4, 4, 2) ms=0.0480 gbps=557.7{
"command": "<985 chars \u2014 see below>",
"description": "Detailed timing distribution"
}uv run python -c "
import statistics
import torch
import solution
import reference
torch.manual_seed(42)
M, N, K = 1, 12288, 4096
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
# Warmup
for _ in range(20): sol(x)
torch.cuda.synchronize()
scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
# More iters
times = []
for _ in range(100):
scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
s.record(); sol(x); e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
import numpy as np
arr = np.array(times)
print(f'min: {arr.min():.4f}, median: {np.median(arr):.4f}, mean: {arr.mean():.4f}, max: {arr.max():.4f}')
print(f'p25: {np.percentile(arr, 25):.4f}, p75: {np.percentile(arr, 75):.4f}')
"min: 0.0489, median: 0.0523, mean: 0.0591, max: 0.2093
p25: 0.0509, p75: 0.0568{
"command": "<2634 chars \u2014 see below>",
"description": "Profile breakdown"
}uv run python -c "
import statistics
import torch
import solution
import reference
torch.manual_seed(42)
M, N, K = 1, 12288, 4096
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
# Warmup
for _ in range(20): sol(x)
torch.cuda.synchronize()
# Profile breakdown using events around individual ops
# Use raw call to see kernel + reduce timing
def run_raw():
BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = solution._gemv_config(N, K)
K_PER_SPLIT = K // SPLIT_K
n_tiles = (N + BLOCK_N - 1) // BLOCK_N
partials = torch.empty((SPLIT_K, N), dtype=torch.float32, device=x.device)
grid = (n_tiles, SPLIT_K)
solution.w4a16_gemv_kernel[grid](
x, sol.w_q, sol.scales, sol.zeros, partials,
N, K,
x.stride(1),
sol.w_q.stride(0), sol.w_q.stride(1),
sol.scales.stride(0), sol.scales.stride(1),
sol.zeros.stride(0), sol.zeros.stride(1),
partials.stride(0), partials.stride(1),
GROUP_SIZE=128,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, K_PER_SPLIT=K_PER_SPLIT,
num_warps=num_warps, num_stages=num_stages,
)
return partials
scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
import numpy as np
# Time just kernel (no reduce)
times_k = []
for _ in range(50):
scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
s.record(); _ = run_raw(); e.record()
torch.cuda.synchronize()
times_k.append(s.elapsed_time(e))
print(f'kernel only median: {np.median(times_k):.4f} ms (min={min(times_k):.4f})')
# Time kernel + reduce
def kernel_plus_reduce():
p = run_raw()
return p.sum(dim=0).to(torch.bfloat16).view(M, N)
times_kr = []
for _ in range(50):
scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
s.record(); _ = kernel_plus_reduce(); e.record()
torch.cuda.synchronize()
times_kr.append(s.elapsed_time(e))
print(f'kernel+reduce median: {np.median(times_kr):.4f} ms (min={min(times_kr):.4f})')
# Time forward
times_f = []
for _ in range(50):
scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
s.record(); sol(x); e.record()
torch.cuda.synchronize()
times_f.append(s.elapsed_time(e))
print(f'forward median: {np.median(times_f):.4f} ms (min={min(times_f):.4f})')
"kernel only median: 0.0470 ms (min=0.0448)
kernel+reduce median: 0.0529 ms (min=0.0479)
forward median: 0.0510 ms (min=0.0484){
"command": "<459 chars \u2014 see below>",
"description": "Test bf16 atomic_add support"
}uv run python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def test_atomic(out_ptr, BLOCK_N: tl.constexpr):
offs = tl.arange(0, BLOCK_N)
val = tl.zeros((BLOCK_N,), dtype=tl.bfloat16) + 1.0
tl.atomic_add(out_ptr + offs, val)
out = torch.zeros(64, dtype=torch.bfloat16, device='cuda')
test_atomic[(8,)](out, BLOCK_N=64)
torch.cuda.synchronize()
print(out)
print('bf16 atomic_add works:', torch.all(out == 8.0).item())
"Exit code 1
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line
... (160 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_atomic.py
@@ -0,0 +1,38 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def test_atomic_bf16(out_ptr, BLOCK_N: tl.constexpr):
+ offs = tl.arange(0, BLOCK_N)
+ val = tl.zeros((BLOCK_N,), dtype=tl.bfloat16) + 1.0
+ tl.atomic_add(out_ptr + offs, val)
+
+
+@triton.jit
+def test_atomic_fp32(out_ptr, BLOCK_N: tl.constexpr):
+ offs = tl.arange(0, BLOCK_N)
+ val = tl.zeros((BLOCK_N,), dtype=tl.float32) + 1.0
+ tl.atomic_add(out_ptr + offs, val)
+
+
+def main():
+ print("Test bf16 atomic_add")
+ try:
+ out = torch.zeros(64, dtype=torch.bfloat16, device="cuda")
+ test_atomic_bf16[(8,)](out, BLOCK_N=64)
+ torch.cuda.synchronize()
+ print(" result:", out[:4].tolist())
+ print(" ok:", bool(torch.all(out == 8.0).item()))
+ except Exception as e:
+ print(f" ERR: {type(e).__name__}: {e}")
+
+ print("Test fp32 atomic_add")
+ out = torch.zeros(64, dtype=torch.float32, device="cuda")
+ test_atomic_fp32[(8,)](out, BLOCK_N=64)
+ torch.cuda.synchronize()
+ print(" result:", out[:4].tolist())
+
+
+if __name__ == "__main__":
+ main()
{
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_atomic.py",
"content": "<1082 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_atomic.py (file state is current in your context — no need to Read it back){
"command": "uv run python _test_atomic.py 2>&1",
"description": "Test atomic_add support"
}Test bf16 atomic_add
result: [8.0, 8.0, 8.0, 8.0]
ok: True
Test fp32 atomic_add
result: [8.0, 8.0, 8.0, 8.0]--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -119,11 +119,12 @@
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
- stride_obs, stride_on,
+ stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
K_PER_SPLIT: tl.constexpr,
+ SPLIT_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_k = tl.program_id(1)
@@ -175,8 +176,11 @@
accumulator += tl.sum(x_even[:, None] * w_lo_f, axis=0)
accumulator += tl.sum(x_odd[:, None] * w_hi_f, axis=0)
- out_ptrs = out_ptr + pid_k * stride_obs + offs_n * stride_on
- tl.store(out_ptrs, accumulator)
+ out_ptrs = out_ptr + offs_n * stride_on
+ if SPLIT_K == 1:
+ tl.store(out_ptrs, accumulator.to(tl.bfloat16))
+ else:
+ tl.atomic_add(out_ptrs, accumulator.to(tl.bfloat16))
# ---------------------------------------------------------------------------
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<2411 chars \u2014 see diff>",
"new_string": "<2517 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -233,26 +233,28 @@
BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = _gemv_config(N, K)
K_PER_SPLIT = K // SPLIT_K
n_tiles = triton.cdiv(N, BLOCK_N)
- partials = torch.empty((SPLIT_K, N), dtype=torch.float32, device=x.device)
+ if SPLIT_K == 1:
+ out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
+ else:
+ out = torch.zeros((1, N), dtype=torch.bfloat16, device=x.device)
grid = (n_tiles, SPLIT_K)
w4a16_gemv_kernel[grid](
- x, w_q, scales, zeros, partials,
+ x, w_q, scales, zeros, out,
N, K,
x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
- partials.stride(0), partials.stride(1),
+ out.stride(1),
GROUP_SIZE=group_size,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
K_PER_SPLIT=K_PER_SPLIT,
+ SPLIT_K=SPLIT_K,
num_warps=num_warps,
num_stages=num_stages,
)
- if SPLIT_K == 1:
- return partials.to(torch.bfloat16).view(1, N)
- return partials.sum(dim=0).to(torch.bfloat16).view(1, N)
+ return out
BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = _gemm_config(M, N, K)
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<951 chars \u2014 see diff>",
"new_string": "<931 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness with bf16 atomic"
}PASS{
"command": "uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Benchmark with atomic_add"
}shape=0 variant=eager tflops=0.131 gbps=34.900 ms=0.767
shape=0 variant=compiled tflops=0.653 gbps=173.750 ms=0.154
shape=0 variant=sota tflops=2.258 gbps=600.580 ms=0.045
shape=0 variant=solution tflops=1.863 gbps=495.474 ms=0.054
shape=0 solution_peak_fraction=0.2753
shape=1 variant=eager tflops=4.144 gbps=35.748 ms=0.777
shape=1 variant=compiled tflops=19.884 gbps=171.526 ms=0.162
shape=1 variant=sota tflops=22.862 gbps=197.218 ms=0.141
shape=1 variant=solution tflops=46.679 gbps=402.667 ms=0.069
shape=1 solution_peak_fraction=0.2237
shape=2 variant=eager tflops=31.182 gbps=42.505 ms=0.826
shape=2 variant=compiled tflops=119.775 gbps=163.267 ms=0.215
shape=2 variant=sota tflops=133.010 gbps=181.308 ms=0.194
shape=2 variant=solution tflops=136.643 gbps=186.261 ms=0.189
shape=2 solution_peak_fraction=0.1035
shape=3 variant=eager tflops=0.178 gbps=47.303 ms=0.189
shape=3 variant=compiled tflops=0.434 gbps=115.401 ms=0.077
shape=3 variant=sota tflops=1.283 gbps=341.542 ms=0.026
shape=3 variant=solution tflops=0.822 gbps=218.769 ms=0.041
shape=3 solution_peak_fraction=0.1215
shape=4 variant=eager tflops=2.054 gbps=34.739 ms=0.915
shape=4 variant=compiled tflops=10.120 gbps=171.181 ms=0.186
shape=4 variant=sota tflops=11.177 gbps=189.070 ms=0.168
shape=4 variant=solution tflops=25.704 gbps=434.791 ms=0.073
shape=4 solution_peak_fraction=0.2416
peak_fraction: 0.1796
RESULT: OK--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -180,7 +180,8 @@
if SPLIT_K == 1:
tl.store(out_ptrs, accumulator.to(tl.bfloat16))
else:
- tl.atomic_add(out_ptrs, accumulator.to(tl.bfloat16))
+ # fp32 atomic_add into pre-zeroed fp32 buffer
+ tl.atomic_add(out_ptrs, accumulator)
# ---------------------------------------------------------------------------
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<191 chars \u2014 see diff>",
"new_string": "<229 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -234,19 +234,36 @@
BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = _gemv_config(N, K)
K_PER_SPLIT = K // SPLIT_K
n_tiles = triton.cdiv(N, BLOCK_N)
+ grid = (n_tiles, SPLIT_K)
if SPLIT_K == 1:
out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
- else:
- out = torch.zeros((1, N), dtype=torch.bfloat16, device=x.device)
- grid = (n_tiles, SPLIT_K)
+ w4a16_gemv_kernel[grid](
+ x, w_q, scales, zeros, out,
+ N, K,
+ x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(1),
+ GROUP_SIZE=group_size,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ K_PER_SPLIT=K_PER_SPLIT,
+ SPLIT_K=SPLIT_K,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return out
+ # SPLIT_K > 1: fp32 atomic_add into pre-zeroed fp32 buffer, then cast.
+ out_fp32 = torch.zeros((1, N), dtype=torch.float32, device=x.device)
w4a16_gemv_kernel[grid](
- x, w_q, scales, zeros, out,
+ x, w_q, scales, zeros, out_fp32,
N, K,
x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
- out.stride(1),
+ out_fp32.stride(1),
GROUP_SIZE=group_size,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
@@ -255,7 +272,7 @@
num_warps=num_warps,
num_stages=num_stages,
)
- return out
+ return out_fp32.to(torch.bfloat16)
BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = _gemm_config(M, N, K)
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<931 chars \u2014 see diff>",
"new_string": "<1633 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3 && uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Test fp32 atomic"
}PASS
shape=0 variant=eager tflops=0.131 gbps=34.864 ms=0.768
shape=0 variant=compiled tflops=0.638 gbps=169.697 ms=0.158
shape=0 variant=sota tflops=1.997 gbps=531.180 ms=0.050
shape=0 variant=solution tflops=1.883 gbps=500.813 ms=0.053
shape=0 solution_peak_fraction=0.2782
shape=1 variant=eager tflops=4.134 gbps=35.660 ms=0.779
shape=1 variant=compiled tflops=19.861 gbps=171.323 ms=0.162
shape=1 variant=sota tflops=22.335 gbps=192.667 ms=0.144
shape=1 variant=solution tflops=46.496 gbps=401.086 ms=0.069
shape=1 solution_peak_fraction=0.2228
shape=2 variant=eager tflops=31.140 gbps=42.448 ms=0.828
shape=2 variant=compiled tflops=119.633 gbps=163.073 ms=0.215
shape=2 variant=sota tflops=130.056 gbps=177.282 ms=0.198
shape=2 variant=solution tflops=135.642 gbps=184.896 ms=0.190
shape=2 solution_peak_fraction=0.1027
shape=3 variant=eager tflops=0.177 gbps=46.972 ms=0.190
shape=3 variant=compiled tflops=0.508 gbps=135.194 ms=0.066
shape=3 variant=sota tflops=1.197 gbps=318.539 ms=0.028
shape=3 variant=solution tflops=1.044 gbps=277.928 ms=0.032
shape=3 solution_peak_fraction=0.1544
shape=4 variant=eager tflops=2.046 gbps=34.610 ms=0.918
shape=4 variant=compiled tflops=10.034 gbps=169.733 ms=0.187
shape=4 variant=sota tflops=11.148 gbps=188.568 ms=0.169
shape=4 variant=solution tflops=27.922 gbps=472.316 ms=0.067
shape=4 solution_peak_fraction=0.2624
peak_fraction: 0.1916
RESULT: OK--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_variants.py
@@ -0,0 +1,173 @@
+"""Test cache modifier and various tweaks."""
+import statistics
+import torch
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.jit
+def gemv_v(
+ x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
+ N, K,
+ stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ K_PER_SPLIT: tl.constexpr,
+ SPLIT_K: tl.constexpr,
+ EVICT: tl.constexpr,
+):
+ pid_n = tl.program_id(0)
+ pid_k = tl.program_id(1)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ accumulator = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ BLOCK_KH: tl.constexpr = BLOCK_K // 2
+ offs_kh = tl.arange(0, BLOCK_KH)
+
+ k_start_full = pid_k * K_PER_SPLIT
+ n_iters: tl.constexpr = K_PER_SPLIT // BLOCK_K
+ n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
+
+ for k_iter in range(n_iters):
+ kbase_full = k_start_full + k_iter * BLOCK_K
+ kbase_packed = kbase_full // 2
+
+ w_ptrs = (
+ w_ptr
+ + (kbase_packed + offs_kh)[:, None] * stride_wk
+ + offs_n[None, :] * stride_wn
+ )
+ if EVICT == 1:
+ w_packed = tl.load(w_ptrs, eviction_policy="evict_first")
+ elif EVICT == 2:
+ w_packed = tl.load(w_ptrs, cache_modifier=".cg")
+ elif EVICT == 3:
+ w_packed = tl.load(w_ptrs, cache_modifier=".ca")
+ else:
+ w_packed = tl.load(w_ptrs)
+ w_lo = (w_packed & 0xF).to(tl.float32)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
+
+ x_offs_even = kbase_full + 2 * offs_kh
+ x_even = tl.load(x_ptr + x_offs_even * stride_xk).to(tl.float32)
+ x_odd = tl.load(x_ptr + (x_offs_even + 1) * stride_xk).to(tl.float32)
+
+ if n_groups_per_block == 1:
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn).to(tl.float32)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn).to(tl.float32)
+ w_lo_f = (w_lo - z[None, :]) * s[None, :]
+ w_hi_f = (w_hi - z[None, :]) * s[None, :]
+ else:
+ group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn).to(tl.float32)
+ z = tl.load(z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn).to(tl.float32)
+ w_lo_f = (w_lo - z) * s
+ w_hi_f = (w_hi - z) * s
+
+ accumulator += tl.sum(x_even[:, None] * w_lo_f, axis=0)
+ accumulator += tl.sum(x_odd[:, None] * w_hi_f, axis=0)
+
+ out_ptrs = out_ptr + offs_n * stride_on
+ if SPLIT_K == 1:
+ tl.store(out_ptrs, accumulator.to(tl.bfloat16))
+ else:
+ tl.atomic_add(out_ptrs, accumulator)
+
+
+def run(M, N, K, w_q, scales, zeros, x, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages, evict):
+ K_PER_SPLIT = K // SPLIT_K
+ n_tiles = (N + BLOCK_N - 1) // BLOCK_N
+ if SPLIT_K == 1:
+ out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
+ else:
+ out = torch.zeros((1, N), dtype=torch.float32, device=x.device)
+ grid = (n_tiles, SPLIT_K)
+ gemv_v[grid](
+ x, w_q, scales, zeros, out,
+ N, K,
+ x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(1),
+ GROUP_SIZE=GROUP_SIZE,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ K_PER_SPLIT=K_PER_SPLIT,
+ SPLIT_K=SPLIT_K,
+ EVICT=evict,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ if SPLIT_K > 1:
+ return out.to(torch.bfloat16)
+ return out
+
+
+def time_call(fn, args, iters=30, warmup=8):
+ scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
+ for _ in range(warmup):
+ fn(*args)
+ torch.cuda.synchronize()
+ times = []
+ for _ in range(iters):
+ scratch.zero_()
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ fn(*args)
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
+
+
+def main():
+ import reference
+
+ for sh, (M, N, K) in [(0, (1, 12288, 4096)), (3, (1, 4096, 4096))]:
+ reference.M = M
+ reference.N = N
+ reference.K = K
+
+ ref = reference.Model(M, N, K).cuda().eval()
+ x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
+ ref_out = ref(x).clone()
+
+ bytes_moved = M * K * 2 + (K // 2) * N + (K // 128) * N * 2 + (K // 128) * N * 2 + M * N * 2
+ print(f"\n=== shape {sh} M={M} N={N} K={K} ===", flush=True)
+
+ # Same config family from previous tuning
+ if N == 12288:
+ base_cfgs = [(64, 256, 2, 4, 2), (64, 256, 4, 4, 3)]
+ else:
+ base_cfgs = [(64, 512, 8, 8, 2), (64, 256, 8, 8, 2)]
+
+ for cfg in base_cfgs:
+ for evict in [0, 1, 2, 3]:
+ try:
+ bn, bk, sk, nw, ns = cfg
+ sol_out = run(M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns, evict)
+ if not torch.allclose(ref_out.float(), sol_out.float(), atol=0.1, rtol=0.1):
+ print(f" cfg={cfg} evict={evict} FAIL")
+ continue
+ ms = time_call(run, (M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns, evict))
+ gbps = bytes_moved / (ms * 1e-3) / 1e9
+ ev
... (317 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_variants.py",
"content": "<6027 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_variants.py (file state is current in your context — no need to Read it back){
"command": "timeout 120 uv run python _test_variants.py 2>&1 | tail -30",
"description": "Test cache modifiers"
}=== shape 0 M=1 N=12288 K=4096 ===
cfg=(64, 256, 2, 4, 2) evict=default ms=0.0552 gbps=485.0
cfg=(64, 256, 2, 4, 2) evict=evict_first ms=0.0490 gbps=546.3
cfg=(64, 256, 2, 4, 2) evict=cg (L2 only) ms=0.0479 gbps=559.2
cfg=(64, 256, 2, 4, 2) evict=ca (L1+L2) ms=0.0477 gbps=560.9
cfg=(64, 256, 4, 4, 3) evict=default ms=0.0567 gbps=472.0
cfg=(64, 256, 4, 4, 3) evict=evict_first ms=0.0575 gbps=465.3
cfg=(64, 256, 4, 4, 3) evict=cg (L2 only) ms=0.0470 gbps=569.9
cfg=(64, 256, 4, 4, 3) evict=ca (L1+L2) ms=0.0462 gbps=579.0
=== shape 3 M=1 N=4096 K=4096 ===
cfg=(64, 512, 8, 8, 2) evict=default ms=0.0350 gbps=255.1
cfg=(64, 512, 8, 8, 2) evict=evict_first ms=0.0284 gbps=314.2
cfg=(64, 512, 8, 8, 2) evict=cg (L2 only) ms=0.0292 gbps=306.0
cfg=(64, 512, 8, 8, 2) evict=ca (L1+L2) ms=0.0290 gbps=308.2
cfg=(64, 256, 8, 8, 2) evict=default ms=0.0312 gbps=285.9
cfg=(64, 256, 8, 8, 2) evict=evict_first ms=0.0321 gbps=278.3
cfg=(64, 256, 8, 8, 2) evict=cg (L2 only) ms=0.0317 gbps=281.7
cfg=(64, 256, 8, 8, 2) evict=ca (L1+L2) ms=0.0384 gbps=232.6--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -148,7 +148,7 @@
+ (kbase_packed + offs_kh)[:, None] * stride_wk
+ offs_n[None, :] * stride_wn
)
- w_packed = tl.load(w_ptrs)
+ w_packed = tl.load(w_ptrs, eviction_policy="evict_first")
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<610 chars \u2014 see diff>",
"new_string": "<641 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3 && uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Bench evict_first"
}PASS
shape=0 variant=eager tflops=0.131 gbps=34.948 ms=0.766
shape=0 variant=compiled tflops=0.640 gbps=170.337 ms=0.157
shape=0 variant=sota tflops=2.198 gbps=584.428 ms=0.046
shape=0 variant=solution tflops=1.874 gbps=498.426 ms=0.054
shape=0 solution_peak_fraction=0.2769
shape=1 variant=eager tflops=4.147 gbps=35.774 ms=0.777
shape=1 variant=compiled tflops=19.880 gbps=171.492 ms=0.162
shape=1 variant=sota tflops=22.362 gbps=192.903 ms=0.144
shape=1 variant=solution tflops=45.395 gbps=391.591 ms=0.071
shape=1 solution_peak_fraction=0.2176
shape=2 variant=eager tflops=31.267 gbps=42.621 ms=0.824
shape=2 variant=compiled tflops=118.332 gbps=161.300 ms=0.218
shape=2 variant=sota tflops=131.232 gbps=178.885 ms=0.196
shape=2 variant=solution tflops=136.597 gbps=186.198 ms=0.189
shape=2 solution_peak_fraction=0.1034
shape=3 variant=eager tflops=0.177 gbps=47.024 ms=0.190
shape=3 variant=compiled tflops=0.486 gbps=129.305 ms=0.069
shape=3 variant=sota tflops=1.251 gbps=332.983 ms=0.027
shape=3 variant=solution tflops=0.990 gbps=263.494 ms=0.034
shape=3 solution_peak_fraction=0.1464
shape=4 variant=eager tflops=2.053 gbps=34.723 ms=0.915
shape=4 variant=compiled tflops=10.291 gbps=174.076 ms=0.183
shape=4 variant=sota tflops=11.274 gbps=190.704 ms=0.167
shape=4 variant=solution tflops=26.618 gbps=450.263 ms=0.071
shape=4 solution_peak_fraction=0.2501
peak_fraction: 0.1869
RESULT: OK--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -148,7 +148,7 @@
+ (kbase_packed + offs_kh)[:, None] * stride_wk
+ offs_n[None, :] * stride_wn
)
- w_packed = tl.load(w_ptrs, eviction_policy="evict_first")
+ w_packed = tl.load(w_ptrs)
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": " w_packed = tl.load(w_ptrs, eviction_policy=\"evict_first\")",
"new_string": " w_packed = tl.load(w_ptrs)"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "find / -name \"kernels.cu\" 2>/dev/null | grep -i bnb | head -3",
"description": "Find bnb kernels source"
}(Bash completed with no output){
"command": "ls .venv/lib/python3.11/site-packages/bitsandbytes/ 2>&1 | head ; find .venv -name \"*.cu\" 2>/dev/null | head -10",
"description": "Look for bnb sources"
}lsd: .venv/lib/python3.11/site-packages/bitsandbytes/: No such file or directory (os error 2).{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/ 2>&1 | head",
"description": "Find bnb dir"
}__init__.py
__main__.py
__pycache__
_ops.py
autograd
backends
cextension.py
consts.py
cuda_specs.py
diagnostics--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
@@ -0,0 +1,153 @@
+"""Test a CUDA C++ GEMV via load_inline."""
+import statistics
+import torch
+from torch.utils.cpp_extension import load_inline
+
+
+CUDA_SRC = r"""
+#include <cuda_runtime.h>
+#include <cuda_bf16.h>
+#include <cstdint>
+
+extern "C" __global__ void w4a16_gemv_v1(
+ const __nv_bfloat16* __restrict__ x,
+ const uint8_t* __restrict__ w_q,
+ const __nv_bfloat16* __restrict__ scales,
+ const __nv_bfloat16* __restrict__ zeros,
+ float* __restrict__ out_fp32,
+ int N, int K
+) {
+ constexpr int BLOCK_N = 64;
+ constexpr int K_PER_SPLIT = 2048;
+ constexpr int BLOCK_K = 256;
+ constexpr int GROUP_SIZE = 128;
+
+ int n_block = blockIdx.x; // [0, N/BLOCK_N)
+ int k_block = blockIdx.y; // [0, SPLIT_K)
+
+ int n0 = n_block * BLOCK_N;
+ int k_start = k_block * K_PER_SPLIT;
+ int k_end = k_start + K_PER_SPLIT;
+
+ int tid = threadIdx.x;
+ int n_threads = blockDim.x; // 128 = 4 warps
+
+ // Each thread accumulates BLOCK_N / n_threads = 0.5 cols. So we have 2 threads per col.
+ // Or: distribute work as (k_chunk, n_chunk) per thread.
+ // Simpler: each thread takes BLOCK_N/n_threads N-cols and full K range.
+ // BLOCK_N=64, n_threads=128 → 0.5 col per thread. So 2 threads per N col.
+
+ // Actually: assign 1 N col per thread (use first 64 threads) for simplicity.
+ if (tid >= BLOCK_N) return;
+
+ int n = n0 + tid;
+ if (n >= N) return;
+
+ float acc = 0.0f;
+
+ for (int k = k_start; k < k_end; k += GROUP_SIZE) {
+ int g = k / GROUP_SIZE;
+ float s = (float)scales[g * N + n];
+ float z = (float)zeros[g * N + n];
+
+ // Process GROUP_SIZE K rows in this group
+ for (int kk = 0; kk < GROUP_SIZE; kk += 2) {
+ int k_row = k + kk;
+ uint8_t b = w_q[(k_row / 2) * N + n];
+ float w_lo = ((float)(b & 0xF) - z) * s;
+ float w_hi = ((float)((b >> 4) & 0xF) - z) * s;
+ float xa = (float)x[k_row];
+ float xb = (float)x[k_row + 1];
+ acc += xa * w_lo + xb * w_hi;
+ }
+ }
+
+ atomicAdd(&out_fp32[n], acc);
+}
+
+void launch(
+ const void* x, const void* w_q, const void* scales, const void* zeros,
+ void* out_fp32, int N, int K
+) {
+ constexpr int BLOCK_N = 64;
+ constexpr int K_PER_SPLIT = 2048;
+ int SPLIT_K = K / K_PER_SPLIT;
+ int n_blocks = (N + BLOCK_N - 1) / BLOCK_N;
+
+ dim3 grid(n_blocks, SPLIT_K);
+ int threads = 128;
+
+ w4a16_gemv_v1<<<grid, threads>>>(
+ (const __nv_bfloat16*)x,
+ (const uint8_t*)w_q,
+ (const __nv_bfloat16*)scales,
+ (const __nv_bfloat16*)zeros,
+ (float*)out_fp32,
+ N, K
+ );
+}
+"""
+
+
+CPP_SRC = r"""
+#include <torch/extension.h>
+extern void launch(const void* x, const void* w_q, const void* scales, const void* zeros, void* out_fp32, int N, int K);
+
+torch::Tensor gemv_cuda(torch::Tensor x, torch::Tensor w_q, torch::Tensor scales, torch::Tensor zeros, int N, int K) {
+ auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(x.device());
+ auto out = torch::zeros({1, N}, opts);
+ launch(x.data_ptr(), w_q.data_ptr(), scales.data_ptr(), zeros.data_ptr(), out.data_ptr(), N, K);
+ return out;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("gemv_cuda", &gemv_cuda, "");
+}
+"""
+
+
+def main():
+ mod = load_inline(
+ name="w4a16_cuda_test",
+ cpp_sources=[CPP_SRC],
+ cuda_sources=[CUDA_SRC],
+ functions=["gemv_cuda"],
+ with_cuda=True,
+ extra_cuda_cflags=["-O3", "-arch=sm_120a"],
+ verbose=False,
+ )
+
+ import reference
+ M, N, K = 1, 12288, 4096
+ reference.M = M; reference.N = N; reference.K = K
+ ref = reference.Model(M, N, K).cuda().eval()
+ x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
+ ref_out = ref(x).clone()
+
+ out = mod.gemv_cuda(x, ref.w_q, ref.scales, ref.zeros, N, K).to(torch.bfloat16)
+ print('shape:', out.shape, 'dtype:', out.dtype)
+ print('max diff:', (out.float() - ref_out.float()).abs().max().item())
+ print('correctness:', torch.allclose(out.float(), ref_out.float(), atol=0.1, rtol=0.1))
+
+ # Warmup
+ for _ in range(20): mod.gemv_cuda(x, ref.w_q, ref.scales, ref.zeros, N, K)
+ torch.cuda.synchronize()
+
+ scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
+ times = []
+ for _ in range(30):
+ scratch.zero_()
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
+ s.record(); mod.gemv_cuda(x, ref.w_q, ref.scales, ref.zeros, N, K); e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+
+ ms = statistics.median(times)
+ bytes_moved = M*K*2 + (K//2)*N + (K//128)*N*2 + (K//128)*N*2 + M*N*2
+ gbps = bytes_moved/(ms*1e-3)/1e9
+ print(f'CUDA gemv: ms={ms:.4f} gbps={gbps:.1f}')
+
+
+if __name__ == "__main__":
+ main()
{
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py",
"content": "<4857 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py (file state is current in your context — no need to Read it back){
"command": "timeout 180 uv run python _test_cuda.py 2>&1 | tail -10",
"description": "Test CUDA C++ baseline"
}13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
| ^~~~~~~~~~~~~~~
/home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda_test/main.cpp:17:1: error: redefinition of ‘void pybind11_init_w4a16_cuda_test(pybind11::module_&)’
17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
| ^~~~~~~~~~~~~~~
/home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda_test/main.cpp:13:1: note: ‘void pybind11_init_w4a16_cuda_test(pybind11::module_&)’ previously defined here
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
| ^~~~~~~~~~~~~~~
ninja: build stopped: subcommand failed.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
@@ -91,17 +91,13 @@
CPP_SRC = r"""
#include <torch/extension.h>
-extern void launch(const void* x, const void* w_q, const void* scales, const void* zeros, void* out_fp32, int N, int K);
+void launch(const void* x, const void* w_q, const void* scales, const void* zeros, void* out_fp32, int N, int K);
-torch::Tensor gemv_cuda(torch::Tensor x, torch::Tensor w_q, torch::Tensor scales, torch::Tensor zeros, int N, int K) {
+torch::Tensor gemv_cuda(torch::Tensor x, torch::Tensor w_q, torch::Tensor scales, torch::Tensor zeros, int64_t N, int64_t K) {
auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(x.device());
auto out = torch::zeros({1, N}, opts);
- launch(x.data_ptr(), w_q.data_ptr(), scales.data_ptr(), zeros.data_ptr(), out.data_ptr(), N, K);
+ launch(x.data_ptr(), w_q.data_ptr(), scales.data_ptr(), zeros.data_ptr(), out.data_ptr(), (int)N, (int)K);
return out;
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("gemv_cuda", &gemv_cuda, "");
}
"""
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py",
"old_string": "<618 chars \u2014 see diff>",
"new_string": "<543 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "rm -rf /home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda_test 2>&1 ; timeout 180 uv run python _test_cuda.py 2>&1 | tail -10",
"description": "Retry CUDA test"
}^
/home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda_test/cuda.cu(57): error: no suitable conversion function from "const __nv_bfloat16" to "float" exists
float xb = (float)x[k_row + 1];
^
4 errors detected in the compilation of "/home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda_test/cuda.cu".
[2/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=w4a16_cuda_test -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -fPIC -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda_test/main.cpp -o main.o
ninja: build stopped: subcommand failed.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
@@ -56,7 +56,7 @@
uint8_t b = w_q[(k_row / 2) * N + n];
float w_lo = ((float)(b & 0xF) - z) * s;
float w_hi = ((float)((b >> 4) & 0xF) - z) * s;
- float xa = (float)x[k_row];
+ float xa = __bfloat162float(x[k_row]);
float xb = (float)x[k_row + 1];
acc += xa * w_lo + xb * w_hi;
}
{
"replace_all": true,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py",
"old_string": "(float)x[k_row]",
"new_string": "__bfloat162float(x[k_row])"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py has been updated. All occurrences were successfully replaced. (file state is current in your context — no need to Read it back)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
@@ -57,7 +57,7 @@
float w_lo = ((float)(b & 0xF) - z) * s;
float w_hi = ((float)((b >> 4) & 0xF) - z) * s;
float xa = __bfloat162float(x[k_row]);
- float xb = (float)x[k_row + 1];
+ float xb = __bfloat162float(x[k_row + 1]);
acc += xa * w_lo + xb * w_hi;
}
}
{
"replace_all": true,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py",
"old_string": "(float)x[k_row + 1]",
"new_string": "__bfloat162float(x[k_row + 1])"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py has been updated. All occurrences were successfully replaced. (file state is current in your context — no need to Read it back)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
@@ -47,7 +47,7 @@
for (int k = k_start; k < k_end; k += GROUP_SIZE) {
int g = k / GROUP_SIZE;
- float s = (float)scales[g * N + n];
+ float s = __bfloat162float(scales[g * N + n]);
float z = (float)zeros[g * N + n];
// Process GROUP_SIZE K rows in this group
{
"replace_all": true,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py",
"old_string": "(float)scales[g * N + n]",
"new_string": "__bfloat162float(scales[g * N + n])"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py has been updated. All occurrences were successfully replaced. (file state is current in your context — no need to Read it back)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py
@@ -48,7 +48,7 @@
for (int k = k_start; k < k_end; k += GROUP_SIZE) {
int g = k / GROUP_SIZE;
float s = __bfloat162float(scales[g * N + n]);
- float z = (float)zeros[g * N + n];
+ float z = __bfloat162float(zeros[g * N + n]);
// Process GROUP_SIZE K rows in this group
for (int kk = 0; kk < GROUP_SIZE; kk += 2) {
{
"replace_all": true,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py",
"old_string": "(float)zeros[g * N + n]",
"new_string": "__bfloat162float(zeros[g * N + n])"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_cuda.py has been updated. All occurrences were successfully replaced. (file state is current in your context — no need to Read it back){
"command": "rm -rf /home/infatoshi/.cache/torch_extensions/py311_cu130/w4a16_cuda_test 2>&1 ; timeout 180 uv run python _test_cuda.py 2>&1 | tail -10",
"description": "Retry CUDA test"
}shape: torch.Size([1, 12288]) dtype: torch.bfloat16
max diff: 0.03125
correctness: True
CUDA gemv: ms=0.1355 gbps=197.6--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -196,15 +196,18 @@
BLOCK_K = 512 if K % 512 == 0 else 256
SPLIT_K = 8
num_warps = 8
+ num_stages = 2
elif n_tiles <= 128:
BLOCK_K = 256
SPLIT_K = 4
num_warps = 4
+ num_stages = 3
else: # large N (e.g., shape 0: N=12288)
BLOCK_K = 256
- SPLIT_K = 2
+ SPLIT_K = 4
num_warps = 4
- return BLOCK_N, BLOCK_K, SPLIT_K, num_warps, 2
+ num_stages = 3
+ return BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
def _gemm_config(M: int, N: int, K: int) -> tuple[int, int, int, int, int]:
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<606 chars \u2014 see diff>",
"new_string": "<684 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3 && uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Bench tweaked config"
}PASS
shape=0 variant=eager tflops=0.131 gbps=34.785 ms=0.770
shape=0 variant=compiled tflops=0.636 gbps=169.166 ms=0.158
shape=0 variant=sota tflops=2.253 gbps=599.075 ms=0.045
shape=0 variant=solution tflops=1.598 gbps=424.890 ms=0.063
shape=0 solution_peak_fraction=0.2360
shape=1 variant=eager tflops=4.146 gbps=35.761 ms=0.777
shape=1 variant=compiled tflops=20.224 gbps=174.455 ms=0.159
shape=1 variant=sota tflops=23.009 gbps=198.480 ms=0.140
shape=1 variant=solution tflops=45.766 gbps=394.795 ms=0.070
shape=1 solution_peak_fraction=0.2193
shape=2 variant=eager tflops=31.118 gbps=42.417 ms=0.828
shape=2 variant=compiled tflops=118.340 gbps=161.312 ms=0.218
shape=2 variant=sota tflops=132.072 gbps=180.029 ms=0.195
shape=2 variant=solution tflops=137.436 gbps=187.342 ms=0.188
shape=2 solution_peak_fraction=0.1041
shape=3 variant=eager tflops=0.178 gbps=47.415 ms=0.188
shape=3 variant=compiled tflops=0.467 gbps=124.183 ms=0.072
shape=3 variant=sota tflops=1.273 gbps=338.846 ms=0.026
shape=3 variant=solution tflops=0.990 gbps=263.494 ms=0.034
shape=3 solution_peak_fraction=0.1464
shape=4 variant=eager tflops=2.036 gbps=34.440 ms=0.923
shape=4 variant=compiled tflops=10.027 gbps=169.617 ms=0.187
shape=4 variant=sota tflops=11.189 gbps=189.268 ms=0.168
shape=4 variant=solution tflops=24.503 gbps=414.471 ms=0.077
shape=4 solution_peak_fraction=0.2303
peak_fraction: 0.1786
RESULT: OK--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -204,9 +204,9 @@
num_stages = 3
else: # large N (e.g., shape 0: N=12288)
BLOCK_K = 256
- SPLIT_K = 4
+ SPLIT_K = 2
num_warps = 4
- num_stages = 3
+ num_stages = 2
return BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<192 chars \u2014 see diff>",
"new_string": "<192 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3 && uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Re-bench with reverted shape0 config"
}PASS
shape=0 variant=eager tflops=0.131 gbps=34.862 ms=0.768
shape=0 variant=compiled tflops=0.638 gbps=169.577 ms=0.158
shape=0 variant=sota tflops=2.145 gbps=570.479 ms=0.047
shape=0 variant=solution tflops=1.881 gbps=500.214 ms=0.054
shape=0 solution_peak_fraction=0.2779
shape=1 variant=eager tflops=4.144 gbps=35.747 ms=0.777
shape=1 variant=compiled tflops=20.054 gbps=172.996 ms=0.161
shape=1 variant=sota tflops=22.273 gbps=192.135 ms=0.145
shape=1 variant=solution tflops=46.271 gbps=399.151 ms=0.070
shape=1 solution_peak_fraction=0.2218
shape=2 variant=eager tflops=31.168 gbps=42.486 ms=0.827
shape=2 variant=compiled tflops=120.114 gbps=163.730 ms=0.215
shape=2 variant=sota tflops=132.180 gbps=180.177 ms=0.195
shape=2 variant=solution tflops=138.155 gbps=188.322 ms=0.187
shape=2 solution_peak_fraction=0.1046
shape=3 variant=eager tflops=0.177 gbps=47.091 ms=0.190
shape=3 variant=compiled tflops=0.478 gbps=127.125 ms=0.070
shape=3 variant=sota tflops=1.224 gbps=325.601 ms=0.027
shape=3 variant=solution tflops=0.937 gbps=249.477 ms=0.036
shape=3 solution_peak_fraction=0.1386
shape=4 variant=eager tflops=2.043 gbps=34.558 ms=0.920
shape=4 variant=compiled tflops=10.124 gbps=171.255 ms=0.186
shape=4 variant=sota tflops=11.080 gbps=187.429 ms=0.170
shape=4 variant=solution tflops=28.484 gbps=481.824 ms=0.066
shape=4 solution_peak_fraction=0.2677
peak_fraction: 0.1887
RESULT: OK{
"command": "<1173 chars \u2014 see below>",
"description": "Measure overheads"
}uv run python -c "
import torch
import statistics
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(20): torch.zeros((1, 12288), dtype=torch.float32, device='cuda')
torch.cuda.synchronize()
times = []
for _ in range(50):
torch.cuda.synchronize()
t0.record(); x = torch.zeros((1, 12288), dtype=torch.float32, device='cuda'); t1.record()
torch.cuda.synchronize()
times.append(t0.elapsed_time(t1))
print(f'torch.zeros (12288,) fp32: {statistics.median(times)*1000:.2f} us')
times = []
for _ in range(50):
torch.cuda.synchronize()
t0.record(); x.zero_(); t1.record()
torch.cuda.synchronize()
times.append(t0.elapsed_time(t1))
print(f'tensor.zero_(): {statistics.median(times)*1000:.2f} us')
# Time .to(bf16)
y = torch.zeros((1, 12288), dtype=torch.float32, device='cuda')
for _ in range(20): y.to(torch.bfloat16)
torch.cuda.synchronize()
times = []
for _ in range(50):
torch.cuda.synchronize()
t0.record(); _ = y.to(torch.bfloat16); t1.record()
torch.cuda.synchronize()
times.append(t0.elapsed_time(t1))
print(f'.to(bf16): {statistics.median(times)*1000:.2f} us')
"torch.zeros (12288,) fp32: 9.74 us
tensor.zero_(): 8.03 us
.to(bf16): 6.34 us{
"command": "which ncu 2>&1 ; ncu --version 2>&1 | head -3",
"description": "Check ncu"
}/usr/local/cuda-13/bin/ncu
NVIDIA (R) Nsight Compute Command Line Profiler
Copyright (c) 2018-2026 NVIDIA Corporation
Version 2026.1.0.0 (build 37166530) (public-release){
"command": "<584 chars \u2014 see below>",
"description": "Profile GEMV with ncu"
}cat > /tmp/profile_gemv.py << 'EOF'
import torch
import solution
import reference
torch.manual_seed(42)
M, N, K = 1, 12288, 4096
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
for _ in range(20): sol(x)
torch.cuda.synchronize()
# Just one launch under ncu
sol(x)
torch.cuda.synchronize()
EOF
ncu --set basic --target-processes all --launch-skip 25 --launch-count 1 -k regex:gemv uv run python /tmp/profile_gemv.py 2>&1 | head -80Traceback (most recent call last):
File "/tmp/profile_gemv.py", line 2, in <module>
import solution
ModuleNotFoundError: No module named 'solution'
==ERROR== The application returned an error code (1).{
"command": "<738 chars \u2014 see below>",
"description": "Profile GEMV with ncu"
}cat > /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_profile.py << 'EOF'
import torch
import solution
import reference
torch.manual_seed(42)
M, N, K = 1, 12288, 4096
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
for _ in range(20): sol(x)
torch.cuda.synchronize()
torch.cuda.cudart().cudaProfilerStart()
sol(x)
torch.cuda.synchronize()
torch.cuda.cudart().cudaProfilerStop()
EOF
ncu --set basic --target-processes all --launch-skip 0 --launch-count 1 -k regex:gemv uv run python /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_profile.py 2>&1 | head -80==PROF== Connected to process 305111 (/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/bin/python3.11)
==ERROR== ERR_NVGPUCTRPERM - The user does not have permission to access NVIDIA GPU Performance Counters on the target device 0. For instructions on enabling permissions and to get more information see https://developer.nvidia.com/ERR_NVGPUCTRPERM
==PROF== Disconnected from process 305111--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -111,6 +111,18 @@
# ---------------------------------------------------------------------------
# GEMV kernel (M == 1) with split-K. Writes (SPLIT_K, N) fp32 partials.
# ---------------------------------------------------------------------------
+@triton.jit
+def cast_fp32_to_bf16_kernel(
+ src_ptr, dst_ptr, N,
+ BLOCK_N: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask = offs < N
+ val = tl.load(src_ptr + offs, mask=mask, other=0.0)
+ tl.store(dst_ptr + offs, val.to(tl.bfloat16), mask=mask)
+
+
@triton.jit
def w4a16_gemv_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<34 chars \u2014 see diff>",
"new_string": "<346 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -271,6 +271,7 @@
return out
# SPLIT_K > 1: fp32 atomic_add into pre-zeroed fp32 buffer, then cast.
out_fp32 = torch.zeros((1, N), dtype=torch.float32, device=x.device)
+ out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
w4a16_gemv_kernel[grid](
x, w_q, scales, zeros, out_fp32,
N, K,
@@ -287,7 +288,10 @@
num_warps=num_warps,
num_stages=num_stages,
)
- return out_fp32.to(torch.bfloat16)
+ cast_fp32_to_bf16_kernel[(triton.cdiv(N, 1024),)](
+ out_fp32, out, N, BLOCK_N=1024, num_warps=4, num_stages=2,
+ )
+ return out
BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = _gemm_config(M, N, K)
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<724 chars \u2014 see diff>",
"new_string": "<913 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3 && uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Bench with custom cast kernel"
}PASS
shape=0 variant=eager tflops=0.130 gbps=34.645 ms=0.773
shape=0 variant=compiled tflops=0.640 gbps=170.111 ms=0.157
shape=0 variant=sota tflops=2.253 gbps=599.289 ms=0.045
shape=0 variant=solution tflops=1.859 gbps=494.449 ms=0.054
shape=0 solution_peak_fraction=0.2747
shape=1 variant=eager tflops=4.144 gbps=35.746 ms=0.777
shape=1 variant=compiled tflops=19.216 gbps=165.763 ms=0.168
shape=1 variant=sota tflops=22.290 gbps=192.283 ms=0.145
shape=1 variant=solution tflops=46.165 gbps=398.235 ms=0.070
shape=1 solution_peak_fraction=0.2212
shape=2 variant=eager tflops=31.297 gbps=42.662 ms=0.823
shape=2 variant=compiled tflops=117.761 gbps=160.522 ms=0.219
shape=2 variant=sota tflops=131.393 gbps=179.104 ms=0.196
shape=2 variant=solution tflops=137.871 gbps=187.935 ms=0.187
shape=2 solution_peak_fraction=0.1044
shape=3 variant=eager tflops=0.178 gbps=47.259 ms=0.189
shape=3 variant=compiled tflops=0.503 gbps=133.832 ms=0.067
shape=3 variant=sota tflops=0.951 gbps=253.098 ms=0.035
shape=3 variant=solution tflops=1.002 gbps=266.514 ms=0.034
shape=3 solution_peak_fraction=0.1481
shape=4 variant=eager tflops=2.046 gbps=34.605 ms=0.918
shape=4 variant=compiled tflops=10.045 gbps=169.908 ms=0.187
shape=4 variant=sota tflops=11.082 gbps=187.464 ms=0.170
shape=4 variant=solution tflops=28.519 gbps=482.409 ms=0.066
shape=4 solution_peak_fraction=0.2680
peak_fraction: 0.1906
RESULT: OK--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_gemm_m1.py
@@ -0,0 +1,168 @@
+"""For M=1, test the tensor-core GEMM kernel padded to BLOCK_M=16 with split-K."""
+import statistics
+import torch
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.jit
+def gemm_split_k_kernel(
+ x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ K_PER_SPLIT: tl.constexpr,
+ SPLIT_K: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ pid_k = tl.program_id(2)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ BLOCK_KH: tl.constexpr = BLOCK_K // 2
+ offs_kh = tl.arange(0, BLOCK_KH)
+ offs_k = tl.arange(0, BLOCK_K)
+ n_iters: tl.constexpr = K_PER_SPLIT // BLOCK_K
+ n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
+ k_start_full = pid_k * K_PER_SPLIT
+
+ for k_iter in range(n_iters):
+ kbase_full = k_start_full + k_iter * BLOCK_K
+ kbase_packed = kbase_full // 2
+
+ w_ptrs = w_ptr + (kbase_packed + offs_kh)[:, None] * stride_wk + offs_n[None, :] * stride_wn
+ w_packed = tl.load(w_ptrs)
+ w_lo = (w_packed & 0xF).to(tl.bfloat16)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
+
+ if n_groups_per_block == 1:
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
+ w_lo_bf = (w_lo - z[None, :]) * s[None, :]
+ w_hi_bf = (w_hi - z[None, :]) * s[None, :]
+ else:
+ group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn)
+ z = tl.load(z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn)
+ w_lo_bf = (w_lo - z) * s
+ w_hi_bf = (w_hi - z) * s
+
+ w_join = tl.join(w_lo_bf, w_hi_bf)
+ w_perm = tl.permute(w_join, (0, 2, 1))
+ w = tl.reshape(w_perm, (BLOCK_K, BLOCK_N))
+
+ x_ptrs = x_ptr + offs_m[:, None] * stride_xm + (kbase_full + offs_k)[None, :] * stride_xk
+ x_mask = offs_m[:, None] < M
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0)
+
+ accumulator = tl.dot(x, w, accumulator)
+
+ out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ out_mask = offs_m[:, None] < M
+ if SPLIT_K == 1:
+ tl.store(out_ptrs, accumulator.to(tl.bfloat16), mask=out_mask)
+ else:
+ tl.atomic_add(out_ptrs, accumulator, mask=out_mask)
+
+
+def run(M, N, K, w_q, scales, zeros, x, BM, BN, BK, SK, NW, NS):
+ K_PER_SPLIT = K // SK
+ if SK == 1:
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ else:
+ out = torch.zeros((M, N), dtype=torch.float32, device=x.device)
+
+ grid = ((M + BM - 1) // BM, (N + BN - 1) // BN, SK)
+ gemm_split_k_kernel[grid](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0), out.stride(1),
+ GROUP_SIZE=GROUP_SIZE,
+ BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK,
+ K_PER_SPLIT=K_PER_SPLIT, SPLIT_K=SK,
+ num_warps=NW, num_stages=NS,
+ )
+ if SK > 1:
+ out = out.to(torch.bfloat16)
+ return out
+
+
+def time_call(fn, args, iters=20, warmup=5):
+ scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device="cuda")
+ for _ in range(warmup):
+ fn(*args)
+ torch.cuda.synchronize()
+ times = []
+ for _ in range(iters):
+ scratch.zero_()
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ fn(*args)
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
+
+
+def main():
+ import reference
+ for sh, (M, N, K) in [(0, (1, 12288, 4096)), (3, (1, 4096, 4096))]:
+ reference.M = M
+ reference.N = N
+ reference.K = K
+ ref = reference.Model(M, N, K).cuda().eval()
+ x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
+ ref_out = ref(x).clone()
+
+ bytes_moved = M * K * 2 + (K // 2) * N + (K // 128) * N * 2 + (K // 128) * N * 2 + M * N * 2
+ print(f"\n=== shape {sh} M={M} N={N} K={K} ===", flush=True)
+
+ for cfg in [
+ (16, 64, 128, 1, 4, 2),
+ (16, 64, 128, 2, 4, 2),
+ (16, 64, 128, 4, 4, 2),
+ (16, 64, 256, 1, 4, 2),
+ (16, 64, 256, 2, 4, 2),
+ (16, 64, 256, 4, 4, 2),
+ (16, 128, 128, 1, 4, 2),
+ (16, 128, 128, 2, 4, 2),
+ (16, 128, 256, 2, 8, 2),
+ (16, 256, 128, 4, 8, 2),
+ ]:
+ try:
+ BM, BN, BK, SK, NW, NS = cfg
+ if K % (BK * SK) != 0:
+ continue
+ sol_out = run(M, N, K, ref.w_q, ref.scales, ref.zeros, x, BM, BN, BK, SK, NW, NS)
+ if not torch.allclose(ref_out.float(), sol_out.float(), atol=0.1, rtol=0.1):
+ print(f"cfg={cfg} FAIL")
+ continue
+ ms = time_call(run, (M, N, K, ref.w_q, ref.scales, ref.zeros, x, BM, BN, BK, SK, NW, NS))
+ gbps = bytes_moved / (ms * 1e-3) / 1e9
+
... (181 more chars){
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_gemm_m1.py",
"content": "<5897 chars \u2014 see diff>"
}File created successfully at: /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/_test_gemm_m1.py (file state is current in your context — no need to Read it back){
"command": "timeout 200 uv run python _test_gemm_m1.py 2>&1 | tail -30",
"description": "Test GEMM for M=1"
}=== shape 0 M=1 N=12288 K=4096 ===
cfg=(16, 64, 128, 1, 4, 2) ms=0.0653 gbps=409.9
cfg=(16, 64, 128, 2, 4, 2) ms=0.0580 gbps=461.8
cfg=(16, 64, 128, 4, 4, 2) ms=0.0481 gbps=556.3
cfg=(16, 64, 256, 1, 4, 2) ms=0.0594 gbps=450.9
cfg=(16, 64, 256, 2, 4, 2) ms=0.0533 gbps=502.0
cfg=(16, 64, 256, 4, 4, 2) ms=0.0484 gbps=552.6
cfg=(16, 128, 128, 1, 4, 2) ms=0.0980 gbps=273.0
cfg=(16, 128, 128, 2, 4, 2) ms=0.0747 gbps=358.4
cfg=(16, 128, 256, 2, 8, 2) ms=0.0809 gbps=330.7
cfg=(16, 256, 128, 4, 8, 2) ms=0.1015 gbps=263.7
=== shape 3 M=1 N=4096 K=4096 ===
cfg=(16, 64, 128, 1, 4, 2) ms=0.0547 gbps=163.3
cfg=(16, 64, 128, 2, 4, 2) ms=0.0420 gbps=212.4
cfg=(16, 64, 128, 4, 4, 2) ms=0.0334 gbps=267.5
cfg=(16, 64, 256, 1, 4, 2) ms=0.0483 gbps=184.9
cfg=(16, 64, 256, 2, 4, 2) ms=0.0348 gbps=256.6
cfg=(16, 64, 256, 4, 4, 2) ms=0.0275 gbps=325.2
cfg=(16, 128, 128, 1, 4, 2) ms=0.0916 gbps=97.4
cfg=(16, 128, 128, 2, 4, 2) ms=0.0580 gbps=154.0
cfg=(16, 128, 256, 2, 8, 2) ms=0.0532 gbps=168.0
cfg=(16, 256, 128, 4, 8, 2) ms=0.0647 gbps=138.0{
"command": "<1534 chars \u2014 see below>",
"description": "Test extra shape 2 configs"
}cat > /tmp/_tune_s2.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import _tune_gemm
_tune_gemm.main()
EOF
timeout 600 uv run python -c "
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import _tune_gemm
import torch
import statistics
import reference
M, N, K = 256, 12288, 4096
reference.M = M; reference.N = N; reference.K = K
ref = reference.Model(M, N, K).cuda().eval()
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
ref_out = ref(x).clone()
bytes_moved = M*K*2 + (K//2)*N + (K//128)*N*2 + (K//128)*N*2 + M*N*2
print(f'shape 2 M={M} N={N} K={K}')
# Larger configs not in original tuner
configs = []
for bm in [128, 256]:
for bn in [64, 128]:
for bk in [128, 256]:
for nw in [4, 8]:
for ns in [2, 3]:
configs.append((bm, bn, bk, nw, ns))
for cfg in configs:
bm, bn, bk, nw, ns = cfg
try:
sol = _tune_gemm.run(M, N, K, ref.w_q, ref.scales, ref.zeros, x, bm, bn, bk, nw, ns)
if not torch.allclose(ref_out.float(), sol.float(), atol=0.1, rtol=0.1):
continue
ms = _tune_gemm.time_call(_tune_gemm.run, (M, N, K, ref.w_q, ref.scales, ref.zeros, x, bm, bn, bk, nw, ns), iters=20, warmup=8)
gbps = bytes_moved/(ms*1e-3)/1e9
flops = 2*M*N*K
tflops = flops/(ms*1e-3)/1e12
print(f' cfg={cfg} ms={ms:.4f} gbps={gbps:.1f} tflops={tflops:.1f}')
except Exception as e:
pass
"shape 2 M=256 N=12288 K=4096
cfg=(128, 64, 128, 4, 2) ms=0.2171 gbps=161.8 tflops=118.7
cfg=(128, 64, 128, 4, 3) ms=0.2818 gbps=124.6 tflops=91.4
cfg=(128, 64, 128, 8, 2) ms=0.2344 gbps=149.9 tflops=110.0
cfg=(128, 64, 128, 8, 3) ms=0.2052 gbps=171.2 tflops=125.6
cfg=(128, 64, 256, 4, 2) ms=0.2535 gbps=138.6 tflops=101.7
cfg=(128, 64, 256, 8, 2) ms=0.1822 gbps=192.8 tflops=141.4
cfg=(128, 128, 128, 4, 2) ms=0.3330 gbps=105.5 tflops=77.4
cfg=(128, 128, 128, 4, 3) ms=0.3825 gbps=91.8 tflops=67.4
cfg=(128, 128, 128, 8, 2) ms=0.2671 gbps=131.5 tflops=96.5
cfg=(128, 128, 128, 8, 3) ms=0.2991 gbps=117.4 tflops=86.2
cfg=(256, 64, 128, 4, 2) ms=0.4340 gbps=80.9 tflops=59.4
cfg=(256, 64, 128, 8, 2) ms=0.2720 gbps=129.2 tflops=94.7
cfg=(256, 128, 128, 4, 2) ms=3.0895 gbps=11.4 tflops=8.3
cfg=(256, 128, 128, 8, 2) ms=0.2899 gbps=121.2 tflops=88.9{
"command": "<709 chars \u2014 see below>",
"description": "Inspect compiled Triton cache"
}uv run python -c "
import torch
import solution
import reference
torch.manual_seed(42)
M, N, K = 1, 12288, 4096
ref = reference.Model(M, N, K).cuda().eval()
sol = solution.Model(M, N, K).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
sol(x); torch.cuda.synchronize()
# Get the kernel and its compiled binaries
import triton
# Show compiled cache
print(list(solution.w4a16_gemv_kernel.cache.keys())[0])
key = list(solution.w4a16_gemv_kernel.cache.keys())[0]
v = solution.w4a16_gemv_kernel.cache[key]
print(v)
print(list(v.keys())[0])
kbin = list(v.values())[0]
print('Type:', type(kbin))
print('Attrs:', dir(kbin))
" 2>&1 | tail -20Traceback (most recent call last):
File "<string>", line 16, in <module>
AttributeError: 'JITFunction' object has no attribute 'cache'{
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"offset": 240,
"limit": 50
}240 zeros: torch.Tensor,
241 group_size: int = GROUP_SIZE,
242 ) -> torch.Tensor:
243 M, K = x.shape
244 K_packed, N = w_q.shape
245 assert K_packed * 2 == K
246 assert K % group_size == 0
247
248 if M == 1:
249 BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = _gemv_config(N, K)
250 K_PER_SPLIT = K // SPLIT_K
251 n_tiles = triton.cdiv(N, BLOCK_N)
252 grid = (n_tiles, SPLIT_K)
253 if SPLIT_K == 1:
254 out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
255 w4a16_gemv_kernel[grid](
256 x, w_q, scales, zeros, out,
257 N, K,
258 x.stride(1),
259 w_q.stride(0), w_q.stride(1),
260 scales.stride(0), scales.stride(1),
261 zeros.stride(0), zeros.stride(1),
262 out.stride(1),
263 GROUP_SIZE=group_size,
264 BLOCK_N=BLOCK_N,
265 BLOCK_K=BLOCK_K,
266 K_PER_SPLIT=K_PER_SPLIT,
267 SPLIT_K=SPLIT_K,
268 num_warps=num_warps,
269 num_stages=num_stages,
270 )
271 return out
272 # SPLIT_K > 1: fp32 atomic_add into pre-zeroed fp32 buffer, then cast.
273 out_fp32 = torch.zeros((1, N), dtype=torch.float32, device=x.device)
274 out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
275 w4a16_gemv_kernel[grid](
... (539 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -58,6 +58,9 @@
n_iters = K // BLOCK_K
n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
+ GROUPS_PER_BLOCK: tl.constexpr = BLOCK_K // GROUP_SIZE
+ KH_PER_GROUP: tl.constexpr = GROUP_SIZE // 2
+
for k_iter in range(n_iters):
kbase_packed = k_iter * BLOCK_KH
kbase_full = k_iter * BLOCK_K
@@ -72,22 +75,22 @@
w_lo = (w_packed & 0xF).to(tl.bfloat16)
w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- if n_groups_per_block == 1:
- group_idx = kbase_full // GROUP_SIZE
- s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
- z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
- w_lo_bf = (w_lo - z[None, :]) * s[None, :]
- w_hi_bf = (w_hi - z[None, :]) * s[None, :]
- else:
- group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
- s = tl.load(
- s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
- )
- z = tl.load(
- z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn
- )
- w_lo_bf = (w_lo - z) * s
- w_hi_bf = (w_hi - z) * s
+ # Load (GROUPS_PER_BLOCK, BLOCK_N) of (s, z), then broadcast to BLOCK_KH rows.
+ group_base = kbase_full // GROUP_SIZE
+ group_offs = group_base + tl.arange(0, GROUPS_PER_BLOCK)
+ s_g = tl.load(s_ptr + group_offs[:, None] * stride_sg + offs_n[None, :] * stride_sn)
+ z_g = tl.load(z_ptr + group_offs[:, None] * stride_zg + offs_n[None, :] * stride_zn)
+ # Broadcast to (BLOCK_KH, BLOCK_N): each row repeats KH_PER_GROUP times.
+ s_b = tl.reshape(
+ tl.broadcast_to(s_g[:, None, :], (GROUPS_PER_BLOCK, KH_PER_GROUP, BLOCK_N)),
+ (BLOCK_KH, BLOCK_N),
+ )
+ z_b = tl.reshape(
+ tl.broadcast_to(z_g[:, None, :], (GROUPS_PER_BLOCK, KH_PER_GROUP, BLOCK_N)),
+ (BLOCK_KH, BLOCK_N),
+ )
+ w_lo_bf = (w_lo - z_b) * s_b
+ w_hi_bf = (w_hi - z_b) * s_b
w_join = tl.join(w_lo_bf, w_hi_bf) # (BLOCK_KH, BLOCK_N, 2)
w_perm = tl.permute(w_join, (0, 2, 1)) # (BLOCK_KH, 2, BLOCK_N)
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<1901 chars \u2014 see diff>",
"new_string": "<2077 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness"
}PASS{
"command": "uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Bench with refactored multi-group"
}shape=0 variant=eager tflops=0.131 gbps=34.823 ms=0.769
shape=0 variant=compiled tflops=0.634 gbps=168.654 ms=0.159
shape=0 variant=sota tflops=2.257 gbps=600.149 ms=0.045
shape=0 variant=solution tflops=1.787 gbps=475.345 ms=0.056
shape=0 solution_peak_fraction=0.2641
shape=1 variant=eager tflops=4.145 gbps=35.758 ms=0.777
shape=1 variant=compiled tflops=19.995 gbps=172.480 ms=0.161
shape=1 variant=sota tflops=22.596 gbps=194.916 ms=0.143
shape=1 variant=solution tflops=42.845 gbps=369.590 ms=0.075
shape=1 solution_peak_fraction=0.2053
shape=2 variant=eager tflops=31.249 gbps=42.596 ms=0.825
shape=2 variant=compiled tflops=120.078 gbps=163.681 ms=0.215
shape=2 variant=sota tflops=132.561 gbps=180.696 ms=0.194
shape=2 variant=solution tflops=120.953 gbps=164.874 ms=0.213
shape=2 solution_peak_fraction=0.0916
shape=3 variant=eager tflops=0.177 gbps=47.119 ms=0.190
shape=3 variant=compiled tflops=0.450 gbps=119.863 ms=0.074
shape=3 variant=sota tflops=1.238 gbps=329.445 ms=0.027
shape=3 variant=solution tflops=0.862 gbps=229.474 ms=0.039
shape=3 solution_peak_fraction=0.1275
shape=4 variant=eager tflops=2.050 gbps=34.671 ms=0.917
shape=4 variant=compiled tflops=10.298 gbps=174.199 ms=0.182
shape=4 variant=sota tflops=11.286 gbps=190.905 ms=0.166
shape=4 variant=solution tflops=24.955 gbps=422.133 ms=0.075
shape=4 solution_peak_fraction=0.2345
peak_fraction: 0.1715
RESULT: OK--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -58,9 +58,6 @@
n_iters = K // BLOCK_K
n_groups_per_block: tl.constexpr = BLOCK_K // GROUP_SIZE
- GROUPS_PER_BLOCK: tl.constexpr = BLOCK_K // GROUP_SIZE
- KH_PER_GROUP: tl.constexpr = GROUP_SIZE // 2
-
for k_iter in range(n_iters):
kbase_packed = k_iter * BLOCK_KH
kbase_full = k_iter * BLOCK_K
@@ -75,22 +72,22 @@
w_lo = (w_packed & 0xF).to(tl.bfloat16)
w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- # Load (GROUPS_PER_BLOCK, BLOCK_N) of (s, z), then broadcast to BLOCK_KH rows.
- group_base = kbase_full // GROUP_SIZE
- group_offs = group_base + tl.arange(0, GROUPS_PER_BLOCK)
- s_g = tl.load(s_ptr + group_offs[:, None] * stride_sg + offs_n[None, :] * stride_sn)
- z_g = tl.load(z_ptr + group_offs[:, None] * stride_zg + offs_n[None, :] * stride_zn)
- # Broadcast to (BLOCK_KH, BLOCK_N): each row repeats KH_PER_GROUP times.
- s_b = tl.reshape(
- tl.broadcast_to(s_g[:, None, :], (GROUPS_PER_BLOCK, KH_PER_GROUP, BLOCK_N)),
- (BLOCK_KH, BLOCK_N),
- )
- z_b = tl.reshape(
- tl.broadcast_to(z_g[:, None, :], (GROUPS_PER_BLOCK, KH_PER_GROUP, BLOCK_N)),
- (BLOCK_KH, BLOCK_N),
- )
- w_lo_bf = (w_lo - z_b) * s_b
- w_hi_bf = (w_hi - z_b) * s_b
+ if n_groups_per_block == 1:
+ group_idx = kbase_full // GROUP_SIZE
+ s = tl.load(s_ptr + group_idx * stride_sg + offs_n * stride_sn)
+ z = tl.load(z_ptr + group_idx * stride_zg + offs_n * stride_zn)
+ w_lo_bf = (w_lo - z[None, :]) * s[None, :]
+ w_hi_bf = (w_hi - z[None, :]) * s[None, :]
+ else:
+ group_idx_per_row = (kbase_full // GROUP_SIZE) + (2 * offs_kh) // GROUP_SIZE
+ s = tl.load(
+ s_ptr + group_idx_per_row[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ )
+ z = tl.load(
+ z_ptr + group_idx_per_row[:, None] * stride_zg + offs_n[None, :] * stride_zn
+ )
+ w_lo_bf = (w_lo - z) * s
+ w_hi_bf = (w_hi - z) * s
w_join = tl.join(w_lo_bf, w_hi_bf) # (BLOCK_KH, BLOCK_N, 2)
w_perm = tl.permute(w_join, (0, 2, 1)) # (BLOCK_KH, 2, BLOCK_N)
{
"replace_all": false,
"file_path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"old_string": "<2077 chars \u2014 see diff>",
"new_string": "<1901 chars \u2014 see diff>"
}The file /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | tail -3 && uv run python benchmark.py 2>&1 | grep -E \"shape=|peak_fraction|RESULT\"",
"description": "Verify reverted"
}PASS
shape=0 variant=eager tflops=0.131 gbps=34.936 ms=0.766
shape=0 variant=compiled tflops=0.602 gbps=160.162 ms=0.167
shape=0 variant=sota tflops=2.278 gbps=605.799 ms=0.044
shape=0 variant=solution tflops=1.770 gbps=470.666 ms=0.057
shape=0 solution_peak_fraction=0.2615
shape=1 variant=eager tflops=4.139 gbps=35.707 ms=0.778
shape=1 variant=compiled tflops=20.094 gbps=173.341 ms=0.160
shape=1 variant=sota tflops=23.170 gbps=199.874 ms=0.139
shape=1 variant=solution tflops=46.282 gbps=399.242 ms=0.070
shape=1 solution_peak_fraction=0.2218
shape=2 variant=eager tflops=31.237 gbps=42.580 ms=0.825
shape=2 variant=compiled tflops=117.221 gbps=159.786 ms=0.220
shape=2 variant=sota tflops=131.511 gbps=179.265 ms=0.196
shape=2 variant=solution tflops=136.655 gbps=186.277 ms=0.189
shape=2 solution_peak_fraction=0.1035
shape=3 variant=eager tflops=0.177 gbps=46.984 ms=0.190
shape=3 variant=compiled tflops=0.503 gbps=133.928 ms=0.067
shape=3 variant=sota tflops=0.959 gbps=255.297 ms=0.035
shape=3 variant=solution tflops=1.005 gbps=267.536 ms=0.033
shape=3 solution_peak_fraction=0.1486
shape=4 variant=eager tflops=2.042 gbps=34.542 ms=0.920
shape=4 variant=compiled tflops=9.964 gbps=168.553 ms=0.189
shape=4 variant=sota tflops=11.116 gbps=188.032 ms=0.169
shape=4 variant=solution tflops=28.340 gbps=479.382 ms=0.066
shape=4 solution_peak_fraction=0.2663
peak_fraction: 0.1884
RESULT: OK{
"command": "<1220 chars \u2014 see below>",
"description": "More configs for shape 3"
}uv run python -c "
import statistics
import torch
import _tune_gemv as tg
import reference
torch.manual_seed(42)
M, N, K = 1, 4096, 4096
reference.M = M; reference.N = N; reference.K = K
ref = reference.Model(M, N, K).cuda().eval()
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
ref_out = ref(x).clone()
# Test best configs explicitly
for cfg in [(64, 512, 8, 8, 2), (64, 256, 8, 4, 2), (64, 256, 8, 8, 2), (64, 256, 4, 4, 2), (64, 512, 4, 4, 2), (64, 128, 8, 4, 2), (64, 128, 16, 4, 2), (32, 512, 8, 4, 2), (32, 512, 8, 8, 2)]:
try:
bn, bk, sk, nw, ns = cfg
if K % (bk * sk) != 0: continue
sol_out = tg.run(M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns)
if not torch.allclose(ref_out.float(), sol_out.float(), atol=0.1, rtol=0.1):
print(f'{cfg} FAIL')
continue
ms = tg.time_call(tg.run, (M, N, K, ref.w_q, ref.scales, ref.zeros, x, bn, bk, sk, nw, ns), iters=30, warmup=8)
bytes_moved = M*K*2 + (K//2)*N + (K//128)*N*2 + (K//128)*N*2 + M*N*2
gbps = bytes_moved/(ms*1e-3)/1e9
print(f'{cfg} ms={ms:.4f} gbps={gbps:.1f}')
except Exception as e:
print(f'{cfg} ERR: {e}')
" 2>&1 | tail -15