"""Fused W4A16 GEMM: unpack int4 weights + dequant + matmul in one Triton kernel."""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
# ---------------------------------------------------------------------------
# Decode path: M == 1 split-K GEMV — parallelize over (N-tile, K-group).
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 256}, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_warps=8),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_warps=8),
],
key=["N", "K"],
)
@triton.jit
def _w4a16_gemv_splitk_kernel(
x_ptr,
wq_ptr,
scales_ptr,
zeros_ptr,
partial_ptr,
N,
K,
stride_wqn,
stride_sg,
stride_sn,
stride_pg,
BLOCK_N: tl.constexpr,
GROUP_SIZE: tl.constexpr,
PACKED_PER_GROUP: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_g = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
s = tl.load(
scales_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
).to(tl.float32)
z = tl.load(
zeros_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
).to(tl.float32)
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
base_k = pid_g * GROUP_SIZE
base_packed = pid_g * PACKED_PER_GROUP
for pi in tl.static_range(PACKED_PER_GROUP):
packed_row = base_packed + pi
k0 = base_k + 2 * pi
w_packed = tl.load(
wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
).to(tl.int32)
w_lo = (w_packed & 0xF).to(tl.float32)
w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
x0 = tl.load(x_ptr + k0).to(tl.float32)
x1 = tl.load(x_ptr + k0 + 1).to(tl.float32)
acc += x0 * (w_lo - z) * s + x1 * (w_hi - z) * s
tl.store(partial_ptr + pid_g * stride_pg + offs_n, acc, mask=n_mask)
@triton.jit
def _reduce_partial_kernel(
partial_ptr,
out_ptr,
N,
n_groups,
stride_pg,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
for g in tl.range(0, n_groups):
acc += tl.load(partial_ptr + g * stride_pg + offs_n, mask=n_mask, other=0.0)
tl.store(out_ptr + offs_n, acc.to(tl.bfloat16), mask=n_mask)
# ---------------------------------------------------------------------------
# General GEMM path: M > 1 — one tl.dot per quant group (128 K).
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=8, num_stages=2),
],
key=["M", "N", "K"],
)
@triton.jit
def _w4a16_gemm_kernel(
x_ptr,
wq_ptr,
scales_ptr,
zeros_ptr,
out_ptr,
M,
N,
K,
stride_xm,
stride_xk,
stride_wqn,
stride_sg,
stride_sn,
stride_om,
stride_on,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
GROUP_SIZE: tl.constexpr,
PACKED_PER_GROUP: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = offs_m < M
n_mask = offs_n < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
n_groups = K // GROUP_SIZE
for g in tl.range(0, n_groups):
s = tl.load(
scales_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
).to(tl.bfloat16)
z = tl.load(
zeros_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
).to(tl.bfloat16)
base_k = g * GROUP_SIZE
base_packed = g * PACKED_PER_GROUP
packed_rows = base_packed + tl.arange(0, PACKED_PER_GROUP)
wq_ptrs = wq_ptr + packed_rows[:, None] * stride_wqn + offs_n[None, :]
wq_mask = n_mask[None, :]
wq_packed = tl.load(wq_ptrs, mask=wq_mask, other=0).to(tl.int32)
w_lo = (wq_packed & 0xF).to(tl.bfloat16)
w_hi = ((wq_packed >> 4) & 0xF).to(tl.bfloat16)
w_deq_lo = (w_lo - z[None, :]) * s[None, :]
w_deq_hi = (w_hi - z[None, :]) * s[None, :]
offs_k_even = base_k + 2 * tl.arange(0, PACKED_PER_GROUP)
offs_k_odd = offs_k_even + 1
x_even_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k_even[None, :] * stride_xk
x_odd_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k_odd[None, :] * stride_xk
k_even_mask = offs_k_even[None, :] < K
k_odd_mask = offs_k_odd[None, :] < K
x_even = tl.load(x_even_ptrs, mask=m_mask[:, None] & k_even_mask, other=0.0).to(
tl.bfloat16
)
x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & k_odd_mask, other=0.0).to(
tl.bfloat16
)
acc += tl.dot(x_even, w_deq_lo, out_dtype=tl.float32)
acc += tl.dot(x_odd, w_deq_hi, out_dtype=tl.float32)
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_mask = m_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
def w4a16_gemm(
x: torch.Tensor,
w_q: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
group_size: int = GROUP_SIZE,
) -> torch.Tensor:
"""x: (M,K) bf16; w_q: (K//2,N) uint8; scales/zeros: (K//group,N) bf16 -> (M,N) bf16."""
assert x.is_contiguous()
assert w_q.is_contiguous()
assert scales.is_contiguous()
assert zeros.is_contiguous()
M, K = x.shape
Kh, N = w_q.shape
assert Kh * 2 == K
assert scales.shape == zeros.shape
assert scales.shape[0] == K // group_size
assert scales.shape[1] == N
if M == 1:
n_groups = K // group_size
packed_per_group = group_size // 2
partial = torch.empty((n_groups, N), dtype=torch.float32, device=x.device)
grid = lambda meta: (
triton.cdiv(N, meta["BLOCK_N"]),
n_groups,
)
_w4a16_gemv_splitk_kernel[grid](
x,
w_q,
scales,
zeros,
partial,
N,
K,
w_q.stride(0),
scales.stride(0),
scales.stride(1),
partial.stride(0),
GROUP_SIZE=group_size,
PACKED_PER_GROUP=packed_per_group,
)
out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
reduce_grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
_reduce_partial_kernel[reduce_grid](
partial,
out,
N,
n_groups,
partial.stride(0),
BLOCK_N=256,
)
return out
else:
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
packed_per_group = group_size // 2
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
_w4a16_gemm_kernel[grid](
x,
w_q,
scales,
zeros,
out,
M,
N,
K,
x.stride(0),
x.stride(1),
w_q.stride(0),
scales.stride(0),
scales.stride(1),
out.stride(0),
out.stride(1),
GROUP_SIZE=group_size,
PACKED_PER_GROUP=packed_per_group,
)
return out
class Model(nn.Module):
"""W4A16 GEMM: y = x @ dequant(w_q, scales, zeros)."""
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0
assert K % 2 == 0
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
w_g = w_full.view(n_groups, group_size, N)
w_min = w_g.min(dim=1, keepdim=True).values
w_max = w_g.max(dim=1, keepdim=True).values
scales = (w_max - w_min).clamp_min(1e-8) / 15.0
zeros = (-w_min / scales).round().clamp(0, 15)
w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8).view(K, N)
lo = w_q[0::2] & 0xF
hi = w_q[1::2] & 0xF
w_packed = (lo | (hi << 4)).contiguous()
self.register_buffer("w_q", w_packed)
self.register_buffer("scales", scales.squeeze(1).to(torch.bfloat16))
self.register_buffer("zeros", zeros.squeeze(1).to(torch.bfloat16))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return w4a16_gemm(x, self.w_q, self.scales, self.zeros, self.group_size)
M = 1
N = 12288
K = 4096
def get_inputs():
x = torch.randn(M, K, dtype=torch.bfloat16)
return [x]
def get_init_inputs():
return [M, N, K]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_123336_cursor_composer-2.5-fast_07_w4a16_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-15T18:57:14.630788+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-15T18:57:15.212475+00:00 elapsed_s=0.582 ms=0.042656
shape=0 variant=solution tflops=2.360 gbps=627.613 ms=0.043
shape=0 solution_peak_fraction=0.3487
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-15T18:57:15.850310+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-15T18:57:16.890642+00:00 elapsed_s=1.040 ms=0.108672
shape=1 variant=solution tflops=29.642 gbps=255.698 ms=0.109
shape=1 solution_peak_fraction=0.1421
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-15T18:57:17.531937+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-15T18:57:18.650366+00:00 elapsed_s=1.118 ms=0.304448
shape=2 variant=solution tflops=84.644 gbps=115.380 ms=0.304
shape=2 solution_peak_fraction=0.0641
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-15T18:57:18.866714+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-15T18:57:19.285467+00:00 elapsed_s=0.419 ms=0.032896
shape=3 variant=solution tflops=1.020 gbps=271.440 ms=0.033
shape=3 solution_peak_fraction=0.1508
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-15T18:57:20.047627+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-15T18:57:21.072748+00:00 elapsed_s=1.025 ms=0.107936
shape=4 variant=solution tflops=17.409 gbps=294.480 ms=0.108
shape=4 solution_peak_fraction=0.1636
peak_fraction: 0.1509
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_123336_cursor_composer-2.5-fast_07_w4a16_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_123336_cursor_composer-2.5-fast_07_w4a16_gemm/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_123336_cursor_composer-2.5-fast_07_w4a16_gemm/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_123336_cursor_composer-2.5-fast_07_w4a16_gemm/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_123336_cursor_composer-2.5-fast_07_w4a16_gemm/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260615_123336_cursor_composer-2.5-fast_07_w4a16_gemm/repo
Installed 56 packages in 66ms
PASS
{
"run_id": "20260615_123336_cursor_composer-2.5-fast_07_w4a16_gemm",
"run_group": "",
"problem": "07_w4a16_gemm",
"harness": "cursor",
"model": "composer-2.5-fast",
"reasoning_effort": "",
"started_at": "2026-06-15T12:33:36-06:00",
"harness_finished_at": "2026-06-15T12:56:58-06:00",
"finished_at": "2026-06-15T12:57:21-06:00",
"start_epoch": 1781548416,
"harness_end_epoch": 1781549818,
"end_epoch": 1781549841,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.1509,
"template_mutated": false,
"elapsed_seconds": 1402,
"total_elapsed_seconds": 1425,
"check_elapsed_seconds": 15,
"benchmark_elapsed_seconds": 8,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 18.950784593437945,
"usage": {"input_tokens": 58000, "output_tokens": 26569, "cache_read_tokens": 1082987, "cache_creation_tokens": 0, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"targetDirectory": "/workspace/problems/07_w4a16_gemm",
"globPattern": "**/*"
}{
"path": "/workspace/problems/07_w4a16_gemm/reference.py"
}{
"path": "/workspace/problems/07_w4a16_gemm/check.py"
}{
"path": "/workspace/problems/07_w4a16_gemm/benchmark.py"
}"""Naive W4A16 weight-only quantized GEMM reference (correctness only).
AWQ/GPTQ-style scheme:
x: (M, K) bf16
w_q: (K // 2, N) uint8 -- two int4 weights packed per byte (low nibble = even-K, high = odd-K)
scales: (K // group, N) bf16
zeros: (K // group, N) bf16 -- asymmetric (stored already as float zero-point)
out: (M, N) bf16
Dequant (per group along K):
w_bf[k, n] = (w_q[k, n] - zeros[k // group, n]) * scales[k // group, n]
where w_q[k, n] is the unpacked 4-bit value (0..15).
This reference unpacks to a full bf16 matrix and then runs torch.matmul. Slow and
memory-heavy on the dequant; the agent's solution must fuse unpack+GEMM.
"""
from __future__ import annotations
import torch
import torch.nn as nn
OP_TYPE = "gemm_w4a16"
SUPPORTED_PRECISIONS = ["int4_bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
GROUP_SIZE = 128
def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
"""Pack (K, N) uint8 in [0,15] into (K//2, N) uint8.
Even rows go in the low nibble, odd rows in the high nibble.
"""
K, N = w_q.shape
assert K % 2 == 0
lo = w_q[0::2].to(torch.uint8) & 0xF
hi = w_q[1::2].to(torch.uint8) & 0xF
return (lo | (hi << 4)).contiguous()
def _unpack_int4(w_packed: torch.Tensor, K: int) -> torch.Tensor:
"""Unpack (K//2, N) uint8 -> (K, N) uint8 in [0,15]."""
Kh, N = w_packed.shape
assert Kh * 2 == K
out = torch.empty((K, N), dtype=torch.uint
... (2702 more chars)"""Correctness runner for FP8 GEMM.
Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
"""
import re
import sys
from pathlib import Path
import torch
import yaml
# Make the repo's src/ importable
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.correctness import check_correctness # noqa: E402
from src.eval.numeric_stress import ( # noqa: E402
numeric_stress_cases,
numeric_stress_context,
tolerance_for_case,
)
def main():
try:
import reference
import shapes
import solution
except Exception as e:
print(f"FAIL: import error: {e}")
sys.exit(1)
problem_yaml = Path("problem.yaml")
meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
# --- Forbidden-op check ------------------------------------------------
sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
for forbidden in meta.get("forbidden", []):
pat = re.escape(forbidden)
if re.search(pat, sol_src):
print(f"FAIL: forbidden op used: {forbidden}")
sys.exit(1)
device = torch.device("cuda:0")
tol_override = meta.get("tolerance") or None
# --- Per-shape correctness --------------------------------------------
all_shapes = shapes.SHAPES
for shape_idx, shape in enumerate(all_shapes):
... (2703 more chars)"""Roofline benchmark for FP8 GEMM.
For each shape: times the agent's solution first. Optional eager, compiled,
and SOTA diagnostics are enabled with KBH_BENCHMARK_BASELINES=1. Reports achieved TFLOPS, GB/s, and peak_fraction.
Output lines the harness picks up:
shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
"""
import sys
from math import exp, log
from pathlib import Path
import torch
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
from src.eval.timing import benchmark_baselines_enabled, time_variant # noqa: E402
from src.hardware import get as get_hw # noqa: E402
def _eval_formula(expr: str, vars: dict) -> float:
# Very small eval: only names from `vars` are valid.
return float(eval(expr, {"__builtins__": {}}, vars))
def main():
import reference
import shapes
import solution
meta = yaml.safe_load(Path("problem.yaml").read_text())
hw = get_hw(meta["hardware"][0])
peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
peak_gbps = hw.peak_bandwidth_gb_s
regime = meta.get("regime", "compute")
flops_formula = meta["flops_formula"]
bytes_formula = meta["bytes_formula"]
num_perf_trials = int(meta.get("num_perf_trials", 30))
device = torch.device("cuda:0")
include_baselines = benchmark
... (3646 more chars)null{
"path": "/workspace/problems/07_w4a16_gemm/shapes.py"
}{
"path": "/workspace/problems/07_w4a16_gemm/problem.yaml"
}{
"path": "/workspace/problems/07_w4a16_gemm/sota.py"
}{
"targetDirectory": "/workspace/problems/07_w4a16_gemm",
"globPattern": "solution.py"
}{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_name(0)); print(torch.cuda.get_device_capability(0)); print(torch.cuda.get_device_properties(0))\"",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_41ce07eb-c154-433e-a97a-27bf78f5264",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
... (1273 more chars)"""Shape sweep for W4A16 GEMM.
Llama-style up_proj / qkv_proj shapes. Decode (M=1) is the bandwidth-bound
case every inference engine optimizes -- it's the bar to beat.
"""
SHAPES = [
{"M": 1, "N": 12288, "K": 4096}, # decode: memory-bound on int4 weight read
{"M": 32, "N": 12288, "K": 4096}, # small prefill: mixed regime
{"M": 256, "N": 12288, "K": 4096}, # larger prefill: approaching compute
{"M": 1, "N": 4096, "K": 4096}, # decode: square shape
{"M": 16, "N": 14336, "K": 4096}, # speculative-decode-ish
]
name: 07_w4a16_gemm
display_name: "W4A16 Weight-only Quantized GEMM"
precision: int4_bf16
regime: memory # decode-dominant; M=1 is bandwidth-bound on the int4 weight stream
# Dense-equivalent FLOPs (matmul work, ignoring dequant arithmetic).
flops_formula: "2 * M * N * K"
# Bytes moved per call (memory roofline):
# x: M*K*2 (bf16 activations, streamed in once)
# w_q: (K/2)*N (packed int4, 0.5 B/elem)
# scales: (K/128)*N*2 (bf16 scales)
# zeros: (K/128)*N*2 (bf16 zero-points)
# out: M*N*2 (bf16 store)
bytes_formula: "M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2"
hardware: [RTX_PRO_6000]
peak_tflops_key: bf16
peak_bandwidth_key: dram
tolerance:
bfloat16: 0.10 # group-quant adds noise on top of bf16 accumulator slop
# Forbidden ops -- agent must write the unpack + GEMM themselves, not call a
# vendor library that does both.
forbidden:
- "bitsandbytes.functional.dequantize_4bit"
- "bitsandbytes.functional.gemv_4bit"
- "marlin_kernel.gemm"
- "torch.nn.functional.linear"
sota:
name: "bitsandbytes NF4 (gemv_4bit / dequantize_4bit + matmul)"
url: "https://github.com/TimDettmers/bitsandbytes"
function: "bitsandbytes.functional.gemv_4bit"
notes: |
Marlin (IST-DASLab) is the W4A16 SOTA on Ampere/Hopper but does not have
SM120 (Blackwell consumer) kernels yet. GPTQ-Triton is unmaintained and
does not target SM120. bitsandbytes 0.49.2 *does* run on SM120 -- it
autotunes its CUDA kernels f
... (413 more chars)"""SOTA reference for W4A16 GEMM.
Library survey on RTX PRO 6000 Blackwell (SM120, CC 12.0):
- Marlin (IST-DASLab): no SM120 kernels (Ampere/Hopper only). Skip.
- GPTQ-Triton (fpgaminer): unmaintained; pure Triton path works on SM120
but is not faster than Marlin on its target HW
and has no Blackwell tuning. Skip as primary.
- AWQ (mit-han-lab/llm-awq): CUDA kernels not built for SM120 in the wheel.
Skip.
- bitsandbytes >= 0.49.2: CUDA kernels compile and run on SM120 (verified
on this machine). Different quant scheme (NF4,
symmetric, blocksize 64) than our reference's
AWQ-style asymmetric INT4 with group_size 128,
but it occupies the same memory regime and is
the only tuned W4A16-class kernel that runs on
SM120 today. Used here as an *informational*
SOTA line, not as a numerical reference.
The benchmark calls `sota_forward(x, ref_model)` and times it; correctness is
NOT checked against this path (the quant scheme differs).
"""
from __future__ import annotations
import torch
_BNB_OK: bool | None = None
def is_available() -> bool:
global _BNB_OK
if _BNB_OK is not None:
return _BNB_OK
try:
imp
... (2234 more chars)nullnull{
"taskId": "68350",
"blockUntilMs": 60000
}{
"pattern": "cutlass|triton",
"path": "/workspace",
"glob": "**/pyproject.toml",
"caseInsensitive": false,
"multiline": false,
"toolCallId": "tool_2f163c5e-9350-4e09-974d-3cab9759160",
"offset": 0
}{
"targetDirectory": "/workspace",
"globPattern": "**/*"
}nullnullnull{
"path": "/home/agent/.cursor/projects/workspace-problems-07-w4a16-gemm/terminals/68350.txt"
}{
"path": "/workspace/pyproject.toml"
}{
"path": "/workspace/src/hardware/rtx_pro_6000.py"
}---
pid: 367
cwd: "/workspace/problems/07_w4a16_gemm"
command: "cd /workspace/problems/07_w4a16_gemm && uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_name(0)); print(torch.cuda.get_device_capability(0)); print(torch.cuda.get_device_properties(0))\""
started_at: 2026-06-15T18:33:44.517Z
running_for_ms: 90005
---
[project]
name = "kernelbench-hard"
version = "0.1.0"
description = "Surgical GPU kernel benchmark. Small deck, frontier agents, roofline metric, link-don't-spoil problem briefs."
requires-python = ">=3.11"
dependencies = [
"torch>=2.11",
"numpy>=1.26",
"pyyaml>=6.0",
"rich>=13.7",
"matplotlib>=3.9",
"pytest>=8.0",
"ruff>=0.6",
"einops>=0.8.2",
"requests<2.33",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project.scripts]
kbh = "src.kbh.cli:main"
[tool.hatch.build.targets.wheel]
packages = ["src"]
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.ruff.lint]
select = ["E", "F", "W", "I", "B", "UP"]
ignore = ["E501"]
[tool.pytest.ini_options]
testpaths = ["tests"]
"""RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
will see 60-85% of peak on well-tuned code.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class HardwareTarget:
name: str
sm: str
vram_gb: int
peak_bandwidth_gb_s: float # DRAM
peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
RTX_PRO_6000 = HardwareTarget(
name="RTX PRO 6000 Blackwell Workstation",
sm="sm_120a",
vram_gb=96,
peak_bandwidth_gb_s=1800.0,
peak_tflops_dense={
# Blackwell GB202 dense tensor peaks. Derived from NVIDIA's headline
# 4000 fp4-sparse AI TOPS -> fp4 dense 2000 -> fp8 1000 -> bf16 500
# (each precision step = 2x; sparse = 2x dense). Verified empirically:
# cuBLAS hits fp8 773 / bf16 412 TFLOPS on 4096^3 (~77-82% of these,
# normal cuBLAS efficiency). The prior table was ~2.5x too low, which
# produced peak_fraction > 1.0 for real fp8 kernels. fp32 is the 125
# TFLOPS SIMT figure from the spec sheet.
"fp4": 2000.0,
"nvfp4": 2000.0,
"mxfp4": 2000.0,
"fp6": 1000.0,
"fp8": 1000.0,
"bf16": 500.0,
"fp16": 500.0,
"tf32": 250.0,
"fp32": 125.0, # non-tensor-core SIMT fp32 (spec sheet)
"int8": 1000.0,
"int4": 2000.0,
},
)
{
"taskId": "68350",
"blockUntilMs": 120000
}{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python -c \"import triton; print('triton', triton.__version__)\" 2>&1 | head -5",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_af5c3a14-ca25-45e7-bce0-15d04b9d0c2",
"simpleCommands": [
"cd",
"uv",
"head"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
... (1327 more chars)nullnull{
"path": "/home/agent/.cursor/projects/workspace-problems-07-w4a16-gemm/terminals/68350.txt"
}{
"command": "<239 chars \u2014 see below>",
"workingDirectory": "",
"timeout": 120000,
"toolCallId": "tool_45605db8-a7b2-459e-8881-57f06df523f",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": false,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"fullText": "cd /workspace/problems/07_w4a16_gemm"
},
{
"name": "uv"
... (1114 more chars)cd /workspace/problems/07_w4a16_gemm && uv run python -c "
import torch
print('cuda', torch.cuda.is_available())
if torch.cuda.is_available():
print(torch.cuda.get_device_name(0))
print('cap', torch.cuda.get_device_capability(0))
"---
pid: 367
cwd: "/workspace/problems/07_w4a16_gemm"
command: "cd /workspace/problems/07_w4a16_gemm && uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_name(0)); print(torch.cuda.get_device_capability(0)); print(torch.cuda.get_device_properties(0))\""
started_at: 2026-06-15T18:33:44.517Z
running_for_ms: 120006
---
Using CPython 3.11.15
Creating virtual environment at: /workspace/.venv
Building kernelbench-hard @ file:///workspace
Built kernelbench-hard @ file:///workspace
warning: Failed to hardlink files; falling back to full copy. This may lead to degraded performance.
If the cache and target directories are on different filesystems, hardlinking may not be supported.
If this is intentional, set `export UV_LINK_MODE=copy` or use `--link-mode=copy` to suppress this warning.
Installed 56 packages in 883ms
2.11.0+cu130
NVIDIA RTX PRO 6000 Blackwell Workstation Edition
(12, 0)
_CudaDeviceProperties(name='NVIDIA RTX PRO 6000 Blackwell Workstation Edition', major=12, minor=0, total_memory=97288MB, multi_processor_count=188, uuid=d594fb7a-abcb-bad5-1fc4-675b478a47df, pci_bus_id=1, pci_device_id=0, pci_domain_id=0, L2_cache_size=128MB)
---
exit_code: 0
elapsed_ms: 121703
ended_at: 2026-06-15T18:35:46.220Z
---
null{
"path": "/workspace/src/eval/numeric_stress.py"
}{
"path": "/workspace/src/eval/correctness.py"
}"""Numeric distribution stress cases for correctness checks.
These cases are not hidden shapes. They rerun the same shape/seed validation
under a few scale regimes that catch kernels that only work for the nominal
N(0, 1)-ish inputs.
"""
from __future__ import annotations
import os
from collections.abc import Iterator, Mapping, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
import torch
ToleranceOverride = Mapping[str, float | Mapping[str, float]]
@dataclass(frozen=True)
class NumericStressCase:
name: str
input_scales: Mapping[int, float] = field(default_factory=dict)
state_scales: Mapping[str, float] = field(default_factory=dict)
tolerance: ToleranceOverride | None = None
NOMINAL_CASE = NumericStressCase("nominal")
_SMALL_BF16 = {"bfloat16": {"atol": 1e-4, "rtol": 5e-2}}
_MED_BF16 = {"bfloat16": {"atol": 5e-4, "rtol": 5e-2}}
_LARGE_BF16 = {"bfloat16": {"atol": 2e-1, "rtol": 5e-2}}
_TINY_FP32 = {"float32": {"atol": 1e-7, "rtol": 1e-4}}
_LARGE_FP32 = {"float32": {"atol": 1e-1, "rtol": 1e-4}}
_CASES: dict[str, tuple[NumericStressCase, ...]] = {
"01_fp8_gemm": (
# fp8 x fp8 has accumulation-order noise vs the bf16-matmul reference that
# scales with input magnitude; atols are calibrated to absorb it (measured
# fp8-MMA residual x ~1.5 margin) while rtol=5e-2 still catches gross error.
NumericStressCase("small_input", input_scales={0: 1e-3},
tolerance={
... (4968 more chars)"""Per-dtype correctness tolerance.
Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
We extend to fp8 (0.1, permissive) and int (exact).
"""
from __future__ import annotations
import torch
DEFAULT_TOLERANCE = {
torch.float32: {"atol": 1e-4, "rtol": 1e-4},
torch.float16: {"atol": 1e-2, "rtol": 1e-2},
torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
torch.int8: {"atol": 0, "rtol": 0},
torch.int32: {"atol": 0, "rtol": 0},
torch.int64: {"atol": 0, "rtol": 0},
}
def _coerce_float(x):
"""Accept int/float/str scientific-notation as float; pass through dicts."""
if isinstance(x, (int, float)):
return float(x)
if isinstance(x, str):
try:
return float(x)
except ValueError:
return x
return x
def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
"""Lookup atol/rtol for a given dtype, with optional per-problem override.
The override dict is keyed by str(dtype). Values may be:
- a single number (or a string parseable as float, e.g. "5e-2") -> used
as both atol and rtol
- a dict {"atol": ..., "rtol": ...} -> used directly, with values coerced
"""
# Accept several spellings of the same dtype key (PyYAML quirks):
# "torch.bfloat16", "bfloat16", torch.bfloat16
if override is not None:
type_
... (2572 more chars)--- /dev/null
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -0,0 +1,179 @@
+"""Fused W4A16 GEMM: unpack int4 weights + dequant + matmul in one Triton kernel."""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 512, "BLOCK_K": 64}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
+ ],
+ key=["M", "N", "K"],
+)
+@triton.jit
+def _w4a16_gemm_kernel(
+ x_ptr,
+ wq_ptr,
+ scales_ptr,
+ zeros_ptr,
+ out_ptr,
+ M,
+ N,
+ K,
+ stride_xm,
+ stride_xk,
+ stride_wq_kh,
+ stride_wqn,
+ stride_sg,
+ stride_sn,
+ stride_om,
+ stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k0 in range(0, K, BLOCK_K):
+ offs_k = k0 + tl.arange(0, BLOCK_K)
+ k_mask = offs_k < K
+
+ x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
+ x_mask = (offs_m[:, None] < M) & k_mask[None, :]
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.bfloat16)
+
+ packed_k = offs_k // 2
+ shift = (offs_k % 2) * 4
+
+ wq_ptrs = wq_ptr + packed_k[:, None] * stride_wq_kh + offs_n[None, :] * stride_wqn
+ wq_mask = k_mask[:, None] & (offs_n[None, :] < N)
+ wq_packed = tl.load(wq_ptrs, mask=wq_mask, other=0).to(tl.int32)
+
+ w_int = (wq_packed >> shift[:, None]) & 0xF
+ w_bf = w_int.to(tl.bfloat16)
+
+ g_idx = offs_k // GROUP_SIZE
+ s_ptrs = scales_ptr + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ z_ptrs = zeros_ptr + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ s = tl.load(s_ptrs, mask=wq_mask, other=0.0).to(tl.bfloat16)
+ z = tl.load(z_ptrs, mask=wq_mask, other=0.0).to(tl.bfloat16)
+
+ w_deq = (w_bf - z) * s
+ acc += tl.dot(x, w_deq, out_dtype=tl.float32)
+
+ out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
+
+
+def w4a16_gemm(
+ x: torch.Tensor,
+ w_q: torch.Tensor,
+ scales: torch.Tensor,
+ zeros: torch.Tensor,
+ group_size: int = GROUP_SIZE,
+) -> torch.Tensor:
+ """x: (M,K) bf16; w_q: (K//2,N) uint8; scales/zeros: (K//group,N) bf16 -> (M,N) bf16."""
+ assert x.is_contiguous()
+ assert w_q.is_contiguous()
+ assert scales.is_contiguous()
+ assert zeros.is_contiguous()
+ M, K = x.shape
+ Kh, N = w_q.shape
+ assert Kh * 2 == K
+ assert scales.shape == zeros.shape
+ assert scales.shape[0] == K // group_size
+ assert scales.shape[1] == N
+
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+
+ grid = lambda meta: (
+ triton.cdiv(M, meta["BLOCK_M"]),
+ triton.cdiv(N, meta["BLOCK_N"]),
+ )
+ _w4a16_gemm_kernel[grid](
+ x,
+ w_q,
+ scales,
+ zeros,
+ out,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ w_q.stride(0),
+ w_q.stride(1),
+ scales.stride(0),
+ scales.stride(1),
+ out.stride(0),
+ out.stride(1),
+ GROUP_SIZE=group_size,
+ )
+ return out
+
+
+class Model(nn.Module):
+ """W4A16 GEMM: y = x @ dequant(w_q, scales, zeros)."""
+
+ def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
+ super().__init__()
+ assert K % group_size == 0
+ assert K % 2 == 0
+ self.M, self.N, self.K = M, N, K
+ self.group_size = group_size
+ n_groups = K // group_size
+
+ torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
+ w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
+
+ w_g = w_full.view(n_groups, group_size, N)
+ w_min = w_g.min(dim=1, keepdim=True).values
+ w_max = w_g.max(dim=1, keepdim=True).values
+ scales = (w_max - w_min).clamp_min(1e-8) / 15.0
+ zeros = (-w_min / scales).round().clamp(0, 15)
+ w_q = ((w_g / scales) + zeros).round().clamp(0, 15).to(torch.uint8).view(K, N)
+
+ lo = w_q[0::2] & 0xF
+ hi = w_q[1::2] & 0xF
+ w_packed = (lo | (hi << 4)).contiguous()
+
+ self.register_buffer("w_q", w_packed)
+ self.register_buffer("scales", scales.squeeze(1).to(torch.bfloat16))
+ self.register_buffer("zeros", zeros.squeeze(1).to(torch.bfloat16))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return w4a16_gemm(x, self.w_q, self.scales, self.zeros, self.group_size)
+
+
+M = 1
+N = 12288
+K = 4096
+
+
+def get_inputs():
+ x = torch.randn(M, K, dtype=torch.bfloat16)
+ return [x]
+
+
+def get_init_inputs():
+ return [M, N, K]
{
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<5570 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_f3f5b42c-8822-44b7-8e50-17f9d893745",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"fullText": "cd /workspace/problems/07_w4a16_gemm"
... (905 more chars)failure:
command: 'cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/problems/07_w4a16_gemm/check.py", line 119, in <module>
main()
File "/workspace/problems/07_w4a16_gemm/check.py", line 80, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/07_w4a16_gemm/solution.py", line 165, in forward
return w4a16_gemm(x, self.w_q, self.scales, self.zeros, self.group_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/07_w4a16_gemm/solution.py", line 113, in w4a16_gemm
_w4a16_gemm_kernel[grid](
File "/workspace/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
... (3581 more chars)--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,179 +1,13 @@
-"""Fused W4A16 GEMM: unpack int4 weights + dequant + matmul in one Triton kernel."""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-GROUP_SIZE = 128
-
-
@triton.autotune(
configs=[
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 512, "BLOCK_K": 64}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=2),
],
key=["M", "N", "K"],
-)
-@triton.jit
-def _w4a16_gemm_kernel(
- x_ptr,
- wq_ptr,
- scales_ptr,
- zeros_ptr,
- out_ptr,
- M,
- N,
- K,
- stride_xm,
- stride_xk,
- stride_wq_kh,
- stride_wqn,
- stride_sg,
- stride_sn,
- stride_om,
- stride_on,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
-):
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
-
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k0 in range(0, K, BLOCK_K):
- offs_k = k0 + tl.arange(0, BLOCK_K)
- k_mask = offs_k < K
-
- x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
- x_mask = (offs_m[:, None] < M) & k_mask[None, :]
- x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.bfloat16)
-
- packed_k = offs_k // 2
- shift = (offs_k % 2) * 4
-
- wq_ptrs = wq_ptr + packed_k[:, None] * stride_wq_kh + offs_n[None, :] * stride_wqn
- wq_mask = k_mask[:, None] & (offs_n[None, :] < N)
- wq_packed = tl.load(wq_ptrs, mask=wq_mask, other=0).to(tl.int32)
-
- w_int = (wq_packed >> shift[:, None]) & 0xF
- w_bf = w_int.to(tl.bfloat16)
-
- g_idx = offs_k // GROUP_SIZE
- s_ptrs = scales_ptr + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn
- z_ptrs = zeros_ptr + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn
- s = tl.load(s_ptrs, mask=wq_mask, other=0.0).to(tl.bfloat16)
- z = tl.load(z_ptrs, mask=wq_mask, other=0.0).to(tl.bfloat16)
-
- w_deq = (w_bf - z) * s
- acc += tl.dot(x, w_deq, out_dtype=tl.float32)
-
- out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
- out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
-
-
-def w4a16_gemm(
- x: torch.Tensor,
- w_q: torch.Tensor,
- scales: torch.Tensor,
- zeros: torch.Tensor,
- group_size: int = GROUP_SIZE,
-) -> torch.Tensor:
- """x: (M,K) bf16; w_q: (K//2,N) uint8; scales/zeros: (K//group,N) bf16 -> (M,N) bf16."""
- assert x.is_contiguous()
- assert w_q.is_contiguous()
- assert scales.is_contiguous()
- assert zeros.is_contiguous()
- M, K = x.shape
- Kh, N = w_q.shape
- assert Kh * 2 == K
- assert scales.shape == zeros.shape
- assert scales.shape[0] == K // group_size
- assert scales.shape[1] == N
-
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
-
- grid = lambda meta: (
- triton.cdiv(M, meta["BLOCK_M"]),
- triton.cdiv(N, meta["BLOCK_N"]),
- )
- _w4a16_gemm_kernel[grid](
- x,
- w_q,
- scales,
- zeros,
- out,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- w_q.stride(0),
- w_q.stride(1),
- scales.stride(0),
- scales.stride(1),
- out.stride(0),
- out.stride(1),
- GROUP_SIZE=group_size,
- )
- return out
-
-
-class Model(nn.Module):
- """W4A16 GEMM: y = x @ dequant(w_q, scales, zeros)."""
-
- def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
- super().__init__()
- assert K % group_size == 0
- assert K % 2 == 0
- self.M, self.N, self.K = M, N, K
- self.group_size = group_size
- n_groups = K // group_size
-
- torch.manual_seed(0xC0DE ^ (M * 1315423911 + N * 2654435761 + K))
- w_full = torch.randn(K, N, dtype=torch.float32) * 0.02
-
- w_g = w_full.view(n_groups, group_size, N)
- w_min = w_g.min(dim=1, keepdim=True).values
- w_max = w_g.max(dim=1, keepdim=True).values
- scales = (w_max - w_min).clamp_min(1e-8) / 15.0
- zeros = (-w_min / scales).round().clamp(0, 15)
- w_q = ((w_g / scales) + zeros).round().clamp(0,
... (660 more chars){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<843 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_4ad90896-ffb3-48e8-ab42-abd8832e2f6",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"fullText": "cd /workspace/problems/07_w4a16_gemm"
... (899 more chars)null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_a8d611f9-08c6-440d-9973-aaaa42e07df",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"fullText": "cd /workspace/problems/07_w4a16_g
... (909 more chars)null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,13 +1,265 @@
+"""Fused W4A16 GEMM: unpack int4 weights + dequant + matmul in one Triton kernel."""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+# ---------------------------------------------------------------------------
+# Decode path: M == 1 GEMV — one program per N-tile, K-reduction fused with dequant.
+# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_N": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_N": 1024}, num_warps=8, num_stages=1),
+ ],
+ key=["N", "K"],
+)
+@triton.jit
+def _w4a16_gemv_kernel(
+ x_ptr,
+ wq_ptr,
+ scales_ptr,
+ zeros_ptr,
+ out_ptr,
+ N,
+ K,
+ stride_wqn,
+ stride_sg,
+ stride_sn,
+ stride_on,
+ BLOCK_N: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+ PACKED_PER_GROUP: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+ n_mask = offs_n < N
+
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+ n_groups = K // GROUP_SIZE
+
+ for g in tl.range(0, n_groups):
+ s = tl.load(
+ scales_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
+ ).to(tl.float32)
+ z = tl.load(
+ zeros_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
+ ).to(tl.float32)
+
+ base_k = g * GROUP_SIZE
+ base_packed = g * PACKED_PER_GROUP
+
+ for pi in tl.static_range(PACKED_PER_GROUP):
+ packed_row = base_packed + pi
+ k0 = base_k + 2 * pi
+
+ w_packed = tl.load(
+ wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
+ ).to(tl.int32)
+ w_lo = (w_packed & 0xF).to(tl.float32)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
+
+ x0 = tl.load(x_ptr + k0).to(tl.float32)
+ x1 = tl.load(x_ptr + k0 + 1).to(tl.float32)
+
+ acc += x0 * (w_lo - z) * s + x1 * (w_hi - z) * s
+
+ tl.store(out_ptr + offs_n, acc.to(tl.bfloat16), mask=n_mask)
+
+
+# ---------------------------------------------------------------------------
+# General GEMM path: M > 1
+# ---------------------------------------------------------------------------
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
],
key=["M", "N", "K"],
-)+)
+@triton.jit
+def _w4a16_gemm_kernel(
+ x_ptr,
+ wq_ptr,
+ scales_ptr,
+ zeros_ptr,
+ out_ptr,
+ M,
+ N,
+ K,
+ stride_xm,
+ stride_xk,
+ stride_wq_kh,
+ stride_wqn,
+ stride_sg,
+ stride_sn,
+ stride_om,
+ stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k0 in range(0, K, BLOCK_K):
+ offs_k = k0 + tl.arange(0, BLOCK_K)
+ k_mask = offs_k < K
+
+ x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
+ x_mask = (offs_m[:, None] < M) & k_mask[None, :]
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.bfloat16)
+
+ packed_k = offs_k // 2
+ shift = (offs_k % 2) * 4
+
+ wq_ptrs = wq_ptr + packed_k[:, None] * stride_wq_kh + offs_n[None, :] * stride_wqn
+ wq_mask = k_mask[:, None] & (offs_n[None, :] < N)
+ wq_packed = tl.load(wq_ptrs, mask=wq_mask, other=0).to(tl.int32)
+
+ w_int = (wq_packed >> shift[:, None]) & 0xF
+ w_bf = w_int.to(tl.bfloat16)
+
+ g_idx = offs_k // GROUP_SIZE
+ s_ptrs = scales_ptr + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ z_ptrs = zeros_ptr + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn
+ s = tl.load(s_ptrs, mask=wq_mask, other=0.0).to(tl.bfloat16)
+ z = tl.load(z_ptrs, mask=wq_mask, other=0.0).to(tl.bfloat16)
+
+ w_deq = (w_bf - z) * s
+ acc += tl.dot(x, w_deq, out_dtype=tl.float32)
+
+ out_ptrs = out_ptr + offs_m[:, None] * str
... (3432 more chars){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<8256 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_9dbf0708-e038-405c-89cb-16c59ac9aef",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"full
... (1490 more chars)null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -10,19 +10,18 @@
# ---------------------------------------------------------------------------
-# Decode path: M == 1 GEMV — one program per N-tile, K-reduction fused with dequant.
+# Decode path: M == 1 split-K GEMV — parallelize over (N-tile, K-group).
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
- triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_N": 512}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_N": 1024}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_N": 256}, num_warps=4),
+ triton.Config({"BLOCK_N": 128}, num_warps=4),
+ triton.Config({"BLOCK_N": 512}, num_warps=8),
],
key=["N", "K"],
)
@triton.jit
-def _w4a16_gemv_kernel(
+def _w4a16_gemv_splitk_kernel(
x_ptr,
wq_ptr,
scales_ptr,
@@ -33,45 +32,43 @@
stride_wqn,
stride_sg,
stride_sn,
- stride_on,
BLOCK_N: tl.constexpr,
GROUP_SIZE: tl.constexpr,
PACKED_PER_GROUP: tl.constexpr,
):
- pid = tl.program_id(0)
- offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+ pid_n = tl.program_id(0)
+ pid_g = tl.program_id(1)
+
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
+ s = tl.load(
+ scales_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
+ ).to(tl.float32)
+ z = tl.load(
+ zeros_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
+ ).to(tl.float32)
+
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
- n_groups = K // GROUP_SIZE
-
- for g in tl.range(0, n_groups):
- s = tl.load(
- scales_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
- ).to(tl.float32)
- z = tl.load(
- zeros_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
- ).to(tl.float32)
-
- base_k = g * GROUP_SIZE
- base_packed = g * PACKED_PER_GROUP
-
- for pi in tl.static_range(PACKED_PER_GROUP):
- packed_row = base_packed + pi
- k0 = base_k + 2 * pi
-
- w_packed = tl.load(
- wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
- ).to(tl.int32)
- w_lo = (w_packed & 0xF).to(tl.float32)
- w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
-
- x0 = tl.load(x_ptr + k0).to(tl.float32)
- x1 = tl.load(x_ptr + k0 + 1).to(tl.float32)
-
- acc += x0 * (w_lo - z) * s + x1 * (w_hi - z) * s
-
- tl.store(out_ptr + offs_n, acc.to(tl.bfloat16), mask=n_mask)
+ base_k = pid_g * GROUP_SIZE
+ base_packed = pid_g * PACKED_PER_GROUP
+
+ for pi in tl.static_range(PACKED_PER_GROUP):
+ packed_row = base_packed + pi
+ k0 = base_k + 2 * pi
+
+ w_packed = tl.load(
+ wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
+ ).to(tl.int32)
+ w_lo = (w_packed & 0xF).to(tl.float32)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
+
+ x0 = tl.load(x_ptr + k0).to(tl.float32)
+ x1 = tl.load(x_ptr + k0 + 1).to(tl.float32)
+
+ acc += x0 * (w_lo - z) * s + x1 * (w_hi - z) * s
+
+ tl.atomic_add(out_ptr + offs_n, acc, mask=n_mask)
# ---------------------------------------------------------------------------
@@ -79,13 +76,15 @@
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=8, num_stages=2),
],
key=["M", "N", "K"],
)
@@ -171,12 +170,15 @@
assert scales.shape[0] == K // group_size
assert scales.shape[1] == N
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
-
if M == 1:
+ out = torch.zeros((M, N), dtype=torch.float32, device=x.device)
+ n_groups = K // group_size
packed_per_group = group_size // 2
- grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
- _w4a16_gemv_kernel[grid](
+ grid = lambda meta: (
+ triton.cdiv(N, meta["BLOCK_N"]),
+ n_groups,
+ )
+ _w4a16_gemv_splitk_kernel[grid](
x,
w_q,
scales,
@@ -187,11 +189,12 @@
w_q.stride(0),
scales.stride(0),
scales.stride(1),
- out.stride(1),
GROUP_SIZE=group_size,
PACKED_PER_GROUP=packed_per_group,
)
+ return out.to(torch.bfloat16)
else:
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
@@ -215,7 +218,7 @@
out.stride(1),
GROUP_SIZE=group_size,
)
- return out
+ return out
class Model(
... (12 more chars){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<8389 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_cbfb669e-19e3-44cb-87c2-1655f1f9078",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"full
... (1485 more chars)failure:
command: 'cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
FAIL: shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42 case nominal: tolerance exceeded: max_abs_diff=9659.19 max_rel_diff=10753 bad=12288/12288 worst_idx=(0, 3980) (atol=0.1, rtol=0.1)
stderr: ''
executionTime: 37427
interleavedOutput:
FAIL: shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42 case nominal: tolerance exceeded: max_abs_diff=9659.19 max_rel_diff=10753 bad=12288/12288 worst_idx=(0, 3980) (atol=0.1, rtol=0.1)
aborted: False
localExecutionTimeMs: 37359
isBackground: False--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,179 +1,7 @@
-"""Fused W4A16 GEMM: unpack int4 weights + dequant + matmul in one Triton kernel."""
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-
-GROUP_SIZE = 128
-
-
-# ---------------------------------------------------------------------------
-# Decode path: M == 1 split-K GEMV — parallelize over (N-tile, K-group).
-# ---------------------------------------------------------------------------
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_N": 256}, num_warps=4),
- triton.Config({"BLOCK_N": 128}, num_warps=4),
- triton.Config({"BLOCK_N": 512}, num_warps=8),
- ],
- key=["N", "K"],
-)
-@triton.jit
-def _w4a16_gemv_splitk_kernel(
- x_ptr,
- wq_ptr,
- scales_ptr,
- zeros_ptr,
- out_ptr,
- N,
- K,
- stride_wqn,
- stride_sg,
- stride_sn,
- BLOCK_N: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
- PACKED_PER_GROUP: tl.constexpr,
-):
- pid_n = tl.program_id(0)
- pid_g = tl.program_id(1)
-
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- n_mask = offs_n < N
-
- s = tl.load(
- scales_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
- ).to(tl.float32)
- z = tl.load(
- zeros_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
- ).to(tl.float32)
-
- acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
- base_k = pid_g * GROUP_SIZE
- base_packed = pid_g * PACKED_PER_GROUP
-
- for pi in tl.static_range(PACKED_PER_GROUP):
- packed_row = base_packed + pi
- k0 = base_k + 2 * pi
-
- w_packed = tl.load(
- wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
- ).to(tl.int32)
- w_lo = (w_packed & 0xF).to(tl.float32)
- w_hi = ((w_packed >> 4) & 0xF).to(tl.float32)
-
- x0 = tl.load(x_ptr + k0).to(tl.float32)
- x1 = tl.load(x_ptr + k0 + 1).to(tl.float32)
-
- acc += x0 * (w_lo - z) * s + x1 * (w_hi - z) * s
-
- tl.atomic_add(out_ptr + offs_n, acc, mask=n_mask)
-
-
-# ---------------------------------------------------------------------------
-# General GEMM path: M > 1
-# ---------------------------------------------------------------------------
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 256, "BLOCK_K": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 1, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8, num_stages=2),
- ],
- key=["M", "N", "K"],
-)
-@triton.jit
-def _w4a16_gemm_kernel(
- x_ptr,
- wq_ptr,
- scales_ptr,
- zeros_ptr,
- out_ptr,
- M,
- N,
- K,
- stride_xm,
- stride_xk,
- stride_wq_kh,
- stride_wqn,
- stride_sg,
- stride_sn,
- stride_om,
- stride_on,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
-):
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
-
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k0 in range(0, K, BLOCK_K):
- offs_k = k0 + tl.arange(0, BLOCK_K)
- k_mask = offs_k < K
-
- x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
- x_mask = (offs_m[:, None] < M) & k_mask[None, :]
- x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.bfloat16)
-
- packed_k = offs_k // 2
- shift = (offs_k % 2) * 4
-
- wq_ptrs = wq_ptr + packed_k[:, None] * stride_wq_kh + offs_n[None, :] * stride_wqn
- wq_mask = k_mask[:, None] & (offs_n[None, :] < N)
- wq_packed = tl.load(wq_ptrs, mask=wq_mask, other=0).to(tl.int32)
-
- w_int = (wq_packed >> shift[:, None]) & 0xF
- w_bf = w_int.to(tl.bfloat16)
-
- g_idx = offs_k // GROUP_SIZE
- s_ptrs = scales_ptr + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn
- z_ptrs = zeros_ptr + g_idx[:, None] * stride_sg + offs_n[None, :] * stride_sn
- s = tl.load(s_ptrs, mask=wq_mask, other=0.0).to(tl.bfloat16)
- z = tl.load(z_ptrs, mask=wq_mask, other=0.0).to(tl.bfloat16)
-
- w_deq = (w_bf - z) * s
- acc += tl.dot(x, w_deq, out_dtype=tl.float32)
-
- out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
- out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
-
-
-def w4a16_gemm(
- x: torch.Tensor,
- w_q: torch.Tensor,
- scales: torch.Tensor,
- zeros: torch.Tensor,
- group_size: int = GROUP_SIZE,
-) -> torch.Tensor:
- """x: (M,K) bf16; w_q: (K//2,N) uint8; scales/zeros: (K//group,N) bf16 -> (M,N) bf16."""
- assert x.is_contiguous()
- assert w_q.is_contiguous()
- assert scales.is_contiguous()
- assert zeros.is_contiguous()
- M, K = x.shape
- Kh, N = w_q.shape
- assert Kh * 2 == K
- assert scales.shape == zeros.shape
- assert scales.shape[0] == K // group_size
- assert scales.shape[1] == N
-
if M == 1:
- out = torch.zeros((M, N), dtype=torch.float32, device=x.devi
... (2938 more chars){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<724 chars \u2014 see diff>"
}null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,24 +1,16 @@
- if M == 1:
- n_groups = K // group_size
- packed_per_group = group_size // 2
- partial = torch.empty((n_groups, N), dtype=torch.float32, device=x.device)
- grid = lambda meta: (
- triton.cdiv(N, meta["BLOCK_N"]),
- n_groups,
- )
- _w4a16_gemv_splitk_kernel[grid](
- x,
- w_q,
- scales,
- zeros,
- partial,
- N,
- K,
- w_q.stride(0),
- scales.stride(0),
- scales.stride(1),
- partial.stride(0),
- GROUP_SIZE=group_size,
- PACKED_PER_GROUP=packed_per_group,
- )
- return partial.sum(dim=0, dtype=torch.bfloat16).unsqueeze(0)+def _w4a16_gemv_splitk_kernel(
+ x_ptr,
+ wq_ptr,
+ scales_ptr,
+ zeros_ptr,
+ partial_ptr,
+ N,
+ K,
+ stride_wqn,
+ stride_sg,
+ stride_sn,
+ stride_pg,
+ BLOCK_N: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+ PACKED_PER_GROUP: tl.constexpr,
+):{
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<272 chars \u2014 see diff>"
}null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,16 +1 @@
-def _w4a16_gemv_splitk_kernel(
- x_ptr,
- wq_ptr,
- scales_ptr,
- zeros_ptr,
- partial_ptr,
- N,
- K,
- stride_wqn,
- stride_sg,
- stride_sn,
- stride_pg,
- BLOCK_N: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
- PACKED_PER_GROUP: tl.constexpr,
-):+ tl.store(partial_ptr + pid_g * stride_pg + offs_n, acc, mask=n_mask){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": " tl.store(partial_ptr + pid_g * stride_pg + offs_n, acc, mask=n_mask)"
}null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_0537af84-e520-4caa-a30a-9938e1d6711",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"full
... (1481 more chars)null{
"command": "cd /workspace/problems/07_w4a16_gemm && KBH_BENCHMARK_BASELINES=1 uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_3864b180-e178-4a39-9c01-50d1553a8f0",
"simpleCommands": [
"cd",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"fullText": "cd /wor
... (949 more chars){
"path": "/workspace/problems/07_w4a16_gemm/solution.py"
}"""Fused W4A16 GEMM: unpack int4 weights + dequant + matmul in one Triton kernel."""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
# ---------------------------------------------------------------------------
# Decode path: M == 1 split-K GEMV — parallelize over (N-tile, K-group).
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 256}, num_warps=4),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_warps=8),
],
key=["N", "K"],
)
@triton.jit
def _w4a16_gemv_splitk_kernel(
x_ptr,
wq_ptr,
scales_ptr,
zeros_ptr,
partial_ptr,
N,
K,
stride_wqn,
stride_sg,
stride_sn,
stride_pg,
BLOCK_N: tl.constexpr,
GROUP_SIZE: tl.constexpr,
PACKED_PER_GROUP: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_g = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
s = tl.load(
scales_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
).to(tl.float32)
z = tl.load(
zeros_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
).to(tl.float32)
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
base_k = pid_g * GROUP_SIZE
base_packed = pid_g * PACKED_PER_GROUP
for pi in tl.static_range(PACKED_P
... (7004 more chars)null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1 +1,91 @@
- tl.store(partial_ptr + pid_g * stride_pg + offs_n, acc, mask=n_mask)+# ---------------------------------------------------------------------------
+# General GEMM path: M > 1 — one tl.dot per quant group (128 K).
+# ---------------------------------------------------------------------------
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=8, num_stages=2),
+ ],
+ key=["M", "N", "K"],
+)
+@triton.jit
+def _w4a16_gemm_kernel(
+ x_ptr,
+ wq_ptr,
+ scales_ptr,
+ zeros_ptr,
+ out_ptr,
+ M,
+ N,
+ K,
+ stride_xm,
+ stride_xk,
+ stride_wqn,
+ stride_sg,
+ stride_sn,
+ stride_om,
+ stride_on,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+ PACKED_PER_GROUP: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ m_mask = offs_m < M
+ n_mask = offs_n < N
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ n_groups = K // GROUP_SIZE
+
+ for g in tl.range(0, n_groups):
+ s = tl.load(
+ scales_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
+ ).to(tl.bfloat16)
+ z = tl.load(
+ zeros_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
+ ).to(tl.bfloat16)
+
+ base_k = g * GROUP_SIZE
+ base_packed = g * PACKED_PER_GROUP
+
+ w_deq = tl.zeros((GROUP_SIZE, BLOCK_N), dtype=tl.bfloat16)
+ for pi in tl.static_range(PACKED_PER_GROUP):
+ packed_row = base_packed + pi
+ k0 = base_k + 2 * pi
+ k1 = k0 + 1
+
+ w_packed = tl.load(
+ wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
+ ).to(tl.int32)
+ w_lo = (w_packed & 0xF).to(tl.bfloat16)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
+ w_deq = tl.where(
+ tl.arange(0, GROUP_SIZE)[:, None] == k0,
+ (w_lo - z) * s,
+ w_deq,
+ )
+ w_deq = tl.where(
+ tl.arange(0, GROUP_SIZE)[:, None] == k1,
+ (w_hi - z) * s,
+ w_deq,
+ )
+
+ offs_k = base_k + tl.arange(0, GROUP_SIZE)
+ x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
+ x_mask = m_mask[:, None] & (offs_k[None, :] < K)
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.bfloat16)
+
+ acc += tl.dot(x, w_deq, out_dtype=tl.float32)
+
+ out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ out_mask = m_mask[:, None] & n_mask[None, :]
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<3184 chars \u2014 see diff>"
}null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,91 +1,27 @@
-# ---------------------------------------------------------------------------
-# General GEMM path: M > 1 — one tl.dot per quant group (128 K).
-# ---------------------------------------------------------------------------
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=8, num_stages=2),
- ],
- key=["M", "N", "K"],
-)
-@triton.jit
-def _w4a16_gemm_kernel(
- x_ptr,
- wq_ptr,
- scales_ptr,
- zeros_ptr,
- out_ptr,
- M,
- N,
- K,
- stride_xm,
- stride_xk,
- stride_wqn,
- stride_sg,
- stride_sn,
- stride_om,
- stride_on,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- GROUP_SIZE: tl.constexpr,
- PACKED_PER_GROUP: tl.constexpr,
-):
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
-
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- m_mask = offs_m < M
- n_mask = offs_n < N
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- n_groups = K // GROUP_SIZE
-
- for g in tl.range(0, n_groups):
- s = tl.load(
- scales_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
- ).to(tl.bfloat16)
- z = tl.load(
- zeros_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
- ).to(tl.bfloat16)
-
- base_k = g * GROUP_SIZE
- base_packed = g * PACKED_PER_GROUP
-
- w_deq = tl.zeros((GROUP_SIZE, BLOCK_N), dtype=tl.bfloat16)
- for pi in tl.static_range(PACKED_PER_GROUP):
- packed_row = base_packed + pi
- k0 = base_k + 2 * pi
- k1 = k0 + 1
-
- w_packed = tl.load(
- wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
- ).to(tl.int32)
- w_lo = (w_packed & 0xF).to(tl.bfloat16)
- w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- w_deq = tl.where(
- tl.arange(0, GROUP_SIZE)[:, None] == k0,
- (w_lo - z) * s,
- w_deq,
- )
- w_deq = tl.where(
- tl.arange(0, GROUP_SIZE)[:, None] == k1,
- (w_hi - z) * s,
- w_deq,
- )
-
- offs_k = base_k + tl.arange(0, GROUP_SIZE)
- x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
- x_mask = m_mask[:, None] & (offs_k[None, :] < K)
- x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.bfloat16)
-
- acc += tl.dot(x, w_deq, out_dtype=tl.float32)
-
- out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
- out_mask = m_mask[:, None] & n_mask[None, :]
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)+ else:
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+ packed_per_group = group_size // 2
+ grid = lambda meta: (
+ triton.cdiv(M, meta["BLOCK_M"]),
+ triton.cdiv(N, meta["BLOCK_N"]),
+ )
+ _w4a16_gemm_kernel[grid](
+ x,
+ w_q,
+ scales,
+ zeros,
+ out,
+ M,
+ N,
+ K,
+ x.stride(0),
+ x.stride(1),
+ w_q.stride(0),
+ scales.stride(0),
+ scales.stride(1),
+ out.stride(0),
+ out.stride(1),
+ GROUP_SIZE=group_size,
+ PACKED_PER_GROUP=packed_per_group,
+ )
+ return out{
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<724 chars \u2014 see diff>"
}null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,27 +1,34 @@
- else:
- out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
- packed_per_group = group_size // 2
- grid = lambda meta: (
- triton.cdiv(M, meta["BLOCK_M"]),
- triton.cdiv(N, meta["BLOCK_N"]),
- )
- _w4a16_gemm_kernel[grid](
- x,
- w_q,
- scales,
- zeros,
- out,
- M,
- N,
- K,
- x.stride(0),
- x.stride(1),
- w_q.stride(0),
- scales.stride(0),
- scales.stride(1),
- out.stride(0),
- out.stride(1),
- GROUP_SIZE=group_size,
- PACKED_PER_GROUP=packed_per_group,
- )
- return out+ for g in tl.range(0, n_groups):
+ s = tl.load(
+ scales_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
+ ).to(tl.bfloat16)
+ z = tl.load(
+ zeros_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
+ ).to(tl.bfloat16)
+
+ base_k = g * GROUP_SIZE
+ base_packed = g * PACKED_PER_GROUP
+
+ w_deq_lo = tl.zeros((PACKED_PER_GROUP, BLOCK_N), dtype=tl.bfloat16)
+ w_deq_hi = tl.zeros((PACKED_PER_GROUP, BLOCK_N), dtype=tl.bfloat16)
+ for pi in tl.static_range(PACKED_PER_GROUP):
+ packed_row = base_packed + pi
+ w_packed = tl.load(
+ wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
+ ).to(tl.int32)
+ w_lo = (w_packed & 0xF).to(tl.bfloat16)
+ w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
+ row = tl.full([1, 1], pi, tl.int32)
+ w_deq_lo = tl.where(row == pi, (w_lo - z) * s, w_deq_lo)
+ w_deq_hi = tl.where(row == pi, (w_hi - z) * s, w_deq_hi)
+
+ offs_k = base_k + tl.arange(0, GROUP_SIZE)
+ x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
+ x_mask = m_mask[:, None] & (offs_k[None, :] < K)
+ x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.bfloat16)
+ x_grp = tl.reshape(x, (BLOCK_M, PACKED_PER_GROUP, 2))
+ x_even = x_grp[:, :, 0]
+ x_odd = x_grp[:, :, 1]
+
+ acc += tl.dot(x_even, w_deq_lo, out_dtype=tl.float32)
+ acc += tl.dot(x_odd, w_deq_hi, out_dtype=tl.float32){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<1565 chars \u2014 see diff>"
}null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,34 +1,10 @@
- for g in tl.range(0, n_groups):
- s = tl.load(
- scales_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
- ).to(tl.bfloat16)
- z = tl.load(
- zeros_ptr + g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
- ).to(tl.bfloat16)
+ base_packed = g * PACKED_PER_GROUP
+ packed_rows = base_packed + tl.arange(0, PACKED_PER_GROUP)
+ wq_ptrs = wq_ptr + packed_rows[:, None] * stride_wqn + offs_n[None, :]
+ wq_mask = n_mask[None, :]
+ wq_packed = tl.load(wq_ptrs, mask=wq_mask, other=0).to(tl.int32)
- base_k = g * GROUP_SIZE
- base_packed = g * PACKED_PER_GROUP
-
- w_deq_lo = tl.zeros((PACKED_PER_GROUP, BLOCK_N), dtype=tl.bfloat16)
- w_deq_hi = tl.zeros((PACKED_PER_GROUP, BLOCK_N), dtype=tl.bfloat16)
- for pi in tl.static_range(PACKED_PER_GROUP):
- packed_row = base_packed + pi
- w_packed = tl.load(
- wq_ptr + packed_row * stride_wqn + offs_n, mask=n_mask, other=0
- ).to(tl.int32)
- w_lo = (w_packed & 0xF).to(tl.bfloat16)
- w_hi = ((w_packed >> 4) & 0xF).to(tl.bfloat16)
- row = tl.full([1, 1], pi, tl.int32)
- w_deq_lo = tl.where(row == pi, (w_lo - z) * s, w_deq_lo)
- w_deq_hi = tl.where(row == pi, (w_hi - z) * s, w_deq_hi)
-
- offs_k = base_k + tl.arange(0, GROUP_SIZE)
- x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
- x_mask = m_mask[:, None] & (offs_k[None, :] < K)
- x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.bfloat16)
- x_grp = tl.reshape(x, (BLOCK_M, PACKED_PER_GROUP, 2))
- x_even = x_grp[:, :, 0]
- x_odd = x_grp[:, :, 1]
-
- acc += tl.dot(x_even, w_deq_lo, out_dtype=tl.float32)
- acc += tl.dot(x_odd, w_deq_hi, out_dtype=tl.float32)+ w_lo = (wq_packed & 0xF).to(tl.bfloat16)
+ w_hi = ((wq_packed >> 4) & 0xF).to(tl.bfloat16)
+ w_deq_lo = (w_lo - z[None, :]) * s[None, :]
+ w_deq_hi = (w_hi - z[None, :]) * s[None, :]{
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<505 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_a45c9944-0f0f-44a9-90b3-2f056642f4e",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"full
... (1478 more chars)failure:
command: 'cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1'
workingDirectory: ''
exitCode: 1
signal: ''
stdout:
Traceback (most recent call last):
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1101, in __getitem__
raise ValueError(f"unsupported tensor index: {sl}")
ValueError: unsupported tensor index: constexpr[0]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspace/problems/07_w4a16_gemm/check.py", line 119, in <module>
main()
File "/workspace/problems/07_w4a16_gemm/check.py", line 80, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/problems/07_w4a16_gemm/solution.py", line 261, in forward
return
... (8686 more chars)--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,10 +1,12 @@
- base_packed = g * PACKED_PER_GROUP
- packed_rows = base_packed + tl.arange(0, PACKED_PER_GROUP)
- wq_ptrs = wq_ptr + packed_rows[:, None] * stride_wqn + offs_n[None, :]
- wq_mask = n_mask[None, :]
- wq_packed = tl.load(wq_ptrs, mask=wq_mask, other=0).to(tl.int32)
-
- w_lo = (wq_packed & 0xF).to(tl.bfloat16)
- w_hi = ((wq_packed >> 4) & 0xF).to(tl.bfloat16)
- w_deq_lo = (w_lo - z[None, :]) * s[None, :]
- w_deq_hi = (w_hi - z[None, :]) * s[None, :]+ offs_k_even = base_k + 2 * tl.arange(0, PACKED_PER_GROUP)
+ offs_k_odd = offs_k_even + 1
+ x_even_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k_even[None, :] * stride_xk
+ x_odd_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k_odd[None, :] * stride_xk
+ k_even_mask = offs_k_even[None, :] < K
+ k_odd_mask = offs_k_odd[None, :] < K
+ x_even = tl.load(x_even_ptrs, mask=m_mask[:, None] & k_even_mask, other=0.0).to(
+ tl.bfloat16
+ )
+ x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & k_odd_mask, other=0.0).to(
+ tl.bfloat16
+ ){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<621 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_dcd645cb-a94b-4873-98dc-32a8c73edea",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"full
... (1478 more chars)null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,12 +1,14 @@
- offs_k_even = base_k + 2 * tl.arange(0, PACKED_PER_GROUP)
- offs_k_odd = offs_k_even + 1
- x_even_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k_even[None, :] * stride_xk
- x_odd_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k_odd[None, :] * stride_xk
- k_even_mask = offs_k_even[None, :] < K
- k_odd_mask = offs_k_odd[None, :] < K
- x_even = tl.load(x_even_ptrs, mask=m_mask[:, None] & k_even_mask, other=0.0).to(
- tl.bfloat16
- )
- x_odd = tl.load(x_odd_ptrs, mask=m_mask[:, None] & k_odd_mask, other=0.0).to(
- tl.bfloat16
- )+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=8, num_stages=2),
+ ],
+ key=["M", "N", "K"],
+){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<812 chars \u2014 see diff>"
}null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,14 +1,9 @@
@triton.autotune(
configs=[
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=2),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
- triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_N": 256}, num_warps=4),
+ triton.Config({"BLOCK_N": 512}, num_warps=8),
+ triton.Config({"BLOCK_N": 128}, num_warps=4),
+ triton.Config({"BLOCK_N": 1024}, num_warps=8),
],
- key=["M", "N", "K"],
+ key=["N", "K"],
){
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<277 chars \u2014 see diff>"
}null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,9 +1,22 @@
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_N": 256}, num_warps=4),
- triton.Config({"BLOCK_N": 512}, num_warps=8),
- triton.Config({"BLOCK_N": 128}, num_warps=4),
- triton.Config({"BLOCK_N": 1024}, num_warps=8),
- ],
- key=["N", "K"],
-)+ tl.store(partial_ptr + pid_g * stride_pg + offs_n, acc, mask=n_mask)
+
+
+@triton.jit
+def _reduce_partial_kernel(
+ partial_ptr,
+ out_ptr,
+ N,
+ n_groups,
+ stride_pg,
+ BLOCK_N: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
+ n_mask = offs_n < N
+ acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
+ for g in tl.range(0, n_groups):
+ acc += tl.load(partial_ptr + g * stride_pg + offs_n, mask=n_mask, other=0.0)
+ tl.store(out_ptr + offs_n, acc.to(tl.bfloat16), mask=n_mask)
+
+
+# ---------------------------------------------------------------------------{
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<627 chars \u2014 see diff>"
}null--- a//workspace/problems/07_w4a16_gemm/solution.py
+++ b//workspace/problems/07_w4a16_gemm/solution.py
@@ -1,22 +1,26 @@
- tl.store(partial_ptr + pid_g * stride_pg + offs_n, acc, mask=n_mask)
-
-
-@triton.jit
-def _reduce_partial_kernel(
- partial_ptr,
- out_ptr,
- N,
- n_groups,
- stride_pg,
- BLOCK_N: tl.constexpr,
-):
- pid = tl.program_id(0)
- offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
- n_mask = offs_n < N
- acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
- for g in tl.range(0, n_groups):
- acc += tl.load(partial_ptr + g * stride_pg + offs_n, mask=n_mask, other=0.0)
- tl.store(out_ptr + offs_n, acc.to(tl.bfloat16), mask=n_mask)
-
-
-# ---------------------------------------------------------------------------+ _w4a16_gemv_splitk_kernel[grid](
+ x,
+ w_q,
+ scales,
+ zeros,
+ partial,
+ N,
+ K,
+ w_q.stride(0),
+ scales.stride(0),
+ scales.stride(1),
+ partial.stride(0),
+ GROUP_SIZE=group_size,
+ PACKED_PER_GROUP=packed_per_group,
+ )
+ out = torch.empty((1, N), dtype=torch.bfloat16, device=x.device)
+ reduce_grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
+ _reduce_partial_kernel[reduce_grid](
+ partial,
+ out,
+ N,
+ n_groups,
+ partial.stride(0),
+ BLOCK_N=256,
+ )
+ return out{
"path": "/workspace/problems/07_w4a16_gemm/solution.py",
"streamContent": "<720 chars \u2014 see diff>"
}null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"workingDirectory": "",
"timeout": 300000,
"toolCallId": "tool_ce468aad-6900-400a-96cc-c9af010e4ba",
"simpleCommands": [
"cd",
"uv",
"uv"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"full
... (1476 more chars)null{
"command": "cd /workspace/problems/07_w4a16_gemm && uv run python check.py 2>&1 && cat framework.txt",
"workingDirectory": "",
"timeout": 30000,
"toolCallId": "tool_b2ef9d50-7d36-4ab7-ad52-bfa7adb96a4",
"simpleCommands": [
"cd",
"uv",
"cat"
],
"hasInputRedirect": false,
"hasOutputRedirect": true,
"parsingResult": {
"parsingFailed": false,
"executableCommands": [
{
"name": "cd",
"args": [
{
"type": "word",
"value": "/workspace/problems/07_w4a16_gemm"
}
],
"fullText": "cd /wo
... (1133 more chars){
"path": "/workspace/problems/07_w4a16_gemm/solution.py"
}"""Fused W4A16 GEMM: unpack int4 weights + dequant + matmul in one Triton kernel."""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
# ---------------------------------------------------------------------------
# Decode path: M == 1 split-K GEMV — parallelize over (N-tile, K-group).
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 256}, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_warps=8),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_warps=8),
],
key=["N", "K"],
)
@triton.jit
def _w4a16_gemv_splitk_kernel(
x_ptr,
wq_ptr,
scales_ptr,
zeros_ptr,
partial_ptr,
N,
K,
stride_wqn,
stride_sg,
stride_sn,
stride_pg,
BLOCK_N: tl.constexpr,
GROUP_SIZE: tl.constexpr,
PACKED_PER_GROUP: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_g = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
s = tl.load(
scales_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
).to(tl.float32)
z = tl.load(
zeros_ptr + pid_g * stride_sg + offs_n * stride_sn, mask=n_mask, other=0.0
).to(tl.float32)
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
base_k = pid_g * GROUP_SIZE
base_packed = pid_g * P
... (8292 more chars)null