"""W4A16 GEMM via Triton with fused int4 unpack + dequant.
AWQ/GPTQ-style asymmetric int4 with group_size=128.
Uses coalesced x loads via reshape+split and tuned num_stages for pipelining.
"""
from __future__ import annotations
import torch
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.autotune(
configs=[
# Decode configs (small M) - num_stages=4 is key for bandwidth
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=2, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8, num_stages=4),
# Small prefill configs
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
# Large prefill configs
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=4, num_stages=3),
],
key=["M", "N"],
)
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
# Load x contiguously
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
# Split into even/odd K elements for correct int4 unpacking
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
class Model(torch.nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0
assert K % 2 == 0
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
w_g = w_full.view(n_groups, group_size, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = _pack_int4(w_q)
self.register_buffer("w_q", w_packed)
self.register_buffer("scales", scales_2d)
self.register_buffer("zeros", zeros_2d)
# Lazy-allocated output buffer (not registered to avoid state_dict mismatch)
self._out_buffer: torch.Tensor | None = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
M, K = x.shape
N = self.N
if x.dtype != torch.bfloat16:
x = x.to(torch.bfloat16)
# Lazy allocate and reuse output buffer
if self._out_buffer is None or self._out_buffer.shape != (M, N) or self._out_buffer.device != x.device:
self._out_buffer = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
out = self._out_buffer
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]))
w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
M, N, K,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
)
return out
def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
K, N = w_q.shape
assert K % 2 == 0
lo = w_q[0::2].to(torch.uint8) & 0xF
hi = w_q[1::2].to(torch.uint8) & 0xF
return (lo | (hi << 4)).contiguous()
M = 1
N = 12288
K = 4096
def get_inputs():
x = torch.randn(M, K, dtype=torch.bfloat16)
return [x]
def get_init_inputs():
return [M, N, K]
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.923 ms=0.767
shape=0 variant=compiled tflops=0.659 gbps=175.316 ms=0.153
shape=0 variant=sota tflops=2.289 gbps=608.885 ms=0.044
shape=0 variant=solution tflops=2.200 gbps=585.041 ms=0.046
shape=0 solution_peak_fraction=0.3250
shape=1 variant=eager tflops=4.190 gbps=36.141 ms=0.769
shape=1 variant=compiled tflops=20.666 gbps=178.270 ms=0.156
shape=1 variant=sota tflops=23.170 gbps=199.874 ms=0.139
shape=1 variant=solution tflops=59.794 gbps=515.802 ms=0.054
shape=1 solution_peak_fraction=0.2866
shape=2 variant=eager tflops=31.308 gbps=42.677 ms=0.823
shape=2 variant=compiled tflops=120.808 gbps=164.676 ms=0.213
shape=2 variant=sota tflops=132.093 gbps=180.059 ms=0.195
shape=2 variant=solution tflops=153.934 gbps=209.830 ms=0.167
shape=2 solution_peak_fraction=0.1166
shape=3 variant=eager tflops=0.180 gbps=47.908 ms=0.186
shape=3 variant=compiled tflops=0.511 gbps=136.051 ms=0.066
shape=3 variant=sota tflops=1.299 gbps=345.774 ms=0.026
shape=3 variant=solution tflops=0.880 gbps=234.291 ms=0.038
shape=3 solution_peak_fraction=0.1302
shape=4 variant=eager tflops=2.058 gbps=34.816 ms=0.913
shape=4 variant=compiled tflops=10.380 gbps=175.584 ms=0.181
shape=4 variant=sota tflops=11.265 gbps=190.557 ms=0.167
shape=4 variant=solution tflops=39.030 gbps=660.206 ms=0.048
shape=4 solution_peak_fraction=0.3668
peak_fraction: 0.2203
RESULT: OK
PASS
{
"problem": "07_w4a16_gemm",
"harness": "kimi",
"model": "kimi-k2.6",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.2203,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": null, "output_tokens": null, "cache_read_tokens": null, "cache_creation_tokens": null, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py"
}{
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py"
}<system>112 lines read from file starting from line 1. Total lines in file: 112. End of file reached.</system>
1 """Naive W4A16 weight-only quantized GEMM reference (correctness only).
2
3 AWQ/GPTQ-style scheme:
4 x: (M, K) bf16
5 w_q: (K // 2, N) uint8 -- two int4 weights packed per byte (low nibble = even-K, high = odd-K)
6 scales: (K // group, N) bf16
7 zeros: (K // group, N) bf16 -- asymmetric (stored already as float zero-point)
8 out: (M, N) bf16
9
10 Dequant (per group along K):
11 w_bf[k, n] = (w_q[k, n] - zeros[k // group, n]) * scales[k // group, n]
12 where w_q[k, n] is the unpacked 4-bit value (0..15).
13
14 This reference unpacks to a full bf16 matrix and then runs torch.matmul. Slow and
15 memory-heavy on the dequant; the agent's solution must fuse unpack+GEMM.
16 """
17 from __future__ import annotations
18
19 import torch
20 import torch.nn as nn
21
22 OP_TYPE = "gemm_w4a16"
23 SUPPORTED_PRECISIONS = ["int4_bf16"]
24 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25
26 GROUP_SIZE = 128
27
28
29 def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
30 """Pack (K, N) uint8 in [0,15] into (K//2, N) uint8.
31
32 Even rows go in the low nibble, odd rows in the high nibble.
33 """
34 K, N = w_q.shape
35 assert K
... (3597 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a"
}<system>112 lines read from file starting from line 1. Total lines in file: 112. End of file reached.</system>
1 """Correctness runner for FP8 GEMM.
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5 """
6 import re
7 import sys
8 from pathlib import Path
9
10 import torch
11 import yaml
12
13 # Make the repo's src/ importable
14 REPO_ROOT = Path(__file__).resolve().parents[2]
15 sys.path.insert(0, str(REPO_ROOT))
16
17 from src.eval.correctness import check_correctness # noqa: E402
18
19
20 def main():
21 try:
22 import reference
23 import shapes
24 import solution
25 except Exception as e:
26 print(f"FAIL: import error: {e}")
27 sys.exit(1)
28
29 problem_yaml = Path("problem.yaml")
30 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31
32 # --- Forbidden-op check ------------------------------------------------
33 sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34 for forbidden in meta.get("forbidden", []):
35 pat = re.escape(forbidden)
36 if re.search(pat, sol_src):
37 print(f"FAIL: forbidden op used: {forbidden}")
38 sys.exit(1)
... (3198 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml"
}<system>128 lines read from file starting from line 1. Total lines in file: 128. End of file reached.</system>
1 """Roofline benchmark for FP8 GEMM.
2
3 For each shape: times eager reference, compiled reference, SOTA (if available),
4 and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5
6 Output lines the harness picks up:
7 shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8 peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9 """
10 import sys
11 from math import exp, log
12 from pathlib import Path
13
14 import torch
15 import yaml
16
17 REPO_ROOT = Path(__file__).resolve().parents[2]
18 sys.path.insert(0, str(REPO_ROOT))
19
20 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21 from src.eval.timing import time_fn # noqa: E402
22 from src.hardware import get as get_hw # noqa: E402
23
24
25 def _eval_formula(expr: str, vars: dict) -> float:
26 # Very small eval: only names from `vars` are valid.
27 return float(eval(expr, {"__builtins__": {}}, vars))
28
29
30 def main():
31 import reference
32 import shapes
33 import solution
34
35 meta = yaml.safe_load(Path("problem.yaml").read_text())
36 hw = get_hw(meta["hardware"][0])
37 peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"],
... (3732 more chars){
"raw": "{\"path\": \"/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a"
}<system>49 lines read from file starting from line 1. Total lines in file: 49. End of file reached.</system>
1 name: 07_w4a16_gemm
2 display_name: "W4A16 Weight-only Quantized GEMM"
3 precision: int4_bf16
4 regime: memory # decode-dominant; M=1 is bandwidth-bound on the int4 weight stream
5
6 # Dense-equivalent FLOPs (matmul work, ignoring dequant arithmetic).
7 flops_formula: "2 * M * N * K"
8
9 # Bytes moved per call (memory roofline):
10 # x: M*K*2 (bf16 activations, streamed in once)
11 # w_q: (K/2)*N (packed int4, 0.5 B/elem)
12 # scales: (K/128)*N*2 (bf16 scales)
13 # zeros: (K/128)*N*2 (bf16 zero-points)
14 # out: M*N*2 (bf16 store)
15 bytes_formula: "M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2"
16
17 hardware: [RTX_PRO_6000]
18 peak_tflops_key: bf16
19 peak_bandwidth_key: dram
20
21 tolerance:
22 bfloat16: 0.10 # group-quant adds noise on top of bf16 accumulator slop
23
24 # Forbidden ops -- agent must write the unpack + GEMM themselves, not call a
25 # vendor library that does both.
26 forbidden:
27 - "bitsandbytes.functional.dequantize_4bit"
28 - "bitsandbytes.functional.gemv_4bit"
29 - "marlin_kernel.gemm"
30 - "torch.nn.functional.linear"
31
32 sota:
33 name: "bitsandbytes NF4 (gemv_4bit / dequantize_4bit + matmul)"
34 url: "https://github.
... (865 more chars)<system>13 lines read from file starting from line 1. Total lines in file: 13. End of file reached.</system>
1 """Shape sweep for W4A16 GEMM.
2
3 Llama-style up_proj / qkv_proj shapes. Decode (M=1) is the bandwidth-bound
4 case every inference engine optimizes -- it's the bar to beat.
5 """
6
7 SHAPES = [
8 {"M": 1, "N": 12288, "K": 4096}, # decode: memory-bound on int4 weight read
9 {"M": 32, "N": 12288, "K": 4096}, # small prefill: mixed regime
10 {"M": 256, "N": 12288, "K": 4096}, # larger prefill: approaching compute
11 {"M": 1, "N": 4096, "K": 4096}, # decode: square shape
12 {"M": 16, "N": 14336, "K": 4096}, # speculative-decode-ish
13 ]
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py"
}{
"raw": "{\"command\": \"cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm &&"
}<system>87 lines read from file starting from line 1. Total lines in file: 87. End of file reached.</system>
1 """SOTA reference for W4A16 GEMM.
2
3 Library survey on RTX PRO 6000 Blackwell (SM120, CC 12.0):
4
5 - Marlin (IST-DASLab): no SM120 kernels (Ampere/Hopper only). Skip.
6 - GPTQ-Triton (fpgaminer): unmaintained; pure Triton path works on SM120
7 but is not faster than Marlin on its target HW
8 and has no Blackwell tuning. Skip as primary.
9 - AWQ (mit-han-lab/llm-awq): CUDA kernels not built for SM120 in the wheel.
10 Skip.
11 - bitsandbytes >= 0.49.2: CUDA kernels compile and run on SM120 (verified
12 on this machine). Different quant scheme (NF4,
13 symmetric, blocksize 64) than our reference's
14 AWQ-style asymmetric INT4 with group_size 128,
15 but it occupies the same memory regime and is
16 the only tuned W4A16-class kernel that runs on
17 SM120 today. Used here as an *informational*
18 SOTA line, not as a numerical reference.
19
20 The benchmark calls `sota_forward(x, ref_model)` and times it; correctness is
21 NOT
... (2952 more chars)<system>Command executed successfully.</system>
Tue Apr 28 13:46:30 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 34C P8 10W / 600W | 704MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================
... (929 more chars){
"command": "python -c \"import triton; print(triton.__version__); import triton.language as tl; print('Triton OK')\""
}{
"raw": "{\"command\": \"ls -"
}<system>Command executed successfully.</system>
3.6.0
Triton OK
{
"raw": ""
}<system>Command executed successfully.</system>
total 40
drwxrwxr-x 10 infatoshi infatoshi 4096 Apr 24 00:30 .
drwxrwxr-x 12 infatoshi infatoshi 4096 Apr 27 21:56 ..
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 28 09:58 01_fp8_gemm
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 28 11:00 02_kda_cutlass
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 28 11:45 03_paged_attention
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 28 12:11 04_kahan_softmax
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 28 12:57 05_topk_bitonic
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 28 13:46 06_sonic_moe_swiglu
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 28 09:24 07_w4a16_gemm
drwxrwxr-x 2 infatoshi infatoshi 4096 Apr 27 01:28 08_metal_lightning_attn
<system>Command executed successfully.</system>
True
{
"path": "/home/infatoshi/.claude/skills/kernel/SKILL.md"
}<system>56 lines read from file starting from line 1. Total lines in file: 56. End of file reached.</system>
1 ---
2 name: kernel-optimization
3 description: Unified GPU/TPU kernel skill. Covers CUDA, Metal, Triton, AMD HIP, TPU Pallas, kernel porting, debugging, numerical precision, and performance optimization. Use when writing, optimizing, debugging, porting, or benchmarking GPU/TPU kernels on any platform.
4 ---
5
6 # Kernel Optimization
7
8 Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
9
10 ## When this skill fires
11
12 - Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
13 - Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
14 - Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
15 - Performance profiling and optimization
16 - Numerical precision issues (FP formats, mixed precision, tolerance)
17 - AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
18
19 ## Sub-references (read on demand)
20
21 The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
22
23 | File | When to read |
24 |------|-------------|
25 | `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
26 | `~/.claude/skills/tpu-opt
... (1947 more chars){
"path": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}<system>1000 lines read from file starting from line 1. Total lines in file: 1002.</system>
1 ---
2 name: perf-optim
3 description: "[SUB-REFERENCE of kernel skill] Performance optimization. Do NOT invoke directly -- use /kernel instead, which reads this file on demand."
4 ---
5
6 # Performance Optimization Atlas
7
8 Comprehensive, hierarchical reference for optimizing GPU workloads. Pick combinations from this atlas based on your bottleneck profile. **Always profile before optimizing.**
9
10 ---
11
12 ## 0. FIRST PRINCIPLES: SPEED OF LIGHT, THEN PROFILE
13
14 ### Rule
15 Never optimize blind. But before you even profile, compute the **speed of light** -- the theoretical physical limit for this operation. Then profile to measure the gap. The gap IS the optimization opportunity.
16
17 ### Speed of Light Analysis
18 1. **Compute the limit**: What does physics allow? Peak FLOPS, peak bandwidth, minimum latency for this exact workload
19 2. **Profile current state**: Where are we now?
20 3. **Measure the gap**: speed_of_light - measured = optimization budget
21 4. **Don't continuous-improve from current state**: Strip to zero. If you built this from scratch knowing only the physical limits, what would you build? Then reason about every piece of the gap between that and current
22 - "It takes 74ms" -- don't optimize to 72ms. Compute that physics allows 6ms. Now reason about th
... (70359 more chars){
"command": "<352 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
print('CUDA version:', torch.version.cuda)
print('Triton available:', end=' ')
try:
import triton
print(triton.__version__)
except:
print('No')
print('CUTLASS available:', end=' ')
try:
import cutlass
print('Yes')
except:
print('No')
"
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1 | head -40",
"timeout": 120
}<system>Command executed successfully.</system>
CUDA version: 13.0
Triton available: 3.6.0
CUTLASS available: No
<system>Command executed successfully.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py", line 128, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py", line 33, in main
import solution
ModuleNotFoundError: No module named 'solution'
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -0,0 +1,192 @@
+"""W4A16 GEMM via Triton with fused int4 unpack + dequant.
+
+AWQ/GPTQ-style asymmetric int4 with group_size=128.
+"""
+from __future__ import annotations
+
+import torch
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.jit
+def w4a16_gemm_kernel(
+ x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ # K must be multiple of GROUP_SIZE and BLOCK_K
+ num_k_tiles = K // BLOCK_K
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_tile in range(num_k_tiles):
+ k_start = k_tile * BLOCK_K
+ offs_k = k_start + tl.arange(0, BLOCK_K)
+
+ # Load x tile: (BLOCK_M, BLOCK_K)
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
+ x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0)
+
+ # Load packed weights: (BLOCK_K // 2, BLOCK_N)
+ # w_q layout: (K//2, N)
+ packed_k_offs = k_start // 2 + tl.arange(0, BLOCK_K // 2)
+ w_ptrs = w_ptr + (packed_k_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn)
+ w_mask = (packed_k_offs[:, None] < K // 2) & (offs_n[None, :] < N)
+ w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
+
+ # Unpack int4: low nibble = even K, high nibble = odd K
+ # w_packed shape: (BLOCK_K//2, BLOCK_N)
+ # We need: (BLOCK_K, BLOCK_N)
+ lo = w_packed & 0xF # even rows
+ hi = (w_packed >> 4) & 0xF # odd rows
+
+ # Interleave lo and hi to get (BLOCK_K, BLOCK_N)
+ # Using tl.reshape and tl.join-like operations
+ # In Triton, we can construct it with broadcast and select
+ w_unpacked = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.uint8)
+ # Need to assign even indices from lo, odd from hi
+ # lo[i, j] -> w_unpacked[2*i, j]
+ # hi[i, j] -> w_unpacked[2*i+1, j]
+ # We can use tl.reshape after stacking
+ # Alternative: use broadcasting with a (BLOCK_K, 1) selector
+
+ # Simpler: create (BLOCK_K//2, BLOCK_N, 2) then reshape
+ # But Triton reshape is limited. Let's use a different approach.
+ # We can do: create two (BLOCK_K, BLOCK_N) halves and interleave with where
+
+ # Actually, a cleaner approach in Triton:
+ # w_unpacked = tl.join(lo, hi) doesn't exist directly.
+ # Let's use index-based construction:
+
+ # Even K rows: indices 0, 2, 4, ...
+ # Odd K rows: indices 1, 3, 5, ...
+ # We can create a (BLOCK_K, BLOCK_N) by broadcasting a selector
+
+ # Since BLOCK_K is compile-time constant, we can do this:
+ k_idx = tl.arange(0, BLOCK_K)
+ is_even = (k_idx[:, None] % 2) == 0
+ # lo has shape (BLOCK_K//2, BLOCK_N), we need to map k_idx//2 to lo/hi indices
+ half_idx = k_idx[:, None] // 2
+ lo_broadcast = tl.load(w_ptrs, mask=w_mask, other=0) & 0xF
+ hi_broadcast = (tl.load(w_ptrs, mask=w_mask, other=0) >> 4) & 0xF
+ # Hmm, this loads twice. Not ideal but let's get correctness first.
+
+ # Better: load once, then use tl.reshape if available or clever indexing
+ w_unpacked = tl.where(is_even,
+ tl.reshape(lo, (BLOCK_K, BLOCK_N)),
+ tl.reshape(hi, (BLOCK_K, BLOCK_N)))
+
+ # Load scales and zeros for this group
+ group_idx = k_start // GROUP_SIZE
+ s_ptrs = s_ptr + (group_idx * stride_sg + offs_n[None, :] * stride_sn)
+ z_ptrs = z_ptr + (group_idx * stride_zg + offs_n[None, :] * stride_zn)
+
+ s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0)
+
+ # Dequant: (w - z) * s
+ # w_unpacked is uint8, need to convert to float/bf16
+ w_bf = (w_unpacked.to(tl.bfloat16) - z.to(tl.bfloat16)) * s.to(tl.bfloat16)
+
+ # Dot product
+ acc += tl.dot(x, w_bf)
+
+ # Store output
+ out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
+ out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
+
+
+class Model(torch.nn.Module):
+ def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
+ super().__init__()
+ assert K % group_size == 0
+ assert K % 2 == 0
+ self.M, self.N, self.K = M, N, K
+ self.group_size = group_size
+ n_groups = K // group_size
+
+ torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
+ w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
+ w_g = w_full.view(n_groups, group_size, N)
+ w_min = w_g.min(dim=1, keepdim=True).values
+ w_max = w_g.max(dim=1, keepdim=True).values
+ scales = (w_max - w_min).clamp_min(1e-8) / 15.0
+ zeros = (-w_min / scales).round().clamp(0, 15)
+ w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
+ w_q = w_q.view(K, N)
+ scales_2d = scales.squeeze(1).to(torch.bfloat16)
+ zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
+ w_packed = _pack_int4(w_q)
+
+ self.register_buffer("w_q", w_packed)
+ self.register_buffer("scales", scales_2d)
+ self.register_buffer("zeros", zeros_2d)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ M,
... (1245 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<6942 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 6942 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 120
}<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 1924, in reshape
return _semantic.reshape(input, shape, can_reorder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/semantic.py", line 630, in reshape
raise ValueError("reshape() cannot change total number of elements in tensor")
ValueError: reshape() cannot change total number of elements in tensor
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
... (2367 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -31,84 +31,47 @@
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- # K must be multiple of GROUP_SIZE and BLOCK_K
num_k_tiles = K // BLOCK_K
-
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
- offs_k = k_start + tl.arange(0, BLOCK_K)
+ # Packed K offset
+ pk_start = k_start // 2
+ offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
- # Load x tile: (BLOCK_M, BLOCK_K)
- x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
- x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
- x = tl.load(x_ptrs, mask=x_mask, other=0.0)
-
- # Load packed weights: (BLOCK_K // 2, BLOCK_N)
- # w_q layout: (K//2, N)
- packed_k_offs = k_start // 2 + tl.arange(0, BLOCK_K // 2)
- w_ptrs = w_ptr + (packed_k_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn)
- w_mask = (packed_k_offs[:, None] < K // 2) & (offs_n[None, :] < N)
+ # Load packed weights: (BLOCK_K//2, BLOCK_N)
+ w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
+ w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
- # Unpack int4: low nibble = even K, high nibble = odd K
- # w_packed shape: (BLOCK_K//2, BLOCK_N)
- # We need: (BLOCK_K, BLOCK_N)
- lo = w_packed & 0xF # even rows
- hi = (w_packed >> 4) & 0xF # odd rows
+ lo = (w_packed & 0xF).to(tl.bfloat16)
+ hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- # Interleave lo and hi to get (BLOCK_K, BLOCK_N)
- # Using tl.reshape and tl.join-like operations
- # In Triton, we can construct it with broadcast and select
- w_unpacked = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.uint8)
- # Need to assign even indices from lo, odd from hi
- # lo[i, j] -> w_unpacked[2*i, j]
- # hi[i, j] -> w_unpacked[2*i+1, j]
- # We can use tl.reshape after stacking
- # Alternative: use broadcasting with a (BLOCK_K, 1) selector
-
- # Simpler: create (BLOCK_K//2, BLOCK_N, 2) then reshape
- # But Triton reshape is limited. Let's use a different approach.
- # We can do: create two (BLOCK_K, BLOCK_N) halves and interleave with where
-
- # Actually, a cleaner approach in Triton:
- # w_unpacked = tl.join(lo, hi) doesn't exist directly.
- # Let's use index-based construction:
-
- # Even K rows: indices 0, 2, 4, ...
- # Odd K rows: indices 1, 3, 5, ...
- # We can create a (BLOCK_K, BLOCK_N) by broadcasting a selector
-
- # Since BLOCK_K is compile-time constant, we can do this:
- k_idx = tl.arange(0, BLOCK_K)
- is_even = (k_idx[:, None] % 2) == 0
- # lo has shape (BLOCK_K//2, BLOCK_N), we need to map k_idx//2 to lo/hi indices
- half_idx = k_idx[:, None] // 2
- lo_broadcast = tl.load(w_ptrs, mask=w_mask, other=0) & 0xF
- hi_broadcast = (tl.load(w_ptrs, mask=w_mask, other=0) >> 4) & 0xF
- # Hmm, this loads twice. Not ideal but let's get correctness first.
-
- # Better: load once, then use tl.reshape if available or clever indexing
- w_unpacked = tl.where(is_even,
- tl.reshape(lo, (BLOCK_K, BLOCK_N)),
- tl.reshape(hi, (BLOCK_K, BLOCK_N)))
-
# Load scales and zeros for this group
group_idx = k_start // GROUP_SIZE
- s_ptrs = s_ptr + (group_idx * stride_sg + offs_n[None, :] * stride_sn)
- z_ptrs = z_ptr + (group_idx * stride_zg + offs_n[None, :] * stride_zn)
-
- s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0)
- z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0)
-
- # Dequant: (w - z) * s
- # w_unpacked is uint8, need to convert to float/bf16
- w_bf = (w_unpacked.to(tl.bfloat16) - z.to(tl.bfloat16)) * s.to(tl.bfloat16)
-
- # Dot product
- acc += tl.dot(x, w_bf)
-
+ s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
+ z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
+ s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
+ z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
+
+ lo_deq = (lo - z) * s
+ hi_deq = (hi - z) * s
+
+ # Load x even and odd K slices
+ offs_k_even = k_start + 2 * tl.arange(0, BLOCK_K // 2)
+ offs_k_odd = k_start + 2 * tl.arange(0, BLOCK_K // 2) + 1
+
+ x_even_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k_even[None, :] * stride_xk)
+ x_odd_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k_odd[None, :] * stride_xk)
+ x_mask = (offs_m[:, None] < M)
+
+ x_even = tl.load(x_even_ptrs, mask=x_mask, other=0.0)
+ x_odd = tl.load(x_odd_ptrs, mask=x_mask, other=0.0)
+
+ acc += tl.dot(x_even, lo_deq)
+ acc += tl.dot(x_odd, hi_deq)
+
# Store output
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
@@ -145,17 +108,13 @@
M, K = x.shape
N = self.N
x = x.to(torch.bfloat16)
-
- # Output
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
-
- # Grid
+
BLOCK_M = 16
BLOCK_N = 128
BLOCK_K = 128
-
+
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
-
w4a16_gemm_kernel[grid](
... (71 more chars){
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<4969 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 4969 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.835 ms=0.769
shape=0 variant=compiled tflops=0.642 gbps=170.667 ms=0.157
shape=0 variant=sota tflops=2.070 gbps=550.400 ms=0.049
shape=0 variant=solution tflops=0.909 gbps=241.829 ms=0.111
shape=0 solution_peak_fraction=0.1343
shape=1 variant=eager tflops=4.221 gbps=36.413 ms=0.763
shape=1 variant=compiled tflops=20.346 gbps=175.513 ms=0.158
shape=1 variant=sota tflops=22.293 gbps=192.305 ms=0.144
shape=1 variant=solution tflops=16.677 gbps=143.862 ms=0.193
shape=1 solution_peak_fraction=0.0799
shape=2 variant=eager tflops=31.106 gbps=42.401 ms=0.828
shape=2 variant=compiled tflops=119.891 gbps=163.425 ms=0.215
shape=2 variant=sota tflops=131.683 gbps=179.499 ms=0.196
shape=2 variant=solution tflops=41.473 gbps=56.533 ms=0.621
shape=2 solution_peak_fraction=0.0314
shape=3 variant=eager tflops=0.181 gbps=48.273 ms=0.185
shape=3 variant=compiled tflops=0.504 gbps=134.057
... (470 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -11,6 +11,31 @@
GROUP_SIZE = 128
+@triton.autotune(
+ configs=[
+ # Decode configs (small M)
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE": 128}, num_warps=4, num_stages=3),
+ # Small prefill configs
+ triton.Config({"BLOCK_M": 8, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ # Large prefill configs
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE": 128}, num_warps=8, num_stages=3),
+ ],
+ key=["M"],
+)
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
@@ -36,11 +61,9 @@
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
- # Packed K offset
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
- # Load packed weights: (BLOCK_K//2, BLOCK_N)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
@@ -48,7 +71,6 @@
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- # Load scales and zeros for this group
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
@@ -58,7 +80,6 @@
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
- # Load x even and odd K slices
offs_k_even = k_start + 2 * tl.arange(0, BLOCK_K // 2)
offs_k_odd = k_start + 2 * tl.arange(0, BLOCK_K // 2) + 1
@@ -72,7 +93,6 @@
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
- # Store output
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
@@ -110,11 +130,7 @@
x = x.to(torch.bfloat16)
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- BLOCK_M = 16
- BLOCK_N = 128
- BLOCK_K = 128
-
- grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+ grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]))
w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
M, N, K,
@@ -123,7 +139,6 @@
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
GROUP_SIZE=self.group_size,
)
return out
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<6835 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 6835 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>ERROR: Command failed with exit code: 1.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py", line 128, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py", line 98, in main
ms_sol = time_fn(sol_model, inputs, iters=num_perf_trials)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py", line 55, in time_fn
fn(*inputs)
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 134, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=gr
... (816 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -14,25 +14,25 @@
@triton.autotune(
configs=[
# Decode configs (small M)
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=3),
# Small prefill configs
- triton.Config({"BLOCK_M": 8, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 8, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
# Large prefill configs
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SIZE": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE": 128}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
],
key=["M"],
)
@@ -48,7 +48,7 @@
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
+ GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -139,7 +139,6 @@
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
- GROUP_SIZE=self.group_size,
)
return out
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<6478 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 6478 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.860 ms=0.768
shape=0 variant=compiled tflops=0.637 gbps=169.320 ms=0.158
shape=0 variant=sota tflops=2.304 gbps=612.675 ms=0.044
shape=0 variant=solution tflops=1.468 gbps=390.300 ms=0.069
shape=0 solution_peak_fraction=0.2168
shape=1 variant=eager tflops=4.197 gbps=36.205 ms=0.767
shape=1 variant=compiled tflops=20.562 gbps=177.378 ms=0.157
shape=1 variant=sota tflops=23.101 gbps=199.278 ms=0.139
shape=1 variant=solution tflops=28.017 gbps=241.679 ms=0.115
shape=1 solution_peak_fraction=0.1343
shape=2 variant=eager tflops=31.255 gbps=42.605 ms=0.824
shape=2 variant=compiled tflops=120.078 gbps=163.681 ms=0.215
shape=2 variant=sota tflops=133.153 gbps=181.503 ms=0.194
shape=2 variant=solution tflops=63.050 gbps=85.945 ms=0.409
shape=2 solution_peak_fraction=0.0477
shape=3 variant=eager tflops=0.183 gbps=48.601 ms=0.184
shape=3 variant=compiled tflops=0.470 gbps=125.018
... (470 more chars){
"command": "<121 chars \u2014 see below>"
}python -c "
import triton.language as tl
print([x for x in dir(tl) if not x.startswith('_')])
" | tr ',' '\n' | head -80
<system>Command executed successfully.</system>
['PropagateNan'
'TRITON_MAX_TENSOR_NUMEL'
'abs'
'add'
'advance'
'arange'
'argmax'
'argmin'
'associative_scan'
'assume'
'atomic_add'
'atomic_and'
'atomic_cas'
'atomic_max'
'atomic_min'
'atomic_or'
'atomic_xchg'
'atomic_xor'
'bfloat16'
'bitonic_merge'
'block_type'
'broadcast'
'broadcast_to'
'cast'
'cat'
'cdiv'
'ceil'
'clamp'
'condition'
'const'
'constexpr'
'constexpr_type'
'core'
'cos'
'cumprod'
'cumsum'
'debug_barrier'
'device_assert'
'device_print'
'div_rn'
'dot'
'dot_scaled'
'dtype'
'erf'
'exp'
'exp2'
'expand_dims'
'extra'
'fdiv'
'flip'
'float16'
'float32'
'float64'
'float8e4b15'
'float8e4b8'
'float8e4nv'
'float8e5'
'float8e5b16'
'floor'
'fma'
'full'
'gather'
'histogram'
'inline_asm_elementwise'
'int1'
'int16'
'int32'
'int64'
'int8'
'interleave'
'join'
'load'
'load_tensor_descriptor'
'log'
'log2'
'make_block_ptr'
'make_tensor_descriptor'
'map_elementwise'
'math'
'max'
{
"command": "<79 chars \u2014 see below>"
}python -c "
import triton.language as tl
help(tl.interleave)
" 2>&1 | head -40
{
"raw": "{\"command\": \"python -c \\\"\\nimport triton.language"
}<system>Command executed successfully.</system>
Help on JITFunction in module triton.language.standard:
interleave = JITFunction(triton.language.standard:interleave)
Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape.
Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])`
:param a: The first input tensor.
:type a: Tensor
:param b: The second input tensor.
:type b: Tensor
<system>Command executed successfully.</system>
Help on function join in module triton.language.core:
join(a, b, _semantic=None)
Join the given tensors in a new, minor dimension.
For example, given two tensors of shape (4,8), produces a new tensor of
shape (4,8,2). Given two scalars, returns a tensor of shape (2).
The two inputs are broadcasted to be the same shape.
If you want to join more than two elements, you can use multiple calls to
this function. This reflects the constraint in Triton that tensors must
have power-of-two sizes.
join is the inverse of split.
:param a: The first input tensor.
:type a: Tensor
:param b: The second input tensor.
:type b: Tensor
{
"command": "python -c \"import triton.language as tl; print(hasattr(tl, 'trans'))\""
}{
"raw": ""
}<system>Command executed successfully.</system>
True
<system>Command executed successfully.</system>
True
{
"command": "<689 chars \u2014 see below>"
}python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def test_join(a_ptr, b_ptr, out_ptr):
a = tl.load(a_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :])
b = tl.load(b_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :])
joined = tl.join(a, b) # (4, 4, 2)
# Try reshape to (8, 4)
c = tl.reshape(joined, (8, 4))
tl.store(out_ptr + tl.arange(0, 8)[:, None] * 4 + tl.arange(0, 4)[None, :], c)
a = torch.arange(16).reshape(4, 4).float().cuda()
b = torch.arange(16, 32).reshape(4, 4).float().cuda()
out = torch.empty(8, 4).float().cuda()
test_join[(1,)](a, b, out)
print('a:', a)
print('b:', b)
print('out:', out)
"
<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1260, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1089, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file
{
"command": "<713 chars \u2014 see below>"
}cat > /tmp/test_triton.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_join(a_ptr, b_ptr, out_ptr):
a = tl.load(a_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :])
b = tl.load(b_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :])
joined = tl.join(a, b) # (4, 4, 2)
c = tl.reshape(joined, (8, 4))
tl.store(out_ptr + tl.arange(0, 8)[:, None] * 4 + tl.arange(0, 4)[None, :], c)
a = torch.arange(16).reshape(4, 4).float().cuda()
b = torch.arange(16, 32).reshape(4, 4).float().cuda()
out = torch.empty(8, 4).float().cuda()
test_join[(1,)](a, b, out)
print('a:', a)
print('b:', b)
print('out:', out)
EOF
python /tmp/test_triton.py
<system>Command executed successfully.</system>
a: tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]], device='cuda:0')
b: tensor([[16., 17., 18., 19.],
[20., 21., 22., 23.],
[24., 25., 26., 27.],
[28., 29., 30., 31.]], device='cuda:0')
out: tensor([[ 0., 16., 1., 17.],
[ 2., 18., 3., 19.],
[ 4., 20., 5., 21.],
[ 6., 22., 7., 23.],
[ 8., 24., 9., 25.],
[10., 26., 11., 27.],
[12., 28., 13., 29.],
[14., 30., 15., 31.]], device='cuda:0')
{
"command": "<845 chars \u2014 see below>"
}cat > /tmp/test_triton2.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_join2(a_ptr, b_ptr, out_ptr):
a = tl.load(a_ptr + tl.arange(0, 2)[:, None] * 4 + tl.arange(0, 4)[None, :])
b = tl.load(b_ptr + tl.arange(0, 2)[:, None] * 4 + tl.arange(0, 4)[None, :])
a_t = tl.trans(a)
b_t = tl.trans(b)
joined = tl.join(a_t, b_t) # (4, 2, 2)
c = tl.reshape(joined, (4, 4))
d = tl.trans(c)
tl.store(out_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :], d)
a = torch.arange(8).reshape(2, 4).float().cuda()
b = torch.arange(8, 16).reshape(2, 4).float().cuda()
out = torch.empty(4, 4).float().cuda()
test_join2[(1,)](a, b, out)
print('a:', a)
print('b:', b)
print('out:', out)
print('Expected: row0=a[0], row1=b[0], row2=a[1], row3=b[1]')
EOF
python /tmp/test_triton2.py
<system>Command executed successfully.</system>
a: tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.]], device='cuda:0')
b: tensor([[ 8., 9., 10., 11.],
[12., 13., 14., 15.]], device='cuda:0')
out: tensor([[ 0., 1., 2., 3.],
[ 8., 9., 10., 11.],
[ 4., 5., 6., 7.],
[12., 13., 14., 15.]], device='cuda:0')
Expected: row0=a[0], row1=b[0], row2=a[1], row3=b[1]
{
"command": "<2154 chars \u2014 see below>"
}cat > /tmp/test_triton3.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_w4a16(x_ptr, w_ptr, s_ptr, z_ptr, out_ptr, M, N, K):
BLOCK_M = 4
BLOCK_N = 4
BLOCK_K = 4
GROUP_SIZE = 4
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
k_start = 0
pk_start = 0
offs_pk = tl.arange(0, BLOCK_K // 2)
w_packed = tl.load(w_ptr + offs_pk[:, None] * 4 + offs_n[None, :])
lo = w_packed & 0xF
hi = (w_packed >> 4) & 0xF
s = tl.load(s_ptr + offs_n[None, :]).to(tl.bfloat16)
z = tl.load(z_ptr + offs_n[None, :]).to(tl.bfloat16)
lo_bf = lo.to(tl.bfloat16)
hi_bf = hi.to(tl.bfloat16)
lo_t = tl.trans(lo_bf)
hi_t = tl.trans(hi_bf)
joined = tl.join(lo_t, hi_t)
w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
w_unpacked = tl.trans(w_reshaped)
w_deq = (w_unpacked - z) * s
x = tl.load(x_ptr + offs_m[:, None] * K + offs_k[None, :]).to(tl.bfloat16)
acc += tl.dot(x, w_deq)
tl.store(out_ptr + offs_m[:, None] * N + offs_n[None, :], acc.to(tl.bfloat16))
# Test data
M, N, K = 4, 4, 4
x = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]], dtype=torch.bfloat16).cuda()
w_q = torch.tensor([[0x01, 0x23, 0x45, 0x67]], dtype=torch.uint8).cuda() # K//2=2, but we need 2x4 for BLOCK_K=4
# Actually need (K//2, N) = (2, 4)
w_q = torch.tensor([[0x10, 0x32, 0x54, 0x76], [0x98, 0xBA, 0xDC, 0xFE]], dtype=torch.uint8).cuda()
s = torch.ones(1, N, dtype=torch.bfloat16).cuda()
z = torch.zeros(1, N, dtype=torch.bfloat16).cuda()
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
test_w4a16[(1,)](x, w_q, s, z, out, M, N, K)
# Reference: unpack and matmul
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device='cuda')
w_unpacked[0::2] = w_q & 0xF
w_unpacked[1::2] = (w_q >> 4) & 0xF
w_bf = w_unpacked.to(torch.bfloat16)
ref = x @ w_bf
print('out:', out)
print('ref:', ref)
print('match:', torch.allclose(out.float(), ref.float(), atol=0.1))
EOF
python /tmp/test_triton3.py
<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/test_triton3.py", line 55, in <module>
test_w4a16[(1,)](x, w_q, s, z, out, M, N, K)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.1
... (1054 more chars){
"command": "<2090 chars \u2014 see below>"
}cat > /tmp/test_triton3.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_w4a16(x_ptr, w_ptr, s_ptr, z_ptr, out_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
k_start = 0
pk_start = 0
offs_pk = tl.arange(0, BLOCK_K // 2)
w_packed = tl.load(w_ptr + offs_pk[:, None] * 4 + offs_n[None, :])
lo = w_packed & 0xF
hi = (w_packed >> 4) & 0xF
s = tl.load(s_ptr + offs_n[None, :]).to(tl.bfloat16)
z = tl.load(z_ptr + offs_n[None, :]).to(tl.bfloat16)
lo_bf = lo.to(tl.bfloat16)
hi_bf = hi.to(tl.bfloat16)
lo_t = tl.trans(lo_bf)
hi_t = tl.trans(hi_bf)
joined = tl.join(lo_t, hi_t)
w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
w_unpacked = tl.trans(w_reshaped)
w_deq = (w_unpacked - z) * s
x = tl.load(x_ptr + offs_m[:, None] * K + offs_k[None, :]).to(tl.bfloat16)
acc += tl.dot(x, w_deq)
tl.store(out_ptr + offs_m[:, None] * N + offs_n[None, :], acc.to(tl.bfloat16))
# Test data
M, N, K = 4, 4, 4
x = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]], dtype=torch.bfloat16).cuda()
w_q = torch.tensor([[0x10, 0x32, 0x54, 0x76], [0x98, 0xBA, 0xDC, 0xFE]], dtype=torch.uint8).cuda()
s = torch.ones(1, N, dtype=torch.bfloat16).cuda()
z = torch.zeros(1, N, dtype=torch.bfloat16).cuda()
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
test_w4a16[(1,)](x, w_q, s, z, out, M, N, K, BLOCK_M=4, BLOCK_N=4, BLOCK_K=4, GROUP_SIZE=4)
# Reference: unpack and matmul
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device='cuda')
w_unpacked[0::2] = w_q & 0xF
w_unpacked[1::2] = (w_q >> 4) & 0xF
w_bf = w_unpacked.to(torch.bfloat16)
ref = x @ w_bf
print('out:', out)
print('ref:', ref)
print('match:', torch.allclose(out.float(), ref.float(), atol=0.1))
EOF
python /tmp/test_triton3.py
<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 2054, in dot
res = _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/semantic.py", line 1528, in dot
assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Input shapes should have M >= 1, N >= 1 and K >= 16
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/test_triton3.py", line 49, in <module>
test_w4a16[(1,)](x, w_q, s, z, out, M, N, K, BLOCK_M=4, BLOCK_N=4, BLOCK_K=4, GROUP_SIZE=4)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/sit
... (1523 more chars){
"command": "<2085 chars \u2014 see below>"
}cat > /tmp/test_triton3.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_w4a16(x_ptr, w_ptr, s_ptr, z_ptr, out_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
k_start = 0
offs_pk = tl.arange(0, BLOCK_K // 2)
w_packed = tl.load(w_ptr + offs_pk[:, None] * N + offs_n[None, :])
lo = w_packed & 0xF
hi = (w_packed >> 4) & 0xF
s = tl.load(s_ptr + offs_n[None, :]).to(tl.bfloat16)
z = tl.load(z_ptr + offs_n[None, :]).to(tl.bfloat16)
lo_bf = lo.to(tl.bfloat16)
hi_bf = hi.to(tl.bfloat16)
lo_t = tl.trans(lo_bf)
hi_t = tl.trans(hi_bf)
joined = tl.join(lo_t, hi_t)
w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
w_unpacked = tl.trans(w_reshaped)
w_deq = (w_unpacked - z) * s
x = tl.load(x_ptr + offs_m[:, None] * K + offs_k[None, :]).to(tl.bfloat16)
acc += tl.dot(x, w_deq)
tl.store(out_ptr + offs_m[:, None] * N + offs_n[None, :], acc.to(tl.bfloat16))
# Test data
M, N, K = 4, 4, 16
x = torch.ones(M, K, dtype=torch.bfloat16).cuda()
w_q = torch.ones(K // 2, N, dtype=torch.uint8).cuda()
# Make w_q have recognizable values: low nibble = 1, high nibble = 2
w_q[:, :] = 0x21 # lo=1, hi=2
s = torch.ones(1, N, dtype=torch.bfloat16).cuda()
z = torch.zeros(1, N, dtype=torch.bfloat16).cuda()
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
test_w4a16[(1,)](x, w_q, s, z, out, M, N, K, BLOCK_M=4, BLOCK_N=4, BLOCK_K=16, GROUP_SIZE=16)
# Reference: unpack and matmul
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device='cuda')
w_unpacked[0::2] = w_q & 0xF
w_unpacked[1::2] = (w_q >> 4) & 0xF
w_bf = w_unpacked.to(torch.bfloat16)
ref = x @ w_bf
print('out:', out)
print('ref:', ref)
print('match:', torch.allclose(out.float(), ref.float(), atol=0.1))
EOF
python /tmp/test_triton3.py
<system>Command executed successfully.</system>
out: tensor([[24., 24., 24., 24.],
[24., 24., 24., 24.],
[24., 24., 24., 24.],
[24., 24., 24., 24.]], device='cuda:0', dtype=torch.bfloat16)
ref: tensor([[24., 24., 24., 24.],
[24., 24., 24., 24.],
[24., 24., 24., 24.],
[24., 24., 24., 24.]], device='cuda:0', dtype=torch.bfloat16)
match: True
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -20,12 +20,14 @@
triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 256}, num_warps=8, num_stages=2),
# Small prefill configs
triton.Config({"BLOCK_M": 8, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256}, num_warps=8, num_stages=2),
# Large prefill configs
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
@@ -33,6 +35,8 @@
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 256}, num_warps=8, num_stages=2),
],
key=["M"],
)
@@ -71,27 +75,29 @@
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
+ # Interleave lo and hi along K dimension to get (BLOCK_K, BLOCK_N)
+ lo_t = tl.trans(lo)
+ hi_t = tl.trans(hi)
+ joined = tl.join(lo_t, hi_t) # (BLOCK_N, BLOCK_K//2, 2)
+ w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
+ w_unpacked = tl.trans(w_reshaped) # (BLOCK_K, BLOCK_N)
+
+ # Load scales and zeros for this group
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
- lo_deq = (lo - z) * s
- hi_deq = (hi - z) * s
+ w_deq = (w_unpacked - z) * s
- offs_k_even = k_start + 2 * tl.arange(0, BLOCK_K // 2)
- offs_k_odd = k_start + 2 * tl.arange(0, BLOCK_K // 2) + 1
+ # Load x tile contiguously: (BLOCK_M, BLOCK_K)
+ offs_k = k_start + tl.arange(0, BLOCK_K)
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
+ x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0)
- x_even_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k_even[None, :] * stride_xk)
- x_odd_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k_odd[None, :] * stride_xk)
- x_mask = (offs_m[:, None] < M)
-
- x_even = tl.load(x_even_ptrs, mask=x_mask, other=0.0)
- x_odd = tl.load(x_odd_ptrs, mask=x_mask, other=0.0)
-
- acc += tl.dot(x_even, lo_deq)
- acc += tl.dot(x_odd, hi_deq)
+ acc += tl.dot(x, w_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<7000 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7000 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 120
}<system>ERROR: Command failed with exit code: 1.</system>
FAIL: shape 4 {'M': 16, 'N': 14336, 'K': 4096} seed 42: tolerance exceeded: max_abs_diff=2.28125 (atol=0.1, rtol=0.1)
{
"command": "<4067 chars \u2014 see below>"
}cat > /tmp/test_triton4.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_w4a16(x_ptr, w_ptr, s_ptr, z_ptr, out_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
num_k_tiles = K // BLOCK_K
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * (N // 2) + offs_n[None, :] * 1) # stride_wk = N//2, stride_wn = 1
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
lo_t = tl.trans(lo)
hi_t = tl.trans(hi)
joined = tl.join(lo_t, hi_t)
w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
w_unpacked = tl.trans(w_reshaped)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * N + offs_n[None, :]
z_ptrs = z_ptr + group_idx * N + offs_n[None, :]
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
w_deq = (w_unpacked - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * K + offs_k[None, :])
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
acc += tl.dot(x, w_deq)
out_ptrs = out_ptr + (offs_m[:, None] * N + offs_n[None, :])
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# Match shape 4: M=16, N=14336, K=4096
M, N, K = 16, 14336, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
# Generate weights like reference
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
# Pack
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
# Reference
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device='cuda')
w_unpacked[0::2] = w_packed & 0xF
w_unpacked[1::2] = (w_packed >> 4) & 0xF
s_full = scales_2d.repeat_interleave(GROUP_SIZE, dim=0)
z_full = zeros_2d.repeat_interleave(GROUP_SIZE, dim=0)
w_bf = (w_unpacked.to(torch.bfloat16) - z_full) * s_full
ref = x @ w_bf
# Test kernel
for BLOCK_M in [8, 16]:
for BLOCK_N in [128, 256]:
for BLOCK_K in [128, 256]:
if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0:
continue
test_w4a16[(M//BLOCK_M, N//BLOCK_N)](x, w_packed, scales_2d, zeros_2d, out, M, N, K,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE=GROUP_SIZE)
max_diff = (out - ref).abs().max().item()
mean_diff = (out - ref).abs().mean().item()
print(f"BM={BLOCK_M}, BN={BLOCK_N}, BK={BLOCK_K}: max_diff={max_diff
... (67 more chars)<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/tmp/test_triton4.py", line 87, in <module>
w_bf = (w_unpacked.to(torch.bfloat16) - z_full) * s_full
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
{
"command": "<4047 chars \u2014 see below>"
}cat > /tmp/test_triton4.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_w4a16(x_ptr, w_ptr, s_ptr, z_ptr, out_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
num_k_tiles = K // BLOCK_K
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * (N // 2) + offs_n[None, :] * 1)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
lo_t = tl.trans(lo)
hi_t = tl.trans(hi)
joined = tl.join(lo_t, hi_t)
w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
w_unpacked = tl.trans(w_reshaped)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * N + offs_n[None, :]
z_ptrs = z_ptr + group_idx * N + offs_n[None, :]
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
w_deq = (w_unpacked - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * K + offs_k[None, :])
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
acc += tl.dot(x, w_deq)
out_ptrs = out_ptr + (offs_m[:, None] * N + offs_n[None, :])
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# Match shape 4: M=16, N=14336, K=4096
M, N, K = 16, 14336, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
# Generate weights like reference
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
# Pack
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
# Reference
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device='cuda')
w_unpacked[0::2] = w_packed & 0xF
w_unpacked[1::2] = (w_packed >> 4) & 0xF
s_full = scales_2d.repeat_interleave(GROUP_SIZE, dim=0)
z_full = zeros_2d.repeat_interleave(GROUP_SIZE, dim=0)
w_bf = (w_unpacked.to(torch.bfloat16) - z_full) * s_full
ref = x @ w_bf
# Test kernel
for BLOCK_M in [8, 16]:
for BLOCK_N in [128, 256]:
for BLOCK_K in [128, 256]:
if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0:
continue
test_w4a16[(M//BLOCK_M, N//BLOCK_N)](x, w_packed, scales_2d, zeros_2d, out, M, N, K,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE=GROUP_SIZE)
max_diff = (out - ref).abs().max().item()
mean_diff = (out - ref).abs().mean().item()
print(f"BM={BLOCK_M}, BN={BLOCK_N}, BK={BLOCK_K}: max_diff={max_diff:.4f}, mean_diff={me
... (47 more chars)<system>ERROR: Command failed with exit code: 1.</system>
BM=8, BN=128, BK=128: max_diff=8.7500, mean_diff=1.507812
BM=8, BN=128, BK=256: max_diff=8.2500, mean_diff=1.500000
BM=8, BN=256, BK=128: max_diff=8.7500, mean_diff=1.507812
Traceback (most recent call last):
File "/tmp/test_triton4.py", line 96, in <module>
test_w4a16[(M//BLOCK_M, N//BLOCK_N)](x, w_packed, scales_2d, zeros_2d, out, M, N, K,
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/compiler/compiler.py", line 482, in launch_metadata
self._init_handles()
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/compiler/compiler.py", line 456, in _init_handles
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/compiler/compiler.py", line 448, in raise_
raise err
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 141312,
... (71 more chars){
"command": "python -c \"import triton.language as tl; help(tl.dot_scaled)\" 2>&1 | head -40"
}<system>Command executed successfully.</system>
Help on function dot_scaled in module triton.language.core:
dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, rhs_k_pack=True, out_dtype=triton.language.float32, _semantic=None)
Returns the matrix product of two blocks in microscaling format.
lhs and rhs use microscaling formats described here:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
Software emulation enables targeting hardware architectures without native microscaling
operation support. Right now for such case, microscaled lhs/rhs are upcasted to
:code:`bf16` element type beforehand for dot computation, with one exception:
for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type,
the other input is also upcasted to :code:`fp16` element type instead.
This behavior is experimental and may be subject to change in the future.
:param lhs: The first tensor to be multiplied.
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
:param lhs_scale: Scale factor for lhs tensor. Shape should be [M, K//group_size] when lhs is [M, K], where group_size is 32 if scales type are `e8m0`.
:type lhs_scale: e8m0 type represented as an uint8 tensor, or None.
:param lhs_fo
... (1146 more chars){
"command": "python -c \"import triton.language as tl; print(hasattr(tl, 'split'))\""
}<system>Command executed successfully.</system>
True
{
"command": "<1048 chars \u2014 see below>"
}cat > /tmp/test_triton5.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_split(x_ptr, out1_ptr, out2_ptr):
x = tl.load(x_ptr + tl.arange(0, 4)[:, None] * 8 + tl.arange(0, 8)[None, :])
# x shape: (4, 8)
# We want to split along last dimension into (4, 4, 2) and extract each slice
# Actually, split takes a tensor and splits it along a dimension into equal halves
a, b = tl.split(x, axis=1) # Split along axis 1 (N dimension) into two (4, 4) tensors
# But this splits [0,1,2,3,4,5,6,7] into [0,1,2,3] and [4,5,6,7]
# We want even/odd: [0,2,4,6] and [1,3,5,7]
tl.store(out1_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :], a)
tl.store(out2_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :], b)
x = torch.arange(32).reshape(4, 8).float().cuda()
out1 = torch.empty(4, 4).float().cuda()
out2 = torch.empty(4, 4).float().cuda()
test_split[(1,)](x, out1, out2)
print('x:', x)
print('out1:', out1)
print('out2:', out2)
EOF
python /tmp/test_triton5.py
<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
TypeError: split() got an unexpected keyword argument 'axis'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/test_triton5.py", line 20, in <module>
test_split[(1,)](x, out1, out2)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module
... (914 more chars){
"command": "python -c \"import triton.language as tl; help(tl.split)\" 2>&1 | head -40"
}<system>Command executed successfully.</system>
Help on function split in module triton.language.core:
split(a, _semantic=None, _generator=None) -> 'tuple[tensor, tensor]'
Split a tensor in two along its last dim, which must have size 2.
For example, given a tensor of shape (4,8,2), produces two tensors of shape
(4,8). Given a tensor of shape (2), returns two scalars.
If you want to split into more than two pieces, you can use multiple calls
to this function (probably plus calling reshape). This reflects the
constraint in Triton that tensors must have power-of-two sizes.
split is the inverse of join.
:param a: The tensor to split.
:type a: Tensor
This function can also be called as a member function on :py:class:`tensor`,
as :code:`x.split()` instead of
:code:`split(x)`.
{
"command": "<750 chars \u2014 see below>"
}cat > /tmp/test_triton6.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_reshape_split(x_ptr, out1_ptr, out2_ptr):
x = tl.load(x_ptr + tl.arange(0, 4)[:, None] * 8 + tl.arange(0, 8)[None, :])
x_3d = tl.reshape(x, (4, 4, 2))
a, b = tl.split(x_3d)
tl.store(out1_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :], a)
tl.store(out2_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :], b)
x = torch.arange(32).reshape(4, 8).float().cuda()
out1 = torch.empty(4, 4).float().cuda()
out2 = torch.empty(4, 4).float().cuda()
test_reshape_split[(1,)](x, out1, out2)
print('x:', x)
print('out1 (even cols):', out1)
print('out2 (odd cols):', out2)
EOF
python /tmp/test_triton6.py
<system>Command executed successfully.</system>
x: tensor([[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 8., 9., 10., 11., 12., 13., 14., 15.],
[16., 17., 18., 19., 20., 21., 22., 23.],
[24., 25., 26., 27., 28., 29., 30., 31.]], device='cuda:0')
out1 (even cols): tensor([[ 0., 2., 4., 6.],
[ 8., 10., 12., 14.],
[16., 18., 20., 22.],
[24., 26., 28., 30.]], device='cuda:0')
out2 (odd cols): tensor([[ 1., 3., 5., 7.],
[ 9., 11., 13., 15.],
[17., 19., 21., 23.],
[25., 27., 29., 31.]], device='cuda:0')
{
"command": "<4014 chars \u2014 see below>"
}cat > /tmp/test_triton7.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_w4a16(x_ptr, w_ptr, s_ptr, z_ptr, out_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
num_k_tiles = K // BLOCK_K
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * (N // 2) + offs_n[None, :] * 1)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * N + offs_n[None, :]
z_ptrs = z_ptr + group_idx * N + offs_n[None, :]
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * K + offs_k[None, :])
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * N + offs_n[None, :])
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# Match shape 4: M=16, N=14336, K=4096
M, N, K = 16, 14336, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
# Generate weights like reference
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
# Pack
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
# Reference
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device='cuda')
w_unpacked[0::2] = w_packed & 0xF
w_unpacked[1::2] = (w_packed >> 4) & 0xF
s_full = scales_2d.repeat_interleave(GROUP_SIZE, dim=0)
z_full = zeros_2d.repeat_interleave(GROUP_SIZE, dim=0)
w_bf = (w_unpacked.to(torch.bfloat16) - z_full) * s_full
ref = x @ w_bf
# Test kernel
for BLOCK_M in [8, 16]:
for BLOCK_N in [128, 256]:
for BLOCK_K in [128, 256]:
if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0:
continue
test_w4a16[(M//BLOCK_M, N//BLOCK_N)](x, w_packed, scales_2d, zeros_2d, out, M, N, K,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE=GROUP_SIZE)
max_diff = (out - ref).abs().max().item()
mean_diff = (out - ref).abs().mean().item()
print(f"BM={BLOCK_M}, BN={BLOCK_N}, BK={BLOCK_K}: max_diff={max_diff:.4f}, mean_diff={mean_diff:.6f}")
EOF
python /tmp/te
... (14 more chars)<system>Command executed successfully.</system>
BM=8, BN=128, BK=128: max_diff=8.7500, mean_diff=1.507812
BM=8, BN=128, BK=256: max_diff=8.2500, mean_diff=1.500000
BM=8, BN=256, BK=128: max_diff=8.7500, mean_diff=1.507812
BM=8, BN=256, BK=256: max_diff=8.2500, mean_diff=1.500000
BM=16, BN=128, BK=128: max_diff=8.7500, mean_diff=1.507812
BM=16, BN=128, BK=256: max_diff=8.2500, mean_diff=1.500000
BM=16, BN=256, BK=128: max_diff=8.7500, mean_diff=1.507812
BM=16, BN=256, BK=256: max_diff=8.2500, mean_diff=1.500000
{
"command": "<4031 chars \u2014 see below>"
}cat > /tmp/test_triton7.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_w4a16(x_ptr, w_ptr, s_ptr, z_ptr, out_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
num_k_tiles = K // BLOCK_K
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * N + offs_n[None, :] * 1) # FIXED: stride_wk = N
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * N + offs_n[None, :]
z_ptrs = z_ptr + group_idx * N + offs_n[None, :]
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * K + offs_k[None, :])
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * N + offs_n[None, :])
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# Match shape 4: M=16, N=14336, K=4096
M, N, K = 16, 14336, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
# Generate weights like reference
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
# Pack
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
# Reference
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device='cuda')
w_unpacked[0::2] = w_packed & 0xF
w_unpacked[1::2] = (w_packed >> 4) & 0xF
s_full = scales_2d.repeat_interleave(GROUP_SIZE, dim=0)
z_full = zeros_2d.repeat_interleave(GROUP_SIZE, dim=0)
w_bf = (w_unpacked.to(torch.bfloat16) - z_full) * s_full
ref = x @ w_bf
# Test kernel
for BLOCK_M in [8, 16]:
for BLOCK_N in [128, 256]:
for BLOCK_K in [128, 256]:
if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0:
continue
test_w4a16[(M//BLOCK_M, N//BLOCK_N)](x, w_packed, scales_2d, zeros_2d, out, M, N, K,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE=GROUP_SIZE)
max_diff = (out - ref).abs().max().item()
mean_diff = (out - ref).abs().mean().item()
print(f"BM={BLOCK_M}, BN={BLOCK_N}, BK={BLOCK_K}: max_diff={max_diff:.4f}, mean_diff={mean_diff:.6f}")
E
... (31 more chars)<system>Command executed successfully.</system>
BM=8, BN=128, BK=128: max_diff=0.0156, mean_diff=0.000002
BM=8, BN=128, BK=256: max_diff=2.4062, mean_diff=0.320312
BM=8, BN=256, BK=128: max_diff=0.0156, mean_diff=0.000002
BM=8, BN=256, BK=256: max_diff=2.4062, mean_diff=0.320312
BM=16, BN=128, BK=128: max_diff=0.0156, mean_diff=0.000002
BM=16, BN=128, BK=256: max_diff=2.4062, mean_diff=0.320312
BM=16, BN=256, BK=128: max_diff=0.0156, mean_diff=0.000002
BM=16, BN=256, BK=256: max_diff=2.4062, mean_diff=0.320312
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -20,14 +20,12 @@
triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 256}, num_warps=8, num_stages=2),
# Small prefill configs
triton.Config({"BLOCK_M": 8, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256}, num_warps=8, num_stages=2),
# Large prefill configs
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
@@ -35,8 +33,6 @@
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 256}, num_warps=8, num_stages=2),
],
key=["M"],
)
@@ -75,29 +71,27 @@
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- # Interleave lo and hi along K dimension to get (BLOCK_K, BLOCK_N)
- lo_t = tl.trans(lo)
- hi_t = tl.trans(hi)
- joined = tl.join(lo_t, hi_t) # (BLOCK_N, BLOCK_K//2, 2)
- w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
- w_unpacked = tl.trans(w_reshaped) # (BLOCK_K, BLOCK_N)
-
- # Load scales and zeros for this group
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
- w_deq = (w_unpacked - z) * s
+ lo_deq = (lo - z) * s
+ hi_deq = (hi - z) * s
- # Load x tile contiguously: (BLOCK_M, BLOCK_K)
+ # Load x contiguously
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
- acc += tl.dot(x, w_deq)
+ # Split into even/odd K elements
+ x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
+ x_even, x_odd = tl.split(x_3d)
+
+ acc += tl.dot(x_even, lo_deq)
+ acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<6415 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 6415 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.796 ms=0.769
shape=0 variant=compiled tflops=0.644 gbps=171.313 ms=0.156
shape=0 variant=sota tflops=2.329 gbps=619.480 ms=0.043
shape=0 variant=solution tflops=1.524 gbps=405.333 ms=0.066
shape=0 solution_peak_fraction=0.2252
shape=1 variant=eager tflops=4.165 gbps=35.930 ms=0.773
shape=1 variant=compiled tflops=20.647 gbps=178.105 ms=0.156
shape=1 variant=sota tflops=23.146 gbps=199.667 ms=0.139
shape=1 variant=solution tflops=45.242 gbps=390.271 ms=0.071
shape=1 solution_peak_fraction=0.2168
shape=2 variant=eager tflops=31.192 gbps=42.519 ms=0.826
shape=2 variant=compiled tflops=118.567 gbps=161.621 ms=0.217
shape=2 variant=sota tflops=131.179 gbps=178.812 ms=0.196
shape=2 variant=solution tflops=145.205 gbps=197.931 ms=0.177
shape=2 solution_peak_fraction=0.1100
shape=3 variant=eager tflops=0.179 gbps=47.662 ms=0.187
shape=3 variant=compiled tflops=0.487 gbps=129.5
... (472 more chars){
"command": "<4079 chars \u2014 see below>"
}cat > /tmp/test_triton8.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_w4a16_interleave(x_ptr, w_ptr, s_ptr, z_ptr, out_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
num_k_tiles = K // BLOCK_K
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * N + offs_n[None, :] * 1)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
lo_t = tl.trans(lo)
hi_t = tl.trans(hi)
joined = tl.join(lo_t, hi_t)
w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
w_unpacked = tl.trans(w_reshaped)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * N + offs_n[None, :]
z_ptrs = z_ptr + group_idx * N + offs_n[None, :]
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
w_deq = (w_unpacked - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * K + offs_k[None, :])
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
acc += tl.dot(x, w_deq)
out_ptrs = out_ptr + (offs_m[:, None] * N + offs_n[None, :])
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# Match shape 4: M=16, N=14336, K=4096
M, N, K = 16, 14336, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
# Generate weights like reference
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
# Pack
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
# Reference
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device='cuda')
w_unpacked[0::2] = w_packed & 0xF
w_unpacked[1::2] = (w_packed >> 4) & 0xF
s_full = scales_2d.repeat_interleave(GROUP_SIZE, dim=0)
z_full = zeros_2d.repeat_interleave(GROUP_SIZE, dim=0)
w_bf = (w_unpacked.to(torch.bfloat16) - z_full) * s_full
ref = x @ w_bf
# Test kernel
for BLOCK_M in [8, 16]:
for BLOCK_N in [128, 256]:
for BLOCK_K in [128]:
if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0:
continue
test_w4a16_interleave[(M//BLOCK_M, N//BLOCK_N)](x, w_packed, scales_2d, zeros_2d, out, M, N, K,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE=GROUP_SIZE)
max_diff = (out - ref).abs().max().item()
mean_diff = (out - ref).abs().mean().item()
print(f"BM={BLOCK_M}, BN={BLOCK_N}, BK={BLOCK_K}: max_di
... (79 more chars)<system>Command executed successfully.</system>
BM=8, BN=128, BK=128: max_diff=0.0156, mean_diff=0.000002
BM=8, BN=256, BK=128: max_diff=0.0156, mean_diff=0.000002
BM=16, BN=128, BK=128: max_diff=0.0156, mean_diff=0.000002
BM=16, BN=256, BK=128: max_diff=0.0156, mean_diff=0.000002
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,6 +1,7 @@
"""W4A16 GEMM via Triton with fused int4 unpack + dequant.
AWQ/GPTQ-style asymmetric int4 with group_size=128.
+Uses a single tl.dot per K tile by interleaving unpacked int4 weights.
"""
from __future__ import annotations
@@ -71,27 +72,29 @@
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
+ # Interleave lo and hi along K dimension: (BLOCK_K, BLOCK_N)
+ lo_t = tl.trans(lo)
+ hi_t = tl.trans(hi)
+ joined = tl.join(lo_t, hi_t)
+ w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
+ w_unpacked = tl.trans(w_reshaped)
+
+ # Load scales and zeros for this group
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
- lo_deq = (lo - z) * s
- hi_deq = (hi - z) * s
+ w_deq = (w_unpacked - z) * s
- # Load x contiguously
+ # Load x tile contiguously: (BLOCK_M, BLOCK_K)
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
- # Split into even/odd K elements
- x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
- x_even, x_odd = tl.split(x_3d)
-
- acc += tl.dot(x_even, lo_deq)
- acc += tl.dot(x_odd, hi_deq)
+ acc += tl.dot(x, w_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<6619 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 6619 bytes.</system>{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 120
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.130 gbps=34.662 ms=0.772
shape=0 variant=compiled tflops=0.664 gbps=176.555 ms=0.152
shape=0 variant=sota tflops=2.174 gbps=578.167 ms=0.046
shape=0 variant=solution tflops=1.359 gbps=361.308 ms=0.074
shape=0 solution_peak_fraction=0.2007
shape=1 variant=eager tflops=4.186 gbps=36.114 ms=0.769
shape=1 variant=compiled tflops=20.657 gbps=178.197 ms=0.156
shape=1 variant=sota tflops=22.782 gbps=196.526 ms=0.141
shape=1 variant=solution tflops=28.225 gbps=243.475 ms=0.114
shape=1 solution_peak_fraction=0.1353
shape=2 variant=eager tflops=31.148 gbps=42.458 ms=0.827
shape=2 variant=compiled tflops=120.618 gbps=164.417 ms=0.214
shape=2 variant=sota tflops=132.626 gbps=180.785 ms=0.194
shape=2 variant=solution tflops=108.818 gbps=148.332 ms=0.237
shape=2 solution_peak_fraction=0.0824
shape=3 variant=eager tflops=0.181 gbps=48.127 ms=0.186
shape=3 variant=compiled tflops=0.483 gbps=128.5
... (472 more chars){
"command": "<1045 chars \u2014 see below>"
}cat > /tmp/test_permute.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_permute(a_ptr, b_ptr, out_ptr):
a = tl.load(a_ptr + tl.arange(0, 2)[:, None] * 4 + tl.arange(0, 4)[None, :])
b = tl.load(b_ptr + tl.arange(0, 2)[:, None] * 4 + tl.arange(0, 4)[None, :])
# a, b are (2, 4)
# Stack them: (2, 4, 2)
stacked = tl.join(a, b) # Actually join creates (2, 4, 2) with last dim = 2
# permute from (2, 4, 2) to (4, 4)
# Wait, permute changes axis order. (2, 4, 2) -> permute to what?
# We want (4, 4) where result[2*i+c, n] = a[i, n] if c=0 else b[i, n]
# Let's try reshape first: (2, 4, 2) -> (4, 4)
c = tl.reshape(stacked, (4, 4))
tl.store(out_ptr + tl.arange(0, 4)[:, None] * 4 + tl.arange(0, 4)[None, :], c)
a = torch.arange(8).reshape(2, 4).float().cuda()
b = torch.arange(8, 16).reshape(2, 4).float().cuda()
out = torch.empty(4, 4).float().cuda()
test_permute[(1,)](a, b, out)
print('a:', a)
print('b:', b)
print('out:', out)
EOF
python /tmp/test_permute.py
<system>Command executed successfully.</system>
a: tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.]], device='cuda:0')
b: tensor([[ 8., 9., 10., 11.],
[12., 13., 14., 15.]], device='cuda:0')
out: tensor([[ 0., 8., 1., 9.],
[ 2., 10., 3., 11.],
[ 4., 12., 5., 13.],
[ 6., 14., 7., 15.]], device='cuda:0')
{
"command": "<725 chars \u2014 see below>"
}cat > /tmp/check_configs.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import solution
import reference
import shapes
import torch
device = torch.device('cuda')
for shape in shapes.SHAPES:
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
model = solution.Model(*init_args).to(device).eval()
x = torch.randn(shape['M'], shape['K'], dtype=torch.bfloat16, device=device)
# Warmup / autotune
for _ in range(3):
_ = model(x)
print(f"Shape {shape}: best config = {solution.w4a16_gemm_kernel.best_config}")
EOF
python /tmp/check_configs.py
<system>Command executed successfully.</system>
Shape {'M': 1, 'N': 12288, 'K': 4096}: best config = BLOCK_M: 1, BLOCK_N: 64, BLOCK_K: 128, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Shape {'M': 32, 'N': 12288, 'K': 4096}: best config = BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Shape {'M': 256, 'N': 12288, 'K': 4096}: best config = BLOCK_M: 64, BLOCK_N: 128, BLOCK_K: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 1, 'N': 4096, 'K': 4096}: best config = BLOCK_M: 1, BLOCK_N: 64, BLOCK_K: 128, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Shape {'M': 16, 'N': 14336, 'K': 4096}: best config = BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 128, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
{
"command": "<5179 chars \u2014 see below>"
}cat > /tmp/test_perf.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# Generate test data
M, N, K = 256, 12288, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
# Time different configs
import time
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
grid = (M // BLOCK_M, N // BLOCK_N)
# Warmup
for _ in range(5):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
start = time.time()
for _ in range(50):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w
... (1179 more chars)<system>Command executed successfully.</system>
BM=32, BN=128, BK=64, warps=4, stages=2: 0.192ms, 133.9 TFLOPS
BM=32, BN=128, BK=64, warps=4, stages=3: 0.188ms, 137.2 TFLOPS
BM=32, BN=128, BK=64, warps=8, stages=2: 0.145ms, 177.4 TFLOPS
BM=32, BN=128, BK=64, warps=8, stages=3: 0.142ms, 181.9 TFLOPS
BM=32, BN=128, BK=128, warps=4, stages=2: 0.210ms, 122.9 TFLOPS
BM=32, BN=128, BK=128, warps=4, stages=3: 0.210ms, 122.6 TFLOPS
BM=32, BN=128, BK=128, warps=8, stages=2: 0.166ms, 154.9 TFLOPS
BM=32, BN=128, BK=128, warps=8, stages=3: 0.132ms, 195.4 TFLOPS
BM=32, BN=256, BK=64, warps=4, stages=2: 0.162ms, 159.2 TFLOPS
BM=32, BN=256, BK=64, warps=4, stages=3: 0.141ms, 182.3 TFLOPS
BM=32, BN=256, BK=64, warps=8, stages=2: 0.296ms, 87.0 TFLOPS
BM=32, BN=256, BK=64, warps=8, stages=3: 0.284ms, 90.8 TFLOPS
BM=32, BN=256, BK=128, warps=4, stages=2: 0.203ms, 127.0 TFLOPS
BM=32, BN=256, BK=128, warps=4, stages=3: 0.186ms, 138.6 TFLOPS
BM=32, BN=256, BK=128, warps=8, stages=2: 0.301ms, 85.5 TFLOPS
BM=32, BN=256, BK=128, warps=8, stages=3: 0.285ms, 90.5 TFLOPS
BM=64, BN=128, BK=64, warps=4, stages=2: 0.147ms, 175.4 TFLOPS
BM=64, BN=128, BK=64, warps=4, stages=3: 0.146ms, 176.6 TFLOPS
BM=64, BN=128, BK=64, warps=8, stages=2: 0.164ms, 157.2 TFLOPS
BM=64, BN=128, BK=64, warps=8, stages=3: 0.145ms, 177.9 TFLOPS
BM=64, BN=128, BK=128, warps=4, stages=2: 0.187ms, 137.8 TFLOPS
BM=64, BN=128, BK=128, warps=4, stages=3: 0.219ms, 117.6 TFLOPS
BM=64, BN=128, BK=128, warps=8, stages=2: 0.174ms, 148.0 TFL
... (1703 more chars){
"command": "<5241 chars \u2014 see below>"
}cat > /tmp/test_perf_m1.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# Generate test data for M=1, N=12288, K=4096
M, N, K = 1, 12288, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
# Time different configs
import time
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
grid = (M // BLOCK_M, N // BLOCK_N)
# Warmup
for _ in range(5):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0),
... (1241 more chars)<system>Command executed successfully.</system>
BM=1, BN=32, BK=64, warps=4, stages=2: 0.024ms, 1114.5 GB/s
BM=1, BN=32, BK=64, warps=4, stages=3: 0.020ms, 1340.4 GB/s
BM=1, BN=32, BK=64, warps=8, stages=2: 0.039ms, 683.1 GB/s
BM=1, BN=32, BK=64, warps=8, stages=3: 0.037ms, 722.3 GB/s
BM=1, BN=32, BK=128, warps=4, stages=2: 0.018ms, 1452.4 GB/s
BM=1, BN=32, BK=128, warps=4, stages=3: 0.018ms, 1516.0 GB/s
BM=1, BN=32, BK=128, warps=8, stages=2: 0.032ms, 837.1 GB/s
BM=1, BN=32, BK=128, warps=8, stages=3: 0.030ms, 879.3 GB/s
BM=1, BN=64, BK=64, warps=4, stages=2: 0.027ms, 1008.7 GB/s
BM=1, BN=64, BK=64, warps=4, stages=3: 0.023ms, 1167.1 GB/s
BM=1, BN=64, BK=64, warps=8, stages=2: 0.029ms, 920.8 GB/s
BM=1, BN=64, BK=64, warps=8, stages=3: 0.029ms, 939.0 GB/s
BM=1, BN=64, BK=128, warps=4, stages=2: 0.022ms, 1211.4 GB/s
BM=1, BN=64, BK=128, warps=4, stages=3: 0.023ms, 1169.7 GB/s
BM=1, BN=64, BK=128, warps=8, stages=2: 0.026ms, 1012.3 GB/s
BM=1, BN=64, BK=128, warps=8, stages=3: 0.023ms, 1189.6 GB/s
BM=1, BN=128, BK=64, warps=4, stages=2: 0.054ms, 500.1 GB/s
BM=1, BN=128, BK=64, warps=4, stages=3: 0.051ms, 529.4 GB/s
BM=1, BN=128, BK=64, warps=8, stages=2: 0.036ms, 736.4 GB/s
BM=1, BN=128, BK=64, warps=8, stages=3: 0.028ms, 958.2 GB/s
BM=1, BN=128, BK=128, warps=4, stages=2: 0.055ms, 488.5 GB/s
BM=1, BN=128, BK=128, warps=4, stages=3: 0.046ms, 585.8 GB/s
BM=1, BN=128, BK=128, warps=8, stages=2: 0.029ms, 922.1 GB/s
BM=1, BN=128, BK=128, warps=8, stages=3: 0.024ms, 1093.6 GB/s
BM=1,
... (479 more chars){
"command": "<5184 chars \u2014 see below>"
}cat > /tmp/test_perf_m1_sq.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=1, N=4096, K=4096
M, N, K = 1, 4096, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import time
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(5):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.strid
... (1184 more chars)<system>Command executed successfully.</system>
BM=1, BN=32, BK=64, warps=4, stages=2: 0.020ms, 441.6 GB/s
BM=1, BN=32, BK=64, warps=4, stages=3: 0.013ms, 710.0 GB/s
BM=1, BN=32, BK=64, warps=8, stages=2: 0.024ms, 367.2 GB/s
BM=1, BN=32, BK=64, warps=8, stages=3: 0.018ms, 500.4 GB/s
BM=1, BN=32, BK=128, warps=4, stages=2: 0.015ms, 596.8 GB/s
BM=1, BN=32, BK=128, warps=4, stages=3: 0.011ms, 841.4 GB/s
BM=1, BN=32, BK=128, warps=8, stages=2: 0.017ms, 528.2 GB/s
BM=1, BN=32, BK=128, warps=8, stages=3: 0.014ms, 630.4 GB/s
BM=1, BN=64, BK=64, warps=4, stages=2: 0.024ms, 377.5 GB/s
BM=1, BN=64, BK=64, warps=4, stages=3: 0.017ms, 525.3 GB/s
BM=1, BN=64, BK=64, warps=8, stages=2: 0.025ms, 359.1 GB/s
BM=1, BN=64, BK=64, warps=8, stages=3: 0.018ms, 509.5 GB/s
BM=1, BN=64, BK=128, warps=4, stages=2: 0.018ms, 500.1 GB/s
BM=1, BN=64, BK=128, warps=4, stages=3: 0.017ms, 520.5 GB/s
BM=1, BN=64, BK=128, warps=8, stages=2: 0.019ms, 478.9 GB/s
BM=1, BN=64, BK=128, warps=8, stages=3: 0.015ms, 594.7 GB/s
BM=1, BN=128, BK=64, warps=4, stages=2: 0.053ms, 167.8 GB/s
BM=1, BN=128, BK=64, warps=4, stages=3: 0.050ms, 178.5 GB/s
BM=1, BN=128, BK=64, warps=8, stages=2: 0.036ms, 247.4 GB/s
BM=1, BN=128, BK=64, warps=8, stages=3: 0.028ms, 321.8 GB/s
BM=1, BN=128, BK=128, warps=4, stages=2: 0.054ms, 164.0 GB/s
BM=1, BN=128, BK=128, warps=4, stages=3: 0.045ms, 196.3 GB/s
BM=1, BN=128, BK=128, warps=8, stages=2: 0.029ms, 311.4 GB/s
BM=1, BN=128, BK=128, warps=8, stages=3: 0.023ms, 390.4 GB/s
BM=1, BN=256, BK
... (468 more chars){
"command": "<5254 chars \u2014 see below>"
}cat > /tmp/test_perf_m16.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=16, N=14336, K=4096
M, N, K = 16, 14336, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import time
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(5):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.str
... (1254 more chars)<system>Command executed successfully.</system>
BM=8, BN=128, BK=64, warps=4, stages=2: 0.081ms, 391.4 GB/s, 23.1 TFLOPS
BM=8, BN=128, BK=64, warps=4, stages=3: 0.069ms, 458.6 GB/s, 27.1 TFLOPS
BM=8, BN=128, BK=64, warps=8, stages=2: 0.050ms, 631.4 GB/s, 37.3 TFLOPS
BM=8, BN=128, BK=64, warps=8, stages=3: 0.045ms, 711.1 GB/s, 42.0 TFLOPS
BM=8, BN=128, BK=128, warps=4, stages=2: 0.068ms, 467.9 GB/s, 27.7 TFLOPS
BM=8, BN=128, BK=128, warps=4, stages=3: 0.068ms, 464.6 GB/s, 27.5 TFLOPS
BM=8, BN=128, BK=128, warps=8, stages=2: 0.038ms, 843.5 GB/s, 49.9 TFLOPS
BM=8, BN=128, BK=128, warps=8, stages=3: 0.040ms, 794.8 GB/s, 47.0 TFLOPS
BM=8, BN=256, BK=64, warps=4, stages=2: 0.063ms, 500.7 GB/s, 29.6 TFLOPS
BM=8, BN=256, BK=64, warps=4, stages=3: 0.040ms, 802.0 GB/s, 47.4 TFLOPS
BM=8, BN=256, BK=64, warps=8, stages=2: 0.090ms, 354.4 GB/s, 21.0 TFLOPS
BM=8, BN=256, BK=64, warps=8, stages=3: 0.082ms, 388.8 GB/s, 23.0 TFLOPS
BM=8, BN=256, BK=128, warps=4, stages=2: 0.054ms, 588.5 GB/s, 34.8 TFLOPS
BM=8, BN=256, BK=128, warps=4, stages=3: 0.037ms, 855.0 GB/s, 50.5 TFLOPS
BM=8, BN=256, BK=128, warps=8, stages=2: 0.086ms, 371.7 GB/s, 22.0 TFLOPS
BM=8, BN=256, BK=128, warps=8, stages=3: 0.079ms, 402.4 GB/s, 23.8 TFLOPS
BM=16, BN=128, BK=64, warps=4, stages=2: 0.063ms, 507.3 GB/s, 30.0 TFLOPS
BM=16, BN=128, BK=64, warps=4, stages=3: 0.049ms, 649.4 GB/s, 38.4 TFLOPS
BM=16, BN=128, BK=64, warps=8, stages=2: 0.034ms, 925.4 GB/s, 54.7 TFLOPS
BM=16, BN=128, BK=64, warps=8, stages=3: 0.027ms, 1169
... (919 more chars){
"command": "<5148 chars \u2014 see below>"
}cat > /tmp/test_perf_m32.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=32, N=12288, K=4096
M, N, K = 32, 12288, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import time
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(5):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.str
... (1148 more chars)<system>Command executed successfully.</system>
BM=16, BN=128, BK=64, warps=4, stages=2: 0.079ms, 40.7 TFLOPS
BM=16, BN=128, BK=64, warps=4, stages=3: 0.072ms, 45.0 TFLOPS
BM=16, BN=128, BK=64, warps=8, stages=2: 0.049ms, 65.7 TFLOPS
BM=16, BN=128, BK=64, warps=8, stages=3: 0.043ms, 74.6 TFLOPS
BM=16, BN=128, BK=128, warps=4, stages=2: 0.069ms, 46.4 TFLOPS
BM=16, BN=128, BK=128, warps=4, stages=3: 0.072ms, 44.5 TFLOPS
BM=16, BN=128, BK=128, warps=8, stages=2: 0.040ms, 80.6 TFLOPS
BM=16, BN=128, BK=128, warps=8, stages=3: 0.042ms, 77.4 TFLOPS
BM=16, BN=256, BK=64, warps=4, stages=2: 0.049ms, 65.7 TFLOPS
BM=16, BN=256, BK=64, warps=4, stages=3: 0.043ms, 74.6 TFLOPS
BM=16, BN=256, BK=64, warps=8, stages=2: 0.092ms, 35.2 TFLOPS
BM=16, BN=256, BK=64, warps=8, stages=3: 0.083ms, 38.9 TFLOPS
BM=16, BN=256, BK=128, warps=4, stages=2: 0.055ms, 58.9 TFLOPS
BM=16, BN=256, BK=128, warps=4, stages=3: 0.037ms, 87.3 TFLOPS
BM=16, BN=256, BK=128, warps=8, stages=2: 0.086ms, 37.5 TFLOPS
BM=16, BN=256, BK=128, warps=8, stages=3: 0.082ms, 39.5 TFLOPS
BM=32, BN=128, BK=64, warps=4, stages=2: 0.062ms, 51.6 TFLOPS
BM=32, BN=128, BK=64, warps=4, stages=3: 0.056ms, 57.2 TFLOPS
BM=32, BN=128, BK=64, warps=8, stages=2: 0.039ms, 82.7 TFLOPS
BM=32, BN=128, BK=64, warps=8, stages=3: 0.033ms, 97.4 TFLOPS
BM=32, BN=128, BK=128, warps=4, stages=2: 0.058ms, 55.8 TFLOPS
BM=32, BN=128, BK=128, warps=4, stages=3: 0.057ms, 57.0 TFLOPS
BM=32, BN=128, BK=128, warps=8, stages=2: 0.033ms, 96.8 TFLOPS
BM=32, BN=128,
... (549 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,7 +1,6 @@
"""W4A16 GEMM via Triton with fused int4 unpack + dequant.
AWQ/GPTQ-style asymmetric int4 with group_size=128.
-Uses a single tl.dot per K tile by interleaving unpacked int4 weights.
"""
from __future__ import annotations
@@ -14,26 +13,34 @@
@triton.autotune(
configs=[
- # Decode configs (small M)
+ # Decode configs (small M) - prioritize BN=32 for better parallelism
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=3),
# Small prefill configs
- triton.Config({"BLOCK_M": 8, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
# Large prefill configs
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=4, num_stages=3),
],
key=["M"],
)
@@ -72,29 +79,27 @@
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- # Interleave lo and hi along K dimension: (BLOCK_K, BLOCK_N)
- lo_t = tl.trans(lo)
- hi_t = tl.trans(hi)
- joined = tl.join(lo_t, hi_t)
- w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
- w_unpacked = tl.trans(w_reshaped)
-
- # Load scales and zeros for this group
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
- w_deq = (w_unpacked - z) * s
+ lo_deq = (lo - z) * s
+ hi_deq = (hi - z) * s
- # Load x tile contiguously: (BLOCK_M, BLOCK_K)
+ # Load x contiguously
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
- acc += tl.dot(x, w_deq)
+ # Split into even/odd K elements for correct int4 unpacking
+ x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
+ x_even, x_odd = tl.split(x_3d)
+
+ acc += tl.dot(x_even, lo_deq)
+ acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<7268 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7268 bytes.</system>{
"command": "rm -rf ~/.triton/cache && cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.810 ms=0.769
shape=0 variant=compiled tflops=0.637 gbps=169.285 ms=0.158
shape=0 variant=sota tflops=2.300 gbps=611.556 ms=0.044
shape=0 variant=solution tflops=1.865 gbps=496.062 ms=0.054
shape=0 solution_peak_fraction=0.2756
shape=1 variant=eager tflops=4.165 gbps=35.933 ms=0.773
shape=1 variant=compiled tflops=20.094 gbps=173.341 ms=0.160
shape=1 variant=sota tflops=22.641 gbps=195.311 ms=0.142
shape=1 variant=solution tflops=52.347 gbps=451.561 ms=0.062
shape=1 solution_peak_fraction=0.2509
shape=2 variant=eager tflops=31.031 gbps=42.298 ms=0.830
shape=2 variant=compiled tflops=118.201 gbps=161.123 ms=0.218
shape=2 variant=sota tflops=132.834 gbps=181.069 ms=0.194
shape=2 variant=solution tflops=151.816 gbps=206.943 ms=0.170
shape=2 solution_peak_fraction=0.1150
shape=3 variant=eager tflops=0.178 gbps=47.399 ms=0.188
shape=3 variant=compiled tflops=0.495 gbps=131.7
... (472 more chars){
"command": "<733 chars \u2014 see below>"
}cat > /tmp/check_configs2.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import solution
import reference
import shapes
import torch
device = torch.device('cuda')
for shape in shapes.SHAPES:
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
model = solution.Model(*init_args).to(device).eval()
x = torch.randn(shape['M'], shape['K'], dtype=torch.bfloat16, device=device)
# Warmup / autotune
for _ in range(3):
_ = model(x)
config = solution.w4a16_gemm_kernel.best_config
print(f"Shape {shape}: {config}")
EOF
python /tmp/check_configs2.py
<system>Command executed successfully.</system>
Shape {'M': 1, 'N': 12288, 'K': 4096}: BLOCK_M: 1, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 32, 'N': 12288, 'K': 4096}: BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 256, 'N': 12288, 'K': 4096}: BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 1, 'N': 4096, 'K': 4096}: BLOCK_M: 1, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 16, 'N': 14336, 'K': 4096}: BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<system>70 lines read from file starting from line 1. Total lines in file: 70. End of file reached.</system>
1 """Centralized GPU timing for benchmark.py files.
2
3 All problems' benchmark.py should call `time_fn` from here so we have one
4 implementation of warmup, L2 flush, and CUDA event capture to maintain.
5
6 Methodology:
7 - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8 torch.compile reduce-overhead CUDA-graph capture.
9 - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10 (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11 we measure HBM-load bandwidth, not L2-cached bandwidth).
12 - GPU timing via cuda.Event with synchronize() AFTER record() but
13 BEFORE elapsed_time() — the canonical NVIDIA pattern.
14 - Reported value is the median of `iters` trials, robust to outliers.
15
16 Notes / known biases not addressed here:
17 - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18 launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19 On small shapes where launch overhead matters, this gives compile an
20 artificial advantage. We accept this as the cost of using the agreed
21 "compiled" baseline policy.
22 - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23 - Median over a fairly small number of trials (default
... (1666 more chars){
"command": "<5761 chars \u2014 see below>"
}cat > /tmp/test_block_ptr.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
# Use block pointer for weights
w_block_ptr = tl.make_block_ptr(
base=w_ptr,
shape=(K // 2, N),
strides=(stride_wk, stride_wn),
offsets=(pk_start, pid_n * BLOCK_N),
block_shape=(BLOCK_K // 2, BLOCK_N),
order=(1, 0),
)
w_packed = tl.load(w_block_ptr, boundary_check=(0, 1))
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
# Block pointer for x
x_block_ptr = tl.make_block_ptr(
base=x_ptr,
shape=(M, K),
strides=(stride_xm, stride_xk),
offsets=(pid_m * BLOCK_M, k_start),
block_shape=(BLOCK_M, BLOCK_K),
order=(1, 0),
)
x = tl.load(x_block_ptr, boundary_check=(0, 1))
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=1, N=12288, K=4096
M, N, K = 1, 12288, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import time
import statistics
# L2 flush
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
# Warmup
for _ in range(10):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_
... (1761 more chars)<system>Command executed successfully.</system>
BM=1, BN=32, BK=128, warps=4, stages=2: 0.063ms, 422.0 GB/s
BM=1, BN=32, BK=128, warps=4, stages=3: 0.059ms, 452.5 GB/s
BM=1, BN=32, BK=128, warps=8, stages=2: 0.062ms, 431.9 GB/s
BM=1, BN=32, BK=128, warps=8, stages=3: 0.048ms, 561.5 GB/s
BM=1, BN=64, BK=128, warps=4, stages=2: 0.061ms, 435.8 GB/s
BM=1, BN=64, BK=128, warps=4, stages=3: 0.056ms, 479.7 GB/s
BM=1, BN=64, BK=128, warps=8, stages=2: 0.058ms, 463.5 GB/s
BM=1, BN=64, BK=128, warps=8, stages=3: 0.044ms, 614.5 GB/s
BM=1, BN=128, BK=128, warps=4, stages=2: 0.104ms, 257.5 GB/s
BM=1, BN=128, BK=128, warps=4, stages=3: 0.086ms, 310.4 GB/s
BM=1, BN=128, BK=128, warps=8, stages=2: 0.055ms, 488.7 GB/s
BM=1, BN=128, BK=128, warps=8, stages=3: 0.043ms, 622.0 GB/s
{
"command": "<5385 chars \u2014 see below>"
}cat > /tmp/test_block_ptr_sq.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=1, N=4096, K=4096
M, N, K = 1, 4096, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(
... (1385 more chars)<system>Command executed successfully.</system>
FAILED: at 27:26:
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
^
NameError('offs_pk is not defined')
FAILED: at 27:26:
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
^
NameError('offs_pk is not defined')
FAILED: at 27:26:
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
^
NameError('offs_pk is not defined')
FAILED: at 27:26:
offs_m = pid_m *
... (9852 more chars){
"command": "<5433 chars \u2014 see below>"
}cat > /tmp/test_block_ptr_sq.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=1, N=4096, K=4096
M, N, K = 1, 4096, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(e
... (1433 more chars)<system>Command executed successfully.</system>
BM=1, BN=32, BK=64, warps=4, stages=2: 0.063ms, 142.7 GB/s
BM=1, BN=32, BK=64, warps=4, stages=3: 0.042ms, 215.1 GB/s
BM=1, BN=32, BK=64, warps=8, stages=2: 0.062ms, 143.1 GB/s
BM=1, BN=32, BK=64, warps=8, stages=3: 0.047ms, 191.7 GB/s
BM=1, BN=32, BK=128, warps=4, stages=2: 0.050ms, 177.5 GB/s
BM=1, BN=32, BK=128, warps=4, stages=3: 0.045ms, 196.5 GB/s
BM=1, BN=32, BK=128, warps=8, stages=2: 0.045ms, 197.6 GB/s
BM=1, BN=32, BK=128, warps=8, stages=3: 0.041ms, 218.6 GB/s
BM=1, BN=64, BK=64, warps=4, stages=2: 0.060ms, 149.2 GB/s
BM=1, BN=64, BK=64, warps=4, stages=3: 0.039ms, 226.7 GB/s
BM=1, BN=64, BK=64, warps=8, stages=2: 0.062ms, 145.1 GB/s
BM=1, BN=64, BK=64, warps=8, stages=3: 0.039ms, 230.0 GB/s
BM=1, BN=64, BK=128, warps=4, stages=2: 0.052ms, 172.7 GB/s
BM=1, BN=64, BK=128, warps=4, stages=3: 0.047ms, 189.4 GB/s
BM=1, BN=64, BK=128, warps=8, stages=2: 0.046ms, 194.8 GB/s
BM=1, BN=64, BK=128, warps=8, stages=3: 0.040ms, 224.4 GB/s
BM=1, BN=128, BK=64, warps=4, stages=2: 0.092ms, 97.1 GB/s
BM=1, BN=128, BK=64, warps=4, stages=3: 0.077ms, 115.6 GB/s
BM=1, BN=128, BK=64, warps=8, stages=2: 0.070ms, 128.1 GB/s
BM=1, BN=128, BK=64, warps=8, stages=3: 0.045ms, 199.0 GB/s
BM=1, BN=128, BK=128, warps=4, stages=2: 0.092ms, 96.7 GB/s
BM=1, BN=128, BK=128, warps=4, stages=3: 0.070ms, 126.8 GB/s
BM=1, BN=128, BK=128, warps=8, stages=2: 0.056ms, 159.5 GB/s
BM=1, BN=128, BK=128, warps=8, stages=3: 0.035ms, 253.8 GB/s
{
"command": "<1777 chars \u2014 see below>"
}cat > /tmp/test_actual_model.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import solution
import reference
import torch
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_fn(fn, inputs, iters=50, warmup=10):
for _ in range(warmup):
with torch.no_grad():
fn(*inputs)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
fn(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return statistics.median(times)
device = torch.device('cuda')
for shape in [
{"M": 1, "N": 12288, "K": 4096},
{"M": 1, "N": 4096, "K": 4096},
{"M": 256, "N": 12288, "K": 4096},
]:
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
model = solution.Model(*init_args).to(device).eval()
x = torch.randn(shape['M'], shape['K'], dtype=torch.bfloat16, device=device)
ms = time_fn(model, [x], iters=50, warmup=10)
M, N, K = shape['M'], shape['N'], shape['K']
bytes_moved = M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2
gbps = bytes_moved / (ms * 1e-3) / 1e9
tflops = 2 * M * N * K / (ms * 1e-3) / 1e12
config = solution.w4a16_gemm_kernel.best_config
print(f"Shape {shape}: {ms:.3f}ms, {gbps:.1f} GB/s, {tflops:.1f} TFLOPS, config={config}")
EOF
rm -rf ~/.triton/cache && python /tmp/test_actual_model.py
<system>Command executed successfully.</system>
Shape {'M': 1, 'N': 12288, 'K': 4096}: 0.055ms, 486.7 GB/s, 1.8 TFLOPS, config=BLOCK_M: 1, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 1, 'N': 4096, 'K': 4096}: 0.048ms, 185.0 GB/s, 0.7 TFLOPS, config=BLOCK_M: 1, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 256, 'N': 12288, 'K': 4096}: 0.171ms, 205.2 GB/s, 150.6 TFLOPS, config=BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -13,15 +13,15 @@
@triton.autotune(
configs=[
- # Decode configs (small M) - prioritize BN=32 for better parallelism
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ # Decode configs (small M)
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
# Small prefill configs
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
@@ -132,11 +132,20 @@
self.register_buffer("scales", scales_2d)
self.register_buffer("zeros", zeros_2d)
+ # Pre-allocate output buffer to avoid allocation overhead in forward
+ self.register_buffer("_out_buffer", torch.empty(M, N, dtype=torch.bfloat16))
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
M, K = x.shape
N = self.N
- x = x.to(torch.bfloat16)
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ if x.dtype != torch.bfloat16:
+ x = x.to(torch.bfloat16)
+
+ # Reuse pre-allocated buffer if shapes match
+ if self._out_buffer.shape == (M, N):
+ out = self._out_buffer
+ else:
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]))
w4a16_gemm_kernel[grid](
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<7592 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7592 bytes.</system>{
"command": "rm -rf ~/.triton/cache && cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 180
}<system>ERROR: Command failed with exit code: 1.</system>
FAIL: state_dict mismatch at shape 0 ({'M': 1, 'N': 12288, 'K': 4096}): Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: "_out_buffer".
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -132,8 +132,8 @@
self.register_buffer("scales", scales_2d)
self.register_buffer("zeros", zeros_2d)
- # Pre-allocate output buffer to avoid allocation overhead in forward
- self.register_buffer("_out_buffer", torch.empty(M, N, dtype=torch.bfloat16))
+ # Pre-allocate output buffer (not registered to avoid state_dict mismatch)
+ self._out_buffer = torch.empty(M, N, dtype=torch.bfloat16)
def forward(self, x: torch.Tensor) -> torch.Tensor:
M, K = x.shape
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<7580 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7580 bytes.</system>{
"command": "rm -rf ~/.triton/cache && cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 180
}<system>ERROR: Command failed with exit code: 1.</system>
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 151, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 240, in run
benchmark()
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 229, in benchmark
timings = {config: self._bench(*args, config=config, **kwargs) for config
... (1075 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -132,20 +132,13 @@
self.register_buffer("scales", scales_2d)
self.register_buffer("zeros", zeros_2d)
- # Pre-allocate output buffer (not registered to avoid state_dict mismatch)
- self._out_buffer = torch.empty(M, N, dtype=torch.bfloat16)
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
M, K = x.shape
N = self.N
if x.dtype != torch.bfloat16:
x = x.to(torch.bfloat16)
- # Reuse pre-allocated buffer if shapes match
- if self._out_buffer.shape == (M, N):
- out = self._out_buffer
- else:
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]))
w4a16_gemm_kernel[grid](
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<7278 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7278 bytes.</system>{
"command": "rm -rf ~/.triton/cache && cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.972 ms=0.766
shape=0 variant=compiled tflops=0.673 gbps=178.954 ms=0.150
shape=0 variant=sota tflops=2.360 gbps=627.613 ms=0.043
shape=0 variant=solution tflops=1.943 gbps=516.744 ms=0.052
shape=0 solution_peak_fraction=0.2871
shape=1 variant=eager tflops=4.192 gbps=36.165 ms=0.768
shape=1 variant=compiled tflops=20.785 gbps=179.300 ms=0.155
shape=1 variant=sota tflops=23.321 gbps=201.170 ms=0.138
shape=1 variant=solution tflops=54.634 gbps=471.290 ms=0.059
shape=1 solution_peak_fraction=0.2618
shape=2 variant=eager tflops=31.329 gbps=42.706 ms=0.823
shape=2 variant=compiled tflops=122.146 gbps=166.499 ms=0.211
shape=2 variant=sota tflops=133.473 gbps=181.939 ms=0.193
shape=2 variant=solution tflops=155.916 gbps=212.532 ms=0.165
shape=2 solution_peak_fraction=0.1181
shape=3 variant=eager tflops=0.183 gbps=48.588 ms=0.184
shape=3 variant=compiled tflops=0.512 gbps=136.2
... (472 more chars){
"command": "<261 chars \u2014 see below>"
}python -c "
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard')
from src.hardware import get as get_hw
hw = get_hw('RTX_PRO_6000')
print(f'Peak TFLOPS (bf16): {hw.peak_tflops_dense}')
print(f'Peak bandwidth: {hw.peak_bandwidth_gb_s} GB/s')
"
<system>Command executed successfully.</system>
Peak TFLOPS (bf16): {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}
Peak bandwidth: 1800.0 GB/s
{
"command": "cat /home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py"
}<system>Command executed successfully.</system>
"""Roofline math: achieved TFLOPS / GB/s, peak fraction.
FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
dense-equivalent algorithmic work; kernels that skip work (sparsity, early
exit) are still measured against the dense cost.
"""
from __future__ import annotations
def compute_tflops(flops: float, time_ms: float) -> float:
"""flops per invocation and wall time in ms -> TFLOPS."""
if time_ms <= 0:
return 0.0
return flops / (time_ms * 1e-3) / 1e12
def compute_gbps(bytes_moved: float, time_ms: float) -> float:
"""bytes per invocation and wall time in ms -> GB/s."""
if time_ms <= 0:
return 0.0
return bytes_moved / (time_ms * 1e-3) / 1e9
def peak_fraction(
achieved: float,
peak: float,
) -> float:
"""Return achieved / peak, clamped to [0, infinity)."""
if peak <= 0:
return 0.0
return max(0.0, achieved / peak)
{
"command": "<319 chars \u2014 see below>"
}python -c "
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard')
from src.eval.roofline import peak_fraction, compute_tflops
flops = 2 * 256 * 12288 * 4096
ms = 0.165
tflops = compute_tflops(flops, ms)
print(f'TFLOPS: {tflops}')
print(f'peak_fraction(155.916, 200): {peak_fraction(155.916, 200)}')
"
<system>Command executed successfully.</system>
TFLOPS: 156.18062894545454
peak_fraction(155.916, 200): 0.7795799999999999
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -132,13 +132,19 @@
self.register_buffer("scales", scales_2d)
self.register_buffer("zeros", zeros_2d)
+ # Lazy-allocated output buffer (not registered to avoid state_dict mismatch)
+ self._out_buffer: torch.Tensor | None = None
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
M, K = x.shape
N = self.N
if x.dtype != torch.bfloat16:
x = x.to(torch.bfloat16)
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ # Lazy allocate and reuse output buffer
+ if self._out_buffer is None or self._out_buffer.shape != (M, N) or self._out_buffer.device != x.device:
+ self._out_buffer = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ out = self._out_buffer
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]))
w4a16_gemm_kernel[grid](
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<7625 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7625 bytes.</system>{
"command": "rm -rf ~/.triton/cache && cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.923 ms=0.767
shape=0 variant=compiled tflops=0.671 gbps=178.476 ms=0.150
shape=0 variant=sota tflops=2.322 gbps=617.423 ms=0.043
shape=0 variant=solution tflops=2.002 gbps=532.532 ms=0.050
shape=0 solution_peak_fraction=0.2959
shape=1 variant=eager tflops=4.189 gbps=36.139 ms=0.769
shape=1 variant=compiled tflops=20.800 gbps=179.430 ms=0.155
shape=1 variant=sota tflops=23.269 gbps=200.729 ms=0.138
shape=1 variant=solution tflops=55.707 gbps=480.549 ms=0.058
shape=1 solution_peak_fraction=0.2670
shape=2 variant=eager tflops=31.339 gbps=42.719 ms=0.822
shape=2 variant=compiled tflops=122.229 gbps=166.613 ms=0.211
shape=2 variant=sota tflops=133.384 gbps=181.818 ms=0.193
shape=2 variant=solution tflops=156.203 gbps=212.924 ms=0.165
shape=2 solution_peak_fraction=0.1183
shape=3 variant=eager tflops=0.180 gbps=47.933 ms=0.186
shape=3 variant=compiled tflops=0.523 gbps=139.0
... (472 more chars){
"command": "<733 chars \u2014 see below>"
}cat > /tmp/check_configs3.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import solution
import reference
import shapes
import torch
device = torch.device('cuda')
for shape in shapes.SHAPES:
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
model = solution.Model(*init_args).to(device).eval()
x = torch.randn(shape['M'], shape['K'], dtype=torch.bfloat16, device=device)
# Warmup / autotune
for _ in range(3):
_ = model(x)
config = solution.w4a16_gemm_kernel.best_config
print(f"Shape {shape}: {config}")
EOF
python /tmp/check_configs3.py
<system>Command executed successfully.</system>
Shape {'M': 1, 'N': 12288, 'K': 4096}: BLOCK_M: 1, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 32, 'N': 12288, 'K': 4096}: BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 256, 'N': 12288, 'K': 4096}: BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 1, 'N': 4096, 'K': 4096}: BLOCK_M: 1, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 16, 'N': 14336, 'K': 4096}: BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
{
"command": "<1537 chars \u2014 see below>"
}cat > /tmp/test_shape3.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import solution
import reference
import torch
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_fn(fn, inputs, iters=50, warmup=10):
for _ in range(warmup):
with torch.no_grad():
fn(*inputs)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
fn(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return statistics.median(times)
device = torch.device('cuda')
shape = {"M": 1, "N": 4096, "K": 4096}
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
model = solution.Model(*init_args).to(device).eval()
x = torch.randn(shape['M'], shape['K'], dtype=torch.bfloat16, device=device)
ms = time_fn(model, [x], iters=50, warmup=10)
M, N, K = shape['M'], shape['N'], shape['K']
bytes_moved = M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2
gbps = bytes_moved / (ms * 1e-3) / 1e9
config = solution.w4a16_gemm_kernel.best_config
print(f"Shape {shape}: {ms:.3f}ms, {gbps:.1f} GB/s, config={config}")
EOF
rm -rf ~/.triton/cache && python /tmp/test_shape3.py
<system>Command executed successfully.</system>
Shape {'M': 1, 'N': 4096, 'K': 4096}: 0.039ms, 226.1 GB/s, config=BLOCK_M: 1, BLOCK_N: 64, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -42,7 +42,7 @@
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=4, num_stages=3),
],
- key=["M"],
+ key=["M", "N"],
)
@triton.jit
def w4a16_gemm_kernel(
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<7630 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 7630 bytes.</system>{
"command": "rm -rf ~/.triton/cache && cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.960 ms=0.766
shape=0 variant=compiled tflops=0.668 gbps=177.775 ms=0.151
shape=0 variant=sota tflops=2.340 gbps=622.245 ms=0.043
shape=0 variant=solution tflops=1.969 gbps=523.534 ms=0.051
shape=0 solution_peak_fraction=0.2909
shape=1 variant=eager tflops=4.224 gbps=36.436 ms=0.763
shape=1 variant=compiled tflops=20.887 gbps=180.175 ms=0.154
shape=1 variant=sota tflops=23.361 gbps=201.521 ms=0.138
shape=1 variant=solution tflops=55.955 gbps=482.686 ms=0.058
shape=1 solution_peak_fraction=0.2682
shape=2 variant=eager tflops=31.320 gbps=42.693 ms=0.823
shape=2 variant=compiled tflops=121.647 gbps=165.820 ms=0.212
shape=2 variant=sota tflops=133.528 gbps=182.014 ms=0.193
shape=2 variant=solution tflops=156.735 gbps=213.649 ms=0.164
shape=2 solution_peak_fraction=0.1187
shape=3 variant=eager tflops=0.180 gbps=47.871 ms=0.187
shape=3 variant=compiled tflops=0.521 gbps=138.5
... (472 more chars){
"command": "<733 chars \u2014 see below>"
}cat > /tmp/check_configs4.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import solution
import reference
import shapes
import torch
device = torch.device('cuda')
for shape in shapes.SHAPES:
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
model = solution.Model(*init_args).to(device).eval()
x = torch.randn(shape['M'], shape['K'], dtype=torch.bfloat16, device=device)
# Warmup / autotune
for _ in range(3):
_ = model(x)
config = solution.w4a16_gemm_kernel.best_config
print(f"Shape {shape}: {config}")
EOF
python /tmp/check_configs4.py
<system>Command executed successfully.</system>
Shape {'M': 1, 'N': 12288, 'K': 4096}: BLOCK_M: 1, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 32, 'N': 12288, 'K': 4096}: BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 256, 'N': 12288, 'K': 4096}: BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 1, 'N': 4096, 'K': 4096}: BLOCK_M: 1, BLOCK_N: 64, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 16, 'N': 14336, 'K': 4096}: BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
{
"command": "<5520 chars \u2014 see below>"
}cat > /tmp/test_single_dot_m1.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel_single(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
# Interleave along K dimension
lo_t = tl.trans(lo)
hi_t = tl.trans(hi)
joined = tl.join(lo_t, hi_t)
w_reshaped = tl.reshape(joined, (BLOCK_N, BLOCK_K))
w_unpacked = tl.trans(w_reshaped)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
w_deq = (w_unpacked - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
acc += tl.dot(x, w_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=1, N=4096, K=4096
M, N, K = 1, 4096, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel_single[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
times = []
for _ in range(iters):
... (1520 more chars)<system>Command executed successfully.</system>
BM=1, BN=64, BK=128, warps=4, stages=2: 0.053ms, 169.8 GB/s
BM=1, BN=64, BK=128, warps=4, stages=3: 0.054ms, 164.7 GB/s
BM=1, BN=64, BK=128, warps=8, stages=2: 0.052ms, 173.3 GB/s
BM=1, BN=64, BK=128, warps=8, stages=3: 0.051ms, 176.1 GB/s
BM=1, BN=128, BK=128, warps=4, stages=2: 0.081ms, 110.5 GB/s
BM=1, BN=128, BK=128, warps=4, stages=3: 0.089ms, 100.3 GB/s
BM=1, BN=128, BK=128, warps=8, stages=2: 0.077ms, 116.3 GB/s
BM=1, BN=128, BK=128, warps=8, stages=3: 0.086ms, 103.5 GB/s
{
"command": "<5611 chars \u2014 see below>"
}cat > /tmp/test_strided_m1.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel_strided(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
# Strided x loads
offs_k_even = k_start + 2 * tl.arange(0, BLOCK_K // 2)
offs_k_odd = k_start + 2 * tl.arange(0, BLOCK_K // 2) + 1
x_even_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k_even[None, :] * stride_xk)
x_odd_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k_odd[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M)
x_even = tl.load(x_even_ptrs, mask=x_mask, other=0.0)
x_odd = tl.load(x_odd_ptrs, mask=x_mask, other=0.0)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=1, N=4096, K=4096
M, N, K = 1, 4096, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel_strided[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
... (1611 more chars)<system>Command executed successfully.</system>
BM=1, BN=32, BK=64, warps=4, stages=2: 0.071ms, 125.3 GB/s
BM=1, BN=32, BK=64, warps=4, stages=3: 0.054ms, 163.9 GB/s
BM=1, BN=32, BK=64, warps=8, stages=2: 0.076ms, 118.2 GB/s
BM=1, BN=32, BK=64, warps=8, stages=3: 0.057ms, 157.6 GB/s
BM=1, BN=32, BK=128, warps=4, stages=2: 0.050ms, 180.0 GB/s
BM=1, BN=32, BK=128, warps=4, stages=3: 0.042ms, 211.3 GB/s
BM=1, BN=32, BK=128, warps=8, stages=2: 0.058ms, 153.5 GB/s
BM=1, BN=32, BK=128, warps=8, stages=3: 0.043ms, 209.6 GB/s
BM=1, BN=64, BK=64, warps=4, stages=2: 0.069ms, 130.2 GB/s
BM=1, BN=64, BK=64, warps=4, stages=3: 0.051ms, 173.6 GB/s
BM=1, BN=64, BK=64, warps=8, stages=2: 0.071ms, 125.6 GB/s
BM=1, BN=64, BK=64, warps=8, stages=3: 0.056ms, 158.3 GB/s
BM=1, BN=64, BK=128, warps=4, stages=2: 0.050ms, 178.9 GB/s
BM=1, BN=64, BK=128, warps=4, stages=3: 0.038ms, 232.1 GB/s
BM=1, BN=64, BK=128, warps=8, stages=2: 0.050ms, 179.1 GB/s
BM=1, BN=64, BK=128, warps=8, stages=3: 0.040ms, 225.6 GB/s
BM=1, BN=128, BK=64, warps=4, stages=2: 0.101ms, 88.7 GB/s
BM=1, BN=128, BK=64, warps=4, stages=3: 0.081ms, 110.0 GB/s
BM=1, BN=128, BK=64, warps=8, stages=2: 0.070ms, 127.7 GB/s
BM=1, BN=128, BK=64, warps=8, stages=3: 0.057ms, 155.5 GB/s
BM=1, BN=128, BK=128, warps=4, stages=2: 0.069ms, 129.2 GB/s
BM=1, BN=128, BK=128, warps=4, stages=3: 0.066ms, 134.4 GB/s
BM=1, BN=128, BK=128, warps=8, stages=2: 0.051ms, 173.5 GB/s
BM=1, BN=128, BK=128, warps=8, stages=3: 0.043ms, 209.6 GB/s
{
"command": "<5891 chars \u2014 see below>"
}cat > /tmp/test_unpacked_cache.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel_v2(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
offs_k = k_start + tl.arange(0, BLOCK_K)
w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
w_unpacked = tl.load(w_ptrs, mask=w_mask, other=0)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
w_deq = (w_unpacked.to(tl.bfloat16) - z) * s
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
acc += tl.dot(x, w_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
def unpack_weights(w_packed, K, N):
w_unpacked = torch.empty(K, N, dtype=torch.uint8, device=w_packed.device)
w_unpacked[0::2] = w_packed & 0xF
w_unpacked[1::2] = (w_packed >> 4) & 0xF
return w_unpacked
# Test for M=1, N=4096 and M=256, N=12288
for M, N, K in [(1, 4096, 4096), (256, 12288, 4096)]:
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
w_unpacked = unpack_weights(w_packed, K, N)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel_v2[grid](
x, w_unpacked, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_unpacked.stride(0), w_unpacked.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
... (1891 more chars)<system>Command executed successfully.</system>
M=1, N=4096, K=4096:
BM=1, BN=64, BK=64, warps=4, stages=2: 0.076ms, 117.7 GB/s, 0.4 TFLOPS
BM=1, BN=64, BK=64, warps=4, stages=3: 0.061ms, 147.0 GB/s, 0.6 TFLOPS
BM=1, BN=64, BK=64, warps=8, stages=2: 0.076ms, 117.7 GB/s, 0.4 TFLOPS
BM=1, BN=64, BK=64, warps=8, stages=3: 0.059ms, 152.6 GB/s, 0.6 TFLOPS
BM=1, BN=64, BK=128, warps=4, stages=2: 0.059ms, 150.5 GB/s, 0.6 TFLOPS
BM=1, BN=64, BK=128, warps=4, stages=3: 0.047ms, 189.0 GB/s, 0.7 TFLOPS
BM=1, BN=64, BK=128, warps=8, stages=2: 0.058ms, 153.4 GB/s, 0.6 TFLOPS
BM=1, BN=64, BK=128, warps=8, stages=3: 0.048ms, 187.7 GB/s, 0.7 TFLOPS
BM=1, BN=128, BK=64, warps=4, stages=2: 0.097ms, 91.7 GB/s, 0.3 TFLOPS
BM=1, BN=128, BK=64, warps=4, stages=3: 0.076ms, 116.8 GB/s, 0.4 TFLOPS
BM=1, BN=128, BK=64, warps=8, stages=2: 0.079ms, 112.8 GB/s, 0.4 TFLOPS
BM=1, BN=128, BK=64, warps=8, stages=3: 0.063ms, 141.8 GB/s, 0.5 TFLOPS
BM=1, BN=128, BK=128, warps=4, stages=2: 0.085ms, 105.2 GB/s, 0.4 TFLOPS
BM=1, BN=128, BK=128, warps=4, stages=3: 0.066ms, 134.6 GB/s, 0.5 TFLOPS
BM=1, BN=128, BK=128, warps=8, stages=2: 0.059ms, 151.8 GB/s, 0.6 TFLOPS
BM=1, BN=128, BK=128, warps=8, stages=3: 0.048ms, 186.3 GB/s, 0.7 TFLOPS
BM=1, BN=256, BK=64, warps=4, stages=2: 0.101ms, 88.6 GB/s, 0.3 TFLOPS
BM=1, BN=256, BK=64, warps=4, stages=3: 0.081ms, 110.5 GB/s, 0.4 TFLOPS
BM=1, BN=256, BK=64, warps=8, stages=2: 0.130ms, 68.8 GB/s, 0.3 TFLOPS
BM=1, BN=256, BK=64, warps
... (8198 more chars){
"command": "<5439 chars \u2014 see below>"
}cat > /tmp/test_more_configs.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=1, N=4096, K=4096
M, N, K = 1, 4096, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(e
... (1439 more chars)<system>Command executed successfully.</system>
BM=1, BN=32, BK=64, warps=2, stages=2: 0.072ms, 124.2 GB/s
BM=1, BN=32, BK=64, warps=2, stages=3: 0.057ms, 157.6 GB/s
BM=1, BN=32, BK=64, warps=2, stages=4: 0.056ms, 158.9 GB/s
BM=1, BN=32, BK=64, warps=4, stages=2: 0.058ms, 153.4 GB/s
BM=1, BN=32, BK=64, warps=4, stages=3: 0.042ms, 213.9 GB/s
BM=1, BN=32, BK=64, warps=4, stages=4: 0.036ms, 247.9 GB/s
BM=1, BN=32, BK=64, warps=8, stages=2: 0.061ms, 146.6 GB/s
BM=1, BN=32, BK=64, warps=8, stages=3: 0.042ms, 211.3 GB/s
BM=1, BN=32, BK=64, warps=8, stages=4: 0.038ms, 235.8 GB/s
BM=1, BN=32, BK=128, warps=2, stages=2: 0.045ms, 197.7 GB/s
BM=1, BN=32, BK=128, warps=2, stages=3: 0.036ms, 250.4 GB/s
BM=1, BN=32, BK=128, warps=2, stages=4: 0.030ms, 297.8 GB/s
BM=1, BN=32, BK=128, warps=4, stages=2: 0.050ms, 179.0 GB/s
BM=1, BN=32, BK=128, warps=4, stages=3: 0.042ms, 213.9 GB/s
BM=1, BN=32, BK=128, warps=4, stages=4: 0.042ms, 214.6 GB/s
BM=1, BN=32, BK=128, warps=8, stages=2: 0.044ms, 201.0 GB/s
BM=1, BN=32, BK=128, warps=8, stages=3: 0.036ms, 251.0 GB/s
BM=1, BN=32, BK=128, warps=8, stages=4: 0.030ms, 298.1 GB/s
BM=1, BN=64, BK=64, warps=2, stages=2: 0.097ms, 92.3 GB/s
BM=1, BN=64, BK=64, warps=2, stages=3: 0.075ms, 119.2 GB/s
BM=1, BN=64, BK=64, warps=2, stages=4: 0.076ms, 118.0 GB/s
BM=1, BN=64, BK=64, warps=4, stages=2: 0.059ms, 151.1 GB/s
BM=1, BN=64, BK=64, warps=4, stages=3: 0.038ms, 233.6 GB/s
BM=1, BN=64, BK=64, warps=4, stages=4: 0.032ms, 279.6 GB/s
BM=1, BN=64, BK=64, warps=8,
... (1775 more chars){
"command": "<5429 chars \u2014 see below>"
}cat > /tmp/test_m1_large.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=1, N=12288, K=4096
M, N, K = 1, 12288, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(ena
... (1429 more chars)<system>Command executed successfully.</system>
BM=1, BN=32, BK=128, warps=2, stages=2: 0.058ms, 462.7 GB/s
BM=1, BN=32, BK=128, warps=2, stages=3: 0.057ms, 467.0 GB/s
BM=1, BN=32, BK=128, warps=2, stages=4: 0.056ms, 476.0 GB/s
BM=1, BN=32, BK=128, warps=4, stages=2: 0.057ms, 466.1 GB/s
BM=1, BN=32, BK=128, warps=4, stages=3: 0.057ms, 468.7 GB/s
BM=1, BN=32, BK=128, warps=4, stages=4: 0.058ms, 458.5 GB/s
BM=1, BN=32, BK=128, warps=8, stages=2: 0.060ms, 444.5 GB/s
BM=1, BN=32, BK=128, warps=8, stages=3: 0.046ms, 579.0 GB/s
BM=1, BN=32, BK=128, warps=8, stages=4: 0.058ms, 459.5 GB/s
BM=1, BN=64, BK=128, warps=2, stages=2: 0.082ms, 326.0 GB/s
BM=1, BN=64, BK=128, warps=2, stages=3: 0.062ms, 429.9 GB/s
BM=1, BN=64, BK=128, warps=2, stages=4: 0.063ms, 424.7 GB/s
BM=1, BN=64, BK=128, warps=4, stages=2: 0.056ms, 474.7 GB/s
BM=1, BN=64, BK=128, warps=4, stages=3: 0.055ms, 483.7 GB/s
BM=1, BN=64, BK=128, warps=4, stages=4: 0.056ms, 477.2 GB/s
BM=1, BN=64, BK=128, warps=8, stages=2: 0.057ms, 471.2 GB/s
BM=1, BN=64, BK=128, warps=8, stages=3: 0.043ms, 629.7 GB/s
BM=1, BN=64, BK=128, warps=8, stages=4: 0.040ms, 664.2 GB/s
BM=1, BN=128, BK=128, warps=2, stages=2: 0.073ms, 368.4 GB/s
BM=1, BN=128, BK=128, warps=2, stages=3: 0.048ms, 557.0 GB/s
BM=1, BN=128, BK=128, warps=2, stages=4: 0.045ms, 596.1 GB/s
BM=1, BN=128, BK=128, warps=4, stages=2: 0.101ms, 264.2 GB/s
BM=1, BN=128, BK=128, warps=4, stages=3: 0.071ms, 379.2 GB/s
BM=1, BN=128, BK=128, warps=4, stages=4: 0.071ms, 377.4 GB/s
BM=1,
... (177 more chars){
"command": "<5511 chars \u2014 see below>"
}cat > /tmp/test_m256_stages4.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# M=256, N=12288, K=4096
M, N, K = 256, 12288, 4096
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.E
... (1511 more chars)<system>Command executed successfully.</system>
BM=32, BN=128, BK=128, warps=8, stages=2: 0.209ms, 167.7 GB/s, 123.0 TFLOPS
BM=32, BN=128, BK=128, warps=8, stages=3: 0.153ms, 230.1 GB/s, 168.8 TFLOPS
BM=32, BN=128, BK=128, warps=8, stages=4: 0.172ms, 204.7 GB/s, 150.2 TFLOPS
BM=64, BN=128, BK=128, warps=8, stages=2: 0.197ms, 178.3 GB/s, 130.8 TFLOPS
BM=64, BN=128, BK=128, warps=8, stages=3: 0.152ms, 231.0 GB/s, 169.4 TFLOPS
BM=64, BN=128, BK=128, warps=8, stages=4: 0.152ms, 231.3 GB/s, 169.7 TFLOPS
BM=128, BN=128, BK=128, warps=8, stages=2: 0.870ms, 40.4 GB/s, 29.6 TFLOPS
BM=128, BN=128, BK=128, warps=8, stages=3: 0.267ms, 131.7 GB/s, 96.6 TFLOPS
FAILED: out of resource: shared memory, Required: 124416, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
{
"command": "<5635 chars \u2014 see below>"
}cat > /tmp/test_m16_m32.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, s_ptr, z_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE: tl.constexpr = 128,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
num_k_tiles = K // BLOCK_K
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_tile in range(num_k_tiles):
k_start = k_tile * BLOCK_K
pk_start = k_start // 2
offs_pk = pk_start + tl.arange(0, BLOCK_K // 2)
w_ptrs = w_ptr + (offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn)
w_mask = (offs_pk[:, None] < K // 2) & (offs_n[None, :] < N)
w_packed = tl.load(w_ptrs, mask=w_mask, other=0)
lo = (w_packed & 0xF).to(tl.bfloat16)
hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
group_idx = k_start // GROUP_SIZE
s_ptrs = s_ptr + group_idx * stride_sg + offs_n[None, :] * stride_sn
z_ptrs = z_ptr + group_idx * stride_zg + offs_n[None, :] * stride_zn
s = tl.load(s_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
z = tl.load(z_ptrs, mask=offs_n[None, :] < N, other=0.0).to(tl.bfloat16)
lo_deq = (lo - z) * s
hi_deq = (hi - z) * s
offs_k = k_start + tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk)
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
x_3d = tl.reshape(x, (BLOCK_M, BLOCK_K // 2, 2))
x_even, x_odd = tl.split(x_3d)
acc += tl.dot(x_even, lo_deq)
acc += tl.dot(x_odd, hi_deq)
out_ptrs = out_ptr + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)
out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_kernel(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, iters=50):
GROUP_SIZE = 128
n_groups = K // GROUP_SIZE
torch.manual_seed(42)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32, device='cuda') * 0.02
w_g = w_full.view(n_groups, GROUP_SIZE, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
w_q = w_q.view(K, N)
scales_2d = scales.squeeze(1).to(torch.bfloat16)
zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
w_packed = torch.empty(K // 2, N, dtype=torch.uint8, device='cuda')
w_packed[:, :] = (w_q[0::2] & 0xF) | ((w_q[1::2] & 0xF) << 4)
out = torch.empty(M, N, dtype=torch.bfloat16).cuda()
grid = (M // BLOCK_M, N // BLOCK_N)
for _ in range(10):
w4a16_gemm_kernel[grid](
x, w_packed, scales_2d, zeros_2d, out,
M, N, K,
x.stride(0), x.stride(1),
w_packed.stride(0), w_packed.stride(1),
scales_2d.stride(0), scales_2d.stride(1),
zeros_2d.stride(0), zeros_2d.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch
... (1635 more chars)<system>Command executed successfully.</system>
M=16, N=14336: BM=16, BN=128, BK=128, warps=8, stages=2: 0.058ms, 550.0 GB/s, 32.5 TFLOPS
M=16, N=14336: BM=16, BN=128, BK=128, warps=8, stages=3: 0.046ms, 690.5 GB/s, 40.8 TFLOPS
M=16, N=14336: BM=16, BN=128, BK=128, warps=8, stages=4: 0.037ms, 867.5 GB/s, 51.3 TFLOPS
M=32, N=12288: BM=16, BN=128, BK=128, warps=8, stages=2: 0.061ms, 454.3 GB/s, 52.7 TFLOPS
M=32, N=12288: BM=16, BN=128, BK=128, warps=8, stages=3: 0.055ms, 508.4 GB/s, 58.9 TFLOPS
M=32, N=12288: BM=16, BN=128, BK=128, warps=8, stages=4: 0.054ms, 509.9 GB/s, 59.1 TFLOPS
M=32, N=12288: BM=32, BN=128, BK=128, warps=8, stages=2: 0.062ms, 447.1 GB/s, 51.8 TFLOPS
M=32, N=12288: BM=32, BN=128, BK=128, warps=8, stages=3: 0.049ms, 572.0 GB/s, 66.3 TFLOPS
M=32, N=12288: BM=32, BN=128, BK=128, warps=8, stages=4: 0.042ms, 662.9 GB/s, 76.8 TFLOPS
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,6 +1,7 @@
"""W4A16 GEMM via Triton with fused int4 unpack + dequant.
AWQ/GPTQ-style asymmetric int4 with group_size=128.
+Uses coalesced x loads via reshape+split and tuned num_stages for pipelining.
"""
from __future__ import annotations
@@ -13,20 +14,25 @@
@triton.autotune(
configs=[
- # Decode configs (small M)
+ # Decode configs (small M) - num_stages=4 is key for bandwidth
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128}, num_warps=2, num_stages=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8, num_stages=4),
# Small prefill configs
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=3),
{
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<8234 chars \u2014 see diff>"
}<system>File successfully overwritten. Current size: 8234 bytes.</system>{
"command": "rm -rf ~/.triton/cache && cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 180
}<system>Command executed successfully.</system>
PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300
}<system>Command executed successfully.</system>
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.948 ms=0.766
shape=0 variant=compiled tflops=0.667 gbps=177.435 ms=0.151
shape=0 variant=sota tflops=2.354 gbps=625.969 ms=0.043
shape=0 variant=solution tflops=2.275 gbps=605.141 ms=0.044
shape=0 solution_peak_fraction=0.3362
shape=1 variant=eager tflops=4.192 gbps=36.162 ms=0.768
shape=1 variant=compiled tflops=20.980 gbps=180.982 ms=0.154
shape=1 variant=sota tflops=23.396 gbps=201.825 ms=0.138
shape=1 variant=solution tflops=61.175 gbps=527.713 ms=0.053
shape=1 solution_peak_fraction=0.2932
shape=2 variant=eager tflops=31.311 gbps=42.680 ms=0.823
shape=2 variant=compiled tflops=121.693 gbps=165.883 ms=0.212
shape=2 variant=sota tflops=133.783 gbps=182.362 ms=0.193
shape=2 variant=solution tflops=157.240 gbps=214.337 ms=0.164
shape=2 solution_peak_fraction=0.1191
shape=3 variant=eager tflops=0.180 gbps=47.920 ms=0.186
shape=3 variant=compiled tflops=0.526 gbps=139.9
... (472 more chars){
"command": "<733 chars \u2014 see below>"
}cat > /tmp/check_configs5.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import solution
import reference
import shapes
import torch
device = torch.device('cuda')
for shape in shapes.SHAPES:
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
model = solution.Model(*init_args).to(device).eval()
x = torch.randn(shape['M'], shape['K'], dtype=torch.bfloat16, device=device)
# Warmup / autotune
for _ in range(3):
_ = model(x)
config = solution.w4a16_gemm_kernel.best_config
print(f"Shape {shape}: {config}")
EOF
python /tmp/check_configs5.py
<system>Command executed successfully.</system>
Shape {'M': 1, 'N': 12288, 'K': 4096}: BLOCK_M: 1, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Shape {'M': 32, 'N': 12288, 'K': 4096}: BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Shape {'M': 256, 'N': 12288, 'K': 4096}: BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Shape {'M': 1, 'N': 4096, 'K': 4096}: BLOCK_M: 1, BLOCK_N: 64, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Shape {'M': 16, 'N': 14336, 'K': 4096}: BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 128, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
{
"command": "<1946 chars \u2014 see below>"
}cat > /tmp/test_compile.py << 'EOF'
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm')
import solution
import reference
import torch
import statistics
_l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_fn(fn, inputs, iters=50, warmup=10):
for _ in range(warmup):
with torch.no_grad():
fn(*inputs)
torch.cuda.synchronize()
times = []
for _ in range(iters):
_l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
fn(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return statistics.median(times)
device = torch.device('cuda')
for shape in [
{"M": 1, "N": 4096, "K": 4096},
{"M": 1, "N": 12288, "K": 4096},
]:
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
model = solution.Model(*init_args).to(device).eval()
x = torch.randn(shape['M'], shape['K'], dtype=torch.bfloat16, device=device)
# Baseline
ms_base = time_fn(model, [x], iters=50, warmup=10)
# Compiled
try:
compiled = torch.compile(model, mode="reduce-overhead")
ms_comp = time_fn(compiled, [x], iters=50, warmup=10)
except Exception as e:
ms_comp = None
print(f"Compile failed: {e}")
M, N, K = shape['M'], shape['N'], shape['K']
bytes_moved = M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2
print(f"Shape {shape}:")
print(f" Baseline: {ms_base:.3f}ms, {bytes_moved/(ms_base*1e-3)/1e9:.1f} GB/s")
if ms_comp:
print(f" Compiled: {ms_comp:.3f}ms, {bytes_moved/(ms_comp*1e-3)/1e9:.1f} GB/s")
EOF
python /tmp/test_compile.py
<system>Command executed successfully.</system>
W0428 14:30:09.467000 491041 torch/_higher_order_ops/triton_kernel_wrap.py:1026] [0/0] Encountered an exception in identify_mutated_tensors, assuming every input is mutated
W0428 14:30:09.467000 491041 torch/_higher_order_ops/triton_kernel_wrap.py:1026] [0/0] Traceback (most recent call last):
W0428 14:30:09.467000 491041 torch/_higher_order_ops/triton_kernel_wrap.py:1026] [0/0] File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 994, in identify_mutated_tensors
W0428 14:30:09.467000 491041 torch/_higher_order_ops/triton_kernel_wrap.py:1026] [0/0] ttir_module, ordered_tensor_names = generate_ttir(
W0428 14:30:09.467000 491041 torch/_higher_order_ops/triton_kernel_wrap.py:1026] [0/0] ^^^^^^^^^^^^^^
W0428 14:30:09.467000 491041 torch/_higher_order_ops/triton_kernel_wrap.py:1026] [0/0] File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 289, in generate_ttir
W0428 14:30:09.467000 491041 torch/_higher_order_ops/triton_kernel_wrap.py:1026] [0/0] raise ValueError(
W0428 14:30:09.467000 491041 torch/_higher_order_ops/triton_kernel_wrap.py:1026] [0/0] ValueError: Incorrect number of arguments passed to kernel: passed ['x_ptr', 'w_ptr', 's_ptr', 'z_ptr', 'out_ptr', 'M', 'N', 'K', 'stride_xm', 'stride_xk', 'stride_wk', 'stride_wn', 'stride_sg', 'stride_sn', 'stride_zg', '
... (28446 more chars)