"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel, optimized.
Persistent kernel for M=1 (amortize launch overhead), standard tiled for M>1.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "gemm_w4a16"
SUPPORTS_PRECISIONS = ["int4_bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
GROUP_SIZE = 128
@triton.autotune(
configs=[
triton.Config({'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_N': 64, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_N': 512, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
],
key=['M', 'N', 'K_half'],
)
@triton.jit
def _w4a16_gemv_persistent_kernel(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
M, N, K_half, n_groups,
stride_xm, stride_xk,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
stride_z_g, stride_z_n,
stride_om, stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K_HALF: tl.constexpr,
):
num_programs = tl.num_programs(0)
total_n_blocks = tl.cdiv(N, BLOCK_N)
for pid in range(tl.program_id(0), total_n_blocks, num_programs):
n_start = pid * BLOCK_N
n_off = n_start + tl.arange(0, BLOCK_N)
n_mask = n_off < N
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
for kh_start in range(0, K_half, BLOCK_K_HALF):
kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
kh_mask = kh_off < K_half
k_even = kh_off * 2
k_odd = k_even + 1
x_even = tl.load(
X_ptr + k_even * stride_xk,
mask=kh_mask,
other=0.0,
)
x_odd = tl.load(
X_ptr + k_odd * stride_xk,
mask=kh_mask,
other=0.0,
)
wq_packed = tl.load(
Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n,
mask=kh_mask[:, None] & n_mask[None, :],
other=0,
)
wq_lo = wq_packed & 0x0F
wq_hi = (wq_packed >> 4) & 0x0F
g_idx = (kh_start * 2) // GROUP_SIZE
s = tl.load(
S_ptr + g_idx * stride_s_g + n_off * stride_s_n,
mask=n_mask,
other=1.0,
)[None, :]
z = tl.load(
Z_ptr + g_idx * stride_z_g + n_off * stride_z_n,
mask=n_mask,
other=0.0,
)[None, :]
s_bf = s.to(tl.bfloat16)
z_bf = z.to(tl.bfloat16)
w_deq_lo = (wq_lo.to(tl.bfloat16) - z_bf) * s_bf
w_deq_hi = (wq_hi.to(tl.bfloat16) - z_bf) * s_bf
x_even_bf = x_even.to(tl.bfloat16)[:, None]
x_odd_bf = x_odd.to(tl.bfloat16)[:, None]
lo = tl.sum(x_even_bf * w_deq_lo, axis=0)
hi = tl.sum(x_odd_bf * w_deq_hi, axis=0)
acc += lo + hi
tl.store(
Out_ptr + n_off * stride_on,
acc.to(tl.bfloat16),
mask=n_mask,
)
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
],
key=['M', 'N', 'K_half'],
)
@triton.jit
def _w4a16_gemm_kernel(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
M, N, K_half, n_groups,
stride_xm, stride_xk,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
stride_z_g, stride_z_n,
stride_om, stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K_HALF: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
m_start = pid_m * BLOCK_M
n_start = pid_n * BLOCK_N
m_off = (m_start + tl.arange(0, BLOCK_M))[:, None]
n_off = (n_start + tl.arange(0, BLOCK_N))[None, :]
m_mask = m_off < M
n_mask = n_off < N
n_off_1d = n_start + tl.arange(0, BLOCK_N)
n_mask_1d = n_off_1d < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for kh_start in range(0, K_half, BLOCK_K_HALF):
kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
kh_mask = kh_off < K_half
k_even = (kh_off * 2)
k_odd = k_even + 1
x_even = tl.load(
X_ptr + m_off * stride_xm + k_even[None, :] * stride_xk,
mask=m_mask & kh_mask[None, :],
other=0.0,
)
x_odd = tl.load(
X_ptr + m_off * stride_xm + k_odd[None, :] * stride_xk,
mask=m_mask & kh_mask[None, :],
other=0.0,
)
wq_packed = tl.load(
Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off * stride_wq_n,
mask=kh_mask[:, None] & n_mask,
other=0,
)
wq_lo = wq_packed & 0x0F
wq_hi = (wq_packed >> 4) & 0x0F
g_idx = (kh_start * 2) // GROUP_SIZE
s = tl.load(
S_ptr + g_idx * stride_s_g + n_off_1d * stride_s_n,
mask=n_mask_1d,
other=1.0,
)[None, :]
z = tl.load(
Z_ptr + g_idx * stride_z_g + n_off_1d * stride_z_n,
mask=n_mask_1d,
other=0.0,
)[None, :]
s_bf = s.to(tl.bfloat16)
z_bf = z.to(tl.bfloat16)
w_deq_lo = (wq_lo.to(tl.bfloat16) - z_bf) * s_bf
w_deq_hi = (wq_hi.to(tl.bfloat16) - z_bf) * s_bf
acc += tl.dot(x_even.to(tl.bfloat16), w_deq_lo).to(tl.float32)
acc += tl.dot(x_odd.to(tl.bfloat16), w_deq_hi).to(tl.float32)
tl.store(
Out_ptr + m_off * stride_om + n_off * stride_on,
acc.to(tl.bfloat16),
mask=m_mask & n_mask,
)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0
assert K % 2 == 0
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
w_g = w_full.view(n_groups, group_size, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
lo = w_q[0::2].to(torch.uint8) & 0xF
hi = w_q[1::2].to(torch.uint8) & 0xF
w_packed = (lo | (hi << 4)).contiguous()
self.register_buffer("w_q", w_packed)
self.register_buffer("scales", scales_2d)
self.register_buffer("zeros", zeros_2d)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return w4a16_gemm(x, self.w_q, self.scales, self.zeros, self.M, self.N, self.K, self.group_size)
def w4a16_gemm(x, w_q, scales, zeros, M, N, K, group_size):
x = x.contiguous().to(torch.bfloat16)
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
K_half = K // 2
n_groups = K // group_size
if M == 1:
num_sms = 128
_w4a16_gemv_persistent_kernel[(num_sms,)](
x, w_q, scales, zeros, out,
M, N, K_half, n_groups,
x.stride(0), x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out.stride(0), out.stride(1),
GROUP_SIZE=group_size,
)
else:
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']), triton.cdiv(N, meta['BLOCK_N']))
_w4a16_gemm_kernel[grid](
x, w_q, scales, zeros, out,
M, N, K_half, n_groups,
x.stride(0), x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out.stride(0), out.stride(1),
GROUP_SIZE=group_size,
)
return out
def get_inputs():
x = torch.randn(M, K, dtype=torch.bfloat16)
return [x]
def get_init_inputs():
return [M, N, K]
shape=0 variant=eager tflops=0.131 gbps=34.718 ms=0.771
shape=0 variant=compiled tflops=0.628 gbps=167.038 ms=0.160
shape=0 variant=solution tflops=0.870 gbps=231.299 ms=0.116
shape=0 solution_peak_fraction=0.1285
shape=1 variant=eager tflops=4.158 gbps=35.865 ms=0.775
shape=1 variant=compiled tflops=20.495 gbps=176.800 ms=0.157
shape=1 variant=solution tflops=34.801 gbps=300.208 ms=0.093
shape=1 solution_peak_fraction=0.1668
shape=2 variant=eager tflops=31.107 gbps=42.402 ms=0.828
shape=2 variant=compiled tflops=119.446 gbps=162.819 ms=0.216
shape=2 variant=solution tflops=86.629 gbps=118.086 ms=0.297
shape=2 solution_peak_fraction=0.0656
shape=3 variant=eager tflops=0.178 gbps=47.419 ms=0.188
shape=3 variant=compiled tflops=0.460 gbps=122.520 ms=0.073
shape=3 variant=solution tflops=0.380 gbps=101.120 ms=0.088
shape=3 solution_peak_fraction=0.0562
shape=4 variant=eager tflops=2.052 gbps=34.706 ms=0.916
shape=4 variant=compiled tflops=10.262 gbps=173.590 ms=0.183
shape=4 variant=solution tflops=21.396 gbps=361.917 ms=0.088
shape=4 solution_peak_fraction=0.2011
peak_fraction: 0.1097
RESULT: OK
PASS
{
"problem": "07_w4a16_gemm",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-max-preview",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1097,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 4330754, "output_tokens": 43661, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 13640, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive W4A16 weight-only quantized GEMM reference (correctness only).
2:
3: AWQ/GPTQ-style scheme:
4: x: (M, K) bf16
5: w_q: (K // 2, N) uint8 -- two int4 weights packed per byte (low nibble = even-K, high = odd-K)
6: scales: (K // group, N) bf16
7: zeros: (K // group, N) bf16 -- asymmetric (stored already as float zero-point)
8: out: (M, N) bf16
9:
10: Dequant (per group along K):
11: w_bf[k, n] = (w_q[k, n] - zeros[k // group, n]) * scales[k // group, n]
12: where w_q[k, n] is the unpacked 4-bit value (0..15).
13:
14: This reference unpacks to a full bf16 matrix and then runs torch.matmul. Slow and
15: memory-heavy on the dequant; the agent's solution must fuse unpack+GEMM.
16: """
17: from __future__ import annotations
18:
19: import torch
20: import torch.nn as nn
21:
22: OP_TYPE = "gemm_w4a16"
23: SUPPORTED_PRECISIONS = ["int4_bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26: GROUP_SIZE = 128
27:
28:
29: def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
30: """Pack (K, N) uint8 in [0,15] into (K//2, N) uint8.
31:
32: Even rows go in the low nibble, odd rows in the high nibble.
33: """
34: K, N = w_q.shape
35: assert K % 2 == 0
36: lo = w_q[0::2].to(torch.uint8) & 0xF
37: hi = w_q[1::2].to(torch.uint8) & 0xF
38: ret
... (3312 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for FP8 GEMM.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or None
42:
43: # --- Per-shape
... (2909 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for FP8 GEMM.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "compute")
40: flops_formul
... (3415 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for W4A16 GEMM.
2:
3: Llama-style up_proj / qkv_proj shapes. Decode (M=1) is the bandwidth-bound
4: case every inference engine optimizes -- it's the bar to beat.
5: """
6:
7: SHAPES = [
8: {"M": 1, "N": 12288, "K": 4096}, # decode: memory-bound on int4 weight read
9: {"M": 32, "N": 12288, "K": 4096}, # small prefill: mixed regime
10: {"M": 256, "N": 12288, "K": 4096}, # larger prefill: approaching compute
11: {"M": 1, "N": 4096, "K": 4096}, # decode: square shape
12: {"M": 16, "N": 14336, "K": 4096}, # speculative-decode-ish
13: ]
(End of file - total 13 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to write a weight-only int4 quantized GEMM (W4A16) for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The scheme is AWQ/GPTQ-style asymmetric int4 with explicit zero-points and per-group bf16 scales. Inputs are x of shape (M, K) bf16, w_q of shape (K // 2, N) uint8 (two int4 weights packed per byte, low nibble = even-K row, high nibble = odd-K row), scales of shape (K // 128, N) bf16, and zeros of shape (K // 128, N) bf16. Group size is 128 along K. Dequant per group is w_bf[k, n] = (unpack(w_q)[k, n] - zeros[k // 128, n]) * scales[k // 128, n], and the output is (M, N) bf16. Correctness tolerance is 0.10 abs/rel — group-quant adds noise on top of bf16 accumulator slop. The shapes you have to handle are M=1 N=12288 K=4096 (decode, memory-bound on the int4 weight read), M=32 N=12288 K=4096 (small prefill, mixed regime), M=256 N=12288 K=4096 (larger prefill, approaching compute-bound), M=1 N=4096 K=4096 (decode, square), and M=16 N=14336 K=4096 (speculative-decode-ish).
4:
5: This needs to be a real custom
... (1082 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 07_w4a16_gemm
2: display_name: "W4A16 Weight-only Quantized GEMM"
3: precision: int4_bf16
4: regime: memory # decode-dominant; M=1 is bandwidth-bound on the int4 weight stream
5:
6: # Dense-equivalent FLOPs (matmul work, ignoring dequant arithmetic).
7: flops_formula: "2 * M * N * K"
8:
9: # Bytes moved per call (memory roofline):
10: # x: M*K*2 (bf16 activations, streamed in once)
11: # w_q: (K/2)*N (packed int4, 0.5 B/elem)
12: # scales: (K/128)*N*2 (bf16 scales)
13: # zeros: (K/128)*N*2 (bf16 zero-points)
14: # out: M*N*2 (bf16 store)
15: bytes_formula: "M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.10 # group-quant adds noise on top of bf16 accumulator slop
23:
24: # Forbidden ops -- agent must write the unpack + GEMM themselves, not call a
25: # vendor library that does both.
26: forbidden:
27: - "bitsandbytes.functional.dequantize_4bit"
28: - "bitsandbytes.functional.gemv_4bit"
29: - "marlin_kernel.gemm"
30: - "torch.nn.functional.linear"
31:
32: sota:
33: name: "bitsandbytes NF4 (gemv_4bit / dequantize_4bit + matmul)"
34: url: "https://github.com/TimDettmers/bitsandbytes"
35: function: "bitsandbytes.functional.gemv_4bit"
36: notes: |
37:
... (757 more chars){
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total,driver_version --format=csv,noheader",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB, 595.58.03
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for W4A16 GEMM.
2:
3: Library survey on RTX PRO 6000 Blackwell (SM120, CC 12.0):
4:
5: - Marlin (IST-DASLab): no SM120 kernels (Ampere/Hopper only). Skip.
6: - GPTQ-Triton (fpgaminer): unmaintained; pure Triton path works on SM120
7: but is not faster than Marlin on its target HW
8: and has no Blackwell tuning. Skip as primary.
9: - AWQ (mit-han-lab/llm-awq): CUDA kernels not built for SM120 in the wheel.
10: Skip.
11: - bitsandbytes >= 0.49.2: CUDA kernels compile and run on SM120 (verified
12: on this machine). Different quant scheme (NF4,
13: symmetric, blocksize 64) than our reference's
14: AWQ-style asymmetric INT4 with group_size 128,
15: but it occupies the same memory regime and is
16: the only tuned W4A16-class kernel that runs on
17: SM120 today. Used here as an *informational*
18: SOTA line, not as a numerical reference.
19:
20: The benchmark calls `sota_forward(x, ref_model)` and times it; correctness is
21: NOT checked against this path (the quant scheme differs).
22: """
23: from
... (2725 more chars){
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"command": "python3 -c \"import triton; print(triton.__version__)\" 2>&1",
"description": "Check Triton version"
}3.6.0
{
"command": "nvcc --version 2>&1 | head -5",
"description": "Check nvcc version"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
{
"command": "python3 -c \"import torch; print(torch.__version__); print('CUDA:', torch.version.cuda)\" 2>&1",
"description": "Check PyTorch version"
}2.11.0+cu130
CUDA: 13.0
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/src/ 2>&1",
"description": "List src directory"
}eval
hardware
harness
__init__.py
__pycache__
sandbox
viewer
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval</path>
<type>directory</type>
<entries>
__init__.py
__pycache__/
correctness.py
report.py
roofline.py
shapes.py
timing.py
(7 entries)
</entries>{
"todos": [
{
"content": "Understand problem, interfaces, and correctness constraints",
"status": "completed",
"priority": "high"
},
{
"content": "Implement initial Triton fused W4A16 GEMM kernel",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py for baseline performance",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize kernel for SM1
... (213 more chars)content: 'Understand problem, interfaces, and correctness constraints'
status: 'completed'
priority: 'high'
content: 'Implement initial Triton fused W4A16 GEMM kernel'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py for baseline performance'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel for SM120'
status: 'pending'
priority: 'medium'
content: 'Iterate until PASS on check.py and good benchmark numbers'
status: 'pending'
priority: 'high'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -0,0 +1,164 @@
+"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
+
+Same interface as reference.py:
+ Model(M, N, K), forward(x), get_inputs(), get_init_inputs()
+
+Fused approach: each program handles a (M_block, N_block) tile,
+iterates over K sequentially with on-the-fly unpack + dequant + dot.
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+OP_TYPE = "gemm_w4a16"
+SUPPORTS_PRECISIONS = ["int4_bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+
+GROUP_SIZE = 128
+
+
+@triton.jit
+def _w4a16_gemm_kernel(
+ X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wq_kh, stride_wq_n,
+ stride_s_g, stride_s_n,
+ stride_z_g, stride_z_n,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ m_start = pid_m * BLOCK_M
+ n_start = pid_n * BLOCK_N
+
+ m_off = m_start + tl.arange(0, BLOCK_M)
+ n_off = n_start + tl.arange(0, BLOCK_N)
+ m_mask = m_off < M
+ n_mask = n_off < N
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, K, BLOCK_K):
+ k_off = k_start + tl.arange(0, BLOCK_K)
+ k_mask = k_off < K
+
+ x_ptrs = X_ptr + m_off[:, None] * stride_xm + k_off[None, :] * stride_xk
+ x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+ k_half_off = k_start // 2 + tl.arange(0, BLOCK_K // 2)
+ k_half_mask = k_half_off < (K // 2)
+
+ wq_ptrs = Wq_ptr + k_half_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n
+ wq_packed = tl.load(wq_ptrs, mask=k_half_mask[:, None] & n_mask[None, :], other=0)
+
+ wq_lo = (wq_packed & 0xF).to(tl.float32)
+ wq_hi = ((wq_packed >> 4) & 0xF).to(tl.float32)
+
+ wq_unpacked = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32)
+ even_indices = tl.arange(0, BLOCK_K // 2) * 2
+ odd_indices = even_indices + 1
+
+ wq_unpacked = tl.where(
+ tl.arange(0, BLOCK_K)[:, None] % 2 == 0,
+ tl.gather(wq_lo, tl.arange(0, BLOCK_K // 2)[:, None], 0),
+ tl.gather(wq_hi, tl.arange(0, BLOCK_K // 2)[:, None], 0),
+ )
+
+ g_idx = k_start // GROUP_SIZE
+
+ s_ptrs = S_ptr + g_idx * stride_s_g + n_off * stride_s_n
+ s = tl.load(s_ptrs, mask=n_mask, other=1.0)
+
+ z_ptrs = Z_ptr + g_idx * stride_z_g + n_off * stride_z_n
+ z = tl.load(z_ptrs, mask=n_mask, other=0.0)
+
+ s = s.to(tl.float32)
+ z = z.to(tl.float32)
+
+ w_dequant = (wq_unpacked - z[None, :]) * s[None, :]
+
+ w_dequant = w_dequant.to(x.dtype)
+ acc += tl.dot(x, w_dequant)
+
+ out_ptrs = Out_ptr + m_off[:, None] * stride_om + n_off[None, :] * stride_on
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
+
+
+class Model(nn.Module):
+ def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
+ super().__init__()
+ assert K % group_size == 0
+ assert K % 2 == 0
+ self.M, self.N, self.K = M, N, K
+ self.group_size = group_size
+ n_groups = K // group_size
+
+ torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
+ w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
+
+ w_g = w_full.view(n_groups, group_size, N)
+ w_min = w_g.min(dim=1, keepdim=True).values
+ w_max = w_g.max(dim=1, keepdim=True).values
+ scales = (w_max - w_min).clamp_min(1e-8) / 15.0
+ zeros = (-w_min / scales).round().clamp(0, 15)
+ w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
+ w_q = w_q.view(K, N)
+
+ scales_2d = scales.squeeze(1).to(torch.bfloat16)
+ zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
+
+ lo = w_q[0::2].to(torch.uint8) & 0xF
+ hi = w_q[1::2].to(torch.uint8) & 0xF
+ w_packed = (lo | (hi << 4)).contiguous()
+
+ self.register_buffer("w_q", w_packed)
+ self.register_buffer("scales", scales_2d)
+ self.register_buffer("zeros", zeros_2d)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return w4a16_gemm(x, self.w_q, self.scales, self.zeros, self.M, self.N, self.K, self.group_size)
+
+
+def w4a16_gemm(x, w_q, scales, zeros, M, N, K, group_size):
+ x = x.contiguous().to(torch.bfloat16)
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+
+ BLOCK_M = min(16, M)
+ BLOCK_N = 128
+ BLOCK_K = 128
+
+ grid_m = triton.cdiv(M, BLOCK_M)
+ grid_n = triton.cdiv(N, BLOCK_N)
+
+ _w4a16_gemm_kernel[(grid_m, grid_n)](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ )
+
+ return out
+
+
+def get_inputs():
+ x = torch.randn(M, K, dtype=torch.bfloat16)
+ return [x]
+
+
+def get_init_inputs():
+ return [M, N, K]
{
"content": "<5077 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 2806, in gather
return _semantic.gather(src, index, axis)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/semantic.py", line 1754, in gather
assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: index dim 0 must match the corresponding source dim
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
... (2501 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,12 +1,12 @@
"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
-Same interface as reference.py:
- Model(M, N, K), forward(x), get_inputs(), get_init_inputs()
-
+Same interface as reference.py.
Fused approach: each program handles a (M_block, N_block) tile,
-iterates over K sequentially with on-the-fly unpack + dequant + dot.
+iterates over K pairs (packed bytes), dequantizes on-the-fly, accumulates.
"""
from __future__ import annotations
+
+import math
import torch
import torch.nn as nn
@@ -24,7 +24,7 @@
@triton.jit
def _w4a16_gemm_kernel(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
- M, N, K,
+ M, N, K_half,
stride_xm, stride_xk,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
@@ -32,7 +32,7 @@
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
+ BLOCK_K_HALF: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -47,50 +47,53 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_start in range(0, K, BLOCK_K):
- k_off = k_start + tl.arange(0, BLOCK_K)
- k_mask = k_off < K
+ for kh_start in range(0, K_half, BLOCK_K_HALF):
+ kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
+ kh_mask = kh_off < K_half
- x_ptrs = X_ptr + m_off[:, None] * stride_xm + k_off[None, :] * stride_xk
- x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+ k_even_idx = kh_off * 2
+ k_odd_idx = kh_off * 2 + 1
- k_half_off = k_start // 2 + tl.arange(0, BLOCK_K // 2)
- k_half_mask = k_half_off < (K // 2)
+ x_even_ptrs = X_ptr + m_off[:, None] * stride_xm + k_even_idx[None, :] * stride_xk
+ x_even = tl.load(x_even_ptrs, mask=m_mask[:, None] & kh_mask[None, :], other=0.0)
- wq_ptrs = Wq_ptr + k_half_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n
- wq_packed = tl.load(wq_ptrs, mask=k_half_mask[:, None] & n_mask[None, :], other=0)
+ x_odd_ptrs = X_ptr + m_off[:, None] * stride_xm + k_odd_idx[None, :] * stride_xk
+ x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & kh_mask[None, :], other=0.0)
- wq_lo = (wq_packed & 0xF).to(tl.float32)
- wq_hi = ((wq_packed >> 4) & 0xF).to(tl.float32)
+ wq_ptrs = Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n
+ wq_packed = tl.load(wq_ptrs, mask=kh_mask[:, None] & n_mask[None, :], other=0)
- wq_unpacked = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32)
- even_indices = tl.arange(0, BLOCK_K // 2) * 2
- odd_indices = even_indices + 1
+ wq_lo = wq_packed & 0x0F
+ wq_hi = (wq_packed >> 4) & 0x0F
- wq_unpacked = tl.where(
- tl.arange(0, BLOCK_K)[:, None] % 2 == 0,
- tl.gather(wq_lo, tl.arange(0, BLOCK_K // 2)[:, None], 0),
- tl.gather(wq_hi, tl.arange(0, BLOCK_K // 2)[:, None], 0),
- )
+ g_start = (kh_start * 2) // GROUP_SIZE
+ g_idx = g_start + tl.arange(0, BLOCK_K_HALF * 2 // GROUP_SIZE)
+ g_mask = g_idx < (K_half * 2 // GROUP_SIZE)
- g_idx = k_start // GROUP_SIZE
+ s_ptrs = S_ptr + g_idx[:, None] * stride_s_g + n_off[None, :] * stride_s_n
+ s = tl.load(s_ptrs, mask=g_mask[:, None] & n_mask[None, :], other=1.0)
- s_ptrs = S_ptr + g_idx * stride_s_g + n_off * stride_s_n
- s = tl.load(s_ptrs, mask=n_mask, other=1.0)
+ z_ptrs = Z_ptr + g_idx[:, None] * stride_z_g + n_off[None, :] * stride_z_n
+ z = tl.load(z_ptrs, mask=g_mask[:, None] & n_mask[None, :], other=0.0)
- z_ptrs = Z_ptr + g_idx * stride_z_g + n_off * stride_z_n
- z = tl.load(z_ptrs, mask=n_mask, other=0.0)
+ s = s.to(tl.bfloat16)
+ z = z.to(tl.bfloat16)
- s = s.to(tl.float32)
- z = z.to(tl.float32)
+ w_lo_bf = wq_lo.to(tl.bfloat16)
+ w_hi_bf = wq_hi.to(tl.bfloat16)
- w_dequant = (wq_unpacked - z[None, :]) * s[None, :]
+ w_lo_bf = (w_lo_bf - z) * s
- w_dequant = w_dequant.to(x.dtype)
- acc += tl.dot(x, w_dequant)
+ w_hi_bf = (w_hi_bf - z) * s
+
+ x_even_bf = x_even.to(tl.bfloat16)
+ x_odd_bf = x_odd.to(tl.bfloat16)
+
+ acc += tl.dot(x_even_bf, w_lo_bf.to(tl.bfloat16)).to(tl.float32)
+ acc += tl.dot(x_odd_bf, w_hi_bf.to(tl.bfloat16)).to(tl.float32)
out_ptrs = Out_ptr + m_off[:, None] * stride_om + n_off[None, :] * stride_on
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
class Model(nn.Module):
@@ -132,16 +135,18 @@
x = x.contiguous().to(torch.bfloat16)
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- BLOCK_M = min(16, M)
+ K_half = K // 2
+
+ BLOCK_M = max(1, min(16, M))
BLOCK_N = 128
- BLOCK_K = 128
+ BLOCK_K_HALF = 64
grid_m = triton.cdiv(M, BLOCK_M)
grid_n = triton.cdiv(N, BLOCK_N)
_w4a16_gemm_kernel[(grid_m, grid_n)](
x, w_q, scales, zeros, out,
- M, N, K,
+ M, N, K_half,
x.stride(0), x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
@@ -149,7 +154,7 @@
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
- BLOCK_K=BLOCK_K,
+ BLOCK_K_HALF=BLOCK_K_HALF,
)
return out
{
"content": "<5322 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -5,8 +5,6 @@
iterates over K pairs (packed bytes), dequantizes on-the-fly, accumulates.
"""
from __future__ import annotations
-
-import math
import torch
import torch.nn as nn
@@ -24,7 +22,7 @@
@triton.jit
def _w4a16_gemm_kernel(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
- M, N, K_half,
+ M, N, K_half, n_groups,
stride_xm, stride_xk,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
@@ -66,31 +64,28 @@
wq_lo = wq_packed & 0x0F
wq_hi = (wq_packed >> 4) & 0x0F
- g_start = (kh_start * 2) // GROUP_SIZE
- g_idx = g_start + tl.arange(0, BLOCK_K_HALF * 2 // GROUP_SIZE)
- g_mask = g_idx < (K_half * 2 // GROUP_SIZE)
+ g_idx = (kh_start * 2) // GROUP_SIZE
- s_ptrs = S_ptr + g_idx[:, None] * stride_s_g + n_off[None, :] * stride_s_n
- s = tl.load(s_ptrs, mask=g_mask[:, None] & n_mask[None, :], other=1.0)
+ s_ptrs = S_ptr + g_idx * stride_s_g + n_off * stride_s_n
+ s = tl.load(s_ptrs, mask=n_mask, other=1.0)
- z_ptrs = Z_ptr + g_idx[:, None] * stride_z_g + n_off[None, :] * stride_z_n
- z = tl.load(z_ptrs, mask=g_mask[:, None] & n_mask[None, :], other=0.0)
+ z_ptrs = Z_ptr + g_idx * stride_z_g + n_off * stride_z_n
+ z = tl.load(z_ptrs, mask=n_mask, other=0.0)
- s = s.to(tl.bfloat16)
- z = z.to(tl.bfloat16)
+ s_bf = s.to(tl.bfloat16)
+ z_bf = z.to(tl.bfloat16)
w_lo_bf = wq_lo.to(tl.bfloat16)
w_hi_bf = wq_hi.to(tl.bfloat16)
- w_lo_bf = (w_lo_bf - z) * s
-
- w_hi_bf = (w_hi_bf - z) * s
+ w_deq_lo = (w_lo_bf - z_bf[None, :]) * s_bf[None, :]
+ w_deq_hi = (w_hi_bf - z_bf[None, :]) * s_bf[None, :]
x_even_bf = x_even.to(tl.bfloat16)
x_odd_bf = x_odd.to(tl.bfloat16)
- acc += tl.dot(x_even_bf, w_lo_bf.to(tl.bfloat16)).to(tl.float32)
- acc += tl.dot(x_odd_bf, w_hi_bf.to(tl.bfloat16)).to(tl.float32)
+ acc += tl.dot(x_even_bf, w_deq_lo).to(tl.float32)
+ acc += tl.dot(x_odd_bf, w_deq_hi).to(tl.float32)
out_ptrs = Out_ptr + m_off[:, None] * stride_om + n_off[None, :] * stride_on
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
@@ -136,17 +131,18 @@
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
K_half = K // 2
+ n_groups = K // group_size
BLOCK_M = max(1, min(16, M))
BLOCK_N = 128
- BLOCK_K_HALF = 64
+ BLOCK_K_HALF = group_size // 2 # 64
grid_m = triton.cdiv(M, BLOCK_M)
grid_n = triton.cdiv(N, BLOCK_N)
_w4a16_gemm_kernel[(grid_m, grid_n)](
x, w_q, scales, zeros, out,
- M, N, K_half,
+ M, N, K_half, n_groups,
x.stride(0), x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
{
"content": "<5189 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 126, in forward
return w4a16_gemm(x, self.w_q, self.scales, self.zeros, self.M, self.N, self.K, self.group_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 143, in w4a16_gemm
_w4a16_gemm_kernel[(grid_m, grid_n)](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site
... (2157 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"offset": 1,
"limit": 30
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
1: """W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
2:
3: Same interface as reference.py.
4: Fused approach: each program handles a (M_block, N_block) tile,
5: iterates over K pairs (packed bytes), dequantizes on-the-fly, accumulates.
6: """
7: from __future__ import annotations
8:
9: import torch
10: import torch.nn as nn
11: import triton
12: import triton.language as tl
13:
14:
15: OP_TYPE = "gemm_w4a16"
16: SUPPORTS_PRECISIONS = ["int4_bf16"]
17: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
18:
19: GROUP_SIZE = 128
20:
21:
22: @triton.jit
23: def _w4a16_gemm_kernel(
24: X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
25: M, N, K_half, n_groups,
26: stride_xm, stride_xk,
27: stride_wq_kh, stride_wq_n,
28: stride_s_g, stride_s_n,
29: stride_z_g, stride_z_n,
30: stride_om, stride_on,
(Showing lines 1-30 of 165. Use offset=31 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -28,6 +28,7 @@
stride_s_g, stride_s_n,
stride_z_g, stride_z_n,
stride_om, stride_on,
+ GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K_HALF: tl.constexpr,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<333 chars \u2014 see below>",
"newString": "<363 chars \u2014 see below>"
}@triton.jit
def _w4a16_gemm_kernel(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
M, N, K_half, n_groups,
stride_xm, stride_xk,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
stride_z_g, stride_z_n,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K_HALF: tl.constexpr,
):@triton.jit
def _w4a16_gemm_kernel(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
M, N, K_half, n_groups,
stride_xm, stride_xk,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
stride_z_g, stride_z_n,
stride_om, stride_on,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K_HALF: tl.constexpr,
):Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"limit": 25,
"offset": 135
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
135: n_groups = K // group_size
136:
137: BLOCK_M = max(1, min(16, M))
138: BLOCK_N = 128
139: BLOCK_K_HALF = group_size // 2 # 64
140:
141: grid_m = triton.cdiv(M, BLOCK_M)
142: grid_n = triton.cdiv(N, BLOCK_N)
143:
144: _w4a16_gemm_kernel[(grid_m, grid_n)](
145: x, w_q, scales, zeros, out,
146: M, N, K_half, n_groups,
147: x.stride(0), x.stride(1),
148: w_q.stride(0), w_q.stride(1),
149: scales.stride(0), scales.stride(1),
150: zeros.stride(0), zeros.stride(1),
151: out.stride(0), out.stride(1),
152: BLOCK_M=BLOCK_M,
153: BLOCK_N=BLOCK_N,
154: BLOCK_K_HALF=BLOCK_K_HALF,
155: )
156:
157: return out
158:
159:
(Showing lines 135-159 of 166. Use offset=160 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -149,6 +149,7 @@
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out.stride(0), out.stride(1),
+ GROUP_SIZE=group_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K_HALF=BLOCK_K_HALF,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<396 chars \u2014 see below>",
"newString": "<427 chars \u2014 see below>"
} _w4a16_gemm_kernel[(grid_m, grid_n)](
x, w_q, scales, zeros, out,
M, N, K_half, n_groups,
x.stride(0), x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K_HALF=BLOCK_K_HALF,
) _w4a16_gemm_kernel[(grid_m, grid_n)](
x, w_q, scales, zeros, out,
M, N, K_half, n_groups,
x.stride(0), x.stride(1),
w_q.stride(0), w_q.stride(1),
scales.stride(0), scales.stride(1),
zeros.stride(0), zeros.stride(1),
out.stride(0), out.stride(1),
GROUP_SIZE=group_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K_HALF=BLOCK_K_HALF,
)Edit applied successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}PASS
{
"todos": [
{
"content": "Understand problem, interfaces, and correctness constraints",
"status": "completed",
"priority": "high"
},
{
"content": "Implement initial Triton fused W4A16 GEMM kernel",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Run benchmark.py for baseline performance",
"status": "in_progress",
"priority": "high"
},
{
"content": "Profile and optimize kernel for
... (217 more chars)content: 'Understand problem, interfaces, and correctness constraints'
status: 'completed'
priority: 'high'
content: 'Implement initial Triton fused W4A16 GEMM kernel'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py for baseline performance'
status: 'in_progress'
priority: 'high'
content: 'Profile and optimize kernel for SM120'
status: 'pending'
priority: 'medium'
content: 'Iterate until PASS on check.py and good benchmark numbers'
status: 'pending'
priority: 'high'{
"command": "python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.710 ms=0.771
shape=0 variant=compiled tflops=0.677 gbps=180.168 ms=0.149
shape=0 variant=sota tflops=2.015 gbps=535.944 ms=0.050
shape=0 variant=solution tflops=1.178 gbps=313.219 ms=0.085
shape=0 solution_peak_fraction=0.1740
shape=1 variant=eager tflops=4.154 gbps=35.833 ms=0.775
shape=1 variant=compiled tflops=20.483 gbps=176.692 ms=0.157
shape=1 variant=sota tflops=23.165 gbps=199.828 ms=0.139
shape=1 variant=solution tflops=30.444 gbps=262.620 ms=0.106
shape=1 solution_peak_fraction=0.1459
shape=2 variant=eager tflops=31.274 gbps=42.630 ms=0.824
shape=2 variant=compiled tflops=118.619 gbps=161.692 ms=0.217
shape=2 variant=sota tflops=130.902 gbps=178.434 ms=0.197
shape=2 variant=solution tflops=89.245 gbps=121.652 ms=0.289
shape=2 solution_peak_fraction=0.0676
shape=3 variant=eager tflops=0.181 gbps=48.065 ms=0.186
shape=3 variant=compiled tflops=0.538 gbps=143.171 ms=0.062
shape=3 variant=sota tflops=1.117 gbp
... (423 more chars){
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"filePath": "/home/infatoshi/.claude/skills/debug-gpu/SKILL.md"
}{
"command": "head -200 /home/infatoshi/.claude/skills/perf-optim/SKILL.md 2>&1",
"description": "Read kernel optimization skill"
}---
name: perf-optim
description: "[SUB-REFERENCE of kernel skill] Performance optimization. Do NOT invoke directly -- use /kernel instead, which reads this file on demand."
---
# Performance Optimization Atlas
Comprehensive, hierarchical reference for optimizing GPU workloads. Pick combinations from this atlas based on your bottleneck profile. **Always profile before optimizing.**
---
## 0. FIRST PRINCIPLES: SPEED OF LIGHT, THEN PROFILE
### Rule
Never optimize blind. But before you even profile, compute the **speed of light** -- the theoretical physical limit for this operation. Then profile to measure the gap. The gap IS the optimization opportunity.
### Speed of Light Analysis
1. **Compute the limit**: What does physics allow? Peak FLOPS, peak bandwidth, minimum latency for this exact workload
2. **Profile current state**: Where are we now?
3. **Measure the gap**: speed_of_light - measured = optimization budget
4. **Don't continuous-improve from current state**: Strip to zero. If you built this from scratch knowing only the physical limits, what would you build? Then reason about every piece of the gap between that and current
- "It takes 74ms" -- don't optimize to 72ms. Compute that physics allows 6ms. Now reason about the 6-to-74 gap. The rest might be well-reasoned compromises, but at least you know what they are
### NVIDIA Profiling Tools
| Tool | What it tells you | When to use |
|------|------------------|-------------|
| `nvidia-smi` | GPU util, memory, te
... (11568 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -3,6 +3,7 @@
Same interface as reference.py.
Fused approach: each program handles a (M_block, N_block) tile,
iterates over K pairs (packed bytes), dequantizes on-the-fly, accumulates.
+Uses block pointers and autotuning for performance.
"""
from __future__ import annotations
@@ -19,6 +20,18 @@
GROUP_SIZE = 128
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': bm, 'BLOCK_N': bn, 'BLOCK_K_HALF': bkh, 'num_stages': ns, 'num_warps': nw},
+ num_warps=nw, num_stages=ns)
+ for bm in [1, 4, 8, 16]
+ for bn in [64, 128, 256]
+ for bkh in [32, 64]
+ for ns in [1, 2, 3, 4]
+ for nw in [4, 8]
+ ],
+ key=['M', 'N', 'K_half'],
+)
@triton.jit
def _w4a16_gemm_kernel(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
@@ -136,7 +149,7 @@
BLOCK_M = max(1, min(16, M))
BLOCK_N = 128
- BLOCK_K_HALF = group_size // 2 # 64
+ BLOCK_K_HALF = group_size // 2
grid_m = triton.cdiv(M, BLOCK_M)
grid_n = triton.cdiv(N, BLOCK_N)
{
"content": "<5677 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,9 +1,7 @@
"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
Same interface as reference.py.
-Fused approach: each program handles a (M_block, N_block) tile,
-iterates over K pairs (packed bytes), dequantizes on-the-fly, accumulates.
-Uses block pointers and autotuning for performance.
+Each program handles (M_block, N_block) tile, iterates K with fused dequant.
"""
from __future__ import annotations
@@ -22,18 +20,25 @@
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': bm, 'BLOCK_N': bn, 'BLOCK_K_HALF': bkh, 'num_stages': ns, 'num_warps': nw},
- num_warps=nw, num_stages=ns)
- for bm in [1, 4, 8, 16]
- for bn in [64, 128, 256]
- for bkh in [32, 64]
- for ns in [1, 2, 3, 4]
- for nw in [4, 8]
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
],
key=['M', 'N', 'K_half'],
)
@triton.jit
-def _w4a16_gemm_kernel(
+def _w4a16_gemm_kernel_autotune(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
M, N, K_half, n_groups,
stride_xm, stride_xk,
@@ -69,7 +74,7 @@
x_even_ptrs = X_ptr + m_off[:, None] * stride_xm + k_even_idx[None, :] * stride_xk
x_even = tl.load(x_even_ptrs, mask=m_mask[:, None] & kh_mask[None, :], other=0.0)
- x_odd_ptrs = X_ptr + m_off[:, None] * stride_xm + k_odd_idx[None, :] * stride_xk
+ x_odd_ptrs = X_ptr + m_off[:, None] * stride_xk + k_odd_idx[None, :] * stride_xk
x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & kh_mask[None, :], other=0.0)
wq_ptrs = Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n
@@ -147,14 +152,9 @@
K_half = K // 2
n_groups = K // group_size
- BLOCK_M = max(1, min(16, M))
- BLOCK_N = 128
- BLOCK_K_HALF = group_size // 2
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']), triton.cdiv(N, meta['BLOCK_N']))
- grid_m = triton.cdiv(M, BLOCK_M)
- grid_n = triton.cdiv(N, BLOCK_N)
-
- _w4a16_gemm_kernel[(grid_m, grid_n)](
+ _w4a16_gemm_kernel_autotune[grid](
x, w_q, scales, zeros, out,
M, N, K_half, n_groups,
x.stride(0), x.stride(1),
@@ -163,9 +163,6 @@
zeros.stride(0), zeros.stride(1),
out.stride(0), out.stride(1),
GROUP_SIZE=group_size,
- BLOCK_M=BLOCK_M,
- BLOCK_N=BLOCK_N,
- BLOCK_K_HALF=BLOCK_K_HALF,
)
return out
{
"content": "<6539 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check after autotune refactor",
"timeout": 120000
}FAIL: shape 1 {'M': 32, 'N': 12288, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=6.09375 (atol=0.1, rtol=0.1)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -74,7 +74,7 @@
x_even_ptrs = X_ptr + m_off[:, None] * stride_xm + k_even_idx[None, :] * stride_xk
x_even = tl.load(x_even_ptrs, mask=m_mask[:, None] & kh_mask[None, :], other=0.0)
- x_odd_ptrs = X_ptr + m_off[:, None] * stride_xk + k_odd_idx[None, :] * stride_xk
+ x_odd_ptrs = X_ptr + m_off[:, None] * stride_xm + k_odd_idx[None, :] * stride_xk
x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & kh_mask[None, :], other=0.0)
wq_ptrs = Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<176 chars \u2014 see below>",
"newString": "<176 chars \u2014 see below>"
} x_odd_ptrs = X_ptr + m_off[:, None] * stride_xk + k_odd_idx[None, :] * stride_xk
x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & kh_mask[None, :], other=0.0) x_odd_ptrs = X_ptr + m_off[:, None] * stride_xm + k_odd_idx[None, :] * stride_xk
x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & kh_mask[None, :], other=0.0)Edit applied successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check after stride fix",
"timeout": 120000
}PASS
{
"command": "python benchmark.py 2>&1",
"description": "Run benchmark with autotune",
"timeout": 600000
}/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.722 ms=0.771
shape=0 variant=compiled tflops=0.659 gbps=175.316 ms=0.153
shape=0 variant=sota tflops=2.271 gbps=604.049 ms=0.044
shape=0 variant=solution tflops=1.456 gbps=387.229 ms=0.069
shape=0 solution_peak_fraction=0.2151
shape=1 variant=eager tflops=4.210 gbps=36.314 ms=0.765
shape=1 variant=compiled tflops=20.514 gbps=176.962 ms=0.157
shape=1 variant=sota tflops=23.410 gbps=201.942 ms=0.138
shape=1 variant=solution tflops=34.158 gbps=294.656 ms=0.094
shape=1 solution_peak_fraction=0.1637
shape=2 variant=eager tflops=31.358 gbps=42.745 ms=0.822
shape=2 variant=compiled tflops=119.793 gbps=163.292 ms=0.215
shape=2 variant=sota tflops=133.130 gbps=181.473 ms=0.194
shape=2 variant=solution tflops=86.727 gbps=118.220 ms=0.297
shape=2 solution_peak_fraction=0.0657
shape=3 variant=eager tflops=0.181 gbps=48.273 ms=0.185
shape=3 variant=compiled tflops=0.514 gbps=136.717 ms=0.065
shape=3 variant=sota tflops=1.346 gbp
... (423 more chars){
"command": "<965 chars \u2014 see below>",
"description": "Get autotune best configs per shape",
"timeout": 120000
}python -c "
import torch
import solution
import reference
shapes = [
{'M': 1, 'N': 12288, 'K': 4096},
{'M': 32, 'N': 12288, 'K': 4096},
{'M': 256, 'N': 12288, 'K': 4096},
{'M': 1, 'N': 4096, 'K': 4096},
{'M': 16, 'N': 14336, 'K': 4096},
]
device = torch.device('cuda:0')
for shape in shapes:
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
# warmup
for _ in range(10):
sol_model(*inputs)
torch.cuda.synchronize()
# get the best config
kernel = solution._w4a16_gemm_kernel_autotune
print(f'Shape {shape}: best_config = {kernel.best_config}')
" 2>&1Shape {'M': 1, 'N': 12288, 'K': 4096}: best_config = BLOCK_M: 8, BLOCK_N: 128, BLOCK_K_HALF: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 32, 'N': 12288, 'K': 4096}: best_config = BLOCK_M: 8, BLOCK_N: 128, BLOCK_K_HALF: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 256, 'N': 12288, 'K': 4096}: best_config = BLOCK_M: 16, BLOCK_N: 128, BLOCK_K_HALF: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 1, 'N': 4096, 'K': 4096}: best_config = BLOCK_M: 1, BLOCK_N: 64, BLOCK_K_HALF: 64, num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 16, 'N': 14336, 'K': 4096}: best_config = BLOCK_M: 8, BLOCK_N: 128, BLOCK_K_HALF: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
{
"command": "<757 chars \u2014 see below>",
"description": "Profile M=1 shape with ncu",
"timeout": 120000
}ncu --target-processes all --set full --metrics gpu__time_duration.sum,sm__throughput.avg.pct_of_peak_sustained_elapsed,lts__throughput.avg.pct_of_peak_sustained_elapsed,smsp__throughput.avg.pct_of_peak_sustained_active \
python -c "
import torch, solution, reference
device = torch.device('cuda:0')
reference.M, reference.N, reference.K = 1, 12288, 4096
init_args = reference.get_init_inputs()
ref = reference.Model(*init_args).to(device).eval()
sol = solution.Model(*init_args).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(2026)
x = torch.randn(1, 4096, dtype=torch.bfloat16, device=device)
for _ in range(5): sol(x)
torch.cuda.synchronize()
for _ in range(20): sol(x)
torch.cuda.synchronize()
" 2>&1 | tail -40==PROF== Connected to process 943536 (/usr/bin/python3.12)
==ERROR== ERR_NVGPUCTRPERM - The user does not have permission to access NVIDIA GPU Performance Counters on the target device 0. For instructions on enabling permissions and to get more information see https://developer.nvidia.com/ERR_NVGPUCTRPERM
==PROF== Disconnected from process 943536
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,14 +1,13 @@
-"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
-
-Same interface as reference.py.
-Each program handles (M_block, N_block) tile, iterates K with fused dequant.
+"""W4A16 weight-only quantized GEMM — Custom CUDA C++ fused kernel.
+
+Uses torch.utils.cpp_extension.load_inline for JIT compilation.
+Specialized GEMV kernel for M=1, tiled GEMM kernel for M>1.
"""
from __future__ import annotations
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
+from torch.utils.cpp_extension import load_inline
OP_TYPE = "gemm_w4a16"
@@ -17,97 +16,316 @@
GROUP_SIZE = 128
-
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=3),
- triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+CUDA_SRC = r"""
+#include <cuda_bf16.h>
+#include <cuda_fp16.h>
+#include <cstdint>
+
+using bf16 = __nv_bfloat16;
+using bf16x2 = __nv_bfloat162;
+
+// GEMV kernel: M=1, single row-vector @ packed-int4 weight matrix
+// Each thread handles one output column (N element).
+// Threads within a warp load consecutive packed bytes -> perfect coalescing.
+template <int BLOCK_N, int UNROLL>
+__global__ void gemv_w4a16_kernel(
+ const bf16* __restrict__ x, // (1, K) bf16
+ const uint8_t* __restrict__ w_q, // (K/2, N) uint8 packed
+ const bf16* __restrict__ scales, // (n_groups, N) bf16
+ const bf16* __restrict__ zeros, // (n_groups, N) bf16
+ bf16* __restrict__ out, // (1, N) bf16
+ int N, int K, int n_groups
+) {
+ int n_col = blockIdx.x * BLOCK_N + threadIdx.x;
+ if (n_col >= N) return;
+
+ float acc = 0.0f;
+ int K_half = K / 2;
+
+ __shared__ bf16 x_shared[4096]; // max K
+ if (threadIdx.x < K) {
+ x_shared[threadIdx.x] = x[threadIdx.x];
+ }
+ __syncthreads();
+
+ for (int kh = 0; kh < K_half; kh += UNROLL) {
+ int k_idx = kh * 2;
+ int g_idx = k_idx / 128;
+
+ bf16 s = scales[g_idx * N + n_col];
+ bf16 z = zeros[g_idx * N + n_col];
+ float s_f = __bfloat162float(s);
+ float z_f = __bfloat162float(z);
+
+ #pragma unroll
+ for (int u = 0; u < UNROLL; u++) {
+ int cur_kh = kh + u;
+ if (cur_kh >= K_half) break;
+ uint8_t packed = w_q[cur_kh * N + n_col];
+
+ // even K row
+ float w_lo = (float)(packed & 0x0F);
+ float x_lo = __bfloat162float(x_shared[cur_kh * 2]);
+ acc += (w_lo - z_f) * s_f * x_lo;
+
+ // odd K row
+ float w_hi = (float)((packed >> 4) & 0x0F);
+ float x_hi = __bfloat162float(x_shared[cur_kh * 2 + 1]);
+ acc += (w_hi - z_f) * s_f * x_hi;
+ }
+ }
+
+ out[n_col] = __float2bfloat16(acc);
+}
+
+// Tiled GEMM kernel for M > 1
+// Each thread block handles BM x BN tile, iterates over K.
+// Uses shared memory for activations, registers for weights.
+template <int BM, int BN, int BK, int UNROLL>
+__global__ void gemm_w4a16_kernel(
+ const bf16* __restrict__ x,
+ const uint8_t* __restrict__ w_q,
+ const bf16* __restrict__ scales,
+ const bf16* __restrict__ zeros,
+ bf16* __restrict__ out,
+ int M, int N, int K, int n_groups
+) {
+ int m_block = blockIdx.x;
+ int n_block = blockIdx.y;
+
+ int tid = threadIdx.x;
+ int warp_id = tid / 32;
+ int lane_id = tid % 32;
+
+ __shared__ bf16 s_x[BM][BK];
+ float acc[BM / 4]; // each thread accumulates 4 M-rows worth
+ for (int i = 0; i < BM / 4; i++) acc[i] = 0.0f;
+
+ int m_start = m_block * BM;
+ int n_start = n_block * BN;
+ int K_half = K / 2;
+ int BK_half = BK / 2;
+
+ for (int kh_start = 0; kh_start < K_half; kh_start += BK_half) {
+ // Load x into shared memory (cooperatively)
+ for (int i = tid; i < BM * BK; i += blockDim.x) {
+ int m_idx = i / BK;
+ int k_idx = i % BK;
+ int global_m = m_start + m_idx;
+ int global_k = kh_start * 2 + k_idx;
+ if (global_m < M && global_k < K) {
+ s_x[m_idx][k_idx] = x[global_m * K + global_k];
+ } else {
+ s_x[m_idx][k_idx] = __float2bfloat16(0.0f);
+ }
+ }
+ __syncthreads();
+
+ int g_idx = (kh_start * 2) / 128;
+
+ // Each thread handles specific N columns
+ // We assign each thread to process n_per_thread columns,
+ // and reduce across threads within the block a
... (10689 more chars){
"content": "<12527 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"limit": 20,
"offset": 200
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
200: for (int n_local = 0; n_local < N_PER_THREAD; n_local++) {
201: int n_idx = tid * N_PER_THREAD + n_local;
202: int global_n = n_start + n_idx;
203: if (global_n >= N) continue;
204:
205: for (int kh = 0; kh < K_half; kh++) {
206: int k_start = kh * 2;
207: int g_idx = k_start / 128;
208:
209: float s_f = __bfloat162float(scales[g_idx * N + global_n]);
210: float z_f = __bfloat162float(zeros[g_idx * N + global_n]);
211:
212: uint8_t packed = w_q[kh * N + global_n];
213: float w_lo = (float)(packed & 0x0F);
214: float w_hi = (float)((packed >> 4) & 0x0F);
215: float w_lo_dq = (w_lo - z_f) * s_f;
216: float w_hi_dq = (w_hi - z_f) * s_f;
217:
218: for (int m_local = 0; m_local < BM; m_local++) {
219: int global_m = m_start + m_local;
(Showing lines 200-219 of 382. Use offset=220 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -300,7 +300,7 @@
}
"""
-PY BIND_SRC = r"""
+PYBIND_SRC = r"""
#include <torch/extension.h>
std::vector<at::Tensor> w4a16_gemm_cuda(
at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "PY BIND_SRC = r\"\"\"",
"newString": "PYBIND_SRC = r\"\"\""
}Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -24,58 +24,150 @@
using bf16 = __nv_bfloat16;
using bf16x2 = __nv_bfloat162;
-// GEMV kernel: M=1, single row-vector @ packed-int4 weight matrix
-// Each thread handles one output column (N element).
-// Threads within a warp load consecutive packed bytes -> perfect coalescing.
-template <int BLOCK_N, int UNROLL>
+template <int MAX_K = 4096>
__global__ void gemv_w4a16_kernel(
- const bf16* __restrict__ x, // (1, K) bf16
- const uint8_t* __restrict__ w_q, // (K/2, N) uint8 packed
- const bf16* __restrict__ scales, // (n_groups, N) bf16
- const bf16* __restrict__ zeros, // (n_groups, N) bf16
- bf16* __restrict__ out, // (1, N) bf16
- int N, int K, int n_groups
-) {
- int n_col = blockIdx.x * BLOCK_N + threadIdx.x;
+ const bf16* __restrict__ x,
+ const uint8_t* __restrict__ w_q,
+ const bf16* __restrict__ scales,
+ const bf16* __restrict__ zeros,
+ bf16* __restrict__ out,
+ int N, int K, int K_half
+) {
+ int n_col = blockIdx.x * blockDim.x + threadIdx.x;
if (n_col >= N) return;
float acc = 0.0f;
+
+ extern __shared__ bf16 x_shared[];
+ for (int i = threadIdx.x; i < K; i += blockDim.x) {
+ x_shared[i] = x[i];
+ }
+ __syncthreads();
+
+ for (int kh = 0; kh < K_half; kh++) {
+ int k_even = kh * 2;
+ int g_idx = k_even / 128;
+
+ float s_f = __bfloat162float(scales[g_idx * N + n_col]);
+ float z_f = __bfloat162float(zeros[g_idx * N + n_col]);
+
+ uint8_t packed = w_q[kh * N + n_col];
+
+ float w_lo = (float)(packed & 0x0F);
+ float w_hi = (float)((packed >> 4) & 0x0F);
+
+ float x_lo = __bfloat162float(x_shared[k_even]);
+ float x_hi = __bfloat162float(x_shared[k_even + 1]);
+
+ acc += x_lo * (w_lo - z_f) * s_f;
+ acc += x_hi * (w_hi - z_f) * s_f;
+ }
+
+ out[n_col] = __float2bfloat16(acc);
+}
+
+template <int BM, int BN, int THREADS>
+__global__ void gemm_w4a16_kernel(
+ const bf16* __restrict__ x,
+ const uint8_t* __restrict__ w_q,
+ const bf16* __restrict__ scales,
+ const bf16* __restrict__ zeros,
+ bf16* __restrict__ out,
+ int M, int N, int K, int n_groups
+) {
+ int m_block = blockIdx.x;
+ int n_block = blockIdx.y;
+ int tid = threadIdx.x;
+
+ int m_start = m_block * BM;
+ int n_start = n_block * BN;
+
+ constexpr int N_PER_THREAD = (BN + THREADS - 1) / THREADS;
int K_half = K / 2;
- __shared__ bf16 x_shared[4096]; // max K
- if (threadIdx.x < K) {
- x_shared[threadIdx.x] = x[threadIdx.x];
- }
- __syncthreads();
-
- for (int kh = 0; kh < K_half; kh += UNROLL) {
- int k_idx = kh * 2;
- int g_idx = k_idx / 128;
-
- bf16 s = scales[g_idx * N + n_col];
- bf16 z = zeros[g_idx * N + n_col];
- float s_f = __bfloat162float(s);
- float z_f = __bfloat162float(z);
-
- #pragma unroll
- for (int u = 0; u < UNROLL; u++) {
- int cur_kh = kh + u;
- if (cur_kh >= K_half) break;
- uint8_t packed = w_q[cur_kh * N + n_col];
-
- // even K row
- float w_lo = (float)(packed & 0x0F);
- float x_lo = __bfloat162float(x_shared[cur_kh * 2]);
- acc += (w_lo - z_f) * s_f * x_lo;
-
- // odd K row
- float w_hi = (float)((packed >> 4) & 0x0F);
- float x_hi = __bfloat162float(x_shared[cur_kh * 2 + 1]);
- acc += (w_hi - z_f) * s_f * x_hi;
- }
- }
-
- out[n_col] = __float2bfloat16(acc);
+ float acc[BM];
+ for (int i = 0; i < BM; i++) acc[i] = 0.0f;
+
+ __shared__ bf16 s_x[BM][4096];
+ __shared__ uint8_t s_wq[BN];
+ __shared__ bf16 s_s[2];
+ __shared__ bf16 s_z[2];
+
+ for (int n_local = 0; n_local < N_PER_THREAD; n_local++) {
+ int n_idx = tid * N_PER_THREAD + n_local;
+ int global_n = n_start + n_idx;
+ if (global_n >= N) continue;
+
+ for (int kh = 0; kh < K_half; kh++) {
+ int k_even = kh * 2;
+ int g_idx = k_even / 128;
+
+ float s_f = __bfloat162float(scales[g_idx * N + global_n]);
+ float z_f = __bfloat162float(zeros[g_idx * N + global_n]);
+
+ uint8_t packed = w_q[kh * N + global_n];
+ float w_lo = (float)(packed & 0x0F) - z_f;
+ float w_hi = (float)((packed >> 4) & 0x0F) - z_f;
+ float w_lo_dq = w_lo * s_f;
+ float w_hi_dq = w_hi * s_f;
+
+ for (int m_local = 0; m_local < BM; m_local++) {
+ int global_m = m_start + m_local;
+ if (global_m >= M) continue;
+
+ float x_lo = __bfloat162float(x[global_m * K + k_even]);
+ float x_hi = __bfloat162float(x[global_m * K + k_even + 1]);
+ acc[m_local] += x_lo * w_lo_dq + x_hi * w_hi_dq;
+ }
+ }
+ }
+
+ for (int m_local = 0; m_local < BM; m_local++) {
+ int global_m = m_start + m_local;
+ if (global_m >= M) continue;
+ for (int n_local = 0; n_local < N_PER_THREAD; n_local++) {
+ int n_idx = tid * N_PER_THREAD + n_local;
+ int global_n = n_start + n_idx;
+ if (global_n < N) {
+ out[global_m * N + global_n] = __float2bfloat16(acc[m_local]);
+ }
+ }
+ }
+}
+
+void gemv_w4a16_launch(
+ const bf16* x, const uint8_t* w_q,
+ const bf16* scales, const bf16* zeros,
+ bf16* out, int N, int K, int n_groups
+) {
+ int K_half = K / 2;
+ constexpr int BLOCK_N = 256;
+ int grid_n = (N + BLOCK_N - 1) / BLOCK_N;
+ int shared_mem_size = K * sizeof(bf16);
+
+ gemv_w4a16_kernel<<<grid_n, BLOCK_N, shared_mem_size>>>(
+ x, w_q, scales, zeros, out, N, K, K_half
+ );
+}
+
+void gemm_w4a16_launch(
+ const bf16* x,
... (507 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<1950 chars \u2014 see below>",
"newString": "<4485 chars \u2014 see below>"
}CUDA_SRC = r"""
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cstdint>
using bf16 = __nv_bfloat16;
using bf16x2 = __nv_bfloat162;
// GEMV kernel: M=1, single row-vector @ packed-int4 weight matrix
// Each thread handles one output column (N element).
// Threads within a warp load consecutive packed bytes -> perfect coalescing.
template <int BLOCK_N, int UNROLL>
__global__ void gemv_w4a16_kernel(
const bf16* __restrict__ x, // (1, K) bf16
const uint8_t* __restrict__ w_q, // (K/2, N) uint8 packed
const bf16* __restrict__ scales, // (n_groups, N) bf16
const bf16* __restrict__ zeros, // (n_groups, N) bf16
bf16* __restrict__ out, // (1, N) bf16
int N, int K, int n_groups
) {
int n_col = blockIdx.x * BLOCK_N + threadIdx.x;
if (n_col >= N) return;
float acc = 0.0f;
int K_half = K / 2;
__shared__ bf16 x_shared[4096]; // max K
if (threadIdx.x < K) {
x_shared[threadIdx.x] = x[threadIdx.x];
}
__syncthreads();
for (int kh = 0; kh < K_half; kh += UNROLL) {
int k_idx = kh * 2;
int g_idx = k_idx / 128;
bf16 s = scales[g_idx * N + n_col];
bf16 z = zeros[g_idx * N + n_col];
float s_f = __bfloat162float(s);
float z_f = __bfloat162float(z);
#pragma unroll
for (int u = 0; u < UNROLL; u++) {
int cur_kh = kh + u;
if (cur_kh >= K_half) break;
uint8_t packed = w_q[cur_kh * N + n_col];
// even K row
float w_lo = (float)(packed & 0x0F);
float x_lo = __bfloat162float(x_shared[cur_kh * 2]);
acc += (w_lo - z_f) * s_f * x_lo;
// odd K row
float w_hi = (float)((packed >> 4) & 0x0F);
float x_hi = __bfloat162float(x_shared[cur_kh * 2 + 1]);
acc += (w_hi - z_f) * s_f * x_hi;
}
}
out[n_col] = __float2bfloat16(acc);
}CUDA_SRC = r"""
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cstdint>
using bf16 = __nv_bfloat16;
using bf16x2 = __nv_bfloat162;
template <int MAX_K = 4096>
__global__ void gemv_w4a16_kernel(
const bf16* __restrict__ x,
const uint8_t* __restrict__ w_q,
const bf16* __restrict__ scales,
const bf16* __restrict__ zeros,
bf16* __restrict__ out,
int N, int K, int K_half
) {
int n_col = blockIdx.x * blockDim.x + threadIdx.x;
if (n_col >= N) return;
float acc = 0.0f;
extern __shared__ bf16 x_shared[];
for (int i = threadIdx.x; i < K; i += blockDim.x) {
x_shared[i] = x[i];
}
__syncthreads();
for (int kh = 0; kh < K_half; kh++) {
int k_even = kh * 2;
int g_idx = k_even / 128;
float s_f = __bfloat162float(scales[g_idx * N + n_col]);
float z_f = __bfloat162float(zeros[g_idx * N + n_col]);
uint8_t packed = w_q[kh * N + n_col];
float w_lo = (float)(packed & 0x0F);
float w_hi = (float)((packed >> 4) & 0x0F);
float x_lo = __bfloat162float(x_shared[k_even]);
float x_hi = __bfloat162float(x_shared[k_even + 1]);
acc += x_lo * (w_lo - z_f) * s_f;
acc += x_hi * (w_hi - z_f) * s_f;
}
out[n_col] = __float2bfloat16(acc);
}
template <int BM, int BN, int THREADS>
__global__ void gemm_w4a16_kernel(
const bf16* __restrict__ x,
const uint8_t* __restrict__ w_q,
const bf16* __restrict__ scales,
const bf16* __restrict__ zeros,
bf16* __restrict__ out,
int M, int N, int K, int n_groups
) {
int m_block = blockIdx.x;
int n_block = blockIdx.y;
int tid = threadIdx.x;
int m_start = m_block * BM;
int n_start = n_block * BN;
constexpr int N_PER_THREAD = (BN + THREADS - 1) / THREADS;
int K_half = K / 2;
float acc[BM];
for (int i = 0; i < BM; i++) acc[i] = 0.0f;
__shared__ bf16 s_x[BM][4096];
__shared__ uint8_t s_wq[BN];
__shared__ bf16 s_s[2];
__shared__ bf16 s_z[2];
for (int n_local = 0; n_local < N_PER_THREAD; n_local++) {
int n_idx = tid * N_PER_THREAD + n_local;
int global_n = n_start + n_idx;
if (global_n >= N) continue;
for (int kh = 0; kh < K_half; kh++) {
int k_even = kh * 2;
int g_idx = k_even / 128;
float s_f = __bfloat162float(scales[g_idx * N + global_n]);
float z_f = __bfloat162float(zeros[g_idx * N + global_n]);
uint8_t packed = w_q[kh * N + global_n];
float w_lo = (float)(packed & 0x0F) - z_f;
float w_hi = (float)((packed >> 4) & 0x0F) - z_f;
float w_lo_dq = w_lo * s_f;
float w_hi_dq = w_hi * s_f;
for (int m_local = 0; m_local < BM; m_local++) {
int global_m = m_start + m_local;
if (global_m >= M) continue;
float x_lo = __bfloat162float(x[global_m * K + k_even]);
float x_hi = __bfloat162float(x[global_m * K + k_even + 1]);
acc[m_local] += x_lo * w_lo_dq + x_hi * w_hi_dq;
}
}
}
for (int m_local = 0; m_local < BM; m_local++) {
int global_m = m_start + m_local;
if (global_m >= M) continue;
for (int n_local = 0; n_local < N_PER_THREAD; n_local++) {
int n_idx = tid * N_PER_THREAD + n_local;
int global_n = n_start + n_idx;
if (global_n < N) {
out[global_m * N + global_n] = __float2bfloat16(acc[m_local]);
}
}
}
}
void gemv_w4a16_launch(
const bf16* x, const uint8_t* w_q,
const bf16* scales, const bf16* zeros,
bf16* out, int N, int K, int n_groups
) {
int K_half = K / 2;
constexpr int BLOCK_N = 256;
int grid_n = (N + BLOCK_N - 1) / BLOCK_N;
int shared_mem_size = K * sizeof(bf16);
gemv_w4a16_kernel<<<grid_n, BLOCK_N, shared_mem_size>>>(
x, w_q, scales, zeros, out, N, K, K_half
);
}
void gemm_w4
... (485 more chars)Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
1: """W4A16 weight-only quantized GEMM — Custom CUDA C++ fused kernel.
2:
3: Uses torch.utils.cpp_extension.load_inline for JIT compilation.
4: Specialized GEMV kernel for M=1, tiled GEMM kernel for M>1.
5: """
6: from __future__ import annotations
7:
8: import torch
9: import torch.nn as nn
10: from torch.utils.cpp_extension import load_inline
11:
12:
13: OP_TYPE = "gemm_w4a16"
14: SUPPORTS_PRECISIONS = ["int4_bf16"]
15: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
16:
17: GROUP_SIZE = 128
18:
19: CUDA_SRC = r"""
20: #include <cuda_bf16.h>
21: #include <cuda_fp16.h>
22: #include <cstdint>
23:
24: using bf16 = __nv_bfloat16;
25: using bf16x2 = __nv_bfloat162;
26:
27: template <int MAX_K = 4096>
28: __global__ void gemv_w4a16_kernel(
29: const bf16* __restrict__ x,
30: const uint8_t* __restrict__ w_q,
31: const bf16* __restrict__ scales,
32: const bf16* __restrict__ zeros,
33: bf16* __restrict__ out,
34: int N, int K, int K_half
35: ) {
36: int n_col = blockIdx.x * blockDim.x + threadIdx.x;
37: if (n_col >= N) return;
38:
39: float acc = 0.0f;
40:
41: extern __shared__ bf16 x_shared[];
42: for (int i = threadIdx.x; i < K; i += blockDim.x) {
43: x_shared[i] = x[i];
44: }
45: __syncthreads();
46:
47: for (int kh = 0; kh < K_half; kh++) {
48: int k_even = kh * 2;
49: int
... (15980 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -20,9 +20,9 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cstdint>
+#include <torch/extension.h>
using bf16 = __nv_bfloat16;
-using bf16x2 = __nv_bfloat162;
template <int MAX_K = 4096>
__global__ void gemv_w4a16_kernel(
@@ -88,11 +88,6 @@
float acc[BM];
for (int i = 0; i < BM; i++) acc[i] = 0.0f;
- __shared__ bf16 s_x[BM][4096];
- __shared__ uint8_t s_wq[BN];
- __shared__ bf16 s_s[2];
- __shared__ bf16 s_z[2];
-
for (int n_local = 0; n_local < N_PER_THREAD; n_local++) {
int n_idx = tid * N_PER_THREAD + n_local;
int global_n = n_start + n_idx;
@@ -135,250 +130,35 @@
}
}
-void gemv_w4a16_launch(
- const bf16* x, const uint8_t* w_q,
- const bf16* scales, const bf16* zeros,
- bf16* out, int N, int K, int n_groups
-) {
- int K_half = K / 2;
- constexpr int BLOCK_N = 256;
- int grid_n = (N + BLOCK_N - 1) / BLOCK_N;
- int shared_mem_size = K * sizeof(bf16);
-
- gemv_w4a16_kernel<<<grid_n, BLOCK_N, shared_mem_size>>>(
- x, w_q, scales, zeros, out, N, K, K_half
- );
-}
-
-void gemm_w4a16_launch(
- const bf16* x, const uint8_t* w_q,
- const bf16* scales, const bf16* zeros,
- bf16* out, int M, int N, int K, int n_groups
-) {
- constexpr int BM = 16;
- constexpr int BN = 128;
- constexpr int THREADS = 256;
-
- int grid_m = (M + BM - 1) / BM;
- int grid_n = (N + BN - 1) / BN;
-
- dim3 grid(grid_m, grid_n);
- dim3 block(THREADS);
-
- gemm_w4a16_kernel<BM, BN, THREADS><<<grid, block>>>(
- x, w_q, scales, zeros, out, M, N, K, n_groups
- );
-}
-
-// Tiled GEMM kernel for M > 1
-// Each thread block handles BM x BN tile, iterates over K.
-// Uses shared memory for activations, registers for weights.
-template <int BM, int BN, int BK, int UNROLL>
-__global__ void gemm_w4a16_kernel(
- const bf16* __restrict__ x,
- const uint8_t* __restrict__ w_q,
- const bf16* __restrict__ scales,
- const bf16* __restrict__ zeros,
- bf16* __restrict__ out,
- int M, int N, int K, int n_groups
-) {
- int m_block = blockIdx.x;
- int n_block = blockIdx.y;
-
- int tid = threadIdx.x;
- int warp_id = tid / 32;
- int lane_id = tid % 32;
-
- __shared__ bf16 s_x[BM][BK];
- float acc[BM / 4]; // each thread accumulates 4 M-rows worth
- for (int i = 0; i < BM / 4; i++) acc[i] = 0.0f;
-
- int m_start = m_block * BM;
- int n_start = n_block * BN;
- int K_half = K / 2;
- int BK_half = BK / 2;
-
- for (int kh_start = 0; kh_start < K_half; kh_start += BK_half) {
- // Load x into shared memory (cooperatively)
- for (int i = tid; i < BM * BK; i += blockDim.x) {
- int m_idx = i / BK;
- int k_idx = i % BK;
- int global_m = m_start + m_idx;
- int global_k = kh_start * 2 + k_idx;
- if (global_m < M && global_k < K) {
- s_x[m_idx][k_idx] = x[global_m * K + global_k];
- } else {
- s_x[m_idx][k_idx] = __float2bfloat16(0.0f);
- }
- }
- __syncthreads();
-
- int g_idx = (kh_start * 2) / 128;
-
- // Each thread handles specific N columns
- // We assign each thread to process n_per_thread columns,
- // and reduce across threads within the block at the end.
- // Actually, let's do it simpler: each thread processes 4 consecutive N cols
- #pragma unroll
- for (int n_local = 0; n_local < 4; n_local++) {
- int n_idx = tid * 4 + n_local;
- int global_n = n_start + n_idx;
- if (global_n >= N) continue;
-
- if (n_idx >= BN) continue;
-
- float s_f = __bfloat162float(scales[g_idx * N + global_n]);
- float z_f = __bfloat162float(zeros[g_idx * N + global_n]);
-
- for (int kh = 0; kh < BK_half; kh++) {
- uint8_t packed = w_q[(kh_start + kh) * N + global_n];
-
- float w_lo = (float)(packed & 0x0F);
- float w_hi = (float)((packed >> 4) & 0x0F);
-
- float w_lo_dq = (w_lo - z_f) * s_f;
- float w_hi_dq = (w_hi - z_f) * s_f;
-
- // Dot product over the unrolled K dimension
- for (int m_idx = 0; m_idx < BM; m_idx += 4) {
- int local_m = m_idx / 4;
- float x_lo = __bfloat162float(s_x[m_idx][kh * 2]);
- float x_hi = __bfloat162float(s_x[m_idx][kh * 2 + 1]);
- acc[local_m] += x_lo * w_lo_dq + x_hi * w_hi_dq;
- }
- }
- }
- __syncthreads();
- }
-
- // Write output - each thread writes its 4 N-cols x BM/4 M-rows
- for (int m_idx = 0; m_idx < BM; m_idx += 4) {
- int local_m = m_idx / 4;
- for (int n_local = 0; n_local < 4; n_local++) {
- int n_idx = tid * 4 + n_local;
- int global_m = m_start + m_idx + n_local; // wrong
- int global_n = n_start + n_idx;
- if (global_m < M && global_n < N) {
- out[global_m * N + global_n] = __float2bfloat16(acc[local_m]);
- }
- }
- }
-}
-
-// Simple tiled GEMM: each program handles its own tile, no cross-block reduction
-template <int BM, int BN, int THREADS_PER_BLOCK>
-__global__ void gemm_w4a16_simple_kernel(
- const bf16* __restrict__ x,
- const uint8_t* __restrict__ w_q,
- const bf16* __restrict__ scales,
- const bf16* __restrict__ zeros,
- bf16* __restrict__ out,
- int M, int N, int K, int n_groups
-) {
- int m_block = blockIdx.x;
- int n_block = blockIdx.y;
- int tid = threadIdx.x;
-
- int m_start = m_block * BM;
- int n_start = n_block * BN;
- int K_half = K / 2;
-
- float acc[BM];
- for (int i = 0; i < BM; i++) acc[i] = 0.0f;
-
- // Each thread handles a
... (4693 more chars){
"content": "<7580 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "python -c \"import solution\" 2>&1 | tail -50",
"description": "Test compilation",
"timeout": 120000
}[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=w4a16_cuda_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -fPIC -std=c++17 -O3 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/w4a16_cuda_v2/main.cpp -o main.o
FAILED: main.o
c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=w4a16_cuda_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -fPIC -std=c++17 -O3 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/w4a16_cuda_v2/main.cpp -o main.o
/home/infatoshi/.cache/torch_extensions/py312_cu130/w4a16_cuda_v2/main.cpp: In function ‘void pybind11_init_w4a16_cuda_v2(pybind11::module_&)’:
/home/infatoshi/.cache/torch_extensions/py312_cu130/w4a16_cuda_v2/main.cpp:4:49: error: ‘w4a16_gemm’ was not declared in this scope
4 | m.def("w4a16_gemm", torch::wrap_pybind_function(w4a16_gemm), "w4a16_gemm");
| ^~~~~~~~~~
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=w4a16_cuda_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/pyth
... (1946 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -15,6 +15,15 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
GROUP_SIZE = 128
+
+CPP_SRC = r"""
+#include <torch/extension.h>
+#include <cstdint>
+std::vector<at::Tensor> w4a16_gemm_cuda(
+ at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros,
+ int M, int N, int K, int group_size
+);
+"""
CUDA_SRC = r"""
#include <cuda_bf16.h>
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<109 chars \u2014 see below>",
"newString": "<332 chars \u2014 see below>"
}CUDA_SRC = r"""
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <torch/extension.h>CPP_SRC = r"""
#include <torch/extension.h>
#include <cstdint>
std::vector<at::Tensor> w4a16_gemm_cuda(
at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros,
int M, int N, int K, int group_size
);
"""
CUDA_SRC = r"""
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <torch/extension.h>Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -179,10 +179,6 @@
return {out};
}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("w4a16_gemm", &w4a16_gemm_cuda, "W4A16 fused GEMM");
-}
"""
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<132 chars \u2014 see below>",
"newString": "<23 chars \u2014 see below>"
} return {out};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("w4a16_gemm", &w4a16_gemm_cuda, "W4A16 fused GEMM");
}
""" return {out};
}
"""Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -184,9 +184,9 @@
_module = load_inline(
name="w4a16_cuda_v2",
- cpp_sources=[""],
+ cpp_sources=[CPP_SRC],
cuda_sources=[CUDA_SRC],
- functions=["w4a16_gemm"],
+ functions=["w4a16_gemm_cuda"],
extra_cflags=["-O3"],
extra_cuda_cflags=[
"-O3",
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<304 chars \u2014 see below>",
"oldString": "<294 chars \u2014 see below>"
}_module = load_inline(
name="w4a16_cuda_v2",
cpp_sources=[CPP_SRC],
cuda_sources=[CUDA_SRC],
functions=["w4a16_gemm_cuda"],
extra_cflags=["-O3"],
extra_cuda_cflags=[
"-O3",
"--use_fast_math",
"-arch=sm_120",
"-std=c++17",
],
verbose=True,
)_module = load_inline(
name="w4a16_cuda_v2",
cpp_sources=[""],
cuda_sources=[CUDA_SRC],
functions=["w4a16_gemm"],
extra_cflags=["-O3"],
extra_cuda_cflags=[
"-O3",
"--use_fast_math",
"-arch=sm_120",
"-std=c++17",
],
verbose=True,
)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -231,7 +231,7 @@
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.contiguous().to(torch.bfloat16)
- result = _module.w4a16_gemm(
+ result = _module.w4a16_gemm_cuda(
x, self.w_q, self.scales, self.zeros,
self.M, self.N, self.K, self.group_size
)
@@ -240,7 +240,7 @@
def w4a16_gemm(x, w_q, scales, zeros, M, N, K, group_size):
x = x.contiguous().to(torch.bfloat16)
- return _module.w4a16_gemm(x, w_q, scales, zeros, M, N, K, group_size)[0]
+ return _module.w4a16_gemm_cuda(x, w_q, scales, zeros, M, N, K, group_size)[0]
def get_inputs():
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<466 chars \u2014 see below>",
"oldString": "<456 chars \u2014 see below>"
} def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.contiguous().to(torch.bfloat16)
result = _module.w4a16_gemm_cuda(
x, self.w_q, self.scales, self.zeros,
self.M, self.N, self.K, self.group_size
)
return result[0]
def w4a16_gemm(x, w_q, scales, zeros, M, N, K, group_size):
x = x.contiguous().to(torch.bfloat16)
return _module.w4a16_gemm_cuda(x, w_q, scales, zeros, M, N, K, group_size)[0] def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.contiguous().to(torch.bfloat16)
result = _module.w4a16_gemm(
x, self.w_q, self.scales, self.zeros,
self.M, self.N, self.K, self.group_size
)
return result[0]
def w4a16_gemm(x, w_q, scales, zeros, M, N, K, group_size):
x = x.contiguous().to(torch.bfloat16)
return _module.w4a16_gemm(x, w_q, scales, zeros, M, N, K, group_size)[0]Edit applied successfully.{
"command": "python -c \"import solution; print('compiled ok')\" 2>&1",
"description": "Test compilation",
"timeout": 120000
}[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=w4a16_cuda_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -fPIC -std=c++17 -O3 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/w4a16_cuda_v2/main.cpp -o main.o
[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=w4a16_cuda_v2 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include -isystem /home/infatoshi/.local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /usr/include/python3.12 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -arch=sm_120 -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py312_cu130/w4a16_cuda_v2/cuda.cu -o cuda.cuda.o
[3/3] c++ main.o cuda.cuda.o -shared -L/home/infatoshi/.local/lib/python3.12/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda-13/lib64 -lcudart -o w4a16_cuda_v2.so
compiled ok
{
"command": "python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}ninja: no work to do.
PASS
{
"command": "python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 600000
}ninja: no work to do.
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.130 gbps=34.677 ms=0.772
shape=0 variant=compiled tflops=0.642 gbps=170.719 ms=0.157
shape=0 variant=sota tflops=2.272 gbps=604.267 ms=0.044
shape=0 variant=solution tflops=0.430 gbps=114.338 ms=0.234
shape=0 solution_peak_fraction=0.0635
shape=1 variant=eager tflops=4.160 gbps=35.882 ms=0.774
shape=1 variant=compiled tflops=20.105 gbps=173.428 ms=0.160
shape=1 variant=sota tflops=22.865 gbps=197.241 ms=0.141
shape=1 variant=solution tflops=1.084 gbps=9.350 ms=2.972
shape=1 solution_peak_fraction=0.0052
shape=2 variant=eager tflops=31.307 gbps=42.676 ms=0.823
shape=2 variant=compiled tflops=120.736 gbps=164.577 ms=0.213
shape=2 variant=sota tflops=132.104 gbps=180.073 ms=0.195
shape=2 variant=solution tflops=4.687 gbps=6.388 ms=5.499
shape=2 solution_peak_fraction=0.0035
shape=3 variant=eager tflops=0.182 gbps=48.398 ms=0.184
shape=3 variant=compiled tflops=0.458 gbps=121.985 ms=0.073
shape=3 variant=sota
... (437 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,13 +1,13 @@
-"""W4A16 weight-only quantized GEMM — Custom CUDA C++ fused kernel.
+"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
-Uses torch.utils.cpp_extension.load_inline for JIT compilation.
-Specialized GEMV kernel for M=1, tiled GEMM kernel for M>1.
+Optimized version with autotuning, block pointers, and persistent scheduling.
"""
from __future__ import annotations
import torch
import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
+import triton
+import triton.language as tl
OP_TYPE = "gemm_w4a16"
@@ -16,186 +16,111 @@
GROUP_SIZE = 128
-CPP_SRC = r"""
-#include <torch/extension.h>
-#include <cstdint>
-std::vector<at::Tensor> w4a16_gemm_cuda(
- at::Tensor x, at::Tensor w_q, at::Tensor scales, at::Tensor zeros,
- int M, int N, int K, int group_size
-);
-"""
-CUDA_SRC = r"""
-#include <cuda_bf16.h>
-#include <cuda_fp16.h>
-#include <cstdint>
-#include <torch/extension.h>
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
+ triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
+ ],
+ key=['M', 'N', 'K_half'],
+)
+@triton.jit
+def _w4a16_gemm_kernel(
+ X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
+ M, N, K_half, n_groups,
+ stride_xm, stride_xk,
+ stride_wq_kh, stride_wq_n,
+ stride_s_g, stride_s_n,
+ stride_z_g, stride_z_n,
+ stride_om, stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K_HALF: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
-using bf16 = __nv_bfloat16;
+ m_start = pid_m * BLOCK_M
+ n_start = pid_n * BLOCK_N
-template <int MAX_K = 4096>
-__global__ void gemv_w4a16_kernel(
- const bf16* __restrict__ x,
- const uint8_t* __restrict__ w_q,
- const bf16* __restrict__ scales,
- const bf16* __restrict__ zeros,
- bf16* __restrict__ out,
- int N, int K, int K_half
-) {
- int n_col = blockIdx.x * blockDim.x + threadIdx.x;
- if (n_col >= N) return;
+ offs_m = (m_start + tl.arange(0, BLOCK_M))[:, None]
+ offs_n = (n_start + tl.arange(0, BLOCK_N))[None, :]
+ m_mask = offs_m < M
+ n_mask = offs_n < N
- float acc = 0.0f;
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- extern __shared__ bf16 x_shared[];
- for (int i = threadIdx.x; i < K; i += blockDim.x) {
- x_shared[i] = x[i];
- }
- __syncthreads();
+ for kh_start in range(0, K_half, BLOCK_K_HALF):
+ kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
+ kh_mask = kh_off < K_half
- for (int kh = 0; kh < K_half; kh++) {
- int k_even = kh * 2;
- int g_idx = k_even / 128;
+ k_even = (kh_off * 2)
+ k_odd = k_even + 1
- float s_f = __bfloat162float(scales[g_idx * N + n_col]);
- float z_f = __bfloat162float(zeros[g_idx * N + n_col]);
+ x_even = tl.load(
+ X_ptr + offs_m * stride_xm + k_even[None, :] * stride_xk,
+ mask=m_mask & kh_mask[None, :],
+ other=0.0,
+ )
+ x_odd = tl.load(
+ X_ptr + offs_m * stride_xm + k_odd[None, :] * stride_xk,
+ mask=m_mask & kh_mask[None, :],
+ other=0.0,
+ )
- uint8_t packed = w_q[kh * N + n_col];
+ wq_packed = tl.load(
+ Wq_ptr + kh_off[:, None] * stride_wq_kh + offs_n * stride_wq_n,
+ mask=kh_mask[:, None] & n_mask,
+ other=0,
+ )
- float w_lo = (float)(packed & 0x0F);
- float w_hi = (float)((packed >> 4) & 0x0F);
+ wq_lo = wq_packed & 0x0F
+ wq_hi = (wq_packed >> 4) & 0x0F
- float x_lo = __bfloat162float(x_shared[k_even]);
- float x_hi = __bfloat162float(x_shared[k_even + 1]);
+ g_idx = (kh_start * 2) // GROUP_SIZE
- acc += x_lo * (w_lo - z_f) * s_f;
- acc += x_hi * (w_hi - z_f) * s_f;
- }
+ s = tl.load(
+ S_ptr + g_idx * stride_s_g + tl.arange(0, BLOCK_N) * stride_s_n,
+ mask=n_mask.reshape((BLOCK_N,)),
+ other=1.0,
+ )[None, :]
+ z = tl.load(
+ Z_ptr + g_idx * stride_z_g + tl.arange(0, BLOCK_N) * stride_z_n,
+ mask=n_mask.reshape((BLOCK_N,)),
+ other=0.0,
+ )[None, :]
-
... (5915 more chars){
"content": "<6716 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check on cleaned Triton version",
"timeout": 120000
}FAIL: shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=2.63281 (atol=0.1, rtol=0.1)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -94,19 +94,21 @@
g_idx = (kh_start * 2) // GROUP_SIZE
+ n_mask_1d = (tl.arange(0, BLOCK_N) < N - n_start)
+
s = tl.load(
S_ptr + g_idx * stride_s_g + tl.arange(0, BLOCK_N) * stride_s_n,
- mask=n_mask.reshape((BLOCK_N,)),
+ mask=n_mask_1d,
other=1.0,
- )[None, :]
+ )
z = tl.load(
Z_ptr + g_idx * stride_z_g + tl.arange(0, BLOCK_N) * stride_z_n,
- mask=n_mask.reshape((BLOCK_N,)),
+ mask=n_mask_1d,
other=0.0,
- )[None, :]
+ )
- s_bf = s.to(tl.bfloat16)
- z_bf = z.to(tl.bfloat16)
+ s_bf = s.to(tl.bfloat16)[None, :]
+ z_bf = z.to(tl.bfloat16)[None, :]
w_deq_lo = (wq_lo.to(tl.bfloat16) - z_bf) * s_bf
w_deq_hi = (wq_hi.to(tl.bfloat16) - z_bf) * s_bf
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<507 chars \u2014 see below>",
"oldString": "<482 chars \u2014 see below>"
} g_idx = (kh_start * 2) // GROUP_SIZE
n_mask_1d = (tl.arange(0, BLOCK_N) < N - n_start)
s = tl.load(
S_ptr + g_idx * stride_s_g + tl.arange(0, BLOCK_N) * stride_s_n,
mask=n_mask_1d,
other=1.0,
)
z = tl.load(
Z_ptr + g_idx * stride_z_g + tl.arange(0, BLOCK_N) * stride_z_n,
mask=n_mask_1d,
other=0.0,
)
s_bf = s.to(tl.bfloat16)[None, :]
z_bf = z.to(tl.bfloat16)[None, :] g_idx = (kh_start * 2) // GROUP_SIZE
s = tl.load(
S_ptr + g_idx * stride_s_g + tl.arange(0, BLOCK_N) * stride_s_n,
mask=n_mask.reshape((BLOCK_N,)),
other=1.0,
)[None, :]
z = tl.load(
Z_ptr + g_idx * stride_z_g + tl.arange(0, BLOCK_N) * stride_z_n,
mask=n_mask.reshape((BLOCK_N,)),
other=0.0,
)[None, :]
s_bf = s.to(tl.bfloat16)
z_bf = z.to(tl.bfloat16)Edit applied successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check after scale/zero mask fix",
"timeout": 120000
}FAIL: shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=2.70703 (atol=0.1, rtol=0.1)
{
"command": "<686 chars \u2014 see below>",
"description": "Debug correctness for shape 0",
"timeout": 60000
}python -c "
import torch, solution, reference
device = torch.device('cuda:0')
reference.M, reference.N, reference.K = 1, 12288, 4096
ref = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol = solution.Model(*solution.get_init_inputs()).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = torch.randn(1, 4096, dtype=torch.bfloat16, device=device)
ref_out = ref(x)
sol_out = sol(x)
print(f'ref_out[0,:5]: {ref_out[0,:5]}')
print(f'sol_out[0,:5]: {sol_out[0,:5]}')
diff = (ref_out.float() - sol_out.float()).abs()
print(f'max_abs_diff: {diff.max().item()}')
print(f'best_config: {solution._w4a16_gemm_kernel.best_config}')
" 2>&1Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 191, in get_init_inputs
return [M, N, K]
^
NameError: name 'M' is not defined
{
"command": "<742 chars \u2014 see below>",
"description": "Debug correctness for shape 0",
"timeout": 60000
}python -c "
import torch, solution, reference
device = torch.device('cuda:0')
reference.M = 1; reference.N = 12288; reference.K = 4096
solution.M = 1; solution.N = 12288; solution.K = 4096
ref = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol = solution.Model(*solution.get_init_inputs()).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = torch.randn(1, 4096, dtype=torch.bfloat16, device=device)
ref_out = ref(x)
sol_out = sol(x)
print(f'ref_out[0,:5]: {ref_out[0,:5]}')
print(f'sol_out[0,:5]: {sol_out[0,:5]}')
diff = (ref_out.float() - sol_out.float()).abs()
print(f'max_abs_diff: {diff.max().item()}')
print(f'best_config: {solution._w4a16_gemm_kernel.best_config}')
" 2>&1ref_out[0,:5]: tensor([-1.5703, -0.2441, -2.7500, 2.3750, 1.6406], device='cuda:0',
dtype=torch.bfloat16)
sol_out[0,:5]: tensor([-1.5703, -0.2441, -2.7500, 2.3750, 1.6406], device='cuda:0',
dtype=torch.bfloat16)
max_abs_diff: 1.9375
best_config: BLOCK_M: 8, BLOCK_N: 128, BLOCK_K_HALF: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
{
"command": "<884 chars \u2014 see below>",
"description": "Find where max diff occurs",
"timeout": 60000
}python -c "
import torch, solution, reference
device = torch.device('cuda:0')
reference.M = 1; reference.N = 12288; reference.K = 4096
solution.M = 1; solution.N = 12288; solution.K = 4096
ref = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol = solution.Model(*solution.get_init_inputs()).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = torch.randn(1, 4096, dtype=torch.bfloat16, device=device)
ref_out = ref(x)
sol_out = sol(x)
diff = (ref_out.float() - sol_out.float()).abs()
max_idx = diff.argmax().item()
print(f'max diff at index {max_idx}: ref={ref_out[0,max_idx]}, sol={sol_out[0,max_idx]}, diff={diff[0,max_idx]}')
print(f'surrounding: ref={ref_out[0,max(max_idx-3,0):max_idx+4]}')
print(f'surrounding: sol={sol_out[0,max(max_idx-3,0):max_idx+4]}')
print(f'num mismatches: {(diff > 0.5).sum().item()}')
" 2>&1max diff at index 10402: ref=-1.96875, sol=-3.90625, diff=1.9375
surrounding: ref=tensor([-2.7812, -0.3691, -0.6133, -1.9688, 2.2031, 0.2617, -0.7578],
device='cuda:0', dtype=torch.bfloat16)
surrounding: sol=tensor([-3.3438, -0.8320, -0.6367, -3.9062, 1.9219, 1.0000, -1.2422],
device='cuda:0', dtype=torch.bfloat16)
num mismatches: 3782
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1 +1 @@
- for bm in [1, 4, 8, 16]+ for bm in [1, 2, 4, 8, 16, 32, 64]{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": " for bm in [1, 2, 4, 8, 16, 32, 64]",
"oldString": " for bm in [1, 4, 8, 16]"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"limit": 40,
"offset": 17
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
17: GROUP_SIZE = 128
18:
19:
20: @triton.autotune(
21: configs=[
22: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
23: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
24: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
25: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
26: triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
27: triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
28: triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
29: triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
30: triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
31: triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
32: triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
33: triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
34: triton.Config({'BLOCK_M': 32
... (993 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,6 +1,8 @@
"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
-Optimized version with autotuning, block pointers, and persistent scheduling.
+Fused approach: each program handles a (M_block, N_block) tile,
+iterates over K pairs (packed bytes), dequantizes on-the-fly, accumulates.
+Separate configs for M=1 (GEMV) and M>1 (GEMM).
"""
from __future__ import annotations
@@ -19,14 +21,102 @@
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
+ triton.Config({'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_N': 64, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_N': 256, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
+ ],
+ key=['M', 'N', 'K_half'],
+)
+@triton.jit
+def _gemv_kernel(
+ X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
+ M, N, K_half, n_groups,
+ stride_xm, stride_xk,
+ stride_wq_kh, stride_wq_n,
+ stride_s_g, stride_s_n,
+ stride_z_g, stride_z_n,
+ stride_om, stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K_HALF: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ n_start = pid * BLOCK_N
+
+ n_off = n_start + tl.arange(0, BLOCK_N)
+ n_mask = n_off < N
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ for kh_start in range(0, K_half, BLOCK_K_HALF):
+ kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
+ kh_mask = kh_off < K_half
+
+ k_even = kh_off * 2
+ k_odd = k_even + 1
+
+ x_even = tl.load(
+ X_ptr + k_even * stride_xk,
+ mask=kh_mask,
+ other=0.0,
+ )
+ x_odd = tl.load(
+ X_ptr + k_odd * stride_xk,
+ mask=kh_mask,
+ other=0.0,
+ )
+
+ wq_packed = tl.load(
+ Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n,
+ mask=kh_mask[:, None] & n_mask[None, :],
+ other=0,
+ )
+
+ wq_lo = wq_packed & 0x0F
+ wq_hi = (wq_packed >> 4) & 0x0F
+
+ g_idx = (kh_start * 2) // GROUP_SIZE
+
+ s = tl.load(
+ S_ptr + g_idx * stride_s_g + n_off * stride_s_n,
+ mask=n_mask,
+ other=1.0,
+ )[None, :]
+ z = tl.load(
+ Z_ptr + g_idx * stride_z_g + n_off * stride_z_n,
+ mask=n_mask,
+ other=0.0,
+ )[None, :]
+
+ s_bf = s.to(tl.bfloat16)
+ z_bf = z.to(tl.bfloat16)
+
+ w_deq_lo = (wq_lo.to(tl.bfloat16) - z_bf) * s_bf
+ w_deq_hi = (wq_hi.to(tl.bfloat16) - z_bf) * s_bf
+
+ x_even_bf = x_even.to(tl.bfloat16)[:, None]
+ x_odd_bf = x_odd.to(tl.bfloat16)[:, None]
+
+ lo_contrib = tl.sum(x_even_bf * w_deq_lo, axis=0)
+ hi_contrib = tl.sum(x_odd_bf * w_deq_hi, axis=0)
+ acc += lo_contrib + hi_contrib
+
+ tl.store(
+ Out_ptr + n_off * stride_on,
+ acc.to(tl.bfloat16),
+ mask=n_mask,
+ )
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
+ triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
@@ -39,7 +129,7 @@
key=['M', 'N', 'K_half'],
)
@triton.jit
-def _w4a16_gemm_kernel(
+def _gemm_kernel(
X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
M, N, K_half, n_groups,
stride_xm, stride_xk,
@@ -69,7 +159,7 @@
kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
kh_mask = kh_off < K_half
- k_even = (kh_off * 2)
+ k_even = kh_off * 2
k_odd = k_even + 1
x_even = tl.load(
@@ -94,21 +184,20 @@
g_idx = (kh_start * 2) // GROUP_SIZE
- n_mask_1d = (tl.arange(0, BLOCK_N) < N - n_start)
-
+ n_mask_1d = tl.arange(0, BLOCK_N) < (N - n_start)
s = tl.load(
S_ptr + g_idx * stride_s_g + tl.arange(0, BLOCK_N) * stride_s_n,
mask=n_mask_1d,
other=1.0,
- )
+ )[None, :]
z = tl.load(
Z_ptr + g_idx * stride_z_g + tl.arange(0, BLOCK_N) * stride_z_n,
mask=n_mask_1d,
other=0.0,
- )
-
- s_bf = s.to(tl.bfloat16)[None, :]
-
... (1719 more chars){
"content": "<9649 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check with separate GEMV/GEMM kernels",
"timeout": 120000
}FAIL: shape 1 {'M': 32, 'N': 12288, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=3.23438 (atol=0.1, rtol=0.1)
{
"command": "<703 chars \u2014 see below>",
"description": "Verify M=1 is correct",
"timeout": 60000
}python -c "
import torch, solution, reference
device = torch.device('cuda:0')
reference.M = 1; reference.N = 12288; reference.K = 4096
solution.M = 1; solution.N = 12288; solution.K = 4096
ref = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol = solution.Model(*solution.get_init_inputs()).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = torch.randn(1, 4096, dtype=torch.bfloat16, device=device)
ref_out = ref(x)
sol_out = sol(x)
diff = (ref_out.float() - sol_out.float()).abs()
print(f'M=1: max_abs_diff={diff.max().item()}, mismatches={(diff > 0.5).sum().item()}')
print(f'GEMV best_config: {solution._gemv_kernel.best_config}')
" 2>&1M=1: max_abs_diff=0.03125, mismatches=0
GEMV best_config: BLOCK_N: 64, BLOCK_K_HALF: 64, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
{
"command": "<893 chars \u2014 see below>",
"description": "Debug M=32 correctness issue",
"timeout": 60000
}python -c "
import torch, solution, reference
device = torch.device('cuda:0')
reference.M = 32; reference.N = 12288; reference.K = 4096
solution.M = 32; solution.N = 12288; solution.K = 4096
ref = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol = solution.Model(*solution.get_init_inputs()).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = torch.randn(32, 4096, dtype=torch.bfloat16, device=device)
ref_out = ref(x)
sol_out = sol(x)
diff = (ref_out.float() - sol_out.float()).abs()
print(f'M=32: max_abs_diff={diff.max().item()}, mismatches={(diff > 0.5).sum().item()}')
max_pos = diff.argmax()
m_idx = max_pos // 12288
n_idx = max_pos % 12288
print(f'Max diff at m={m_idx.item()}, n={n_idx.item()}: ref={ref_out[m_idx,n_idx]}, sol={sol_out[m_idx,n_idx]}')
print(f'GEMM best_config: {solution._gemm_kernel.best_config}')
" 2>&1M=32: max_abs_diff=2.96875, mismatches=151835
Max diff at m=27, n=10068: ref=-3.15625, sol=-6.125
GEMM best_config: BLOCK_M: 16, BLOCK_N: 128, BLOCK_K_HALF: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -184,7 +184,8 @@
g_idx = (kh_start * 2) // GROUP_SIZE
- n_mask_1d = tl.arange(0, BLOCK_N) < (N - n_start)
+ n_mask_1d = (tl.arange(0, BLOCK_N) + n_start) < N
+
s = tl.load(
S_ptr + g_idx * stride_s_g + tl.arange(0, BLOCK_N) * stride_s_n,
mask=n_mask_1d,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<394 chars \u2014 see below>",
"oldString": "<393 chars \u2014 see below>"
} n_mask_1d = (tl.arange(0, BLOCK_N) + n_start) < N
s = tl.load(
S_ptr + g_idx * stride_s_g + tl.arange(0, BLOCK_N) * stride_s_n,
mask=n_mask_1d,
other=1.0,
)[None, :]
z = tl.load(
Z_ptr + g_idx * stride_z_g + tl.arange(0, BLOCK_N) * stride_z_n,
mask=n_mask_1d,
other=0.0,
)[None, :] n_mask_1d = tl.arange(0, BLOCK_N) < (N - n_start)
s = tl.load(
S_ptr + g_idx * stride_s_g + tl.arange(0, BLOCK_N) * stride_s_n,
mask=n_mask_1d,
other=1.0,
)[None, :]
z = tl.load(
Z_ptr + g_idx * stride_z_g + tl.arange(0, BLOCK_N) * stride_z_n,
mask=n_mask_1d,
other=0.0,
)[None, :]Edit applied successfully.{
"command": "python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}FAIL: shape 1 {'M': 32, 'N': 12288, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=3.23438 (atol=0.1, rtol=0.1)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,8 +1,7 @@
"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
-Fused approach: each program handles a (M_block, N_block) tile,
-iterates over K pairs (packed bytes), dequantizes on-the-fly, accumulates.
-Separate configs for M=1 (GEMV) and M>1 (GEMM).
+Single unified kernel: each program handles a (M_block, N_block) tile,
+iterates over K with on-the-fly dequant. M=1 uses smaller block sizes.
"""
from __future__ import annotations
@@ -21,115 +20,26 @@
@triton.autotune(
configs=[
- triton.Config({'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_N': 64, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_N': 256, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
+ # M=1 configs (GEMV)
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ # M>1 configs (GEMM)
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
],
key=['M', 'N', 'K_half'],
)
@triton.jit
-def _gemv_kernel(
- X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
- M, N, K_half, n_groups,
- stride_xm, stride_xk,
- stride_wq_kh, stride_wq_n,
- stride_s_g, stride_s_n,
- stride_z_g, stride_z_n,
- stride_om, stride_on,
- GROUP_SIZE: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K_HALF: tl.constexpr,
-):
- pid = tl.program_id(0)
- n_start = pid * BLOCK_N
-
- n_off = n_start + tl.arange(0, BLOCK_N)
- n_mask = n_off < N
-
- acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
-
- for kh_start in range(0, K_half, BLOCK_K_HALF):
- kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
- kh_mask = kh_off < K_half
-
- k_even = kh_off * 2
- k_odd = k_even + 1
-
- x_even = tl.load(
- X_ptr + k_even * stride_xk,
- mask=kh_mask,
- other=0.0,
- )
- x_odd = tl.load(
- X_ptr + k_odd * stride_xk,
- mask=kh_mask,
- other=0.0,
- )
-
- wq_packed = tl.load(
- Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n,
- mask=kh_mask[:, None] & n_mask[None, :],
- other=0,
- )
-
- wq_lo = wq_packed & 0x0F
- wq_hi = (wq_packed >> 4) & 0x0F
-
- g_idx = (kh_start * 2) // GROUP_SIZE
-
- s = tl.load(
- S_ptr + g_idx * stride_s_g + n_off * stride_s_n,
- mask=n_mask,
- other=1.0,
- )[None, :]
- z = tl.load(
- Z_ptr + g_idx * stride_z_g + n_off * stride_z_n,
- mask=n_mask,
- other=0.0,
- )[None, :]
-
- s_bf = s.to(tl.bfloat16)
- z_bf = z.to(tl.bfloat16)
-
- w_deq_lo = (wq_lo.to(tl.bfloat16) - z_bf) * s_bf
- w_deq_hi = (wq_hi.to(tl.bfloat16) - z_bf) * s_bf
-
- x_even_bf = x_even.to(tl.bfloat16)[:, None]
- x_odd_bf = x_odd.to(tl.bfloat16)[:, None]
-
- lo_contrib = tl.sum(x_even_bf * w_deq_lo, axis=0)
- hi_contrib = tl.sum(x_odd_bf * w_deq_hi, axis=0)
- acc += lo_contrib + hi_contrib
-
- tl.store(
- Out_ptr + n_off * stride_on,
- acc.to(tl.bfloat16),
- mask=n_mask,
- )
-
-
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K_HALF':
... (4555 more chars){
"content": "<6502 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "rm -rf ~/.triton/cache && python check.py 2>&1",
"description": "Clear Triton cache and run correctness check",
"timeout": 120000
}{
"command": "python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}PASS
{
"command": "python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 600000
}/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.775 ms=0.770
shape=0 variant=compiled tflops=0.653 gbps=173.624 ms=0.154
shape=0 variant=sota tflops=2.282 gbps=606.897 ms=0.044
shape=0 variant=solution tflops=1.443 gbps=383.853 ms=0.070
shape=0 solution_peak_fraction=0.2133
shape=1 variant=eager tflops=4.169 gbps=35.966 ms=0.773
shape=1 variant=compiled tflops=20.493 gbps=176.782 ms=0.157
shape=1 variant=sota tflops=23.064 gbps=198.958 ms=0.140
shape=1 variant=solution tflops=34.880 gbps=300.884 ms=0.092
shape=1 solution_peak_fraction=0.1672
shape=2 variant=eager tflops=31.328 gbps=42.703 ms=0.823
shape=2 variant=compiled tflops=120.402 gbps=164.122 ms=0.214
shape=2 variant=sota tflops=132.746 gbps=180.949 ms=0.194
shape=2 variant=solution tflops=86.699 gbps=118.181 ms=0.297
shape=2 solution_peak_fraction=0.0657
shape=3 variant=eager tflops=0.181 gbps=48.289 ms=0.185
shape=3 variant=compiled tflops=0.537 gbps=142.914 ms=0.062
shape=3 variant=sota tflops=1.160 gbp
... (423 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,7 +1,7 @@
"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
-Single unified kernel: each program handles a (M_block, N_block) tile,
-iterates over K with on-the-fly dequant. M=1 uses smaller block sizes.
+Optimized version with block pointers, 2D accumulation for K-loop,
+and autotuned configs per shape regime.
"""
from __future__ import annotations
@@ -20,21 +20,30 @@
@triton.autotune(
configs=[
- # M=1 configs (GEMV)
+ # GEMV (M=1) — small block, high stages
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
- # M>1 configs (GEMM)
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ # Small M (4-32)
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ # Medium M (32-64)
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ # Large M (64+)
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
],
key=['M', 'N', 'K_half'],
)
@@ -58,10 +67,13 @@
m_start = pid_m * BLOCK_M
n_start = pid_n * BLOCK_N
- m_off = m_start + tl.arange(0, BLOCK_M)
- n_off = n_start + tl.arange(0, BLOCK_N)
- m_mask = (m_off < M)[:, None]
- n_mask = (n_off < N)[None, :]
+ m_off = (m_start + tl.arange(0, BLOCK_M))[:, None]
+ n_off = (n_start + tl.arange(0, BLOCK_N))[None, :]
+ m_mask = m_off < M
+ n_mask = n_off < N
+
+ n_off_1d = n_start + tl.arange(0, BLOCK_N)
+ n_mask_1d = n_off_1d < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
@@ -69,22 +81,22 @@
kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
kh_mask = kh_off < K_half
- k_even = kh_off * 2
+ k_even = (kh_off * 2)
k_odd = k_even + 1
x_even = tl.load(
- X_ptr + m_off[:, None] * stride_xm + k_even[None, :] * stride_xk,
+ X_ptr + m_off * stride_xm + k_even[None, :] * stride_xk,
mask=m_mask & kh_mask[None, :],
other=0.0,
)
x_odd = tl.load(
- X_ptr + m_off[:, None] * stride_xm + k_odd[None, :] * stride_xk,
+ X_ptr + m_off * stride_xm + k_odd[None, :] * stride_xk,
mask=m_mask & kh_mask[None, :],
other=0.0,
)
wq_packed = tl.load(
- Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n,
+ Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off * stride_wq_n,
mask=kh_mask[:, None] & n_mask,
other=0,
)
@@ -95,30 +107,27 @@
g_idx = (kh_start * 2) // GROUP_SIZE
s = tl.load(
- S_ptr + g_idx * stride_s_g + n_off * stride_s_n,
- mask=(n_off < N),
+ S_ptr + g_idx * stride_s_g + n_off_1d * stride_s_n,
+ mask=n_mask_1d,
other=1.0,
- )
+ )[None, :]
z = tl.load(
- Z_ptr + g_idx * stride_z_g + n_off * stride_z_n,
- mask=(n_off < N),
+ Z_ptr + g_idx * stride_z_g + n_off_1d * stride_z_n,
+ mask=n_mask_1d,
other=0.0,
- )
+ )[None, :]
- s_bf = s.to(tl.bfloat16)[None, :]
- z_bf = z.to(tl.bfloat16)[None, :]
+ s_bf = s.to(tl.bfloat16)
+ z_bf = z.to(tl.bfloat16)
w_deq_lo = (wq_lo.to(tl.bfloat16) - z_bf) * s_bf
w_deq_hi = (wq_hi.to(tl.bfloat16) - z_bf) * s_bf
- x_even_bf = x_even.to(tl.bfloat16)
- x_odd_bf = x_odd.to(tl.bfloat16)
-
- acc += tl.dot(x_even_bf, w_deq_lo).to(tl.float32)
- acc += tl.dot(x_odd_bf, w_deq_hi
... (379 more chars){
"content": "<7230 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "python check.py 2>&1 && python benchmark.py 2>&1",
"description": "Run correctness and benchmark",
"timeout": 600000
}PASS
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.749 ms=0.770
shape=0 variant=compiled tflops=0.665 gbps=176.854 ms=0.151
shape=0 variant=sota tflops=2.175 gbps=578.567 ms=0.046
shape=0 variant=solution tflops=1.475 gbps=392.313 ms=0.068
shape=0 solution_peak_fraction=0.2180
shape=1 variant=eager tflops=4.161 gbps=35.897 ms=0.774
shape=1 variant=compiled tflops=20.462 gbps=176.512 ms=0.157
shape=1 variant=sota tflops=23.022 gbps=198.594 ms=0.140
shape=1 variant=solution tflops=31.881 gbps=275.013 ms=0.101
shape=1 solution_peak_fraction=0.1528
shape=2 variant=eager tflops=31.192 gbps=42.519 ms=0.826
shape=2 variant=compiled tflops=120.402 gbps=164.122 ms=0.214
shape=2 variant=sota tflops=133.872 gbps=182.483 ms=0.192
shape=2 variant=solution tflops=86.088 gbps=117.348 ms=0.299
shape=2 solution_peak_fraction=0.0652
shape=3 variant=eager tflops=0.179 gbps=47.662 ms=0.187
shape=3 variant=compiled tflops=0.522 gbps=139.033 ms=0.064
shape=3 variant=sota tflops=1.28
... (428 more chars){
"command": "<1069 chars \u2014 see below>",
"description": "Get selected configs",
"timeout": 120000
}python -c "
import torch, solution, reference
device = torch.device('cuda:0')
shapes = [
{'M': 1, 'N': 12288, 'K': 4096},
{'M': 32, 'N': 12288, 'K': 4096},
{'M': 256, 'N': 12288, 'K': 4096},
{'M': 1, 'N': 4096, 'K': 4096},
{'M': 16, 'N': 14336, 'K': 4096},
]
for shape in shapes:
reference.M = shape['M']; reference.N = shape['N']; reference.K = shape['K']
solution.M = shape['M']; solution.N = shape['N']; solution.K = shape['K']
ref = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol = solution.Model(*solution.get_init_inputs()).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(2026)
x = torch.randn(shape['M'], 4096, dtype=torch.bfloat16, device=device)
for _ in range(3): sol(x)
torch.cuda.synchronize()
cfg = solution._w4a16_gemm_kernel.best_config
print(f'Shape {shape}: BLOCK_M={cfg.kwargs[\"BLOCK_M\"]}, BLOCK_N={cfg.kwargs[\"BLOCK_N\"]}, BLOCK_K_HALF={cfg.kwargs[\"BLOCK_K_HALF\"]}, warps={cfg.num_warps}, stages={cfg.num_stages}')
" 2>&1Shape {'M': 1, 'N': 12288, 'K': 4096}: BLOCK_M=1, BLOCK_N=128, BLOCK_K_HALF=64, warps=8, stages=3
Shape {'M': 32, 'N': 12288, 'K': 4096}: BLOCK_M=8, BLOCK_N=128, BLOCK_K_HALF=64, warps=8, stages=3
Shape {'M': 256, 'N': 12288, 'K': 4096}: BLOCK_M=16, BLOCK_N=128, BLOCK_K_HALF=64, warps=8, stages=3
Shape {'M': 1, 'N': 4096, 'K': 4096}: BLOCK_M=1, BLOCK_N=64, BLOCK_K_HALF=64, warps=4, stages=4
Shape {'M': 16, 'N': 14336, 'K': 4096}: BLOCK_M=8, BLOCK_N=128, BLOCK_K_HALF=64, warps=8, stages=3
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,7 +1,6 @@
-"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel.
-
-Optimized version with block pointers, 2D accumulation for K-loop,
-and autotuned configs per shape regime.
+"""W4A16 weight-only quantized GEMM — Triton fused dequant kernel, optimized.
+
+Persistent kernel for M=1 (amortize launch overhead), standard tiled for M>1.
"""
from __future__ import annotations
@@ -20,29 +19,107 @@
@triton.autotune(
configs=[
- # GEMV (M=1) — small block, high stages
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- # Small M (4-32)
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
- triton.Config({'BLOCK_M': 4, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 8, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_N': 64, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_N': 64, 'BLOCK_K_HALF': 32}, num_warps=4, num_stages=4),
+ triton.Config({'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_N': 512, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ ],
+ key=['M', 'N', 'K_half'],
+)
+@triton.jit
+def _w4a16_gemv_persistent_kernel(
+ X_ptr, Wq_ptr, S_ptr, Z_ptr, Out_ptr,
+ M, N, K_half, n_groups,
+ stride_xm, stride_xk,
+ stride_wq_kh, stride_wq_n,
+ stride_s_g, stride_s_n,
+ stride_z_g, stride_z_n,
+ stride_om, stride_on,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K_HALF: tl.constexpr,
+):
+ num_programs = tl.num_programs(0)
+ total_n_blocks = tl.cdiv(N, BLOCK_N)
+
+ for pid in range(tl.program_id(0), total_n_blocks, num_programs):
+ n_start = pid * BLOCK_N
+ n_off = n_start + tl.arange(0, BLOCK_N)
+ n_mask = n_off < N
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ for kh_start in range(0, K_half, BLOCK_K_HALF):
+ kh_off = kh_start + tl.arange(0, BLOCK_K_HALF)
+ kh_mask = kh_off < K_half
+
+ k_even = kh_off * 2
+ k_odd = k_even + 1
+
+ x_even = tl.load(
+ X_ptr + k_even * stride_xk,
+ mask=kh_mask,
+ other=0.0,
+ )
+ x_odd = tl.load(
+ X_ptr + k_odd * stride_xk,
+ mask=kh_mask,
+ other=0.0,
+ )
+
+ wq_packed = tl.load(
+ Wq_ptr + kh_off[:, None] * stride_wq_kh + n_off[None, :] * stride_wq_n,
+ mask=kh_mask[:, None] & n_mask[None, :],
+ other=0,
+ )
+
+ wq_lo = wq_packed & 0x0F
+ wq_hi = (wq_packed >> 4) & 0x0F
+
+ g_idx = (kh_start * 2) // GROUP_SIZE
+
+ s = tl.load(
+ S_ptr + g_idx * stride_s_g + n_off * stride_s_n,
+ mask=n_mask,
+ other=1.0,
+ )[None, :]
+ z = tl.load(
+ Z_ptr + g_idx * stride_z_g + n_off * stride_z_n,
+ mask=n_mask,
+ other=0.0,
+ )[None, :]
+
+ s_bf = s.to(tl.bfloat16)
+ z_bf = z.to(tl.bfloat16)
+
+ w_deq_lo = (wq_lo.to(tl.bfloat16) - z_bf) * s_bf
+ w_deq_hi = (wq_hi.to(tl.bfloat16) - z_bf) * s_bf
+
+ x_even_bf = x_even.to(tl.bfloat16)[:, None]
+ x_odd_bf = x_odd.to(tl.bfloat16)[:, None]
+
+ lo = tl.sum(x_even_bf * w_deq_lo, axis=0)
+ hi = tl.sum(x_odd_bf * w_deq_hi, axis=0)
+ acc += lo + hi
+
+ tl.store(
+ Out_ptr + n_off * stride_on,
+ acc.to(tl.bfloat16),
+ mask=n_mask,
+ )
+
+
+@triton.autotune(
+ configs=[
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- # Medium M (32-64)
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- # Large M (64+)
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K_HALF': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K_HALF': 32}, num_warps=8, num_stages=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K_HAL
... (1533 more chars){
"content": "<9368 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.