import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 512, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'num_warps': 2, 'num_stages': 2}),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'num_warps': 8, 'num_stages': 3}),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8, 'num_stages': 3}),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'num_warps': 8, 'num_stages': 3}),
],
key=['M', 'N'],
)
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_q_ptr, scales_ptr, zeros_ptr, out_ptr,
stride_x_m, stride_x_k,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
stride_z_g, stride_z_n,
stride_out_m, stride_out_n,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = offs_m < M
n_mask = offs_n < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, 128):
offs_k_even = k_start + tl.arange(0, 64) * 2
offs_k_odd = k_start + tl.arange(0, 64) * 2 + 1
offs_kh = k_start // 2 + tl.arange(0, 64)
k_mask_even = offs_k_even < K
k_mask_odd = offs_k_odd < K
kh_mask = offs_kh < K // 2
x_even = tl.load(
x_ptr + offs_m[:, None] * stride_x_m + offs_k_even[None, :] * stride_x_k,
mask=m_mask[:, None] & k_mask_even[None, :], other=0.0
)
x_odd = tl.load(
x_ptr + offs_m[:, None] * stride_x_m + offs_k_odd[None, :] * stride_x_k,
mask=m_mask[:, None] & k_mask_odd[None, :], other=0.0
)
w_packed = tl.load(
w_q_ptr + offs_kh[:, None] * stride_wq_kh + offs_n[None, :] * stride_wq_n,
mask=kh_mask[:, None] & n_mask[None, :],
other=0
)
w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
g_idx = k_start // 128
s = tl.load(scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
mask=n_mask, other=0.0)
z = tl.load(zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
mask=n_mask, other=0.0)
w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
acc += tl.dot(x_even, w_even_deq)
acc += tl.dot(x_odd, w_odd_deq)
c = acc.to(tl.bfloat16)
tl.store(
out_ptr + offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n,
c, mask=m_mask[:, None] & n_mask[None, :]
)
class Model(nn.Module):
def __init__(self, M, N, K, group_size=GROUP_SIZE):
super().__init__()
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
self.register_buffer("w_q", torch.zeros(K // 2, N, dtype=torch.uint8))
self.register_buffer("scales", torch.zeros(n_groups, N, dtype=torch.bfloat16))
self.register_buffer("zeros", torch.zeros(n_groups, N, dtype=torch.bfloat16))
def forward(self, x):
M, K = x.shape
N = self.N
out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
M, N, K,
)
return out
M = 1
N = 12288
K = 4096
def get_inputs():
x = torch.randn(M, K, dtype=torch.bfloat16)
return [x]
def get_init_inputs():
return [M, N, K]
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.953 ms=0.766
shape=0 variant=compiled tflops=0.665 gbps=176.761 ms=0.151
shape=0 variant=sota tflops=2.303 gbps=612.451 ms=0.044
shape=0 variant=solution tflops=1.489 gbps=396.027 ms=0.068
shape=0 solution_peak_fraction=0.2200
shape=1 variant=eager tflops=4.235 gbps=36.534 ms=0.761
shape=1 variant=compiled tflops=20.956 gbps=180.775 ms=0.154
shape=1 variant=sota tflops=23.248 gbps=200.543 ms=0.139
shape=1 variant=solution tflops=33.421 gbps=288.297 ms=0.096
shape=1 solution_peak_fraction=0.1602
shape=2 variant=eager tflops=31.378 gbps=42.772 ms=0.821
shape=2 variant=compiled tflops=119.544 gbps=162.952 ms=0.216
shape=2 variant=sota tflops=134.330 gbps=183.107 ms=0.192
shape=2 variant=solution tflops=90.826 gbps=123.806 ms=0.284
shape=2 solution_peak_fraction=0.0688
shape=3 variant=eager tflops=0.182 gbps=48.550 ms=0.184
shape=3 variant=compiled tflops=0.498 gbps=132.561 ms=0.067
shape=3 variant=sota tflops=1.215 gbps=323.337 ms=0.028
shape=3 variant=solution tflops=0.580 gbps=154.336 ms=0.058
shape=3 solution_peak_fraction=0.0857
shape=4 variant=eager tflops=2.064 gbps=34.911 ms=0.910
shape=4 variant=compiled tflops=10.372 gbps=175.445 ms=0.181
shape=4 variant=sota tflops=11.374 gbps=192.403 ms=0.165
shape=4 variant=solution tflops=22.442 gbps=379.622 ms=0.084
shape=4 solution_peak_fraction=0.2109
peak_fraction: 0.1344
RESULT: OK
PASS
{
"problem": "07_w4a16_gemm",
"harness": "opencode",
"model": "deepseek/deepseek-v4-flash",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1344,
"elapsed_seconds": 1065,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 43184, "output_tokens": 21754, "cache_read_tokens": 3479296, "cache_creation_tokens": 0, "reasoning_tokens": 54869, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive W4A16 weight-only quantized GEMM reference (correctness only).
2:
3: AWQ/GPTQ-style scheme:
4: x: (M, K) bf16
5: w_q: (K // 2, N) uint8 -- two int4 weights packed per byte (low nibble = even-K, high = odd-K)
6: scales: (K // group, N) bf16
7: zeros: (K // group, N) bf16 -- asymmetric (stored already as float zero-point)
8: out: (M, N) bf16
9:
10: Dequant (per group along K):
11: w_bf[k, n] = (w_q[k, n] - zeros[k // group, n]) * scales[k // group, n]
12: where w_q[k, n] is the unpacked 4-bit value (0..15).
13:
14: This reference unpacks to a full bf16 matrix and then runs torch.matmul. Slow and
15: memory-heavy on the dequant; the agent's solution must fuse unpack+GEMM.
16: """
17: from __future__ import annotations
18:
19: import torch
20: import torch.nn as nn
21:
22: OP_TYPE = "gemm_w4a16"
23: SUPPORTED_PRECISIONS = ["int4_bf16"]
24: HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
25:
26: GROUP_SIZE = 128
27:
28:
29: def _pack_int4(w_q: torch.Tensor) -> torch.Tensor:
30: """Pack (K, N) uint8 in [0,15] into (K//2, N) uint8.
31:
32: Even rows go in the low nibble, odd rows in the high nibble.
33: """
34: K, N = w_q.shape
35: assert K % 2 == 0
36: lo = w_q[0::2].to(torch.uint8) & 0xF
37: hi = w_q[1::2].to(torch.uint8) & 0xF
38: ret
... (3312 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for FP8 GEMM.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or None
42:
43: # --- Per-shape
... (2909 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for FP8 GEMM.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime", "compute")
40: flops_formul
... (3415 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Shape sweep for W4A16 GEMM.
2:
3: Llama-style up_proj / qkv_proj shapes. Decode (M=1) is the bandwidth-bound
4: case every inference engine optimizes -- it's the bar to beat.
5: """
6:
7: SHAPES = [
8: {"M": 1, "N": 12288, "K": 4096}, # decode: memory-bound on int4 weight read
9: {"M": 32, "N": 12288, "K": 4096}, # small prefill: mixed regime
10: {"M": 256, "N": 12288, "K": 4096}, # larger prefill: approaching compute
11: {"M": 1, "N": 4096, "K": 4096}, # decode: square shape
12: {"M": 16, "N": 14336, "K": 4096}, # speculative-decode-ish
13: ]
(End of file - total 13 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 07_w4a16_gemm
2: display_name: "W4A16 Weight-only Quantized GEMM"
3: precision: int4_bf16
4: regime: memory # decode-dominant; M=1 is bandwidth-bound on the int4 weight stream
5:
6: # Dense-equivalent FLOPs (matmul work, ignoring dequant arithmetic).
7: flops_formula: "2 * M * N * K"
8:
9: # Bytes moved per call (memory roofline):
10: # x: M*K*2 (bf16 activations, streamed in once)
11: # w_q: (K/2)*N (packed int4, 0.5 B/elem)
12: # scales: (K/128)*N*2 (bf16 scales)
13: # zeros: (K/128)*N*2 (bf16 zero-points)
14: # out: M*N*2 (bf16 store)
15: bytes_formula: "M*K*2 + (K/2)*N + (K/128)*N*2 + (K/128)*N*2 + M*N*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.10 # group-quant adds noise on top of bf16 accumulator slop
23:
24: # Forbidden ops -- agent must write the unpack + GEMM themselves, not call a
25: # vendor library that does both.
26: forbidden:
27: - "bitsandbytes.functional.dequantize_4bit"
28: - "bitsandbytes.functional.gemv_4bit"
29: - "marlin_kernel.gemm"
30: - "torch.nn.functional.linear"
31:
32: sota:
33: name: "bitsandbytes NF4 (gemv_4bit / dequantize_4bit + matmul)"
34: url: "https://github.com/TimDettmers/bitsandbytes"
35: function: "bitsandbytes.functional.gemv_4bit"
36: notes: |
37:
... (757 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for W4A16 GEMM.
2:
3: Library survey on RTX PRO 6000 Blackwell (SM120, CC 12.0):
4:
5: - Marlin (IST-DASLab): no SM120 kernels (Ampere/Hopper only). Skip.
6: - GPTQ-Triton (fpgaminer): unmaintained; pure Triton path works on SM120
7: but is not faster than Marlin on its target HW
8: and has no Blackwell tuning. Skip as primary.
9: - AWQ (mit-han-lab/llm-awq): CUDA kernels not built for SM120 in the wheel.
10: Skip.
11: - bitsandbytes >= 0.49.2: CUDA kernels compile and run on SM120 (verified
12: on this machine). Different quant scheme (NF4,
13: symmetric, blocksize 64) than our reference's
14: AWQ-style asymmetric INT4 with group_size 128,
15: but it occupies the same memory regime and is
16: the only tuned W4A16-class kernel that runs on
17: SM120 today. Used here as an *informational*
18: SOTA line, not as a numerical reference.
19:
20: The benchmark calls `sota_forward(x, ref_model)` and times it; correctness is
21: NOT checked against this path (the quant scheme differs).
22: """
23: from
... (2725 more chars){
"command": "python -c \"import triton; print(triton.__version__)\" 2>&1",
"description": "Check Triton version"
}3.6.0
{
"command": "nvidia-smi -q 2>/dev/null | head -20",
"description": "Check GPU info"
}==============NVSMI LOG==============
Timestamp : Mon Apr 27 14:09:14 2026
Driver Version : 595.58.03
CUDA Version : 13.2
Attached GPUs : 1
GPU 00000000:01:00.0
Product Name : NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Product Brand : NVIDIA RTX
Product Architecture : Blackwell
Display Mode : Requested functionality has been deprecated
Display Attached : Yes
Display Active : Enabled
Persistence Mode : Enabled
Addressing Mode : HMM
MIG Mode
Current : N/A
Pending : N/A
{
"command": "python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_capability()); print(torch.cuda.get_device_name())\" 2>&1",
"description": "Check PyTorch version and GPU"
}2.11.0+cu130
(12, 0)
NVIDIA RTX PRO 6000 Blackwell Workstation Edition
{
"description": "Look for Triton matmul examples",
"command": "<70 chars \u2014 see below>"
}python -c "
import triton
import inspect
print(triton.__file__)
" 2>&1/home/infatoshi/.local/lib/python3.12/site-packages/triton/__init__.py
{
"description": "Check Triton dot documentation",
"command": "<71 chars \u2014 see below>"
}python -c "
import triton.language as tl
help(tl.dot)
" 2>&1 | head -40Help on function dot in module triton.language.core:
dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=triton.language.float32, _semantic=None)
Returns the matrix product of two blocks.
The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions.
For three-dimensional blocks, `tl.dot` performs the batched matrix product,
where the first dimension of each block represents the batch dimension.
:param input: The first tensor to be multiplied.
:type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second tensor to be multiplied.
:type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
:type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
:param input_precision: How to exercise the Tensor Cores for f32 x f32. If
the device does not have Tensor Cores or the inputs are not of dtype f32,
this option is ignored. For devices that do have tensor cores, the
default precision is tf32.
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available o
... (272 more chars){
"description": "Check Triton reshape API and basic operations",
"command": "<75 chars \u2014 see below>"
}python -c "
import triton.language as tl
help(tl.reshape)
" 2>&1 | head -30Help on function reshape in module triton.language.core:
reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None)
Returns a tensor with the same number of elements as input but with the
provided shape.
:param input: The input tensor.
:type input: Block
:param shape: The new shape.
:code:`shape` can be passed as a tuple or as individual parameters: ::
# These are equivalent
reshape(x, (32, 32))
reshape(x, 32, 32)
This function can also be called as a member function on :py:class:`tensor`,
as :code:`x.reshape(...)` instead of
:code:`reshape(x, ...)`.
{
"description": "Check if tl.cat or tl.stack available",
"command": "<270 chars \u2014 see below>"
}python -c "
import triton.language as tl
# Check for interleave-like operations
for name in dir(tl):
if 'terleave' in name.lower() or 'stack' in name.lower() or 'cat' in name.lower() or 'permute' in name.lower() or 'trans' in name.lower():
print(name)
" 2>&1cat
interleave
permute
trans
{
"description": "Check interleave documentation",
"command": "<78 chars \u2014 see below>"
}python -c "
import triton.language as tl
help(tl.interleave)
" 2>&1 | head -40Help on JITFunction in module triton.language.standard:
interleave = JITFunction(triton.language.standard:interleave)
Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape.
Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])`
:param a: The first input tensor.
:type a: Tensor
:param b: The second input tensor.
:type b: Tensor
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -0,0 +1,135 @@
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+GROUP_SIZE = 128
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'num_warps': 4}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'num_warps': 4}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'num_warps': 2}),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'num_warps': 4}),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'num_warps': 4}),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'num_warps': 8}),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'num_warps': 4}),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8}),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'num_warps': 4}),
+ ],
+ key=['M', 'N'],
+)
+@triton.jit
+def w4a16_gemm_kernel(
+ x_ptr, w_q_ptr, scales_ptr, zeros_ptr, out_ptr,
+ stride_x_m, stride_x_k,
+ stride_wq_kh, stride_wq_n,
+ stride_s_g, stride_s_n,
+ stride_z_g, stride_z_n,
+ stride_out_m, stride_out_n,
+ M, N, K,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
+):
+ BLOCK_K = GROUP_SIZE
+ BLOCK_KH = BLOCK_K // 2
+
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ m_mask = offs_m < M
+ n_mask = offs_n < N
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, K, BLOCK_K):
+ offs_k_even = k_start + tl.arange(0, BLOCK_KH) * 2
+ offs_k_odd = k_start + tl.arange(0, BLOCK_KH) * 2 + 1
+ offs_kh = k_start // 2 + tl.arange(0, BLOCK_KH)
+
+ k_mask_even = offs_k_even < K
+ k_mask_odd = offs_k_odd < K
+ kh_mask = offs_kh < K // 2
+
+ x_even = tl.load(
+ x_ptr + offs_m[:, None] * stride_x_m + offs_k_even[None, :] * stride_x_k,
+ mask=m_mask[:, None] & k_mask_even[None, :], other=0.0
+ )
+ x_odd = tl.load(
+ x_ptr + offs_m[:, None] * stride_x_m + offs_k_odd[None, :] * stride_x_k,
+ mask=m_mask[:, None] & k_mask_odd[None, :], other=0.0
+ )
+
+ w_packed = tl.load(
+ w_q_ptr + offs_kh[:, None] * stride_wq_kh + offs_n[None, :] * stride_wq_n,
+ mask=kh_mask[:, None] & n_mask[None, :],
+ other=0
+ )
+
+ w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
+ w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
+
+ g_idx = k_start // GROUP_SIZE
+ s = tl.load(scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
+ mask=n_mask, other=0.0)
+ z = tl.load(zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
+ mask=n_mask, other=0.0)
+
+ w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
+ w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
+
+ acc += tl.dot(x_even, w_even_deq)
+ acc += tl.dot(x_odd, w_odd_deq)
+
+ c = acc.to(tl.bfloat16)
+ tl.store(
+ out_ptr + offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n,
+ c, mask=m_mask[:, None] & n_mask[None, :]
+ )
+
+
+class Model(nn.Module):
+ def __init__(self, M, N, K, group_size=GROUP_SIZE):
+ super().__init__()
+ self.M, self.N, self.K = M, N, K
+ self.group_size = group_size
+ n_groups = K // group_size
+ self.register_buffer("w_q", torch.zeros(K // 2, N, dtype=torch.uint8))
+ self.register_buffer("scales", torch.zeros(n_groups, N, dtype=torch.bfloat16))
+ self.register_buffer("zeros", torch.zeros(n_groups, N, dtype=torch.bfloat16))
+
+ def forward(self, x):
+ M, K = x.shape
+ N = self.N
+ out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
+
+ def grid(META):
+ return (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ w4a16_gemm_kernel[grid](
+ x, self.w_q, self.scales, self.zeros, out,
+ x.stride(0), x.stride(1),
+ self.w_q.stride(0), self.w_q.stride(1),
+ self.scales.stride(0), self.scales.stride(1),
+ self.zeros.stride(0), self.zeros.stride(1),
+ out.stride(0), out.stride(1),
+ M, N, K,
+ )
+ return out
+
+
+M = 1
+N = 12288
+K = 4096
+
+
+def get_inputs():
+ x = torch.randn(M, K, dtype=torch.bfloat16)
+ return [x]
+
+
+def get_init_inputs():
+ return [M, N, K]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<4408 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Quick test the Triton kernel compiles",
"command": "<757 chars \u2014 see below>",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
import solution
import reference
# Test shape 0: M=1, N=12288, K=4096
reference.M = 1
reference.N = 12288
reference.K = 4096
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).cuda().eval()
sol_model = solution.Model(*init_args).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'Shape 0: Max diff = {diff:.6f}')
print(f'PASS: {diff < 0.10}')
" 2>&1Traceback (most recent call last):
File "<string>", line 24, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 112, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 240, in run
benchmark()
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 229, in benchmark
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 164, in _bench
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^
... (2210 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -29,6 +29,7 @@
stride_z_g, stride_z_n,
stride_out_m, stride_out_n,
M, N, K,
+ GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
BLOCK_K = GROUP_SIZE
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<1055 chars \u2014 see below>",
"newString": "<1085 chars \u2014 see below>"
}@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'num_warps': 4}),
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'num_warps': 4}),
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'num_warps': 2}),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'num_warps': 4}),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'num_warps': 4}),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'num_warps': 8}),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'num_warps': 4}),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8}),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'num_warps': 4}),
],
key=['M', 'N'],
)
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_q_ptr, scales_ptr, zeros_ptr, out_ptr,
stride_x_m, stride_x_k,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
stride_z_g, stride_z_n,
stride_out_m, stride_out_n,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
BLOCK_K = GROUP_SIZE
BLOCK_KH = BLOCK_K // 2@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'num_warps': 4}),
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'num_warps': 4}),
triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'num_warps': 2}),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'num_warps': 4}),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'num_warps': 4}),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'num_warps': 8}),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'num_warps': 4}),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8}),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'num_warps': 4}),
],
key=['M', 'N'],
)
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_q_ptr, scales_ptr, zeros_ptr, out_ptr,
stride_x_m, stride_x_k,
stride_wq_kh, stride_wq_n,
stride_s_g, stride_s_n,
stride_z_g, stride_z_n,
stride_out_m, stride_out_n,
M, N, K,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
BLOCK_K = GROUP_SIZE
BLOCK_KH = BLOCK_K // 2Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -73,7 +73,7 @@
w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
- g_idx = k_start // GROUP_SIZE
+ g_idx = k_start // BLOCK_K
s = tl.load(scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
mask=n_mask, other=0.0)
z = tl.load(zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": " g_idx = k_start // GROUP_SIZE",
"newString": " g_idx = k_start // BLOCK_K"
}Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -118,6 +118,7 @@
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
M, N, K,
+ GROUP_SIZE,
)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<364 chars \u2014 see below>",
"newString": "<388 chars \u2014 see below>"
} w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
M, N, K,
) w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
M, N, K,
GROUP_SIZE,
)Edit applied successfully.{
"description": "Test the updated kernel compiles and runs",
"command": "<757 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
import solution
import reference
# Test shape 0: M=1, N=12288, K=4096
reference.M = 1
reference.N = 12288
reference.K = 4096
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).cuda().eval()
sol_model = solution.Model(*init_args).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'Shape 0: Max diff = {diff:.6f}')
print(f'PASS: {diff < 0.10}')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 24, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 113, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, *
... (2556 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -32,8 +32,8 @@
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
- BLOCK_K = GROUP_SIZE
- BLOCK_KH = BLOCK_K // 2
+ BLOCK_K = 128
+ BLOCK_KH = 64
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<52 chars \u2014 see below>",
"newString": "<35 chars \u2014 see below>"
} BLOCK_K = GROUP_SIZE
BLOCK_KH = BLOCK_K // 2 BLOCK_K = 128
BLOCK_KH = 64Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
1: import torch
2: import torch.nn as nn
3: import triton
4: import triton.language as tl
5:
6: GROUP_SIZE = 128
7:
8:
9: @triton.autotune(
10: configs=[
11: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'num_warps': 4}),
12: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'num_warps': 4}),
13: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'num_warps': 2}),
14: triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'num_warps': 4}),
15: triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'num_warps': 4}),
16: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'num_warps': 8}),
17: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'num_warps': 4}),
18: triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8}),
19: triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'num_warps': 4}),
20: ],
21: key=['M', 'N'],
22: )
23: @triton.jit
24: def w4a16_gemm_kernel(
25: x_ptr, w_q_ptr, scales_ptr, zeros_ptr, out_ptr,
26: stride_x_m, stride_x_k,
27: stride_wq_kh, stride_wq_n,
28: stride_s_g, stride_s_n,
29: stride_z_g, stride_z_n,
30: stride_out_m, stride_out_n,
31: M, N, K,
32: GROUP_SIZE: tl.constexpr,
33: BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
34: ):
35: BLOCK_K = 128
36: BLOCK_KH = 64
37:
38: pid_m = tl.program_id(0)
39: pid_n = tl.program_id(1)
40:
41:
... (3676 more chars){
"description": "Test updated kernel",
"command": "<720 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
import solution
import reference
reference.M = 1
reference.N = 12288
reference.K = 4096
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).cuda().eval()
sol_model = solution.Model(*init_args).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'Shape 0: Max diff = {diff:.6f}')
print(f'PASS: {diff < 0.10}')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 23, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 113, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, *
... (2556 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -32,9 +32,6 @@
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
- BLOCK_K = 128
- BLOCK_KH = 64
-
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -46,10 +43,10 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_start in range(0, K, BLOCK_K):
- offs_k_even = k_start + tl.arange(0, BLOCK_KH) * 2
- offs_k_odd = k_start + tl.arange(0, BLOCK_KH) * 2 + 1
- offs_kh = k_start // 2 + tl.arange(0, BLOCK_KH)
+ for k_start in range(0, K, 128):
+ offs_k_even = k_start + tl.arange(0, 64) * 2
+ offs_k_odd = k_start + tl.arange(0, 64) * 2 + 1
+ offs_kh = k_start // 2 + tl.arange(0, 64)
k_mask_even = offs_k_even < K
k_mask_odd = offs_k_odd < K
@@ -73,7 +70,7 @@
w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
- g_idx = k_start // BLOCK_K
+ g_idx = k_start // 128
s = tl.load(scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
mask=n_mask, other=0.0)
z = tl.load(zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<1838 chars \u2014 see below>",
"newString": "<1775 chars \u2014 see below>"
} BLOCK_K = 128
BLOCK_KH = 64
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = offs_m < M
n_mask = offs_n < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
offs_k_even = k_start + tl.arange(0, BLOCK_KH) * 2
offs_k_odd = k_start + tl.arange(0, BLOCK_KH) * 2 + 1
offs_kh = k_start // 2 + tl.arange(0, BLOCK_KH)
k_mask_even = offs_k_even < K
k_mask_odd = offs_k_odd < K
kh_mask = offs_kh < K // 2
x_even = tl.load(
x_ptr + offs_m[:, None] * stride_x_m + offs_k_even[None, :] * stride_x_k,
mask=m_mask[:, None] & k_mask_even[None, :], other=0.0
)
x_odd = tl.load(
x_ptr + offs_m[:, None] * stride_x_m + offs_k_odd[None, :] * stride_x_k,
mask=m_mask[:, None] & k_mask_odd[None, :], other=0.0
)
w_packed = tl.load(
w_q_ptr + offs_kh[:, None] * stride_wq_kh + offs_n[None, :] * stride_wq_n,
mask=kh_mask[:, None] & n_mask[None, :],
other=0
)
w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
g_idx = k_start // BLOCK_K
s = tl.load(scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
mask=n_mask, other=0.0)
z = tl.load(zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
mask=n_mask, other=0.0)
w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
acc += tl.dot(x_even, w_even_deq)
acc += tl.dot(x_odd, w_odd_deq) pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = offs_m < M
n_mask = offs_n < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, 128):
offs_k_even = k_start + tl.arange(0, 64) * 2
offs_k_odd = k_start + tl.arange(0, 64) * 2 + 1
offs_kh = k_start // 2 + tl.arange(0, 64)
k_mask_even = offs_k_even < K
k_mask_odd = offs_k_odd < K
kh_mask = offs_kh < K // 2
x_even = tl.load(
x_ptr + offs_m[:, None] * stride_x_m + offs_k_even[None, :] * stride_x_k,
mask=m_mask[:, None] & k_mask_even[None, :], other=0.0
)
x_odd = tl.load(
x_ptr + offs_m[:, None] * stride_x_m + offs_k_odd[None, :] * stride_x_k,
mask=m_mask[:, None] & k_mask_odd[None, :], other=0.0
)
w_packed = tl.load(
w_q_ptr + offs_kh[:, None] * stride_wq_kh + offs_n[None, :] * stride_wq_n,
mask=kh_mask[:, None] & n_mask[None, :],
other=0
)
w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
g_idx = k_start // 128
s = tl.load(scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
mask=n_mask, other=0.0)
z = tl.load(zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
mask=n_mask, other=0.0)
w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
acc += tl.dot(x_even, w_even_deq)
acc += tl.dot(x_odd, w_odd_deq)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -29,7 +29,6 @@
stride_z_g, stride_z_n,
stride_out_m, stride_out_n,
M, N, K,
- GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<92 chars \u2014 see below>",
"newString": "<62 chars \u2014 see below>"
} M, N, K,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -114,7 +114,6 @@
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
M, N, K,
- GROUP_SIZE,
)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<44 chars \u2014 see below>",
"newString": " M, N, K,"
} M, N, K,
GROUP_SIZE,Edit applied successfully.{
"description": "Test compiled kernel",
"command": "<741 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
import solution
import reference
reference.M = 1
reference.N = 12288
reference.K = 4096
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).cuda().eval()
sol_model = solution.Model(*init_args).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'Shape 0 (M=1 N=12288 K=4096): Max diff = {diff:.6f}')
print(f'PASS: {diff < 0.10}')
" 2>&1Shape 0 (M=1 N=12288 K=4096): Max diff = 0.001953
PASS: True
{
"description": "Test all shapes",
"command": "<1051 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
import solution
import reference
import shapes
device = torch.device('cuda:0')
for shape_idx, shape in enumerate(shapes.SHAPES):
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
for seed in (42, 123, 456):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
inputs = [t.to(device) for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
ok = diff < 0.10
status = 'PASS' if ok else 'FAIL'
print(f'Shape {shape_idx} {shape} seed {seed}: {status} diff={diff:.6f}')
" 2>&1Shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 42: PASS diff=0.001953
Shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 123: PASS diff=0.003906
Shape 0 {'M': 1, 'N': 12288, 'K': 4096} seed 456: PASS diff=0.007812
Shape 1 {'M': 32, 'N': 12288, 'K': 4096} seed 42: PASS diff=0.031250
Shape 1 {'M': 32, 'N': 12288, 'K': 4096} seed 123: PASS diff=0.031250
Shape 1 {'M': 32, 'N': 12288, 'K': 4096} seed 456: PASS diff=0.031250
Shape 2 {'M': 256, 'N': 12288, 'K': 4096} seed 42: PASS diff=0.046875
Shape 2 {'M': 256, 'N': 12288, 'K': 4096} seed 123: PASS diff=0.031250
Shape 2 {'M': 256, 'N': 12288, 'K': 4096} seed 456: PASS diff=0.031250
Shape 3 {'M': 1, 'N': 4096, 'K': 4096} seed 42: PASS diff=0.001953
Shape 3 {'M': 1, 'N': 4096, 'K': 4096} seed 123: PASS diff=0.007812
Shape 3 {'M': 1, 'N': 4096, 'K': 4096} seed 456: PASS diff=0.007812
Shape 4 {'M': 16, 'N': 14336, 'K': 4096} seed 42: PASS diff=0.015625
Shape 4 {'M': 16, 'N': 14336, 'K': 4096} seed 123: PASS diff=0.031250
Shape 4 {'M': 16, 'N': 14336, 'K': 4096} seed 456: PASS diff=0.015625
{
"description": "Run formal correctness check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 300000
}PASS
{
"description": "Run benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python benchmark.py 2>&1",
"timeout": 300000
}/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.132 gbps=34.974 ms=0.765
shape=0 variant=compiled tflops=0.672 gbps=178.820 ms=0.150
shape=0 variant=sota tflops=2.349 gbps=624.801 ms=0.043
shape=0 variant=solution tflops=1.519 gbps=403.866 ms=0.066
shape=0 solution_peak_fraction=0.2244
shape=1 variant=eager tflops=4.235 gbps=36.534 ms=0.761
shape=1 variant=compiled tflops=21.132 gbps=182.293 ms=0.152
shape=1 variant=sota tflops=23.539 gbps=203.052 ms=0.137
shape=1 variant=solution tflops=30.184 gbps=260.375 ms=0.107
shape=1 solution_peak_fraction=0.1447
shape=2 variant=eager tflops=31.338 gbps=42.717 ms=0.822
shape=2 variant=compiled tflops=121.491 gbps=165.607 ms=0.212
shape=2 variant=sota tflops=134.915 gbps=183.905 ms=0.191
shape=2 variant=solution tflops=89.241 gbps=121.645 ms=0.289
shape=2 solution_peak_fraction=0.0676
shape=3 variant=eager tflops=0.180 gbps=47.962 ms=0.186
shape=3 variant=compiled tflops=0.506 gbps=134.639 ms=0.066
shape=3 variant=sota tflops=1.239 gbp
... (423 more chars){
"description": "Profile with torch profiler",
"command": "<901 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
import solution
import reference
import shapes
device = torch.device('cuda:0')
shape_idx = 2
shape = shapes.SHAPES[shape_idx]
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup
for _ in range(5):
sol_model(*inputs)
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
for _ in range(10):
sol_model(*inputs)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=15))
" 2>&1/home/infatoshi/.local/lib/python3.12/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
w4a16_gemm_kernel 0.00% 0.000us 0.00% 0.000us 0.000us 2.531ms 100.00% 2.531ms 253.136us 10
Activity Buffer Request 25.55% 855.743us 25.55% 855.743us 855.743us 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernelEx 6.56% 219.832us 6.56% 219.832us 21.983us 0.000us 0.00% 0.000us 0.000us 10
cudaDeviceSynchronize 67.88% 2.273ms 67.88% 2.273ms 2.273ms 0.000us 0.00% 0.000us 0.000us 1
--------------------------- ---------
... (192 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -78,8 +78,9 @@
w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
- acc += tl.dot(x_even, w_even_deq)
- acc += tl.dot(x_odd, w_odd_deq)
+ w_deq = tl.cat(w_even_deq, w_odd_deq, axis=0)
+ x_full = tl.cat(x_even, x_odd, axis=1)
+ acc += tl.dot(x_full, w_deq)
c = acc.to(tl.bfloat16)
tl.store(
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<236 chars \u2014 see below>",
"newString": "<292 chars \u2014 see below>"
} w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
acc += tl.dot(x_even, w_even_deq)
acc += tl.dot(x_odd, w_odd_deq) w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
w_deq = tl.cat(w_even_deq, w_odd_deq, axis=0)
x_full = tl.cat(x_even, x_odd, axis=1)
acc += tl.dot(x_full, w_deq)Edit applied successfully.{
"description": "Check if tl.cat exists",
"command": "<140 chars \u2014 see below>"
}python -c "
import triton.language as tl
print(hasattr(tl, 'cat'))
print(tl.cat.__doc__[:200] if hasattr(tl, 'cat') else 'not found')
" 2>&1True
Concatenate the given blocks
:param input: The first input tensor.
:type input: Tensor
:param other: The second input tensor.
:type other: Tensor
:param reorder: Compiler hin
{
"description": "Check tl.trans and reshape API details",
"command": "<73 chars \u2014 see below>"
}python -c "
import triton.language as tl
help(tl.trans)
" 2>&1 | head -30Help on function trans in module triton.language.core:
trans(input: 'tensor', *dims, _semantic=None)
Permutes the dimensions of a tensor.
If the parameter :code:`dims` is not specified, the function defaults to
swapping the last two axes, thereby performing an (optionally batched)
2D transpose.
:param input: The input tensor.
:param dims: The desired ordering of dimensions. For example,
:code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
:code:`dims` can be passed as a tuple or as individual parameters: ::
# These are equivalent
trans(x, (2, 1, 0))
trans(x, 2, 1, 0)
:py:func:`permute` is equivalent to this function, except it doesn't
have the special case when no permutation is specified.
This function can also be called as a member function on :py:class:`tensor`,
as :code:`x.trans(...)` instead of
:code:`trans(x, ...)`.
{
"description": "Check autotune selection for each shape",
"command": "<1068 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
import solution
import reference
import shapes
# Force recompilation by clearing cache
import triton
triton.runtime.cache.clear()
device = torch.device('cuda:0')
for shape_idx, shape in enumerate(shapes.SHAPES):
reference.M = shape['M']
reference.N = shape['N']
reference.K = shape['K']
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup with first call (triggers autotune)
sol_model(*inputs)
# Check what config was used
from triton.runtime.cache import get_cache
print(f'Shape {shape_idx} M={shape[\"M\"]} N={shape[\"N\"]} K={shape[\"K\"]}')
print(f' sol time: {solution.forward_time:.3f}ms' if hasattr(solution, 'forward_time') else '')
" 2>&1Traceback (most recent call last):
File "<string>", line 9, in <module>
AttributeError: module 'triton.runtime.cache' has no attribute 'clear'
{
"description": "Check Triton autotune cache location",
"command": "<198 chars \u2014 see below>"
}python -c "
from triton.runtime.cache import get_cache_manager
import pathlib
cache_dir = pathlib.Path(get_cache_manager('/tmp').cache_dir)
# Find the parent cache dir
print(cache_dir.parent)
" 2>&1Traceback (most recent call last):
File "<string>", line 4, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/cache.py", line 248, in get_cache_manager
return cls(_base32(key))
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/cache.py", line 243, in _base32
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
^^^^^^^^^^^^^^^^^^
ValueError: non-hexadecimal number found in fromhex() arg at position 0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -8,17 +8,21 @@
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'num_warps': 4}),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'num_warps': 4}),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'num_warps': 2}),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'num_warps': 4}),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'num_warps': 4}),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'num_warps': 8}),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'num_warps': 4}),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8}),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'num_warps': 4}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 512, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 2, 'num_stages': 2}),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 8, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 8, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 8, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 8, 'num_stages': 3}),
],
- key=['M', 'N'],
+ key=['M', 'N', 'K'],
)
@triton.jit
def w4a16_gemm_kernel(
@@ -29,8 +33,11 @@
stride_z_g, stride_z_n,
stride_out_m, stride_out_n,
M, N, K,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
+ KH = K // 2
+ BLOCK_KH = BLOCK_K // 2
+
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -42,23 +49,17 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_start in range(0, K, 128):
- offs_k_even = k_start + tl.arange(0, 64) * 2
- offs_k_odd = k_start + tl.arange(0, 64) * 2 + 1
- offs_kh = k_start // 2 + tl.arange(0, 64)
+ for k_start in range(0, K, BLOCK_K):
+ offs_k = k_start + tl.arange(0, BLOCK_K)
+ offs_kh = k_start // 2 + tl.arange(0, BLOCK_KH)
- k_mask_even = offs_k_even < K
- k_mask_odd = offs_k_odd < K
- kh_mask = offs_kh < K // 2
+ k_mask = offs_k < K
+ kh_mask = offs_kh < KH
- x_even = tl.load(
- x_ptr + offs_m[:, None] * stride_x_m + offs_k_even[None, :] * stride_x_k,
- mask=m_mask[:, None] & k_mask_even[None, :], other=0.0
- )
- x_odd = tl.load(
- x_ptr + offs_m[:, None] * stride_x_m + offs_k_odd[None, :] * stride_x_k,
- mask=m_mask[:, None] & k_mask_odd[None, :], other=0.0
- )
+ x_block = tl.load(
+ x_ptr + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k,
+ mask=m_mask[:, None] & k_mask[None, :], other=0.0
+ ).to(tl.bfloat16)
w_packed = tl.load(
w_q_ptr + offs_kh[:, None] * stride_wq_kh + offs_n[None, :] * stride_wq_n,
@@ -66,21 +67,40 @@
other=0
)
- w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
- w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
+ w_even_u8 = w_packed & 0xF
+ w_odd_u8 = (w_packed >> 4) & 0xF
- g_idx = k_start // 128
- s = tl.load(scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
- mask=n_mask, other=0.0)
- z = tl.load(zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
- mask=n_mask, other=0.0)
+ w_even = tl.cast(w_even_u8, tl.bfloat16)
+ w_odd = tl.cast(w_odd_u8, tl.bfloat16)
- w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
- w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
+ n_groups = BLOCK_K // GROUP_SIZE
+ for g in range(n_groups):
+ g_idx = (k_start // GROUP_SIZE) + g
+ start_kh = g * (GROUP_SIZE // 2)
- w_deq = tl.cat(w_even_deq, w_odd_deq, axis=0)
- x_full = tl.cat(x_even, x_odd, axis=1)
- acc += tl.dot(x_full, w_deq)
+ s = tl.load(
+ scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
+ mask=n_mask, other=0.0
+ )
+ z = tl.load(
+ zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
+ mask=n_mask, other=0.0
+ )
+
+ we_g = (w_even[start_kh:start_kh + GROUP_SIZE // 2] - z[None, :]) * s[None, :]
+ wo_g = (w_odd[start_kh:start_kh + GROUP_SIZE // 2] - z[None, :]) * s[None, :]
+
+ we_g = tl.cast(we_g, tl.bfloat16)
+ wo_g = tl.cast(wo_g, tl.bfloat16)
+
+ wg = tl.cat(we_g, wo_g, axis=0)
+
+ if g == 0:
+ w_block = wg
+ else:
+ w_block = tl.cat(w_block, wg, axis=0)
+
+ acc += tl.dot(x_block, w_block)
c = acc.to(tl.bfloat16)
tl.store(
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<5340 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test new kernel compiles and runs",
"command": "<756 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import os; os.environ['TRITON_CACHE_DIR'] = '/tmp/triton_cache_test'
import torch
import solution
import reference
reference.M = 1
reference.N = 12288
reference.K = 4096
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).cuda().eval()
sol_model = solution.Model(*init_args).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
inputs = [torch.randn(1, 4096, dtype=torch.bfloat16).cuda()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'Shape 0: Max diff = {diff:.6f} PASS={diff < 0.10}')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 23, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 130, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, *
... (2574 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -35,9 +35,6 @@
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
- KH = K // 2
- BLOCK_KH = BLOCK_K // 2
-
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -49,9 +46,12 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ KH = K // 2
+ NGROUPS = BLOCK_K // GROUP_SIZE
+
for k_start in range(0, K, BLOCK_K):
offs_k = k_start + tl.arange(0, BLOCK_K)
- offs_kh = k_start // 2 + tl.arange(0, BLOCK_KH)
+ offs_kh = k_start // 2 + tl.arange(0, BLOCK_K // 2)
k_mask = offs_k < K
kh_mask = offs_kh < KH
@@ -59,7 +59,7 @@
x_block = tl.load(
x_ptr + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k,
mask=m_mask[:, None] & k_mask[None, :], other=0.0
- ).to(tl.bfloat16)
+ )
w_packed = tl.load(
w_q_ptr + offs_kh[:, None] * stride_wq_kh + offs_n[None, :] * stride_wq_n,
@@ -67,16 +67,12 @@
other=0
)
- w_even_u8 = w_packed & 0xF
- w_odd_u8 = (w_packed >> 4) & 0xF
+ w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
+ w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
- w_even = tl.cast(w_even_u8, tl.bfloat16)
- w_odd = tl.cast(w_odd_u8, tl.bfloat16)
-
- n_groups = BLOCK_K // GROUP_SIZE
- for g in range(n_groups):
+ for g in range(NGROUPS):
g_idx = (k_start // GROUP_SIZE) + g
- start_kh = g * (GROUP_SIZE // 2)
+ start = g * (GROUP_SIZE // 2)
s = tl.load(
scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
@@ -87,8 +83,8 @@
mask=n_mask, other=0.0
)
- we_g = (w_even[start_kh:start_kh + GROUP_SIZE // 2] - z[None, :]) * s[None, :]
- wo_g = (w_odd[start_kh:start_kh + GROUP_SIZE // 2] - z[None, :]) * s[None, :]
+ we_g = (w_even[start:start + GROUP_SIZE // 2] - z[None, :]) * s[None, :]
+ wo_g = (w_odd[start:start + GROUP_SIZE // 2] - z[None, :]) * s[None, :]
we_g = tl.cast(we_g, tl.bfloat16)
wo_g = tl.cast(wo_g, tl.bfloat16)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<5220 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test the single-dot kernel",
"command": "<745 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import os; os.environ['TRITON_CACHE_DIR'] = '/tmp/triton_test'
import torch
import solution
import reference
reference.M = 1
reference.N = 12288
reference.K = 4096
ref_model = reference.Model(*reference.get_init_inputs()).cuda().eval()
sol_model = solution.Model(*reference.get_init_inputs()).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
inputs = [torch.randn(1, 4096, dtype=torch.bfloat16).cuda()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'Shape 0: Max diff = {diff:.6f} PASS={diff < 0.10}')
" 2>&1Traceback (most recent call last):
File "<string>", line 21, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 126, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 240, in run
benchmark()
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 229, in benchmark
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 164, in _bench
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^
... (2204 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -46,15 +46,12 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- KH = K // 2
- NGROUPS = BLOCK_K // GROUP_SIZE
-
for k_start in range(0, K, BLOCK_K):
offs_k = k_start + tl.arange(0, BLOCK_K)
offs_kh = k_start // 2 + tl.arange(0, BLOCK_K // 2)
k_mask = offs_k < K
- kh_mask = offs_kh < KH
+ kh_mask = offs_kh < K // 2
x_block = tl.load(
x_ptr + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k,
@@ -70,9 +67,9 @@
w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
- for g in range(NGROUPS):
- g_idx = (k_start // GROUP_SIZE) + g
- start = g * (GROUP_SIZE // 2)
+ for g in range(BLOCK_K // 128):
+ g_idx = (k_start // 128) + g
+ start = g * 64
s = tl.load(
scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
@@ -83,8 +80,8 @@
mask=n_mask, other=0.0
)
- we_g = (w_even[start:start + GROUP_SIZE // 2] - z[None, :]) * s[None, :]
- wo_g = (w_odd[start:start + GROUP_SIZE // 2] - z[None, :]) * s[None, :]
+ we_g = (w_even[start:start + 64] - z[None, :]) * s[None, :]
+ wo_g = (w_odd[start:start + 64] - z[None, :]) * s[None, :]
we_g = tl.cast(we_g, tl.bfloat16)
wo_g = tl.cast(wo_g, tl.bfloat16)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<5130 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test single-dot kernel",
"command": "<746 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import os; os.environ['TRITON_CACHE_DIR'] = '/tmp/triton_test2'
import torch
import solution
import reference
reference.M = 1
reference.N = 12288
reference.K = 4096
ref_model = reference.Model(*reference.get_init_inputs()).cuda().eval()
sol_model = solution.Model(*reference.get_init_inputs()).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
inputs = [torch.randn(1, 4096, dtype=torch.bfloat16).cuda()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'Shape 0: Max diff = {diff:.6f} PASS={diff < 0.10}')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 1101, in __getitem__
raise ValueError(f"unsupported tensor index: {sl}")
ValueError: unsupported tensor index: <triton.language.core.slice object at 0x702f9db60950>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 21, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 123, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime
... (2461 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -8,21 +8,21 @@
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 512, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 2, 'num_stages': 2}),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 8, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 4, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 8, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'num_warps': 8, 'num_stages': 3}),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'num_warps': 8, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 512, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'num_warps': 2, 'num_stages': 2}),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'num_warps': 8, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8, 'num_stages': 3}),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'num_warps': 8, 'num_stages': 3}),
],
- key=['M', 'N', 'K'],
+ key=['M', 'N'],
)
@triton.jit
def w4a16_gemm_kernel(
@@ -33,8 +33,11 @@
stride_z_g, stride_z_n,
stride_out_m, stride_out_n,
M, N, K,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
+ BLOCK_K = 128
+ BLOCK_KH = 64
+
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -47,15 +50,21 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
- offs_k = k_start + tl.arange(0, BLOCK_K)
- offs_kh = k_start // 2 + tl.arange(0, BLOCK_K // 2)
+ offs_k_even = k_start + tl.arange(0, BLOCK_KH) * 2
+ offs_k_odd = k_start + tl.arange(0, BLOCK_KH) * 2 + 1
+ offs_kh = k_start // 2 + tl.arange(0, BLOCK_KH)
- k_mask = offs_k < K
+ k_mask_even = offs_k_even < K
+ k_mask_odd = offs_k_odd < K
kh_mask = offs_kh < K // 2
- x_block = tl.load(
- x_ptr + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k,
- mask=m_mask[:, None] & k_mask[None, :], other=0.0
+ x_even = tl.load(
+ x_ptr + offs_m[:, None] * stride_x_m + offs_k_even[None, :] * stride_x_k,
+ mask=m_mask[:, None] & k_mask_even[None, :], other=0.0
+ )
+ x_odd = tl.load(
+ x_ptr + offs_m[:, None] * stride_x_m + offs_k_odd[None, :] * stride_x_k,
+ mask=m_mask[:, None] & k_mask_odd[None, :], other=0.0
)
w_packed = tl.load(
@@ -67,33 +76,17 @@
w_even = tl.cast(w_packed & 0xF, tl.bfloat16)
w_odd = tl.cast((w_packed >> 4) & 0xF, tl.bfloat16)
- for g in range(BLOCK_K // 128):
- g_idx = (k_start // 128) + g
- start = g * 64
+ g_idx = k_start // 128
+ s = tl.load(scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
+ mask=n_mask, other=0.0)
+ z = tl.load(zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
+ mask=n_mask, other=0.0)
- s = tl.load(
- scales_ptr + g_idx * stride_s_g + offs_n * stride_s_n,
- mask=n_mask, other=0.0
- )
- z = tl.load(
- zeros_ptr + g_idx * stride_z_g + offs_n * stride_z_n,
- mask=n_mask, other=0.0
- )
+ w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
+ w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
- we_g = (w_even[start:start + 64] - z[None, :]) * s[None, :]
- wo_g = (w_odd[start:start + 64] - z[None, :]) * s[None, :]
-
- we_g = tl.cast(we_g, tl.bfloat16)
- wo_g = tl.cast(wo_g, tl.bfloat16)
-
- wg = tl.cat(we_g, wo_g, axis=0)
-
- if g == 0:
- w_block = wg
- else:
- w_block = tl.cat(w_block, wg, axis=0)
-
- acc += tl.dot(x_block, w_block)
+ acc += tl.dot(x_even, w_even_deq)
+ acc += tl.dot(x_odd, w_odd_deq)
c = acc.to(tl.bfloat16)
tl.store(
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"content": "<4892 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test correctness and benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 112, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/check.py", line 73, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
... (2792 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -35,9 +35,6 @@
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
- BLOCK_K = 128
- BLOCK_KH = 64
-
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -49,10 +46,10 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_start in range(0, K, BLOCK_K):
- offs_k_even = k_start + tl.arange(0, BLOCK_KH) * 2
- offs_k_odd = k_start + tl.arange(0, BLOCK_KH) * 2 + 1
- offs_kh = k_start // 2 + tl.arange(0, BLOCK_KH)
+ for k_start in range(0, K, 128):
+ offs_k_even = k_start + tl.arange(0, 64) * 2
+ offs_k_odd = k_start + tl.arange(0, 64) * 2 + 1
+ offs_kh = k_start // 2 + tl.arange(0, 64)
k_mask_even = offs_k_even < K
k_mask_odd = offs_k_odd < K
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<527 chars \u2014 see below>",
"newString": "<468 chars \u2014 see below>"
} BLOCK_K = 128
BLOCK_KH = 64
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = offs_m < M
n_mask = offs_n < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
offs_k_even = k_start + tl.arange(0, BLOCK_KH) * 2
offs_k_odd = k_start + tl.arange(0, BLOCK_KH) * 2 + 1
offs_kh = k_start // 2 + tl.arange(0, BLOCK_KH) pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = offs_m < M
n_mask = offs_n < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, 128):
offs_k_even = k_start + tl.arange(0, 64) * 2
offs_k_odd = k_start + tl.arange(0, 64) * 2 + 1
offs_kh = k_start // 2 + tl.arange(0, 64)Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py</path>
<type>file</type>
<content>
1: import torch
2: import torch.nn as nn
3: import triton
4: import triton.language as tl
5:
6: GROUP_SIZE = 128
7:
8:
9: @triton.autotune(
10: configs=[
11: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 512, 'num_warps': 4, 'num_stages': 3}),
12: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 256, 'num_warps': 4, 'num_stages': 3}),
13: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
14: triton.Config({'BLOCK_M': 1, 'BLOCK_N': 64, 'num_warps': 2, 'num_stages': 2}),
15: triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'num_warps': 4, 'num_stages': 3}),
16: triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
17: triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
18: triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'num_warps': 4, 'num_stages': 3}),
19: triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
20: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'num_warps': 8, 'num_stages': 3}),
21: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'num_warps': 4, 'num_stages': 3}),
22: triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8, 'num_stages': 3}),
23: triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'num_warps': 8, 'num_stages': 3}),
24: ],
25: key=
... (4062 more chars){
"description": "Test the kernel compiles and runs",
"command": "<746 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import os; os.environ['TRITON_CACHE_DIR'] = '/tmp/triton_test3'
import torch
import solution
import reference
reference.M = 1
reference.N = 12288
reference.K = 4096
ref_model = reference.Model(*reference.get_init_inputs()).cuda().eval()
sol_model = solution.Model(*reference.get_init_inputs()).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
inputs = [torch.randn(1, 4096, dtype=torch.bfloat16).cuda()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'Shape 0: Max diff = {diff:.6f} PASS={diff < 0.10}')
" 2>&1Shape 0: Max diff = 0.001953 PASS=True
{
"description": "Run full check and benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python check.py 2>&1 && echo \"---\" && python benchmark.py 2>&1",
"timeout": 600000
}PASS
---
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
/home/infatoshi/.local/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py:468: FutureWarning: _check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. Use _check(i >= 0) instead.
torch._check_is_size(blocksize)
shape=0 variant=eager tflops=0.131 gbps=34.956 ms=0.766
shape=0 variant=compiled tflops=0.674 gbps=179.318 ms=0.149
shape=0 variant=sota tflops=2.321 gbps=617.195 ms=0.043
shape=0 variant=solution tflops=1.524 gbps=405.432 ms=0.066
shape=0 solution_peak_fraction=0.2252
shape=1 variant=eager tflops=4.166 gbps=35.937 ms=0.773
shape=1 variant=compiled tflops=20.649 gbps=178.123 ms=0.156
shape=1 variant=sota tflops=23.369 gbps=201.591 ms=0.138
shape=1 variant=solution tflops=33.973 gbps=293.065 ms=0.095
shape=1 solution_peak_fraction=0.1628
shape=2 variant=eager tflops=31.242 gbps=42.586 ms=0.825
shape=2 variant=compiled tflops=121.528 gbps=165.657 ms=0.212
shape=2 variant=sota tflops=134.543 gbps=183.398 ms=0.192
shape=2 variant=solution tflops=91.237 gbps=124.367 ms=0.282
shape=2 solution_peak_fraction=0.0691
shape=3 variant=eager tflops=0.177 gbps=47.171 ms=0.189
shape=3 variant=compiled tflops=0.496 gbps=131.965 ms=0.068
shape=3 variant=sota tflops=
... (432 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -82,8 +82,9 @@
w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
- acc += tl.dot(x_even, w_even_deq)
- acc += tl.dot(x_odd, w_odd_deq)
+ w_block = tl.cat(w_even_deq, w_odd_deq, axis=0)
+ x_block = tl.cat(x_even, x_odd, axis=1)
+ acc += tl.dot(x_block, w_block)
c = acc.to(tl.bfloat16)
tl.store(
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<236 chars \u2014 see below>",
"newString": "<298 chars \u2014 see below>"
} w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
acc += tl.dot(x_even, w_even_deq)
acc += tl.dot(x_odd, w_odd_deq) w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
w_block = tl.cat(w_even_deq, w_odd_deq, axis=0)
x_block = tl.cat(x_even, x_odd, axis=1)
acc += tl.dot(x_block, w_block)Edit applied successfully.{
"description": "Test correctness of single-dot approach",
"command": "<1222 chars \u2014 see below>",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import os; os.environ['TRITON_CACHE_DIR'] = '/tmp/triton_test4'
import torch
import solution
import reference
for shape_info in [
{'M': 1, 'N': 12288, 'K': 4096},
{'M': 32, 'N': 12288, 'K': 4096},
{'M': 256, 'N': 12288, 'K': 4096},
{'M': 1, 'N': 4096, 'K': 4096},
{'M': 16, 'N': 14336, 'K': 4096},
]:
reference.M = shape_info['M']
reference.N = shape_info['N']
reference.K = shape_info['K']
ref_model = reference.Model(*reference.get_init_inputs()).cuda().eval()
sol_model = solution.Model(*reference.get_init_inputs()).cuda().eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
for seed in (42, 123, 456):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
inputs = [t.cuda() for t in reference.get_inputs()]
with torch.no_grad():
ref_out = ref_model(*inputs)
sol_out = sol_model(*inputs)
diff = (ref_out - sol_out).abs().max().item()
print(f'M={shape_info[\"M\"]} N={shape_info[\"N\"]} K={shape_info[\"K\"]} seed={seed} diff={diff:.6f} {"PASS" if diff < 0.10 else "FAIL"}')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
TypeError: cat() got an unexpected keyword argument 'axis'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 30, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py", line 114, in forward
w4a16_gemm_kernel[grid](
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 240, in run
benchmark()
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 229, in benchmark
timings = {config: self._bench(*arg
... (2369 more chars){
"description": "Check tl.cat API signature",
"command": "<71 chars \u2014 see below>"
}python -c "
import triton.language as tl
help(tl.cat)
" 2>&1 | head -30Help on function cat in module triton.language.core:
cat(input, other, can_reorder=False, _semantic=None)
Concatenate the given blocks
:param input: The first input tensor.
:type input: Tensor
:param other: The second input tensor.
:type other: Tensor
:param reorder: Compiler hint. If true, the compiler is
allowed to reorder elements while concatenating inputs. Only use if the
order does not matter (e.g., result is only used in reduction ops).
Current implementation of `cat` supports only can_reorder=True.
{
"description": "Debug autotune selection and try shapes",
"command": "<226 chars \u2014 see below>",
"timeout": 10000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import os; os.environ['TRITON_CACHE_DIR'] = '/tmp/triton_test5'
import torch
import solution
solution
# Revert to two-dot approach (no cat)
" 2>&1(no output)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py
@@ -82,9 +82,8 @@
w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
- w_block = tl.cat(w_even_deq, w_odd_deq, axis=0)
- x_block = tl.cat(x_even, x_odd, axis=1)
- acc += tl.dot(x_block, w_block)
+ acc += tl.dot(x_even, w_even_deq)
+ acc += tl.dot(x_odd, w_odd_deq)
c = acc.to(tl.bfloat16)
tl.store(
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm/solution.py",
"oldString": "<298 chars \u2014 see below>",
"newString": "<236 chars \u2014 see below>"
} w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
w_block = tl.cat(w_even_deq, w_odd_deq, axis=0)
x_block = tl.cat(x_even, x_odd, axis=1)
acc += tl.dot(x_block, w_block) w_even_deq = tl.cast((w_even - z[None, :]) * s[None, :], tl.bfloat16)
w_odd_deq = tl.cast((w_odd - z[None, :]) * s[None, :], tl.bfloat16)
acc += tl.dot(x_even, w_even_deq)
acc += tl.dot(x_odd, w_odd_deq)Edit applied successfully.{
"description": "Test if tl.cat preserves order",
"command": "<887 chars \u2014 see below>",
"timeout": 30000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/07_w4a16_gemm && python -c "
import torch
import triton
import triton.language as tl
@triton.jit
def test_cat_kernel(a_ptr, b_ptr, out_ptr, N: tl.constexpr):
offs = tl.arange(0, N)
a = tl.load(a_ptr + offs)
b = tl.load(b_ptr + offs)
# Cat along axis 0
c = tl.cat(a, b, can_reorder=True)
# Store result
tl.store(out_ptr + tl.arange(0, 2 * N), c)
N = 8
a = torch.arange(N, dtype=torch.float32, device='cuda')
b = torch.arange(N, dtype=torch.float32, device='cuda') + 100
out = torch.zeros(2 * N, dtype=torch.float32, device='cuda')
test_cat_kernel[(1,)](a, b, out, N)
print('a:', a.tolist())
print('b:', b.tolist())
print('cat:', out.tolist())
print('Expected: [0,1,2,3,4,5,6,7,100,101,102,103,104,105,106,107]')
print('Match:', out.tolist() == list(range(N)) + list(range(100, 100+N)))
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1260, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1089, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file
{
"description": "Test tl.cat order preservation",
"command": "<1607 chars \u2014 see below>",
"timeout": 30000
}cat > /tmp/test_cat.py << 'EOF'
import torch
import triton
import triton.language as tl
@triton.jit
def test_cat_kernel(a_ptr, b_ptr, out_ptr, N: tl.constexpr):
offs = tl.arange(0, N)
a = tl.load(a_ptr + offs)
b = tl.load(b_ptr + offs)
c = tl.cat(a, b, can_reorder=True)
tl.store(out_ptr + tl.arange(0, 2 * N), c)
@triton.jit
def test_cat_2d_kernel(a_ptr, b_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr):
offs_m = tl.arange(0, M)
offs_n = tl.arange(0, N)
a = tl.load(a_ptr + offs_m[:, None] * N + offs_n[None, :])
b = tl.load(b_ptr + offs_m[:, None] * N + offs_n[None, :])
c = tl.cat(a, b, can_reorder=True)
tl.store(out_ptr + tl.arange(0, 2 * M)[:, None] * N + offs_n[None, :], c)
N = 8
a = torch.arange(N, dtype=torch.float32, device='cuda')
b = torch.arange(N, dtype=torch.float32, device='cuda') + 100
out = torch.zeros(2 * N, dtype=torch.float32, device='cuda')
test_cat_kernel[(1,)](a, b, out, N)
print('1D cat:', out.tolist())
print('Match:', out.tolist() == list(range(N)) + list(range(100, 100+N)))
M, N = 4, 4
a = torch.arange(M * N, dtype=torch.float32, device='cuda').reshape(M, N)
b = torch.arange(M * N, dtype=torch.float32, device='cuda').reshape(M, N) + 100
out = torch.zeros(2 * M, N, dtype=torch.float32, device='cuda')
test_cat_2d_kernel[(1,)](a, b, out, M, N)
print('2D cat:')
print('a:', a.cpu().tolist())
print('b:', b.cpu().tolist())
print('cat:', out.cpu().tolist())
expected = torch.cat([a, b], dim=0).cpu().tolist()
print('Expected:', expected)
print('Match:', out.cpu().tolist() == expected)
EOF
cd /tmp && python test_cat.py 2>&1