import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "gemm"
SUPPORTED_PRECISIONS = ["fp8_e4m3"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
E4M3_MAX = 448.0
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, scale_ptr,
M, N, K,
stride_am, stride_ak,
stride_bn, stride_bk,
stride_cm, stride_cn,
stride_scale,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Load scales
scale_ptrs = scale_ptr + offs_bn * stride_scale
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=1.0)
accumulator = accumulator * scale[None, :]
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator.to(tl.bfloat16), mask=c_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int):
super().__init__()
self.M, self.N, self.K = M, N, K
# Setup temporary buffers matching reference
w = torch.empty(N, K, dtype=torch.bfloat16)
nn.init.normal_(w, std=0.02)
s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
self.register_buffer("weight", w_fp8)
self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Run our Triton GEMM
M, K = x.shape
N = self.weight.shape[0]
c = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
BLOCK_M = 64
BLOCK_N = 64
BLOCK_K = 64
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
matmul_kernel[grid](
x, self.weight, c, self.weight_scale,
M, N, K,
x.stride(0), x.stride(1),
self.weight.stride(0), self.weight.stride(1),
c.stride(0), c.stride(1),
self.weight_scale.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
return c
M = 4096
N = 4096
K = 4096
def get_inputs():
x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn)
return [x]
def get_init_inputs():
return [M, N, K]
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260614_170713_gemini_gemini-3.5-flash_01_fp8_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
benchmark_event event=variant_start shape=0 variant=solution ts=2026-06-14T23:18:03.212258+00:00
benchmark_event event=variant_end shape=0 variant=solution ts=2026-06-14T23:18:03.432224+00:00 elapsed_s=0.220 ms=0.329888
shape=0 variant=solution tflops=416.623 gbps=203.429 ms=0.330
shape=0 solution_peak_fraction=0.4166
benchmark_event event=variant_start shape=1 variant=solution ts=2026-06-14T23:18:03.967603+00:00
benchmark_event event=variant_end shape=1 variant=solution ts=2026-06-14T23:18:04.001937+00:00 elapsed_s=0.034 ms=0.780672
shape=1 variant=solution tflops=177.385 gbps=86.288 ms=0.781
shape=1 solution_peak_fraction=0.1774
benchmark_event event=variant_start shape=2 variant=solution ts=2026-06-14T23:18:05.908728+00:00
benchmark_event event=variant_end shape=2 variant=solution ts=2026-06-14T23:18:05.914285+00:00 elapsed_s=0.006 ms=0.095840
shape=2 variant=solution tflops=44.814 gbps=708.423 ms=0.096
shape=2 solution_peak_fraction=0.0448
benchmark_event event=variant_start shape=3 variant=solution ts=2026-06-14T23:18:07.636303+00:00
benchmark_event event=variant_end shape=3 variant=solution ts=2026-06-14T23:18:07.681664+00:00 elapsed_s=0.045 ms=1.074864
shape=3 variant=solution tflops=447.532 gbps=179.500 ms=1.075
shape=3 solution_peak_fraction=0.4475
peak_fraction: 0.1962
RESULT: OK
warning: `VIRTUAL_ENV=/home/infatoshi/kernelbench.com/benchmarks/hard/.venv` does not match the project environment path `/home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260614_170713_gemini_gemini-3.5-flash_01_fp8_gemm/repo/.venv` and will be ignored; use `--active` to target the active environment instead
warning: Ignoring existing virtual environment linked to non-existent Python interpreter: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260614_170713_gemini_gemini-3.5-flash_01_fp8_gemm/repo/.venv/bin/python3 -> python
Using CPython 3.11.15
Removed virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260614_170713_gemini_gemini-3.5-flash_01_fp8_gemm/repo/.venv
Creating virtual environment at: /home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260614_170713_gemini_gemini-3.5-flash_01_fp8_gemm/repo/.venv
Building kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260614_170713_gemini_gemini-3.5-flash_01_fp8_gemm/repo
Built kernelbench-hard @ file:///home/infatoshi/kernelbench.com/benchmarks/hard/outputs/runs/20260614_170713_gemini_gemini-3.5-flash_01_fp8_gemm/repo
Installed 56 packages in 67ms
PASS
{
"run_id": "20260614_170713_gemini_gemini-3.5-flash_01_fp8_gemm",
"run_group": "",
"problem": "01_fp8_gemm",
"harness": "gemini",
"model": "gemini-3.5-flash",
"reasoning_effort": "",
"started_at": "2026-06-14T17:07:14-06:00",
"harness_finished_at": "2026-06-14T17:17:53-06:00",
"finished_at": "2026-06-14T17:18:07-06:00",
"start_epoch": 1781478434,
"harness_end_epoch": 1781479073,
"end_epoch": 1781479087,
"has_solution": true,
"correct": true,
"failure_reason": "pass",
"retryable_infra_failure": false,
"minimum_useful_output_tokens": 5000,
"peak_fraction": 0.1962,
"template_mutated": false,
"elapsed_seconds": 639,
"total_elapsed_seconds": 653,
"check_elapsed_seconds": 9,
"benchmark_elapsed_seconds": 5,
"check_timeout_seconds": 180,
"benchmark_timeout_seconds": 1800,
"check_exit_code": 0,
"benchmark_exit_code": 0,
"harness_exit_code": 0,
"session_complete": true,
"agent_cuda_disabled": false,
"agent_container": true,
"agent_container_image": "nvcr.io/nvidia/tensorrt-llm/release:latest",
"agent_container_network": "bridge",
"gpu_queue_mode": "agent_container_native_profiling_path_wrapper_gpu_lock",
"output_tokens_per_second": 15.522691705790297,
"usage": {"input_tokens": 412996, "output_tokens": 9919, "cache_read_tokens": 243721, "cache_creation_tokens": null, "reasoning_tokens": null, "total_cost_usd": null}
}
{
"strategic_intent": "Initiating optimization of FP8 GEMM on RTX PRO 6000 by studying the files in the workspace.",
"summary": "Starting the optimization of the FP8 GEMM. We will first read reference.py, check.py, and other files in the workspace to understand the constraints, input formats, model interface, and testing framework. then we will build our strategy around Triton or inline PTX / CUTLASS.",
"title": "Researching Workspace and Reference Implementation"
}{
"file_path": "reference.py",
"start_line": 1,
"end_line": 100
}{
"file_path": "check.py",
"start_line": 1,
"end_line": 100
}{
"start_line": 1,
"end_line": 100,
"file_path": "benchmark.py"
}## 📂 Topic: **Researching Workspace and Reference Implementation**
**Summary:**
Starting the optimization of the FP8 GEMM. We will first read reference.py, check.py, and other files in the workspace to understand the constraints, input formats, model interface, and testing framework. then we will build our strategy around Triton or inline PTX / CUTLASS.
> [!STRATEGY]
> **Intent:** Initiating optimization of FP8 GEMM on RTX PRO 6000 by studying the files in the workspace.Read lines 1-100 of 120 from check.pyRead lines 1-100 of 148 from benchmark.py{
"file_path": "problem.yaml",
"start_line": 1,
"end_line": 100
}{
"file_path": "shapes.py",
"end_line": 100,
"start_line": 1
}{
"pattern": "*.py"
}Found 5 matching file(s){
"end_line": 100,
"file_path": "sota.py",
"start_line": 1
}{
"description": "Check PyTorch, Triton, and CUDA device information.",
"command": "uv run python -c \"import torch; import triton; print('PyTorch:', torch.__version__); print('Triton:', triton.__version__); print('Device:', torch.cuda.get_device_name(0))\""
}{
"content": "<2600 chars \u2014 see below>",
"file_path": "test_triton.py"
}import torch
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bn, stride_bk,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Load blocks (with masking if needed, but let's do a simple mask)
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
# B is shape (K, N) or stored as (N, K) and we load from B.
# Here we assume b_ptr is weight of shape (N, K), so strides are stride_bn, stride_bk.
# So b_ptrs offset is offs_bn[None, :] * stride_bn + (k * BLOCK_K + offs_k[:, None]) * stride_bk
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def triton_gemm(a, b, M, N, K):
# a: (M, K)
# b: (N, K)
# c: (M, N)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M=64, BLOCK_N=64, BLOCK_K=64,
)
return c
# Quick test
device = torch.device("cuda:0")
M, N, K = 128, 128, 128
x = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device=device).to(torch.float8_e4m3fn)
y = triton_gemm(x, w, M, N, K)
ref_y = (x.to(torch.float32) @ w.to(torch.float32).T)
diff = (y - ref_y).abs().max()
print("Diff:", diff.item())
{
"command": "uv run python test_triton.py",
"description": "Run test_triton.py to verify FP8 GEMM with Triton."
}{
"file_path": "test_triton.py",
"content": "<2665 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, scale_ptr,
M, N, K,
stride_am, stride_ak,
stride_bn, stride_bk,
stride_cm, stride_cn,
stride_scale,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Correctly mask along the K dimension
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Load scales
scale_ptrs = scale_ptr + offs_bn * stride_scale
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=1.0)
accumulator = accumulator * scale[None, :]
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator.to(tl.bfloat16), mask=c_mask)
def triton_gemm(a, b, scale, M, N, K):
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
matmul_kernel[grid](
a, b, scale,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
scale.stride(0),
BLOCK_M=64, BLOCK_N=64, BLOCK_K=64,
)
return c
# Quick test with unaligned K and weight scale
device = torch.device("cuda:0")
M, N, K = 128, 128, 130
x = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device=device).to(torch.float8_e4m3fn)
scale = torch.randn(N, 1, device=device) # (N, 1) to match (N,)
y = triton_gemm(x, w, scale, M, N, K)
ref_y = (x.to(torch.float32) @ w.to(torch.float32).T) * scale.T
ref_y = ref_y.to(torch.bfloat16)
diff = (y - ref_y).abs().max()
print("Diff:", diff.item())
{
"description": "Verify unaligned K and scale correctness in Triton.",
"command": "uv run python test_triton.py"
}{
"file_path": "test_triton.py",
"content": "<2679 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, scale_ptr,
M, N, K,
stride_am, stride_ak,
stride_bn, stride_bk,
stride_cm, stride_cn,
stride_scale,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Correctly mask along the K dimension
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Load scales
scale_ptrs = scale_ptr + offs_bn * stride_scale
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=1.0)
accumulator = accumulator * scale[None, :]
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator.to(tl.bfloat16), mask=c_mask)
def triton_gemm(a, b, scale, M, N, K):
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
matmul_kernel[grid](
a, b, c, scale,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
scale.stride(0),
BLOCK_M=64, BLOCK_N=64, BLOCK_K=64,
)
return c
# Quick test with unaligned K and weight scale
device = torch.device("cuda:0")
M, N, K = 128, 128, 130
x = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
w = torch.randn(N, K, device=device).to(torch.float8_e4m3fn)
scale = torch.randn(N, 1, device=device) # (N, 1) to match (N,)
y = triton_gemm(x, w, scale.squeeze(1), M, N, K)
ref_y = (x.to(torch.float32) @ w.to(torch.float32).T) * scale.T
ref_y = ref_y.to(torch.bfloat16)
diff = (y - ref_y).abs().max()
print("Diff:", diff.item())
{
"description": "Verify the corrected Triton FP8 GEMM test.",
"command": "uv run python test_triton.py"
}{
"content": "<3305 chars \u2014 see below>",
"file_path": "solution.py"
}import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "gemm"
SUPPORTED_PRECISIONS = ["fp8_e4m3"]
HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"]
E4M3_MAX = 448.0
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, scale_ptr,
M, N, K,
stride_am, stride_ak,
stride_bn, stride_bk,
stride_cm, stride_cn,
stride_scale,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Load scales
scale_ptrs = scale_ptr + offs_bn * stride_scale
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=1.0)
accumulator = accumulator * scale[None, :]
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator.to(tl.bfloat16), mask=c_mask)
class Model(nn.Module):
def __init__(self, M: int, N: int, K: int):
super().__init__()
self.M, self.N, self.K = M, N, K
# Setup temporary buffers matching reference
w = torch.empty(N, K, dtype=torch.bfloat16)
nn.init.normal_(w, std=0.02)
s = (w.float().abs().amax(dim=1, keepdim=True) / E4M3_MAX).clamp(min=1e-12)
w_fp8 = (w.float() / s).to(torch.float8_e4m3fn)
self.register_buffer("weight", w_fp8)
self.register_buffer("weight_scale", s.squeeze(1).to(torch.float32))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Run our Triton GEMM
M, K = x.shape
N = self.weight.shape[0]
c = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
BLOCK_M = 64
BLOCK_N = 64
BLOCK_K = 64
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
matmul_kernel[grid](
x, self.weight, c, self.weight_scale,
M, N, K,
x.stride(0), x.stride(1),
self.weight.stride(0), self.weight.stride(1),
c.stride(0), c.stride(1),
self.weight_scale.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
return c
M = 4096
N = 4096
K = 4096
def get_inputs():
x = (torch.rand(M, K) * 8 - 4).to(torch.float8_e4m3fn)
return [x]
def get_init_inputs():
return [M, N, K]
{
"description": "Run correctness checks.",
"command": "uv run python check.py"
}{
"command": "uv run python benchmark.py",
"description": "Run performance benchmarks on the baseline."
}{
"file_path": "test_tuning.py",
"content": "<4052 chars \u2014 see below>"
}import torch
import triton
import triton.language as tl
import time
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, scale_ptr,
M, N, K,
stride_am, stride_ak,
stride_bn, stride_bk,
stride_cm, stride_cn,
stride_scale,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Load scales
scale_ptrs = scale_ptr + offs_bn * stride_scale
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=1.0)
accumulator = accumulator * scale[None, :]
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator.to(tl.bfloat16), mask=c_mask)
def benchmark_config(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
device = torch.device("cuda:0")
a = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
b = torch.randn(N, K, device=device).to(torch.float8_e4m3fn)
scale = torch.randn(N, device=device)
c = torch.empty((M, N), device=device, dtype=torch.bfloat16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
# Warmup
for _ in range(5):
matmul_kernel[grid](
a, b, c, scale,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
scale.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
start = time.perf_counter()
iters = 50
for _ in range(iters):
matmul_kernel[grid](
a, b, c, scale,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
scale.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
end = time.perf_counter()
ms = (end - start) * 1000 / iters
flops = 2 * M * N * K
tflops = (flops * 1e-12) / (ms * 1e-3)
return ms, tflops
M, N, K = 4096, 4096, 4096
configs = [
# (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages)
(128, 128, 64, 4, 3),
(128, 128, 64, 8, 4),
(128, 128, 64, 8, 5),
(128, 256, 64, 8, 4),
(256, 128, 64, 8, 4),
(128, 128, 128, 8, 4),
(128, 128, 128, 8, 5),
(128, 128, 128, 4, 3),
(128, 128, 128, 8, 3),
(256, 256, 64, 8, 4),
(256, 256, 128, 8, 4),
(64, 128, 64, 4, 4),
(128, 64, 64, 4, 4),
(128, 128, 32, 4, 4),
]
best_tflops = 0
best_config = None
for cfg in configs:
try:
ms, tflops = benchmark_config(M, N, K, *cfg)
print(f"Config {cfg}: {ms:.3f} ms, {tflops:.3f} TFLOPS")
if tflops > best_tflops:
best_tflops = tflops
best_config = cfg
except Exception as e:
print(f"Config {cfg} failed: {e}")
print(f"
... (52 more chars){
"command": "uv run python test_tuning.py",
"description": "Run the tuning sweep on Shape 0."
}{
"content": "<5169 chars \u2014 see below>",
"file_path": "test_tuning_all.py"
}import torch
import triton
import triton.language as tl
import time
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, scale_ptr,
M, N, K,
stride_am, stride_ak,
stride_bn, stride_bk,
stride_cm, stride_cn,
stride_scale,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Load scales
scale_ptrs = scale_ptr + offs_bn * stride_scale
scale = tl.load(scale_ptrs, mask=offs_bn < N, other=1.0)
accumulator = accumulator * scale[None, :]
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator.to(tl.bfloat16), mask=c_mask)
def benchmark_config(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
device = torch.device("cuda:0")
a = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
b = torch.randn(N, K, device=device).to(torch.float8_e4m3fn)
scale = torch.randn(N, device=device)
c = torch.empty((M, N), device=device, dtype=torch.bfloat16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
# Warmup
for _ in range(5):
matmul_kernel[grid](
a, b, c, scale,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
scale.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
start = time.perf_counter()
iters = 30
for _ in range(iters):
matmul_kernel[grid](
a, b, c, scale,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
scale.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps, num_stages=num_stages,
)
torch.cuda.synchronize()
end = time.perf_counter()
ms = (end - start) * 1000 / iters
flops = 2 * M * N * K
tflops = (flops * 1e-12) / (ms * 1e-3)
bytes_moved = M*K + K*N + M*N*2
gbps = (bytes_moved * 1e-9) / (ms * 1e-3)
return ms, tflops, gbps
SHAPES = [
{"M": 4096, "N": 4096, "K": 4096}, # square aligned
{"M": 4096, "N": 4096, "K": 4127}, # K not multiple of 128 -> predicated tails
{"M": 32, "N": 8192, "K": 8192}, # skinny M (decode)
{"M": 4096, "N": 14336, "K": 4096}, # Llama3 up-proj shape
]
# Standard configs for M large
configs_large = [
(128, 256, 64, 8, 4),
(256, 128, 64, 8, 4),
(128, 128, 64, 4, 3),
(128, 128, 64, 8, 4),
(128, 128, 128, 8, 3),
(128, 128, 128, 8, 4),
(64, 128, 64, 4, 4),
(128, 64, 64, 4, 4),
(256, 128, 128, 8, 3),
(128, 256, 128, 8, 3),
]
# Configs for M small (skinny)
configs_skinny = [
(16, 128, 64, 4, 3),
(16, 256, 64, 4, 3),
(32, 128, 64, 4,
... (1169 more chars)