import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
b_col_mask = offs_bn < N
# Hoist invariant pointer offsets outside the loop
a_base_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half)[None, :] * stride_xk)
a_base_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half + 1)[None, :] * stride_xk)
b_base_ptrs = w_ptr + (offs_b_k[:, None] * stride_wk + offs_bn[None, :] * stride_wn)
scale_base_ptrs = scales_ptr + offs_bn * stride_sn
zero_base_ptrs = zeros_ptr + offs_bn * stride_zn
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even = tl.load(a_base_even_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_base_odd_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_packed = tl.load(b_base_ptrs + (k_curr // 2) * stride_wk, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale = tl.load(scale_base_ptrs + k_group * stride_sm, mask=b_col_mask, other=0.0)
zero = tl.load(zero_base_ptrs + k_group * stride_zm, mask=b_col_mask, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.register_buffer("scales", torch.empty((n_groups, N), dtype=torch.bfloat16))
self.register_buffer("zeros", torch.empty((n_groups, N), dtype=torch.bfloat16))
def forward(self, x: torch.Tensor, config_override=None) -> torch.Tensor:
M, K = x.shape
N = self.w_q.shape[1]
out = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
if config_override is not None:
BLOCK_SIZE_M = config_override['BLOCK_SIZE_M']
BLOCK_SIZE_N = config_override['BLOCK_SIZE_N']
BLOCK_SIZE_K = config_override['BLOCK_SIZE_K']
num_warps = config_override['num_warps']
num_stages = config_override['num_stages']
else:
# Optimal hand-tuned dispatcher configs
if M == 1 and N == 12288:
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = 128
num_warps = 4
num_stages = 5
elif M == 32 and N == 12288:
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 128
num_warps = 8
num_stages = 2
elif M == 256 and N == 12288:
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 256
BLOCK_SIZE_K = 64
num_warps = 4
num_stages = 3
elif M == 1 and N == 4096:
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = 128
num_warps = 8
num_stages = 3
elif M == 16 and N == 14336:
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 128
num_warps = 8
num_stages = 2
else:
# Default fallback
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 128
num_warps = 4
num_stages = 2
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
w4a16_gemm_kernel[grid](
x, self.w_q, self.scales, self.zeros, out,
M, N,
x.stride(0), x.stride(1),
self.w_q.stride(0), self.w_q.stride(1),
self.scales.stride(0), self.scales.stride(1),
self.zeros.stride(0), self.zeros.stride(1),
out.stride(0), out.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE=self.group_size,
K=K,
num_warps=num_warps,
num_stages=num_stages,
)
return out
def get_inputs():
from reference import get_inputs as ref_get_inputs
return ref_get_inputs()
def get_init_inputs():
from reference import get_init_inputs as ref_get_init_inputs
return ref_get_init_inputs()
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_114239_gemini_gemini-3.5-flash_07_w4a16_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-13T19:23:12.806476+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-13T19:23:13.021228+00:00 elapsed_s=0.215 ms=0.067168
shape=0 variant=solution tflops=1.499 gbps=398.575 ms=0.067
shape=0 solution_peak_fraction=0.2214
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-13T19:23:13.358908+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-13T19:23:13.368701+00:00 elapsed_s=0.010 ms=0.084848
shape=1 variant=solution tflops=37.965 gbps=327.495 ms=0.085
shape=1 solution_peak_fraction=0.1819
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-13T19:23:13.714634+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-13T19:23:13.733748+00:00 elapsed_s=0.019 ms=0.239920
shape=2 variant=solution tflops=107.410 gbps=146.413 ms=0.240
shape=2 solution_peak_fraction=0.0813
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-13T19:23:13.847926+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-13T19:23:13.992349+00:00 elapsed_s=0.144 ms=0.033744
shape=3 variant=solution tflops=0.994 gbps=264.618 ms=0.034
shape=3 solution_peak_fraction=0.1470
benchmark_event event=variant_start shape=4 variant=solution ts=2026-06-13T19:23:14.394627+00:00
benchmark_event event=variant_end shape=4 variant=solution ts=2026-06-13T19:23:14.569948+00:00 elapsed_s=0.175 ms=0.057328
shape=4 variant=solution tflops=32.777 gbps=554.440 ms=0.057
shape=4 solution_peak_fraction=0.3080
peak_fraction: 0.1715
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_114239_gemini_gemini-3.5-flash_07_w4a16_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_114239_gemini_gemini-3.5-flash_07_w4a16_gemm/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_114239_gemini_gemini-3.5-flash_07_w4a16_gemm/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_114239_gemini_gemini-3.5-flash_07_w4a16_gemm/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_114239_gemini_gemini-3.5-flash_07_w4a16_gemm/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260613_114239_gemini_gemini-3.5-flash_07_w4a16_gemm/repo
Installed 56 packages in 65ms
PASS
{
"run_id": "20260613_114239_gemini_gemini-3.5-flash_07_w4a16_gemm",
"run_group": "",
"problem": "07_w4a16_gemm",
"harness": "gemini",
"model": "gemini-3.5-flash",
"reasoning_effort": "",
"started_at": "2026-06-13T11:42:39-06:00",
"harness_finished_at": "2026-06-13T13:21:44-06:00",
"finished_at": "2026-06-13T13:23:14-06:00",
"start_epoch": 1781372559,
"harness_end_epoch": 1781378504,
"end_epoch": 1781378594,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.1715,
"template_mutated": false,
"elapsed_seconds": 5945,
"total_elapsed_seconds": 6035,
"check_elapsed_seconds": 64,
"benchmark_elapsed_seconds": 26,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 5.67434819175778,
"usage": {"input_tokens": 3587222, "output_tokens": 33734, "cache_read_tokens": 3135886, "cache_creation_tokens": null, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"strategic_intent": "Researching the reference implementation, checking the GPU environment, and understanding the benchmark and verification scripts.",
"title": "Research & Baseline Establishment",
"summary": "We are beginning the W4A16 GEMM optimization task. First, we need to inspect the existing files: reference.py, check.py, benchmark.py, and PROMPT.txt. We also need to query the available GPU (architecture, SM count, CUDA version) to design our high-performance Triton or PyTorch/CUDA C++ kernel properly."
}## 📂 Topic: **Research & Baseline Establishment**
**Summary:**
We are beginning the W4A16 GEMM optimization task. First, we need to inspect the existing files: reference.py, check.py, benchmark.py, and PROMPT.txt. We also need to query the available GPU (architecture, SM count, CUDA version) to design our high-performance Triton or PyTorch/CUDA C++ kernel properly.
> [!STRATEGY]
> **Intent:** Researching the reference implementation, checking the GPU environment, and understanding the benchmark and verification scripts.{
"start_line": 1,
"file_path": "reference.py",
"end_line": 200
}{
"start_line": 1,
"file_path": "PROMPT.txt",
"end_line": 200
}{
"end_line": 200,
"start_line": 1,
"file_path": "check.py"
}{
"file_path": "benchmark.py",
"end_line": 200,
"start_line": 1
}{
"end_line": 200,
"file_path": "shapes.py",
"start_line": 1
}{
"end_line": 200,
"file_path": "problem.yaml",
"start_line": 1
}{
"command": "nvidia-smi && uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_name(0)); print(torch.cuda.get_device_properties(0))\"",
"description": "Get details about the GPU hardware, CUDA version, PyTorch version, and device capabilities."
}{
"description": "Check the installed Triton version.",
"command": "uv run python -c \"import triton; print(triton.__version__)\""
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"command": "uv pip list",
"description": "List installed packages in the workspace virtual environment."
}{
"file_path": "sota.py",
"end_line": 200,
"start_line": 1
}{
"content": "<6155 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 8}),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
],
key=['M', 'N', 'K'],
)
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
BLOCK_SIZE_K_half = BLOCK_SIZE_K // 2
offs_k_half = tl.arange(0, BLOCK_SIZE_K_half)
offs_b_k = tl.arange(0, BLOCK_SIZE_K_half)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half)[None, :] * stride_xk)
a_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half + 1)[None, :] * stride_xk)
a_even = tl.load(a_even_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_odd_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_ptrs = w_ptr + ((k_curr // 2 + offs_b_k[:, None]) * stride_wk + offs_bn[None, :] * stride_wn)
b_packed = tl.load(b_ptrs, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale_ptrs = scales_ptr + (k_group * stride_sm + offs_bn * stride_sn)
zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_zn * stride_zn)
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=0.0)
zero = tl.load(zero_ptrs, mask=offs_bn < N, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] *
... (2155 more chars){
"command": "uv run check.py",
"description": "Run correctness verification on the custom Triton kernel."
}{
"allow_multiple": false,
"file_path": "solution.py",
"instruction": "Use BLOCK_SIZE_K // 2 directly in tl.arange so that it is recognized as a constexpr.",
"new_string": "<96 chars \u2014 see below>",
"old_string": "<138 chars \u2014 see below>"
} offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2) BLOCK_SIZE_K_half = BLOCK_SIZE_K // 2
offs_k_half = tl.arange(0, BLOCK_SIZE_K_half)
offs_b_k = tl.arange(0, BLOCK_SIZE_K_half){
"command": "uv run check.py",
"description": "Run correctness verification again."
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"file_path": "test_simple.py",
"content": "<764 chars \u2014 see below>"
}import torch
import reference
import solution
device = torch.device("cuda:0")
M, N, K = 1, 12288, 4096
print("Initializing reference model...")
ref_model = reference.Model(M, N, K).to(device).eval()
print("Initializing solution model...")
sol_model = solution.Model(M, N, K).to(device).eval()
print("Loading state dict...")
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
print("Getting inputs...")
inputs = [t.to(device) for t in reference.get_inputs()]
print("Running reference forward...")
ref_out = ref_model(*inputs)
print("Reference forward done.")
print("Running solution forward...")
sol_out = sol_model(*inputs)
print("Solution forward done.")
diff = (ref_out - sol_out).abs().max()
print(f"Max absolute difference: {diff.item()}")
{
"description": "Run the simple debug script.",
"command": "uv run test_simple.py"
}{
"instruction": "Correct the variable name `offs_zn` to `offs_bn` in the zero pointer calculation.",
"file_path": "solution.py",
"new_string": " zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_bn * stride_zn)",
"allow_multiple": false,
"old_string": " zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_zn * stride_zn)"
}{
"description": "Run the simple debug script again.",
"command": "uv run test_simple.py"
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"content": "<896 chars \u2014 see below>",
"file_path": "triton_vec_add.py"
}import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
x = torch.rand(10000, device='cuda')
y = torch.rand(10000, device='cuda')
res = add(x, y)
print("SUCCESS!", res[:5])
{
"description": "Run the simple vector add test script.",
"command": "uv run triton_vec_add.py"
}{
"content": "<5303 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N, K,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half)[None, :] * stride_xk)
a_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half + 1)[None, :] * stride_xk)
a_even = tl.load(a_even_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_odd_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_ptrs = w_ptr + ((k_curr // 2 + offs_b_k[:, None]) * stride_wk + offs_bn[None, :] * stride_wn)
b_packed = tl.load(b_ptrs, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale_ptrs = scales_ptr + (k_group * stride_sm + offs_bn * stride_sn)
zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_bn * stride_zn)
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=0.0)
zero = tl.load(zero_ptrs, mask=offs_bn < N, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.register_buffer("scales", torch.empty((n_groups, N), dtype=torch.bfloat16))
self.register_buffer("zeros", torch.empty((n_groups, N), dtype=torch.bfloat
... (1303 more chars){
"command": "uv run test_simple.py",
"description": "Run the simple debug script with fixed config."
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"content": "<5335 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half)[None, :] * stride_xk)
a_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half + 1)[None, :] * stride_xk)
a_even = tl.load(a_even_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_odd_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_ptrs = w_ptr + ((k_curr // 2 + offs_b_k[:, None]) * stride_wk + offs_bn[None, :] * stride_wn)
b_packed = tl.load(b_ptrs, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale_ptrs = scales_ptr + (k_group * stride_sm + offs_bn * stride_sn)
zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_bn * stride_zn)
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=0.0)
zero = tl.load(zero_ptrs, mask=offs_bn < N, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.register_buffer("scales", torch.empty((n_groups, N), dtype=torch.bfloat16))
self.register_buffer("zeros", torch.empty((n_groups, N),
... (1335 more chars){
"description": "Run the simple debug script with constexpr K.",
"command": "uv run test_simple.py"
}{
"command": "uv run check.py",
"description": "Run the full correctness check across all shapes and seeds."
}{
"description": "Run performance benchmarks across all shapes with baselines.",
"command": "KBH_BENCHMARK_BASELINES=1 uv run benchmark.py"
}{
"strategic_intent": "Re-integrating autotuner with constexpr K and benchmarking performance.",
"title": "Autotuning & Performance Tuning",
"summary": "By identifying that K must be a constexpr in the Triton kernel signature, we solved the compilation hang. Now we will re-enable the Triton autotuner to select the best tile sizes (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, num_stages, num_warps) for each M, N, K shape. This will automatically improve performance, especially for larger M shapes (like M=256) where our hardcoded block sizes were suboptimal."
}## 📂 Topic: **Autotuning & Performance Tuning**
**Summary:**
By identifying that K must be a constexpr in the Triton kernel signature, we solved the compilation hang. Now we will re-enable the Triton autotuner to select the best tile sizes (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, num_stages, num_warps) for each M, N, K shape. This will automatically improve performance, especially for larger M shapes (like M=256) where our hardcoded block sizes were suboptimal.
> [!STRATEGY]
> **Intent:** Re-integrating autotuner with constexpr K and benchmarking performance.{
"file_path": "solution.py",
"content": "<5918 chars \u2014 see below>"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 8}),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 4}),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'num_stages': 4, 'num_warps': 8}),
],
key=['M', 'N'],
)
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half)[None, :] * stride_xk)
a_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half + 1)[None, :] * stride_xk)
a_even = tl.load(a_even_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_odd_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_ptrs = w_ptr + ((k_curr // 2 + offs_b_k[:, None]) * stride_wk + offs_bn[None, :] * stride_wn)
b_packed = tl.load(b_ptrs, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale_ptrs = scales_ptr + (k_group * stride_sm + offs_bn * stride_sn)
zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_bn * stride_zn)
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=0.0)
zero = tl.load(zero_ptrs, mask=offs_bn < N, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid
... (1918 more chars){
"description": "Run the simple debug script with autotuning.",
"command": "uv run test_simple.py"
}Command was automatically cancelled because it exceeded the timeout of 5.0 minutes without output.{
"file_path": "solution.py",
"content": "<5896 chars \u2014 see below>"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half)[None, :] * stride_xk)
a_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half + 1)[None, :] * stride_xk)
a_even = tl.load(a_even_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_odd_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_ptrs = w_ptr + ((k_curr // 2 + offs_b_k[:, None]) * stride_wk + offs_bn[None, :] * stride_wn)
b_packed = tl.load(b_ptrs, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale_ptrs = scales_ptr + (k_group * stride_sm + offs_bn * stride_sn)
zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_bn * stride_zn)
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=0.0)
zero = tl.load(zero_ptrs, mask=offs_bn < N, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.register_buffer("scales", torch.empty((n_groups, N), dtype=torch.bfloat16))
self.register_buffer("zeros", torch.empty((n_groups, N),
... (1896 more chars){
"description": "Verify correctness with the manual dispatcher.",
"command": "uv run check.py"
}{
"command": "KBH_BENCHMARK_BASELINES=1 uv run benchmark.py",
"description": "Run performance benchmark with manual dispatcher."
}{
"content": "<9200 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemv_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
N,
stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_on,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid_n = tl.program_id(0)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
b_mask = offs_bn < N
accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
BLOCK_SIZE_K_half = BLOCK_SIZE_K // 2
offs_k_half = tl.arange(0, BLOCK_SIZE_K_half)
offs_b_k = tl.arange(0, BLOCK_SIZE_K_half)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Load x_even and x_odd as 1D vectors
x_even_ptrs = x_ptr + (k_curr + 2 * offs_k_half) * stride_xk
x_odd_ptrs = x_ptr + (k_curr + 2 * offs_k_half + 1) * stride_xk
x_even = tl.load(x_even_ptrs, mask=(k_curr + 2 * offs_k_half) < K, other=0.0)
x_odd = tl.load(x_odd_ptrs, mask=(k_curr + 2 * offs_k_half + 1) < K, other=0.0)
# Load packed weights
b_ptrs = w_ptr + ((k_curr // 2 + offs_b_k[:, None]) * stride_wk + offs_bn[None, :] * stride_wn)
b_packed = tl.load(b_ptrs, mask=b_mask[None, :] & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale_ptrs = scales_ptr + (k_group * stride_sm + offs_bn * stride_sn)
zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_bn * stride_zn)
scale = tl.load(scale_ptrs, mask=b_mask, other=0.0)
zero = tl.load(zero_ptrs, mask=b_mask, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Element-wise multiply and accumulate
prod_even = x_even[:, None] * w_even
prod_odd = x_odd[:, None] * w_odd
# Sum along columns to get N outputs
accumulator += tl.sum((prod_even + prod_odd).to(tl.float32), axis=0)
# Write to output vector
out_ptrs = out_ptr + offs_bn * stride_on
tl.store(out_ptrs, accumulator.to(tl.bfloat16), mask=b_mask)
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half)[None, :] * stride_xk)
a_odd
... (5200 more chars){
"description": "Verify correctness of the specialized GEMV path.",
"command": "uv run check.py"
}{
"instruction": "Use BLOCK_SIZE_K // 2 directly in tl.arange for w4a16_gemv_kernel.",
"new_string": "<96 chars \u2014 see below>",
"file_path": "solution.py",
"old_string": "<138 chars \u2014 see below>",
"allow_multiple": false
} offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2) BLOCK_SIZE_K_half = BLOCK_SIZE_K // 2
offs_k_half = tl.arange(0, BLOCK_SIZE_K_half)
offs_b_k = tl.arange(0, BLOCK_SIZE_K_half){
"description": "Verify correctness with the updated GEMV kernel.",
"command": "uv run check.py"
}{
"description": "Benchmark performance of GEMV and GEMM paths.",
"command": "KBH_BENCHMARK_BASELINES=1 uv run benchmark.py"
}{
"content": "<6334 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half)[None, :] * stride_xk)
a_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half + 1)[None, :] * stride_xk)
a_even = tl.load(a_even_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_odd_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_ptrs = w_ptr + ((k_curr // 2 + offs_b_k[:, None]) * stride_wk + offs_bn[None, :] * stride_wn)
b_packed = tl.load(b_ptrs, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale_ptrs = scales_ptr + (k_group * stride_sm + offs_bn * stride_sn)
zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_bn * stride_zn)
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=0.0)
zero = tl.load(zero_ptrs, mask=offs_bn < N, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.register_buffer("scales", torch.empty((n_groups, N), dtype=torch.bfloat16))
self.register_buffer("zeros", torch.empty((n_groups, N),
... (2334 more chars){
"file_path": "optimize_configs.py",
"content": "<4481 chars \u2014 see below>"
}import torch
import reference
import solution
import shapes
import time
device = torch.device("cuda:0")
# All shapes to optimize
SHAPES = shapes.SHAPES
# Hyperparameter space to explore
BLOCK_SIZE_M_OPTS = [16, 32, 64, 128]
BLOCK_SIZE_N_OPTS = [64, 128, 256]
BLOCK_SIZE_K_OPTS = [64, 128, 256]
NUM_WARPS_OPTS = [4, 8]
NUM_STAGES_OPTS = [2, 3, 4, 5]
print("Starting grid search for optimal configurations...")
best_configs = {}
for idx, shape in enumerate(SHAPES):
M, N, K = shape["M"], shape["N"], shape["K"]
print(f"\n==================================================")
print(f"Optimizing Shape {idx}: M={M}, N={N}, K={K}")
print(f"==================================================")
# Initialize reference and solution models for this shape
reference.M = M
reference.N = N
reference.K = K
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model = solution.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
inputs = [t.to(device) for t in reference.get_inputs()]
x = inputs[0]
# Warmup and baseline
with torch.no_grad():
ref_out = ref_model(x)
# Check standard default forward runs
_ = sol_model(x)
best_ms = float('inf')
best_cfg = None
# Grid search
count = 0
for BM in BLOCK_SIZE_M_OPTS:
if BM < 16:
continue
for BN in BLOCK_SIZE_N_OPTS:
for BK in BLOCK_SIZE_K_OPTS:
# Group size along K is 128, so BK should be a multiple of group size
# or group size should be a multiple of BK. To be aligned with GROUP_SIZE (128),
# BK must be either 128 or 256. If BK is 64, we need scales to be aligned.
# Since we use k_group = k_curr // 128, BK=64 also works as long as BK is divisor of 128.
# So BK of 64, 128, 256 are all perfectly valid.
for warps in NUM_WARPS_OPTS:
for stages in NUM_STAGES_OPTS:
cfg = {
'BLOCK_SIZE_M': BM,
'BLOCK_SIZE_N': BN,
'BLOCK_SIZE_K': BK,
'num_warps': warps,
'num_stages': stages
}
try:
# Warmup
with torch.no_grad():
sol_out = sol_model(x, config_override=cfg)
# Basic correctness check
if not torch.allclose(ref_out, sol_out, atol=0.10, rtol=0.10):
continue
# Time it
start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
_ = sol_model(x, config_override=cfg)
torch.cuda.synchronize()
start_evt.record()
for _ in range(50):
_ = sol_model(x, config_override=cfg)
end_evt.record()
torch.cuda.synchronize()
ms = start_evt.elapsed_time(end_evt) / 50.0
count += 1
if ms < best_ms:
best_ms = ms
best_cfg = cfg
print(f" New Best: {cfg} -> {ms:.4f} ms")
except Exception as e:
... (481 more chars){
"description": "Run the configuration optimizer script.",
"command": "uv run optimize_configs.py"
}{
"content": "<6822 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half)[None, :] * stride_xk)
a_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (k_curr + 2 * offs_k_half + 1)[None, :] * stride_xk)
a_even = tl.load(a_even_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_odd_ptrs, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_ptrs = w_ptr + ((k_curr // 2 + offs_b_k[:, None]) * stride_wk + offs_bn[None, :] * stride_wn)
b_packed = tl.load(b_ptrs, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale_ptrs = scales_ptr + (k_group * stride_sm + offs_bn * stride_sn)
zero_ptrs = zeros_ptr + (k_group * stride_zm + offs_bn * stride_zn)
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=0.0)
zero = tl.load(zero_ptrs, mask=offs_bn < N, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.register_buffer("scales", torch.empty((n_groups, N), dtype=torch.bfloat16))
self.register_buffer("zeros", torch.empty((n_groups, N),
... (2822 more chars){
"description": "Verify correctness with hand-tuned optimal configs.",
"command": "uv run check.py"
}{
"command": "KBH_BENCHMARK_BASELINES=1 uv run benchmark.py",
"description": "Run the benchmark to measure our final performance results."
}{
"file_path": "solution.py",
"content": "<6926 chars \u2014 see below>"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
# Hoist invariant pointer offsets outside the loop
a_base_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half)[None, :] * stride_xk)
a_base_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half + 1)[None, :] * stride_xk)
b_base_ptrs = w_ptr + (offs_b_k[:, None] * stride_wk + offs_bn[None, :] * stride_wn)
scale_base_ptrs = scales_ptr + offs_bn * stride_sn
zero_base_ptrs = zeros_ptr + offs_bn * stride_zn
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even = tl.load(a_base_even_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_base_odd_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_packed = tl.load(b_base_ptrs + (k_curr // 2) * stride_wk, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale = tl.load(scale_base_ptrs + k_group * stride_sm, mask=b_mask, other=0.0)
zero = tl.load(zero_base_ptrs + k_group * stride_zm, mask=b_mask, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.register_buffer("scales", torch.empty
... (2926 more chars){
"command": "uv run check.py",
"description": "Verify correctness of loop-hoisted Triton kernel."
}{
"content": "<6963 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
b_col_mask = offs_bn < N
# Hoist invariant pointer offsets outside the loop
a_base_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half)[None, :] * stride_xk)
a_base_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half + 1)[None, :] * stride_xk)
b_base_ptrs = w_ptr + (offs_b_k[:, None] * stride_wk + offs_bn[None, :] * stride_wn)
scale_base_ptrs = scales_ptr + offs_bn * stride_sn
zero_base_ptrs = zeros_ptr + offs_bn * stride_zn
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even = tl.load(a_base_even_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_base_odd_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_packed = tl.load(b_base_ptrs + (k_curr // 2) * stride_wk, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale = tl.load(scale_base_ptrs + k_group * stride_sm, mask=b_col_mask, other=0.0)
zero = tl.load(zero_base_ptrs + k_group * stride_zm, mask=b_col_mask, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.
... (2963 more chars){
"description": "Verify correctness of loop-hoisted kernel with 1D masks.",
"command": "uv run check.py"
}{
"description": "Benchmark the performance of our loop-hoisted optimized kernel.",
"command": "KBH_BENCHMARK_BASELINES=1 uv run benchmark.py"
}{
"description": "Re-optimize configurations for the loop-hoisted kernel.",
"command": "uv run optimize_configs.py"
}{
"content": "<6963 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
GROUP_SIZE = 128
@triton.jit
def w4a16_gemm_kernel(
x_ptr, w_ptr, scales_ptr, zeros_ptr, out_ptr,
M, N,
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_sm, stride_sn,
stride_zm, stride_zn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# 2D Grid with Program ID grouping/swizzling for L2 cache reuse
num_pid_in_group = 8
group_id = pid // (num_pid_in_group * num_pid_n)
first_pid_m = group_id * num_pid_in_group
group_size_m = min(num_pid_m - first_pid_m, num_pid_in_group)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % (group_size_m * num_pid_n)) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_half = tl.arange(0, BLOCK_SIZE_K // 2)
offs_b_k = tl.arange(0, BLOCK_SIZE_K // 2)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
b_col_mask = offs_bn < N
# Hoist invariant pointer offsets outside the loop
a_base_even_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half)[None, :] * stride_xk)
a_base_odd_ptrs = x_ptr + (offs_am[:, None] * stride_xm + (2 * offs_k_half + 1)[None, :] * stride_xk)
b_base_ptrs = w_ptr + (offs_b_k[:, None] * stride_wk + offs_bn[None, :] * stride_wn)
scale_base_ptrs = scales_ptr + offs_bn * stride_sn
zero_base_ptrs = zeros_ptr + offs_bn * stride_zn
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_curr in range(0, K, BLOCK_SIZE_K):
# Even and Odd input pointers
a_even = tl.load(a_base_even_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half)[None, :] < K), other=0.0)
a_odd = tl.load(a_base_odd_ptrs + k_curr * stride_xk, mask=a_mask & ((k_curr + 2 * offs_k_half + 1)[None, :] < K), other=0.0)
# Load packed weights
b_packed = tl.load(b_base_ptrs + (k_curr // 2) * stride_wk, mask=b_mask & ((k_curr // 2 + offs_b_k[:, None]) < (K // 2)), other=0)
# Unpack even/odd
b_even_uint8 = b_packed & 0xF
b_odd_uint8 = (b_packed >> 4) & 0xF
# Load scale and zero for current group
k_group = k_curr // GROUP_SIZE
scale = tl.load(scale_base_ptrs + k_group * stride_sm, mask=b_col_mask, other=0.0)
zero = tl.load(zero_base_ptrs + k_group * stride_zm, mask=b_col_mask, other=0.0)
# Dequantize to bfloat16
w_even = (b_even_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
w_odd = (b_odd_uint8.to(tl.bfloat16) - zero[None, :]) * scale[None, :]
# Dot products
accumulator += tl.dot(a_even, w_even)
accumulator += tl.dot(a_odd, w_odd)
# Write to output
c = accumulator.to(tl.bfloat16)
offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_on = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = out_ptr + (offs_om[:, None] * stride_om + offs_on[None, :] * stride_on)
out_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
tl.store(out_ptrs, c, mask=out_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int, group_size: int = GROUP_SIZE):
super().__init__()
assert K % group_size == 0, "K must be divisible by group_size"
assert K % 2 == 0, "K must be even (int4 packing)"
self.M, self.N, self.K = M, N, K
self.group_size = group_size
n_groups = K // group_size
# Register buffers so load_state_dict works
self.register_buffer("w_q", torch.empty((K // 2, N), dtype=torch.uint8))
self.
... (2963 more chars){
"description": "Verify correctness with absolute fastest configs.",
"command": "uv run check.py"
}{
"command": "KBH_BENCHMARK_BASELINES=1 uv run benchmark.py",
"description": "Benchmark final performance of the customized, loop-hoisted Triton kernel."
}{
"description": "Clean up temporary scratch files.",
"command": "rm -f test_simple.py triton_vec_add.py optimize_configs.py"
}{
"description": "Final verification check.",
"command": "uv run check.py"
}{
"strategic_intent": "Task successfully completed. Preparing the final response.",
"title": "Task Completed",
"summary": "We have successfully implemented and thoroughly verified a high-performance, fused weight-only int4 quantized GEMM (W4A16) in solution.py. The implementation uses a highly optimized Triton kernel with hoisted pointer arithmetic, 1D mask-driven dequantization, and custom program ID swizzling. We systematically tuned the block sizes, warp counts, and pipeline stages for each specific input shape (including decode and prefill workloads). The kernel is 100% correct, pass
... (269 more chars)## 📂 Topic: **Task Completed**
**Summary:**
We have successfully implemented and thoroughly verified a high-performance, fused weight-only int4 quantized GEMM (W4A16) in solution.py. The implementation uses a highly optimized Triton kernel with hoisted pointer arithmetic, 1D mask-driven dequantization, and custom program ID swizzling. We systematically tuned the block sizes, warp counts, and pipeline stages for each specific input shape (including decode and prefill workloads). The kernel is 100% correct, passing all shapes, seeds, and numeric-stress validation cases (check.py). In the final benchmark (benchmark.py), our optimized kernel achieved an outstanding geomean peak bandwidth fraction of 14.78%, easily exceeding the 10.0% target on the Blackwell RTX PRO 6000 GPU.
> [!STRATEGY]
> **Intent:** Task successfully completed. Preparing the final response.