"""W4A16 weight-only quantized GEMM for SM120 Blackwell.
Fuses unpack, dequant, and matmul in a single Triton kernel.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def _w4a16_gemm_kernel(
x_ptr,
wq_ptr,
scales_ptr,
zeros_ptr,
out_ptr,
M,
N,
K,
stride_xm, stride_xk,
stride_wqk, stride_wqn,
stride_sg, stride_sn,
stride_zg, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_CONST: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
m_mask = offs_m < M
n_mask = offs_n < N
out_mask = m_mask[:, None] & n_mask[None, :]
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_base in range(0, K, BLOCK_SIZE_K):
k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offs < K
x_tile = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
wq_row = k_offs // 2
wq_tile = tl.load(
wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
mask=k_mask[:, None] & n_mask[None, :],
other=0,
)
is_odd = (k_offs % 2)[:, None]
w_int_4 = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
g = k_base // GROUP_SIZE_CONST
s_row = tl.load(
scales_ptr + g * stride_sg + offs_n * stride_sn,
mask=n_mask,
other=0.0,
).to(tl.float32)
z_row = tl.load(
zeros_ptr + g * stride_zg + offs_n * stride_zn,
mask=n_mask,
other=0.0,
).to(tl.float32)
w_deq = (w_int_4.to(tl.float32) - z_row[None, :]) * s_row[None, :]
acc = tl.dot(x_tile.to(tl.float32), w_deq, acc)
tl.store(
out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on,
acc.to(tl.bfloat16),
mask=out_mask,
)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0
assert K % 2 == 0
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
w_g = w_full.view(n_groups, group_size, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8).view(K, N)
lo = w_q[0::2].to(torch.uint8) & 0xF
hi = w_q[1::2].to(torch.uint8) & 0xF
w_packed = (lo | (hi << 4)).contiguous()
self.register_buffer("w_q", w_packed)
self.register_buffer("scales", scales.squeeze(1).to(torch.bfloat16))
self.register_buffer("zeros", zeros.squeeze(1).to(torch.bfloat16))
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
M, K = x.shape
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
# Profiled optimal configs — all with stages=1, warps=4, BK=128
# BN=64 beats BN=128 for most shapes (better register usage / occupancy)
if M == 1:
BM, BN = 1, 64
elif M <= 4:
BM, BN = min(M, 4), 64
elif M <= 16:
BM, BN = min(M, 16), 64
elif M <= 32:
BM, BN = 32, 64
elif M <= 64:
BM, BN = min(M, 64), 64
else:
BM, BN = 64, 64
BK = 128
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
_w4a16_gemm_kernel[grid](
x.contiguous(),
self.w_q,
self.scales,
self.zeros,
out,
M, N, K,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
BLOCK_SIZE_M=BM,
BLOCK_SIZE_N=BN,
BLOCK_SIZE_K=BK,
GROUP_SIZE_CONST=GROUP_SIZE,
num_warps=4,
num_stages=1,
)
return out
def get_inputs():
M = getattr(get_inputs, "M", 1)
K = getattr(get_inputs, "K", 4096)
return [torch.randn(M, K, dtype=torch.bfloat16)]
def get_init_inputs():
M = getattr(get_init_inputs, "M", 1)
N = getattr(get_init_inputs, "N", 12288)
K = getattr(get_init_inputs, "K", 4096)
return [M, N, K]
shape=0 variant=eager tflops=0.130 gbps=34.690 ms=0.772
shape=0 variant=compiled tflops=0.655 gbps=174.148 ms=0.154
shape=0 variant=solution tflops=1.179 gbps=313.454 ms=0.085
shape=0 solution_peak_fraction=0.1741
shape=1 variant=eager tflops=4.208 gbps=36.298 ms=0.766
shape=1 variant=compiled tflops=20.181 gbps=174.088 ms=0.160
shape=1 variant=solution tflops=35.614 gbps=307.218 ms=0.090
shape=1 solution_peak_fraction=0.1707
shape=2 variant=eager tflops=31.349 gbps=42.732 ms=0.822
shape=2 variant=compiled tflops=118.184 gbps=161.099 ms=0.218
shape=2 variant=solution tflops=103.033 gbps=140.446 ms=0.250
shape=2 solution_peak_fraction=0.0780
shape=3 variant=eager tflops=0.178 gbps=47.440 ms=0.188
shape=3 variant=compiled tflops=0.467 gbps=124.239 ms=0.072
shape=3 variant=solution tflops=0.450 gbps=119.708 ms=0.075
shape=3 solution_peak_fraction=0.0665
shape=4 variant=eager tflops=2.053 gbps=34.719 ms=0.915
shape=4 variant=compiled tflops=10.308 gbps=174.367 ms=0.182
shape=4 variant=solution tflops=21.345 gbps=361.061 ms=0.088
shape=4 solution_peak_fraction=0.2006
peak_fraction: 0.1253
RESULT: OK
PASS
{
"problem": "07_w4a16_gemm",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-plus",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1253,
"elapsed_seconds": 1440,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 4496955, "output_tokens": 40196, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 19431, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive W4A16 weight-only quantized GEMM reference (correctness only).
2:
3: AWQ/GPTQ-style scheme:
4: x: (M, K) bf16
5: w_q: (K // 2, N) uint8 -- two int4 weights packed per byte (low nibble = even-K, high = odd-K)
6: scales: (K // group, N) bf16
7: zeros: (K // group, N) bf16 -- asymmetric (stored already as float zero-point)
8: out: (M, N) bf16
9:
10: Dequant (per group along K):
11: w_bf[k, n] = (w_q[k, n] - zeros[k // group, n]) * scales[k // group, n]
12: where w_q[k, n] is the unpacked 4-bit value (0..15).
13:
14: This reference unpacks to a full bf16 matrix and then runs torch.matmul. Slow and
15: memory-heavy on the dequant; the agent's solution must fuse unpack+GEMM.
16: """
17: from __future__ import annotations
18:
19: import torch
20: import torch.nn as nn
21:
22: OP_TYPE = "gemm_w4a16"
23: SUPPORTED_PRECISIONS = ["int4_bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26: GROUP_SIZE = 128
27:
28:
29: def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
30: """Pack (K, N) uint8 in [0,15] into (K//2, N) uint8.
31:
32: Even rows go in the low nibble, odd rows in the high nibble.
33: """
34: K, N = w_q.shape
35: assert K % 2 == 0
36: lo = w_q[0::2].to(torch.uint8) & 0xF
37: hi = w_q[1::2].to(torch.uint8) & 0xF
38: ret
... (3312 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for FP8 GEMM.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or None
42:
43: # --- Per-shape
... (2909 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for FP8 GEMM.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "compute")
40: flops_formul
... (3415 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for W4A16 GEMM.
2:
3: Llama-style up_proj / qkv_proj shapes. Decode (M=1) is the bandwidth-bound
4: case every inference engine optimizes -- it's the bar to beat.
5: """
6:
7: SHAPES = [
8: {"M": 1, "N": 12288, "K": 4096}, # decode: memory-bound on int4 weight read
9: {"M": 32, "N": 12288, "K": 4096}, # small prefill: mixed regime
10: {"M": 256, "N": 12288, "K": 4096}, # larger prefill: approaching compute
11: {"M": 1, "N": 4096, "K": 4096}, # decode: square shape
12: {"M": 16, "N": 14336, "K": 4096}, # speculative-decode-ish
13: ]
(End of file - total 13 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gmdl/sota.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 07_w4a16_gemm
2: display_name: "W4A16 Weight-only Quantized GEMM"
3: precision: int4_bf16
4: regime: memory # decode-dominant; M=1 is bandwidth-bound on the int4 weight stream
5:
6: # Dense-equivalent FLOPs (matmul work, ignoring dequant arithmetic).
7: flops_formula: "2 * M * N * K"
8:
9: # Bytes moved per call (memory roofline):
10: # x: M*K*2 (bf16 activations, streamed in once)
11: # w_q: (K/2)*N (packed int4, 0.5 B/elem)
12: # scales: (K/128)*N*2 (bf16 scales)
13: # zeros: (K/128)*N*2 (bf16 zero-points)
14: # out: M*N*2 (bf16 store)
15: bytes_formula: "M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.10 # group-quant adds noise on top of bf16 accumulator slop
23:
24: # Forbidden ops -- agent must write the unpack + GEMM themselves, not call a
25: # vendor library that does both.
26: forbidden:
27: - "bitsandbytes.functional.dequantize_4bit"
28: - "bitsandbytes.functional.gemv_4bit"
29: - "marlin_kernel.gemm"
30: - "torch.nn.functional.linear"
31:
32: sota:
33: name: "bitsandbytes NF4 (gemv_4bit / dequantize_4bit + matmul)"
34: url: "https://github.com/TimDettmers/bitsandbytes"
35: function: "bitsandbytes.functional.gemv_4bit"
36: notes: |
37:
... (757 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for W4A16 GEMM.
2:
3: Library survey on RTX PRO 6000 Blackwell (SM120, CC 12.0):
4:
5: - Marlin (IST-DASLab): no SM120 kernels (Ampere/Hopper only). Skip.
6: - GPTQ-Triton (fpgaminer): unmaintained; pure Triton path works on SM120
7: but is not faster than Marlin on its target HW
8: and has no Blackwell tuning. Skip as primary.
9: - AWQ (mit-han-lab/llm-awq): CUDA kernels not built for SM120 in the wheel.
10: Skip.
11: - bitsandbytes >= 0.49.2: CUDA kernels compile and run on SM120 (verified
12: on this machine). Different quant scheme (NF4,
13: symmetric, blocksize 64) than our reference's
14: AWQ-style asymmetric INT4 with group_size 128,
15: but it occupies the same memory regime and is
16: the only tuned W4A16-class kernel that runs on
17: SM120 today. Used here as an *informational*
18: SOTA line, not as a numerical reference.
19:
20: The benchmark calls `sota_forward(x, ref_model)` and times it; correctness is
21: NOT checked against this path (the quant scheme differs).
22: """
23: from
... (2725 more chars){
"pattern": "**/solution.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}No files found{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>/dev/null | head -5",
"description": "Check GPU capabilities"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "uv run python -c \"import triton; print(triton.__version__)\" 2>&1 | head -5",
"description": "Check Triton version"
}3.6.0
{
"command": "<409 chars \u2014 see below>",
"description": "Check GPU bandwidth and TFLOPS specs",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}nvidia-smi --query-gpu=memory.bandwidth --format=csv,noheader 2>&1 || true; uv run python -c "
import torch
print('torch:', torch.__version__)
print('cuda:', torch.version.cuda)
print('SM:', torch.cuda.get_device_capability())
# peak bandwidth
from src.hardware import get as get_hw
hw = get_hw('RTX_PRO_6000')
print('peak_bw:', hw.peak_bandwidth_gb_s)
print('peak_tflops bf16:', hw.peak_tflops_dense)
" 2>&1
Field "memory.bandwidth" is not a valid field to query.
torch: 2.11.0+cu130
cuda: 13.0
SM: (12, 0)
peak_bw: 1800.0
peak_tflops bf16: {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}
{
"todos": [
{
"content": "Understand the W4A16 GEMM kernel requirements and data layout",
"status": "completed",
"priority": "high"
},
{
"content": "Write initial Triton W4A16 GEMM kernel in solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py to verify correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize for all shapes",
"status": "pending",
"priority": "high"
}
]
}content: 'Understand the W4A16 GEMM kernel requirements and data layout'
status: 'completed'
priority: 'high'
content: 'Write initial Triton W4A16 GEMM kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Run check.py to verify correctness'
status: 'pending'
priority: 'high'
content: 'Profile and optimize for all shapes'
status: 'pending'
priority: 'high'{
"command": "ls",
"description": "List problem directory"
}benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -0,0 +1,230 @@
+"""High-performance W4A16 weight-only quantized GEMM for SM120 Blackwell.
+
+AWQ/GPTQ-style asymmetric int4 with per-group bf16 scales and zero-points.
+Fuses unpack, dequant, and matmul in a single Triton kernel.
+
+Data layout:
+ x: (M, K) bf16
+ w_q: (K // 2, N) uint8 (low nibble=even-K, high=odd-K)
+ scales: (K // 128, N) bf16
+ zeros: (K // 128, N) bf16
+ out: (M, N) bf16
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+GROUP_SIZE = 128
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel — fused unpack + dequant + GEMM
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _w4a16_gemm_kernel(
+ x_ptr,
+ wq_ptr,
+ scales_ptr,
+ zeros_ptr,
+ out_ptr,
+ M,
+ N,
+ K,
+ stride_xm,
+ stride_xk,
+ stride_wqk,
+ stride_wqn,
+ stride_sg,
+ stride_sn,
+ stride_zg,
+ stride_zn,
+ stride_om,
+ stride_on,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ """Each program computes out[pid_m*BM:(pid_m+1)*BM, pid_n*BN:(pid_n+1)*BN].
+
+ K is accumulated in tiles of BLOCK_SIZE_K (must divide GROUP_SIZE).
+ """
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+
+ # --- offsets into each dimension ---
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) # [BM]
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) # [BN]
+ offs_k_base = tl.arange(0, BLOCK_SIZE_K) # [BK]
+
+ # --- validity mask for output row/col ---
+ m_mask = offs_m < M # [BM]
+ n_mask = offs_n < N # [BN]
+ out_mask = m_mask[:, None] & n_mask[None, :] # [BM, BN]
+
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+
+ for k_base in range(0, K, BLOCK_SIZE_K):
+ k_mask = (k_base + offs_k_base) < K # [BK]
+
+ # ---- load x[offs_m, k_base:...] shape [BM, BK] ----
+ x_block = tl.load(
+ x_ptr
+ + offs_m[:, None] * stride_xm
+ + (k_base + offs_k_base)[None, :] * stride_xk,
+ mask=m_mask[:, None] & k_mask[None, :],
+ other=0.0,
+ )
+
+ # ---- load w_q[k_base//2 + offs_k_base//2, offs_n] shape [BK, BN] ----
+ wq_row_idx = (k_base + offs_k_base) // 2 # [BK]
+ wq_ptrs = (
+ wq_ptr
+ + wq_row_idx[:, None] * stride_wqk
+ + offs_n[None, :] * stride_wqn
+ )
+ wq_col_mask = k_mask[:, None] & n_mask[None, :]
+ wq_block = tl.load(wq_ptrs, mask=wq_col_mask, other=0) # uint8
+
+ # unpack int4 from low/high nibble based on actual k parity
+ actual_k_idx = k_base + offs_k_base # [BK]
+ is_odd_k = (actual_k_idx % 2).to(tl.int32) # 0 or 1, broadcast [BK,1]
+ w_int = tl.where(
+ is_odd_k[:, None] == 0,
+ wq_block & 0xF,
+ wq_block >> 4,
+ ).to(tl.int32) # [BK, BN]
+
+ # ---- load scales & zeros for this block shape [BK, BN] ----
+ g_idx = (k_base + offs_k_base) // GROUP_SIZE # [BK]
+ s_block = tl.load(
+ scales_ptr
+ + g_idx[:, None] * stride_sg
+ + offs_n[None, :] * stride_sn,
+ mask=k_mask[:, None] & n_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+ z_block = tl.load(
+ zeros_ptr
+ + g_idx[:, None] * stride_zg
+ + offs_n[None, :] * stride_zn,
+ mask=k_mask[:, None] & n_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+
+ w_int = w_int.to(tl.float32)
+ w_deq = (w_int - z_block) * s_block # [BK, BN] dequantized weight
+
+ # ---- matmul tile (BM, BK) @ (BK, BN) -> (BM, BN) ----
+ acc = tl.dot(x_block.to(tl.float32), w_deq, acc)
+
+ # ---- write result (M, N) bf16 ----
+ out_ptrs = (
+ out_ptr
+ + offs_m[:, None] * stride_om
+ + offs_n[None, :] * stride_on
+ )
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
+
+
+# ---------------------------------------------------------------------------
+# Model
+# ---------------------------------------------------------------------------
+
+class Model(nn.Module):
+ """W4A16 GEMM: y = x @ dequant(w_q, scales, zeros)."""
+
+ def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
+ super().__init__()
+ assert K % group_size == 0, "K must be divisible by group_size"
+ assert K % 2 == 0, "K must be even (int4 packing)"
+ self.M, self.N, self.K = M, N, K
+ self.group_size = group_size
+ n_groups = K // group_size
+
+ torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
+ w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
+
+ w_g = w_full.view(n_groups, group_size, N)
+ w_min = w_g.min(dim=1, keepdim=True).values
+ w_max = w_g.max(dim=1, keepdim=True).values
+ scales = (w_max - w_min).clamp_min(1e-8) / 15.0
+ zeros = (-w_min / scales).round().clamp(0, 15)
+
+ w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8)
+ w_q = w_q.view(K, N)
+
+ # Pack: even row → low nibble, odd row → high nibble
+ lo = w_q[0::2].to(torch.uint8) & 0xF
+ hi = w_q[1::2].to(torch.uint8) & 0xF
+ w_packed = (lo | (hi << 4)).contiguous()
+
+ scales_2d = scales.squeeze(1).to(torch.bfloat16)
+ zeros_2d = zeros.squeeze(1).to(torch.bfloat16)
+
+ self.register_buffer("w_q", w_packed)
+ self.register_buffer("scales", scales_2d)
+ self.r
... (1759 more chars){
"content": "<7418 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "nvidia-smi | head -10",
"description": "Check GPU status before running"
}Wed Apr 29 08:59:58 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 30C P8 10W / 600W | 627MiB / 97887MiB | 0% Default |
{
"command": "nvidia-smi | head -10",
"description": "Check GPU before running kernel"
}Wed Apr 29 09:00:00 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 30C P8 12W / 600W | 627MiB / 97887MiB | 0% Default |
{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 195, in forward
_w4a16_gemm_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/
... (661 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -179,16 +179,16 @@
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # Block config: BM=64 is optimal for M>=2; M=1 special case
+ # Block sizes tuned to fit SM120 shared memory (101 KB)
if M == 1:
BLOCK_M = 1
warps = 4
else:
BLOCK_M = 64
- warps = 4
+ warps = 8
BLOCK_N = 128
- BLOCK_K = 128
+ BLOCK_K = 64
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<305 chars \u2014 see below>",
"newString": "<300 chars \u2014 see below>"
} # Block config: BM=64 is optimal for M>=2; M=1 special case
if M == 1:
BLOCK_M = 1
warps = 4
else:
BLOCK_M = 64
warps = 4
BLOCK_N = 128
BLOCK_K = 128
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) # Block sizes tuned to fit SM120 shared memory (101 KB)
if M == 1:
BLOCK_M = 1
warps = 4
else:
BLOCK_M = 64
warps = 8
BLOCK_N = 128
BLOCK_K = 64
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check (2nd attempt)"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 195, in forward
_w4a16_gemm_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/
... (661 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -179,16 +179,16 @@
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # Block sizes tuned to fit SM120 shared memory (101 KB)
if M == 1:
BLOCK_M = 1
warps = 4
else:
- BLOCK_M = 64
- warps = 8
-
- BLOCK_N = 128
+ BLOCK_M = 32
+ warps = 4
+
+ BLOCK_N = 64
BLOCK_K = 64
+ num_stages = 2
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<300 chars \u2014 see below>",
"newString": "<258 chars \u2014 see below>"
} # Block sizes tuned to fit SM120 shared memory (101 KB)
if M == 1:
BLOCK_M = 1
warps = 4
else:
BLOCK_M = 64
warps = 8
BLOCK_N = 128
BLOCK_K = 64
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) if M == 1:
BLOCK_M = 1
warps = 4
else:
BLOCK_M = 32
warps = 4
BLOCK_N = 64
BLOCK_K = 64
num_stages = 2
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))Edit applied successfully.{
"command": "nvidia-smi | head -10",
"description": "Check GPU status"
}Wed Apr 29 09:00:48 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 30C P8 10W / 600W | 627MiB / 97887MiB | 0% Default |
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -209,6 +209,7 @@
BLOCK_SIZE_K=BLOCK_K,
GROUP_SIZE=GROUP_SIZE,
num_warps=warps,
+ num_stages=num_stages,
)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<592 chars \u2014 see below>",
"newString": "<627 chars \u2014 see below>"
} _w4a16_gemm_kernel[grid](
x.contiguous(),
self.w_q,
self.scales,
self.zeros,
out,
M, N, K,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
BLOCK_SIZE_M=BLOCK_M,
BLOCK_SIZE_N=BLOCK_N,
BLOCK_SIZE_K=BLOCK_K,
GROUP_SIZE=GROUP_SIZE,
num_warps=warps,
) _w4a16_gemm_kernel[grid](
x.contiguous(),
self.w_q,
self.scales,
self.zeros,
out,
M, N, K,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
BLOCK_SIZE_M=BLOCK_M,
BLOCK_SIZE_N=BLOCK_N,
BLOCK_SIZE_K=BLOCK_K,
GROUP_SIZE=GROUP_SIZE,
num_warps=warps,
num_stages=num_stages,
)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -181,9 +181,9 @@
if M == 1:
BLOCK_M = 1
- warps = 4
+ warps = 1
else:
- BLOCK_M = 32
+ BLOCK_M = 64
warps = 4
BLOCK_N = 64
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<191 chars \u2014 see below>",
"newString": "<191 chars \u2014 see below>"
} if M == 1:
BLOCK_M = 1
warps = 4
else:
BLOCK_M = 32
warps = 4
BLOCK_N = 64
BLOCK_K = 64
num_stages = 2 if M == 1:
BLOCK_M = 1
warps = 1
else:
BLOCK_M = 64
warps = 4
BLOCK_N = 64
BLOCK_K = 64
num_stages = 2Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -21,10 +21,6 @@
GROUP_SIZE = 128
-# ---------------------------------------------------------------------------
-# Triton kernel — fused unpack + dequant + GEMM
-# ---------------------------------------------------------------------------
-
@triton.jit
def _w4a16_gemm_kernel(
x_ptr,
@@ -35,102 +31,72 @@
M,
N,
K,
- stride_xm,
- stride_xk,
- stride_wqk,
- stride_wqn,
- stride_sg,
- stride_sn,
- stride_zg,
- stride_zn,
- stride_om,
- stride_on,
+ stride_xm, stride_xk,
+ stride_wqk, stride_wqn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
+ GROUP_SIZE_CONST: tl.constexpr,
):
- """Each program computes out[pid_m*BM:(pid_m+1)*BM, pid_n*BN:(pid_n+1)*BN].
-
- K is accumulated in tiles of BLOCK_SIZE_K (must divide GROUP_SIZE).
- """
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
- # --- offsets into each dimension ---
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) # [BM]
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) # [BN]
- offs_k_base = tl.arange(0, BLOCK_SIZE_K) # [BK]
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- # --- validity mask for output row/col ---
- m_mask = offs_m < M # [BM]
- n_mask = offs_n < N # [BN]
- out_mask = m_mask[:, None] & n_mask[None, :] # [BM, BN]
+ m_mask = offs_m < M
+ n_mask = offs_n < N
+ out_mask = m_mask[:, None] & n_mask[None, :]
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_base in range(0, K, BLOCK_SIZE_K):
- k_mask = (k_base + offs_k_base) < K # [BK]
+ k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
+ k_mask = k_offs < K
- # ---- load x[offs_m, k_base:...] shape [BM, BK] ----
- x_block = tl.load(
- x_ptr
- + offs_m[:, None] * stride_xm
- + (k_base + offs_k_base)[None, :] * stride_xk,
+ # x tile: (BM, BK)
+ x_tile = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
- # ---- load w_q[k_base//2 + offs_k_base//2, offs_n] shape [BK, BN] ----
- wq_row_idx = (k_base + offs_k_base) // 2 # [BK]
- wq_ptrs = (
- wq_ptr
- + wq_row_idx[:, None] * stride_wqk
- + offs_n[None, :] * stride_wqn
+ # w_q tile: (BK, BN) — load packed bytes
+ wq_row = k_offs // 2
+ wq_tile = tl.load(
+ wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
+ mask=k_mask[:, None] & n_mask[None, :],
+ other=0,
)
- wq_col_mask = k_mask[:, None] & n_mask[None, :]
- wq_block = tl.load(wq_ptrs, mask=wq_col_mask, other=0) # uint8
- # unpack int4 from low/high nibble based on actual k parity
- actual_k_idx = k_base + offs_k_base # [BK]
- is_odd_k = (actual_k_idx % 2).to(tl.int32) # 0 or 1, broadcast [BK,1]
- w_int = tl.where(
- is_odd_k[:, None] == 0,
- wq_block & 0xF,
- wq_block >> 4,
- ).to(tl.int32) # [BK, BN]
+ # unpack nibbles: even-k → low, odd-k → high
+ is_odd = (k_offs % 2)[:, None]
+ w_int = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
- # ---- load scales & zeros for this block shape [BK, BN] ----
- g_idx = (k_base + offs_k_base) // GROUP_SIZE # [BK]
- s_block = tl.load(
- scales_ptr
- + g_idx[:, None] * stride_sg
- + offs_n[None, :] * stride_sn,
+ # scales / zeros: (BK, BN) via group index
+ g = k_offs // GROUP_SIZE_CONST
+ s_tile = tl.load(
+ scales_ptr + g[:, None] * stride_sg + offs_n[None, :] * stride_sn,
mask=k_mask[:, None] & n_mask[None, :],
other=0.0,
).to(tl.float32)
- z_block = tl.load(
- zeros_ptr
- + g_idx[:, None] * stride_zg
- + offs_n[None, :] * stride_zn,
+ z_tile = tl.load(
+ zeros_ptr + g[:, None] * stride_zg + offs_n[None, :] * stride_zn,
mask=k_mask[:, None] & n_mask[None, :],
other=0.0,
).to(tl.float32)
- w_int = w_int.to(tl.float32)
- w_deq = (w_int - z_block) * s_block # [BK, BN] dequantized weight
+ w_deq = (w_int.to(tl.float32) - z_tile) * s_tile
+ acc = tl.dot(x_tile.to(tl.float32), w_deq, acc)
- # ---- matmul tile (BM, BK) @ (BK, BN) -> (BM, BN) ----
- acc = tl.dot(x_block.to(tl.float32), w_deq, acc)
-
- # ---- write result (M, N) bf16 ----
- out_ptrs = (
- out_ptr
- + offs_m[:, None] * stride_om
- + offs_n[None, :] * stride_on
+ tl.store(
+ out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on,
+ acc.to(tl.bfloat16),
+ mask=out_mask,
)
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
# ---------------------------------------------------------------------------
@@ -138,59 +104,48 @@
# ---------------------------------------------------------------------------
class Model(nn.Module):
- """W4A16 GEMM: y = x @ dequant(w_q, scales, zeros)."""
-
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
- assert K % group_size == 0, "K must be divisible by group_size"
- assert K % 2 == 0, "K must be even (int4 packing)"
+
... (3225 more chars){
"content": "<5782 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "nvidia-smi | head -10",
"description": "Check GPU status before running"
}Wed Apr 29 09:01:54 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 29C P8 10W / 600W | 627MiB / 97887MiB | 0% Default |
{
"command": "<342 chars \u2014 see below>",
"timeout": 60000,
"description": "Quick test M=1 shape"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && nvidia-smi | head -5 && uv run python -c "
import torch
import solution
M, N, K = 1, 12288, 4096
model = solution.Model(M, N, K).cuda()
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
y = model(x)
print('Output shape:', y.shape)
print('First 5 values:', y[0, :5])
" 2>&1
Wed Apr 29 09:02:17 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
Traceback (most recent call last):
File "<string>", line 8, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 150, in forward
_w4a16_gemm_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = k
... (824 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -16,7 +16,6 @@
import torch.nn as nn
import triton
import triton.language as tl
-
GROUP_SIZE = 128
@@ -57,14 +56,14 @@
k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offs < K
- # x tile: (BM, BK)
+ # x tile: (BM, BK) bf16
x_tile = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
- # w_q tile: (BK, BN) — load packed bytes
+ # w_q tile: (BK, BN) uint8 — packed
wq_row = k_offs // 2
wq_tile = tl.load(
wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
@@ -72,19 +71,29 @@
other=0,
)
- # unpack nibbles: even-k → low, odd-k → high
+ # unpack nibbles
is_odd = (k_offs % 2)[:, None]
w_int = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
- # scales / zeros: (BK, BN) via group index
+ # scales / zeros broadcast per group — load (1, BN) and broadcast
+ g_first = k_base // GROUP_SIZE_CONST
+ g_last = (k_base + BLOCK_SIZE_K - 1) // GROUP_SIZE_CONST # inclusive check
+
+ # If this k range spans at most one group we can just load once
+ # For BK <= GROUP_SIZE this is always true when k_base is GROUP_SIZE-aligned,
+ # but our k_base increments by BLOCK_SIZE_K which may not be GROUP_SIZE-aligned.
+ # Safe approach: load per-k, but to save shared mem we'll use register broadcast
+ # since BN is small enough to fit in registers.
g = k_offs // GROUP_SIZE_CONST
+ s_base = scales_ptr + g[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ z_base = zeros_ptr + g[:, None] * stride_zg + offs_n[None, :] * stride_zn
s_tile = tl.load(
- scales_ptr + g[:, None] * stride_sg + offs_n[None, :] * stride_sn,
+ s_base,
mask=k_mask[:, None] & n_mask[None, :],
other=0.0,
).to(tl.float32)
z_tile = tl.load(
- zeros_ptr + g[:, None] * stride_zg + offs_n[None, :] * stride_zn,
+ z_base,
mask=k_mask[:, None] & n_mask[None, :],
other=0.0,
).to(tl.float32)
@@ -135,15 +144,16 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # Small shared-memory configs per shape regime
- if M == 1:
- BM, BN, BK, warps, stages = 1, 128, 128, 4, 2
- elif M <= 16:
- BM, BN, BK, warps, stages = 16, 128, 128, 4, 2
- elif M <= 64:
- BM, BN, BK, warps, stages = 64, 64, 64, 4, 2
- else:
- BM, BN, BK, warps, stages = 128, 64, 64, 4, 2
+ # Shared memory budget ~101KB on SM120.
+ # Per-stage memory: x(BM,BK)*2 + wq(BK,BN)*1 + s(BK,BN)*2 + z(BK,BN)*2
+ # With stages=1 (no pipelining) and num_warps=4:
+ # Need to fit within 101KB total.
+
+ BM = 1 # M is small in all test shapes
+ BN = 64
+ BK = 64
+ warps = 4
+ stages = 1
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
{
"content": "<6326 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "nvidia-smi | head -10",
"description": "Check GPU status before running"
}Wed Apr 29 09:03:05 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 28C P8 10W / 600W | 627MiB / 97887MiB | 0% Default |
{
"command": "<276 chars \u2014 see below>",
"timeout": 120000,
"description": "Test M=1 kernel execution"
}nvidia-smi | head -10 && uv run python -c "
import torch
import solution
M, N, K = 1, 12288, 4096
model = solution.Model(M, N, K).cuda()
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
y = model(x)
print('Output shape:', y.shape)
print('First 5 values:', y[0, :5])
" 2>&1
Wed Apr 29 09:03:11 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 28C P8 10W / 600W | 627MiB / 97887MiB | 0% Default |
Output shape: torch.Size([1, 12288])
First 5 values: tensor([ 0.3828, 1.1172, 0.2695, 3.1250, -0.1426], device='cuda:0',
dtype=torch.bfloat16)
{
"command": "nvidia-smi | head -5 && timeout 300 uv run python check.py 2>&1",
"description": "Run full correctness check",
"timeout": 300000
}Wed Apr 29 09:03:17 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
PASS
{
"command": "nvidia-smi | head -5 && timeout 300 uv run python benchmark.py 2>&1",
"description": "Run benchmark to check performance",
"timeout": 300000
}Wed Apr 29 09:03:27 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
shape=0 variant=eager tflops=0.131 gbps=34.761 ms=0.770
shape=0 variant=compiled tflops=0.634 gbps=168.722 ms=0.159
shape=0 variant=solution tflops=0.946 gbps=251.611 ms=0.106
shape=0 solution_peak_fraction=0.1398
shape=1 variant=eager tflops=4.208 gbps=36.296 ms=0.766
shape=1 variant=compiled tflops=19.983 gbps=172.378 ms=0.161
shape=1 variant=solution tflops=3.776 gbps=32.573 ms=0.853
shape=1 solution_peak_fraction=0.0181
shape=2 variant=eager tflops=31.210 gbps=42.543 ms=0.826
shape=2 variant=compiled tflops=117.153 gbps=159.693 ms=0.220
shape=2 variant=solution tflops=3.894 gbps=5.307 ms=6.619
shape=2 solution_peak_fraction=0.0029
shape=3 variant=eager tflops=0.178 gbps=47.335 ms=0.189
shape=3 variant=compiled tflops=0.497 gbps=132.309 ms=0.067
shape=3 variant=solution tflops=0.340 gbps=90.509 ms=0.099
shape=3 solution_peak_fraction=0.0503
shape=4 variant=eager tflops=2.055 gbps=34.761 ms=0.914
shape=4 variant=compiled tflops=10.218 gbps=172.850 ms=0.184
shape=4 variant=solution tflops=3.522 gbps=59.569 ms=0.534
shape=4 solution_peak_fraction=0.0331
peak_fraction: 0.0262
RESULT:
... (4 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
1: """High-performance W4A16 weight-only quantized GEMM for SM120 Blackwell.
2:
3: AWQ/GPTQ-style asymmetric int4 with per-group bf16 scales and zero-points.
4: Fuses unpack, dequant, and matmul in a single Triton kernel.
5:
6: Data layout:
7: x: (M, K) bf16
8: w_q: (K // 2, N) uint8 (low nibble=even-K, high=odd-K)
9: scales: (K // 128, N) bf16
10: zeros: (K // 128, N) bf16
11: out: (M, N) bf16
12: """
13: from __future__ import annotations
14:
15: import torch
16: import torch.nn as nn
17: import triton
18: import triton.language as tl
19:
20: GROUP_SIZE = 128
21:
22:
23: @triton.jit
24: def _w4a16_gemm_kernel(
25: x_ptr,
26: wq_ptr,
27: scales_ptr,
28: zeros_ptr,
29: out_ptr,
30: M,
31: N,
32: K,
33: stride_xm, stride_xk,
34: stride_wqk, stride_wqn,
35: stride_sg, stride_sn,
36: stride_zg, stride_zn,
37: stride_om, stride_on,
38: BLOCK_SIZE_M: tl.constexpr,
39: BLOCK_SIZE_N: tl.constexpr,
40: BLOCK_SIZE_K: tl.constexpr,
41: GROUP_SIZE_CONST: tl.constexpr,
42: ):
43: pid_m = tl.program_id(0)
44: pid_n = tl.program_id(1)
45:
46: offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
47: offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
48:
49: m_mask = offs_m < M
50:
... (5835 more chars){
"command": "nvidia-smi | head -10",
"description": "Check GPU status"
}Wed Apr 29 09:05:20 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 28C P8 10W / 600W | 627MiB / 97887MiB | 0% Default |
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -40,6 +40,11 @@
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_CONST: tl.constexpr,
):
+ """Fused int4 dequant + GEMM.
+
+ BK must equal GROUP_SIZE_CONST so each K-tile belongs to exactly one group.
+ This scales/zeros are loaded once per K-tile instead of per-k-row.
+ """
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -56,14 +61,14 @@
k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offs < K
- # x tile: (BM, BK) bf16
+ # x tile (BM, BK) bf16
x_tile = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
- # w_q tile: (BK, BN) uint8 — packed
+ # w_q tile (BK, BN) uint8 packed
wq_row = k_offs // 2
wq_tile = tl.load(
wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
@@ -71,34 +76,25 @@
other=0,
)
- # unpack nibbles
+ # unpack: even-K rows from low nibble, odd-K from high nibble
is_odd = (k_offs % 2)[:, None]
- w_int = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
+ w_int_4 = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
- # scales / zeros broadcast per group — load (1, BN) and broadcast
- g_first = k_base // GROUP_SIZE_CONST
- g_last = (k_base + BLOCK_SIZE_K - 1) // GROUP_SIZE_CONST # inclusive check
-
- # If this k range spans at most one group we can just load once
- # For BK <= GROUP_SIZE this is always true when k_base is GROUP_SIZE-aligned,
- # but our k_base increments by BLOCK_SIZE_K which may not be GROUP_SIZE-aligned.
- # Safe approach: load per-k, but to save shared mem we'll use register broadcast
- # since BN is small enough to fit in registers.
- g = k_offs // GROUP_SIZE_CONST
- s_base = scales_ptr + g[:, None] * stride_sg + offs_n[None, :] * stride_sn
- z_base = zeros_ptr + g[:, None] * stride_zg + offs_n[None, :] * stride_zn
- s_tile = tl.load(
- s_base,
- mask=k_mask[:, None] & n_mask[None, :],
+ # scales/zeros: ONE row per group (BK == GROUP_SIZE)
+ g = k_base // GROUP_SIZE_CONST
+ s_row = tl.load(
+ scales_ptr + g * stride_sg + offs_n * stride_sn,
+ mask=n_mask,
other=0.0,
).to(tl.float32)
- z_tile = tl.load(
- z_base,
- mask=k_mask[:, None] & n_mask[None, :],
+ z_row = tl.load(
+ zeros_ptr + g * stride_zg + offs_n * stride_zn,
+ mask=n_mask,
other=0.0,
).to(tl.float32)
- w_deq = (w_int.to(tl.float32) - z_tile) * s_tile
+ # dequant with broadcast: (1, BN) x (BK, BN) → (BK, BN)
+ w_deq = (w_int_4.to(tl.float32) - z_row[None, :]) * s_row[None, :]
acc = tl.dot(x_tile.to(tl.float32), w_deq, acc)
tl.store(
@@ -144,16 +140,26 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # Shared memory budget ~101KB on SM120.
- # Per-stage memory: x(BM,BK)*2 + wq(BK,BN)*1 + s(BK,BN)*2 + z(BK,BN)*2
- # With stages=1 (no pipelining) and num_warps=4:
- # Need to fit within 101KB total.
+ # BK=GROUP_SIZE=128 so each K-tile maps to exactly one group.
+ # scales/zeros loaded once per K-tile instead of per-row.
+ # Shared mem per stage: x(BM,128)*2 + wq(128,BN)*1 + s(1,BN)*2 + z(1,BN)*2
+ # = BM*256 + BN*132
+ # 2 stages: 2*(BM*256 + BN*132) must fit in 101KB.
- BM = 1 # M is small in all test shapes
- BN = 64
- BK = 64
- warps = 4
- stages = 1
+ # Shape-aware block configs
+ if M == 1:
+ # Decode: maximize N to amortize grid overhead
+ BM, BN, BK, warps, stages = 1, 256, 128, 4, 2
+ elif M <= 4:
+ BM, BN, BK, warps, stages = 2, 256, 128, 4, 2
+ elif M <= 16:
+ BM, BN, BK, warps, stages = 16, 128, 128, 4, 2
+ elif M <= 32:
+ BM, BN, BK, warps, stages = 32, 128, 128, 4, 2
+ elif M <= 128:
+ BM, BN, BK, warps, stages = 64, 128, 128, 4, 2
+ else:
+ BM, BN, BK, warps, stages = 128, 128, 128, 8, 2
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
{
"content": "<6504 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "nvidia-smi | head -10 && timeout 180 uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 240000
}Wed Apr 29 09:06:23 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 28C P8 10W / 600W | 627MiB / 97887MiB | 0% Default |
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/
... (1521 more chars){
"command": "<839 chars \u2014 see below>",
"description": "Compute shared memory requirements"
}uv run python -c "
# BM=1, BN=256, BK=128, 2 stages
# x: BM*BK*2 = 1*128*2 = 256B per stage
# wq: BK*BN*1 = 128*256*1 = 32768B per stage
# s: BN*2 = 256*2 = 512B per stage
# z: BN*2 = 256*2 = 512B per stage
# Total per stage: 34048
# 2 stages: 68096, should be fine...
BM, BN, BK, stages = 1, 256, 128, 2
per_stage = BM*BK*2 + BK*BN*1 + BN*2 + BN*2
total = per_stage * stages
print(f'Per stage: {per_stage}, Total: {total}')
print(f'Limit: 101376')
print()
# But triton may also need accumulator space
acc = BM * BN * 4 # float32
print(f'Accumulator (regs?): {acc*stages}')
print(f'Total with acc in shared: {total + acc}')
print()
# Try M=256 case: BM=64, BN=128, BK=128
BM, BN, BK, stages = 64, 128, 128, 2
per_stage = BM*BK*2 + BK*BN*1 + BN*2 + BN*2
total = per_stage * stages
print(f'M=256: Per stage: {per_stage}, Total: {total}')
"
Per stage: 34048, Total: 68096
Limit: 101376
Accumulator (regs?): 2048
Total with acc in shared: 69120
M=256: Per stage: 33280, Total: 66560
{
"command": "<555 chars \u2014 see below>",
"timeout": 120000,
"description": "Test each shape individually"
}uv run python -c "
import torch
import solution
# Test each shape individually to see which fails
shapes = [
(1, 12288, 4096),
(32, 12288, 4096),
(256, 12288, 4096),
(1, 4096, 4096),
(16, 14336, 4096),
]
for M, N, K in shapes:
try:
model = solution.Model(M, N, K).cuda()
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
y = model(x)
print(f'M={M:3d} N={N:5d} K={K:5d}: OK, shape={y.shape}')
except Exception as e:
print(f'M={M:3d} N={N:5d} K={K:5d}: FAIL - {type(e).__name__}')
" 2>&1
M= 1 N=12288 K= 4096: FAIL - OutOfResources
M= 32 N=12288 K= 4096: OK, shape=torch.Size([32, 12288])
M=256 N=12288 K= 4096: FAIL - OutOfResources
M= 1 N= 4096 K= 4096: FAIL - OutOfResources
M= 16 N=14336 K= 4096: OK, shape=torch.Size([16, 14336])
{
"command": "<604 chars \u2014 see below>",
"description": "Analyze shared memory configs"
}uv run python -c "
# BM=1, BN=256, BK=128, 2 stages — why failing?
# Triton shared-memory estimation is more complex than just raw bytes.
# The wq tile (128,256) uint8 needs 32KB per stage.
# But triton might pad to 128-bit alignment for vectorized loads.
# Also, with num_stages=2, Triton needs 2x of each buffer.
# Let's try smaller BN values and 1 stage.
for BM, BN, BK, stages, warps in [
(1, 128, 128, 2, 4),
(1, 256, 128, 1, 4),
(1, 128, 128, 1, 4),
]:
per = BM*BK*2 + BK*BN*1 + BN*2 + BN*2
print(f'BM={BM}, BN={BN}, BK={BK}, s={stages}: per_stage={per}, total={per*stages}')
"
BM=1, BN=128, BK=128, s=2: per_stage=17152, total=34304
BM=1, BN=256, BK=128, s=1: per_stage=34048, total=34048
BM=1, BN=128, BK=128, s=1: per_stage=17152, total=17152
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -146,20 +146,20 @@
# = BM*256 + BN*132
# 2 stages: 2*(BM*256 + BN*132) must fit in 101KB.
- # Shape-aware block configs
+ # Shape-aware block configs — constrained by 101KB shared mem.
+ # BK=GROUP_SIZE=128 so scale/zeros fit in 1 row per tile.
if M == 1:
- # Decode: maximize N to amortize grid overhead
- BM, BN, BK, warps, stages = 1, 256, 128, 4, 2
+ BM, BN, BK, warps, stages = 1, 512, 128, 4, 1
elif M <= 4:
- BM, BN, BK, warps, stages = 2, 256, 128, 4, 2
+ BM, BN, BK, warps, stages = 4, 256, 128, 4, 1
elif M <= 16:
- BM, BN, BK, warps, stages = 16, 128, 128, 4, 2
+ BM, BN, BK, warps, stages = 16, 128, 128, 4, 1
elif M <= 32:
- BM, BN, BK, warps, stages = 32, 128, 128, 4, 2
+ BM, BN, BK, warps, stages = 32, 128, 128, 4, 1
elif M <= 128:
- BM, BN, BK, warps, stages = 64, 128, 128, 4, 2
+ BM, BN, BK, warps, stages = 64, 128, 128, 4, 1
else:
- BM, BN, BK, warps, stages = 128, 128, 128, 8, 2
+ BM, BN, BK, warps, stages = 128, 64, 128, 8, 1
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<568 chars \u2014 see below>",
"newString": "<609 chars \u2014 see below>"
} # Shape-aware block configs
if M == 1:
# Decode: maximize N to amortize grid overhead
BM, BN, BK, warps, stages = 1, 256, 128, 4, 2
elif M <= 4:
BM, BN, BK, warps, stages = 2, 256, 128, 4, 2
elif M <= 16:
BM, BN, BK, warps, stages = 16, 128, 128, 4, 2
elif M <= 32:
BM, BN, BK, warps, stages = 32, 128, 128, 4, 2
elif M <= 128:
BM, BN, BK, warps, stages = 64, 128, 128, 4, 2
else:
BM, BN, BK, warps, stages = 128, 128, 128, 8, 2 # Shape-aware block configs — constrained by 101KB shared mem.
# BK=GROUP_SIZE=128 so scale/zeros fit in 1 row per tile.
if M == 1:
BM, BN, BK, warps, stages = 1, 512, 128, 4, 1
elif M <= 4:
BM, BN, BK, warps, stages = 4, 256, 128, 4, 1
elif M <= 16:
BM, BN, BK, warps, stages = 16, 128, 128, 4, 1
elif M <= 32:
BM, BN, BK, warps, stages = 32, 128, 128, 4, 1
elif M <= 128:
BM, BN, BK, warps, stages = 64, 128, 128, 4, 1
else:
BM, BN, BK, warps, stages = 128, 64, 128, 8, 1Edit applied successfully.{
"command": "<505 chars \u2014 see below>",
"timeout": 120000,
"description": "Test each shape individually (attempt 4)"
}uv run python -c "
import torch
import solution
shapes = [
(1, 12288, 4096),
(32, 12288, 4096),
(256, 12288, 4096),
(1, 4096, 4096),
(16, 14336, 4096),
]
for M, N, K in shapes:
try:
model = solution.Model(M, N, K).cuda()
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
y = model(x)
print(f'M={M:3d} N={N:5d} K={K:5d}: OK, shape={y.shape}')
except Exception as e:
print(f'M={M:3d} N={N:5d} K={K:5d}: FAIL - {type(e).__name__}')
" 2>&1
M= 1 N=12288 K= 4096: FAIL - OutOfResources
M= 32 N=12288 K= 4096: OK, shape=torch.Size([32, 12288])
M=256 N=12288 K= 4096: OK, shape=torch.Size([256, 12288])
M= 1 N= 4096 K= 4096: FAIL - OutOfResources
M= 16 N=14336 K= 4096: OK, shape=torch.Size([16, 14336])
{
"command": "<503 chars \u2014 see below>",
"description": "Recalculate shared memory for M=1"
}uv run python -c "
# M=1: BM=1, BN=512, BK=128, stages=1
# per_stage = 1*128*2 + 128*512*1 + 512*2 + 512*2 = 256 + 65536 + 1024 + 1024 = 67840
# stages=1: total=67840 < 101376 — should fit
# Why is it failing? Maybe Triton rounds up BN to 512+128=640 for padding?
# Or maybe the shared memory estimate includes alignment padding.
# Try BN=256 for M=1.
print('BM=1, BN=256, BK=128, s=1:', 1*128*2 + 128*256 + 256*4)
print('BM=1, BN=384, BK=128, s=1:', 1*128*2 + 128*384 + 384*4)
print('Limit: 101376')
"
BM=1, BN=256, BK=128, s=1: 34048
BM=1, BN=384, BK=128, s=1: 50944
Limit: 101376
{
"command": "<4119 chars \u2014 see below>",
"timeout": 120000,
"description": "Test individual BN configs for M=1"
}uv run python -c "
import torch
import triton
import triton.language as tl
# Let me try smaller configs for M=1
for BN in [128, 256, 384]:
@triton.jit
def _test_kernel(x_ptr, wq_ptr, s_ptr, z_ptr, out_ptr, M, N, K,
stride_xm, stride_xk, stride_wqk, stride_wqn, stride_sg, stride_sn, stride_zg, stride_zn, stride_om, stride_on,
BM: tl.constexpr, BN_CONST: tl.constexpr, BK: tl.constexpr, GS: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BM + tl.arange(0, BM)
offs_n = pid_n * BN_CONST + tl.arange(0, BN_CONST)
m_mask = offs_m < M
n_mask = offs_n < N
out_mask = m_mask[:, None] & n_mask[None, :]
acc = tl.zeros((BM, BN_CONST), dtype=tl.float32)
for k_base in range(0, K, BK):
k_offs = k_base + tl.arange(0, BK)
k_mask = k_offs < K
x_tile = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
wq_row = k_offs // 2
wq_tile = tl.load(
wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
mask=k_mask[:, None] & n_mask[None, :],
other=0,
)
is_odd = (k_offs % 2)[:, None]
w_int_4 = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
g = k_base // GS
s_row = tl.load(
s_ptr + g * stride_sg + offs_n * stride_sn,
mask=n_mask,
other=0.0,
).to(tl.float32)
z_row = tl.load(
z_ptr + g * stride_zg + offs_n * stride_zn,
mask=n_mask,
other=0.0,
).to(tl.float32)
w_deq = (w_int_4.to(tl.float32) - z_row[None, :]) * s_row[None, :]
acc = tl.dot(x_tile.to(tl.float32), w_deq, acc)
tl.store(
out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on,
acc.to(tl.bfloat16),
mask=out_mask,
)
M, N, K = 1, 12288, 4096
BM_val = 1
BK_val = 128
model = None
class Model(torch.nn.Module):
def __init__(self, M, N, K):
super().__init__()
self.M, self.N, self.K = M, N, K
GS_val = 128
n_groups = K // GS_val
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
w_g = w_full.view(n_groups, GS_val, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8).view(K, N)
lo = w_q[0::2].to(torch.uint8) & 0xF
hi = w_q[1::2].to(torch.uint8) & 0xF
w_packed = (lo | (hi << 4)).contiguous()
self.register_buffer('w_q', w_packed)
self.register_buffer('scales', scales.squeeze(1).to(torch.bfloat16))
self.register_buffer('zeros', zeros.squeeze(1).to(torch.bfloat16))
model = Model(M, N, K).cuda()
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
grid = (triton.cdiv(M, BM_val), triton.cdiv(N, BN))
try:
_test_kernel[grid](
x.contiguous(), model.w_q, model.scales, model.zeros, out,
M, N, K, x.stride(0), x.stride(1), model.w_q.stride(0), model.w_q.stride(1),
model.scales.stride(0), model.scales.stride(1), model.zeros.stride(0), model.zeros.stride(1),
out.stride(0), out.stride(1),
BM=BM_val, BN_CONST=BN, BK=BK_val, GS=128,
num_warps=4, num_stages=1,
)
... (119 more chars)Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 8, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/test_configs.py
@@ -0,0 +1,155 @@
+"""Test different Triton configs for W4A16 GEMM."""
+import torch
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.jit
+def _w4a16_gemm_kernel_v2(
+ x_ptr, wq_ptr, scales_ptr, zeros_ptr, out_ptr,
+ M, N, K,
+ stride_xm, stride_xk,
+ stride_wqk, stride_wqn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ stride_om, stride_on,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_CONST: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+
+ m_mask = offs_m < M
+ n_mask = offs_n < N
+ out_mask = m_mask[:, None] & n_mask[None, :]
+
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+
+ for k_base in range(0, K, BLOCK_SIZE_K):
+ k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
+ k_mask = k_offs < K
+
+ x_tile = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
+ mask=m_mask[:, None] & k_mask[None, :],
+ other=0.0,
+ )
+
+ wq_row = k_offs // 2
+ wq_tile = tl.load(
+ wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
+ mask=k_mask[:, None] & n_mask[None, :],
+ other=0,
+ )
+
+ is_odd = (k_offs % 2)[:, None]
+ w_int_4 = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
+
+ g = k_base // GROUP_SIZE_CONST
+ s_row = tl.load(
+ scales_ptr + g * stride_sg + offs_n * stride_sn,
+ mask=n_mask,
+ other=0.0,
+ ).to(tl.float32)
+ z_row = tl.load(
+ zeros_ptr + g * stride_zg + offs_n * stride_zn,
+ mask=n_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ w_deq = (w_int_4.to(tl.float32) - z_row[None, :]) * s_row[None, :]
+ acc = tl.dot(x_tile.to(tl.float32), w_deq, acc)
+
+ tl.store(
+ out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on,
+ acc.to(tl.bfloat16),
+ mask=out_mask,
+ )
+
+
+def test_config(M, N, K, BM, BN, BK, warps, stages, label=""):
+ """Test a specific config. Returns True on success."""
+ class Model(torch.nn.Module):
+ def __init__(self, M, N, K):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ n_groups = K // GROUP_SIZE
+ torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
+ w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
+ w_g = w_full.view(n_groups, GROUP_SIZE, N)
+ w_min = w_g.min(dim=1, keepdim=True).values
+ w_max = w_g.max(dim=1, keepdim=True).values
+ scales = (w_max - w_min).clamp_min(1e-8) / 15.0
+ zeros = (-w_min / scales).round().clamp(0, 15)
+ w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8).view(K, N)
+ lo = w_q[0::2].to(torch.uint8) & 0xF
+ hi = w_q[1::2].to(torch.uint8) & 0xF
+ w_packed = (lo | (hi << 4)).contiguous()
+ self.register_buffer("w_q", w_packed)
+ self.register_buffer("scales", scales.squeeze(1).to(torch.bfloat16))
+ self.register_buffer("zeros", zeros.squeeze(1).to(torch.bfloat16))
+
+ model = Model(M, N, K).cuda()
+ x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
+
+ try:
+ _w4a16_gemm_kernel_v2[grid](
+ x.contiguous(), model.w_q, model.scales, model.zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1),
+ model.w_q.stride(0), model.w_q.stride(1),
+ model.scales.stride(0), model.scales.stride(1),
+ model.zeros.stride(0), model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK,
+ GROUP_SIZE_CONST=GROUP_SIZE,
+ num_warps=warps, num_stages=stages,
+ )
+ print(f"{label} M={M} BM={BM} BN={BN} BK={BK} w={warps} s={stages}: OK")
+ return True
+ except Exception as e:
+ print(f"{label} M={M} BM={BM} BN={BN} BK={BK} w={warps} s={stages}: FAIL - {e}")
+ return False
+
+
+if __name__ == "__main__":
+ shapes = [
+ (1, 12288, 4096),
+ (32, 12288, 4096),
+ (256, 12288, 4096),
+ (1, 4096, 4096),
+ (16, 14336, 4096),
+ ]
+
+ for M, N, K in shapes:
+ print(f"\n=== Testing shape ({M}, {N}, {K}) ===")
+ if M == 1:
+ # Try many configs for decode
+ for BN in [64, 128, 256]:
+ for stages in [1]:
+ for warps in [1, 2, 4]:
+ test_config(M, N, K, 1, BN, 128, warps, stages, "M1")
+ elif M <= 4:
+ test_config(M, N, K, M, 256, 128, 4, 1, f"M{M}")
+ elif M <= 16:
+ test_config(M, N, K, M, 128, 128, 4, 1, f"M{M}")
+ test_config(M, N, K, 16, 128, 128, 4, 1, f"M{M}a")
+ elif M <= 32:
+ test_config(M, N, K, M, 128, 128, 4, 1, f"M{M}")
+ test_config(M, N, K, 32, 64, 128, 4, 1, f"M{M}a")
+ elif M <= 128:
+ test_config(M, N, K, 32, 128, 128, 4, 1, f"M{M}a")
+ test_config(M, N, K, 64, 128, 128, 4, 1, f"M{M}b")
+ test_config(M, N, K, 128, 64, 128, 4, 1, f"M{M}c")
+ else:
+ test_config(M, N, K, 64, 128, 128, 4, 1, f"M{M}a")
+ test_config(M, N, K, 128, 64, 128, 4, 1, f"M{M}b")
{
"content": "<5595 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/test_configs.py"
}Wrote file successfully.{
"command": "nvidia-smi | head -5",
"description": "Check GPU status"
}Wed Apr 29 09:09:02 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "nvidia-smi | head -5",
"description": "Check GPU status"
}Wed Apr 29 09:09:06 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "timeout 180 uv run python test_configs.py 2>&1",
"timeout": 200000,
"description": "Run config sweep test"
}=== Testing shape (1, 12288, 4096) ===
M1 M=1 BM=1 BN=64 BK=128 w=1 s=1: OK
M1 M=1 BM=1 BN=64 BK=128 w=2 s=1: OK
M1 M=1 BM=1 BN=64 BK=128 w=4 s=1: OK
M1 M=1 BM=1 BN=128 BK=128 w=1 s=1: OK
M1 M=1 BM=1 BN=128 BK=128 w=2 s=1: OK
M1 M=1 BM=1 BN=128 BK=128 w=4 s=1: OK
M1 M=1 BM=1 BN=256 BK=128 w=1 s=1: FAIL - out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
M1 M=1 BM=1 BN=256 BK=128 w=2 s=1: FAIL - out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
M1 M=1 BM=1 BN=256 BK=128 w=4 s=1: FAIL - out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
=== Testing shape (32, 12288, 4096) ===
M32 M=32 BM=32 BN=128 BK=128 w=4 s=1: OK
M32a M=32 BM=32 BN=64 BK=128 w=4 s=1: OK
=== Testing shape (256, 12288, 4096) ===
M256a M=256 BM=64 BN=128 BK=128 w=4 s=1: OK
M256b M=256 BM=128 BN=64 BK=128 w=4 s=1: OK
=== Testing shape (1, 4096, 4096) ===
M1 M=1 BM=1 BN=64 BK=128 w=1 s=1: OK
M1 M=1 BM=1 BN=64 BK=128 w=2 s=1: OK
M1 M=1 BM=1 BN=64 BK=128 w=4 s=1: OK
M1 M=1 BM=1 BN=128 BK=128 w=1 s=1: OK
M1 M=1 BM=1 BN=128 BK=128 w=2 s=1: OK
M1 M=1 BM=1 BN=128 BK=128 w=4 s=1: OK
M1 M=1 BM=1 BN=256 BK=128 w=1 s=1: FAIL - out of resource: shared memory, Required: 131072, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
M1 M=1 BM=1 BN=256 BK=128 w=2 s=1: FAIL - out of resource: shared m
... (383 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
1: """High-performance W4A16 weight-only quantized GEMM for SM120 Blackwell.
2:
3: AWQ/GPTQ-style asymmetric int4 with per-group bf16 scales and zero-points.
4: Fuses unpack, dequant, and matmul in a single Triton kernel.
5:
6: Data layout:
7: x: (M, K) bf16
8: w_q: (K // 2, N) uint8 (low nibble=even-K, high=odd-K)
9: scales: (K // 128, N) bf16
10: zeros: (K // 128, N) bf16
11: out: (M, N) bf16
12: """
13: from __future__ import annotations
14:
15: import torch
16: import torch.nn as nn
17: import triton
18: import triton.language as tl
19:
20: GROUP_SIZE = 128
21:
22:
23: @triton.jit
24: def _w4a16_gemm_kernel(
25: x_ptr,
26: wq_ptr,
27: scales_ptr,
28: zeros_ptr,
29: out_ptr,
30: M,
31: N,
32: K,
33: stride_xm, stride_xk,
34: stride_wqk, stride_wqn,
35: stride_sg, stride_sn,
36: stride_zg, stride_zn,
37: stride_om, stride_on,
38: BLOCK_SIZE_M: tl.constexpr,
39: BLOCK_SIZE_N: tl.constexpr,
40: BLOCK_SIZE_K: tl.constexpr,
41: GROUP_SIZE_CONST: tl.constexpr,
42: ):
43: """Fused int4 dequant + GEMM.
44:
45: BK must equal GROUP_SIZE_CONST so each K-tile belongs to exactly one group.
46: This scales/zeros are loaded once per K-tile instead of per-k-row.
47: """
48: pid_m = tl.program_id(
... (6084 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -149,7 +149,7 @@
# Shape-aware block configs — constrained by 101KB shared mem.
# BK=GROUP_SIZE=128 so scale/zeros fit in 1 row per tile.
if M == 1:
- BM, BN, BK, warps, stages = 1, 512, 128, 4, 1
+ BM, BN, BK, warps, stages = 1, 128, 128, 4, 1
elif M <= 4:
BM, BN, BK, warps, stages = 4, 256, 128, 4, 1
elif M <= 16:
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<472 chars \u2014 see below>",
"newString": "<472 chars \u2014 see below>"
} if M == 1:
BM, BN, BK, warps, stages = 1, 512, 128, 4, 1
elif M <= 4:
BM, BN, BK, warps, stages = 4, 256, 128, 4, 1
elif M <= 16:
BM, BN, BK, warps, stages = 16, 128, 128, 4, 1
elif M <= 32:
BM, BN, BK, warps, stages = 32, 128, 128, 4, 1
elif M <= 128:
BM, BN, BK, warps, stages = 64, 128, 128, 4, 1
else:
BM, BN, BK, warps, stages = 128, 64, 128, 8, 1 if M == 1:
BM, BN, BK, warps, stages = 1, 128, 128, 4, 1
elif M <= 4:
BM, BN, BK, warps, stages = 4, 256, 128, 4, 1
elif M <= 16:
BM, BN, BK, warps, stages = 16, 128, 128, 4, 1
elif M <= 32:
BM, BN, BK, warps, stages = 32, 128, 128, 4, 1
elif M <= 128:
BM, BN, BK, warps, stages = 64, 128, 128, 4, 1
else:
BM, BN, BK, warps, stages = 128, 64, 128, 8, 1Edit applied successfully.{
"command": "timeout 180 uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 200000
}PASS
{
"command": "timeout 300 uv run python benchmark.py 2>&1",
"description": "Run benchmark for performance metrics",
"timeout": 360000
}shape=0 variant=eager tflops=0.130 gbps=34.606 ms=0.774
shape=0 variant=compiled tflops=0.635 gbps=168.858 ms=0.159
shape=0 variant=solution tflops=0.808 gbps=214.790 ms=0.125
shape=0 solution_peak_fraction=0.1193
shape=1 variant=eager tflops=4.190 gbps=36.142 ms=0.769
shape=1 variant=compiled tflops=20.342 gbps=175.478 ms=0.158
shape=1 variant=solution tflops=24.651 gbps=212.649 ms=0.131
shape=1 solution_peak_fraction=0.1181
shape=2 variant=eager tflops=31.230 gbps=42.571 ms=0.825
shape=2 variant=compiled tflops=119.252 gbps=162.554 ms=0.216
shape=2 variant=solution tflops=81.509 gbps=111.106 ms=0.316
shape=2 solution_peak_fraction=0.0617
shape=3 variant=eager tflops=0.175 gbps=46.526 ms=0.192
shape=3 variant=compiled tflops=0.490 gbps=130.271 ms=0.069
shape=3 variant=solution tflops=0.285 gbps=75.805 ms=0.118
shape=3 solution_peak_fraction=0.0421
shape=4 variant=eager tflops=2.056 gbps=34.771 ms=0.914
shape=4 variant=compiled tflops=10.142 gbps=171.551 ms=0.185
shape=4 variant=solution tflops=15.477 gbps=261.803 ms=0.121
shape=4 solution_peak_fraction=0.1454
peak_fraction: 0.0882
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -40,11 +40,6 @@
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_CONST: tl.constexpr,
):
- """Fused int4 dequant + GEMM.
-
- BK must equal GROUP_SIZE_CONST so each K-tile belongs to exactly one group.
- This scales/zeros are loaded once per K-tile instead of per-k-row.
- """
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -76,11 +71,11 @@
other=0,
)
- # unpack: even-K rows from low nibble, odd-K from high nibble
+ # unpack nibbles
is_odd = (k_offs % 2)[:, None]
w_int_4 = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
- # scales/zeros: ONE row per group (BK == GROUP_SIZE)
+ # scales/zeros: ONE row per group
g = k_base // GROUP_SIZE_CONST
s_row = tl.load(
scales_ptr + g * stride_sg + offs_n * stride_sn,
@@ -93,7 +88,6 @@
other=0.0,
).to(tl.float32)
- # dequant with broadcast: (1, BN) x (BK, BN) → (BK, BN)
w_deq = (w_int_4.to(tl.float32) - z_row[None, :]) * s_row[None, :]
acc = tl.dot(x_tile.to(tl.float32), w_deq, acc)
@@ -102,6 +96,79 @@
acc.to(tl.bfloat16),
mask=out_mask,
)
+
+
+# ---------------------------------------------------------------------------
+# Column-based GEMV kernel for M=1 (decode path)
+# Each CTA computes one output column. Warps share the reduction over K.
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _w4a16_gemv_kernel(
+ x_ptr,
+ wq_ptr,
+ scales_ptr,
+ zeros_ptr,
+ out_ptr,
+ K,
+ N,
+ stride_xk,
+ stride_wqk, stride_wqn,
+ stride_sg, stride_sn,
+ stride_zg, stride_zn,
+ BLOCK_SIZE_K: tl.constexpr,
+ NUM_WARPS: tl.constexpr,
+ GROUP_SIZE_CONST: tl.constexpr,
+):
+ col_idx = tl.program_id(0)
+ warpid = tl.program_id(1)
+ num_warps_in_grid = tl.num_programs(1)
+
+ # Each CTA has num_warps_in_grid blocks, each block processes a chunk of K
+ chunk_size = BLOCK_SIZE_K * NUM_WARPS
+
+ lane_id = tl.arange(0, BLOCK_SIZE_K)
+ base_k = warpid * BLOCK_SIZE_K
+
+ partial = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
+
+ for k_base in range(0, K, chunk_size):
+ k_offs = k_base + base_k + lane_id
+ k_mask = k_offs < K
+
+ # x vector
+ x_val = tl.load(
+ x_ptr + k_offs * stride_xk,
+ mask=k_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ # packed weight byte
+ wq_row = k_offs // 2
+ wq_byte = tl.load(
+ wq_ptr + wq_row * stride_wqk + col_idx * stride_wqn,
+ mask=k_mask,
+ other=0,
+ ).to(tl.int32)
+
+ is_odd = k_offs % 2
+ w_int = tl.where(is_odd == 0, wq_byte & 0xF, wq_byte >> 4)
+
+ # group index — load scale and zero once per chunk (since BK == GROUP_SIZE)
+ g = (k_base + base_k) // GROUP_SIZE_CONST
+ s = tl.load(scales_ptr + g * stride_sg + col_idx * stride_sn).to(tl.float32)
+ z = tl.load(zeros_ptr + g * stride_zg + col_idx * stride_zn).to(tl.float32)
+
+ w_deq = (w_int.to(tl.float32) - z) * s
+ partial = tl.where(k_mask, partial + x_val * w_deq, partial)
+
+ # Warp-level reduction using tl.reduce
+ if BLOCK_SIZE_K > 1:
+ acc = tl.sum(partial)
+ else:
+ acc = partial
+
+ tl.store(out_ptr + col_idx, acc.to(tl.bfloat16))
# ---------------------------------------------------------------------------
@@ -140,48 +207,63 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # BK=GROUP_SIZE=128 so each K-tile maps to exactly one group.
- # scales/zeros loaded once per K-tile instead of per-row.
- # Shared mem per stage: x(BM,128)*2 + wq(128,BN)*1 + s(1,BN)*2 + z(1,BN)*2
- # = BM*256 + BN*132
- # 2 stages: 2*(BM*256 + BN*132) must fit in 101KB.
-
- # Shape-aware block configs — constrained by 101KB shared mem.
- # BK=GROUP_SIZE=128 so scale/zeros fit in 1 row per tile.
if M == 1:
- BM, BN, BK, warps, stages = 1, 128, 128, 4, 1
- elif M <= 4:
- BM, BN, BK, warps, stages = 4, 256, 128, 4, 1
- elif M <= 16:
- BM, BN, BK, warps, stages = 16, 128, 128, 4, 1
- elif M <= 32:
- BM, BN, BK, warps, stages = 32, 128, 128, 4, 1
- elif M <= 128:
- BM, BN, BK, warps, stages = 64, 128, 128, 4, 1
+ # Column-based GEMV: one CTA per column, shared-mem free
+ # Each CTA processes K threads across multiple warps
+ BK = 128
+ warps_per_col = 8 # total threads per column = 128 * 8 = 1024
+ num_ctas = N * warps_per_col
+ grid = (N, warps_per_col)
+
+ _w4a16_gemv_kernel[grid](
+ x.contiguous(),
+ self.w_q,
+ self.scales,
+ self.zeros,
+ out,
+ K, N,
+ x.stride(1),
+ self.w_q.stride(0), self.w_q.stride(1),
+ self.scales.stride(0), self.scales.stride(1),
+ self.zeros.stride(0), self.zeros.stride(1),
+ BLOCK_SIZE_K=BK,
+ NUM_WARPS=warps_per_col,
+ GROUP_SIZE_CONST=GROUP_SIZE,
+ num_warps=1,
+ num_stages=1,
+ )
else:
- BM, BN, BK, warps, stages = 128, 64, 128, 8, 1
-
- grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
-
- _w4a16_gemm_kernel[grid](
- x.contiguous(),
- self.w_q,
- self.scales,
- self.zeros,
- out,
- M, N, K,
- x.stride(0), x.stride(1),
- self.w_q.stride(0), self.w_q.stride(1),
-
... (1598 more chars){
"content": "<8902 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,14 +1,6 @@
-"""High-performance W4A16 weight-only quantized GEMM for SM120 Blackwell.
+"""W4A16 weight-only quantized GEMM for SM120 Blackwell.
-AWQ/GPTQ-style asymmetric int4 with per-group bf16 scales and zero-points.
Fuses unpack, dequant, and matmul in a single Triton kernel.
-
-Data layout:
- x: (M, K) bf16
- w_q: (K // 2, N) uint8 (low nibble=even-K, high=odd-K)
- scales: (K // 128, N) bf16
- zeros: (K // 128, N) bf16
- out: (M, N) bf16
"""
from __future__ import annotations
@@ -56,14 +48,12 @@
k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offs < K
- # x tile (BM, BK) bf16
x_tile = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
- # w_q tile (BK, BN) uint8 packed
wq_row = k_offs // 2
wq_tile = tl.load(
wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
@@ -71,11 +61,9 @@
other=0,
)
- # unpack nibbles
is_odd = (k_offs % 2)[:, None]
w_int_4 = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
- # scales/zeros: ONE row per group
g = k_base // GROUP_SIZE_CONST
s_row = tl.load(
scales_ptr + g * stride_sg + offs_n * stride_sn,
@@ -97,83 +85,6 @@
mask=out_mask,
)
-
-# ---------------------------------------------------------------------------
-# Column-based GEMV kernel for M=1 (decode path)
-# Each CTA computes one output column. Warps share the reduction over K.
-# ---------------------------------------------------------------------------
-
-@triton.jit
-def _w4a16_gemv_kernel(
- x_ptr,
- wq_ptr,
- scales_ptr,
- zeros_ptr,
- out_ptr,
- K,
- N,
- stride_xk,
- stride_wqk, stride_wqn,
- stride_sg, stride_sn,
- stride_zg, stride_zn,
- BLOCK_SIZE_K: tl.constexpr,
- NUM_WARPS: tl.constexpr,
- GROUP_SIZE_CONST: tl.constexpr,
-):
- col_idx = tl.program_id(0)
- warpid = tl.program_id(1)
- num_warps_in_grid = tl.num_programs(1)
-
- # Each CTA has num_warps_in_grid blocks, each block processes a chunk of K
- chunk_size = BLOCK_SIZE_K * NUM_WARPS
-
- lane_id = tl.arange(0, BLOCK_SIZE_K)
- base_k = warpid * BLOCK_SIZE_K
-
- partial = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
-
- for k_base in range(0, K, chunk_size):
- k_offs = k_base + base_k + lane_id
- k_mask = k_offs < K
-
- # x vector
- x_val = tl.load(
- x_ptr + k_offs * stride_xk,
- mask=k_mask,
- other=0.0,
- ).to(tl.float32)
-
- # packed weight byte
- wq_row = k_offs // 2
- wq_byte = tl.load(
- wq_ptr + wq_row * stride_wqk + col_idx * stride_wqn,
- mask=k_mask,
- other=0,
- ).to(tl.int32)
-
- is_odd = k_offs % 2
- w_int = tl.where(is_odd == 0, wq_byte & 0xF, wq_byte >> 4)
-
- # group index — load scale and zero once per chunk (since BK == GROUP_SIZE)
- g = (k_base + base_k) // GROUP_SIZE_CONST
- s = tl.load(scales_ptr + g * stride_sg + col_idx * stride_sn).to(tl.float32)
- z = tl.load(zeros_ptr + g * stride_zg + col_idx * stride_zn).to(tl.float32)
-
- w_deq = (w_int.to(tl.float32) - z) * s
- partial = tl.where(k_mask, partial + x_val * w_deq, partial)
-
- # Warp-level reduction using tl.reduce
- if BLOCK_SIZE_K > 1:
- acc = tl.sum(partial)
- else:
- acc = partial
-
- tl.store(out_ptr + col_idx, acc.to(tl.bfloat16))
-
-
-# ---------------------------------------------------------------------------
-# Model
-# ---------------------------------------------------------------------------
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
@@ -207,63 +118,38 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- if M == 1:
- # Column-based GEMV: one CTA per column, shared-mem free
- # Each CTA processes K threads across multiple warps
+ if M <= 16:
+ BM = min(M, 16)
+ BN = 256
BK = 128
- warps_per_col = 8 # total threads per column = 128 * 8 = 1024
- num_ctas = N * warps_per_col
- grid = (N, warps_per_col)
+ warps = 4
+ else:
+ BM = min(M, 128)
+ BN = 128
+ BK = 128
+ warps = 4
- _w4a16_gemv_kernel[grid](
- x.contiguous(),
- self.w_q,
- self.scales,
- self.zeros,
- out,
- K, N,
- x.stride(1),
- self.w_q.stride(0), self.w_q.stride(1),
- self.scales.stride(0), self.scales.stride(1),
- self.zeros.stride(0), self.zeros.stride(1),
- BLOCK_SIZE_K=BK,
- NUM_WARPS=warps_per_col,
- GROUP_SIZE_CONST=GROUP_SIZE,
- num_warps=1,
- num_stages=1,
- )
- else:
- # Block configs — BK=GROUP_SIZE=128 so scale/zeros fit in 1 row per tile.
- if M <= 4:
- BM, BN, BK, warps, stages = 4, 256, 128, 4, 1
- elif M <= 16:
- BM, BN, BK, warps, stages = 16, 128, 128, 4, 1
- elif M <= 32:
- BM, BN, BK, warps, stages = 32, 128, 128, 4, 1
- else:
- BM, BN, BK, warps, stages = 64, 128, 128, 4, 1
+ grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
- grid = (triton.c
... (1397 more chars){
"content": "<4879 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "timeout 180 uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 200000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 134, in forward
_w4a16_gemm_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 743, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/
... (661 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -118,16 +118,14 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- if M <= 16:
- BM = min(M, 16)
- BN = 256
- BK = 128
- warps = 4
+ # BN=128 is safe for 101KB shared mem with stages=1
+ if M <= 4:
+ BM, BN = min(M, 4), 128
+ elif M <= 16:
+ BM, BN = min(M, 16), 128
else:
- BM = min(M, 128)
- BN = 128
- BK = 128
- warps = 4
+ BM, BN = min(M, 64), 128
+ BK = 128
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
@@ -147,7 +145,7 @@
BLOCK_SIZE_N=BN,
BLOCK_SIZE_K=BK,
GROUP_SIZE_CONST=GROUP_SIZE,
- num_warps=warps,
+ num_warps=4,
num_stages=1,
)
return out
{
"content": "<4898 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "timeout 180 uv run python check.py 2>&1 && timeout 300 uv run python benchmark.py 2>&1",
"description": "Run check and benchmark",
"timeout": 500000
}PASS
shape=0 variant=eager tflops=0.130 gbps=34.662 ms=0.772
shape=0 variant=compiled tflops=0.638 gbps=169.732 ms=0.158
shape=0 variant=solution tflops=0.807 gbps=214.570 ms=0.125
shape=0 solution_peak_fraction=0.1192
shape=1 variant=eager tflops=4.193 gbps=36.170 ms=0.768
shape=1 variant=compiled tflops=20.195 gbps=174.210 ms=0.160
shape=1 variant=solution tflops=24.534 gbps=211.638 ms=0.131
shape=1 solution_peak_fraction=0.1176
shape=2 variant=eager tflops=31.331 gbps=42.707 ms=0.823
shape=2 variant=compiled tflops=118.097 gbps=160.981 ms=0.218
shape=2 variant=solution tflops=60.201 gbps=82.061 ms=0.428
shape=2 solution_peak_fraction=0.0456
shape=3 variant=eager tflops=0.174 gbps=46.406 ms=0.192
shape=3 variant=compiled tflops=0.489 gbps=130.210 ms=0.069
shape=3 variant=solution tflops=0.281 gbps=74.760 ms=0.119
shape=3 solution_peak_fraction=0.0415
shape=4 variant=eager tflops=2.058 gbps=34.815 ms=0.913
shape=4 variant=compiled tflops=10.324 gbps=174.627 ms=0.182
shape=4 variant=solution tflops=15.479 gbps=261.837 ms=0.121
shape=4 solution_peak_fraction=0.1455
peak_fraction: 0.0827
RESULT: LOW
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/profile_configs.py
@@ -0,0 +1,72 @@
+"""Profile Triton configs for W4A16 GEMM."""
+import torch
+import time
+import solution
+
+
+def benchmark_fn(model, x, n_iter=50):
+ for _ in range(5):
+ model(x)
+ torch.cuda.synchronize()
+ t0 = time.perf_counter()
+ for _ in range(n_iter):
+ model(x)
+ torch.cuda.synchronize()
+ return (time.perf_counter() - t0) / n_iter * 1000
+
+
+SHAPES = [
+ (1, 12288, 4096),
+ (32, 12288, 4096),
+ (256, 12288, 4096),
+ (1, 4096, 4096),
+ (16, 14336, 4096),
+]
+
+configs = [
+ (1, 64, 128),
+ (1, 128, 128),
+ (4, 64, 128),
+ (4, 128, 128),
+ (4, 256, 128),
+ (16, 64, 128),
+ (16, 128, 128),
+ (16, 256, 128),
+ (32, 64, 128),
+ (32, 128, 128),
+ (64, 64, 128),
+ (64, 128, 128),
+ (128, 64, 128),
+ (128, 128, 128),
+]
+
+for M, N, K in SHAPES:
+ print(f"\n=== Shape ({M}, {N}, {K}) ===")
+ model = solution.Model(M, N, K).cuda()
+ x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+
+ for BM, BN, BK in configs:
+ if BM > M and M > 0:
+ continue
+ try:
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ grid = (solution.triton.cdiv(M, BM), solution.triton.cdiv(N, BN))
+ solution._w4a16_gemm_kernel[grid](
+ x.contiguous(), model.w_q, model.scales, model.zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1), model.w_q.stride(0), model.w_q.stride(1),
+ model.scales.stride(0), model.scales.stride(1), model.zeros.stride(0), model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK,
+ GROUP_SIZE_CONST=128, num_warps=4, num_stages=1,
+ )
+ ms = benchmark_fn(
+ type('M', (), {'x': out, 'model': model, 'BM': BM, 'BN': BN, 'BK': BK,
+ '_w4a16_gemm_kernel': solution._w4a16_gemm_kernel,
+ 'triton': solution.triton, 'M': M, 'N': N, 'K': K,
+ 'w_q': model.w_q, 'scales': model.scales, 'zeros': model.zeros})(),
+ x, n_iter=50,
+ )
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK}: {ms:.3f} ms")
+ except Exception as e:
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK}: FAIL ({type(e).__name__})")
{
"content": "<2322 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/profile_configs.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/profile_configs.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/profile_configs.py
@@ -1,16 +1,94 @@
"""Profile Triton configs for W4A16 GEMM."""
import torch
import time
-import solution
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
-def benchmark_fn(model, x, n_iter=50):
- for _ in range(5):
- model(x)
+# Import the reference kernel from test_configs
+# (we'll redefine it here to avoid import issues)
+@triton.jit
+def _w4a16_gemm_kernel(
+ x_ptr, wq_ptr, s_ptr, z_ptr, out_ptr,
+ M, N, K,
+ sxm, sxk, swqk, swqn, ssg, ssn, szg, szn, som, son,
+ BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr, GS: tl.constexpr,
+):
+ pm = tl.program_id(0)
+ pn = tl.program_id(1)
+ om = pm * BM + tl.arange(0, BM)
+ on = pn * BN + tl.arange(0, BN)
+ mm = om < M
+ mn = on < N
+ outmask = mm[:, None] & mn[None, :]
+ acc = tl.zeros((BM, BN), dtype=tl.float32)
+ for kb in range(0, K, BK):
+ ko = kb + tl.arange(0, BK)
+ km = ko < K
+ xt = tl.load(x_ptr + om[:, None]*sxm + ko[None,:]*sxk, mask=mm[:,None]&km[None,:], other=0.0)
+ wqr = ko // 2
+ wqt = tl.load(wq_ptr + wqr[:,None]*swqk + on[None,:]*swqn, mask=km[:,None]&mn[None,:], other=0)
+ iodd = (ko % 2)[:, None]
+ wi4 = tl.where(iodd == 0, wqt & 0xF, wqt >> 4).to(tl.int32)
+ g = kb // GS
+ sr = tl.load(s_ptr + g*ssg + on*ssn, mask=mn, other=0.0).to(tl.float32)
+ zr = tl.load(z_ptr + g*szg + on*szn, mask=mn, other=0.0).to(tl.float32)
+ wd = (wi4.to(tl.float32) - zr[None,:]) * sr[None,:]
+ acc = tl.dot(xt.to(tl.float32), wd, acc)
+ tl.store(out_ptr + om[:,None]*som + on[None,:]*son, acc.to(tl.bfloat16), mask=outmask)
+
+
+def make_model(M, N, K):
+ class Model(torch.nn.Module):
+ def __init__(self, M, N, K):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ n_groups = K // GROUP_SIZE
+ torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
+ w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
+ w_g = w_full.view(n_groups, GROUP_SIZE, N)
+ w_min = w_g.min(dim=1, keepdim=True).values
+ w_max = w_g.max(dim=1, keepdim=True).values
+ scales = (w_max - w_min).clamp_min(1e-8) / 15.0
+ zeros = (-w_min / scales).round().clamp(0, 15)
+ w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8).view(K, N)
+ lo = w_q[0::2].to(torch.uint8) & 0xF
+ hi = w_q[1::2].to(torch.uint8) & 0xF
+ w_packed = (lo | (hi << 4)).contiguous()
+ self.register_buffer("w_q", w_packed)
+ self.register_buffer("scales", scales.squeeze(1).to(torch.bfloat16))
+ self.register_buffer("zeros", zeros.squeeze(1).to(torch.bfloat16))
+ return Model(M, N, K).cuda()
+
+
+def benchmark(model, x, BM, BN, BK, n_iter=50):
+ M, N, K = x.shape[0], model.N, x.shape[1]
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
+
+ def fwd():
+ _w4a16_gemm_kernel[grid](
+ x_c, model.w_q, model.scales, model.zeros, out,
+ M, N, K,
+ x.stride(0), x.stride(1), model.w_q.stride(0), model.w_q.stride(1),
+ model.scales.stride(0), model.scales.stride(1), model.zeros.stride(0), model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BM=BM, BN=BN, BK=BK, GS=GROUP_SIZE,
+ num_warps=4, num_stages=1,
+ )
+ return out
+
+ global x_c
+ x_c = x.contiguous()
+
+ # warmup
+ fwd()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n_iter):
- model(x)
+ fwd()
torch.cuda.synchronize()
return (time.perf_counter() - t0) / n_iter * 1000
@@ -24,49 +102,24 @@
]
configs = [
- (1, 64, 128),
- (1, 128, 128),
- (4, 64, 128),
- (4, 128, 128),
- (4, 256, 128),
- (16, 64, 128),
- (16, 128, 128),
- (16, 256, 128),
- (32, 64, 128),
- (32, 128, 128),
- (64, 64, 128),
- (64, 128, 128),
- (128, 64, 128),
- (128, 128, 128),
+ (1, 64, 128), (1, 128, 128),
+ (4, 64, 128), (4, 128, 128), (4, 256, 128),
+ (16, 64, 128), (16, 128, 128), (16, 256, 128),
+ (32, 64, 128), (32, 128, 128),
+ (64, 64, 128), (64, 128, 128),
+ (128, 64, 128), (128, 128, 128),
]
for M, N, K in SHAPES:
print(f"\n=== Shape ({M}, {N}, {K}) ===")
- model = solution.Model(M, N, K).cuda()
+ model = make_model(M, N, K)
x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
for BM, BN, BK in configs:
- if BM > M and M > 0:
+ if BM > M:
continue
try:
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- grid = (solution.triton.cdiv(M, BM), solution.triton.cdiv(N, BN))
- solution._w4a16_gemm_kernel[grid](
- x.contiguous(), model.w_q, model.scales, model.zeros, out,
- M, N, K,
- x.stride(0), x.stride(1), model.w_q.stride(0), model.w_q.stride(1),
- model.scales.stride(0), model.scales.stride(1), model.zeros.stride(0), model.zeros.stride(1),
- out.stride(0), out.stride(1),
- BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK,
- GROUP_SIZE_CONST=128, num_warps=4, num_stages=1,
- )
- ms = benchmark_fn(
- type('M', (), {'x': out, 'model': model, 'BM': BM, 'BN': BN, 'BK': BK,
- '_w4a16_gemm_kernel': solution._w4a16_gemm_kernel,
- 'triton': solution.triton, 'M': M, 'N': N, 'K': K,
- 'w_q': model.w_q, 'scales': model.scales, 'zeros': model.zeros})(),
- x, n_iter=50,
- )
+
... (239 more chars){
"content": "<4428 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/profile_configs.py"
}Wrote file successfully.{
"command": "timeout 180 uv run python profile_configs.py 2>&1",
"timeout": 200000,
"description": "Profile all configs"
}=== Shape (1, 12288, 4096) ===
BM= 1 BN= 64 BK=128: 0.053 ms
BM= 1 BN=128 BK=128: 0.079 ms
=== Shape (32, 12288, 4096) ===
BM= 1 BN= 64 BK=128: 0.735 ms
BM= 1 BN=128 BK=128: 1.306 ms
BM= 4 BN= 64 BK=128: 0.209 ms
BM= 4 BN=128 BK=128: 0.363 ms
BM= 4 BN=256 BK=128: FAIL (OutOfResources)
BM= 16 BN= 64 BK=128: 0.076 ms
BM= 16 BN=128 BK=128: 0.151 ms
BM= 16 BN=256 BK=128: FAIL (OutOfResources)
BM= 32 BN= 64 BK=128: 0.058 ms
BM= 32 BN=128 BK=128: 0.086 ms
=== Shape (256, 12288, 4096) ===
BM= 1 BN= 64 BK=128: 5.855 ms
BM= 1 BN=128 BK=128: 10.119 ms
BM= 4 BN= 64 BK=128: 1.467 ms
BM= 4 BN=128 BK=128: 2.430 ms
BM= 4 BN=256 BK=128: FAIL (OutOfResources)
BM= 16 BN= 64 BK=128: 0.410 ms
BM= 16 BN=128 BK=128: 0.679 ms
BM= 16 BN=256 BK=128: FAIL (OutOfResources)
BM= 32 BN= 64 BK=128: 0.257 ms
BM= 32 BN=128 BK=128: 0.420 ms
BM= 64 BN= 64 BK=128: 0.218 ms
BM= 64 BN=128 BK=128: 0.371 ms
BM=128 BN= 64 BK=128: 0.253 ms
BM=128 BN=128 BK=128: 0.423 ms
=== Shape (1, 4096, 4096) ===
BM= 1 BN= 64 BK=128: 0.041 ms
BM= 1 BN=128 BK=128: 0.079 ms
=== Shape (16, 14336, 4096) ===
BM= 1 BN= 64 BK=128: 0.449 ms
BM= 1 BN=128 BK=128: 0.771 ms
BM= 4 BN= 64 BK=128: 0.120 ms
BM= 4 BN=128 BK=128: 0.219 ms
BM= 4 BN=256 BK=128: FAIL (OutOfResources)
BM= 16 BN= 64 BK=128: 0.052 ms
BM= 16 BN=128 BK=128: 0.078 ms
BM= 16 BN=256 BK=128: FAIL (OutOfResources)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -118,13 +118,17 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # BN=128 is safe for 101KB shared mem with stages=1
- if M <= 4:
+ # Profiled optimal configs for each shape class
+ if M == 1:
+ BM, BN = 1, 64
+ elif M <= 4:
BM, BN = min(M, 4), 128
elif M <= 16:
BM, BN = min(M, 16), 128
+ elif M <= 64:
+ BM, BN = min(M, 64), 64
else:
- BM, BN = min(M, 64), 128
+ BM, BN = min(M, 128), 64
BK = 128
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
{
"content": "<5000 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "nvidia-smi | head -5",
"description": "Check GPU status before running"
}Wed Apr 29 09:14:48 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "nvidia-smi | head -5 && timeout 180 uv run python check.py 2>&1 && timeout 300 uv run python benchmark.py 2>&1",
"description": "Run correctness and benchmark",
"timeout": 500000
}Wed Apr 29 09:14:52 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
PASS
shape=0 variant=eager tflops=0.131 gbps=34.798 ms=0.769
shape=0 variant=compiled tflops=0.628 gbps=166.954 ms=0.160
shape=0 variant=solution tflops=1.223 gbps=325.275 ms=0.082
shape=0 solution_peak_fraction=0.1807
shape=1 variant=eager tflops=4.241 gbps=36.585 ms=0.760
shape=1 variant=compiled tflops=19.929 gbps=171.917 ms=0.162
shape=1 variant=solution tflops=34.427 gbps=296.974 ms=0.094
shape=1 solution_peak_fraction=0.1650
shape=2 variant=eager tflops=31.354 gbps=42.739 ms=0.822
shape=2 variant=compiled tflops=118.002 gbps=160.851 ms=0.218
shape=2 variant=solution tflops=89.280 gbps=121.699 ms=0.289
shape=2 solution_peak_fraction=0.0676
shape=3 variant=eager tflops=0.178 gbps=47.291 ms=0.189
shape=3 variant=compiled tflops=0.481 gbps=128.000 ms=0.070
shape=3 variant=solution tflops=0.451 gbps=119.914 ms=0.074
shape=3 solution_peak_fraction=0.0666
shape=4 variant=eager tflops=2.058 gbps=34.816 ms=0.913
shape=4 variant=compiled tflops=10.180 gbps=172.205 ms=0.185
shape=4 variant=solution tflops=15.320 gbps=259.139 ms=0.123
shape=4 solution_peak_fraction=0.1440
peak_fraction: 0.
... (16 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/profile_configs2.py
@@ -0,0 +1,113 @@
+"""Profile Triton configs with warps and stages variations."""
+import torch
+import time
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.jit
+def _w4a16_gemm_kernel(
+ x_ptr, wq_ptr, s_ptr, z_ptr, out_ptr,
+ M, N, K,
+ sxm, sxk, swqk, swqn, ssg, ssn, szg, szn, som, son,
+ BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr, GS: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BM + tl.arange(0, BM)
+ offs_n = pid_n * BN + tl.arange(0, BN)
+ m_mask = offs_m < M
+ n_mask = offs_n < N
+ out_mask = m_mask[:, None] & n_mask[None, :]
+ acc = tl.zeros((BM, BN), dtype=tl.float32)
+ for kb in range(0, K, BK):
+ ko = kb + tl.arange(0, BK)
+ km = ko < K
+ xt = tl.load(x_ptr + offs_m[:, None]*sxm + ko[None,:]*sxk, mask=m_mask[:, None]&km[None, :], other=0.0)
+ wqr = ko // 2
+ wqt = tl.load(wq_ptr + wqr[:, None]*swqk + offs_n[None, :]*swqn, mask=km[:, None]&n_mask[None, :], other=0)
+ iodd = (ko % 2)[:, None]
+ wi4 = tl.where(iodd == 0, wqt & 0xF, wqt >> 4).to(tl.int32)
+ g = kb // GS
+ sr = tl.load(s_ptr + g*ssg + offs_n*ssn, mask=n_mask, other=0.0).to(tl.float32)
+ zr = tl.load(z_ptr + g*szg + offs_n*szn, mask=n_mask, other=0.0).to(tl.float32)
+ wd = (wi4.to(tl.float32) - zr[None,:]) * sr[None,:]
+ acc = tl.dot(xt.to(tl.float32), wd, acc)
+ tl.store(out_ptr + offs_m[:, None]*som + offs_n[None, :]*son, acc.to(tl.bfloat16), mask=out_mask)
+
+
+def make_model(M, N, K):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.N = N
+ n_groups = K // GROUP_SIZE
+ torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
+ w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
+ w_g = w_full.view(n_groups, GROUP_SIZE, N)
+ w_min = w_g.min(dim=1, keepdim=True).values
+ w_max = w_g.max(dim=1, keepdim=True).values
+ scales = (w_max - w_min).clamp_min(1e-8) / 15.0
+ zeros = (-w_min / scales).round().clamp(0, 15)
+ w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8).view(K, N)
+ lo = w_q[0::2].to(torch.uint8) & 0xF
+ hi = w_q[1::2].to(torch.uint8) & 0xF
+ w_packed = (lo | (hi << 4)).contiguous()
+ self.register_buffer("w_q", w_packed)
+ self.register_buffer("scales", scales.squeeze(1).to(torch.bfloat16))
+ self.register_buffer("zeros", zeros.squeeze(1).to(torch.bfloat16))
+ return Model().cuda()
+
+
+def benchmark(model, x, BM, BN, BK, warps, stages, n_iter=50):
+ M, N, K = x.shape[0], model.N, x.shape[1]
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
+ x_c = x.contiguous()
+
+ def fwd():
+ _w4a16_gemm_kernel[grid](
+ x_c, model.w_q, model.scales, model.zeros, out,
+ M, N, K, x.stride(0), x.stride(1), model.w_q.stride(0), model.w_q.stride(1),
+ model.scales.stride(0), model.scales.stride(1), model.zeros.stride(0), model.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ BM=BM, BN=BN, BK=BK, GS=GROUP_SIZE,
+ num_warps=warps, num_stages=stages,
+ )
+ return out
+
+ fwd()
+ torch.cuda.synchronize()
+ t0 = time.perf_counter()
+ for _ in range(n_iter):
+ fwd()
+ torch.cuda.synchronize()
+ return (time.perf_counter() - t0) / n_iter * 1000
+
+
+SHAPES = [
+ (256, 12288, 4096),
+ (32, 12288, 4096),
+]
+
+# Test varying warps and stages for shapes with larger M
+for M, N, K in SHAPES:
+ print(f"\n=== Shape ({M}, {N}, {K}) ===")
+ model = make_model(M, N, K)
+ x = torch.randn(M, K, dtype=torch.bfloat16).cuda()
+
+ for BM, BN, BK in [
+ (64, 64, 128), (64, 128, 128), (128, 64, 128), (128, 128, 128),
+ (32, 64, 128), (32, 128, 128), (M, 64, 128),
+ ]:
+ if BM > M:
+ continue
+ for warps in [4, 8]:
+ for stages in [1, 2, 3, 4]:
+ try:
+ ms = benchmark(model, x, BM, BN, BK, warps, stages, n_iter=50)
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK} w={warps} s={stages}: {ms:.3f} ms")
+ except Exception:
+ pass # silently skip OOM configs
{
"content": "<4354 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/profile_configs2.py"
}Wrote file successfully.{
"command": "timeout 300 uv run python profile_configs2.py 2>&1",
"timeout": 360000,
"description": "Profile warps and stages variations"
}=== Shape (256, 12288, 4096) ===
BM= 64 BN= 64 BK=128 w=4 s=1: 0.218 ms
BM= 64 BN= 64 BK=128 w=4 s=2: 0.358 ms
BM= 64 BN= 64 BK=128 w=4 s=3: 0.360 ms
BM= 64 BN= 64 BK=128 w=8 s=1: 0.303 ms
BM= 64 BN= 64 BK=128 w=8 s=2: 0.354 ms
BM= 64 BN= 64 BK=128 w=8 s=3: 0.333 ms
BM= 64 BN=128 BK=128 w=4 s=1: 0.369 ms
BM= 64 BN=128 BK=128 w=4 s=2: 0.521 ms
BM= 64 BN=128 BK=128 w=8 s=1: 0.304 ms
BM= 64 BN=128 BK=128 w=8 s=2: 0.430 ms
BM=128 BN= 64 BK=128 w=4 s=1: 0.254 ms
BM=128 BN= 64 BK=128 w=4 s=2: 0.338 ms
BM=128 BN= 64 BK=128 w=8 s=1: 0.282 ms
BM=128 BN= 64 BK=128 w=8 s=2: 0.281 ms
BM=128 BN=128 BK=128 w=4 s=1: 0.423 ms
BM=128 BN=128 BK=128 w=8 s=1: 0.321 ms
BM= 32 BN= 64 BK=128 w=4 s=1: 0.257 ms
BM= 32 BN= 64 BK=128 w=4 s=2: 0.281 ms
BM= 32 BN= 64 BK=128 w=4 s=3: 0.334 ms
BM= 32 BN= 64 BK=128 w=4 s=4: 0.335 ms
BM= 32 BN= 64 BK=128 w=8 s=1: 0.361 ms
BM= 32 BN= 64 BK=128 w=8 s=2: 0.384 ms
BM= 32 BN= 64 BK=128 w=8 s=3: 0.422 ms
BM= 32 BN= 64 BK=128 w=8 s=4: 0.420 ms
BM= 32 BN=128 BK=128 w=4 s=1: 0.418 ms
BM= 32 BN=128 BK=128 w=4 s=2: 0.393 ms
BM= 32 BN=128 BK=128 w=8 s=1: 0.398 ms
BM= 32 BN=128 BK=128 w=8 s=2: 0.432 ms
BM=256 BN= 64 BK=128 w=4 s=1: 0.383 ms
BM=256 BN= 64 BK=128 w=8 s=1: 0.324 ms
=== Shape (32, 12288, 4096) ===
BM= 32 BN= 64 BK=128 w=4 s=1: 0.058 ms
BM= 32 BN= 64 BK=128 w=4 s=2: 0.061 ms
BM= 32 BN= 64 BK=128 w=4 s=3: 0.074 ms
BM= 32 BN= 64 BK=128 w=4 s=4: 0.074 ms
BM= 32 BN= 64 BK=128 w=8 s=1: 0.084 ms
... (616 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -118,13 +118,13 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # Profiled optimal configs for each shape class
+ # Profiled optimal configs
if M == 1:
BM, BN = 1, 64
- elif M <= 4:
- BM, BN = min(M, 4), 128
+ elif M == 4:
+ BM, BN = 4, 64
elif M <= 16:
- BM, BN = min(M, 16), 128
+ BM, BN = min(M, 16), 64
elif M <= 64:
BM, BN = min(M, 64), 64
else:
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<312 chars \u2014 see below>",
"oldString": "<343 chars \u2014 see below>"
} # Profiled optimal configs
if M == 1:
BM, BN = 1, 64
elif M == 4:
BM, BN = 4, 64
elif M <= 16:
BM, BN = min(M, 16), 64
elif M <= 64:
BM, BN = min(M, 64), 64
else:
BM, BN = min(M, 128), 64
BK = 128 # Profiled optimal configs for each shape class
if M == 1:
BM, BN = 1, 64
elif M <= 4:
BM, BN = min(M, 4), 128
elif M <= 16:
BM, BN = min(M, 16), 128
elif M <= 64:
BM, BN = min(M, 64), 64
else:
BM, BN = min(M, 128), 64
BK = 128Edit applied successfully.{
"command": "nvidia-smi | head -5",
"description": "Check GPU status"
}Wed Apr 29 09:16:34 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "nvidia-smi | head -5 && timeout 180 uv run python check.py 2>&1 && timeout 300 uv run python benchmark.py 2>&1",
"description": "Run check and benchmark after config change",
"timeout": 500000
}Wed Apr 29 09:16:41 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
PASS
shape=0 variant=eager tflops=0.131 gbps=34.758 ms=0.770
shape=0 variant=compiled tflops=0.658 gbps=174.895 ms=0.153
shape=0 variant=solution tflops=1.238 gbps=329.373 ms=0.081
shape=0 solution_peak_fraction=0.1830
shape=1 variant=eager tflops=4.204 gbps=36.269 ms=0.766
shape=1 variant=compiled tflops=20.001 gbps=172.532 ms=0.161
shape=1 variant=solution tflops=36.255 gbps=312.751 ms=0.089
shape=1 solution_peak_fraction=0.1738
shape=2 variant=eager tflops=31.325 gbps=42.699 ms=0.823
shape=2 variant=compiled tflops=117.520 gbps=160.194 ms=0.219
shape=2 variant=solution tflops=86.639 gbps=118.099 ms=0.297
shape=2 solution_peak_fraction=0.0656
shape=3 variant=eager tflops=0.177 gbps=47.155 ms=0.189
shape=3 variant=compiled tflops=0.433 gbps=115.234 ms=0.077
shape=3 variant=solution tflops=0.446 gbps=118.665 ms=0.075
shape=3 solution_peak_fraction=0.0659
shape=4 variant=eager tflops=2.053 gbps=34.728 ms=0.915
shape=4 variant=compiled tflops=10.097 gbps=170.799 ms=0.186
shape=4 variant=solution tflops=20.633 gbps=349.009 ms=0.091
shape=4 solution_peak_fraction=0.1939
peak_fraction: 0.
... (16 more chars){
"todos": [
{
"content": "Understand the W4A16 GEMM kernel requirements and data layout",
"status": "completed",
"priority": "high"
},
{
"content": "Write initial Triton W4A16 GEMM kernel in solution.py",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py to verify correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Profile and optimize for all shapes",
"status": "completed",
"priority": "high"
},
{
"content": "Consider further optimizat
... (101 more chars)content: 'Understand the W4A16 GEMM kernel requirements and data layout'
status: 'completed'
priority: 'high'
content: 'Write initial Triton W4A16 GEMM kernel in solution.py'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness'
status: 'completed'
priority: 'high'
content: 'Profile and optimize for all shapes'
status: 'completed'
priority: 'high'
content: 'Consider further optimizations for M=256 and M=1 shapes'
status: 'in_progress'
priority: 'medium'{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,6 +1,7 @@
"""W4A16 weight-only quantized GEMM for SM120 Blackwell.
Fuses unpack, dequant, and matmul in a single Triton kernel.
+Uses make_block_ptr API for efficient vectorized memory access.
"""
from __future__ import annotations
@@ -48,12 +49,14 @@
k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offs < K
+ # Load x tile
x_tile = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
+ # Load packed w_q tile
wq_row = k_offs // 2
wq_tile = tl.load(
wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
@@ -61,9 +64,11 @@
other=0,
)
+ # Unpack nibbles
is_odd = (k_offs % 2)[:, None]
w_int_4 = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
+ # Load scales and zeros (one row per group)
g = k_base // GROUP_SIZE_CONST
s_row = tl.load(
scales_ptr + g * stride_sg + offs_n * stride_sn,
@@ -76,6 +81,7 @@
other=0.0,
).to(tl.float32)
+ # Dequantize
w_deq = (w_int_4.to(tl.float32) - z_row[None, :]) * s_row[None, :]
acc = tl.dot(x_tile.to(tl.float32), w_deq, acc)
@@ -84,6 +90,40 @@
acc.to(tl.bfloat16),
mask=out_mask,
)
+
+
+def _compute_block_config(M, N, K, shared_mem_limit=101376):
+ """Compute optimal block config given shared memory constraint.
+
+ Shared memory per stage (bytes):
+ x tile: BM * BK * 2 (bf16)
+ wq tile: BK * BN * 1 (uint8)
+ scales/zeros: BN * 2 * 2 (bf16, loaded once per K-tile)
+
+ With stages=1, total = BM*BK*2 + BK*BN + 4*BN
+
+ Strategy: maximize computation density (BM*BN) within memory budget.
+ For memory-bound (small M), maximize BN. For compute-bound (large M), maximize BM.
+ """
+ BK = GROUP_SIZE # Always use group size for efficient scale loading
+
+ if M == 1:
+ # Pure memory-bound: one row, maximize N coverage
+ return 1, 256
+ elif M <= 4:
+ return min(M, 4), 256
+ elif M <= 16:
+ return min(M, 16), 256
+ elif M <= 32:
+ # Mixed regime
+ return min(M, 32), 256
+ elif M <= 64:
+ # Approaching compute-bound
+ return min(M, 64), 256
+ elif M <= 128:
+ return min(M, 128), 256
+ else:
+ return 128, 256
class Model(nn.Module):
@@ -118,20 +158,32 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # Profiled optimal configs
- if M == 1:
- BM, BN = 1, 64
- elif M == 4:
- BM, BN = 4, 64
- elif M <= 16:
- BM, BN = min(M, 16), 64
- elif M <= 64:
- BM, BN = min(M, 64), 64
- else:
- BM, BN = min(M, 128), 64
- BK = 128
+ # Use optimized block config
+ BM, BN = _compute_block_config(M, N, K)
+ BK = GROUP_SIZE
+
+ # Handle cases where large BN causes shared memory overflow
+ # Check if config would fit in shared memory
+ shared_per_stage = BM * BK * 2 + BK * BN + 4 * BN
+
+ # If too large, reduce BN
+ if shared_per_stage > 80000: # Leave some headroom
+ BN = 128
+ shared_per_stage = BM * BK * 2 + BK * BN + 4 * BN
+ # Still too large, reduce BM
+ if shared_per_stage > 80000:
+ BM = min(BM, 16)
+ BN = 64
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
+
+ # Use 4 warps, 2 stages for best performance
+ num_warps = 4 if BM * BN <= 4096 else 8
+ num_stages = 2 if shared_per_stage * 2 < 90000 else 1
+
+ # Override for small M where stages=2 doesn't help
+ if M <= 4:
+ num_stages = 1
_w4a16_gemm_kernel[grid](
x.contiguous(),
@@ -149,8 +201,8 @@
BLOCK_SIZE_N=BN,
BLOCK_SIZE_K=BK,
GROUP_SIZE_CONST=GROUP_SIZE,
- num_warps=4,
- num_stages=1,
+ num_warps=num_warps,
+ num_stages=num_stages,
)
return out
{
"content": "<6830 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "nvidia-smi | head -5",
"description": "Check GPU status before running"
}Wed Apr 29 09:18:47 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "nvidia-smi | head -5 && timeout 180 uv run python check.py 2>&1 && timeout 300 uv run python benchmark.py 2>&1",
"description": "Run check and benchmark",
"timeout": 500000
}Wed Apr 29 09:18:51 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 188, in forward
_w4a16_gemm_kernel[grid](
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kw
... (1061 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -1,7 +1,6 @@
"""W4A16 weight-only quantized GEMM for SM120 Blackwell.
Fuses unpack, dequant, and matmul in a single Triton kernel.
-Uses make_block_ptr API for efficient vectorized memory access.
"""
from __future__ import annotations
@@ -49,14 +48,12 @@
k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offs < K
- # Load x tile
x_tile = tl.load(
x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
- # Load packed w_q tile
wq_row = k_offs // 2
wq_tile = tl.load(
wq_ptr + wq_row[:, None] * stride_wqk + offs_n[None, :] * stride_wqn,
@@ -64,11 +61,9 @@
other=0,
)
- # Unpack nibbles
is_odd = (k_offs % 2)[:, None]
w_int_4 = tl.where(is_odd == 0, wq_tile & 0xF, wq_tile >> 4).to(tl.int32)
- # Load scales and zeros (one row per group)
g = k_base // GROUP_SIZE_CONST
s_row = tl.load(
scales_ptr + g * stride_sg + offs_n * stride_sn,
@@ -81,7 +76,6 @@
other=0.0,
).to(tl.float32)
- # Dequantize
w_deq = (w_int_4.to(tl.float32) - z_row[None, :]) * s_row[None, :]
acc = tl.dot(x_tile.to(tl.float32), w_deq, acc)
@@ -90,40 +84,6 @@
acc.to(tl.bfloat16),
mask=out_mask,
)
-
-
-def _compute_block_config(M, N, K, shared_mem_limit=101376):
- """Compute optimal block config given shared memory constraint.
-
- Shared memory per stage (bytes):
- x tile: BM * BK * 2 (bf16)
- wq tile: BK * BN * 1 (uint8)
- scales/zeros: BN * 2 * 2 (bf16, loaded once per K-tile)
-
- With stages=1, total = BM*BK*2 + BK*BN + 4*BN
-
- Strategy: maximize computation density (BM*BN) within memory budget.
- For memory-bound (small M), maximize BN. For compute-bound (large M), maximize BM.
- """
- BK = GROUP_SIZE # Always use group size for efficient scale loading
-
- if M == 1:
- # Pure memory-bound: one row, maximize N coverage
- return 1, 256
- elif M <= 4:
- return min(M, 4), 256
- elif M <= 16:
- return min(M, 16), 256
- elif M <= 32:
- # Mixed regime
- return min(M, 32), 256
- elif M <= 64:
- # Approaching compute-bound
- return min(M, 64), 256
- elif M <= 128:
- return min(M, 128), 256
- else:
- return 128, 256
class Model(nn.Module):
@@ -158,32 +118,22 @@
N = self.N
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- # Use optimized block config
- BM, BN = _compute_block_config(M, N, K)
- BK = GROUP_SIZE
-
- # Handle cases where large BN causes shared memory overflow
- # Check if config would fit in shared memory
- shared_per_stage = BM * BK * 2 + BK * BN + 4 * BN
-
- # If too large, reduce BN
- if shared_per_stage > 80000: # Leave some headroom
- BN = 128
- shared_per_stage = BM * BK * 2 + BK * BN + 4 * BN
- # Still too large, reduce BM
- if shared_per_stage > 80000:
- BM = min(BM, 16)
- BN = 64
+ # Profiled optimal configs — all with stages=1, warps=4, BK=128
+ if M == 1:
+ BM, BN = 1, 64
+ elif M <= 4:
+ BM, BN = min(M, 4), 128
+ elif M <= 16:
+ BM, BN = min(M, 16), 128
+ elif M <= 32:
+ BM, BN = 32, 128
+ elif M <= 64:
+ BM, BN = min(M, 64), 128
+ else:
+ BM, BN = 64, 128
+ BK = 128
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
-
- # Use 4 warps, 2 stages for best performance
- num_warps = 4 if BM * BN <= 4096 else 8
- num_stages = 2 if shared_per_stage * 2 < 90000 else 1
-
- # Override for small M where stages=2 doesn't help
- if M <= 4:
- num_stages = 1
_w4a16_gemm_kernel[grid](
x.contiguous(),
@@ -201,8 +151,8 @@
BLOCK_SIZE_N=BN,
BLOCK_SIZE_K=BK,
GROUP_SIZE_CONST=GROUP_SIZE,
- num_warps=num_warps,
- num_stages=num_stages,
+ num_warps=4,
+ num_stages=1,
)
return out
{
"content": "<5060 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}Wrote file successfully.{
"command": "nvidia-smi | head -5",
"description": "Check GPU status before running"
}Wed Apr 29 09:19:31 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "timeout 180 uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 200000
}PASS
{
"command": "nvidia-smi | head -5 && timeout 300 uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 360000
}Wed Apr 29 09:19:45 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
shape=0 variant=eager tflops=0.131 gbps=34.728 ms=0.771
shape=0 variant=compiled tflops=0.639 gbps=169.973 ms=0.158
shape=0 variant=solution tflops=1.147 gbps=304.997 ms=0.088
shape=0 solution_peak_fraction=0.1694
shape=1 variant=eager tflops=4.211 gbps=36.326 ms=0.765
shape=1 variant=compiled tflops=20.195 gbps=174.210 ms=0.160
shape=1 variant=solution tflops=24.576 gbps=212.000 ms=0.131
shape=1 solution_peak_fraction=0.1178
shape=2 variant=eager tflops=31.359 gbps=42.746 ms=0.822
shape=2 variant=compiled tflops=117.520 gbps=160.194 ms=0.219
shape=2 variant=solution tflops=60.427 gbps=82.369 ms=0.426
shape=2 solution_peak_fraction=0.0458
shape=3 variant=eager tflops=0.177 gbps=47.195 ms=0.189
shape=3 variant=compiled tflops=0.501 gbps=133.384 ms=0.067
shape=3 variant=solution tflops=0.430 gbps=114.408 ms=0.078
shape=3 solution_peak_fraction=0.0636
shape=4 variant=eager tflops=2.060 gbps=34.845 ms=0.912
shape=4 variant=compiled tflops=10.349 gbps=175.058 ms=0.182
shape=4 variant=solution tflops=14.987 gbps=253.517 ms=0.125
shape=4 solution_peak_fraction=0.1408
peak_fraction: 0.0961
R
... (11 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
1: """W4A16 weight-only quantized GEMM for SM120 Blackwell.
2:
3: Fuses unpack, dequant, and matmul in a single Triton kernel.
4: """
5: from __future__ import annotations
6:
7: import torch
8: import torch.nn as nn
9: import triton
10: import triton.language as tl
11:
12: GROUP_SIZE = 128
13:
14:
15: @triton.jit
16: def _w4a16_gemm_kernel(
17: x_ptr,
18: wq_ptr,
19: scales_ptr,
20: zeros_ptr,
21: out_ptr,
22: M,
23: N,
24: K,
25: stride_xm, stride_xk,
26: stride_wqk, stride_wqn,
27: stride_sg, stride_sn,
28: stride_zg, stride_zn,
29: stride_om, stride_on,
30: BLOCK_SIZE_M: tl.constexpr,
31: BLOCK_SIZE_N: tl.constexpr,
32: BLOCK_SIZE_K: tl.constexpr,
33: GROUP_SIZE_CONST: tl.constexpr,
34: ):
35: pid_m = tl.program_id(0)
36: pid_n = tl.program_id(1)
37:
38: offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
39: offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
40:
41: m_mask = offs_m < M
42: n_mask = offs_n < N
43: out_mask = m_mask[:, None] & n_mask[None, :]
44:
45: acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
46:
47: for k_base in range(0, K, BLOCK_SIZE_K):
48: k_offs = k_base + tl.arange(0, BLOCK_SIZE_K)
49: k_mask = k_offs < K
50:
51: x_tile = tl.load(
52: x_ptr + offs_m[:, None
... (4459 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -119,18 +119,19 @@
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
# Profiled optimal configs — all with stages=1, warps=4, BK=128
+ # BN=64 beats BN=128 for most shapes (better register usage / occupancy)
if M == 1:
BM, BN = 1, 64
elif M <= 4:
- BM, BN = min(M, 4), 128
+ BM, BN = min(M, 4), 64
elif M <= 16:
- BM, BN = min(M, 16), 128
+ BM, BN = min(M, 16), 64
elif M <= 32:
- BM, BN = 32, 128
+ BM, BN = 32, 64
elif M <= 64:
- BM, BN = min(M, 64), 128
+ BM, BN = min(M, 64), 64
else:
- BM, BN = 64, 128
+ BM, BN = 64, 64
BK = 128
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"newString": "<479 chars \u2014 see below>",
"oldString": "<403 chars \u2014 see below>"
} # Profiled optimal configs — all with stages=1, warps=4, BK=128
# BN=64 beats BN=128 for most shapes (better register usage / occupancy)
if M == 1:
BM, BN = 1, 64
elif M <= 4:
BM, BN = min(M, 4), 64
elif M <= 16:
BM, BN = min(M, 16), 64
elif M <= 32:
BM, BN = 32, 64
elif M <= 64:
BM, BN = min(M, 64), 64
else:
BM, BN = 64, 64
BK = 128 # Profiled optimal configs — all with stages=1, warps=4, BK=128
if M == 1:
BM, BN = 1, 64
elif M <= 4:
BM, BN = min(M, 4), 128
elif M <= 16:
BM, BN = min(M, 16), 128
elif M <= 32:
BM, BN = 32, 128
elif M <= 64:
BM, BN = min(M, 64), 128
else:
BM, BN = 64, 128
BK = 128Edit applied successfully.{
"command": "nvidia-smi | head -5",
"description": "Check GPU status"
}Wed Apr 29 09:20:18 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "nvidia-smi | head -5 && timeout 180 uv run python check.py 2>&1 && timeout 300 uv run python benchmark.py 2>&1",
"description": "Run check and benchmark",
"timeout": 500000
}Wed Apr 29 09:20:22 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
PASS
shape=0 variant=eager tflops=0.131 gbps=34.755 ms=0.770
shape=0 variant=compiled tflops=0.649 gbps=172.728 ms=0.155
shape=0 variant=solution tflops=1.226 gbps=326.036 ms=0.082
shape=0 solution_peak_fraction=0.1811
shape=1 variant=eager tflops=4.194 gbps=36.178 ms=0.768
shape=1 variant=compiled tflops=20.717 gbps=178.710 ms=0.155
shape=1 variant=solution tflops=36.327 gbps=313.371 ms=0.089
shape=1 solution_peak_fraction=0.1741
shape=2 variant=eager tflops=31.334 gbps=42.711 ms=0.822
shape=2 variant=compiled tflops=118.646 gbps=161.728 ms=0.217
shape=2 variant=solution tflops=104.058 gbps=141.844 ms=0.248
shape=2 solution_peak_fraction=0.0788
shape=3 variant=eager tflops=0.179 gbps=47.504 ms=0.188
shape=3 variant=compiled tflops=0.496 gbps=131.996 ms=0.068
shape=3 variant=solution tflops=0.441 gbps=117.293 ms=0.076
shape=3 solution_peak_fraction=0.0652
shape=4 variant=eager tflops=2.053 gbps=34.735 ms=0.915
shape=4 variant=compiled tflops=10.349 gbps=175.058 ms=0.182
shape=4 variant=solution tflops=21.264 gbps=359.689 ms=0.088
shape=4 solution_peak_fraction=0.1998
peak_fraction: 0
... (17 more chars){
"command": "nvidia-smi | head -5",
"description": "Check GPU status before final verification"
}Wed Apr 29 09:20:38 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
{
"command": "nvidia-smi | head -5 && timeout 180 uv run python check.py 2>&1",
"description": "Final correctness verification",
"timeout": 200000
}Wed Apr 29 09:20:42 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
PASS
{
"command": "rm -f profile_configs.py profile_configs2.py test_configs.py __pycache__ -rf",
"description": "Clean up scratch files"
}(no output){
"command": "ls *.py",
"description": "List Python files"
}benchmark.py
check.py
reference.py
shapes.py
solution.py
sota.py