"""W4A16 weight-only quantized GEMM for RTX PRO 6000 (SM120).
AWQ/GPTQ-style asymmetric int4 with bf16 per-group scales/zeros.
Fused unpack + GEMM via Triton.
The scheme:
w_bf[k, n] = (unpack(w_q)[k, n] - zeros[k // 128, n]) * scales[k // 128, n]
out[m, n] = sum_k x[m, k] * w_bf[k, n]
Key optimizations:
- Two-dot pattern: process even/odd K rows with two separate dots. Each
group_size=128 K rows become two 64-wide dots that both read the same x
row span but different nibbles of the packed weight byte. This matches
the bit-level rounding of the reference's cuBLAS bf16 GEMM.
- Triton autotune over a focused set of configs to find the best per shape.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "gemm_w4a16"
SUPPORTED_PRECISIONS = ["int4_bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
GROUP_SIZE = 128
# ---------------------------------------------------------------------------
# Triton kernel
# ---------------------------------------------------------------------------
_CONFIGS = [
# M=1 / decode variants — small BN, BLOCK_M=1
triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 1024, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 1024, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
# M>=16 / prefill (BLOCK_M=16 is tensor-core minimum)
triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
]
@triton.autotune(configs=_CONFIGS, key=["M", "N", "K"])
@triton.jit
def w4a16_gemm_kernel(
X, WQ, S, Z, OUT,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SZ: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m_mask = offs_m < M
n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
GROUP_HALF: tl.constexpr = GROUP_SZ // 2
BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_blk in tl.range(0, K, BLOCK_K):
wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
wq = tl.load(
WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
mask=offs_n[None, :] < N, other=0,
)
w_lo = (wq & 0xF).to(tl.bfloat16)
w_hi = (wq >> 4).to(tl.bfloat16)
g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
mask=offs_n[None, :] < N, other=0.0)
z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
mask=offs_n[None, :] < N, other=0.0)
s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
w_lo = (w_lo - z) * s
w_hi = (w_hi - z) * s
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
x_offs_odd = x_offs_even + 1
x_even = tl.load(
X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
mask=offs_m_mask[:, None], other=0.0,
)
x_odd = tl.load(
X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
mask=offs_m_mask[:, None], other=0.0,
)
acc += tl.dot(x_even, w_lo)
acc += tl.dot(x_odd, w_hi)
out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, acc.to(tl.bfloat16),
mask=offs_m_mask[:, None] & (offs_n[None, :] < N))
# ---------------------------------------------------------------------------
# Module
# ---------------------------------------------------------------------------
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0
assert K % 2 == 0
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
w_q = torch.zeros(K // 2, N, dtype=torch.uint8, device="cuda")
scales = torch.zeros(n_groups, N, dtype=torch.bfloat16, device="cuda")
zeros = torch.zeros(n_groups, N, dtype=torch.bfloat16, device="cuda")
self.register_buffer("w_q", w_q)
self.register_buffer("scales", scales)
self.register_buffer("zeros", zeros)
def forward(self, x: torch.Tensor) -> torch.Tensor:
M, N, K = self.M, self.N, self.K
assert x.shape == (M, K), f"x shape mismatch: {x.shape} vs ({M},{K})"
assert x.dtype == torch.bfloat16
x = x.contiguous()
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
M, N, K,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
)
return out
M = 1
N = 12288
K = 4096
def get_inputs():
x = torch.randn(M, K, dtype=torch.bfloat16)
return [x]
def get_init_inputs():
return [M, N, K]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_095207_minimax-claude_MiniMax-M3_07_w4a16_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-13T20:31:55.495268+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-13T20:31:58.538041+00:00 elapsed_s=3.043 ms=0.060832
shape=0 variant=solution tflops=1.655 gbps=440.088 ms=0.061
shape=0 solution_peak_fraction=0.2445
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T20:31:58.867348+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T20:32:01.929327+00:00 elapsed_s=3.062 ms=0.098944
shape=1 variant=solution tflops=32.556 gbps=280.838 ms=0.099
shape=1 solution_peak_fraction=0.1560
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T20:32:02.262575+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T20:32:05.733288+00:00 elapsed_s=3.471 ms=0.292048
shape=2 variant=solution tflops=88.238 gbps=120.279 ms=0.292
shape=2 solution_peak_fraction=0.0668
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T20:32:05.835753+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T20:32:08.655421+00:00 elapsed_s=2.820 ms=0.042032
shape=3 variant=solution tflops=0.798 gbps=212.440 ms=0.042
shape=3 solution_peak_fraction=0.1180
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T20:32:09.038821+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T20:32:12.014589+00:00 elapsed_s=2.976 ms=0.082976
shape=4 variant=solution tflops=22.646 gbps=383.062 ms=0.083
shape=4 solution_peak_fraction=0.2128
peak_fraction: 0.1450
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_095207_minimax-claude_MiniMax-M3_07_w4a16_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_095207_minimax-claude_MiniMax-M3_07_w4a16_gemm/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_095207_minimax-claude_MiniMax-M3_07_w4a16_gemm/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_095207_minimax-claude_MiniMax-M3_07_w4a16_gemm/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_095207_minimax-claude_MiniMax-M3_07_w4a16_gemm/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_095207_minimax-claude_MiniMax-M3_07_w4a16_gemm/repo
Installed 56 packages in 65ms
PASS
{
"run_id": "20260613_095207_minimax-claude_MiniMax-M3_07_w4a16_gemm",
"run_group": "",
"problem": "07_w4a16_gemm",
"harness": "minimax-claude",
"model": "MiniMax-M3",
"reasoning_effort": "",
"started_at": "2026-06-13T09:52:07-06:00",
"harness_finished_at": "2026-06-13T14:31:21-06:00",
"finished_at": "2026-06-13T14:32:12-06:00",
"start_epoch": 1781365927,
"harness_end_epoch": 1781382681,
"end_epoch": 1781382732,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.1450,
"template_mutated": false,
"elapsed_seconds": 16754,
"total_elapsed_seconds": 16805,
"check_elapsed_seconds": 33,
"benchmark_elapsed_seconds": 18,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 7.877103975170109,
"usage": {"input_tokens": 17033774, "output_tokens": 131973, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": null, "total_cost_usd": 89.25434400000002}
}
{
"command": "ls -la /workspace/problems/07_w4a16_gemm/",
"description": "List workspace contents"
}total 48
drwxrwxr-x 2 1002 1002 4096 Jun 13 15:52 .
drwxrwxr-x 3 1002 1002 4096 Jun 13 15:52 ..
-rw-rw-r-- 1 1002 1002 2421 Apr 27 09:21 PROMPT.txt
-rw-rw-r-- 1 1002 1002 5146 Jun 1 21:05 benchmark.py
-rw-rw-r-- 1 1002 1002 4205 Jun 2 01:08 check.py
-rw-rw-r-- 1 1002 1002 1913 Apr 26 04:22 problem.yaml
-rw-rw-r-- 1 1002 1002 4202 Apr 26 04:25 reference.py
-rw-rw-r-- 1 1002 1002 552 Apr 26 04:22 shapes.py
-rw-rw-r-- 1 1002 1002 3734 Apr 26 04:24 sota.py{
"file_path": "/workspace/problems/07_w4a16_gemm/reference.py"
}1 """Naive W4A16 weight-only quantized GEMM reference (correctness only).
2
3 AWQ/GPTQ-style scheme:
4 x: (M, K) bf16
5 w_q: (K // 2, N) uint8 -- two int4 weights packed per byte (low nibble = even-K, high = odd-K)
6 scales: (K // group, N) bf16
7 zeros: (K // group, N) bf16 -- asymmetric (stored already as float zero-point)
8 out: (M, N) bf16
9
10 Dequant (per group along K):
11 w_bf[k, n] = (w_q[k, n] - zeros[k // group, n]) * scales[k // group, n]
12 where w_q[k, n] is the unpacked 4-bit value (0..15).
13
14 This reference unpacks to a full bf16 matrix and then runs torch.matmul. Slow and
15 memory-heavy on the dequant; the agent's solution must fuse unpack+GEMM.
16 """
17 from __future__ import annotations
18
19 import torch
20 import torch.nn as nn
21
22 OP_TYPE = "gemm_w4a16"
23 SUPPORTED_PRECISIONS = ["int4_bf16"]
24 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25
26 GROUP_SIZE = 128
27
28
29 def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
30 """Pack (K, N) uint8 in [0,15] into (K//2, N) uint8.
31
32 Even rows go in the low nibble, odd rows in the high nibble.
33 """
34 K, N = w_q.shape
35 assert K % 2 == 0
36 lo = w_q[0::2].to(torch.uint8) & 0xF
37 hi = w_q[1::2].to(torch.uint8) & 0xF
38 return (lo | (hi << 4)).contiguous()
39
40
41 def _unpack_int4(w_packed: torch.Tensor, K: int) -> torch.Tensor:
42 """Unpack (K//2, N) uint8 -> (K, N)
... (3046 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/check.py"
}1 """Correctness runner for FP8 GEMM.
2
3 Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4 each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5 """
6 import re
7 import sys
8 from pathlib import Path
9
10 import torch
11 import yaml
12
13 # Make the repo's src/ importable
14 REPO_ROOT = Path(__file__).resolve().parents[2]
15 sys.path.insert(0, str(REPO_ROOT))
16
17 from src.eval.correctness import check_correctness # noqa: E402
18 from src.eval.numeric_stress import ( # noqa: E402
19 numeric_stress_cases,
20 numeric_stress_context,
21 tolerance_for_case,
22 )
23
24
25 def main():
26 try:
27 import reference
28 import shapes
29 import solution
30 except Exception as e:
31 print(f"FAIL: import error: {e}")
32 sys.exit(1)
33
34 problem_yaml = Path("problem.yaml")
35 meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
36
37 # --- Forbidden-op check ------------------------------------------------
38 sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
39 for forbidden in meta.get("forbidden", []):
40 pat = re.escape(forbidden)
41 if re.search(pat, sol_src):
42 print(f"FAIL: forbidden op used: {forbidden}")
43 sys.exit(1)
44
45 device = torch.device("cuda:0")
46 tol_override = meta.get("tolerance") or None
47
48 # --- Per-shape correc
... (3075 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/shapes.py"
}1 """Shape sweep for W4A16 GEMM.
2
3 Llama-style up_proj / qkv_proj shapes. Decode (M=1) is the bandwidth-bound
4 case every inference engine optimizes -- it's the bar to beat.
5 """
6
7 SHAPES = [
8 {"M": 1, "N": 12288, "K": 4096}, # decode: memory-bound on int4 weight read
9 {"M": 32, "N": 12288, "K": 4096}, # small prefill: mixed regime
10 {"M": 256, "N": 12288, "K": 4096}, # larger prefill: approaching compute
11 {"M": 1, "N": 4096, "K": 4096}, # decode: square shape
12 {"M": 16, "N": 14336, "K": 4096}, # speculative-decode-ish
13 ]
14 {
"file_path": "/workspace/problems/07_w4a16_gemm/benchmark.py"
}1 """Roofline benchmark for FP8 GEMM.
2
3 For each shape: times the agent's solution first. Optional eager, compiled,
4 and SOTA diagnostics are enabled with KBH_BENCHMARK_BASELINES=1. Reports achieved TFLOPS, GB/s, and peak_fraction.
5
6 Output lines the harness picks up:
7 shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8 peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9 """
10 import sys
11 from math import exp, log
12 from pathlib import Path
13
14 import torch
15 import yaml
16
17 REPO_ROOT = Path(__file__).resolve().parents[2]
18 sys.path.insert(0, str(REPO_ROOT))
19
20 from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21 from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
22 from src.hardware import get as get_hw # noqa: E402
23
24
25 def _eval_formula(expr: str, vars: dict) -> float:
26 # Very small eval: only names from `vars` are valid.
27 return float(eval(expr, {"__builtins__": {}}, vars))
28
29
30 def main():
31 import reference
32 import shapes
33 import solution
34
35 meta = yaml.safe_load(Path("problem.yaml").read_text())
36 hw = get_hw(meta["hardware"][0])
37 peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38 peak_gbps = hw.peak_bandwidth_gb_s
39 regime = meta.get("regime", "compute")
40 flops_formula = meta["flops_formula"]
41 bytes_formula = meta["bytes_formula"]
42 num_perf
... (4130 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/sota.py"
}1 """SOTA reference for W4A16 GEMM.
2
3 Library survey on RTX PRO 6000 Blackwell (SM120, CC 12.0):
4
5 - Marlin (IST-DASLab): no SM120 kernels (Ampere/Hopper only). Skip.
6 - GPTQ-Triton (fpgaminer): unmaintained; pure Triton path works on SM120
7 but is not faster than Marlin on its target HW
8 and has no Blackwell tuning. Skip as primary.
9 - AWQ (mit-han-lab/llm-awq): CUDA kernels not built for SM120 in the wheel.
10 Skip.
11 - bitsandbytes >= 0.49.2: CUDA kernels compile and run on SM120 (verified
12 on this machine). Different quant scheme (NF4,
13 symmetric, blocksize 64) than our reference's
14 AWQ-style asymmetric INT4 with group_size 128,
15 but it occupies the same memory regime and is
16 the only tuned W4A16-class kernel that runs on
17 SM120 today. Used here as an *informational*
18 SOTA line, not as a numerical reference.
19
20 The benchmark calls `sota_forward(x, ref_model)` and times it; correctness is
21 NOT checked against this path (the quant scheme differs).
22 """
23 from __future__ import annotations
24
25 import torch
26
27 _BNB_OK: bool | None = None
28
29
30 def is_available() -> bool:
31 gl
... (2489 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/problem.yaml"
}1 name: 07_w4a16_gemm
2 display_name: "W4A16 Weight-only Quantized GEMM"
3 precision: int4_bf16
4 regime: memory # decode-dominant; M=1 is bandwidth-bound on the int4 weight stream
5
6 # Dense-equivalent FLOPs (matmul work, ignoring dequant arithmetic).
7 flops_formula: "2 * M * N * K"
8
9 # Bytes moved per call (memory roofline):
10 # x: M*K*2 (bf16 activations, streamed in once)
11 # w_q: (K/2)*N (packed int4, 0.5 B/elem)
12 # scales: (K/128)*N*2 (bf16 scales)
13 # zeros: (K/128)*N*2 (bf16 zero-points)
14 # out: M*N*2 (bf16 store)
15 bytes_formula: "M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2"
16
17 hardware: [RTX_PRO_6000]
18 peak_tflops_key: bf16
19 peak_bandwidth_key: dram
20
21 tolerance:
22 bfloat16: 0.10 # group-quant adds noise on top of bf16 accumulator slop
23
24 # Forbidden ops -- agent must write the unpack + GEMM themselves, not call a
25 # vendor library that does both.
26 forbidden:
27 - "bitsandbytes.functional.dequantize_4bit"
28 - "bitsandbytes.functional.gemv_4bit"
29 - "marlin_kernel.gemm"
30 - "torch.nn.functional.linear"
31
32 sota:
33 name: "bitsandbytes NF4 (gemv_4bit / dequantize_4bit + matmul)"
34 url: "https://github.com/TimDettmers/bitsandbytes"
35 function: "bitsandbytes.functional.gemv_4bit"
36 notes: |
37 Marlin (IST-DASLab) is the W4A16 SOTA on Ampere/Hopper but does not have
38 SM120 (Blackwell consumer) kernels yet. GPTQ-Triton is unmaintained and
... (554 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/PROMPT.txt"
}1 I need you to write a weight-only int4 quantized GEMM (W4A16) for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2
3 The scheme is AWQ/GPTQ-style asymmetric int4 with explicit zero-points and per-group bf16 scales. Inputs are x of shape (M, K) bf16, w_q of shape (K // 2, N) uint8 (two int4 weights packed per byte, low nibble = even-K row, high nibble = odd-K row), scales of shape (K // 128, N) bf16, and zeros of shape (K // 128, N) bf16. Group size is 128 along K. Dequant per group is w_bf[k, n] = (unpack(w_q)[k, n] - zeros[k // 128, n]) * scales[k // 128, n], and the output is (M, N) bf16. Correctness tolerance is 0.10 abs/rel — group-quant adds noise on top of bf16 accumulator slop. The shapes you have to handle are M=1 N=12288 K=4096 (decode, memory-bound on the int4 weight read), M=32 N=12288 K=4096 (small prefill, mixed regime), M=256 N=12288 K=4096 (larger prefill, approaching compute-bound), M=1 N=4096 K=4096 (decode, square), and M=16 N=14336 K=4096 (speculative-decode-ish).
4
5 This needs to be a real custom kernel that fuses unpack and GEMM in the same pass — a separate dequant-then-matmul wastes the entire bandwidth advan
... (923 more chars){
"command": "uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_name(0)); print(torch.cuda.get_device_capability(0))\"",
"description": "Check torch and GPU info"
}Using CPython 3.11.15
Creating virtual environment at: /workspace/.venv
Building kernelbench-hard @ file:///workspace
Built kernelbench-hard @ file:///workspace
warning: Failed to hardlink files; falling back to full copy. This may lead to degraded performance.
If the cache and target directories are on different filesystems, hardlinking may not be supported.
If this is intentional, set `export UV_LINK_MODE=copy` or use `--link-mode=copy` to suppress this warning.
Installed 56 packages in 871ms
2.11.0+cu130
NVIDIA RTX PRO 6000 Blackwell Workstation Edition
(12, 0){
"command": "uv run python -c \"import triton; print('triton:', triton.__version__)\"",
"description": "Check triton version"
}triton: 3.6.0{
"command": "<486 chars \u2014 see below>",
"description": "Check available libraries"
}uv run python -c "
import torch
print('CUDA:', torch.version.cuda)
print('Device:', torch.cuda.get_device_name(0))
print('CC:', torch.cuda.get_device_capability(0))
# Check available libraries
try:
import triton
print('triton:', triton.__version__)
except: pass
try:
import bitsandbytes as bnb
print('bnb:', bnb.__version__)
except Exception as e: print('bnb err:', e)
try:
from torch.utils.cpp_extension import load_inline
print('load_inline OK')
except: pass
"CUDA: 13.0
Device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
CC: (12, 0)
triton: 3.6.0
bnb err: No module named 'bitsandbytes'
load_inline OK{
"command": "which nvcc && nvcc --version",
"description": "Check nvcc availability"
}/kbh/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0{
"command": "ls /usr/local/cuda* 2>/dev/null; ls /opt/cuda* 2>/dev/null; echo \"---\"; env | grep -i cuda",
"description": "Find CUDA installation"
}/usr/local/cuda:
NsightSystems-cli-2025.3.1
bin
compat
compute-sanitizer
doc
extras
gds
include
lib64
nvml
nvvm
share
src
targets
/usr/local/cuda-12:
NsightSystems-cli-2025.3.1
bin
compat
compute-sanitizer
doc
extras
gds
include
lib64
nvml
nvvm
share
src
targets
/usr/local/cuda-12.9:
NsightSystems-cli-2025.3.1
bin
compat
compute-sanitizer
doc
extras
gds
include
lib64
nvml
nvvm
share
src
targets
/usr/local/cuda-host:
DOCS
EULA.txt
README
bin
compute-sanitizer
doc
extras
gds
include
lib64
nsightee_plugins
nvml
nvvm
share
src
targets
tools
version.json
---
NVIDIA_REQUIRE_CUDA=cuda>=9.0
TORCH_CUDA_ARCH_LIST=7.5 8.0 8.6 9.0 10.0 12.0+PTX
CUDA_ARCH_LIST=7.5 8.0 8.6 9.0 10.0 12.0
CUDA_VERSION=12.9.1.010
PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.99999
LIBRARY_PATH=/usr/local/cuda/lib64/stubs:
LD_LIBRARY_PATH=/usr/local/tensorrt/lib:/usr/local/cuda/lib64:/usr/local/lib/python3.12/dist-packages/torch/lib:/usr/local/lib/python3.12/dist-packages/torch_tensorrt/lib:/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
CUDA_DRIVER_VERSION=575.57.08
_CUDA_COMPAT_PATH=/usr/local/cuda/compat
CUDA_HOME=/usr/local/cuda-host
CUDA_MODULE_LOADING=LAZY
PATH=/kbh/bin:/usr/local/cuda-host/bin:/usr/local/bin:/usr/bin:/bin{
"command": "ls /usr/local/cuda/include/cute 2>/dev/null; echo \"---\"; ls /usr/local/cuda/include/cutlass 2>/dev/null; echo \"---\"; find / -name \"cute\" -type d 2>/dev/null | head -10",
"description": "Check if CUTLASS headers are available"
}---
---
/usr/include/cute
/opt/pytorch/ao/third_party/cutlass/test/unit/cute
/uv-cache/archive-v0/vRmFa6d-Xjr_fQ1_9lPaP/nvidia_cutlass_dsl/python_packages/cutlass/cute
/uv-cache/archive-v0/FHHqsxOe4DHzv5CTTSOyA/flashinfer/data/cutlass/include/cute
/uv-cache/archive-v0/FHHqsxOe4DHzv5CTTSOyA/flashinfer/data/cutlass/python/CuTeDSL/cutlass/cute
/uv-cache/archive-v0/FHHqsxOe4DHzv5CTTSOyA/flashinfer/data/cutlass/examples/python/CuTeDSL/cute
/opt/pytorch/ao/third_party/cutlass/examples/cute
/opt/pytorch/ao/third_party/cutlass/include/cute
/opt/pytorch/ao/third_party/cutlass/media/images/cute
/opt/pytorch/ao/third_party/cutlass/media/docs/cpp/cute{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/ | head -30",
"description": "Check cutlass include path"
}cute
cutlass{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/ | head; echo \"---\"; ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/ 2>/dev/null | head",
"description": "Check cutlass gemm headers"
}aligned_buffer.h
arch
array.h
array_planar_complex.h
array_subbyte.h
barrier.h
bfloat16.h
blas3.h
blas3_types.h
block_striped.h
---
collective
device
dispatch_policy.hpp
gemm.h
gemm_enumerated_types.h
group_array_problem_shape.hpp
kernel
thread
threadblock
warp{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/device/ | head -20",
"description": "Check cutlass device gemm"
}base_grouped.h
default_gemm_configuration.h
ell_gemm.h
gemm.h
gemm_array.h
gemm_batched.h
gemm_complex.h
gemm_grouped.h
gemm_layernorm_mainloop_fusion.h
gemm_sparse.h
gemm_sparse_universal.h
gemm_sparse_universal_with_absmax.h
gemm_sparse_with_absmax.h
gemm_sparse_with_visitor.h
gemm_splitk_parallel.h
gemm_universal.h
gemm_universal_adapter.h
gemm_universal_base.h
gemm_universal_streamk_with_broadcast.h
gemm_universal_with_absmax.h{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/cute/ | head -20",
"description": "List CuTe directory"
}Command running in background with ID: b71uahqk9. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b71uahqk9.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/cute/ 2>/dev/null | head -10; echo \"---\"; ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/collective/ 2>/dev/null | head -20",
"description": "List CuTe and gemm collective"
}algorithm
arch
atom
config.hpp
container
int_tuple.hpp
layout.hpp
layout_composed.hpp
numeric
pointer.hpp
---
builders
collective_builder.hpp
collective_builder_decl.hpp
collective_mma.hpp
collective_mma_decl.hpp
fp8_accumulation.hpp
sm100_blockscaled_mma_array_warpspecialized.hpp
sm100_blockscaled_mma_warpspecialized.hpp
sm100_blockscaled_sparse_mma_warpspecialized.hpp
sm100_mma_array_warpspecialized.hpp
sm100_mma_array_warpspecialized_blockwise_scaling.hpp
sm100_mma_array_warpspecialized_emulated.hpp
sm100_mma_warpspecialized.hpp
sm100_mma_warpspecialized_blockwise_scaling.hpp
sm100_mma_warpspecialized_emulated.hpp
sm100_mma_warpspecialized_mixed_input.hpp
sm100_sparse_mma_warpspecialized.hpp
sm120_blockscaled_mma_array_tma.hpp
sm120_blockscaled_mma_tma.hpp
sm120_blockscaled_sparse_mma_tma.hpp{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/collective/ | grep -i sm120",
"description": "List sm120 collective files"
}sm120_blockscaled_mma_array_tma.hpp
sm120_blockscaled_mma_tma.hpp
sm120_blockscaled_sparse_mma_tma.hpp
sm120_mma_tma.hpp
sm120_sparse_mma_tma.hpp{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/ | head -30; echo \"---\"; ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/kernel/ 2>/dev/null | head -20",
"description": "List gemm kernel directory"
}Command running in background with ID: bjynfbtb5. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bjynfbtb5.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/ | head -30",
"description": "List gemm directory"
}collective
device
dispatch_policy.hpp
gemm.h
gemm_enumerated_types.h
group_array_problem_shape.hpp
kernel
thread
threadblock
warp{
"command": "ls /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/kernel/ | grep -i sm120",
"description": "List sm120 kernel files"
}sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp{
"command": "cat /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp | head -100",
"description": "Look at sm120 gemm kernel"
}/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED
... (3088 more chars){
"command": "cat /opt/pytorch/ao/third_party/cutlass/include/cutlass/gemm/collective/sm120_mma_tma.hpp | head -100",
"description": "Look at sm120 mma tma collective"
}/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED
... (2222 more chars){
"command": "ls /opt/pytorch/ao/third_party/cutlass/examples/ 2>/dev/null | head -30",
"description": "Check cutlass examples"
}00_basic_gemm
01_cutlass_utilities
02_dump_reg_shmem
03_visualize_layout
04_tile_iterator
05_batched_gemm
06_splitK_gemm
07_volta_tensorop_gemm
08_turing_tensorop_gemm
09_turing_tensorop_conv2dfprop
10_planar_complex
11_planar_complex_array
12_gemm_bias_relu
13_two_tensor_op_fusion
14_ampere_tf32_tensorop_gemm
15_ampere_sparse_tensorop_gemm
16_ampere_tensorop_conv2dfprop
17_fprop_per_channel_bias
18_ampere_fp64_tensorop_affine2_gemm
19_tensorop_canonical
20_simt_canonical
21_quaternion_gemm
22_quaternion_conv
23_ampere_gemm_operand_reduction_fusion
24_gemm_grouped
25_ampere_fprop_mainloop_fusion
26_ampere_wgrad_mainloop_fusion
27_ampere_3xtf32_fast_accurate_tensorop_gemm
28_ampere_3xtf32_fast_accurate_tensorop_fprop
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm{
"command": "ls /opt/pytorch/ao/third_party/cutlass/examples/ | grep -i blackwell; echo \"---\"; ls /opt/pytorch/ao/third_party/cutlass/examples/ | grep -i sm120; echo \"---\"; ls /opt/pytorch/ao/third_party/cutlass/examples/ | grep -i 12 2>/dev/null",
"description": "Look for blackwell/sm120 examples"
}70_blackwell_gemm
71_blackwell_gemm_with_collective_builder
72_blackwell_narrow_precision_gemm
73_blackwell_gemm_preferred_cluster
74_blackwell_gemm_streamk
75_blackwell_grouped_gemm
76_blackwell_conv
77_blackwell_fmha
78_blackwell_emulated_bf16x9_gemm
79_blackwell_geforce_gemm
80_blackwell_geforce_sparse_gemm
81_blackwell_gemm_blockwise
82_blackwell_distributed_gemm
83_blackwell_sparse_gemm
84_blackwell_narrow_precision_sparse_gemm
---
---
12_gemm_bias_relu{
"command": "ls /opt/pytorch/ao/third_party/cutlass/examples/79_blackwell_geforce_gemm/",
"description": "Look at SM120 example"
}79a_blackwell_geforce_nvfp4_bf16_gemm.cu
79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu
79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu
79d_blackwell_geforce_nvfp4_grouped_gemm.cu
CMakeLists.txt{
"command": "cat /opt/pytorch/ao/third_party/cutlass/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu | head -200",
"description": "Look at SM120 nvfp4 example"
}/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED
... (8783 more chars){
"command": "find /opt/pytorch/ao/third_party/cutlass -name \"*.hpp\" -o -name \"*.h\" 2>/dev/null | xargs grep -l \"int4\\|int4b\\|mxfp4\\|nv_float4\\|float_e2m1\" 2>/dev/null | head -20",
"description": "Find int4 related headers"
}/opt/pytorch/ao/third_party/cutlass/test/unit/conv/cache_testbed_output.h
/opt/pytorch/ao/third_party/cutlass/include/cute/pointer.hpp
/opt/pytorch/ao/third_party/cutlass/include/cutlass/numeric_conversion.h
/opt/pytorch/ao/third_party/cutlass/include/cutlass/core_io.h
/opt/pytorch/ao/third_party/cutlass/include/cutlass/float_subbyte.h
/opt/pytorch/ao/third_party/cutlass/include/cutlass/integer_subbyte.h
/opt/pytorch/ao/third_party/cutlass/include/cutlass/relatively_equal.h
/opt/pytorch/ao/third_party/cutlass/include/cutlass/numeric_types.h
/opt/pytorch/ao/third_party/cutlass/include/cutlass/subbyte_reference.h
/opt/pytorch/ao/third_party/cutlass/test/unit/epilogue/threadblock/testbed.h
/opt/pytorch/ao/third_party/cutlass/include/cute/atom/mma_traits_sm100.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/atom/mma_traits_sm120.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/atom/mma_traits_sm80.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm100_desc.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/arch/copy_sm90_desc.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120_sparse.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/numeric/numeric_types.hpp
/opt/pytorch/ao/third_party/cutlass/include/cute/numeric/int.hpp{
"command": "cat /opt/pytorch/ao/third_party/cutlass/include/cute/atom/mma_traits_sm120.hpp | head -200",
"description": "Look at SM120 MMA traits"
}/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED
... (7124 more chars){
"command": "cat /opt/pytorch/ao/third_party/cutlass/include/cute/atom/mma_traits_sm120.hpp | sed -n '200,400p'",
"description": "Continue reading sm120 mma traits"
}// Transform if needed
template<class MMA_Op, class Tensor>
CUTLASS_DEVICE void
fp4_shift_A(MMA_Op const& op, Tensor&& tensor) {
}
template<class MMA_Op, class Tensor>
CUTLASS_DEVICE void
fp4_shift_B(MMA_Op const& op, Tensor&& tensor) {
}
// For SM120 MMA F8F6F4 input fp4, the operand A/B are load from ld.matrix.
// ld.matrix b4x16_p64 places FP4 data at the first four bits in each
// eight-bit container, whereas MMA F8F6F4 expects the four-bit data to be in
// the middle of the eight-bit container. Thus, e2m1 operands being fed
// to MMA F8F6F4 must be shifted left by two bits.
// 0b0000ABCD --> 0b00ABCD00
// NOTE: Same transformation is NOT needed for FP6 and FP8.
template<class AType, class BType, class... MMAArgs, class Tensor>
CUTLASS_DEVICE void
fp4_shift_A(SM120_16x8x32_TN<AType, BType, MMAArgs ...> const&, Tensor&& tensor) {
using RegisterTypeA = typename remove_extent<typename
SM120_16x8x32_TN<AType, BType, MMAArgs ...>::ARegisters>::type;
if constexpr (cute::is_same_v<AType, cutlass::float_e2m1_t>) {
cute::transform(recast<RegisterTypeA>(tensor), [](RegisterTypeA& v){ return v << 2; });
}
}
template<class AType, class BType, class... MMAArgs, class Tensor>
CUTLASS_DEVICE void
fp4_shift_B(SM120_16x8x32_TN<AType, BType, MMAArgs ...> const&, Tensor&& tensor) {
using RegisterTypeB = typename remove_extent<typename
SM120_16x8x32_TN<AType, BType, MMAArgs ...>::BRegisters>::type;
if constexpr (cute::is_same_v
... (1465 more chars){
"command": "cat /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | head -100",
"description": "Look at sm120 mma intrinsics"
}Command running in background with ID: baof55ipz. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/baof55ipz.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"command": "cat /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | grep -E \"mma_sync|sm120|MMAOp|class SM120\" | head -30",
"description": "Find mma intrinsics in sm120 header"
}Command running in background with ID: bhapzqxsg. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bhapzqxsg.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"command": "grep -E \"SM120_16x8x32_TN|SM120_16x8x64|class SM120|mma_sync\" /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | head -20",
"description": "Look for SM120 mma instructions"
}struct SM120_16x8x32_TN
static_assert(cutlass::detail::dependent_false<a_type>, "No MMA matches SM120_16x8x32_TN for given data types.");
struct SM120_16x8x32_TN<float_e2m1_t, float_e2m1_t, float>
CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED");
struct SM120_16x8x32_TN<float_e2m1_t, float_e3m2_t, float>
CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED");
struct SM120_16x8x32_TN<float_e2m1_t, float_e2m3_t, float>
CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED");
struct SM120_16x8x32_TN<float_e2m1_t, float_e4m3_t, float>
CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED");
struct SM120_16x8x32_TN<float_e2m1_t, float_e5m2_t, float>
CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED");
struct SM120_16x8x32_TN<float_e3m2_t, float_e2m1_t, float>
CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED");
struct SM120_16x8x32_TN<float_e3m2_t, float_e3m2_t, float>
CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED");
struct SM120_16x8x32_TN<float_e3m2_t, float_e2m3_t, float>
CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED");
struct SM120_16x8x32_TN<float_e3m2_t, fl
... (124 more chars){
"command": "grep -E \"mma\\.sync\\.\" /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | head -20",
"description": "Look for raw mma.sync instructions"
}"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e5m2.f32 "{
"command": "grep -E \"mma\\.sync\\.\" /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | grep -E \"block_scale|mxfp\" | head -20",
"description": "Look for block scale mma instructions"
}"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e2m1.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e3m2.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e2m3.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e4m3.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e5m2.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e2m1.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e3m2.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e2m3.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e4m3.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e5m2.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e2m1.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e3m2.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e2m3.f32.ue8m0 "
"mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e4m3.f32.ue8m0 "
"mma.sync.aligne
... (615 more chars){
"command": "grep -E \"mma\\.sync\\.\" /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | grep -E \"16x8x32|16x8x64\" | head -20",
"description": "Look for sm120 mma sizes"
}(Bash completed with no output){
"command": "grep -E \"kind::\" /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | head -30",
"description": "Look for kinds of mma"
}"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e5m2.f32 "
"mma.sync.aligned.ki
... (713 more chars){
"command": "<450 chars \u2014 see below>",
"description": "Check tooling availability"
}uv run python -c "
import torch
# Check what dtypes triton supports
import triton
import triton.language as tl
print('triton:', triton.__version__)
# Check SM
print('SM:', torch.cuda.get_device_capability(0))
# Check ncu
import subprocess
r = subprocess.run(['which', 'ncu'], capture_output=True, text=True)
print('ncu:', r.stdout, r.stderr)
r = subprocess.run(['which', 'nsys'], capture_output=True, text=True)
print('nsys:', r.stdout, r.stderr)
"triton: 3.6.0
SM: (12, 0)
ncu: /kbh/bin/ncu
nsys: /kbh/bin/nsys{
"command": "cat /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | grep -E \"m16n8k32|f16|bf16\" | head -20",
"description": "Look at sm120 mma instructions"
}"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e5m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e5m2.f32 "{
"command": "cat /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | grep -E \"f8f6f4\" | head -5; echo \"---\"; cat /opt/pytorch/ao/third_party/cutlass/include/cute/arch/mma_sm120.hpp | head -300 | tail -150",
"description": "Look at sm120 mma intrinsics details"
}"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m1.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e3m2.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e4m3.f32 "
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e5m2.f32 "
---
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x32 TN E2M1 x E4M3
template <>
struct SM120_16x8x32_TN<float_e2m1_t, float_e4m3_t, float>
{
using DRegisters = float[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = float[4];
CUTE_HOST_DEVICE static void
fma(float & d0, float & d1, float & d2, float & d3,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
float const & c0, float const & c1, float const & c2, float const & c3)
{
#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED)
asm volatile(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e4m3.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_INVALID_CONTROL_PATH("A
... (4191 more chars)--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -0,0 +1,287 @@
+"""W4A16 weight-only quantized GEMM for RTX PRO 6000 (SM120).
+
+AWQ/GPTQ-style asymmetric int4 with bf16 per-group scales/zeros.
+Fused unpack + GEMM via Triton.
+
+The scheme:
+ w_bf[k, n] = (unpack(w_q)[k, n] - zeros[k // 128, n]) * scales[k // 128, n]
+ out[m, n] = sum_k x[m, k] * w_bf[k, n]
+
+Key optimizations:
+ - Two-dot pattern: process even/odd K rows with two separate dots. Each
+ group_size=128 K rows become two 64-wide dots that both read the same x row
+ span but different nibbles of the packed weight byte.
+ - Decode path (M=1) uses a memory-bound split-K GEMV kernel that loads
+ w_q, scales, and zeros on the fly and accumulates in fp32.
+ - Prefill path (M>=16) uses a tile-based GEMM kernel that dequantizes
+ BLOCK_K rows at a time and uses tensor cores.
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+import triton
+import triton.language as tl
+
+OP_TYPE = "gemm_w4a16"
+SUPPORTED_PRECISIONS = ["int4_bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
+GROUP_SIZE = 128
+
+
+# ---------------------------------------------------------------------------
+# Triton kernels
+# ---------------------------------------------------------------------------
+
+# Autotune configs. We keep this small so compilation stays fast.
+_GEMM_CONFIGS = [
+ triton.Config({"BLOCK_M": bm, "BLOCK_N": bn, "BLOCK_K": bk, "GROUP_SZ": 128},
+ num_warps=nw, num_stages=ns)
+ for bm in (16, 32, 64, 128)
+ for bn in (64, 128, 256)
+ for bk in (128, 256)
+ for nw in (4, 8)
+ for ns in (2, 3)
+ if bm * bn <= 32768
+]
+
+
+@triton.autotune(configs=_GEMM_CONFIGS, key=["M", "N", "K"])
+@triton.jit
+def w4a16_gemm_kernel(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_m_mask = offs_m < M
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+ N_GROUPS: tl.constexpr = K // GROUP_SZ
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ # Loop over K in chunks of BLOCK_K. We do two tl.dot calls per chunk: one
+ # for the even-K rows of the packed weights, one for the odd-K rows.
+ for k_blk in tl.range(0, K, BLOCK_K):
+ # Load packed weights: (BLOCK_K // 2, BLOCK_N) uint8
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N,
+ other=0,
+ )
+
+ # Unpack
+ w_lo = (wq & 0xF).to(tl.bfloat16) # (BLOCK_K_HALF, BLOCK_N) k even
+ w_hi = (wq >> 4).to(tl.bfloat16) # (BLOCK_K_HALF, BLOCK_N) k odd
+
+ # Load scales/zeros for the groups covered by this K block.
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0) # (n_groups_per_blk, BLOCK_N)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ # Broadcast to (BLOCK_K_HALF, BLOCK_N) by repeating each group GROUP_HALF times.
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ # Dequant: (w_int - zero) * scale
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ # Load x for the even/odd K rows in this block.
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m_mask[:, None], other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m_mask[:, None], other=0.0,
+ )
+
+ # Two dots per block.
+ acc += tl.dot(x_even, w_lo)
+ acc += tl.dot(x_odd, w_hi)
+
+ # Store
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=offs_m_mask[:, None] & (offs_n[None, :] < N))
+
+
+_GEMV_CONFIGS = [
+ triton.Config({"BLOCK_N": bn, "BLOCK_K": bk, "GROUP_SZ": 128},
+ num_warps=nw, num_stages=ns)
+ for bn in (64, 128, 256)
+ for bk in (128, 256, 512)
+ for nw in (4, 8)
+ for ns in (2, 3, 4)
+]
+
+
+@triton.autotune(configs=_GEMV_CONFIGS, key=["M", "N", "K"])
+@triton.jit
+def w4a16_gemv_kernel(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ """M=1 specialized: each program computes a (1, BLOCK_N) output tile."""
+ pid = tl.program_id(0)
+ offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ # Load x once (full K) -- for M=1, this is small (K * 2 bytes).
+ offs_k = tl.arange(0, BLOCK_K)
+
+ for k_b
... (4572 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<10201 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/solution.py (file state is current in your context — no need to Read it back){
"command": "uv run python check.py 2>&1 | head -100",
"description": "Run check.py to see if it passes"
}Command running in background with ID: bpd213mh4. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bpd213mh4.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bpd213mh4",
"block": true,
"timeout": 300000
}<retrieval_status>timeout</retrieval_status>
<task_id>bpd213mh4</task_id>
<task_type>local_bash</task_type>
<status>running</status>{
"command": "cat /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bpd213mh4.output 2>/dev/null | head -100",
"description": "Check check.py output"
}Command running in background with ID: bgqlj5bqy. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bgqlj5bqy.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bgqlj5bqy",
"block": true,
"timeout": 300000
}<retrieval_status>success</retrieval_status>
<task_id>bgqlj5bqy</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>{
"command": "ls /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/ | head -10; echo \"---\"; cat /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bpd213mh4.output 2>/dev/null | head -50",
"description": "Check output files"
}b71uahqk9.output
baof55ipz.output
bb9wj0tls.output
bgqlj5bqy.output
bhapzqxsg.output
bjynfbtb5.output
bpd213mh4.output
---{
"task_id": "bpd213mh4",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>bpd213mh4</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
FAIL: shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42 case large_activation: tolerance exceeded: max_abs_diff=2 max_rel_diff=70.1111 bad=10/12288 worst_idx=(0, 281) (atol=1.0, rtol=0.05)
</output>{
"command": "find /workspace -name \"correctness.py\" -path \"*/eval/*\" 2>/dev/null | head; echo \"---\"; find /workspace -name \"numeric_stress*\" -path \"*/eval/*\" 2>/dev/null | head",
"description": "Find eval module files"
}/workspace/src/eval/correctness.py
---
/workspace/src/eval/numeric_stress.py
/workspace/src/eval/__pycache__/numeric_stress.cpython-311.pyc{
"file_path": "/workspace/src/eval/correctness.py"
}1 """Per-dtype correctness tolerance.
2
3 Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4 We extend to fp8 (0.1, permissive) and int (exact).
5 """
6 from __future__ import annotations
7
8 import torch
9
10 DEFAULT_TOLERANCE = {
11 torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12 torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13 torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14 torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15 torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16 torch.int8: {"atol": 0, "rtol": 0},
17 torch.int32: {"atol": 0, "rtol": 0},
18 torch.int64: {"atol": 0, "rtol": 0},
19 }
20
21
22 def _coerce_float(x):
23 """Accept int/float/str scientific-notation as float; pass through dicts."""
24 if isinstance(x, (int, float)):
25 return float(x)
26 if isinstance(x, str):
27 try:
28 return float(x)
29 except ValueError:
30 return x
31 return x
32
33
34 def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35 """Lookup atol/rtol for a given dtype, with optional per-problem override.
36
37 The override dict is keyed by str(dtype). Values may be:
38 - a single number (or a string parseable as float, e.g. "5e-2") -> used
39 as both atol and rtol
40 - a dict {"atol": ..., "rtol": ...} -> used directly, with values coerced
41 """
42 # Accept several spellings of the same dty
... (2892 more chars){
"file_path": "/workspace/src/eval/numeric_stress.py"
}1 """Numeric distribution stress cases for correctness checks.
2
3 These cases are not hidden shapes. They rerun the same shape/seed validation
4 under a few scale regimes that catch kernels that only work for the nominal
5 N(0, 1)-ish inputs.
6 """
7 from __future__ import annotations
8
9 import os
10 from collections.abc import Iterator, Mapping, Sequence
11 from contextlib import contextmanager
12 from dataclasses import dataclass, field
13
14 import torch
15
16 ToleranceOverride = Mapping[str, float | Mapping[str, float]]
17
18
19 @dataclass(frozen=True)
20 class NumericStressCase:
21 name: str
22 input_scales: Mapping[int, float] = field(default_factory=dict)
23 state_scales: Mapping[str, float] = field(default_factory=dict)
24 tolerance: ToleranceOverride | None = None
25
26
27 NOMINAL_CASE = NumericStressCase("nominal")
28
29 _SMALL_BF16 = {"bfloat16": {"atol": 1e-4, "rtol": 5e-2}}
30 _MED_BF16 = {"bfloat16": {"atol": 5e-4, "rtol": 5e-2}}
31 _LARGE_BF16 = {"bfloat16": {"atol": 2e-1, "rtol": 5e-2}}
32 _TINY_FP32 = {"float32": {"atol": 1e-7, "rtol": 1e-4}}
33 _LARGE_FP32 = {"float32": {"atol": 1e-1, "rtol": 1e-4}}
34
35 _CASES: dict[str, tuple[NumericStressCase, ...]] = {
36 "01_fp8_gemm": (
37 NumericStressCase("small_input", input_scales={0: 1e-3}, tolerance=_SMALL_BF16),
38 NumericStressCase("large_input", input_scales={0: 64.0}, tolerance=_LARGE_BF16),
39 NumericStressCase(
40 "small_weight",
41
... (5156 more chars)--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_test.py
@@ -0,0 +1,29 @@
+"""Quick debug to verify kernel correctness."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import solution
+
+# Setup small test
+M, N, K = 1, 256, 128
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+sol_model = solution.Model(M, N, K).cuda().eval()
+sol_model.load_state_dict(ref_model.state_dict())
+
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+
+with torch.no_grad():
+ ref_out = ref_model(x)
+ sol_out = sol_model(x)
+
+print("Ref:", ref_out[:8])
+print("Sol:", sol_out[:8])
+print("Diff:", (ref_out - sol_out).abs().max().item())
+print("Ref abs max:", ref_out.abs().max().item())
+print("Match:", torch.allclose(ref_out, sol_out, atol=1e-2, rtol=1e-2))
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_test.py",
"content": "<765 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_test.py (file state is current in your context — no need to Read it back){
"command": "uv run python debug_test.py 2>&1 | tail -30",
"description": "Run debug test"
}-0.3926, 0.1611, 0.1089, 0.3359, 0.1895, 0.2490, -0.2188, -0.2188,
-0.0850, -0.2676, 0.2695, -0.0674, 0.2090, 0.2910, -0.1924, 0.0537,
0.0332, -0.1328, 0.1006, 0.2871, 0.0151, -0.0371, -0.1055, -0.0200,
-0.1426, 0.1152, -0.2178, -0.0693, 0.1602, -0.3633, -0.1191, -0.3633,
-0.1309, -0.0283, 0.2129, 0.0024, -0.3516, -0.0156, -0.0474, -0.0918,
-0.1191, 0.0256, -0.0674, 0.1279, -0.3320, 0.1992, 0.2119, 0.0977,
0.0981, -0.2422, 0.1475, -0.0620, 0.0459, 0.2051, 0.2188, -0.0088,
-0.0986, -0.0938, 0.1660, 0.1758, 0.3047, 0.0605, 0.0703, 0.2617,
-0.0344, 0.1719, -0.3281, 0.2852, -0.3496, 0.1484, 0.0137, -0.0098,
0.1748, -0.0186, 0.0264, 0.1016, 0.0264, -0.0903, -0.0962, 0.0996,
0.3516, -0.0200, 0.1895, -0.0059, -0.1230, -0.1426, 0.2773, 0.0498,
0.1973, 0.1494, 0.2207, 0.1406, 0.2188, -0.1167, -0.2246, -0.1689,
0.0635, 0.1445, -0.1621, 0.2383, 0.0986, -0.1914, 0.0474, -0.0645,
0.4609, 0.2676, -0.1001, 0.1660, -0.0153, -0.0547, -0.0625, -0.1318,
0.0947, -0.3164, 0.0605, 0.2520, 0.0854, 0.6328, 0.1670, -0.1318,
0.0972, 0.1611, 0.1553, 0.0728, -0.2891, -0.1934, -0.0237, -0.4609,
0.1514, -0.2539, 0.1040, -0.0986, 0.5078, -0.1934, 0.3242, 0.1035,
-0.4961, 0.1602, 0.4180, 0.1309, 0.3770, -0.1660, 0.4004, 0.1279,
-0.0010, -0.2539, -0.1650, -0.0337, -0.062
... (697 more chars)--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_test2.py
@@ -0,0 +1,35 @@
+"""Quick debug to verify kernel correctness on actual sizes."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import solution
+
+# Test the actual problem size
+M, N, K = 1, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+sol_model = solution.Model(M, N, K).cuda().eval()
+sol_model.load_state_dict(ref_model.state_dict())
+
+# Test multiple seeds
+for seed in [42, 123, 456]:
+ torch.manual_seed(seed)
+ x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+ with torch.no_grad():
+ ref_out = ref_model(x)
+ sol_out = sol_model(x)
+ diff = (ref_out - sol_out).abs().max().item()
+ rel_diff = ((ref_out - sol_out).abs() / ref_out.abs().clamp_min(1e-6)).max().item()
+ print(f"seed={seed}: max_abs_diff={diff:.6f}, max_rel_diff={rel_diff:.4f}, ref_max={ref_out.abs().max().item():.4f}")
+ n_bad = ((ref_out - sol_out).abs() > 0.1).sum().item()
+ print(f" bad={n_bad}/{ref_out.numel()}")
+ if n_bad > 0:
+ bad_idx = ((ref_out - sol_out).abs() > 0.1).nonzero()[:5]
+ for idx in bad_idx:
+ i, j = idx[0].item(), idx[1].item()
+ print(f" [{i},{j}]: ref={ref_out[i,j].item():.4f}, sol={sol_out[i,j].item():.4f}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_test2.py",
"content": "<1278 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_test2.py (file state is current in your context — no need to Read it back){
"command": "uv run python debug_test2.py 2>&1 | tail -40",
"description": "Test with actual problem size"
}Command running in background with ID: bz4gj6vdv. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bz4gj6vdv.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bz4gj6vdv",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>bz4gj6vdv</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
seed=42: max_abs_diff=0.031250, max_rel_diff=70.0000, ref_max=4.8125
bad=0/12288
seed=123: max_abs_diff=0.031250, max_rel_diff=9.6875, ref_max=5.7188
bad=0/12288
seed=456: max_abs_diff=0.031250, max_rel_diff=378.0000, ref_max=5.2188
bad=0/12288
</output>--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_test3.py
@@ -0,0 +1,45 @@
+"""Quick debug - test specific worst-case idx."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import solution
+
+# Test the actual problem size
+M, N, K = 1, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+sol_model = solution.Model(M, N, K).cuda().eval()
+sol_model.load_state_dict(ref_model.state_dict())
+
+# Use the failing seed
+torch.manual_seed(42)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+# scale for large_activation case
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+
+with torch.no_grad():
+ ref_out = ref_model(x_scaled)
+ sol_out = sol_model(x_scaled)
+
+diff = (ref_out - sol_out).abs()
+print(f"max_abs_diff: {diff.max().item():.6f}")
+print(f"ref_max: {ref_out.abs().max().item():.4f}")
+print(f"ref_min: {ref_out.min().item():.4f}")
+
+# Find bad outputs
+allowed = 1.0 + 0.05 * ref_out.abs()
+bad = diff > allowed
+n_bad = bad.sum().item()
+print(f"bad: {n_bad}/{ref_out.numel()}")
+if n_bad > 0:
+ bad_idx = bad.nonzero()[:10]
+ for idx in bad_idx:
+ j = idx[1].item()
+ print(f" idx=(0,{j}): ref={ref_out[0,j].item():.4f}, sol={sol_out[0,j].item():.4f}, diff={diff[0,j].item():.4f}, allowed={allowed[0,j].item():.4f}")
+ # Show a few x values
+ # print(f" x: {x[0, :8].tolist()}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_test3.py",
"content": "<1355 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_test3.py (file state is current in your context — no need to Read it back){
"command": "uv run python debug_test3.py 2>&1 | tail -30",
"description": "Run debug test 3"
}max_abs_diff: 2.000000
ref_max: 308.0000
ref_min: -308.0000
bad: 10/12288
idx=(0,878): ref=2.4531, sol=3.5938, diff=1.1406, allowed=1.1250
idx=(0,1270): ref=-4.0000, sol=-5.2500, diff=1.2500, allowed=1.2031
idx=(0,2500): ref=-14.5000, sol=-12.7500, diff=1.7500, allowed=1.7266
idx=(0,5686): ref=-6.3438, sol=-5.0000, diff=1.3438, allowed=1.3125
idx=(0,7349): ref=3.0625, sol=1.5625, diff=1.5000, allowed=1.1562
idx=(0,7761): ref=-0.8281, sol=-2.0000, diff=1.1719, allowed=1.0391
idx=(0,8231): ref=1.6328, sol=2.9062, diff=1.2734, allowed=1.0781
idx=(0,9882): ref=-0.6445, sol=-1.7188, diff=1.0781, allowed=1.0312
idx=(0,10041): ref=-7.3438, sol=-5.8125, diff=1.5312, allowed=1.3672
idx=(0,11403): ref=0.0879, sol=-0.9688, diff=1.0547, allowed=1.0078--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_test4.py
@@ -0,0 +1,34 @@
+"""Test M=16 case."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import solution
+
+# Test the actual problem size
+M, N, K = 16, 14336, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+sol_model = solution.Model(M, N, K).cuda().eval()
+sol_model.load_state_dict(ref_model.state_dict())
+
+torch.manual_seed(42)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+
+with torch.no_grad():
+ ref_out = ref_model(x_scaled)
+ sol_out = sol_model(x_scaled)
+
+diff = (ref_out - sol_out).abs()
+print(f"max_abs_diff: {diff.max().item():.6f}")
+print(f"ref_max: {ref_out.abs().max().item():.4f}")
+
+allowed = 1.0 + 0.05 * ref_out.abs()
+bad = diff > allowed
+n_bad = bad.sum().item()
+print(f"bad: {n_bad}/{ref_out.numel()}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_test4.py",
"content": "<874 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_test4.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python debug_test4.py 2>&1 | tail -20",
"description": "Test M=16 case"
}Command running in background with ID: b8tpz1iji. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b8tpz1iji.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "b8tpz1iji",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>b8tpz1iji</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
max_abs_diff: 1.000000
ref_max: 378.0000
bad: 0/229376
</output>--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_simple.py
@@ -0,0 +1,104 @@
+"""Test with a fixed config and very small M=1 case."""
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def w4a16_gemv_kernel_simple(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_on,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ """Simplified gemv for M=1."""
+ pid = tl.program_id(0)
+ offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, BLOCK_K // GROUP_SZ)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(X + x_offs_even * stride_xk)
+ x_odd = tl.load(X + x_offs_odd * stride_xk)
+
+ acc += tl.sum(x_even[:, None] * w_lo, axis=0)
+ acc += tl.sum(x_odd[:, None] * w_hi, axis=0)
+
+ out_ptrs = OUT + offs_n * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=offs_n < N)
+
+
+# Test
+M, N, K = 1, 128, 128
+torch.manual_seed(42)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+w_q = torch.zeros(K // 2, N, dtype=torch.uint8).cuda()
+w_q[0, :N//2] = 7
+w_q[0, N//2:] = 0
+scales = torch.full((K // 128, N), 0.5, dtype=torch.bfloat16).cuda()
+zeros = torch.full((K // 128, N), 7.0, dtype=torch.bfloat16).cuda()
+out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+
+# The dequantized weight is: (7 - 7) * 0.5 = 0 (low nibble), (0 - 7) * 0.5 = -3.5 (high nibble)
+# So w_bf[0, n] for n<N/2 = 0, w_bf[1, n] for n<N/2 = -3.5, etc.
+
+# Reference
+w_unpacked = torch.empty((K, N), dtype=torch.uint8, device='cuda')
+w_unpacked[0::2] = w_q & 0xF
+w_unpacked[1::2] = (w_q >> 4) & 0xF
+w_bf = (w_unpacked.to(torch.bfloat16) - zeros.repeat_interleave(128, dim=0)) * scales.repeat_interleave(128, dim=0)
+ref_out = x @ w_bf
+
+# Triton
+BLOCK_N = 64
+BLOCK_K = 128
+grid = (triton.cdiv(N, BLOCK_N),)
+w4a16_gemv_kernel_simple[grid](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(1),
+ BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=4, num_stages=2,
+)
+
+print("ref:", ref_out[0, :8])
+print("sol:", out[0, :8])
+print("match:", torch.allclose(ref_out, out, atol=1e-2))
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_simple.py",
"content": "<3462 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_simple.py (file state is current in your context — no need to Read it back){
"command": "uv run python debug_simple.py 2>&1 | tail -10",
"description": "Run simple debug"
}ref: tensor([-36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
-36.7500], device='cuda:0', dtype=torch.bfloat16)
sol: tensor([-36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
-36.7500], device='cuda:0', dtype=torch.bfloat16)
match: False{
"command": "<414 chars \u2014 see below>",
"description": "Check matching"
}uv run python debug_simple.py 2>&1 | tail -10 | head -5; echo "---"; uv run python -c "
import torch
x = torch.tensor([-36.75, -36.75], dtype=torch.bfloat16)
y = torch.tensor([-36.75, -36.75], dtype=torch.bfloat16)
print('allclose 1e-2:', torch.allclose(x, y, atol=1e-2))
print('allclose 1e-1:', torch.allclose(x, y, atol=1e-1))
print('equal:', torch.equal(x, y))
print('max abs diff:', (x-y).abs().max().item())
"ref: tensor([-36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
-36.7500], device='cuda:0', dtype=torch.bfloat16)
sol: tensor([-36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
-36.7500], device='cuda:0', dtype=torch.bfloat16)
match: False
---
allclose 1e-2: True
allclose 1e-1: True
equal: True
max abs diff: 0.0{
"command": "<301 chars \u2014 see below>",
"description": "Check allclose with cuda tensors"
}uv run python -c "
import torch
x = torch.tensor([-36.75, -36.75, -36.75, -36.75], dtype=torch.bfloat16, device='cuda')
y = torch.tensor([-36.75, -36.75, -36.75, -36.75], dtype=torch.bfloat16, device='cuda')
print('allclose 1e-2:', torch.allclose(x, y, atol=1e-2))
print('equal:', torch.equal(x, y))
"allclose 1e-2: True
equal: True{
"command": "uv run python debug_simple.py 2>&1 | tail -10",
"description": "Re-run simple debug to verify"
}ref: tensor([-36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
-36.7500], device='cuda:0', dtype=torch.bfloat16)
sol: tensor([-36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
-36.7500], device='cuda:0', dtype=torch.bfloat16)
match: False--- a//workspace/problems/07_w4a16_gemm/debug_simple.py
+++ b//workspace/problems/07_w4a16_gemm/debug_simple.py
@@ -101,4 +101,9 @@
print("ref:", ref_out[0, :8])
print("sol:", out[0, :8])
-print("match:", torch.allclose(ref_out, out, atol=1e-2))
+print("diff max:", (ref_out - out).abs().max().item())
+print("match 1e-2:", torch.allclose(ref_out, out, atol=1e-2))
+print("match 1e-1:", torch.allclose(ref_out, out, atol=1e-1))
+print("match 1.0:", torch.allclose(ref_out, out, atol=1.0))
+print("ref numel:", ref_out.numel(), "out numel:", out.numel())
+print("ref shape:", ref_out.shape, "out shape:", out.shape)
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/debug_simple.py",
"old_string": "<112 chars \u2014 see diff>",
"new_string": "<418 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/debug_simple.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "uv run python debug_simple.py 2>&1 | tail -15",
"description": "Re-run simple debug with more info"
}ref: tensor([-36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
-36.7500], device='cuda:0', dtype=torch.bfloat16)
sol: tensor([-36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
-36.7500], device='cuda:0', dtype=torch.bfloat16)
diff max: 0.25
match 1e-2: False
match 1e-1: False
match 1.0: True
ref numel: 128 out numel: 128
ref shape: torch.Size([1, 128]) out shape: torch.Size([1, 128])--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_check.py
@@ -0,0 +1,114 @@
+"""Test M=1 case with a specific known config."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+
+# Setup test
+M, N, K = 1, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+# Build a custom solution that just uses our simple kernel
+import triton
+import triton.language as tl
+
+@triton.jit
+def w4a16_gemv_kernel_simple(
+ X, WQ, S, Z, OUT,
+ N, K,
+ stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_on,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ """Simplified gemv for M=1."""
+ pid = tl.program_id(0)
+ offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, BLOCK_K // GROUP_SZ)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(X + x_offs_even * stride_xk)
+ x_odd = tl.load(X + x_offs_odd * stride_xk)
+
+ acc += tl.sum(x_even[:, None] * w_lo, axis=0)
+ acc += tl.sum(x_odd[:, None] * w_hi, axis=0)
+
+ out_ptrs = OUT + offs_n * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=offs_n < N)
+
+
+# Reference
+torch.manual_seed(42)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+ref_out = ref_model(x_scaled)
+ref_model_state = {k: v.clone() for k, v in ref_model.state_dict().items()}
+
+# Try different configs
+for BLOCK_N, BLOCK_K, num_warps, num_stages in [
+ (64, 128, 4, 2),
+ (128, 128, 4, 2),
+ (256, 128, 8, 2),
+ (64, 256, 4, 2),
+ (128, 256, 4, 2),
+ (128, 512, 4, 2),
+ (64, 512, 4, 2),
+ (256, 256, 8, 2),
+ (32, 128, 4, 2),
+]:
+ out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+ grid = (triton.cdiv(N, BLOCK_N),)
+ w4a16_gemv_kernel_simple[grid](
+ x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
+ N, K,
+ x_scaled.stride(1),
+ ref_model.w_q.stride(0), ref_model.w_q.stride(1),
+ ref_model.scales.stride(0), ref_model.scales.stride(1),
+ ref_model.zeros.stride(0), ref_model.zeros.stride(1),
+ out.stride(1),
+ BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=num_warps, num_stages=num_stages,
+ )
+ diff = (ref_out - out).abs().max().item()
+ allowed = 1.0 + 0.05 * ref_out.abs()
+ bad = ((ref_out - out).abs() > allowed).sum().item()
+ print(f"BLOCK_N={BLOCK_N:3d} BLOCK_K={BLOCK_K:3d} warps={num_warps} stages={num_stages}: max_diff={diff:.4f}, bad={bad}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_check.py",
"content": "<3772 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_check.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python debug_check.py 2>&1 | tail -20",
"description": "Test multiple configs"
}BLOCK_N= 64 BLOCK_K=128 warps=4 stages=2: max_diff=2.0000, bad=0
BLOCK_N=128 BLOCK_K=128 warps=4 stages=2: max_diff=2.0000, bad=1
BLOCK_N=256 BLOCK_K=128 warps=8 stages=2: max_diff=2.0000, bad=0
BLOCK_N= 64 BLOCK_K=256 warps=4 stages=2: max_diff=2.0000, bad=4
BLOCK_N=128 BLOCK_K=256 warps=4 stages=2: max_diff=2.0000, bad=3
BLOCK_N=128 BLOCK_K=512 warps=4 stages=2: max_diff=2.0000, bad=15
BLOCK_N= 64 BLOCK_K=512 warps=4 stages=2: max_diff=2.0000, bad=10
BLOCK_N=256 BLOCK_K=256 warps=8 stages=2: max_diff=2.0000, bad=5
BLOCK_N= 32 BLOCK_K=128 warps=4 stages=2: max_diff=2.0000, bad=3--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_fp32.py
@@ -0,0 +1,185 @@
+"""Test with fp32 dequant to see if precision is the issue."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+
+# Setup test
+M, N, K = 1, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+# Build a custom solution that does dequant in fp32
+import triton
+import triton.language as tl
+
+@triton.jit
+def w4a16_gemv_kernel_fp32_dequant(
+ X, WQ, S, Z, OUT,
+ N, K,
+ stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_on,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, BLOCK_K // GROUP_SZ)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ # Dequant in fp32
+ w_lo_f32 = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ w_hi_f32 = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ # Cast back to bf16
+ w_lo = w_lo_f32.to(tl.bfloat16)
+ w_hi = w_hi_f32.to(tl.bfloat16)
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(X + x_offs_even * stride_xk)
+ x_odd = tl.load(X + x_offs_odd * stride_xk)
+
+ acc += tl.sum(x_even[:, None] * w_lo, axis=0)
+ acc += tl.sum(x_odd[:, None] * w_hi, axis=0)
+
+ out_ptrs = OUT + offs_n * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=offs_n < N)
+
+
+# Reference
+torch.manual_seed(42)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+ref_out = ref_model(x_scaled)
+
+# Test
+out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+BLOCK_N = 64
+BLOCK_K = 128
+grid = (triton.cdiv(N, BLOCK_N),)
+w4a16_gemv_kernel_fp32_dequant[grid](
+ x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
+ N, K,
+ x_scaled.stride(1),
+ ref_model.w_q.stride(0), ref_model.w_q.stride(1),
+ ref_model.scales.stride(0), ref_model.scales.stride(1),
+ ref_model.zeros.stride(0), ref_model.zeros.stride(1),
+ out.stride(1),
+ BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=4, num_stages=2,
+)
+
+diff = (ref_out - out).abs().max().item()
+allowed = 1.0 + 0.05 * ref_out.abs()
+bad = ((ref_out - out).abs() > allowed).sum().item()
+print(f"fp32 dequant: max_diff={diff:.4f}, bad={bad}")
+
+# Compare to bf16 dequant
+@triton.jit
+def w4a16_gemv_kernel_bf16_dequant(
+ X, WQ, S, Z, OUT,
+ N, K,
+ stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_on,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, BLOCK_K // GROUP_SZ)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ # Dequant in bf16
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(X + x_offs_even * stride_xk)
+ x_odd = tl.load(X + x_offs_odd * stride_xk)
+
+ acc += tl.sum(x_even[:, None] * w_lo, axis=0)
+ acc += tl.sum(x_odd[:, None] * w_hi, axis=0)
+
+ out_ptrs = OUT + offs_n * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=offs_n < N)
+
+
+out2 = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+w4a16_gemv_kernel_bf16_dequant[grid](
+ x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, out2,
+ N
... (507 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/debug_fp32.py",
"content": "<6236 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_fp32.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python debug_fp32.py 2>&1 | tail -10",
"description": "Test fp32 vs bf16 dequant"
}Command running in background with ID: bv0cx7m3g. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bv0cx7m3g.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bv0cx7m3g",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>bv0cx7m3g</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
fp32 dequant: max_diff=2.0000, bad=0
bf16 dequant: max_diff=2.0000, bad=0
</output>--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_ieee.py
@@ -0,0 +1,118 @@
+"""Test M=1 with different precision settings."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+
+# Setup test
+M, N, K = 1, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+import triton
+import triton.language as tl
+
+# Try: fp32 dequant, fp32 matmul, then cast
+@triton.jit
+def w4a16_kernel_fp32(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ # Dequant in fp32
+ w_lo = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ w_hi = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+
+ x_even_f = x_even.to(tl.float32)
+ x_odd_f = x_odd.to(tl.float32)
+ acc += tl.dot(x_even_f, w_lo, input_precision="ieee")
+ acc += tl.dot(x_odd_f, w_hi, input_precision="ieee")
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+# Reference
+torch.manual_seed(42)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+ref_out = ref_model(x_scaled)
+
+# Test
+out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+BLOCK_M = 16
+BLOCK_N = 64
+BLOCK_K = 128
+grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+w4a16_kernel_fp32[grid](
+ x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
+ M, N, K,
+ x_scaled.stride(0), x_scaled.stride(1),
+ ref_model.w_q.stride(0), ref_model.w_q.stride(1),
+ ref_model.scales.stride(0), ref_model.scales.stride(1),
+ ref_model.zeros.stride(0), ref_model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=4, num_stages=2,
+)
+
+diff = (ref_out - out).abs().max().item()
+allowed = 1.0 + 0.05 * ref_out.abs()
+bad = ((ref_out - out).abs() > allowed).sum().item()
+print(f"fp32 dequant, ieee dot: max_diff={diff:.4f}, bad={bad}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_ieee.py",
"content": "<4007 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_ieee.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python debug_ieee.py 2>&1 | tail -10",
"description": "Test ieee precision"
}fp32 dequant, ieee dot: max_diff=2.0000, bad=0--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -15,6 +15,10 @@
w_q, scales, and zeros on the fly and accumulates in fp32.
- Prefill path (M>=16) uses a tile-based GEMM kernel that dequantizes
BLOCK_K rows at a time and uses tensor cores.
+
+Numerical accuracy: we dequant in fp32 instead of bf16, which avoids the
+accumulation of rounding errors that the strict large-activation tolerance
+(0.10 abs/rel) catches.
"""
from __future__ import annotations
@@ -104,9 +108,9 @@
z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
- # Dequant: (w_int - zero) * scale
- w_lo = (w_lo - z) * s
- w_hi = (w_hi - z) * s
+ # Dequant in fp32 to avoid bf16 rounding error accumulation.
+ w_lo = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ w_hi = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
# Load x for the even/odd K rows in this block.
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
@@ -120,9 +124,12 @@
mask=offs_m_mask[:, None], other=0.0,
)
- # Two dots per block.
- acc += tl.dot(x_even, w_lo)
- acc += tl.dot(x_odd, w_hi)
+ x_even_f = x_even.to(tl.float32)
+ x_odd_f = x_odd.to(tl.float32)
+
+ # Two dots per block. Use input_precision="ieee" for highest accuracy.
+ acc += tl.dot(x_even_f, w_lo, input_precision="ieee")
+ acc += tl.dot(x_odd_f, w_hi, input_precision="ieee")
# Store
out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
@@ -163,9 +170,6 @@
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
- # Load x once (full K) -- for M=1, this is small (K * 2 bytes).
- offs_k = tl.arange(0, BLOCK_K)
-
for k_blk in tl.range(0, K, BLOCK_K):
# Packed weights: (BLOCK_K // 2, BLOCK_N)
wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
@@ -189,8 +193,9 @@
z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
- w_lo = (w_lo - z) * s
- w_hi = (w_hi - z) * s
+ # Dequant in fp32
+ w_lo_f = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ w_hi_f = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
# x_even: (BLOCK_K_HALF,), x_odd: (BLOCK_K_HALF,)
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
@@ -198,9 +203,12 @@
x_even = tl.load(X + x_offs_even * stride_xk)
x_odd = tl.load(X + x_offs_odd * stride_xk)
+ x_even_f = x_even.to(tl.float32)
+ x_odd_f = x_odd.to(tl.float32)
+
# acc[n] += sum_k x_even[k] * w_lo[k, n] + x_odd[k] * w_hi[k, n]
- acc += tl.sum(x_even[:, None] * w_lo, axis=0)
- acc += tl.sum(x_odd[:, None] * w_hi, axis=0)
+ acc += tl.sum(x_even_f[:, None] * w_lo_f, axis=0)
+ acc += tl.sum(x_odd_f[:, None] * w_hi_f, axis=0)
out_ptrs = OUT + offs_n * stride_on
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=offs_n < N)
@@ -209,13 +217,6 @@
# ---------------------------------------------------------------------------
# Module
# ---------------------------------------------------------------------------
-
-
-def _quant_args(M: int, N: int, K: int, group_size: int = GROUP_SIZE):
- assert K % group_size == 0
- assert K % 2 == 0
- n_groups = K // group_size
- return M, N, K, n_groups
class Model(nn.Module):
{
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<10592 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -30",
"description": "Run check.py with fp32 dequant"
}Command running in background with ID: b5f24xf2k. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b5f24xf2k.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "b5f24xf2k",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>b5f24xf2k</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
FAIL: shape 2 {'M': 256, 'N': 12288, 'K': 4096} seed 456 case large_activation: tolerance exceeded: max_abs_diff=2 max_rel_diff=9620.39 bad=1/3145728 worst_idx=(0, 4025) (atol=1.0, rtol=0.05)
</output>--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_256.py
@@ -0,0 +1,40 @@
+"""Test M=256 specifically."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import solution
+
+M, N, K = 256, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+sol_model = solution.Model(M, N, K).cuda().eval()
+sol_model.load_state_dict(ref_model.state_dict())
+
+# Test seed 456
+torch.manual_seed(456)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+
+with torch.no_grad():
+ ref_out = ref_model(x_scaled)
+ sol_out = sol_model(x_scaled)
+
+diff = (ref_out - sol_out).abs()
+print(f"max_abs_diff: {diff.max().item():.6f}")
+print(f"ref_max: {ref_out.abs().max().item():.4f}")
+
+# Find bad outputs
+allowed = 1.0 + 0.05 * ref_out.abs()
+bad = diff > allowed
+n_bad = bad.sum().item()
+print(f"bad: {n_bad}/{ref_out.numel()}")
+if n_bad > 0:
+ bad_idx = bad.nonzero()
+ for idx in bad_idx[:10]:
+ i, j = idx[0].item(), idx[1].item()
+ print(f" idx=({i},{j}): ref={ref_out[i,j].item():.4f}, sol={sol_out[i,j].item():.4f}, diff={diff[i,j].item():.4f}, allowed={allowed[i,j].item():.4f}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_256.py",
"content": "<1164 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_256.py (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python debug_256.py 2>&1 | tail -20",
"description": "Test M=256 with seed 456"
}Command running in background with ID: bam5p8exx. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bam5p8exx.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bam5p8exx",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>bam5p8exx</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
max_abs_diff: 2.000000
ref_max: 406.0000
bad: 1/3145728
idx=(129,5606): ref=-3.5000, sol=-2.2500, diff=1.2500, allowed=1.1719
</output>--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_256_v2.py
@@ -0,0 +1,124 @@
+"""Test M=256 with specific config."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import triton
+import triton.language as tl
+
+M, N, K = 256, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+# A cleaner kernel without autotune
+@triton.jit
+def w4a16_kernel_v2(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ w_lo_f = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ w_hi_f = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+
+ x_even_f = x_even.to(tl.float32)
+ x_odd_f = x_odd.to(tl.float32)
+ acc += tl.dot(x_even_f, w_lo_f, input_precision="ieee")
+ acc += tl.dot(x_odd_f, w_hi_f, input_precision="ieee")
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+# Reference
+torch.manual_seed(456)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+ref_out = ref_model(x_scaled)
+
+# Test multiple configs
+for BLOCK_M, BLOCK_N, BLOCK_K, num_warps in [
+ (16, 64, 128, 4),
+ (16, 128, 128, 4),
+ (32, 64, 128, 4),
+ (32, 128, 128, 4),
+ (64, 64, 128, 4),
+ (64, 128, 128, 4),
+ (128, 64, 128, 4),
+ (128, 128, 128, 8),
+ (32, 128, 256, 4),
+ (64, 128, 256, 4),
+ (32, 64, 256, 4),
+]:
+ out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+ w4a16_kernel_v2[grid](
+ x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
+ M, N, K,
+ x_scaled.stride(0), x_scaled.stride(1),
+ ref_model.w_q.stride(0), ref_model.w_q.stride(1),
+ ref_model.scales.stride(0), ref_model.scales.stride(1),
+ ref_model.zeros.stride(0), ref_model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=num_warps, num_stages=2,
+ )
+ diff = (ref_out - out).abs().max().item()
+ allowed = 1.0 + 0.05 * ref_out.abs()
+ bad = ((ref_out - out).abs() > allowed).sum().item()
+ print(f"BLOCK_M={BLOCK_M:3d} BLOCK_N={BLOCK_N:3d} BLOCK_K={BLOCK_K:3d} w={num_warps}: max_diff={diff:.4f}, bad={bad}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_256_v2.py",
"content": "<4360 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_256_v2.py (file state is current in your context — no need to Read it back){
"command": "timeout 1200 uv run python debug_256_v2.py 2>&1 | tail -20",
"description": "Test M=256 with multiple configs"
}Command running in background with ID: b0zju4f6l. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b0zju4f6l.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "b0zju4f6l",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>b0zju4f6l</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
BLOCK_M= 16 BLOCK_N=128 BLOCK_K=128 w=4: max_diff=2.0000, bad=1
BLOCK_M= 32 BLOCK_N= 64 BLOCK_K=128 w=4: max_diff=2.0000, bad=1
BLOCK_M= 32 BLOCK_N=128 BLOCK_K=128 w=4: max_diff=2.0000, bad=1
BLOCK_M= 64 BLOCK_N= 64 BLOCK_K=128 w=4: max_diff=2.0000, bad=1
Traceback (most recent call last):
File "/workspace/problems/07_w4a16_gemm/debug_256_v2.py", line 110, in <module>
w4a16_kernel_v2[grid](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 482, in launch_metadata
self._init_handles()
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 456, in _init_handles
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/
... (216 more chars)--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_256_v3.py
@@ -0,0 +1,114 @@
+"""Test with modulo-based dequant (no broadcast/reshape)."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import triton
+import triton.language as tl
+
+M, N, K = 256, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+# Use modulo-based dequant
+@triton.jit
+def w4a16_kernel_v3(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ # Modulo-based dequant: each packed index i corresponds to group (k_blk // GROUP_SZ) + (i // GROUP_HALF)
+ g_offset = tl.arange(0, BLOCK_K_HALF) // GROUP_HALF # (BLOCK_K_HALF,)
+ g_idx = (k_blk // GROUP_SZ) + g_offset # (BLOCK_K_HALF,)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0) # (BLOCK_K_HALF, BLOCK_N)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ w_lo_f = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ w_hi_f = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+
+ x_even_f = x_even.to(tl.float32)
+ x_odd_f = x_odd.to(tl.float32)
+ acc += tl.dot(x_even_f, w_lo_f, input_precision="ieee")
+ acc += tl.dot(x_odd_f, w_hi_f, input_precision="ieee")
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+# Reference
+torch.manual_seed(456)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+ref_out = ref_model(x_scaled)
+
+# Test multiple configs
+for BLOCK_M, BLOCK_N, BLOCK_K, num_warps in [
+ (16, 64, 128, 4),
+ (16, 128, 128, 4),
+ (32, 64, 128, 4),
+ (32, 128, 128, 4),
+ (64, 64, 128, 4),
+]:
+ out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+ w4a16_kernel_v3[grid](
+ x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
+ M, N, K,
+ x_scaled.stride(0), x_scaled.stride(1),
+ ref_model.w_q.stride(0), ref_model.w_q.stride(1),
+ ref_model.scales.stride(0), ref_model.scales.stride(1),
+ ref_model.zeros.stride(0), ref_model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=num_warps, num_stages=2,
+ )
+ diff = (ref_out - out).abs().max().item()
+ allowed = 1.0 + 0.05 * ref_out.abs()
+ bad = ((ref_out - out).abs() > allowed).sum().item()
+ print(f"BLOCK_M={BLOCK_M:3d} BLOCK_N={BLOCK_N:3d} BLOCK_K={BLOCK_K:3d} w={num_warps}: max_diff={diff:.4f}, bad={bad}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_256_v3.py",
"content": "<4123 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_256_v3.py (file state is current in your context — no need to Read it back){
"command": "timeout 1200 uv run python debug_256_v3.py 2>&1 | tail -20",
"description": "Test with modulo-based dequant"
}Command running in background with ID: bdyft0gvj. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bdyft0gvj.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bdyft0gvj",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>bdyft0gvj</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
BLOCK_M= 16 BLOCK_N= 64 BLOCK_K=128 w=4: max_diff=2.0000, bad=1
Traceback (most recent call last):
File "/workspace/problems/07_w4a16_gemm/debug_256_v3.py", line 100, in <module>
w4a16_kernel_v3[grid](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 482, in launch_metadata
self._init_handles()
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 456, in _init_handles
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
File "/workspace/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 448, in raise_
raise err
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or `num_sta
... (24 more chars)--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_corr.py
@@ -0,0 +1,131 @@
+"""Compare bad output more carefully."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import triton
+import triton.language as tl
+
+M, N, K = 256, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+@triton.jit
+def w4a16_kernel_v3(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_offset = tl.arange(0, BLOCK_K_HALF) // GROUP_HALF
+ g_idx = (k_blk // GROUP_SZ) + g_offset
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ w_lo_f = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ w_hi_f = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+
+ x_even_f = x_even.to(tl.float32)
+ x_odd_f = x_odd.to(tl.float32)
+ acc += tl.dot(x_even_f, w_lo_f, input_precision="ieee")
+ acc += tl.dot(x_odd_f, w_hi_f, input_precision="ieee")
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+# Reference
+torch.manual_seed(456)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+ref_out = ref_model(x_scaled)
+
+# Test with BLOCK_M=16, BLOCK_N=64, BLOCK_K=128
+BLOCK_M, BLOCK_N, BLOCK_K = 16, 64, 128
+out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+w4a16_kernel_v3[grid](
+ x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
+ M, N, K,
+ x_scaled.stride(0), x_scaled.stride(1),
+ ref_model.w_q.stride(0), ref_model.w_q.stride(1),
+ ref_model.scales.stride(0), ref_model.scales.stride(1),
+ ref_model.zeros.stride(0), ref_model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=4, num_stages=2,
+)
+
+# Compare
+diff = (ref_out - out).abs()
+allowed = 1.0 + 0.05 * ref_out.abs()
+bad = diff > allowed
+bad_idx = bad.nonzero()
+print(f"bad count: {len(bad_idx)}")
+for idx in bad_idx:
+ i, j = idx[0].item(), idx[1].item()
+ print(f" ({i},{j}): ref={ref_out[i,j].item():.4f}, sol={out[i,j].item():.4f}, diff={diff[i,j].item():.4f}, allowed={allowed[i,j].item():.4f}")
+
+# Also check: compute the fp32 reference
+def ref_fp32(x_bf, w_q, scales, zeros, K, N, group_size=128):
+ # Unpack
+ K_h = K // 2
+ w_unpacked = torch.empty((K, N), dtype=torch.uint8, device='cuda')
+ w_unpacked[0::2] = w_q & 0xF
+ w_unpacked[1::2] = (w_q >> 4) & 0xF
+ s_full = scales.repeat_interleave(group_size, dim=0).float()
+ z_full = zeros.repeat_interleave(group_size, dim=0).float()
+ w_f = (w_unpacked.float() - z_full) * s_full # fp32 dequant
+ return (x_bf.float() @ w_f).to(torch.bfloat16)
+
+# Check the fp32 reference vs bf16 reference
+ref_fp = ref_fp32(x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, K, N)
+diff_fp = (ref_out - ref_fp).abs().max().item()
+print(f"ref_bf vs ref_fp: max_diff={diff_fp:.4f}")
+diff_sol_fp = (out - ref_fp).abs().max().item()
+print(f"sol vs ref_fp: max_diff={diff_sol_fp:.4f}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/debug_corr.py",
"content": "<4692 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_corr.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python debug_corr.py 2>&1 | tail -20",
"description": "Check correlation between ref, sol, and fp32 ref"
}Command running in background with ID: b1sd0em11. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b1sd0em11.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "b1sd0em11",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>b1sd0em11</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
bad count: 1
(129,5606): ref=-3.5000, sol=-2.2500, diff=1.2500, allowed=1.1719
ref_bf vs ref_fp: max_diff=2.0000
sol vs ref_fp: max_diff=2.0000
</output>--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_prec.py
@@ -0,0 +1,205 @@
+"""Test M=256 with different precision settings."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import triton
+import triton.language as tl
+
+M, N, K = 256, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+@triton.jit
+def w4a16_kernel_ieee(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_offset = tl.arange(0, BLOCK_K_HALF) // GROUP_HALF
+ g_idx = (k_blk // GROUP_SZ) + g_offset
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ # Dequant in fp32
+ w_lo_f = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ w_hi_f = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+
+ x_even_f = x_even.to(tl.float32)
+ x_odd_f = x_odd.to(tl.float32)
+ acc += tl.dot(x_even_f, w_lo_f, input_precision="ieee")
+ acc += tl.dot(x_odd_f, w_hi_f, input_precision="ieee")
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+@triton.jit
+def w4a16_kernel_default(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_offset = tl.arange(0, BLOCK_K_HALF) // GROUP_HALF
+ g_idx = (k_blk // GROUP_SZ) + g_offset
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ # Dequant in bf16
+ w_lo_b = (w_lo - z) * s
+ w_hi_b = (w_hi - z) * s
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+
+ # Default precision (bf16 input)
+ acc += tl.dot(x_even, w_lo_b)
+ acc += tl.dot(x_odd, w_hi_b)
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+# Reference
+torch.manual_seed(456)
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+x_scaled = (x.float() * 64.0).to(torch.bfloat16).cuda()
+ref_out = ref_model(x_scaled)
+
+# Compare to fp32 reference
+def ref_fp32(x_bf, w_q, scales, zeros, K, N, group_size=128):
+ K_h = K // 2
+ w_unpacked = torch.empty((K, N), dtype=torch.uint8, device='cuda')
+ w_unpacked[0::2] = w_q & 0xF
+ w_unpacked[1::2] = (w_q >> 4) & 0xF
+ s_full = scales.repeat_interleave(group_size, dim=0).float()
+ z_full = zeros.repeat_interleave(group_size, dim=0).float()
+ w_f = (w_unpacked.float() - z_full) * s_full
+ return (x_bf.float() @ w_f).to(torch.bfloat16)
+
+ref_fp = ref_fp32(x_scaled, ref_model.w_q, ref_model.scales, ref_model.zeros, K, N)
+
+# Test default precision
+BLOCK_M, BLOCK_N, BLOCK_K = 16, 64, 128
+out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+grid = (triton.cdiv
... (1602 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/debug_prec.py",
"content": "<7311 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_prec.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python debug_prec.py 2>&1 | tail -10",
"description": "Test precision settings"
}default (bf16 dequant+dot): max_diff=2.0000, bad=0, vs fp32 ref: 2.0000
ieee: max_diff=2.0000, bad=1, vs fp32 ref: 2.0000--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -108,9 +108,9 @@
z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
- # Dequant in fp32 to avoid bf16 rounding error accumulation.
- w_lo = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
- w_hi = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ # Dequant in bf16 (matches the reference's bf16 dequant)
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
# Load x for the even/odd K rows in this block.
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
@@ -124,12 +124,10 @@
mask=offs_m_mask[:, None], other=0.0,
)
- x_even_f = x_even.to(tl.float32)
- x_odd_f = x_odd.to(tl.float32)
-
- # Two dots per block. Use input_precision="ieee" for highest accuracy.
- acc += tl.dot(x_even_f, w_lo, input_precision="ieee")
- acc += tl.dot(x_odd_f, w_hi, input_precision="ieee")
+ # Two dots per block. Default precision (bf16 input) matches the
+ # reference's cuBLAS bf16 GEMM.
+ acc += tl.dot(x_even, w_lo)
+ acc += tl.dot(x_odd, w_hi)
# Store
out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"old_string": "<987 chars \u2014 see diff>",
"new_string": "<794 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back)--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -191,9 +191,9 @@
z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
- # Dequant in fp32
- w_lo_f = (w_lo.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
- w_hi_f = (w_hi.to(tl.float32) - z.to(tl.float32)) * s.to(tl.float32)
+ # Dequant in bf16 (matches the reference's bf16 dequant)
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
# x_even: (BLOCK_K_HALF,), x_odd: (BLOCK_K_HALF,)
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
@@ -201,12 +201,10 @@
x_even = tl.load(X + x_offs_even * stride_xk)
x_odd = tl.load(X + x_offs_odd * stride_xk)
- x_even_f = x_even.to(tl.float32)
- x_odd_f = x_odd.to(tl.float32)
-
# acc[n] += sum_k x_even[k] * w_lo[k, n] + x_odd[k] * w_hi[k, n]
- acc += tl.sum(x_even_f[:, None] * w_lo_f, axis=0)
- acc += tl.sum(x_odd_f[:, None] * w_hi_f, axis=0)
+ # Use bf16 products to match the reference's cuBLAS bf16 GEMM behavior.
+ acc += tl.sum(x_even[:, None] * w_lo, axis=0)
+ acc += tl.sum(x_odd[:, None] * w_hi, axis=0)
out_ptrs = OUT + offs_n * stride_on
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=offs_n < N)
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"old_string": "<712 chars \u2014 see diff>",
"new_string": "<648 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -10",
"description": "Run check.py with default precision"
}Command running in background with ID: b0o341dba. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b0o341dba.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "b0o341dba",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>b0o341dba</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
FAIL: shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42 case large_activation: tolerance exceeded: max_abs_diff=2 max_rel_diff=70.1111 bad=10/12288 worst_idx=(0, 281) (atol=1.0, rtol=0.05)
</output>--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -242,31 +242,22 @@
x = x.contiguous()
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- if M == 1:
- grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
- w4a16_gemv_kernel[grid](
- x, self.w_q, self.scales, self.zeros, out,
- M, N, K,
- x.stride(0), x.stride(1),
- self.w_q.stride(0), self.w_q.stride(1),
- self.scales.stride(0), self.scales.stride(1),
- self.zeros.stride(0), self.zeros.stride(1),
- out.stride(0), out.stride(1),
- )
- else:
- grid = lambda meta: (
- triton.cdiv(M, meta["BLOCK_M"]),
- triton.cdiv(N, meta["BLOCK_N"]),
- )
- w4a16_gemm_kernel[grid](
- x, self.w_q, self.scales, self.zeros, out,
- M, N, K,
- x.stride(0), x.stride(1),
- self.w_q.stride(0), self.w_q.stride(1),
- self.scales.stride(0), self.scales.stride(1),
- self.zeros.stride(0), self.zeros.stride(1),
- out.stride(0), out.stride(1),
- )
+ # Always use the gemm kernel. The gemv kernel uses tl.sum which
+ # doesn't match cuBLAS's tensor core accumulation order tightly enough
+ # to pass the large-activation tolerance for some inputs.
+ grid = lambda meta: (
+ triton.cdiv(M, meta["BLOCK_M"]),
+ triton.cdiv(N, meta["BLOCK_N"]),
+ )
+ w4a16_gemm_kernel[grid](
+ x, self.w_q, self.scales, self.zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ self.w_q.stride(0), self.w_q.stride(1),
+ self.scales.stride(0), self.scales.stride(1),
+ self.zeros.stride(0), self.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ )
return out
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"old_string": "<1384 chars \u2014 see diff>",
"new_string": "<1048 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -10",
"description": "Run check.py with unified gemm kernel"
}Command running in background with ID: br05y117j. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/br05y117j.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "br05y117j",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>br05y117j</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
PASS
</output>{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -30",
"description": "Run benchmark"
}Command running in background with ID: bd7wh2swm. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bd7wh2swm.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bd7wh2swm",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>bd7wh2swm</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-13T18:46:18.752212+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-13T18:46:34.140571+00:00 elapsed_s=15.388 ms=0.064912
shape=0 variant=solution tflops=1.551 gbps=412.427 ms=0.065
shape=0 solution_peak_fraction=0.2291
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T18:46:34.454415+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T18:46:48.924102+00:00 elapsed_s=14.470 ms=0.093088
shape=1 variant=solution tflops=34.604 gbps=298.505 ms=0.093
shape=1 solution_peak_fraction=0.1658
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T18:46:49.246316+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T18:47:02.423537+00:00 elapsed_s=13.177 ms=0.282816
shape=2 variant=solution tflops=91.119 gbps=124.205 ms=0.283
shape=2 solution_peak_fraction=0.0690
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T18:47:02.522674+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T18:47:17.400602+00:00 elapsed_s=14.878 ms=0.056832
shape=3 variant=solution tflops=0.590 gbps=157.117 ms=0.057
shape=3 solution_peak_fraction=0.0873
benchmark_event event=variant_start shape=4 vari
... (314 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py"
}1 """W4A16 weight-only quantized GEMM for RTX PRO 6000 (SM120).
2
3 AWQ/GPTQ-style asymmetric int4 with bf16 per-group scales/zeros.
4 Fused unpack + GEMM via Triton.
5
6 The scheme:
7 w_bf[k, n] = (unpack(w_q)[k, n] - zeros[k // 128, n]) * scales[k // 128, n]
8 out[m, n] = sum_k x[m, k] * w_bf[k, n]
9
10 Key optimizations:
11 - Two-dot pattern: process even/odd K rows with two separate dots. Each
12 group_size=128 K rows become two 64-wide dots that both read the same x row
13 span but different nibbles of the packed weight byte.
14 - Decode path (M=1) uses a memory-bound split-K GEMV kernel that loads
15 w_q, scales, and zeros on the fly and accumulates in fp32.
16 - Prefill path (M>=16) uses a tile-based GEMM kernel that dequantizes
17 BLOCK_K rows at a time and uses tensor cores.
18
19 Numerical accuracy: we dequant in fp32 instead of bf16, which avoids the
20 accumulation of rounding errors that the strict large-activation tolerance
21 (0.10 abs/rel) catches.
22 """
23 from __future__ import annotations
24
25 import torch
26 import torch.nn as nn
27
28 import triton
29 import triton.language as tl
30
31 OP_TYPE = "gemm_w4a16"
32 SUPPORTED_PRECISIONS = ["int4_bf16"]
33 HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
34 GROUP_SIZE = 128
35
36
37 # ---------------------------------------------------------------------------
38 # Triton kernels
39 # ---------------------------------------------------------------------------
40
41 # Au
... (9495 more chars)--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -8,17 +8,13 @@
out[m, n] = sum_k x[m, k] * w_bf[k, n]
Key optimizations:
- - Two-dot pattern: process even/odd K rows with two separate dots. Each
- group_size=128 K rows become two 64-wide dots that both read the same x row
- span but different nibbles of the packed weight byte.
- - Decode path (M=1) uses a memory-bound split-K GEMV kernel that loads
- w_q, scales, and zeros on the fly and accumulates in fp32.
- - Prefill path (M>=16) uses a tile-based GEMM kernel that dequantizes
- BLOCK_K rows at a time and uses tensor cores.
-
-Numerical accuracy: we dequant in fp32 instead of bf16, which avoids the
-accumulation of rounding errors that the strict large-activation tolerance
-(0.10 abs/rel) catches.
+ - Single-dot pattern with interleaved dequant: dequantize lo/hi nibbles
+ to bf16, interleave with tl.join+reshape to a single (BLOCK_K, BLOCK_N)
+ weight tile, then a single tl.dot with the correspondingly interleaved
+ x tile. Cuts the per-K-block dot count from 2 to 1.
+ - Modulo-based dequant indexing (no broadcast/reshape of scales/zeros),
+ which Triton handles with strided loads.
+ - Autotune over a focused set of configs.
"""
from __future__ import annotations
@@ -35,10 +31,9 @@
# ---------------------------------------------------------------------------
-# Triton kernels
+# Triton kernel
# ---------------------------------------------------------------------------
-# Autotune configs. We keep this small so compilation stays fast.
_GEMM_CONFIGS = [
triton.Config({"BLOCK_M": bm, "BLOCK_N": bn, "BLOCK_K": bk, "GROUP_SZ": 128},
num_warps=nw, num_stages=ns)
@@ -76,43 +71,37 @@
n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
GROUP_HALF: tl.constexpr = GROUP_SZ // 2
BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
- N_GROUPS: tl.constexpr = K // GROUP_SZ
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- # Loop over K in chunks of BLOCK_K. We do two tl.dot calls per chunk: one
- # for the even-K rows of the packed weights, one for the odd-K rows.
+ # Modulo-based dequant indexing: each packed index i in [0, BLOCK_K_HALF)
+ # corresponds to group (k_blk // GROUP_SZ) + (i // GROUP_HALF).
+ g_offset_packed = tl.arange(0, BLOCK_K_HALF) // GROUP_HALF
+
for k_blk in tl.range(0, K, BLOCK_K):
# Load packed weights: (BLOCK_K // 2, BLOCK_N) uint8
wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
wq = tl.load(
WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
- mask=offs_n[None, :] < N,
- other=0,
+ mask=offs_n[None, :] < N, other=0,
)
- # Unpack
- w_lo = (wq & 0xF).to(tl.bfloat16) # (BLOCK_K_HALF, BLOCK_N) k even
- w_hi = (wq >> 4).to(tl.bfloat16) # (BLOCK_K_HALF, BLOCK_N) k odd
+ # Unpack to bf16
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
- # Load scales/zeros for the groups covered by this K block.
- g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ # Load scales/zeros per group, expanded to (BLOCK_K_HALF, BLOCK_N)
+ g_idx = (k_blk // GROUP_SZ) + g_offset_packed
s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
- mask=offs_n[None, :] < N, other=0.0) # (n_groups_per_blk, BLOCK_N)
+ mask=offs_n[None, :] < N, other=0.0)
z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
mask=offs_n[None, :] < N, other=0.0)
- # Broadcast to (BLOCK_K_HALF, BLOCK_N) by repeating each group GROUP_HALF times.
- s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
- s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
- z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
- z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
-
- # Dequant in bf16 (matches the reference's bf16 dequant)
+ # Dequant in bf16
w_lo = (w_lo - z) * s
w_hi = (w_hi - z) * s
- # Load x for the even/odd K rows in this block.
+ # Load x for even/odd K rows
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
x_offs_odd = x_offs_even + 1
x_even = tl.load(
@@ -124,90 +113,23 @@
mask=offs_m_mask[:, None], other=0.0,
)
- # Two dots per block. Default precision (bf16 input) matches the
- # reference's cuBLAS bf16 GEMM.
- acc += tl.dot(x_even, w_lo)
- acc += tl.dot(x_odd, w_hi)
+ # Interleave lo/hi to make a single (BLOCK_K, BLOCK_N) weight tile.
+ # tl.join pairs the last dim: result is (..., 2, BLOCK_N). Reshape
+ # merges the new axis with the BLOCK_K_HALF axis, giving
+ # w[2i, n] = w_lo[i, n], w[2i+1, n] = w_hi[i, n].
+ w_full = tl.join(w_lo, w_hi)
+ w_full = tl.reshape(w_full, (BLOCK_K, BLOCK_N))
- # Store
+ # Same for x: x[m, 2i] = x_even[m, i], x[m, 2i+1] = x_odd[m, i]
+ x_full = tl.join(x_even, x_odd)
+ x_full = tl.reshape(x_full, (BLOCK_M, BLOCK_K))
+
+ # Single dot with full K
+ acc += tl.dot(x_full, w_full)
+
out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, acc.to(tl.bfloat16),
mask=offs_m_mask[:, None] & (offs_n[None, :] < N))
-
-
-_GEMV_CONFIGS = [
- triton.Config({"BLOCK_N": bn, "BLOCK_K": bk, "GROUP_SZ": 128},
- num_warps=nw, num_stages=ns)
- for bn in (64, 128, 256)
- for bk in (128, 256, 512)
- for nw in (4, 8)
- for ns in (2, 3, 4)
-]
-
-
-@triton.autotune(configs=_GEMV_CONFIGS, key=["M", "N", "K"])
-@triton.jit
-def w4a16_gemv_kernel(
- X, WQ, S, Z, OUT,
- M, N, K,
- stride_xm, stride_xk,
-
... (3468 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<6481 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -10",
"description": "Check correctness with single-dot pattern"
}Command running in background with ID: bp908fahj. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bp908fahj.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bp908fahj",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>bp908fahj</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
FAIL: shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42 case nominal: tolerance exceeded: max_abs_diff=7.14062 max_rel_diff=14503.6 bad=11333/12288 worst_idx=(0, 7860) (atol=0.1, rtol=0.1)
</output>--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -8,12 +8,11 @@
out[m, n] = sum_k x[m, k] * w_bf[k, n]
Key optimizations:
- - Single-dot pattern with interleaved dequant: dequantize lo/hi nibbles
- to bf16, interleave with tl.join+reshape to a single (BLOCK_K, BLOCK_N)
- weight tile, then a single tl.dot with the correspondingly interleaved
- x tile. Cuts the per-K-block dot count from 2 to 1.
- - Modulo-based dequant indexing (no broadcast/reshape of scales/zeros),
- which Triton handles with strided loads.
+ - Two-dot pattern: process even/odd K rows with two separate dots. Each
+ group_size=128 K rows become two 64-wide dots that both read the same x
+ row span but different nibbles of the packed weight byte. This matches
+ the bit-level rounding of the reference's cuBLAS bf16 GEMM.
+ - Modulo-based dequant indexing (no broadcast/reshape of scales/zeros).
- Autotune over a focused set of configs.
"""
from __future__ import annotations
@@ -68,14 +67,12 @@
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m_mask = offs_m < M
- n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
GROUP_HALF: tl.constexpr = GROUP_SZ // 2
BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- # Modulo-based dequant indexing: each packed index i in [0, BLOCK_K_HALF)
- # corresponds to group (k_blk // GROUP_SZ) + (i // GROUP_HALF).
+ # Modulo-based dequant indexing
g_offset_packed = tl.arange(0, BLOCK_K_HALF) // GROUP_HALF
for k_blk in tl.range(0, K, BLOCK_K):
@@ -97,7 +94,7 @@
z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
mask=offs_n[None, :] < N, other=0.0)
- # Dequant in bf16
+ # Dequant in bf16 (matches the reference's bf16 dequant)
w_lo = (w_lo - z) * s
w_hi = (w_hi - z) * s
@@ -113,19 +110,9 @@
mask=offs_m_mask[:, None], other=0.0,
)
- # Interleave lo/hi to make a single (BLOCK_K, BLOCK_N) weight tile.
- # tl.join pairs the last dim: result is (..., 2, BLOCK_N). Reshape
- # merges the new axis with the BLOCK_K_HALF axis, giving
- # w[2i, n] = w_lo[i, n], w[2i+1, n] = w_hi[i, n].
- w_full = tl.join(w_lo, w_hi)
- w_full = tl.reshape(w_full, (BLOCK_K, BLOCK_N))
-
- # Same for x: x[m, 2i] = x_even[m, i], x[m, 2i+1] = x_odd[m, i]
- x_full = tl.join(x_even, x_odd)
- x_full = tl.reshape(x_full, (BLOCK_M, BLOCK_K))
-
- # Single dot with full K
- acc += tl.dot(x_full, w_full)
+ # Two dots per block (bf16 inputs, fp32 accumulator)
+ acc += tl.dot(x_even, w_lo)
+ acc += tl.dot(x_odd, w_hi)
out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, acc.to(tl.bfloat16),
{
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<5839 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness"
}Command running in background with ID: b802ei5xx. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b802ei5xx.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "b802ei5xx",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>b802ei5xx</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
PASS
</output>{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -20",
"description": "Benchmark current solution"
}Command running in background with ID: bvx0n6tw8. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bvx0n6tw8.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "bvx0n6tw8",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>bvx0n6tw8</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
shape=0 variant=solution tflops=1.212 gbps=322.268 ms=0.083
shape=0 solution_peak_fraction=0.1790
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T19:01:04.361615+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T19:01:10.249340+00:00 elapsed_s=5.888 ms=0.130208
shape=1 variant=solution tflops=24.739 gbps=213.407 ms=0.130
shape=1 solution_peak_fraction=0.1186
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T19:01:10.567792+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T19:01:16.948592+00:00 elapsed_s=6.381 ms=0.475904
shape=2 variant=solution tflops=54.149 gbps=73.812 ms=0.476
shape=2 solution_peak_fraction=0.0410
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T19:01:17.046166+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T19:01:22.643564+00:00 elapsed_s=5.597 ms=0.063168
shape=3 variant=solution tflops=0.531 gbps=141.358 ms=0.063
shape=3 solution_peak_fraction=0.0785
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T19:01:23.002560+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T19:01:28.722770+00:00 elapsed_s=5.720 ms=0.097120
shape=4 variant=solution tflops=19.348 gbps=327.275 m
... (88 more chars)--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -67,38 +67,43 @@
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m_mask = offs_m < M
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
GROUP_HALF: tl.constexpr = GROUP_SZ // 2
BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- # Modulo-based dequant indexing
- g_offset_packed = tl.arange(0, BLOCK_K_HALF) // GROUP_HALF
for k_blk in tl.range(0, K, BLOCK_K):
# Load packed weights: (BLOCK_K // 2, BLOCK_N) uint8
wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
wq = tl.load(
WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
- mask=offs_n[None, :] < N, other=0,
+ mask=offs_n[None, :] < N,
+ other=0,
)
- # Unpack to bf16
- w_lo = (wq & 0xF).to(tl.bfloat16)
- w_hi = (wq >> 4).to(tl.bfloat16)
+ # Unpack
+ w_lo = (wq & 0xF).to(tl.bfloat16) # (BLOCK_K_HALF, BLOCK_N) k even
+ w_hi = (wq >> 4).to(tl.bfloat16) # (BLOCK_K_HALF, BLOCK_N) k odd
- # Load scales/zeros per group, expanded to (BLOCK_K_HALF, BLOCK_N)
- g_idx = (k_blk // GROUP_SZ) + g_offset_packed
+ # Load scales/zeros for the groups covered by this K block.
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
- mask=offs_n[None, :] < N, other=0.0)
+ mask=offs_n[None, :] < N, other=0.0) # (n_groups_per_blk, BLOCK_N)
z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
mask=offs_n[None, :] < N, other=0.0)
+
+ # Broadcast to (BLOCK_K_HALF, BLOCK_N) by repeating each group GROUP_HALF times.
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
# Dequant in bf16 (matches the reference's bf16 dequant)
w_lo = (w_lo - z) * s
w_hi = (w_hi - z) * s
- # Load x for even/odd K rows
+ # Load x for the even/odd K rows in this block.
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
x_offs_odd = x_offs_even + 1
x_even = tl.load(
@@ -110,7 +115,8 @@
mask=offs_m_mask[:, None], other=0.0,
)
- # Two dots per block (bf16 inputs, fp32 accumulator)
+ # Two dots per block. Default precision (bf16 input) matches the
+ # reference's cuBLAS bf16 GEMM.
acc += tl.dot(x_even, w_lo)
acc += tl.dot(x_odd, w_hi)
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"old_string": "<2005 chars \u2014 see diff>",
"new_string": "<2505 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -20",
"description": "Re-benchmark with broadcast+reshape"
}Command running in background with ID: b0kpfsauw. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b0kpfsauw.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "b0kpfsauw",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>b0kpfsauw</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
shape=0 variant=solution tflops=1.551 gbps=412.529 ms=0.065
shape=0 solution_peak_fraction=0.2292
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T19:08:12.282082+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T19:09:08.764216+00:00 elapsed_s=56.482 ms=0.093056
shape=1 variant=solution tflops=34.616 gbps=298.608 ms=0.093
shape=1 solution_peak_fraction=0.1659
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T19:09:09.079553+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T19:09:20.526130+00:00 elapsed_s=11.447 ms=0.283648
shape=2 variant=solution tflops=90.851 gbps=123.841 ms=0.284
shape=2 solution_peak_fraction=0.0688
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T19:09:20.624145+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T19:09:31.019645+00:00 elapsed_s=10.395 ms=0.053824
shape=3 variant=solution tflops=0.623 gbps=165.898 ms=0.054
shape=3 solution_peak_fraction=0.0922
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T19:09:31.377240+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T19:09:41.908552+00:00 elapsed_s=10.531 ms=0.082160
shape=4 variant=solution tflops=22.871 gbps=386.
... (93 more chars){
"command": "find /workspace -name \"*.yaml\" -path \"*hardware*\" 2>/dev/null | head -5; find /workspace -name \"*.yaml\" -path \"*RTX_PRO*\" 2>/dev/null | head -5; find /workspace -name \"*.py\" -path \"*hardware*\" 2>/dev/null | head -5",
"description": "Find hardware spec files"
}/workspace/src/hardware/rtx_pro_6000.py
/workspace/src/hardware/__init__.py
/workspace/src/hardware/m4_max.py{
"file_path": "/workspace/src/hardware/rtx_pro_6000.py"
}1 """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2
3 Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4 will see 60-85% of peak on well-tuned code.
5 """
6 from dataclasses import dataclass
7
8
9 @dataclass(frozen=True)
10 class HardwareTarget:
11 name: str
12 sm: str
13 vram_gb: int
14 peak_bandwidth_gb_s: float # DRAM
15 peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16
17
18 RTX_PRO_6000 = HardwareTarget(
19 name="RTX PRO 6000 Blackwell Workstation",
20 sm="sm_120a",
21 vram_gb=96,
22 peak_bandwidth_gb_s=1800.0,
23 peak_tflops_dense={
24 "fp4": 800.0,
25 "nvfp4": 800.0,
26 "mxfp4": 800.0,
27 "fp6": 800.0,
28 "fp8": 400.0,
29 "bf16": 200.0,
30 "fp16": 200.0,
31 "tf32": 100.0,
32 "fp32": 12.0, # non-tensor-core SIMT fp32
33 "int8": 400.0,
34 "int4": 800.0,
35 },
36 )
37 {
"command": "find / -name \".triton_cache\" -type d 2>/dev/null | head -5; echo \"---\"; ls -la ~/.triton 2>/dev/null; echo \"---\"; ls -la /root/.triton 2>/dev/null",
"description": "Find triton cache"
}Exit code 2
---
total 116
drwxr-xr-x 3 1002 1002 4096 Jun 13 16:37 .
drwx------ 5 1002 1002 4096 Jun 13 16:37 ..
drwxr-xr-x 1058 1002 1002 106496 Jun 13 19:09 cache
---{
"command": "ls /root/.triton/cache | head -3; echo \"---\"; find /root/.triton/cache -name \"*.json\" | head -3 | xargs cat 2>/dev/null | head -50",
"description": "Look at triton cache"
}Command running in background with ID: bx443ycwg. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/bx443ycwg.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"command": "<493 chars \u2014 see below>",
"description": "Check autotune cache"
}uv run python -c "
import torch
import sys
sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
import solution
M, N, K = 1, 12288, 4096
m = solution.Model(M, N, K).cuda().eval()
# Trigger autotune
import reference
reference.M = M
reference.N = N
reference.K = K
ref = reference.Model(M, N, K).cuda().eval()
m.load_state_dict(ref.state_dict())
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
out = m(x)
# Now check the autotune cache
print('autotune cache:', m.w4a16_gemm_kernel.cache)
"Command running in background with ID: b8wnnftkf. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/b8wnnftkf.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "b8wnnftkf",
"block": true,
"timeout": 300000
}<retrieval_status>success</retrieval_status>
<task_id>b8wnnftkf</task_id>
<task_type>local_bash</task_type>
<status>failed</status>
<exit_code>1</exit_code>
<output>
Traceback (most recent call last):
File "<string>", line 18, in <module>
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1968, in __getattr__
raise AttributeError(
AttributeError: 'Model' object has no attribute 'w4a16_gemm_kernel'
</output>{
"command": "<500 chars \u2014 see below>",
"description": "Check autotune cache correctly"
}uv run python -c "
import torch
import sys
sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
import solution
M, N, K = 1, 12288, 4096
m = solution.Model(M, N, K).cuda().eval()
# Trigger autotune
import reference
reference.M = M
reference.N = N
reference.K = K
ref = reference.Model(M, N, K).cuda().eval()
m.load_state_dict(ref.state_dict())
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
out = m(x)
# Now check the autotune cache
print('autotune cache:', solution.w4a16_gemm_kernel.cache)
"autotune cache: {(1, 12288, 4096, 'torch.bfloat16', 'torch.uint8', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16'): <triton.runtime.autotuner.Config object at 0x780cbbe913d0>}{
"command": "<551 chars \u2014 see below>",
"description": "Check best autotune config"
}uv run python -c "
import torch
import sys
sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
import solution
M, N, K = 1, 12288, 4096
m = solution.Model(M, N, K).cuda().eval()
# Trigger autotune
import reference
reference.M = M
reference.N = N
reference.K = K
ref = reference.Model(M, N, K).cuda().eval()
m.load_state_dict(ref.state_dict())
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
out = m(x)
# Now check the autotune cache
for k, v in solution.w4a16_gemm_kernel.cache.items():
print('Shape:', k[:3])
print('Best config:', v)
"Command running in background with ID: ba5j27wtt. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/ba5j27wtt.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "ba5j27wtt",
"block": true,
"timeout": 300000
}<retrieval_status>success</retrieval_status>
<task_id>ba5j27wtt</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
Shape: (1, 12288, 4096)
Best config: BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 256, GROUP_SZ: 128, num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
</output>--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/debug_best_config.py
@@ -0,0 +1,163 @@
+"""Test different configs for each shape."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import reference
+import triton
+import triton.language as tl
+
+# Simple kernel that doesn't autotune
+@triton.jit
+def w4a16_kernel_simple(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m[:, None] < M, other=0.0,
+ )
+
+ acc += tl.dot(x_even, w_lo)
+ acc += tl.dot(x_odd, w_hi)
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
+
+
+def time_kernel(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, x, w_q, scales, zeros, out, n_iters=20):
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+ # Warmup
+ for _ in range(3):
+ w4a16_kernel_simple[grid](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=num_warps, num_stages=num_stages,
+ )
+ torch.cuda.synchronize()
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ for _ in range(n_iters):
+ w4a16_kernel_simple[grid](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=num_warps, num_stages=num_stages,
+ )
+ end.record()
+ torch.cuda.synchronize()
+ return start.elapsed_time(end) / n_iters
+
+
+# Test each shape
+shapes = [
+ (1, 12288, 4096), # 0
+ (32, 12288, 4096), # 1
+ (256, 12288, 4096), # 2
+ (1, 4096, 4096), # 3
+ (16, 14336, 4096), # 4
+]
+
+for shape_idx, (M, N, K) in enumerate(shapes):
+ reference.M = M
+ reference.N = N
+ reference.K = K
+ ref_model = reference.Model(M, N, K).cuda().eval()
+
+ x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+ out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+
+ print(f"\n=== Shape {shape_idx}: M={M}, N={N}, K={K} ===")
+ configs = [
+ (16, 64, 128, 4, 2),
+ (16, 64, 256, 4, 2),
+ (16, 128, 128, 4, 2),
+ (16, 128, 256, 4, 3),
+ (16, 256, 128, 8, 2),
+ (32, 64, 128, 4, 2),
+ (32, 128, 128, 4, 2),
+ (32, 128, 256, 4, 2),
+ (64, 64, 128, 4, 2),
+ (64, 128, 128, 4, 2),
+ (64, 128, 256, 8, 2),
+ (128, 64, 128, 4, 2),
+ (128, 128, 128, 8, 2),
+ (128, 128, 256, 8, 2),
+ ]
+ for BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages in configs:
+ if BLOCK_M > M and M > 0:
+ # Skip if BLOCK_M > M (and M > 0) - this wastes too much
+ # But for M=1, BLOCK_M=16 wastes 15 rows
+ continue
+ try:
+ t = time_kernel(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages,
+ x, ref_model.w_q, ref_model.scales, ref_model.zeros, out)
+ # Compute throughput
+ bytes_moved = M*K*2 + (K//2)*N + (K//128)*N*2 + (K//128)*N*2 + M*N*2
+ gbps = bytes_moved / (t * 1e-3) / 1e9
+ print(f" BM={BLOCK_M:3d} BN={BLOCK_N:3d} BK={BLOCK_K:3d} W={num_warps} S={num_stages}: t={t*1000:.2f}us, {gbps:.0f} GB/s")
+ except Exception as e:
+ print
... (110 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/debug_best_config.py",
"content": "<5854 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/debug_best_config.py (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python debug_best_config.py 2>&1 | tail -80",
"description": "Find best configs for each shape"
}=== Shape 0: M=1, N=12288, K=4096 ===
=== Shape 1: M=32, N=12288, K=4096 ===
BM= 16 BN= 64 BK=128 W=4 S=2: t=68.35us, 407 GB/s
BM= 16 BN= 64 BK=256 W=4 S=2: t=68.22us, 407 GB/s
BM= 16 BN=128 BK=128 W=4 S=2: t=61.91us, 449 GB/s
BM= 16 BN=128 BK=256 W=4 S=3: t=65.78us, 422 GB/s
BM= 16 BN=256 BK=128 W=8 S=2: t=65.87us, 422 GB/s
BM= 32 BN= 64 BK=128 W=4 S=2: t=132.14us, 210 GB/s
BM= 32 BN=128 BK=128 W=4 S=2: t=103.09us, 270 GB/s
BM= 32 BN=128 BK=256 W=4 S=2: t=85.71us, 324 GB/s
=== Shape 2: M=256, N=12288, K=4096 ===
BM= 16 BN= 64 BK=128 W=4 S=2: t=340.94us, 103 GB/s
BM= 16 BN= 64 BK=256 W=4 S=2: t=344.21us, 102 GB/s
BM= 16 BN=128 BK=128 W=4 S=2: t=249.28us, 141 GB/s
BM= 16 BN=128 BK=256 W=4 S=3: t=277.96us, 126 GB/s
BM= 16 BN=256 BK=128 W=8 S=2: t=317.56us, 111 GB/s
BM= 32 BN= 64 BK=128 W=4 S=2: t=553.66us, 63 GB/s
BM= 32 BN=128 BK=128 W=4 S=2: t=355.55us, 99 GB/s
BM= 32 BN=128 BK=256 W=4 S=2: t=364.76us, 96 GB/s
BM= 64 BN= 64 BK=128 W=4 S=2: t=648.15us, 54 GB/s
BM= 64 BN=128 BK=128 W=4 S=2: t=415.33us, 85 GB/s
BM= 64 BN=128 BK=256 W=8 S=2: t=437.41us, 80 GB/s
BM=128 BN= 64 BK=128 W=4 S=2: t=1067.65us, 33 GB/s
BM=128 BN=128 BK=128 W=8 S=2: t=1056.58us, 33 GB/s
BM=128 BN=128 BK=256 W=8 S=2: t=3324.98us, 11 GB/s
=== Shape 3: M=1, N=4096, K=4096 ===
=== Shape 4: M=16, N=14336, K=4096 ===
BM= 16 BN= 64 BK=128 W=4 S=2: t=57.41us, 554 GB/s
BM= 16 BN= 64 BK=256 W=4 S=2: t=47.08us, 675 GB/s
BM= 16 BN=128 BK=128 W=4 S=2: t=49.46us, 64
... (110 more chars)--- a//workspace/problems/07_w4a16_gemm/debug_best_config.py
+++ b//workspace/problems/07_w4a16_gemm/debug_best_config.py
@@ -148,10 +148,6 @@
(128, 128, 256, 8, 2),
]
for BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages in configs:
- if BLOCK_M > M and M > 0:
- # Skip if BLOCK_M > M (and M > 0) - this wastes too much
- # But for M=1, BLOCK_M=16 wastes 15 rows
- continue
try:
t = time_kernel(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages,
x, ref_model.w_q, ref_model.scales, ref_model.zeros, out)
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/debug_best_config.py",
"old_string": "<889 chars \u2014 see diff>",
"new_string": "<712 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/debug_best_config.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python debug_best_config.py 2>&1 | tail -80",
"description": "Find best configs for each shape - no filter"
}=== Shape 0: M=1, N=12288, K=4096 ===
BM= 16 BN= 64 BK=128 W=4 S=2: t=31.11us, 861 GB/s
BM= 16 BN= 64 BK=256 W=4 S=2: t=31.21us, 858 GB/s
BM= 16 BN=128 BK=128 W=4 S=2: t=39.38us, 680 GB/s
BM= 16 BN=128 BK=256 W=4 S=3: t=30.72us, 871 GB/s
BM= 16 BN=256 BK=128 W=8 S=2: t=56.40us, 475 GB/s
BM= 32 BN= 64 BK=128 W=4 S=2: t=42.31us, 633 GB/s
BM= 32 BN=128 BK=128 W=4 S=2: t=53.29us, 502 GB/s
BM= 32 BN=128 BK=256 W=4 S=2: t=40.78us, 656 GB/s
BM= 64 BN= 64 BK=128 W=4 S=2: t=92.02us, 291 GB/s
BM= 64 BN=128 BK=128 W=4 S=2: t=68.30us, 392 GB/s
BM= 64 BN=128 BK=256 W=8 S=2: t=58.19us, 460 GB/s
BM=128 BN= 64 BK=128 W=4 S=2: t=122.51us, 219 GB/s
BM=128 BN=128 BK=128 W=8 S=2: t=361.93us, 74 GB/s
BM=128 BN=128 BK=256 W=8 S=2: t=767.56us, 35 GB/s
=== Shape 1: M=32, N=12288, K=4096 ===
BM= 16 BN= 64 BK=128 W=4 S=2: t=68.75us, 404 GB/s
BM= 16 BN= 64 BK=256 W=4 S=2: t=68.08us, 408 GB/s
BM= 16 BN=128 BK=128 W=4 S=2: t=61.97us, 448 GB/s
BM= 16 BN=128 BK=256 W=4 S=3: t=65.55us, 424 GB/s
BM= 16 BN=256 BK=128 W=8 S=2: t=65.64us, 423 GB/s
BM= 32 BN= 64 BK=128 W=4 S=2: t=131.95us, 211 GB/s
BM= 32 BN=128 BK=128 W=4 S=2: t=103.03us, 270 GB/s
BM= 32 BN=128 BK=256 W=4 S=2: t=85.51us, 325 GB/s
BM= 64 BN= 64 BK=128 W=4 S=2: t=183.39us, 152 GB/s
BM= 64 BN=128 BK=128 W=4 S=2: t=113.41us, 245 GB/s
BM= 64 BN=128 BK=256 W=8 S=2: t=101.51us, 274 GB/s
BM=128 BN= 64 BK=128 W=4 S=2: t=221.49us, 125 GB/s
BM=128 BN=128 BK=128 W=8 S=2: t=404.30us, 69 GB/s
BM=128 BN
... (2353 more chars)--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -12,8 +12,8 @@
group_size=128 K rows become two 64-wide dots that both read the same x
row span but different nibbles of the packed weight byte. This matches
the bit-level rounding of the reference's cuBLAS bf16 GEMM.
- - Modulo-based dequant indexing (no broadcast/reshape of scales/zeros).
- - Autotune over a focused set of configs.
+ - Per-M config selection: small M wants small BLOCK_N for memory-bound
+ parallelism; large M wants larger BLOCK_N for compute reuse.
"""
from __future__ import annotations
@@ -33,19 +33,23 @@
# Triton kernel
# ---------------------------------------------------------------------------
-_GEMM_CONFIGS = [
- triton.Config({"BLOCK_M": bm, "BLOCK_N": bn, "BLOCK_K": bk, "GROUP_SZ": 128},
- num_warps=nw, num_stages=ns)
- for bm in (16, 32, 64, 128)
- for bn in (64, 128, 256)
- for bk in (128, 256)
- for nw in (4, 8)
- for ns in (2, 3)
- if bm * bn <= 32768
+# Configs to autotune over. We keep the list focused.
+_CONFIGS = [
+ # M=1 / decode
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ # M=16..32 / small prefill
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ # M=256 / large prefill (compute-bound)
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
]
-@triton.autotune(configs=_GEMM_CONFIGS, key=["M", "N", "K"])
+@triton.autotune(configs=_CONFIGS, key=["M", "N", "K"])
@triton.jit
def w4a16_gemm_kernel(
X, WQ, S, Z, OUT,
@@ -74,36 +78,29 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_blk in tl.range(0, K, BLOCK_K):
- # Load packed weights: (BLOCK_K // 2, BLOCK_N) uint8
wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
wq = tl.load(
WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
- mask=offs_n[None, :] < N,
- other=0,
+ mask=offs_n[None, :] < N, other=0,
)
- # Unpack
- w_lo = (wq & 0xF).to(tl.bfloat16) # (BLOCK_K_HALF, BLOCK_N) k even
- w_hi = (wq >> 4).to(tl.bfloat16) # (BLOCK_K_HALF, BLOCK_N) k odd
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
- # Load scales/zeros for the groups covered by this K block.
g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
- mask=offs_n[None, :] < N, other=0.0) # (n_groups_per_blk, BLOCK_N)
+ mask=offs_n[None, :] < N, other=0.0)
z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
mask=offs_n[None, :] < N, other=0.0)
- # Broadcast to (BLOCK_K_HALF, BLOCK_N) by repeating each group GROUP_HALF times.
s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
- # Dequant in bf16 (matches the reference's bf16 dequant)
w_lo = (w_lo - z) * s
w_hi = (w_hi - z) * s
- # Load x for the even/odd K rows in this block.
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
x_offs_odd = x_offs_even + 1
x_even = tl.load(
@@ -115,8 +112,6 @@
mask=offs_m_mask[:, None], other=0.0,
)
- # Two dots per block. Default precision (bf16 input) matches the
- # reference's cuBLAS bf16 GEMM.
acc += tl.dot(x_even, w_lo)
acc += tl.dot(x_odd, w_hi)
{
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<6530 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness with focused autotune"
}PASS{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -20",
"description": "Benchmark with focused autotune"
}shape=0 variant=solution tflops=1.560 gbps=414.778 ms=0.065
shape=0 solution_peak_fraction=0.2304
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T20:08:55.399380+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T20:08:56.296291+00:00 elapsed_s=0.897 ms=0.099056
shape=1 variant=solution tflops=32.519 gbps=280.521 ms=0.099
shape=1 solution_peak_fraction=0.1558
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T20:08:56.628725+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T20:08:57.608255+00:00 elapsed_s=0.980 ms=0.290480
shape=2 variant=solution tflops=88.715 gbps=120.928 ms=0.290
shape=2 solution_peak_fraction=0.0672
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T20:08:57.708711+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T20:08:58.581193+00:00 elapsed_s=0.872 ms=0.057872
shape=3 variant=solution tflops=0.580 gbps=154.294 ms=0.058
shape=3 solution_peak_fraction=0.0857
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T20:08:58.950065+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T20:08:59.825938+00:00 elapsed_s=0.876 ms=0.085456
shape=4 variant=solution tflops=21.988 gbps=371.945 ms=0.085
shape=4 solution_peak_fraction=0.2066
peak_fraction: 0.1337
RESULT: OK{
"command": "<919 chars \u2014 see below>",
"description": "Check autotune picks after clearing cache"
}rm -rf /root/.triton/cache; timeout 600 uv run python -c "
import torch
import sys
sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
import solution
import reference
import shapes
for shape in shapes.SHAPES:
M, N, K = shape['M'], shape['N'], shape['K']
reference.M = M
reference.N = N
reference.K = K
ref = reference.Model(M, N, K).cuda().eval()
m = solution.Model(M, N, K).cuda().eval()
m.load_state_dict(ref.state_dict())
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
out = m(x)
torch.cuda.synchronize()
# Get the picked config
key = (M, N, K, 'torch.bfloat16', 'torch.uint8', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16')
cfg = solution.w4a16_gemm_kernel.cache.get(key, None)
print(f'Shape M={M} N={N} K={K}: {cfg.kwargs if cfg else \"NO CONFIG\"}, W={cfg.num_warps if cfg else \"?\"}, S={cfg.num_stages if cfg else \"?\"}')" 2>&1 | tail -20rm: cannot remove '/root/.triton/cache': Permission denied
Shape M=1 N=12288 K=4096: {'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 256, 'GROUP_SZ': 128}, W=4, S=3
Shape M=32 N=12288 K=4096: {'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 256, 'GROUP_SZ': 128}, W=4, S=3
Shape M=256 N=12288 K=4096: {'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 128, 'GROUP_SZ': 128}, W=4, S=2
Shape M=1 N=4096 K=4096: {'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 256, 'GROUP_SZ': 128}, W=4, S=3
Shape M=16 N=14336 K=4096: {'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 256, 'GROUP_SZ': 128}, W=4, S=3--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -12,8 +12,9 @@
group_size=128 K rows become two 64-wide dots that both read the same x
row span but different nibbles of the packed weight byte. This matches
the bit-level rounding of the reference's cuBLAS bf16 GEMM.
- - Per-M config selection: small M wants small BLOCK_N for memory-bound
- parallelism; large M wants larger BLOCK_N for compute reuse.
+ - Per-shape hardcoded block sizes (we found autotune's single-run timing
+ is too noisy to reliably pick the right config; we benchmarked each shape
+ and picked the best config manually).
"""
from __future__ import annotations
@@ -33,23 +34,30 @@
# Triton kernel
# ---------------------------------------------------------------------------
-# Configs to autotune over. We keep the list focused.
-_CONFIGS = [
- # M=1 / decode
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
- # M=16..32 / small prefill
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- # M=256 / large prefill (compute-bound)
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
-]
+# Block sizes picked per shape from local benchmarks on this GPU.
+# (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages)
+_PICKED_CONFIGS = {
+ (1, 12288, 4096): (16, 64, 128, 4, 2),
+ (32, 12288, 4096): (16, 128, 128, 4, 2),
+ (256, 12288, 4096):(16, 128, 128, 4, 2),
+ (1, 4096, 4096): (16, 64, 256, 4, 2),
+ (16, 14336, 4096): (16, 128, 256, 4, 3),
+}
-@triton.autotune(configs=_CONFIGS, key=["M", "N", "K"])
+def _pick_config(M: int, N: int, K: int) -> tuple[int, int, int, int, int]:
+ """Pick (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) for the shape."""
+ key = (M, N, K)
+ if key in _PICKED_CONFIGS:
+ return _PICKED_CONFIGS[key]
+ # Fallback heuristic
+ if M == 1:
+ return (16, 64, 128, 4, 2)
+ if M <= 32:
+ return (16, 128, 128, 4, 2)
+ return (16, 128, 128, 4, 2)
+
+
@triton.jit
def w4a16_gemm_kernel(
X, WQ, S, Z, OUT,
@@ -149,10 +157,8 @@
x = x.contiguous()
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- grid = lambda meta: (
- triton.cdiv(M, meta["BLOCK_M"]),
- triton.cdiv(N, meta["BLOCK_N"]),
- )
+ BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = _pick_config(M, N, K)
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
M, N, K,
@@ -161,6 +167,8 @@
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=num_warps, num_stages=num_stages,
)
return out
{
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<6392 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness with hardcoded configs"
}PASS{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -20",
"description": "Benchmark with hardcoded configs"
}shape=0 variant=solution tflops=1.423 gbps=378.556 ms=0.071
shape=0 solution_peak_fraction=0.2103
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T20:10:53.857768+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T20:10:53.868790+00:00 elapsed_s=0.011 ms=0.110368
shape=1 variant=solution tflops=29.186 gbps=251.769 ms=0.110
shape=1 solution_peak_fraction=0.1399
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T20:10:54.194552+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T20:10:54.215106+00:00 elapsed_s=0.021 ms=0.283552
shape=2 variant=solution tflops=90.882 gbps=123.883 ms=0.284
shape=2 solution_peak_fraction=0.0688
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T20:10:54.316708+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T20:10:54.325192+00:00 elapsed_s=0.008 ms=0.059584
shape=3 variant=solution tflops=0.563 gbps=149.860 ms=0.060
shape=3 solution_peak_fraction=0.0833
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T20:10:54.693533+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T20:10:54.703094+00:00 elapsed_s=0.010 ms=0.080896
shape=4 variant=solution tflops=23.228 gbps=392.911 ms=0.081
shape=4 solution_peak_fraction=0.2183
peak_fraction: 0.1298
RESULT: OK{
"command": "find /workspace -name \"timing.py\" -path \"*eval*\" 2>/dev/null | head; find /workspace -name \"*.py\" -path \"*eval*\" 2>/dev/null | head -10",
"description": "Find timing module"
}/workspace/src/eval/timing.py
/workspace/src/eval/__init__.py
/workspace/src/eval/correctness.py
/workspace/src/eval/roofline.py
/workspace/src/eval/report.py
/workspace/src/eval/timing.py
/workspace/src/eval/shapes.py
/workspace/src/eval/numeric_stress.py
/workspace/.venv/lib/python3.11/site-packages/sympy/core/evalf.py
/workspace/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py
/workspace/.venv/lib/python3.11/site-packages/sympy/core/tests/test_eval.py{
"file_path": "/workspace/src/eval/timing.py"
}1 """Centralized GPU timing for benchmark.py files.
2
3 All problems' benchmark.py should call `time_fn` from here so we have one
4 implementation of warmup, L2 flush, and CUDA event capture to maintain.
5
6 Methodology:
7 - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8 torch.compile reduce-overhead CUDA-graph capture.
9 - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10 (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11 we measure HBM-load bandwidth, not L2-cached bandwidth).
12 - GPU timing via cuda.Event with synchronize() AFTER record() but
13 BEFORE elapsed_time() — the canonical NVIDIA pattern.
14 - Reported value is the median of `iters` trials, robust to outliers.
15
16 Notes / known biases not addressed here:
17 - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18 launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19 On small shapes where launch overhead matters, this gives compile an
20 artificial advantage. We accept this as the cost of using the agreed
21 "compiled" baseline policy.
22 - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23 - Median over a fairly small number of trials (default 30) is fine for
24 headline numbers but won't catch bimodal distributions. Use --extra
25 instrumentation if you ever care.
26 """
27 from __future__ import annotations
28
29 import os
30 import re
31
... (3182 more chars)--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -10,11 +10,11 @@
Key optimizations:
- Two-dot pattern: process even/odd K rows with two separate dots. Each
group_size=128 K rows become two 64-wide dots that both read the same x
- row span but different nibbles of the packed weight byte. This matches
- the bit-level rounding of the reference's cuBLAS bf16 GEMM.
- - Per-shape hardcoded block sizes (we found autotune's single-run timing
- is too noisy to reliably pick the right config; we benchmarked each shape
- and picked the best config manually).
+ row span but different nibbles of the packed weight byte.
+ - Per-shape hardcoded block sizes, picked from local benchmarks.
+ - M=1 kernel uses a smaller BLOCK_M=1 path to avoid wasted M tiles while
+ still feeding the tensor core's M=16 minimum (Triton handles the padding).
+ - For very small N (shape 3), we split-K to raise SM occupancy.
"""
from __future__ import annotations
@@ -31,32 +31,8 @@
# ---------------------------------------------------------------------------
-# Triton kernel
-# ---------------------------------------------------------------------------
-
-# Block sizes picked per shape from local benchmarks on this GPU.
-# (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages)
-_PICKED_CONFIGS = {
- (1, 12288, 4096): (16, 64, 128, 4, 2),
- (32, 12288, 4096): (16, 128, 128, 4, 2),
- (256, 12288, 4096):(16, 128, 128, 4, 2),
- (1, 4096, 4096): (16, 64, 256, 4, 2),
- (16, 14336, 4096): (16, 128, 256, 4, 3),
-}
-
-
-def _pick_config(M: int, N: int, K: int) -> tuple[int, int, int, int, int]:
- """Pick (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) for the shape."""
- key = (M, N, K)
- if key in _PICKED_CONFIGS:
- return _PICKED_CONFIGS[key]
- # Fallback heuristic
- if M == 1:
- return (16, 64, 128, 4, 2)
- if M <= 32:
- return (16, 128, 128, 4, 2)
- return (16, 128, 128, 4, 2)
-
+# Triton kernel — main path
+# ---------------------------------------------------------------------------
@triton.jit
def w4a16_gemm_kernel(
@@ -129,8 +105,90 @@
# ---------------------------------------------------------------------------
+# Split-K kernel for M=1, small N
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def w4a16_splitk_kernel(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ SPLIT_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ """Split-K version for very small N (e.g., N=4096, M=1)."""
+ pid_n = tl.program_id(0)
+ pid_k = tl.program_id(1) # split-K index
+
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ # Each split processes K / SPLIT_K elements of K
+ K_per_split: tl.constexpr = K // SPLIT_K
+ k_start = pid_k * K_per_split
+ k_end = k_start + K_per_split
+
+ for k_blk in tl.range(k_start, k_end, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(X + x_offs_even * stride_xk)
+ x_odd = tl.load(X + x_offs_odd * stride_xk)
+
+ acc += tl.sum(x_even[:, None] * w_lo, axis=0)
+ acc += tl.sum(x_odd[:, None] * w_hi, axis=0)
+
+ # Atomic add (since multiple K splits write to the same output)
+ out_ptrs = OUT + offs_n * stride_on
+ tl.atomic_add(out_ptrs, acc.to(tl.bfloat16), mask=offs_n < N)
+
+
+# ---------------------------------------------------------------------------
# Module
# ---------------------------------------------------------------------------
+
+
+# Picked configs from local benchmarks.
+_PICKED_CONFIGS = {
+ (1, 12288, 4096): (16, 64, 128, 4, 2),
+ (32, 12288, 4096): (16, 128, 128, 4, 2),
+ (256, 12288, 4096):(16, 128, 128, 4, 2),
+ (1, 4096, 4096): (16, 64, 128, 4, 2), # use split-K path
+ (16, 14336, 4096): (16, 128, 256, 4, 3),
+}
class Model(nn.Module):
@@ -155,21 +213,42 @@
assert x.dtype == torch.bfloat16
x = x.contiguous()
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
-
- BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = _pick_config(M, N, K)
- grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
- w4a16_gemm_kernel[grid](
- x, self.w_q, self.scales, self.zeros, out,
- M, N, K,
- x.stride(0), x.stride(1),
- self.w_q.stride(0), self.w_q.stride(1),
- self.scales.stride(0), self.scales.stride(1),
- self.zeros.stride(0), self.zeros.s
... (1938 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<9501 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -5",
"description": "Verify with split-K path"
}FAIL: shape 3 {'M': 1, 'N': 4096, 'K': 4096} seed 123 case large_activation: tolerance exceeded: max_abs_diff=2 max_rel_diff=6.42029 bad=3/4096 worst_idx=(0, 1195) (atol=1.0, rtol=0.05)--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -10,11 +10,9 @@
Key optimizations:
- Two-dot pattern: process even/odd K rows with two separate dots. Each
group_size=128 K rows become two 64-wide dots that both read the same x
- row span but different nibbles of the packed weight byte.
+ row span but different nibbles of the packed weight byte. This matches
+ the bit-level rounding of the reference's cuBLAS bf16 GEMM.
- Per-shape hardcoded block sizes, picked from local benchmarks.
- - M=1 kernel uses a smaller BLOCK_M=1 path to avoid wasted M tiles while
- still feeding the tensor core's M=16 minimum (Triton handles the padding).
- - For very small N (shape 3), we split-K to raise SM occupancy.
"""
from __future__ import annotations
@@ -31,7 +29,7 @@
# ---------------------------------------------------------------------------
-# Triton kernel — main path
+# Triton kernel
# ---------------------------------------------------------------------------
@triton.jit
@@ -105,88 +103,16 @@
# ---------------------------------------------------------------------------
-# Split-K kernel for M=1, small N
-# ---------------------------------------------------------------------------
-
-@triton.jit
-def w4a16_splitk_kernel(
- X, WQ, S, Z, OUT,
- M, N, K,
- stride_xm, stride_xk,
- stride_wk, stride_wn,
- stride_sg, stride_sn,
- stride_zg, stride_zn,
- stride_om, stride_on,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- SPLIT_K: tl.constexpr,
- GROUP_SZ: tl.constexpr,
-):
- """Split-K version for very small N (e.g., N=4096, M=1)."""
- pid_n = tl.program_id(0)
- pid_k = tl.program_id(1) # split-K index
-
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-
- GROUP_HALF: tl.constexpr = GROUP_SZ // 2
- BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
- n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
-
- acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
-
- # Each split processes K / SPLIT_K elements of K
- K_per_split: tl.constexpr = K // SPLIT_K
- k_start = pid_k * K_per_split
- k_end = k_start + K_per_split
-
- for k_blk in tl.range(k_start, k_end, BLOCK_K):
- wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
- wq = tl.load(
- WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
- mask=offs_n[None, :] < N, other=0,
- )
-
- w_lo = (wq & 0xF).to(tl.bfloat16)
- w_hi = (wq >> 4).to(tl.bfloat16)
-
- g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
- s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
- mask=offs_n[None, :] < N, other=0.0)
- z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
- mask=offs_n[None, :] < N, other=0.0)
-
- s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
- s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
- z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
- z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
-
- w_lo = (w_lo - z) * s
- w_hi = (w_hi - z) * s
-
- x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
- x_offs_odd = x_offs_even + 1
- x_even = tl.load(X + x_offs_even * stride_xk)
- x_odd = tl.load(X + x_offs_odd * stride_xk)
-
- acc += tl.sum(x_even[:, None] * w_lo, axis=0)
- acc += tl.sum(x_odd[:, None] * w_hi, axis=0)
-
- # Atomic add (since multiple K splits write to the same output)
- out_ptrs = OUT + offs_n * stride_on
- tl.atomic_add(out_ptrs, acc.to(tl.bfloat16), mask=offs_n < N)
-
-
-# ---------------------------------------------------------------------------
# Module
# ---------------------------------------------------------------------------
-# Picked configs from local benchmarks.
+# Picked configs from local benchmarks (per-shape).
_PICKED_CONFIGS = {
(1, 12288, 4096): (16, 64, 128, 4, 2),
(32, 12288, 4096): (16, 128, 128, 4, 2),
(256, 12288, 4096):(16, 128, 128, 4, 2),
- (1, 4096, 4096): (16, 64, 128, 4, 2), # use split-K path
+ (1, 4096, 4096): (16, 64, 128, 4, 2),
(16, 14336, 4096): (16, 128, 256, 4, 3),
}
@@ -213,42 +139,23 @@
assert x.dtype == torch.bfloat16
x = x.contiguous()
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # For small N with M=1, use split-K to raise occupancy
- if M == 1 and N <= 8192:
- out = torch.zeros((M, N), dtype=torch.bfloat16, device=x.device)
- BLOCK_N = 64
- BLOCK_K = 128
- SPLIT_K = 4
- grid = (triton.cdiv(N, BLOCK_N), SPLIT_K)
- w4a16_splitk_kernel[grid](
- x, self.w_q, self.scales, self.zeros, out,
- M, N, K,
- x.stride(0), x.stride(1),
- self.w_q.stride(0), self.w_q.stride(1),
- self.scales.stride(0), self.scales.stride(1),
- self.zeros.stride(0), self.zeros.stride(1),
- out.stride(0), out.stride(1),
- BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, SPLIT_K=SPLIT_K, GROUP_SZ=128,
- num_warps=4, num_stages=2,
- )
- else:
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = _PICKED_CONFIGS.get(
- (M, N, K), (16, 128, 128, 4, 2)
- )
- grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
- w4a16_gemm_kernel[grid](
- x, self.w_q, self.scales, self.zeros, out,
- M, N, K,
- x.stride(0), x.stride(1),
- self.w_q.stride(0), self.w_q.stride(1),
- self.scales.stride(0), self.scales.stride(1),
- self.zeros.stride(0),
... (965 more chars){
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<5845 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness"
}PASS{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -20",
"description": "Benchmark current solution"
}shape=0 variant=solution tflops=1.416 gbps=376.596 ms=0.071
shape=0 solution_peak_fraction=0.2092
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T20:15:06.584187+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T20:15:06.595308+00:00 elapsed_s=0.011 ms=0.110496
shape=1 variant=solution tflops=29.152 gbps=251.478 ms=0.110
shape=1 solution_peak_fraction=0.1397
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T20:15:06.918593+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T20:15:06.939480+00:00 elapsed_s=0.021 ms=0.285792
shape=2 variant=solution tflops=90.170 gbps=122.912 ms=0.286
shape=2 solution_peak_fraction=0.0683
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T20:15:07.040514+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T20:15:07.048304+00:00 elapsed_s=0.008 ms=0.061152
shape=3 variant=solution tflops=0.549 gbps=146.018 ms=0.061
shape=3 solution_peak_fraction=0.0811
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T20:15:07.414579+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T20:15:07.423812+00:00 elapsed_s=0.009 ms=0.080848
shape=4 variant=solution tflops=23.242 gbps=393.145 ms=0.081
shape=4 solution_peak_fraction=0.2184
peak_fraction: 0.1287
RESULT: OK--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -12,7 +12,7 @@
group_size=128 K rows become two 64-wide dots that both read the same x
row span but different nibbles of the packed weight byte. This matches
the bit-level rounding of the reference's cuBLAS bf16 GEMM.
- - Per-shape hardcoded block sizes, picked from local benchmarks.
+ - Triton autotune over a focused set of configs to find the best per shape.
"""
from __future__ import annotations
@@ -32,6 +32,29 @@
# Triton kernel
# ---------------------------------------------------------------------------
+_CONFIGS = [
+ # M=1 / decode variants — small BN, more programs
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ # M=16-32 / small prefill
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ # M=256 / large prefill
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
+]
+
+
+@triton.autotune(configs=_CONFIGS, key=["M", "N", "K"])
@triton.jit
def w4a16_gemm_kernel(
X, WQ, S, Z, OUT,
@@ -107,16 +130,6 @@
# ---------------------------------------------------------------------------
-# Picked configs from local benchmarks (per-shape).
-_PICKED_CONFIGS = {
- (1, 12288, 4096): (16, 64, 128, 4, 2),
- (32, 12288, 4096): (16, 128, 128, 4, 2),
- (256, 12288, 4096):(16, 128, 128, 4, 2),
- (1, 4096, 4096): (16, 64, 128, 4, 2),
- (16, 14336, 4096): (16, 128, 256, 4, 3),
-}
-
-
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
@@ -141,10 +154,10 @@
x = x.contiguous()
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = _PICKED_CONFIGS.get(
- (M, N, K), (16, 128, 128, 4, 2)
+ grid = lambda meta: (
+ triton.cdiv(M, meta["BLOCK_M"]),
+ triton.cdiv(N, meta["BLOCK_N"]),
)
- grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
M, N, K,
@@ -153,8 +166,6 @@
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
- num_warps=num_warps, num_stages=num_stages,
)
return out
{
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"content": "<7213 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness"
}PASS{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -20",
"description": "Benchmark with wider autotune"
}shape=0 variant=solution tflops=1.576 gbps=419.142 ms=0.064
shape=0 solution_peak_fraction=0.2329
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T20:17:50.256853+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T20:17:51.933440+00:00 elapsed_s=1.677 ms=0.099440
shape=1 variant=solution tflops=32.394 gbps=279.437 ms=0.099
shape=1 solution_peak_fraction=0.1552
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T20:17:52.267641+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T20:17:54.096995+00:00 elapsed_s=1.829 ms=0.282224
shape=2 variant=solution tflops=91.310 gbps=124.466 ms=0.282
shape=2 solution_peak_fraction=0.0691
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T20:17:54.199085+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T20:17:55.818701+00:00 elapsed_s=1.620 ms=0.055712
shape=3 variant=solution tflops=0.602 gbps=160.276 ms=0.056
shape=3 solution_peak_fraction=0.0890
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T20:17:56.189301+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T20:17:57.829709+00:00 elapsed_s=1.640 ms=0.083920
shape=4 variant=solution tflops=22.391 gbps=378.753 ms=0.084
shape=4 solution_peak_fraction=0.2104
peak_fraction: 0.1362
RESULT: OK{
"command": "<3481 chars \u2014 see below>",
"description": "Test BLOCK_M=1"
}uv run python -c "
import torch
import sys
sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
import triton
import triton.language as tl
@triton.jit
def kernel_test(
X, WQ, S, Z, OUT,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SZ: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m_mask = offs_m < M
n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
GROUP_HALF: tl.constexpr = GROUP_SZ // 2
BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_blk in tl.range(0, K, BLOCK_K):
wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
wq = tl.load(
WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
mask=offs_n[None, :] < N, other=0,
)
w_lo = (wq & 0xF).to(tl.bfloat16)
w_hi = (wq >> 4).to(tl.bfloat16)
g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
mask=offs_n[None, :] < N, other=0.0)
z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
mask=offs_n[None, :] < N, other=0.0)
s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
w_lo = (w_lo - z) * s
w_hi = (w_hi - z) * s
x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
x_offs_odd = x_offs_even + 1
x_even = tl.load(
X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
mask=offs_m_mask[:, None], other=0.0,
)
x_odd = tl.load(
X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
mask=offs_m_mask[:, None], other=0.0,
)
acc += tl.dot(x_even, w_lo)
acc += tl.dot(x_odd, w_hi)
out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, acc.to(tl.bfloat16),
mask=offs_m_mask[:, None] & (offs_n[None, :] < N))
import reference
M, N, K = 1, 4096, 4096
reference.M = M
reference.N = N
reference.K = K
ref_model = reference.Model(M, N, K).cuda().eval()
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
# Try BLOCK_M=1
BLOCK_M, BLOCK_N, BLOCK_K = 1, 64, 128
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
kernel_test[grid](
x, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
M, N, K,
x.stride(0), x.stride(1),
ref_model.w_q.stride(0), ref_model.w_q.stride(1),
ref_model.scales.stride(0), ref_model.scales.stride(1),
ref_model.zeros.stride(0), ref_model.zeros.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
num_warps=4, num_stages=2,
)
ref_out = ref_model(x)
print(f'BLOCK_M=1, BLOCK_N=64: max diff = {(ref_out - out).abs().max().item():.4f}')
" 2>&1 | tail -10return decorator(fn)
^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/test_bm1.py
@@ -0,0 +1,113 @@
+"""Test BLOCK_M=1."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import triton
+import triton.language as tl
+
+@triton.jit
+def kernel_test(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_m_mask = offs_m < M
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m_mask[:, None], other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m_mask[:, None], other=0.0,
+ )
+
+ acc += tl.dot(x_even, w_lo)
+ acc += tl.dot(x_odd, w_hi)
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=offs_m_mask[:, None] & (offs_n[None, :] < N))
+
+
+import reference
+M, N, K = 1, 4096, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+
+# Try BLOCK_M=1
+for BLOCK_M, BLOCK_N, BLOCK_K, nw, ns in [
+ (1, 64, 128, 4, 2),
+ (1, 128, 128, 4, 2),
+ (1, 64, 256, 4, 2),
+ (1, 32, 128, 4, 2),
+]:
+ try:
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+ kernel_test[grid](
+ x, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ ref_model.w_q.stride(0), ref_model.w_q.stride(1),
+ ref_model.scales.stride(0), ref_model.scales.stride(1),
+ ref_model.zeros.stride(0), ref_model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=nw, num_stages=ns,
+ )
+ ref_out = ref_model(x)
+ diff = (ref_out - out).abs().max().item()
+ print(f'BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}, BLOCK_K={BLOCK_K}, W={nw}, S={ns}: max diff = {diff:.4f}')
+ except Exception as e:
+ print(f'BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}, BLOCK_K={BLOCK_K}, W={nw}, S={ns}: ERROR {e}')
{
"file_path": "/workspace/problems/07_w4a16_gemm/test_bm1.py",
"content": "<3892 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/test_bm1.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python test_bm1.py 2>&1 | tail -10",
"description": "Test BLOCK_M=1"
}BLOCK_M=1, BLOCK_N=64, BLOCK_K=128, W=4, S=2: max diff = 0.0156
BLOCK_M=1, BLOCK_N=128, BLOCK_K=128, W=4, S=2: max diff = 0.0156
BLOCK_M=1, BLOCK_N=64, BLOCK_K=256, W=4, S=2: max diff = 0.0156
BLOCK_M=1, BLOCK_N=32, BLOCK_K=128, W=4, S=2: max diff = 0.0156--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/test_bm1_time.py
@@ -0,0 +1,161 @@
+"""Time BLOCK_M=1 for M=1 cases."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import triton
+import triton.language as tl
+import statistics
+import reference
+
+@triton.jit
+def kernel_test(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_m_mask = offs_m < M
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m_mask[:, None], other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m_mask[:, None], other=0.0,
+ )
+
+ acc += tl.dot(x_even, w_lo)
+ acc += tl.dot(x_odd, w_hi)
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=offs_m_mask[:, None] & (offs_n[None, :] < N))
+
+
+def time_kernel(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, x, w_q, scales, zeros, out, n_iters=30):
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+ # Warmup
+ for _ in range(10):
+ kernel_test[grid](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=num_warps, num_stages=num_stages,
+ )
+ torch.cuda.synchronize()
+
+ # L2 flush
+ l2 = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
+
+ times = []
+ for _ in range(n_iters):
+ l2.zero_()
+ torch.cuda.synchronize()
+ s = torch.cuda.Event(enable_timing=True)
+ e = torch.cuda.Event(enable_timing=True)
+ s.record()
+ kernel_test[grid](
+ x, w_q, scales, zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ w_q.stride(0), w_q.stride(1),
+ scales.stride(0), scales.stride(1),
+ zeros.stride(0), zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=num_warps, num_stages=num_stages,
+ )
+ e.record()
+ torch.cuda.synchronize()
+ times.append(s.elapsed_time(e))
+ return statistics.median(times)
+
+
+# Test each shape
+shapes = [
+ (1, 12288, 4096),
+ (1, 4096, 4096),
+]
+
+for M, N, K in shapes:
+ reference.M = M
+ reference.N = N
+ reference.K = K
+ ref_model = reference.Model(M, N, K).cuda().eval()
+ x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+ out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+
+ print(f"\n=== M={M}, N={N}, K={K} ===")
+ for BLOCK_M, BLOCK_N, BLOCK_K, nw, ns in [
+ (1, 32, 128, 4, 2),
+ (1, 32, 256, 4, 2),
+ (1, 64, 128, 4, 2),
+ (1, 64, 256, 4, 2),
+ (1, 128, 128, 4, 2),
+ (1, 128, 256, 4, 2),
+ (1, 64, 128, 4, 3),
+ (1, 64, 128, 4, 4),
+ (1, 64, 256, 4, 3),
+ (1, 32, 128, 4, 3),
+ (1, 32, 256, 4, 3),
+ (1, 64, 512, 4, 2),
+ ]:
+ try:
+ t = time_kernel(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, nw, ns,
+ x, ref_model.w_q, ref_model.scales, ref_model.zeros, out)
+ bytes_moved = M*K*2 + (K//2)*N + (K//128)*N*2 + (K//128)*N*2 + M*N*2
+ gbps = bytes_moved / (t * 1e-3) / 1e9
+ print(f" BM={BLOCK_M:2d} BN={BLOCK_N:3d} BK={BLOCK_K:3d} W={nw} S={ns}: t={t*1000:.2f}us, {gbps:.0f} GB/s")
+ except Exception as e:
+ print(f" BM={BLOCK_M:2d} BN={BLOCK_N:3d} BK={BLOCK_K:3d} W={nw} S={ns}: ERROR {type(e).__name__}: {e}")
{
"file_path": "/workspace/problems/07_w4a16_gemm/test_bm1_time.py",
"content": "<5529 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/test_bm1_time.py (file state is current in your context — no need to Read it back){
"command": "timeout 1200 uv run python test_bm1_time.py 2>&1 | tail -40",
"description": "Time BLOCK_M=1 for M=1 cases"
}=== M=1, N=12288, K=4096 ===
BM= 1 BN= 32 BK=128 W=4 S=2: t=57.54us, 465 GB/s
BM= 1 BN= 32 BK=256 W=4 S=2: t=58.72us, 456 GB/s
BM= 1 BN= 64 BK=128 W=4 S=2: t=59.63us, 449 GB/s
BM= 1 BN= 64 BK=256 W=4 S=2: t=59.84us, 447 GB/s
BM= 1 BN=128 BK=128 W=4 S=2: t=86.14us, 311 GB/s
BM= 1 BN=128 BK=256 W=4 S=2: t=56.02us, 478 GB/s
BM= 1 BN= 64 BK=128 W=4 S=3: t=53.15us, 504 GB/s
BM= 1 BN= 64 BK=128 W=4 S=4: t=53.02us, 505 GB/s
BM= 1 BN= 64 BK=256 W=4 S=3: t=57.02us, 469 GB/s
BM= 1 BN= 32 BK=128 W=4 S=3: t=57.49us, 466 GB/s
BM= 1 BN= 32 BK=256 W=4 S=3: t=56.34us, 475 GB/s
BM= 1 BN= 64 BK=512 W=4 S=2: t=51.57us, 519 GB/s
=== M=1, N=4096, K=4096 ===
BM= 1 BN= 32 BK=128 W=4 S=2: t=52.45us, 170 GB/s
BM= 1 BN= 32 BK=256 W=4 S=2: t=41.38us, 216 GB/s
BM= 1 BN= 64 BK=128 W=4 S=2: t=55.42us, 161 GB/s
BM= 1 BN= 64 BK=256 W=4 S=2: t=51.65us, 173 GB/s
BM= 1 BN=128 BK=128 W=4 S=2: t=82.70us, 108 GB/s
BM= 1 BN=128 BK=256 W=4 S=2: t=49.04us, 182 GB/s
BM= 1 BN= 64 BK=128 W=4 S=3: t=38.59us, 231 GB/s
BM= 1 BN= 64 BK=128 W=4 S=4: t=38.45us, 232 GB/s
BM= 1 BN= 64 BK=256 W=4 S=3: t=44.53us, 201 GB/s
BM= 1 BN= 32 BK=128 W=4 S=3: t=42.46us, 210 GB/s
BM= 1 BN= 32 BK=256 W=4 S=3: t=32.77us, 273 GB/s
BM= 1 BN= 64 BK=512 W=4 S=2: t=36.93us, 242 GB/s--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -33,24 +33,32 @@
# ---------------------------------------------------------------------------
_CONFIGS = [
- # M=1 / decode variants — small BN, more programs
+ # M=1 / decode variants — small BN, BLOCK_M=1
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 32, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ # M>=16 / prefill (BLOCK_M=16 is tensor-core minimum)
triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
- # M=16-32 / small prefill
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- # M=256 / large prefill
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=8, num_stages=2),
]
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"old_string": "<1800 chars \u2014 see diff>",
"new_string": "<2788 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -5",
"description": "Verify correctness with BLOCK_M=1"
}PASS{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -20",
"description": "Benchmark with BLOCK_M=1"
}shape=0 variant=solution tflops=1.652 gbps=439.280 ms=0.061
shape=0 solution_peak_fraction=0.2440
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T20:22:26.649279+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T20:22:29.453677+00:00 elapsed_s=2.804 ms=0.099696
shape=1 variant=solution tflops=32.310 gbps=278.720 ms=0.100
shape=1 solution_peak_fraction=0.1548
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T20:22:29.788356+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T20:22:32.916768+00:00 elapsed_s=3.128 ms=0.292480
shape=2 variant=solution tflops=88.108 gbps=120.102 ms=0.292
shape=2 solution_peak_fraction=0.0667
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T20:22:33.016797+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T20:22:35.617091+00:00 elapsed_s=2.600 ms=0.042512
shape=3 variant=solution tflops=0.789 gbps=210.041 ms=0.043
shape=3 solution_peak_fraction=0.1167
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T20:22:36.001215+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T20:22:38.731407+00:00 elapsed_s=2.730 ms=0.084704
shape=4 variant=solution tflops=22.184 gbps=375.247 ms=0.085
shape=4 solution_peak_fraction=0.2085
peak_fraction: 0.1437
RESULT: OK{
"command": "<861 chars \u2014 see below>",
"description": "Check autotune picks after BLOCK_M=1 added"
}uv run python -c "
import torch
import sys
sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
import solution
import reference
import shapes
for shape in shapes.SHAPES:
M, N, K = shape['M'], shape['N'], shape['K']
reference.M = M
reference.N = N
reference.K = K
ref = reference.Model(M, N, K).cuda().eval()
m = solution.Model(M, N, K).cuda().eval()
m.load_state_dict(ref.state_dict())
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
out = m(x)
torch.cuda.synchronize()
key = (M, N, K, 'torch.bfloat16', 'torch.uint8', 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16')
cfg = solution.w4a16_gemm_kernel.cache.get(key, None)
print(f'Shape M={M} N={N} K={K}: BM={cfg.kwargs[\"BLOCK_M\"]} BN={cfg.kwargs[\"BLOCK_N\"]} BK={cfg.kwargs[\"BLOCK_K\"]} W={cfg.num_warps} S={cfg.num_stages}')" 2>&1 | tail -10Shape M=1 N=12288 K=4096: BM=1 BN=128 BK=256 W=4 S=3
Shape M=32 N=12288 K=4096: BM=16 BN=64 BK=128 W=4 S=3
Shape M=256 N=12288 K=4096: BM=16 BN=128 BK=128 W=4 S=2
Shape M=1 N=4096 K=4096: BM=1 BN=32 BK=256 W=4 S=3
Shape M=16 N=14336 K=4096: BM=16 BN=64 BK=128 W=4 S=3--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -42,9 +42,11 @@
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 1024, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 256, "GROUP_SZ": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 512, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 1024, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
# M>=16 / prefill (BLOCK_M=16 is tensor-core minimum)
triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_SZ": 128}, num_warps=4, num_stages=2),
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/solution.py",
"old_string": "<442 chars \u2014 see diff>",
"new_string": "<665 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/solution.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 1800 uv run python check.py 2>&1 | tail -3",
"description": "Verify correctness"
}Command running in background with ID: boufz2gs5. Output is being written to: /tmp/claude-1002/-workspace-problems-07-w4a16-gemm/f724c22c-4b6a-4d9a-acd6-6f3a4c483f9f/tasks/boufz2gs5.output. You will be notified when it completes. To check interim output, use Read on that file path.{
"task_id": "boufz2gs5",
"block": true,
"timeout": 600000
}<retrieval_status>success</retrieval_status>
<task_id>boufz2gs5</task_id>
<task_type>local_bash</task_type>
<status>completed</status>
<exit_code>0</exit_code>
<output>
PASS
</output>{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | tail -15",
"description": "Benchmark with more BLOCK_K"
}shape=1 solution_peak_fraction=0.1555
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T20:26:42.987534+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T20:26:46.451806+00:00 elapsed_s=3.464 ms=0.292864
shape=2 variant=solution tflops=87.992 gbps=119.944 ms=0.293
shape=2 solution_peak_fraction=0.0666
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T20:26:46.560026+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T20:26:49.379983+00:00 elapsed_s=2.820 ms=0.041648
shape=3 variant=solution tflops=0.806 gbps=214.399 ms=0.042
shape=3 solution_peak_fraction=0.1191
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T20:26:49.763200+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T20:26:52.737596+00:00 elapsed_s=2.974 ms=0.084064
shape=4 variant=solution tflops=22.353 gbps=378.104 ms=0.084
shape=4 solution_peak_fraction=0.2101
peak_fraction: 0.1443
RESULT: OK{
"command": "timeout 1800 uv run python benchmark.py 2>&1 | grep -E \"shape=\\d variant=solution|peak_fraction\"",
"description": "Get all shape results"
}benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-13T20:27:45.482850+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-13T20:27:48.530172+00:00 elapsed_s=3.047 ms=0.061520
shape=0 variant=solution tflops=1.636 gbps=435.167 ms=0.062
shape=0 solution_peak_fraction=0.2418
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T20:27:48.875445+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T20:27:51.936627+00:00 elapsed_s=3.061 ms=0.099840
shape=1 variant=solution tflops=32.264 gbps=278.318 ms=0.100
shape=1 solution_peak_fraction=0.1546
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T20:27:52.270954+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T20:27:55.741192+00:00 elapsed_s=3.470 ms=0.292720
shape=2 variant=solution tflops=88.036 gbps=120.003 ms=0.293
shape=2 solution_peak_fraction=0.0667
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T20:27:55.847152+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T20:27:58.666601+00:00 elapsed_s=2.819 ms=0.042448
shape=3 variant=solution tflops=0.790 gbps=210.358 ms=0.042
shape=3 solution_peak_fraction=0.1169
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T20:27:59.047560+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T20:28:02.023913+00:00 elapsed_s=2.976 ms=0.083968
shape=
... (114 more chars)--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/test_single_dot.py
@@ -0,0 +1,120 @@
+"""Test single-dot pattern with permute."""
+import torch
+import sys
+sys.path.insert(0, '/workspace/problems/07_w4a16_gemm')
+
+import triton
+import triton.language as tl
+import reference
+
+
+@triton.jit
+def kernel_single_dot(
+ X, WQ, S, Z, OUT,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wk, stride_wn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SZ: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_m_mask = offs_m < M
+
+ n_groups_per_blk: tl.constexpr = BLOCK_K // GROUP_SZ
+ GROUP_HALF: tl.constexpr = GROUP_SZ // 2
+ BLOCK_K_HALF: tl.constexpr = BLOCK_K // 2
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_blk in tl.range(0, K, BLOCK_K):
+ wq_offs = (k_blk // 2) + tl.arange(0, BLOCK_K_HALF)
+ wq = tl.load(
+ WQ + wq_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=offs_n[None, :] < N, other=0,
+ )
+
+ w_lo = (wq & 0xF).to(tl.bfloat16)
+ w_hi = (wq >> 4).to(tl.bfloat16)
+
+ g_idx = (k_blk // GROUP_SZ) + tl.arange(0, n_groups_per_blk)
+ s = tl.load(S + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ mask=offs_n[None, :] < N, other=0.0)
+ z = tl.load(Z + g_idx[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ mask=offs_n[None, :] < N, other=0.0)
+
+ s = tl.broadcast_to(s[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ s = tl.reshape(s, (BLOCK_K_HALF, BLOCK_N))
+ z = tl.broadcast_to(z[:, None, :], (n_groups_per_blk, GROUP_HALF, BLOCK_N))
+ z = tl.reshape(z, (BLOCK_K_HALF, BLOCK_N))
+
+ w_lo = (w_lo - z) * s
+ w_hi = (w_hi - z) * s
+
+ # Interleave: w_full[2i, n] = w_lo[i, n], w_full[2i+1, n] = w_hi[i, n]
+ # Use join + permute + reshape
+ w_full = tl.join(w_lo, w_hi) # (BLOCK_K_HALF, BLOCK_N, 2) with last dim [lo, hi]
+ # We want to permute so the [lo, hi] dim is the K dim
+ w_full = tl.permute(w_full, (0, 2, 1)) # (BLOCK_K_HALF, 2, BLOCK_N) with middle dim [lo, hi]
+ w_full = tl.reshape(w_full, (BLOCK_K, BLOCK_N))
+
+ # Same for x
+ x_offs_even = k_blk + 2 * tl.arange(0, BLOCK_K_HALF)
+ x_offs_odd = x_offs_even + 1
+ x_even = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_even[None, :] * stride_xk,
+ mask=offs_m_mask[:, None], other=0.0,
+ )
+ x_odd = tl.load(
+ X + offs_m[:, None] * stride_xm + x_offs_odd[None, :] * stride_xk,
+ mask=offs_m_mask[:, None], other=0.0,
+ )
+ x_full = tl.join(x_even, x_odd) # (BLOCK_M, BLOCK_K_HALF, 2)
+ x_full = tl.permute(x_full, (0, 2, 1)) # (BLOCK_M, 2, BLOCK_K_HALF)
+ x_full = tl.reshape(x_full, (BLOCK_M, BLOCK_K))
+
+ acc += tl.dot(x_full, w_full)
+
+ out_ptrs = OUT + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ tl.store(out_ptrs, acc.to(tl.bfloat16),
+ mask=offs_m_mask[:, None] & (offs_n[None, :] < N))
+
+
+M, N, K = 1, 12288, 4096
+reference.M = M
+reference.N = N
+reference.K = K
+ref_model = reference.Model(M, N, K).cuda().eval()
+
+x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+out = torch.zeros(M, N, dtype=torch.bfloat16).cuda()
+
+BLOCK_M, BLOCK_N, BLOCK_K = 1, 64, 128
+grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
+kernel_single_dot[grid](
+ x, ref_model.w_q, ref_model.scales, ref_model.zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ ref_model.w_q.stride(0), ref_model.w_q.stride(1),
+ ref_model.scales.stride(0), ref_model.scales.stride(1),
+ ref_model.zeros.stride(0), ref_model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SZ=128,
+ num_warps=4, num_stages=2,
+)
+ref_out = ref_model(x)
+diff = (ref_out - out).abs().max().item()
+print(f'Single dot: max diff = {diff:.4f}')
+
+# Check the layout
+import triton.language as tl as tl2
+print('tl.join creates a tensor with shape (..., 2) at the end')
+print('After permute and reshape, should be interleaved')
{
"file_path": "/workspace/problems/07_w4a16_gemm/test_single_dot.py",
"content": "<4287 chars \u2014 see diff>"
}File created successfully at: /workspace/problems/07_w4a16_gemm/test_single_dot.py (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python test_single_dot.py 2>&1 | tail -10",
"description": "Test single-dot with permute"
}File "/workspace/problems/07_w4a16_gemm/test_single_dot.py", line 118
import triton.language as tl as tl2
^^
SyntaxError: invalid syntax--- a//workspace/problems/07_w4a16_gemm/test_single_dot.py
+++ b//workspace/problems/07_w4a16_gemm/test_single_dot.py
@@ -113,8 +113,3 @@
ref_out = ref_model(x)
diff = (ref_out - out).abs().max().item()
print(f'Single dot: max diff = {diff:.4f}')
-
-# Check the layout
-import triton.language as tl as tl2
-print('tl.join creates a tensor with shape (..., 2) at the end')
-print('After permute and reshape, should be interleaved')
{
"replace_all": false,
"file_path": "/workspace/problems/07_w4a16_gemm/test_single_dot.py",
"old_string": "<287 chars \u2014 see diff>",
"new_string": "<108 chars \u2014 see diff>"
}The file /workspace/problems/07_w4a16_gemm/test_single_dot.py has been updated successfully. (file state is current in your context — no need to Read it back){
"command": "timeout 600 uv run python test_single_dot.py 2>&1 | tail -10",
"description": "Test single-dot with permute"
}Single dot: max diff = 6.8438