"""Triton fused patch-gather + GEMM for 3D patch embedding.
Strategy: explicit two-phase approach.
1. im2col Triton kernel to materialize (M, K) matrix from input tensor
2. Standard Triton GEMM kernel for (M, K) x (K, N) → (M, N)
This separates concerns and lets each kernel be independently optimized.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# Phase 1: im2col kernel
# ---------------------------------------------------------------------------
@triton.jit
def _im2col_kernel(
X_PTR, Y_PTR,
C_val, T_val, H_val, W_val,
kT_val, kH_val, kW_val,
T_out_val, H_out_val, W_out_val,
M_total, K_val,
stride_y_m, stride_y_k,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Each program handles BLOCK_M rows and BLOCK_K columns of the im2col matrix."""
pid_m = tl.program_id(0)
pid_k = tl.program_id(1)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
m_mask = off_m < M_total
k_mask = off_k < K_val
stride_c = T_val * H_val * W_val
stride_t = H_val * W_val
stride_h = W_val
# Decode m -> (b, t_out, h_out, w_out)
w_out_arr = off_m % W_out_val
h_out_arr = ((off_m // W_out_val) % H_out_val)
t_out_arr = ((off_m // (W_out_val * H_out_val)) % T_out_val)
b_arr = off_m // (W_out_val * H_out_val * T_out_val)
t_start = t_out_arr * kT_val
h_start = h_out_arr * kH_val
w_start = w_out_arr * kW_val
base_x = b_arr * (C_val * T_val * H_val * W_val) + t_start * stride_t + h_start * stride_h + w_start
# Map k -> (c, kt, kh, kw)
k_per_ch = kT_val * kH_val * kW_val
c_idx = off_k // k_per_ch
rem = off_k % k_per_ch
kHkW = kH_val * kW_val
kt_idx = rem // kHkW
rem2 = rem % kHkW
kh_idx = rem2 // kW_val
kw_idx = rem2 % kW_val
off_k_val = c_idx * stride_c + kt_idx * stride_t + kh_idx * stride_h + kw_idx
x_idx = base_x[:, None] + off_k_val[None, :]
x_val = tl.load(X_PTR + x_idx, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
y_idx = off_m[:, None] * stride_y_m + off_k[None, :] * stride_y_k
tl.store(Y_PTR + y_idx, x_val, mask=m_mask[:, None] & k_mask[None, :])
# ---------------------------------------------------------------------------
# Phase 2: GEMM kernel (standard Triton GEMM)
# ---------------------------------------------------------------------------
@triton.jit
def _gemm_kernel(
A_PTR, B_PTR, C_PTR,
M_total, N_val, K_val,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Standard blocked GEMM."""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = off_m < M_total
n_mask = off_n < N_val
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_base in range(0, (K_val // BLOCK_K + 1) * BLOCK_K, BLOCK_K):
k = k_base + tl.arange(0, BLOCK_K)
k_mask_t = k < K_val
a_idx = off_m[:, None] * stride_am + k[None, :] * stride_ak
b_idx = k[:, None] * stride_bk + off_n[None, :] * stride_bn
a = tl.load(A_PTR + a_idx, mask=m_mask[:, None] & k_mask_t[None, :], other=0.0).to(tl.bfloat16)
b = tl.load(B_PTR + b_idx, mask=k_mask_t[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
acc = tl.dot(a, b, acc)
c_idx = off_m[:, None] * stride_cm + off_n[None, :] * stride_cn
tl.store(C_PTR + c_idx, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
# ---------------------------------------------------------------------------
# Wrapper class
# ---------------------------------------------------------------------------
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
assert T % kT == 0 and H % kH == 0 and W % kW == 0
self.B = B
self.C = C
self.T = T
self.H = H
self.W = W
self.kT = kT
self.kH = kH
self.kW = kW
self.embed_dim = embed_dim
self.T_out = T // kT
self.H_out = H // kH
self.W_out = W // kW
self.M_total = B * self.T_out * self.H_out * self.W_out
self.K = C * kT * kH * kW
self.conv = nn.Conv3d(
C, embed_dim,
kernel_size=(kT, kH, kW),
stride=(kT, kH, kW),
bias=False,
dtype=torch.bfloat16,
)
nn.init.normal_(self.conv.weight, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
M = self.M_total
N = self.embed_dim
K = self.K
weights = self.conv.weight.view(N, K).t().contiguous()
# Phase 1: im2col
im2col = torch.empty(M, K, dtype=x.dtype, device=x.device)
grid1 = (triton.cdiv(M, 64), triton.cdiv(K, 64))
_im2col_kernel[grid1](
x, im2col,
self.C, self.T, self.H, self.W,
self.kT, self.kH, self.kW,
self.T_out, self.H_out, self.W_out,
M, K,
K, 1,
BLOCK_M=64, BLOCK_K=64,
)
# Phase 2: GEMM (im2col @ weights.T = (M,K) x (K,N) -> (M,N))
im2col = im2col.view(M, K).contiguous().to(torch.bfloat16)
weights_bf16 = weights.contiguous().to(torch.bfloat16)
y = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid2 = (triton.cdiv(M, 128), triton.cdiv(N, 128))
_gemm_kernel[grid2](
im2col, weights_bf16, y,
M, N, K,
K, 1, K, 1, N, 1,
BLOCK_M=128, BLOCK_N=128, BLOCK_K=64,
)
return y.view(self.B, self.T_out, self.H_out, self.W_out, N).permute(0, 4, 1, 2, 3).contiguous()
B = 1; C = 3; T = 2; H = 224; W = 224
kT = 2; kH = 14; kW = 14; embed_dim = 1280
def get_inputs():
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
return [x]
def get_init_inputs():
return [B, C, T, H, W, kT, kH, kW, embed_dim]
FAIL: shape 0 {'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280} seed 42: tolerance exceeded: max_abs_diff=2.34375 (atol=0.02, rtol=0.02)
{
"problem": "10_patch_embed_conv3d_gemm",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-plus",
"reasoning_effort": "",
"has_solution": true,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 1657,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 2414517, "output_tokens": 40872, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 36331, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive Conv3d-as-GEMM patch embedding reference (correctness only).
2:
3: Vision-Transformer / Qwen2-VL style patch embedding: a video tensor
4: (B, C, T, H, W) is split into non-overlapping (kT, kH, kW) patches and each
5: patch projected to embed_dim. This is mathematically a 3D convolution with
6: stride == kernel and equivalently a single GEMM after a strided reshape.
7:
8: Reference uses nn.Conv3d (cuDNN dispatch) for clarity. The agent's solution is
9: forbidden from using Conv3d / conv3d / matmul / linear / einsum, forcing them
10: to write a fused patch-gather + tensor-core GEMM kernel.
11:
12: Output layout: (B, embed_dim, T/kT, H/kH, W/kW). No bias.
13: """
14: import torch
15: import torch.nn as nn
16:
17: OP_TYPE = "patch_embed"
18: SUPPORTED_PRECISIONS = ["bf16"]
19: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
20:
21:
22: class Model(nn.Module):
23: def __init__(self, B: int, C: int, T: int, H: int, W: int,
24: kT: int, kH: int, kW: int, embed_dim: int):
25: super().__init__()
26: assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
27: f"Input dims must be divisible by patch size: T={T} kT={kT} H={H} kH={kH} W={W} kW={kW}"
28: self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
29: self.kT, self.kH, self.kW = kT, kH, kW
30: self.embed_dim = embed_dim
31:
3
... (911 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Conv3d-as-GEMM patch embedding.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") o
... (3130 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Conv3d-as-GEMM patch embedding.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regim
... (3879 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 10_patch_embed_conv3d_gemm
2: display_name: "Conv3d-as-GEMM Patch Embedding"
3: precision: bf16
4: regime: compute # GEMM-equivalent, embed_dim=1280 is well above the bf16 ridge
5:
6: # Dense FLOPs: 2 * (num_patches) * (C * kT * kH * kW) * embed_dim.
7: # With stride==kernel and divisible inputs, num_patches * (kT*kH*kW) = T*H*W,
8: # so this collapses to: 2 * B * embed_dim * C * T * H * W.
9: flops_formula: "2 * B * embed_dim * C * T * H * W"
10:
11: # Bytes moved (lower bound):
12: # read input x: B * C * T * H * W * 2
13: # read weights: embed_dim * C * kT * kH * kW * 2
14: # write output: B * embed_dim * (T // kT) * (H // kH) * (W // kW) * 2
15: bytes_formula: "B*C*T*H*W*2 + embed_dim*C*kT*kH*kW*2 + B*embed_dim*(T//kT)*(H//kH)*(W//kW)*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must implement the fused patch-gather + GEMM
25: # themselves. Banning Conv3d/conv3d (the cuDNN dispatch) and matmul/bmm/linear/
26: # einsum (the lazy reshape-then-cuBLAS shortcut). Unfold is also banned for
27: # completeness even though it's nominally 2D.
28: forbidden:
29: - "torch.nn.Conv3d"
30: - "torch.nn.functional.conv3d"
31: - "F.conv3d"
32: - "torch.conv3d"
33: - "torch.matmul"
34: - "torch.bmm"
35: -
... (386 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for Conv3d-as-GEMM patch embedding.
2:
3: Sized for Qwen2-VL / ViT-style patch embedding. embed_dim=1280 matches Qwen2-VL
4: ViT; embed_dim=768 is a ViT-B/16 reference. Patch sizes are the canonical
5: 14x14 spatial / 2-frame temporal used in modern VL encoders.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL ViT base: 224x224x2-frame video, 14x14x2 patches, embed_dim=1280
10: {"B": 1, "C": 3, "T": 2, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
11:
12: # Same patch config, batch 2, 4-frame video
13: {"B": 2, "C": 3, "T": 4, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
14:
15: # Larger spatial (336x336), 8-frame — exercises larger num_patches
16: {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
17:
18: # ViT-B/16 (image-mode): kT=1, kH=kW=16, embed_dim=768
19: {"B": 4, "C": 3, "T": 1, "H": 224, "W": 224, "kT": 1, "kH": 16, "kW": 16, "embed_dim": 768},
20: ]
(End of file - total 20 lines)
</content>{
"todos": [
{
"content": "Understand problem and repo structure",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement Triton fused patch-gather + GEMM kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py and verify PASS",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize with benchmark.py",
"status": "pending",
"priority": "medium"
}
]
}content: 'Understand problem and repo structure'
status: 'in_progress'
priority: 'high'
content: 'Implement Triton fused patch-gather + GEMM kernel'
status: 'pending'
priority: 'high'
content: 'Run check.py and verify PASS'
status: 'pending'
priority: 'high'
content: 'Profile and optimize with benchmark.py'
status: 'pending'
priority: 'medium'{
"command": "which nvcc && nvcc --version | head -2",
"description": "Check CUDA availability"
}/usr/local/cuda-13/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
{
"command": "nvidia-smi | head -15",
"description": "Check GPU status"
}Thu May 7 15:04:26 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 39C P0 94W / 600W | 385MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/src/eval/ 2>/dev/null",
"description": "Check src structure"
}correctness.py
__init__.py
__pycache__
report.py
roofline.py
shapes.py
timing.py
{
"command": "uv run python -c \"import triton; print(triton.__version__)\"",
"description": "Check Triton version"
}3.6.0
{
"todos": [
{
"content": "Understand problem and verify shapes",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement Triton fused patch-gather + GEMM kernel in solution.py",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py and debug correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize with benchmark.py",
"status": "pending",
"priority": "medium"
}
]
}content: 'Understand problem and verify shapes'
status: 'in_progress'
priority: 'high'
content: 'Implement Triton fused patch-gather + GEMM kernel in solution.py'
status: 'pending'
priority: 'high'
content: 'Run check.py and debug correctness'
status: 'pending'
priority: 'high'
content: 'Profile and optimize with benchmark.py'
status: 'pending'
priority: 'medium'{
"todos": [
{
"content": "Understand problem and verify shapes",
"status": "completed",
"priority": "high"
},
{
"content": "Implement Triton fused patch-gather + GEMM kernel in solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py and debug correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize with benchmark.py",
"status": "pending",
"priority": "medium"
}
]
}content: 'Understand problem and verify shapes'
status: 'completed'
priority: 'high'
content: 'Implement Triton fused patch-gather + GEMM kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Run check.py and debug correctness'
status: 'pending'
priority: 'high'
content: 'Profile and optimize with benchmark.py'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -0,0 +1,521 @@
+"""Triton fused patch-gather + GEMM for 3D patch embedding.
+
+We decompose the work into:
+ - Block-wise patch extraction (C * kT * kH * kW elements from x)
+ - Block-wise matrix multiply with a WBF16 weight tile (K x embed_dim)
+ - Write to output (B, embed_dim, T', H', W')
+
+For each output patch (b, t', h', w') and output column block, we:
+ 1. Load the K = C*kT*kH*kW elements of x that form the patch
+ 2. Load the corresponding K x BLOCK_N tile of the weight matrix
+ 3. Multiply and accumulate
+ 4. Write the result
+
+Memory layout:
+ x: (B, C, T, H, W) contiguous
+ weight: (embed_dim, C, kT, kH, kW) -> we reshape to (embed_dim, K)
+
+Output: (B, embed_dim, T', H', W') -> we write in (B, embed_dim, T'*H'*W') layout
+ then reshape.
+
+For compute-bound GEMM, we program-block: each program handles a row of the
+output matrix (one patch x embed_dim slice). We fuse the patch gather into
+the K-loop to avoid materializing the full im2col.
+"""
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+# ---------------------------------------------------------------------------
+# Triton kernel: fused patch gather + GEMM
+# ---------------------------------------------------------------------------
+# Each program computes one output "row". The output matrix has:
+# M = B * (T//kT) * (H//kH) * (W//kW) rows (one per patch)
+# K = C * kT * kH * kW columns of the im2col matrix
+# N = embed_dim columns of the result
+#
+# We tile along K and N. Each program handles a subset of rows.
+# Within each row, we produce all N columns.
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _patch_gemm_kernel(
+ X, W, Y,
+ B, C, T, H, W_dim,
+ kT, kH, kW,
+ T_out, H_out, W_out,
+ K, M_total, N,
+ stride_xb, stride_xc, stride_xt, stride_xh, stride_xw,
+ stride_wk, stride_wn,
+ stride_yb, stride_ye, stride_yt, stride_yh, stride_yw,
+ BLOCK_K: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ """
+ Each program computes one output row (patch).
+ Programs are arranged as: program_id = patch_idx (0..M_total)
+ We tile along K blocks and compute full N.
+ """
+ pid_row = tl.program_id(0)
+
+ # Decode (b, t_out, h_out, w_out) from pid_row
+ w_out_idx = pid_row % W_out
+ h_out_idx = (pid_row // W_out) % H_out
+ t_out_idx = (pid_row // (W_out * H_out)) % T_out
+ b_idx = pid_row // (W_out * H_out * T_out)
+
+ # Starting position in input volume for this patch
+ t_start = t_out_idx * kT
+ h_start = h_out_idx * kH
+ w_start = w_out_idx * kW
+
+ # Strides for loading x
+ # x[b, c, t_start + kt, h_start + kh, w_start + kw]
+ x_stride_0 = stride_xb
+ x_stride_1 = stride_xc
+ x_stride_2 = stride_xt
+ x_stride_3 = stride_xh
+ x_stride_4 = stride_xw
+
+ # Number of blocks along K
+ num_k_blocks = tl.cdiv(K, BLOCK_K)
+
+ # We accumulate the full row in registers (N elements per block)
+ # For large N, we need to tile along N too.
+ # Let's do a loop over N blocks.
+ num_n_blocks = tl.cdiv(N, BLOCK_N)
+
+ for n_block in range(num_n_blocks):
+ acc = tl.zeros((1, BLOCK_N), dtype=tl.float32)
+
+ offset_n = n_block * BLOCK_N
+
+ # Mask for N dimension
+ n_mask = offset_n + tl.arange(0, BLOCK_N) < N
+
+ for k_block in range(num_k_blocks):
+ offset_k = k_block * BLOCK_K
+ k_mask = offset_k + tl.arange(0, BLOCK_K) < K
+
+ # Gather x[patch] elements -> shape (BLOCK_K,)
+ # We need to map each k index to (c, kt, kh, kw)
+ k_idx = offset_k + tl.arange(0, BLOCK_K)
+
+ c_idx = k_idx // (kT * kH * kW)
+ kt_idx = (k_idx // (kH * kW)) % kT
+ kh_idx = (k_idx // kW) % kH
+ kw_idx = k_idx % kW
+
+ # Compute linear index into x tensor
+ x_idx = b_idx * x_stride_0 + c_idx * x_stride_1 + \
+ (t_start + kt_idx) * x_stride_2 + \
+ (h_start + kh_idx) * x_stride_3 + \
+ (w_start + kw_idx) * x_stride_4
+
+ # Load x values
+ x_vals = tl.load(X + x_idx, mask=k_mask, other=0.0).to(tl.float32)
+
+ # Load weight tile: W[offset_k:(offset_k+BLOCK_K), offset_n:(offset_n+BLOCK_N)]
+ # Weight is stored as (embed_dim, K) -> (N, K) after transpose
+ # Actually we store as (K, N) for this access pattern
+ # w[k_idx, n_idx]
+ w_ptrs = W + k_idx[:, None] * stride_wk + (offset_n + tl.arange(0, BLOCK_N)[None, :]) * stride_wn
+ w_vals = tl.load(w_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0).to(tl.float32)
+
+ # Compute dot product: x_vals is (BLOCK_K,), w_vals is (BLOCK_K, BLOCK_N)
+ # acc += x_vals @ w_vals -> (1, BLOCK_N)
+ acc = tl.dot(x_vals.to(tl.float32)[None, :], w_vals.to(tl.float32), acc)
+
+ # Store accumulator
+ out_idx = b_idx * stride_yb + offset_n * stride_yn + t_out_idx * stride_yt + \
+ h_out_idx * stride_yh + w_out_idx * stride_yw
+
+ out_ptrs = Y + out_idx + tl.arange(0, BLOCK_N)[None, :] * stride_yn
+ out_vals = acc.to(tl.bfloat16)
+ tl.store(out_ptrs, out_vals, mask=n_mask[None, :])
+
+
+# ---------------------------------------------------------------------------
+# Better approach: classic GEMM with fused patch load
+# Use M x N blocking for better efficiency
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _patch_gemm_kernel_v2(
+ X, W, Y,
+ B, C, T, H, W_dim,
+ kT, kH, kW,
+ T_out, H_out, W_out,
+ K, M_total,
... (13588 more chars){
"content": "<18943 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}FAIL: shape 0 {'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280} seed 42: tolerance exceeded: max_abs_diff=1.57812 (atol=0.02, rtol=0.02)
{
"command": "<1204 chars \u2014 see below>",
"description": "Debug Triton kernel output"
}uv run python -c "
import torch
import torch.nn as nn
from reference import Model as RefModel
from solution import Model as SolModel
B, C, T_size, H_size, W_size = 1, 3, 2, 224, 224
kT, kH, kW = 2, 14, 14
embed_dim = 1280
RefModel.B = RefModel.C = RefModel.kT = RefModel.kH = RefModel.kW = RefModel.embed_dim = None
RefModel.T = RefModel.H = RefModel.W = None
torch.manual_seed(42)
ref = RefModel(B, C, T_size, H_size, W_size, kT, kH, kW, embed_dim).cuda().eval()
sol = SolModel(B, C, T_size, H_size, W_size, kT, kH, kW, embed_dim).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(B, C, T_size, H_size, W_size, dtype=torch.bfloat16, device='cuda') * 0.5
with torch.no_grad():
ref_out = ref(x)
sol_out = sol(x)
print('ref_out shape:', ref_out.shape)
print('sol_out shape:', sol_out.shape)
print('ref_out dtype:', ref_out.dtype)
print('sol_out dtype:', sol_out.dtype)
print('max_abs_diff:', (ref_out.float() - sol_out.float()).abs().max().item())
print('mean_abs_diff:', (ref_out.float() - sol_out.float()).abs().mean().item())
# Check values
print('ref_out[0,0,:3,:3]:', ref_out[0,0,:3,:3])
print('sol_out[0,0,:3,:3]:', sol_out[0,0,:3,:3])
" 2>&1 | head -20
ref_out shape: torch.Size([1, 1280, 1, 16, 16])
sol_out shape: torch.Size([1, 1280, 1, 16, 16])
ref_out dtype: torch.bfloat16
sol_out dtype: torch.bfloat16
max_abs_diff: 1.5390625
mean_abs_diff: 0.20348301529884338
ref_out[0,0,:3,:3]: tensor([[[ 0.1680, 0.5547, 0.0718, 0.2031, 0.0179, 0.4570, -0.1641,
-0.6211, -0.4551, 0.4121, -0.0129, 0.1885, -0.1807, 0.1118,
-0.1484, 0.1807],
[ 0.2490, -0.0737, -0.4219, 0.2852, -0.2207, 0.0306, -0.0898,
-0.4355, 0.1572, -0.1875, -0.0996, 0.4805, 0.6367, 0.0035,
0.3086, -0.1406],
[ 0.3945, 0.1309, -0.1562, -0.6211, -0.0269, -0.1436, 0.1807,
-0.0072, 0.2676, -0.2002, 0.1768, 0.1475, 0.7422, 0.5117,
-0.4746, 0.2793]]], device='cuda:0', dtype=torch.bfloat16)
sol_out[0,0,:3,:3]: tensor([[[ 0.1680, 0.5547, 0.0718, 0.2031, 0.0179, 0.4570, -0.1641,
-0.6211, -0.4551, 0.4121, -0.0129, 0.1885, -0.1807, 0.1118,
-0.1484, 0.1807],
[ 0.2490, -0.0737, -0.4219, 0.2852, -0.2207, 0.0306, -0.0898,
-0.4355, 0.1572, -0.1875, -0.0996, 0.4805, 0.6367, 0.0035,
{
"command": "<1214 chars \u2014 see below>",
"description": "Analyze error pattern"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import torch
import torch.nn as nn
from reference import Model as RefModel
from solution import Model as SolModel
B, C, T_size, H_size, W_size = 1, 3, 2, 224, 224
kT, kH, kW = 2, 14, 14
embed_dim = 1280
torch.manual_seed(42)
ref = RefModel(B, C, T_size, H_size, W_size, kT, kH, kW, embed_dim).cuda().eval()
sol = SolModel(B, C, T_size, H_size, W_size, kT, kH, kW, embed_dim).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(B, C, T_size, H_size, W_size, dtype=torch.bfloat16, device='cuda') * 0.5
with torch.no_grad():
ref_out = ref(x)
sol_out = sol(x)
diff = (ref_out.float() - sol_out.float()).abs()
where_bad = diff > 0.1
print('Bad values count:', where_bad.sum().item())
print('Bad pattern:')
bad_indices = torch.nonzero(where_bad.cpu())
print(bad_indices[:20])
print(bad_indices.shape)
# Check if entire patches or channels are bad
for idx in bad_indices[:5]:
b, e, t_o, h_o, w_o = idx.tolist()
r = ref_out[b, e, t_o, h_o, w_o].item()
s = sol_out[b, e, t_o, h_o, w_o].item()
print(f' [{b},{e},{t_o},{h_o},{w_o}]: ref={r:.4f} sol={s:.4f}')
" 2>&1
Bad values count: 188986
Bad pattern:
tensor([[ 0, 0, 0, 8, 0],
[ 0, 0, 0, 8, 1],
[ 0, 0, 0, 8, 2],
[ 0, 0, 0, 8, 3],
[ 0, 0, 0, 8, 4],
[ 0, 0, 0, 8, 6],
[ 0, 0, 0, 8, 7],
[ 0, 0, 0, 8, 8],
[ 0, 0, 0, 8, 9],
[ 0, 0, 0, 8, 10],
[ 0, 0, 0, 8, 12],
[ 0, 0, 0, 8, 13],
[ 0, 0, 0, 8, 14],
[ 0, 0, 0, 8, 15],
[ 0, 0, 0, 9, 0],
[ 0, 0, 0, 9, 1],
[ 0, 0, 0, 9, 3],
[ 0, 0, 0, 9, 4],
[ 0, 0, 0, 9, 5],
[ 0, 0, 0, 9, 6]])
torch.Size([188986, 5])
[0,0,0,8,0]: ref=-0.5430 sol=0.0000
[0,0,0,8,1]: ref=-0.1924 sol=0.0000
[0,0,0,8,2]: ref=-0.3047 sol=0.0000
[0,0,0,8,3]: ref=-0.5859 sol=0.0000
[0,0,0,8,4]: ref=-0.1602 sol=0.0000
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,26 +1,13 @@
"""Triton fused patch-gather + GEMM for 3D patch embedding.
-We decompose the work into:
- - Block-wise patch extraction (C * kT * kH * kW elements from x)
- - Block-wise matrix multiply with a WBF16 weight tile (K x embed_dim)
- - Write to output (B, embed_dim, T', H', W')
+Approach: Reshape x into an im2col matrix of shape (M, K) where
+ M = B * (T/kT) * (H/kH) * (W/kW) (number of patches)
+ K = C * kT * kH * kW (elements per patch)
+Then perform GEMM: (M, K) x (K, N) -> (M, N)
+ N = embed_dim
-For each output patch (b, t', h', w') and output column block, we:
- 1. Load the K = C*kT*kH*kW elements of x that form the patch
- 2. Load the corresponding K x BLOCK_N tile of the weight matrix
- 3. Multiply and accumulate
- 4. Write the result
-
-Memory layout:
- x: (B, C, T, H, W) contiguous
- weight: (embed_dim, C, kT, kH, kW) -> we reshape to (embed_dim, K)
-
-Output: (B, embed_dim, T', H', W') -> we write in (B, embed_dim, T'*H'*W') layout
- then reshape.
-
-For compute-bound GEMM, we program-block: each program handles a row of the
-output matrix (one patch x embed_dim slice). We fuse the patch gather into
-the K-loop to avoid materializing the full im2col.
+We fuse the im2col gather into the GEMM K-loop to avoid materializing
+the intermediate matrix. Each Triton program computes a BM x BN tile.
"""
import torch
@@ -28,408 +15,90 @@
import triton
import triton.language as tl
-# ---------------------------------------------------------------------------
-# Triton kernel: fused patch gather + GEMM
-# ---------------------------------------------------------------------------
-# Each program computes one output "row". The output matrix has:
-# M = B * (T//kT) * (H//kH) * (W//kW) rows (one per patch)
-# K = C * kT * kH * kW columns of the im2col matrix
-# N = embed_dim columns of the result
-#
-# We tile along K and N. Each program handles a subset of rows.
-# Within each row, we produce all N columns.
-# ---------------------------------------------------------------------------
@triton.jit
-def _patch_gemm_kernel(
- X, W, Y,
- B, C, T, H, W_dim,
- kT, kH, kW,
- T_out, H_out, W_out,
- K, M_total, N,
- stride_xb, stride_xc, stride_xt, stride_xh, stride_xw,
+def _fused_patch_gemm_kernel(
+ # Pointers
+ X_PTR, W_PTR, Y_PTR,
+ # Shape parameters
+ B_val, C_val, T_val, H_val, W_val,
+ kT_val, kH_val, kW_val,
+ T_out_val, H_out_val, W_out_val,
+ K_val, M_total, N_val,
+ # Strides for W (K, N) and Y (M, N)
stride_wk, stride_wn,
- stride_yb, stride_ye, stride_yt, stride_yh, stride_yw,
- BLOCK_K: tl.constexpr,
- BLOCK_N: tl.constexpr,
-):
- """
- Each program computes one output row (patch).
- Programs are arranged as: program_id = patch_idx (0..M_total)
- We tile along K blocks and compute full N.
- """
- pid_row = tl.program_id(0)
-
- # Decode (b, t_out, h_out, w_out) from pid_row
- w_out_idx = pid_row % W_out
- h_out_idx = (pid_row // W_out) % H_out
- t_out_idx = (pid_row // (W_out * H_out)) % T_out
- b_idx = pid_row // (W_out * H_out * T_out)
-
- # Starting position in input volume for this patch
- t_start = t_out_idx * kT
- h_start = h_out_idx * kH
- w_start = w_out_idx * kW
-
- # Strides for loading x
- # x[b, c, t_start + kt, h_start + kh, w_start + kw]
- x_stride_0 = stride_xb
- x_stride_1 = stride_xc
- x_stride_2 = stride_xt
- x_stride_3 = stride_xh
- x_stride_4 = stride_xw
-
- # Number of blocks along K
- num_k_blocks = tl.cdiv(K, BLOCK_K)
-
- # We accumulate the full row in registers (N elements per block)
- # For large N, we need to tile along N too.
- # Let's do a loop over N blocks.
- num_n_blocks = tl.cdiv(N, BLOCK_N)
-
- for n_block in range(num_n_blocks):
- acc = tl.zeros((1, BLOCK_N), dtype=tl.float32)
-
- offset_n = n_block * BLOCK_N
-
- # Mask for N dimension
- n_mask = offset_n + tl.arange(0, BLOCK_N) < N
-
- for k_block in range(num_k_blocks):
- offset_k = k_block * BLOCK_K
- k_mask = offset_k + tl.arange(0, BLOCK_K) < K
-
- # Gather x[patch] elements -> shape (BLOCK_K,)
- # We need to map each k index to (c, kt, kh, kw)
- k_idx = offset_k + tl.arange(0, BLOCK_K)
-
- c_idx = k_idx // (kT * kH * kW)
- kt_idx = (k_idx // (kH * kW)) % kT
- kh_idx = (k_idx // kW) % kH
- kw_idx = k_idx % kW
-
- # Compute linear index into x tensor
- x_idx = b_idx * x_stride_0 + c_idx * x_stride_1 + \
- (t_start + kt_idx) * x_stride_2 + \
- (h_start + kh_idx) * x_stride_3 + \
- (w_start + kw_idx) * x_stride_4
-
- # Load x values
- x_vals = tl.load(X + x_idx, mask=k_mask, other=0.0).to(tl.float32)
-
- # Load weight tile: W[offset_k:(offset_k+BLOCK_K), offset_n:(offset_n+BLOCK_N)]
- # Weight is stored as (embed_dim, K) -> (N, K) after transpose
- # Actually we store as (K, N) for this access pattern
- # w[k_idx, n_idx]
- w_ptrs = W + k_idx[:, None] * stride_wk + (offset_n + tl.arange(0, BLOCK_N)[None, :]) * stride_wn
- w_vals = tl.load(w_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0).to(tl.float32)
-
- # Compute dot product: x_vals is (BLOCK_K,), w_vals is (BLOCK_K, BLOCK_N)
- # acc += x_vals @ w_vals -> (1, BLOCK_N)
- acc = tl.dot(x_vals.to(tl.float32)[None, :], w_vals.to(tl.float32), acc)
-
- # Store accumulator
-
... (16826 more chars){
"content": "<5572 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}shape=0 variant=eager tflops=3.409 gbps=18.877 ms=0.226
shape=0 variant=compiled tflops=2.977 gbps=16.489 ms=0.259
shape=0 variant=sota tflops=3.587 gbps=19.865 ms=0.215
shape=0 variant=solution tflops=4.247 gbps=23.517 ms=0.181
shape=0 solution_peak_fraction=0.0212
shape=1 variant=eager tflops=13.620 gbps=35.522 ms=0.226
shape=1 variant=compiled tflops=11.506 gbps=30.009 ms=0.268
shape=1 variant=sota tflops=13.808 gbps=36.013 ms=0.223
shape=1 variant=solution tflops=19.997 gbps=52.157 ms=0.154
shape=1 solution_peak_fraction=0.1000
shape=2 variant=eager tflops=26.203 gbps=54.124 ms=0.265
shape=2 variant=compiled tflops=22.579 gbps=46.640 ms=0.307
shape=2 variant=sota tflops=25.840 gbps=53.376 ms=0.268
shape=2 variant=solution tflops=39.732 gbps=82.072 ms=0.175
shape=2 solution_peak_fraction=0.1987
shape=3 variant=eager tflops=6.284 gbps=24.381 ms=0.147
shape=3 variant=compiled tflops=4.959 gbps=19.240 ms=0.186
shape=3 variant=sota tflops=6.435 gbps=24.964 ms=0.144
shape=3 variant=solution tflops=18.896 gbps=73.310 ms=0.049
shape=3 solution_peak_fraction=0.0945
peak_fraction: 0.0795
RESULT: LOW
{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"todos": [
{
"content": "Understand problem and verify shapes",
"status": "completed",
"priority": "high"
},
{
"content": "Implement Triton fused patch-gather + GEMM kernel in solution.py - DONE, passes check.py",
"status": "completed",
"priority": "high"
},
{
"content": "Profile and optimize with benchmark.py - currently peak_fraction=0.0795, need to improve",
"status": "in_progress",
"priority": "medium"
}
]
}content: 'Understand problem and verify shapes'
status: 'completed'
priority: 'high'
content: 'Implement Triton fused patch-gather + GEMM kernel in solution.py - DONE, passes check.py'
status: 'completed'
priority: 'high'
content: 'Profile and optimize with benchmark.py - currently peak_fraction=0.0795, need to improve'
status: 'in_progress'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,13 +1,12 @@
"""Triton fused patch-gather + GEMM for 3D patch embedding.
-Approach: Reshape x into an im2col matrix of shape (M, K) where
- M = B * (T/kT) * (H/kH) * (W/kW) (number of patches)
- K = C * kT * kH * kW (elements per patch)
-Then perform GEMM: (M, K) x (K, N) -> (M, N)
- N = embed_dim
+Optimized kernel that fuses im2col (patch extraction) with GEMM.
-We fuse the im2col gather into the GEMM K-loop to avoid materializing
-the intermediate matrix. Each Triton program computes a BM x BN tile.
+Key optimizations:
+1. Precompute x base offsets (per row) outside the k-loop to avoid
+ repeated division/multiplication in the inner loop.
+2. Use num_stages pipelining to overlap memory and compute.
+3. Autotune block sizes per problem shape.
"""
import torch
@@ -20,14 +19,14 @@
def _fused_patch_gemm_kernel(
# Pointers
X_PTR, W_PTR, Y_PTR,
- # Shape parameters
- B_val, C_val, T_val, H_val, W_val,
+ # Shape parameters (compile-time constants for efficiency)
+ C_val, T_val, H_val, W_val,
kT_val, kH_val, kW_val,
T_out_val, H_out_val, W_out_val,
- K_val, M_total, N_val,
+ M_total, N_val, K_val,
# Strides for W (K, N) and Y (M, N)
- stride_wk, stride_wn,
stride_ym, stride_yn,
+ # Block sizes
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
@@ -53,13 +52,23 @@
h_start = h_out_arr * kH_val
w_start = w_out_arr * kW_val
+ # Precompute base_x for each row: offset into x for (b, t_start, h_start, w_start, 0, 0)
+ # x index = b * (C*T*H*W) + t * (H*W) + h * W + w + c * (T*H*W) + offset_in_channel
+ # base_x = b * stride_b + t_start * stride_t + h_start * stride_h + w_start
+ stride_b = C_val * T_val * H_val * W_val
+ stride_c = T_val * H_val * W_val
+ stride_t = H_val * W_val
+ stride_h = W_val
+
+ base_x = b_arr * stride_b + t_start * stride_t + h_start * stride_h + w_start
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_base in range(0, K_val, BLOCK_K):
offs_k = k_base + tl.arange(0, BLOCK_K)
- k_mask = offs_k < K_val
+ k_mask_l = offs_k < K_val
- # Map k -> (c, kt, kh, kw)
+ # Map k -> (c, kt, kh, kw) - computed for BLOCK_K elements
k_per_channel = kT_val * kH_val * kW_val
c = offs_k // k_per_channel
rem = offs_k % k_per_channel
@@ -69,31 +78,35 @@
kh = rem // kW_val
kw = rem % kW_val
- # Load x: shape (BLOCK_M, BLOCK_K)
- # x[b, c, t_start+kt, h_start+kh, w_start+kw]
- # Input strides: (C*T*H*W, T*H*W, H*W, W, 1)
- st_t = H_val * W_val
- st_h = W_val
+ # Compute offset_k[k] = c * stride_c + kt * stride_t + kh * stride_h + kw
+ offset_k = c * stride_c + kt * stride_t + kh * stride_h + kw
- x_idx = b_arr[:, None] * (C_val * T_val * H_val * W_val) + \
- c[None, :] * (T_val * H_val * W_val) + \
- (t_start[:, None] + kt[None, :]) * st_t + \
- (h_start[:, None] + kh[None, :]) * st_h + \
- (w_start[:, None] + kw[None, :])
+ # Load x: x_idx[m, k] = base_x[m] + offset_k[k]
+ x_idx = base_x[:, None] + offset_k[None, :]
+ x_val = tl.load(X_PTR + x_idx, mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
- x_val = tl.load(X_PTR + x_idx, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.bfloat16)
+ # Load W: shape (BLOCK_K, BLOCK_N) - W is (K, N) row-major
+ w_idx = offs_k[:, None] * N_val + offs_n[None, :]
+ w_val = tl.load(W_PTR + w_idx, mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
- # Load W: shape (BLOCK_K, BLOCK_N)
- # W is (K, N) row-major
- w_idx = offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn
- w_val = tl.load(W_PTR + w_idx, mask=k_mask[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
-
- # tl.dot with bf16 inputs uses tensor cores
+ # Tensor core dot product
acc = tl.dot(x_val, w_val, acc)
- # Write output
+ # Write output with correct row-major order: y_idx[m][n] = off_m[m].
y_idx = offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
tl.store(Y_PTR + y_idx, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+
+
+def _get_config(M, N, K):
+ """Select an appropriate autotune config based on problem size."""
+ if N <= 512:
+ return {'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'num_stages': 2, 'num_warps': 4}
+ elif M <= 256:
+ return {'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'num_stages': 2, 'num_warps': 4}
+ elif M <= 1024:
+ return {'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'num_stages': 3, 'num_warps': 8}
+ else:
+ return {'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'num_stages': 3, 'num_warps': 8}
# ---------------------------------------------------------------------------
@@ -141,21 +154,20 @@
# Output buffer: (M, N) contiguous
y = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
- grid = (triton.cdiv(M, 128), triton.cdiv(N, 128))
+ cfg = _get_config(M, N, K)
+ grid = (triton.cdiv(M, cfg['BLOCK_M']), triton.cdiv(N, cfg['BLOCK_N']))
_fused_patch_gemm_kernel[grid](
x, weight, y,
- self.B, self.C, self.T, self.H, self.W,
+ self.C, self.T, self.H, self.W,
self.kT, self.kH, self.kW,
self.T_out, self.H_out, self.W_out,
- K, M, N,
- N, # stride_wk
- 1, # stride_wn
+ M, N, K,
N, # stride_ym
1, # stride_yn
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=64,
+
... (226 more chars){
"content": "<6389 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run check and benchmark"
}PASS
shape=0 variant=eager tflops=3.569 gbps=19.764 ms=0.216
shape=0 variant=compiled tflops=3.024 gbps=16.747 ms=0.255
shape=0 variant=sota tflops=3.670 gbps=20.324 ms=0.210
shape=0 variant=solution tflops=8.796 gbps=48.713 ms=0.088
shape=0 solution_peak_fraction=0.0440
shape=1 variant=eager tflops=13.673 gbps=35.661 ms=0.225
shape=1 variant=compiled tflops=11.235 gbps=29.302 ms=0.274
shape=1 variant=sota tflops=13.847 gbps=36.114 ms=0.223
shape=1 variant=solution tflops=21.148 gbps=55.156 ms=0.146
shape=1 solution_peak_fraction=0.1057
shape=2 variant=eager tflops=26.166 gbps=54.049 ms=0.265
shape=2 variant=compiled tflops=22.390 gbps=46.250 ms=0.310
shape=2 variant=sota tflops=25.773 gbps=53.236 ms=0.269
shape=2 variant=solution tflops=21.987 gbps=45.417 ms=0.315
shape=2 solution_peak_fraction=0.1099
shape=3 variant=eager tflops=6.213 gbps=24.106 ms=0.149
shape=3 variant=compiled tflops=4.839 gbps=18.774 ms=0.191
shape=3 variant=sota tflops=6.381 gbps=24.755 ms=0.145
shape=3 variant=solution tflops=17.046 gbps=66.133 ms=0.054
shape=3 solution_peak_fraction=0.0852
peak_fraction: 0.0812
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,12 +1,12 @@
"""Triton fused patch-gather + GEMM for 3D patch embedding.
-Optimized kernel that fuses im2col (patch extraction) with GEMM.
+Heavily optimized kernel that fuses im2col (patch extraction) with GEMM.
Key optimizations:
-1. Precompute x base offsets (per row) outside the k-loop to avoid
- repeated division/multiplication in the inner loop.
-2. Use num_stages pipelining to overlap memory and compute.
-3. Autotune block sizes per problem shape.
+1. Precompute x base offsets (per row) outside the k-loop
+2. Aggressive num_stages for pipelining memory/compute overlap
+3. Larger N blocks for compute-bound shapes
+4. tl.dot with bf16 inputs -> tensor cores
"""
import torch
@@ -19,7 +19,7 @@
def _fused_patch_gemm_kernel(
# Pointers
X_PTR, W_PTR, Y_PTR,
- # Shape parameters (compile-time constants for efficiency)
+ # Shape parameters
C_val, T_val, H_val, W_val,
kT_val, kH_val, kW_val,
T_out_val, H_out_val, W_out_val,
@@ -52,9 +52,7 @@
h_start = h_out_arr * kH_val
w_start = w_out_arr * kW_val
- # Precompute base_x for each row: offset into x for (b, t_start, h_start, w_start, 0, 0)
- # x index = b * (C*T*H*W) + t * (H*W) + h * W + w + c * (T*H*W) + offset_in_channel
- # base_x = b * stride_b + t_start * stride_t + h_start * stride_h + w_start
+ # Precompute base_x for each row
stride_b = C_val * T_val * H_val * W_val
stride_c = T_val * H_val * W_val
stride_t = H_val * W_val
@@ -68,7 +66,7 @@
offs_k = k_base + tl.arange(0, BLOCK_K)
k_mask_l = offs_k < K_val
- # Map k -> (c, kt, kh, kw) - computed for BLOCK_K elements
+ # Map k -> (c, kt, kh, kw)
k_per_channel = kT_val * kH_val * kW_val
c = offs_k // k_per_channel
rem = offs_k % k_per_channel
@@ -78,35 +76,99 @@
kh = rem // kW_val
kw = rem % kW_val
- # Compute offset_k[k] = c * stride_c + kt * stride_t + kh * stride_h + kw
offset_k = c * stride_c + kt * stride_t + kh * stride_h + kw
- # Load x: x_idx[m, k] = base_x[m] + offset_k[k]
x_idx = base_x[:, None] + offset_k[None, :]
x_val = tl.load(X_PTR + x_idx, mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
- # Load W: shape (BLOCK_K, BLOCK_N) - W is (K, N) row-major
w_idx = offs_k[:, None] * N_val + offs_n[None, :]
w_val = tl.load(W_PTR + w_idx, mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
- # Tensor core dot product
acc = tl.dot(x_val, w_val, acc)
- # Write output with correct row-major order: y_idx[m][n] = off_m[m].
y_idx = offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
tl.store(Y_PTR + y_idx, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
-def _get_config(M, N, K):
- """Select an appropriate autotune config based on problem size."""
- if N <= 512:
- return {'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'num_stages': 2, 'num_warps': 4}
- elif M <= 256:
- return {'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'num_stages': 2, 'num_warps': 4}
- elif M <= 1024:
- return {'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'num_stages': 3, 'num_warps': 8}
- else:
- return {'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'num_stages': 3, 'num_warps': 8}
+# Autotune for best block size per shape
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=2, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=2, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=2, num_warps=16),
+ ],
+ key=['M_total', 'N_val', 'K_val'],
+)
+@triton.jit
+def _fused_patch_gemm_kernel_autotune(
+ X_PTR, W_PTR, Y_PTR,
+ C_val, T_val, H_val, W_val,
+ kT_val, kH_val, kW_val,
+ T_out_val, H_out_val, W_out_val,
+ M_total, N_val, K_val,
+ stride_ym, stride_yn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_m = tl.program_id(axis=0)
+ pid_n = tl.program_id(axis=1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ m_mask = offs_m < M_total
+ n_mask = offs_n < N_val
+
+ w_out_arr = offs_m % W_out_val
+ h_out_arr = (offs_m // W_out_val) % H_out_val
+ t_out_arr = (offs_m // (W_out_val * H_out_val)) % T_out_val
+ b_arr = offs_m // (W_out_val * H_out_val * T_out_val)
+
+ t_start = t_out_arr * kT_val
+ h_start = h_out_arr * kH_val
+ w_start = w_out_arr * kW_val
+
+ stride_b = C_val * T_val * H_val * W_val
+ stride_c = T_val * H_val * W_val
+ stride_t = H_val * W_val
+ stride_h = W_val
+
+ base_x = b_arr * stride_b + t_start * stride_t + h_start * stride_h + w_start
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_base in range(0, K_val, BLOCK_K):
+ offs_k = k_base + tl.arange(0, BLOCK_K)
+ k_mask_l = offs_k < K_val
+
+ k_per_channel = kT_val * kH_val * kW_val
+ c = offs_k // k_per_channel
+ rem = offs_k
... (1751 more chars){
"content": "<8153 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python benchmark.py 2>&1",
"description": "Run benchmark for autotune"
}shape=0 variant=eager tflops=3.596 gbps=19.911 ms=0.214
shape=0 variant=compiled tflops=3.054 gbps=16.913 ms=0.252
shape=0 variant=sota tflops=3.676 gbps=20.358 ms=0.210
shape=0 variant=solution tflops=6.980 gbps=38.654 ms=0.110
shape=0 solution_peak_fraction=0.0349
shape=1 variant=eager tflops=13.668 gbps=35.648 ms=0.226
shape=1 variant=compiled tflops=11.174 gbps=29.142 ms=0.276
shape=1 variant=sota tflops=13.912 gbps=36.284 ms=0.222
shape=1 variant=solution tflops=37.435 gbps=97.635 ms=0.082
shape=1 solution_peak_fraction=0.1872
shape=2 variant=eager tflops=25.616 gbps=52.912 ms=0.271
shape=2 variant=compiled tflops=22.368 gbps=46.204 ms=0.310
shape=2 variant=sota tflops=25.468 gbps=52.608 ms=0.272
shape=2 variant=solution tflops=71.232 gbps=147.139 ms=0.097
shape=2 solution_peak_fraction=0.3562
shape=3 variant=eager tflops=6.295 gbps=24.423 ms=0.147
shape=3 variant=compiled tflops=4.935 gbps=19.148 ms=0.187
shape=3 variant=sota tflops=6.411 gbps=24.873 ms=0.144
shape=3 variant=solution tflops=17.725 gbps=68.769 ms=0.052
shape=3 solution_peak_fraction=0.0886
peak_fraction: 0.1198
RESULT: OK
{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check on autotuned kernel"
}FAIL: shape 0 {'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280} seed 123: tolerance exceeded: max_abs_diff=1.51562 (atol=0.02, rtol=0.02)
{
"command": "<5489 chars \u2014 see below>",
"timeout": 300000
}uv run python -c "
import torch
import torch.nn as nn
import triton
import triton.language as tl
# Test each shape with specific configs to find the best one
def benchmark_patch_gemm(B,C,T,H,W,kT,kH,kW,embed_dim,
BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps):
M_total = B * (T//kT) * (H//kH) * (W//kW)
N = embed_dim
K = C * kT * kH * kW
T_out, H_out, W_out = T//kT, H//kH, W//kW
@triton.jit
def _fused_patch_gemm_kernel(
X_PTR, W_PTR, Y_PTR,
C_val, T_val, H_val, W_val,
kT_val, kH_val, kW_val,
T_out_val, H_out_val, W_out_val,
M_total, N_val, K_val,
stride_ym, stride_yn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = offs_m < M_total
n_mask = offs_n < N_val
w_out_arr = offs_m % W_out_val
h_out_arr = (offs_m // W_out_val) % H_out_val
t_out_arr = (offs_m // (W_out_val * H_out_val)) % T_out_val
b_arr = offs_m // (W_out_val * H_out_val * T_out_val)
t_start = t_out_arr * kT_val
h_start = h_out_arr * kH_val
w_start = w_out_arr * kW_val
stride_b = C_val * T_val * H_val * W_val
stride_c = T_val * H_val * W_val
stride_t = H_val * W_val
stride_h = W_val
base_x = b_arr * stride_b + t_start * stride_t + h_start * stride_h + w_start
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_base in range(0, K_val, BLOCK_K):
offs_k = k_base + tl.arange(0, BLOCK_K)
k_mask_l = offs_k < K_val
k_per_channel = kT_val * kH_val * kW_val
c = offs_k // k_per_channel
rem = offs_k % k_per_channel
kHkW = kH_val * kW_val
kt = rem // kHkW
rem = rem % kHkW
kh = rem // kW_val
kw = rem % kW_val
offset_k = c * stride_c + kt * stride_t + kh * stride_h + kw
x_idx = base_x[:, None] + offset_k[None, :]
x_val = tl.load(X_PTR + x_idx, mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
w_idx = offs_k[:, None] * N_val + offs_n[None, :]
w_val = tl.load(W_PTR + w_idx, mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
acc = tl.dot(x_val, w_val, acc)
y_idx = offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
tl.store(Y_PTR + y_idx, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16, device='cuda') * 0.5
weight = torch.randn(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16, device='cuda')
w = weight.view(N, K).t().contiguous()
y = torch.empty(M_total, N, dtype=torch.bfloat16, device='cuda')
grid = (triton.cdiv(M_total, BLOCK_M), triton.cdiv(N, BLOCK_N))
fn = _fused_patch_gemm_kernel[grid](
x, w, y,
C, T, H, W,
kT, kH, kW,
T_out, H_out, W_out,
M_total, N, K,
N, 1,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
# benchmark with 30 runs
ms = triton.testing.do_bench(lambda: _fused_patch_gemm_kernel[grid](
x, w, y,
C, T, H, W,
kT, kH, kW,
T_out, H_out, W_out,
M_total, N, K,
N, 1,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
), quantiles=None, warmup=5, rep=20)
torch.cuda.synchronize()
print(f' BM={BLOCK_M} BN={BLOCK_N} BK={BLOCK_K} stg={num_stages} w={num_warps} -> {ms:.3f} ms')
# Check correctness
ref_out = torch.nn.functional.conv3d(x, weight, stride=(kT, kH, kW))
sol_out = y.view(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3).contiguous()
diff = (ref_out.float() - sol_out.float()).abs(
... (1489 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+"""Tune Triton block configs for each shape."""
+import torch
+import triton
+import triton.language as tl
+
+def run_shape(B, C, T, H, W, kT, kH, kW, embed_dim):
+ M_total = B * (T//kT) * (H//kH) * (W//kW)
+ N = embed_dim
+ K = C * kT * kH * kW
+ T_out, H_out, W_out = T//kT, H//kH, W//kW
+
+ x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16, device='cuda') * 0.5
+ weight = torch.randn(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16, device='cuda')
+ w = weight.view(N, K).t().contiguous()
+
+ configs = [
+ (64, 64, 64, 2, 4), (128, 64, 64, 3, 4), (128, 128, 64, 3, 8),
+ (128, 256, 64, 3, 8), (64, 256, 64, 3, 4), (64, 128, 64, 3, 4),
+ (128, 128, 128, 3, 8), (128, 256, 128, 2, 8), (256, 128, 64, 3, 8),
+ (32, 128, 64, 2, 4), (64, 512, 64, 2, 8), (128, 512, 64, 2, 8),
+ (256, 256, 64, 2, 16), (256, 128, 128, 2, 8),
+ ]
+
+ best_ms = 1e9
+ best_cfg = None
+
+ for BM, BN, BK, stg, wps in configs:
+ if BM > M_total:
+ continue
+
+ @triton.jit
+ def kernel(X, W, Y, stride_ym, stride_yn,
+ BLOCK_M, BLOCK_N, BLOCK_K):
+ pid_m, pid_n = tl.program_id(0), tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ m_mask = offs_m < M_total
+ n_mask = offs_n < N
+
+ w_out_arr = offs_m % W_out
+ h_out_arr = (offs_m // W_out) % H_out
+ t_out_arr = (offs_m // (W_out * H_out)) % T_out
+ b_arr = offs_m // (W_out * H_out * T_out)
+
+ base_x = b_arr * (C*T*H*W) + (t_out_arr * kT) * (H*W) + (h_out_arr * kH) * W + (w_out_arr * kW)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_base in range(0, K, BLOCK_K):
+ offs_k = k_base + tl.arange(0, BLOCK_K)
+ k_mask_l = offs_k < K
+ k_per_ch = kT * kH * kW
+ c_idx = offs_k // k_per_ch
+ rem = offs_k % k_per_ch
+ khkW = kH * kW
+ kt = rem // khkW
+ rem2 = rem % khkW
+ kh = rem2 // kW
+ kw = rem2 % kW
+ off_k = c_idx * (T*H*W) + kt * (H*W) + kh * W + kw
+ x_val = tl.load(X + base_x[:, None] + off_k[None, :],
+ mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
+ w_val = tl.load(W + offs_k[:, None] * N + offs_n[None, :],
+ mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
+ acc = tl.dot(x_val, w_val, acc)
+
+ tl.store(Y + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn,
+ acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+
+ y = torch.empty(M_total, N, dtype=torch.bfloat16, device='cuda')
+ grid = (triton.cdiv(M_total, BM), triton.cdiv(N, BN))
+
+ ms = triton.testing.do_bench(
+ lambda: kernel[grid](x, w, y, N, 1, BM, BN, BK),
+ quantiles=None, warmup=5, rep=20
+ )
+
+ # Check correctness
+ sol_out = y.view(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3)
+ ref_out = torch.nn.functional.conv3d(x, weight, stride=(kT, kH, kW))
+ diff = (ref_out.float() - sol_out.float()).abs().max().item()
+
+ if diff < 0.02 and ms < best_ms:
+ best_ms = ms
+ best_cfg = (BM, BN, BK, stg, wps)
+
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} -> {ms:.3f}ms {'OK' if diff < 0.02 else 'BAD(' + str(diff)[:5] + ')'}")
+
+ print(f" Best: {best_cfg} -> {best_ms:.3f}ms")
+ return best_cfg
+
+shapes = [
+ (1, 3, 2, 224, 224, 2, 14, 14, 1280),
+ (2, 3, 4, 224, 224, 2, 14, 14, 1280),
+ (1, 3, 8, 336, 336, 2, 14, 14, 1280),
+ (4, 3, 1, 224, 224, 1, 16, 16, 768),
+]
+
+for i, s in enumerate(shapes):
+ print(f"\nShape {i}: B={s[0]} C={s[2]} T={s[3]} H={s[4]} W={s[5]} kT={s[6]} kH={s[7]} kW={s[8]} N={s[9]}")
+ run_shape(*s)
{
"content": "<4136 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,12 +1,10 @@
"""Triton fused patch-gather + GEMM for 3D patch embedding.
-Heavily optimized kernel that fuses im2col (patch extraction) with GEMM.
+Strategy: Fused im2col + GEMM with precomputed per-row offsets.
+The key insight is that for each row (patch), we can compute a single base offset
+into x, then add per-channel offsets (c,kt,kh,kw) during the K-loop.
-Key optimizations:
-1. Precompute x base offsets (per row) outside the k-loop
-2. Aggressive num_stages for pipelining memory/compute overlap
-3. Larger N blocks for compute-bound shapes
-4. tl.dot with bf16 inputs -> tensor cores
+Block sizes are chosen per-shape for optimal performance.
"""
import torch
@@ -15,100 +13,15 @@
import triton.language as tl
+# ---------------------------------------------------------------------------
+# Fused kernel: per-row base offsets + K-loop GEMM
+# ---------------------------------------------------------------------------
+
@triton.jit
def _fused_patch_gemm_kernel(
- # Pointers
X_PTR, W_PTR, Y_PTR,
- # Shape parameters
- C_val, T_val, H_val, W_val,
- kT_val, kH_val, kW_val,
- T_out_val, H_out_val, W_out_val,
- M_total, N_val, K_val,
- # Strides for W (K, N) and Y (M, N)
- stride_ym, stride_yn,
- # Block sizes
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
-):
- pid_m = tl.program_id(axis=0)
- pid_n = tl.program_id(axis=1)
-
- # Starting positions
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-
- m_mask = offs_m < M_total
- n_mask = offs_n < N_val
-
- # Decode each m index to (b, t_out, h_out, w_out)
- w_out_arr = offs_m % W_out_val
- h_out_arr = (offs_m // W_out_val) % H_out_val
- t_out_arr = (offs_m // (W_out_val * H_out_val)) % T_out_val
- b_arr = offs_m // (W_out_val * H_out_val * T_out_val)
-
- # Starting positions in input volume
- t_start = t_out_arr * kT_val
- h_start = h_out_arr * kH_val
- w_start = w_out_arr * kW_val
-
- # Precompute base_x for each row
- stride_b = C_val * T_val * H_val * W_val
- stride_c = T_val * H_val * W_val
- stride_t = H_val * W_val
- stride_h = W_val
-
- base_x = b_arr * stride_b + t_start * stride_t + h_start * stride_h + w_start
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k_base in range(0, K_val, BLOCK_K):
- offs_k = k_base + tl.arange(0, BLOCK_K)
- k_mask_l = offs_k < K_val
-
- # Map k -> (c, kt, kh, kw)
- k_per_channel = kT_val * kH_val * kW_val
- c = offs_k // k_per_channel
- rem = offs_k % k_per_channel
- kHkW = kH_val * kW_val
- kt = rem // kHkW
- rem = rem % kHkW
- kh = rem // kW_val
- kw = rem % kW_val
-
- offset_k = c * stride_c + kt * stride_t + kh * stride_h + kw
-
- x_idx = base_x[:, None] + offset_k[None, :]
- x_val = tl.load(X_PTR + x_idx, mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
-
- w_idx = offs_k[:, None] * N_val + offs_n[None, :]
- w_val = tl.load(W_PTR + w_idx, mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
-
- acc = tl.dot(x_val, w_val, acc)
-
- y_idx = offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
- tl.store(Y_PTR + y_idx, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
-
-
-# Autotune for best block size per shape
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=2, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=2, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=2, num_warps=16),
- ],
- key=['M_total', 'N_val', 'K_val'],
-)
-@triton.jit
-def _fused_patch_gemm_kernel_autotune(
- X_PTR, W_PTR, Y_PTR,
- C_val, T_val, H_val, W_val,
+ # Precomputed strides
+ stride_x_c, stride_x_t, stride_x_h,
kT_val, kH_val, kW_val,
T_out_val, H_out_val, W_out_val,
M_total, N_val, K_val,
@@ -117,8 +30,13 @@
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
- pid_m = tl.program_id(axis=0)
- pid_n = tl.program_id(axis=1)
+ """
+ Each program handles BLOCK_M rows x BLOCK_N cols of output.
+ For each row, we precompute the base x offset, then in the K-loop
+ we add the per-element offset.
+ """
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
@@ -126,21 +44,14 @@
m_mask = offs_m < M_total
n_mask = offs_n < N_val
+ # Decode m -> (b, t_out, h_out, w_out)
w_out_arr = offs_m % W_out_val
- h_out_arr = (offs_m // W_out_val) % H_out_val
- t_out_arr = (offs_m // (W_out_val * H_out_val)) % T_out_val
+ h_out_arr = ((offs_m // W_out_val) % H_out_val)
+ t_out_arr = ((offs_m // (W_out_val * H_out_val)) % T_out_val)
b_arr = offs_m // (W_out_val * H_out_val * T_out_val)
- t_start = t_out_arr * kT_val
- h_start = h_out_arr * kH_val
- w_start = w_out_arr * kW_val
-
- stride_b = C_val
... (4097 more chars){
"content": "<2547 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,10 +1,13 @@
"""Triton fused patch-gather + GEMM for 3D patch embedding.
-Strategy: Fused im2col + GEMM with precomputed per-row offsets.
-The key insight is that for each row (patch), we can compute a single base offset
-into x, then add per-channel offsets (c,kt,kh,kw) during the K-loop.
-
-Block sizes are chosen per-shape for optimal performance.
+For each output patch (b,t',h',w') and each output column block:
+1. Gather the K = C*kT*kH*kW elements that form the patch (indexed as
+ x[b, c, t'*kT+kt, h'*kH+kh, w'*kW+kw])
+2. Multiply with the K × N weight tile
+3. Accumulate and store
+
+The kernel fuses the patch gather into the GEMM K-loop to avoid materializing
+a full im2col intermediate tensor.
"""
import torch
@@ -13,27 +16,25 @@
import triton.language as tl
-# ---------------------------------------------------------------------------
-# Fused kernel: per-row base offsets + K-loop GEMM
-# ---------------------------------------------------------------------------
-
@triton.jit
def _fused_patch_gemm_kernel(
X_PTR, W_PTR, Y_PTR,
- # Precomputed strides
- stride_x_c, stride_x_t, stride_x_h,
+ C_val, T_val, H_val, W_val,
kT_val, kH_val, kW_val,
T_out_val, H_out_val, W_out_val,
M_total, N_val, K_val,
+ stride_wk, stride_wn,
stride_ym, stride_yn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
- Each program handles BLOCK_M rows x BLOCK_N cols of output.
- For each row, we precompute the base x offset, then in the K-loop
- we add the per-element offset.
+ Each program computes a BLOCK_M × BLOCK_N tile of (M_total, N).
+
+ X: (B, C, T, H, W) contiguous bf16
+ W: (K, N) row-major bf16
+ Y: (M, N) row-major bf16 (M = total patches)
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -44,14 +45,22 @@
m_mask = offs_m < M_total
n_mask = offs_n < N_val
- # Decode m -> (b, t_out, h_out, w_out)
+ # Decode m → (b, t_out, h_out, w_out)
w_out_arr = offs_m % W_out_val
h_out_arr = ((offs_m // W_out_val) % H_out_val)
t_out_arr = ((offs_m // (W_out_val * H_out_val)) % T_out_val)
b_arr = offs_m // (W_out_val * H_out_val * T_out_val)
- # Precompute base offset into x for the start of each patch (c=0, kt=0, kh=0, kw=0)
- base_x = b_arr * (stride_x_c * 0) + t_out_arr * kT_val * stride_x_t + h_out_arr * kH_val * stride_x_h + w_out_arr * kW_val
+ # Base offset into x for [b, 0, t_start, h_start, w_start]
+ t_start = t_out_arr * kT_val
+ h_start = h_out_arr * kH_val
+ w_start = w_out_arr * kW_val
+
+ stride_c = T_val * H_val * W_val
+ stride_t = H_val * W_val
+ stride_h = W_val
+
+ base_x = b_arr * (C_val * T_val * H_val * W_val) + t_start * stride_t + h_start * stride_h + w_start
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
@@ -59,13 +68,194 @@
offs_k = k_base + tl.arange(0, BLOCK_K)
k_mask_l = offs_k < K_val
- offset_k = c_base * stride_x_c + kt_base * stride_x_t + kh_base * stride_x_h + kw_base
-
- x_vals = tl.load(X_PTR + x_idx, mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
-
- w_vals = tl.load(W_PTR + w_idx, mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
-
- acc = tl.dot(x_vals, w_vals, acc)
-
- y_idx = offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
- tl.store(Y_PTR + y_idx, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+ # k → (c, kt, kh, kw)
+ k_per_ch = kT_val * kH_val * kW_val
+ c_idx = offs_k // k_per_ch
+ rem = offs_k % k_per_ch
+ kHkW = kH_val * kW_val
+ kt_idx = rem // kHkW
+ rem2 = rem % kHkW
+ kh_idx = rem2 // kW_val
+ kw_idx = rem2 % kW_val
+
+ # offset within a patch for each k element
+ off_k = c_idx * stride_c + kt_idx * stride_t + kh_idx * stride_h + kw_idx
+
+ x_val = tl.load(X_PTR + base_x[:, None] + off_k[None, :],
+ mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
+
+ w_val = tl.load(W_PTR + offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
+
+ acc = tl.dot(x_val, w_val, acc)
+
+ tl.store(Y_PTR + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn,
+ acc.to(tl.bfloat16),
+ mask=m_mask[:, None] & n_mask[None, :])
+
+
+# Autotube configs
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=2, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=2, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 512, 'BLOCK_K': 64}, num_stages=2, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 512, 'BLOCK_K': 64}, num_stages=2, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=2, num_warps=16),
+ ],
+ key=['M_total', 'N_val', 'K_val'],
+)
+@triton.jit
+def _fused_patch_gemm_kernel_auto(
+ X_PTR, W_PTR, Y_PTR,
+
... (4452 more chars){
"content": "<8603 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check"
}FAIL: shape 0 {'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280} seed 123: tolerance exceeded: max_abs_diff=1.52344 (atol=0.02, rtol=0.02)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,13 +1,6 @@
"""Triton fused patch-gather + GEMM for 3D patch embedding.
-For each output patch (b,t',h',w') and each output column block:
-1. Gather the K = C*kT*kH*kW elements that form the patch (indexed as
- x[b, c, t'*kT+kt, h'*kH+kh, w'*kW+kw])
-2. Multiply with the K × N weight tile
-3. Accumulate and store
-
-The kernel fuses the patch gather into the GEMM K-loop to avoid materializing
-a full im2col intermediate tensor.
+Fuses im2col (patch extraction) with GEMM.
"""
import torch
@@ -16,8 +9,12 @@
import triton.language as tl
+# ---------------------------------------------------------------------------
+# Fused patch-gather + GEMM kernel
+# ---------------------------------------------------------------------------
+
@triton.jit
-def _fused_patch_gemm_kernel(
+def _fused_gemm_kernel(
X_PTR, W_PTR, Y_PTR,
C_val, T_val, H_val, W_val,
kT_val, kH_val, kW_val,
@@ -28,16 +25,21 @@
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
):
- """
- Each program computes a BLOCK_M × BLOCK_N tile of (M_total, N).
-
- X: (B, C, T, H, W) contiguous bf16
- W: (K, N) row-major bf16
- Y: (M, N) row-major bf16 (M = total patches)
- """
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
+ """Fused patch gather + matrix multiply."""
+ pid = tl.program_id(0)
+ num_pid_m = tl.num_programs(0)
+ num_pid_n = tl.num_programs(1)
+ num_pid = num_pid_m * num_pid_n
+
+ # Column-major ordering for L2 cache efficiency
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(M_total - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
@@ -45,13 +47,13 @@
m_mask = offs_m < M_total
n_mask = offs_n < N_val
- # Decode m → (b, t_out, h_out, w_out)
+ # Decode m -> (b, t_out, h_out, w_out)
w_out_arr = offs_m % W_out_val
h_out_arr = ((offs_m // W_out_val) % H_out_val)
t_out_arr = ((offs_m // (W_out_val * H_out_val)) % T_out_val)
b_arr = offs_m // (W_out_val * H_out_val * T_out_val)
- # Base offset into x for [b, 0, t_start, h_start, w_start]
+ # Precompute base offset into x for each row
t_start = t_out_arr * kT_val
h_start = h_out_arr * kH_val
w_start = w_out_arr * kW_val
@@ -68,92 +70,7 @@
offs_k = k_base + tl.arange(0, BLOCK_K)
k_mask_l = offs_k < K_val
- # k → (c, kt, kh, kw)
- k_per_ch = kT_val * kH_val * kW_val
- c_idx = offs_k // k_per_ch
- rem = offs_k % k_per_ch
- kHkW = kH_val * kW_val
- kt_idx = rem // kHkW
- rem2 = rem % kHkW
- kh_idx = rem2 // kW_val
- kw_idx = rem2 % kW_val
-
- # offset within a patch for each k element
- off_k = c_idx * stride_c + kt_idx * stride_t + kh_idx * stride_h + kw_idx
-
- x_val = tl.load(X_PTR + base_x[:, None] + off_k[None, :],
- mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
-
- w_val = tl.load(W_PTR + offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
- mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
-
- acc = tl.dot(x_val, w_val, acc)
-
- tl.store(Y_PTR + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn,
- acc.to(tl.bfloat16),
- mask=m_mask[:, None] & n_mask[None, :])
-
-
-# Autotube configs
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=2, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=2, num_warps=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 512, 'BLOCK_K': 64}, num_stages=2, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 512, 'BLOCK_K': 64}, num_stages=2, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=2, num_warps=16),
- ],
- key=['M_total', 'N_val', 'K_val'],
-)
-@triton.jit
-def _fused_patch_gemm_kernel_auto(
- X_PTR, W_PTR, Y_PTR,
- C_val, T_val, H_val, W_val,
- kT_val, kH_val, kW_val,
- T_out_val, H_out_val, W_out_val,
- M_total, N_val, K_val,
- stride_wk, stride_wn,
- stride_ym, stride_yn,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
-):
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
-
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-
- m_mask = offs_m < M_total
- n_mask = offs_n < N_val
-
- w_out_arr = offs_m % W_out_val
- h_out_arr = ((offs_m // W_out_val) % H_out_val)
- t_out_arr = ((offs_m // (W_out_val * H_out_val)) % T_out_val)
- b_arr = offs_m // (W_out_val * H_out_val * T_out_val)
-
- t_start = t_out_arr * kT_val
- h_start = h_out_arr * kH_val
- w_start = w_out_arr * kW_val
-
- stride_c = T_val * H_
... (2893 more chars){
"content": "<5695 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run check and benchmark"
}FAIL: shape 0 {'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280} seed 42: tolerance exceeded: max_abs_diff=1.5625 (atol=0.02, rtol=0.02)
{
"command": "<1132 chars \u2014 see below>",
"description": "Debug column-major indexing issue"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import torch
import torch.nn as nn
from reference import Model as RefModel
from solution import Model as SolModel
B, C, T_size, H_size, W_size = 1, 3, 2, 224, 224
kT, kH, kW = 2, 14, 14
embed_dim = 1280
torch.manual_seed(42)
ref = RefModel(B, C, T_size, H_size, W_size, kT, kH, kW, embed_dim).cuda().eval()
sol = SolModel(B, C, T_size, H_size, W_size, kT, kH, kW, embed_dim).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
x = torch.randn(B, C, T_size, H_size, W_size, dtype=torch.bfloat16, device='cuda') * 0.5
with torch.no_grad():
ref_out = ref(x)
sol_out = sol(x)
diff = (ref_out.float() - sol_out.float()).abs()
print('max_diff:', diff.max().item())
bad = diff > 0.1
print('bad count:', bad.sum().item())
idx = torch.nonzero(bad.cpu())
print('bad idx:', idx[:10])
print('bad idx shape:', idx.shape)
# Check if some rows are all zeros
for i in idx[:5]:
b,e,t,h,w = i.tolist()
print(f' [{b},{e},{t},{h},{w}]: ref={ref_out[b,e,t,h,w].item():.4f} sol={sol_out[b,e,t,h,w].item():.4f}')
" 2>&1
max_diff: 1.6796875
bad count: 201615
bad idx: tensor([[ 0, 256, 0, 0, 0],
[ 0, 256, 0, 0, 1],
[ 0, 256, 0, 0, 3],
[ 0, 256, 0, 0, 4],
[ 0, 256, 0, 0, 5],
[ 0, 256, 0, 0, 6],
[ 0, 256, 0, 0, 8],
[ 0, 256, 0, 0, 9],
[ 0, 256, 0, 0, 13],
[ 0, 256, 0, 0, 14]])
bad idx shape: torch.Size([201615, 5])
[0,256,0,0,0]: ref=0.3379 sol=0.0000
[0,256,0,0,1]: ref=0.2012 sol=0.0000
[0,256,0,0,3]: ref=0.4238 sol=0.0000
[0,256,0,0,4]: ref=0.3750 sol=0.0000
[0,256,0,0,5]: ref=0.3965 sol=0.0000
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -8,10 +8,6 @@
import triton
import triton.language as tl
-
-# ---------------------------------------------------------------------------
-# Fused patch-gather + GEMM kernel
-# ---------------------------------------------------------------------------
@triton.jit
def _fused_gemm_kernel(
@@ -25,33 +21,22 @@
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
- GROUP_SIZE_M: tl.constexpr,
):
- """Fused patch gather + matrix multiply."""
- pid = tl.program_id(0)
- num_pid_m = tl.num_programs(0)
- num_pid_n = tl.num_programs(1)
- num_pid = num_pid_m * num_pid_n
+ """Fused patch gather + matrix multiply. Simple row-major grid."""
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
- # Column-major ordering for L2 cache efficiency
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_SIZE_M
- group_size_m = min(M_total - first_pid_m, GROUP_SIZE_M)
- pid_m = first_pid_m + (pid % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-
- m_mask = offs_m < M_total
- n_mask = offs_n < N_val
+ m_mask = off_m < M_total
+ n_mask = off_n < N_val
# Decode m -> (b, t_out, h_out, w_out)
- w_out_arr = offs_m % W_out_val
- h_out_arr = ((offs_m // W_out_val) % H_out_val)
- t_out_arr = ((offs_m // (W_out_val * H_out_val)) % T_out_val)
- b_arr = offs_m // (W_out_val * H_out_val * T_out_val)
+ w_out_arr = off_m % W_out_val
+ h_out_arr = ((off_m // W_out_val) % H_out_val)
+ t_out_arr = ((off_m // (W_out_val * H_out_val)) % T_out_val)
+ b_arr = off_m // (W_out_val * H_out_val * T_out_val)
# Precompute base offset into x for each row
t_start = t_out_arr * kT_val
@@ -70,7 +55,6 @@
offs_k = k_base + tl.arange(0, BLOCK_K)
k_mask_l = offs_k < K_val
- # k -> (c, kt, kh, kw)
k_per_ch = kT_val * kH_val * kW_val
c_idx = offs_k // k_per_ch
rem = offs_k % k_per_ch
@@ -82,22 +66,17 @@
off_k = c_idx * stride_c + kt_idx * stride_t + kh_idx * stride_h + kw_idx
- # Load x and W
x_val = tl.load(X_PTR + base_x[:, None] + off_k[None, :],
mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
- w_val = tl.load(W_PTR + offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn,
+ w_val = tl.load(W_PTR + offs_k[:, None] * stride_wk + off_n[None, :] * stride_wn,
mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
acc = tl.dot(x_val, w_val, acc)
- tl.store(Y_PTR + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn,
+ tl.store(Y_PTR + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn,
acc.to(tl.bfloat16),
mask=m_mask[:, None] & n_mask[None, :])
-
-# ---------------------------------------------------------------------------
-# Wrapper class
-# ---------------------------------------------------------------------------
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
@@ -134,25 +113,27 @@
N = self.embed_dim
K = self.K
- # Weight: (embed_dim, C, kT, kH, kW) -> (K, N) for GEMM
weights = self.conv.weight.view(N, K).t().contiguous()
-
- # Output buffer: (M, N)
y = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
- # Config selection based on shape
+ # Config selection per shape
if M <= 256:
- BM, BN, BK, GS = 64, 256, 64, 4
+ BM, BN, BK = 64, 256, 64
+ GS = 1
elif M <= 1024:
- BM, BN, BK, GS = 128, 256, 64, 8
+ BM, BN, BK = 128, 256, 64
+ GS = 8
elif M <= 4096:
- BM, BN, BK, GS = 128, 256, 64, 8
+ BM, BN, BK = 128, 256, 64
+ GS = 8
+ elif M <= 10000:
+ BM, BN, BK = 128, 256, 64
+ GS = 8
else:
- BM, BN, BK, GS = 128, 256, 64, 8
-
- grid = lambda META: (
- triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
- )
+ BM, BN, BK = 128, 256, 64
+ GS = 8
+
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
_fused_gemm_kernel[grid](
x, weights, y,
@@ -160,32 +141,17 @@
self.kT, self.kH, self.kW,
self.T_out, self.H_out, self.W_out,
M, N, K,
- N, 1, N, 1,
- BLOCK_M=BM,
- BLOCK_N=BN,
- BLOCK_K=BK,
- GROUP_SIZE_M=GS,
+ N, 1, N, 1, BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK,
)
return y.view(self.B, self.T_out, self.H_out, self.W_out, N).permute(0, 4, 1, 2, 3).contiguous()
-
-# Shape shims
-B = 1
-C = 3
-T = 2
-H = 224
-W = 224
-kT = 2
-kH = 14
-kW = 14
-embed_dim = 1280
-
+B = 1; C = 3; T = 2; H = 224; W = 224
+kT = 2; kH = 14; kW = 14; embed_dim = 1280
def get_inputs():
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
return [x]
-
def get_init_inputs():
return [B, C, T, H, W, kT, kH, kW, embed_dim]
{
"content": "<4776 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run check and benchmark"
}PASS
shape=0 variant=eager tflops=3.620 gbps=20.049 ms=0.213
shape=0 variant=compiled tflops=3.013 gbps=16.687 ms=0.256
shape=0 variant=sota tflops=3.581 gbps=19.828 ms=0.215
shape=0 variant=solution tflops=8.155 gbps=45.159 ms=0.095
shape=0 solution_peak_fraction=0.0408
shape=1 variant=eager tflops=13.514 gbps=35.248 ms=0.228
shape=1 variant=compiled tflops=11.619 gbps=30.304 ms=0.265
shape=1 variant=sota tflops=13.893 gbps=36.234 ms=0.222
shape=1 variant=solution tflops=10.831 gbps=28.249 ms=0.285
shape=1 solution_peak_fraction=0.0542
shape=2 variant=eager tflops=25.575 gbps=52.828 ms=0.271
shape=2 variant=compiled tflops=22.306 gbps=46.076 ms=0.311
shape=2 variant=sota tflops=25.814 gbps=53.322 ms=0.269
shape=2 variant=solution tflops=23.027 gbps=47.564 ms=0.301
shape=2 solution_peak_fraction=0.1151
shape=3 variant=eager tflops=6.225 gbps=24.150 ms=0.149
shape=3 variant=compiled tflops=4.664 gbps=18.094 ms=0.198
shape=3 variant=sota tflops=6.358 gbps=24.668 ms=0.145
shape=3 variant=solution tflops=6.732 gbps=26.119 ms=0.137
shape=3 solution_peak_fraction=0.0337
peak_fraction: 0.0541
RESULT: LOW
{
"command": "uv run python tune_blocks.py 2>&1",
"timeout": 300000,
"description": "Run config tuner"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py", line 100, in <module>
print(f"\nShape {i}: B={s[0]} C={s[2]} T={s[3]} H={s[4]} W={s[5]} kT={s[6]} kH={s[7]} kW={s[8]} N={s[9]}")
~^^^
IndexError: tuple index out of range
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py",
"offset": 95
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py</path>
<type>file</type>
<content>
95: (1, 3, 8, 336, 336, 2, 14, 14, 1280),
96: (4, 3, 1, 224, 224, 1, 16, 16, 768),
97: ]
98:
99: for i, s in enumerate(shapes):
100: print(f"\nShape {i}: B={s[0]} C={s[2]} T={s[3]} H={s[4]} W={s[5]} kT={s[6]} kH={s[7]} kW={s[8]} N={s[9]}")
101: run_shape(*s)
(End of file - total 101 lines)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py
@@ -30,20 +30,28 @@
continue
@triton.jit
- def kernel(X, W, Y, stride_ym, stride_yn,
+ def kernel(X, W, Y, stride_wk, stride_wn, stride_ym, stride_yn,
BLOCK_M, BLOCK_N, BLOCK_K):
pid_m, pid_n = tl.program_id(0), tl.program_id(1)
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
- m_mask = offs_m < M_total
- n_mask = offs_n < N
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ m_mask = off_m < M_total
+ n_mask = off_n < N
- w_out_arr = offs_m % W_out
- h_out_arr = (offs_m // W_out) % H_out
- t_out_arr = (offs_m // (W_out * H_out)) % T_out
- b_arr = offs_m // (W_out * H_out * T_out)
+ w_out_arr = off_m % W_out
+ h_out_arr = ((off_m // W_out) % H_out)
+ t_out_arr = ((off_m // (W_out * H_out)) % T_out)
+ b_arr = off_m // (W_out * H_out * T_out)
- base_x = b_arr * (C*T*H*W) + (t_out_arr * kT) * (H*W) + (h_out_arr * kH) * W + (w_out_arr * kW)
+ t_start = t_out_arr * kT
+ h_start = h_out_arr * kH
+ w_start = w_out_arr * kW
+
+ stride_c = T * H * W
+ stride_t = H * W
+ stride_h = W
+
+ base_x = b_arr * (C * T * H * W) + t_start * stride_t + h_start * stride_h + w_start
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_base in range(0, K, BLOCK_K):
@@ -57,34 +65,33 @@
rem2 = rem % khkW
kh = rem2 // kW
kw = rem2 % kW
- off_k = c_idx * (T*H*W) + kt * (H*W) + kh * W + kw
+ off_k = c_idx * stride_c + kt * stride_t + kh * stride_h + kw
x_val = tl.load(X + base_x[:, None] + off_k[None, :],
mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
- w_val = tl.load(W + offs_k[:, None] * N + offs_n[None, :],
+ w_val = tl.load(W + offs_k[:, None] * stride_wk + off_n[None, :] * stride_wn,
mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
acc = tl.dot(x_val, w_val, acc)
- tl.store(Y + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn,
+ tl.store(Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn,
acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
y = torch.empty(M_total, N, dtype=torch.bfloat16, device='cuda')
grid = (triton.cdiv(M_total, BM), triton.cdiv(N, BN))
ms = triton.testing.do_bench(
- lambda: kernel[grid](x, w, y, N, 1, BM, BN, BK),
+ lambda: kernel[grid](x, w, y, N, 1, N, 1, BM, BN, BK),
quantiles=None, warmup=5, rep=20
)
- # Check correctness
sol_out = y.view(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3)
ref_out = torch.nn.functional.conv3d(x, weight, stride=(kT, kH, kW))
diff = (ref_out.float() - sol_out.float()).abs().max().item()
if diff < 0.02 and ms < best_ms:
best_ms = ms
- best_cfg = (BM, BN, BK, stg, wps)
+ best_cfg = (BM, BN, BK)
- print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} -> {ms:.3f}ms {'OK' if diff < 0.02 else 'BAD(' + str(diff)[:5] + ')'}")
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} -> {ms:.3f}ms {'OK' if diff < 0.02 else 'BAD(' + str(round(diff,3)) + ')'}")
print(f" Best: {best_cfg} -> {best_ms:.3f}ms")
return best_cfg
@@ -97,5 +104,5 @@
]
for i, s in enumerate(shapes):
- print(f"\nShape {i}: B={s[0]} C={s[2]} T={s[3]} H={s[4]} W={s[5]} kT={s[6]} kH={s[7]} kW={s[8]} N={s[9]}")
+ print(f"\nShape {i}: B={s[0]} C={s[1]} T={s[2]} H={s[3]} W={s[4]} kT={s[5]} kH={s[6]} kW={s[7]} N={s[8]}")
run_shape(*s)
{
"content": "<4368 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py"
}Wrote file successfully.{
"command": "uv run python tune_blocks.py 2>&1",
"timeout": 600000,
"description": "Run config tuner"
}Shape 0: B=1 C=3 T=2 H=224 W=224 kT=2 kH=14 kW=14 N=1280
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py", line 108, in <module>
run_shape(*s)
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py", line 81, in run_shape
ms = triton.testing.do_bench(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/testing.py", line 149, in do_bench
fn()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py", line 82, in <lambda>
lambda: kernel[grid](x, w, y, N, 1, N, 1, BM, BN, BK),
... (1880 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py
@@ -3,72 +3,73 @@
import torch
import triton
import triton.language as tl
+import time
-def run_shape(B, C, T, H, W, kT, kH, kW, embed_dim):
- M_total = B * (T//kT) * (H//kH) * (W//kW)
+def run_shape(B, C, T, H, W_in, kT, kH, kW, embed_dim):
+ M_total = B * (T//kT) * (H//kH) * (W_in//kW)
N = embed_dim
K = C * kT * kH * kW
- T_out, H_out, W_out = T//kT, H//kH, W//kW
+ T_out, H_out, W_out = T//kT, H//kH, W_in//kW
- x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16, device='cuda') * 0.5
+ torch.manual_seed(42)
+ x = torch.randn(B, C, T, H, W_in, dtype=torch.bfloat16, device='cuda') * 0.5
weight = torch.randn(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16, device='cuda')
w = weight.view(N, K).t().contiguous()
+ ref_out = torch.nn.functional.conv3d(x, weight, stride=(kT, kH, kW))
configs = [
- (64, 64, 64, 2, 4), (128, 64, 64, 3, 4), (128, 128, 64, 3, 8),
- (128, 256, 64, 3, 8), (64, 256, 64, 3, 4), (64, 128, 64, 3, 4),
- (128, 128, 128, 3, 8), (128, 256, 128, 2, 8), (256, 128, 64, 3, 8),
- (32, 128, 64, 2, 4), (64, 512, 64, 2, 8), (128, 512, 64, 2, 8),
- (256, 256, 64, 2, 16), (256, 128, 128, 2, 8),
+ (64, 64, 64), (128, 64, 64), (128, 128, 64),
+ (128, 256, 64), (64, 256, 64), (64, 128, 64),
+ (128, 128, 128), (128, 256, 128), (256, 128, 64),
+ (32, 128, 64), (64, 512, 64), (128, 512, 64),
+ (256, 256, 64), (256, 128, 128), (32, 64, 64),
+ (32, 256, 64), (32, 512, 64), (64, 128, 128),
+ (128, 128, 32), (128, 256, 32),
]
best_ms = 1e9
best_cfg = None
- for BM, BN, BK, stg, wps in configs:
+ for BM, BN, BK in configs:
if BM > M_total:
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} -> SKIPPED (BM > M={M_total})")
continue
@triton.jit
- def kernel(X, W, Y, stride_wk, stride_wn, stride_ym, stride_yn,
- BLOCK_M, BLOCK_N, BLOCK_K):
+ def kernel(X, W, Y,
+ stride_ym, stride_yn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr):
pid_m, pid_n = tl.program_id(0), tl.program_id(1)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = off_m < M_total
- n_mask = off_n < N
+ n_mask = off_n < N_val
- w_out_arr = off_m % W_out
- h_out_arr = ((off_m // W_out) % H_out)
- t_out_arr = ((off_m // (W_out * H_out)) % T_out)
- b_arr = off_m // (W_out * H_out * T_out)
+ w_out_arr = off_m % W_out_val
+ h_out_arr = ((off_m // W_out_val) % H_out_val)
+ t_out_arr = ((off_m // (W_out_val * H_out_val)) % T_out_val)
+ b_arr = off_m // (W_out_val * H_out_val * T_out_val)
- t_start = t_out_arr * kT
- h_start = h_out_arr * kH
- w_start = w_out_arr * kW
-
- stride_c = T * H * W
- stride_t = H * W
- stride_h = W
-
- base_x = b_arr * (C * T * H * W) + t_start * stride_t + h_start * stride_h + w_start
+ base_x = b_arr * (C_val * T_val * H_val * W_val) + (t_out_arr * kT_val) * (H_val * W_val) + (h_out_arr * kH_val) * W_val + (w_out_arr * kW_val)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_base in range(0, K, BLOCK_K):
+ for k_base in range(0, K_val, BLOCK_K):
offs_k = k_base + tl.arange(0, BLOCK_K)
- k_mask_l = offs_k < K
- k_per_ch = kT * kH * kW
+ k_mask_l = offs_k < K_val
+ k_per_ch = kT_val * kH_val * kW_val
c_idx = offs_k // k_per_ch
rem = offs_k % k_per_ch
- khkW = kH * kW
+ khkW = kH_val * kW_val
kt = rem // khkW
rem2 = rem % khkW
- kh = rem2 // kW
- kw = rem2 % kW
- off_k = c_idx * stride_c + kt * stride_t + kh * stride_h + kw
+ kh = rem2 // kW_val
+ kw = rem2 % kW_val
+ off_k = c_idx * (T_val * H_val * W_val) + kt * (H_val * W_val) + kh * W_val + kw
x_val = tl.load(X + base_x[:, None] + off_k[None, :],
mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
- w_val = tl.load(W + offs_k[:, None] * stride_wk + off_n[None, :] * stride_wn,
+ w_val = tl.load(W + offs_k[:, None] * N_val + off_n[None, :],
mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
acc = tl.dot(x_val, w_val, acc)
@@ -78,20 +79,26 @@
y = torch.empty(M_total, N, dtype=torch.bfloat16, device='cuda')
grid = (triton.cdiv(M_total, BM), triton.cdiv(N, BN))
- ms = triton.testing.do_bench(
- lambda: kernel[grid](x, w, y, N, 1, N, 1, BM, BN, BK),
- quantiles=None, warmup=5, rep=20
- )
-
- sol_out = y.view(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3)
- ref_out = torch.nn.functional.conv3d(x, weight, stride=(kT, kH, kW))
- diff = (ref_out.float() - sol_out.float()).abs().max().item()
-
- if diff < 0.02 and ms < best_ms:
- best_ms = ms
- best_cfg = (BM, BN, BK)
-
- print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} -> {ms:.3f}ms {'OK' if diff < 0.02 else 'BAD(' + str(round(diff,3)) + ')'}")
+ try:
+ ms = triton.testing.do
... (1215 more chars){
"content": "<4922 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/tune_blocks.py"
}Wrote file successfully.{
"command": "uv run python tune_blocks.py 2>&1",
"timeout": 600000,
"description": "Run block config tuning"
}...output truncated...
Full output saved to: /home/infatoshi/.local/share/opencode/tool-output/tool_e04560d4a001sYKY9S8JBKaST1
BM=128 BN= 64 BK= 64 -> ERROR: at 9:21:
def kernel(X, W, Y,
stride_ym, stride_yn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
pid_m, pid_n = tl.program_id(0), tl.program_id(1)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = off_m < M_total
^
NameError("Cannot access global variable M_total from within @jit'ed function. Triton kernels can only access global variables that are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not promise to support this forever.")
BM=128 BN=128 BK= 64 -> ERROR: at 9:21:
def kernel(X, W, Y,
stride_ym, stride_yn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
pid_m, pid_n = tl.program_id(0), tl.program_id(1)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = off_m < M_total
^
NameError("Cannot access global variable M_total from within @jit'ed function. Triton kernels ca
... (49803 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,6 +1,6 @@
"""Triton fused patch-gather + GEMM for 3D patch embedding.
-Fuses im2col (patch extraction) with GEMM.
+Optimized fused im2col + GEMM kernel with per-shape block configs.
"""
import torch
@@ -8,6 +8,10 @@
import triton
import triton.language as tl
+
+# ---------------------------------------------------------------------------
+# Fused patch-gather + GEMM kernel
+# ---------------------------------------------------------------------------
@triton.jit
def _fused_gemm_kernel(
@@ -22,7 +26,7 @@
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
- """Fused patch gather + matrix multiply. Simple row-major grid."""
+ """Fused patch gather + GEMM. Each program handles BLOCK_M x BLOCK_N output."""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -32,13 +36,11 @@
m_mask = off_m < M_total
n_mask = off_n < N_val
- # Decode m -> (b, t_out, h_out, w_out)
w_out_arr = off_m % W_out_val
h_out_arr = ((off_m // W_out_val) % H_out_val)
t_out_arr = ((off_m // (W_out_val * H_out_val)) % T_out_val)
b_arr = off_m // (W_out_val * H_out_val * T_out_val)
- # Precompute base offset into x for each row
t_start = t_out_arr * kT_val
h_start = h_out_arr * kH_val
w_start = w_out_arr * kW_val
@@ -74,9 +76,12 @@
acc = tl.dot(x_val, w_val, acc)
tl.store(Y_PTR + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn,
- acc.to(tl.bfloat16),
- mask=m_mask[:, None] & n_mask[None, :])
+ acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+
+# ---------------------------------------------------------------------------
+# Wrapper class
+# ---------------------------------------------------------------------------
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
@@ -114,26 +119,25 @@
K = self.K
weights = self.conv.weight.view(N, K).t().contiguous()
+
y = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
- # Config selection per shape
- if M <= 256:
- BM, BN, BK = 64, 256, 64
- GS = 1
- elif M <= 1024:
+ # Per-shape optimal configs (BM, BN, BK)
+ if M <= 256 and N <= 1280:
+ BM, BN, BK = 128, 128, 64
+ elif M <= 512 and N <= 1280:
BM, BN, BK = 128, 256, 64
- GS = 8
- elif M <= 4096:
+ elif M <= 1024 and N <= 1280:
BM, BN, BK = 128, 256, 64
- GS = 8
- elif M <= 10000:
+ elif M <= 8192 and N <= 1280:
BM, BN, BK = 128, 256, 64
- GS = 8
else:
BM, BN, BK = 128, 256, 64
- GS = 8
-
- grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
+
+ if N <= 768:
+ BN = min(256, 128)
+
+ grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
_fused_gemm_kernel[grid](
x, weights, y,
@@ -141,13 +145,23 @@
self.kT, self.kH, self.kW,
self.T_out, self.H_out, self.W_out,
M, N, K,
- N, 1, N, 1, BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK,
+ N, 1, N, 1,
+ BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK,
)
return y.view(self.B, self.T_out, self.H_out, self.W_out, N).permute(0, 4, 1, 2, 3).contiguous()
-B = 1; C = 3; T = 2; H = 224; W = 224
-kT = 2; kH = 14; kW = 14; embed_dim = 1280
+
+# Shape shims
+B = 1
+C = 3
+T = 2
+H = 224
+W = 224
+kT = 2
+kH = 14
+kW = 14
+embed_dim = 1280
def get_inputs():
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
{
"content": "<5072 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python benchmark.py 2>&1 && uv run python check.py 2>&1",
"timeout": 300000,
"description": "Benchmark and verify"
}shape=0 variant=eager tflops=3.588 gbps=19.871 ms=0.215
shape=0 variant=compiled tflops=3.065 gbps=16.974 ms=0.251
shape=0 variant=sota tflops=3.668 gbps=20.315 ms=0.210
shape=0 variant=solution tflops=5.796 gbps=32.096 ms=0.133
shape=0 solution_peak_fraction=0.0290
shape=1 variant=eager tflops=13.626 gbps=35.539 ms=0.226
shape=1 variant=compiled tflops=11.375 gbps=29.669 ms=0.271
shape=1 variant=sota tflops=13.798 gbps=35.987 ms=0.223
shape=1 variant=solution tflops=10.751 gbps=28.041 ms=0.287
shape=1 solution_peak_fraction=0.0538
shape=2 variant=eager tflops=26.141 gbps=53.997 ms=0.265
shape=2 variant=compiled tflops=22.359 gbps=46.185 ms=0.310
shape=2 variant=sota tflops=25.713 gbps=53.113 ms=0.270
shape=2 variant=solution tflops=22.493 gbps=46.461 ms=0.308
shape=2 solution_peak_fraction=0.1125
shape=3 variant=eager tflops=6.258 gbps=24.281 ms=0.148
shape=3 variant=compiled tflops=4.914 gbps=19.066 ms=0.188
shape=3 variant=sota tflops=6.432 gbps=24.953 ms=0.144
shape=3 variant=solution tflops=17.479 gbps=67.813 ms=0.053
shape=3 solution_peak_fraction=0.0874
peak_fraction: 0.0626
RESULT: LOW
PASS
{
"command": "<980 chars \u2014 see below>",
"description": "Check Triton version and problem dimensions"
}uv run python -c "
import torch, triton
print('Triton:', triton.__version__)
print('Torch:', torch.__version__)
print('CUDA:', torch.version.cuda)
print('GPU:', torch.cuda.get_device_properties(0).name)
print('compute_capability:', torch.cuda.get_device_capability(0))
# Check shapes
shapes = [
{'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280},
{'B': 2, 'C': 3, 'T': 4, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280},
{'B': 1, 'C': 3, 'T': 8, 'H': 336, 'W': 336, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280},
{'B': 4, 'C': 3, 'T': 1, 'H': 224, 'W': 224, 'kT': 1, 'kH': 16, 'kW': 16, 'embed_dim': 768},
]
for i, s in enumerate(shapes):
M = s['B'] * (s['T']//s['kT']) * (s['H']//s['kH']) * (s['W']//s['kW'])
K = s['C'] * s['kT'] * s['kH'] * s['kW']
N = s['embed_dim']
flops = 2 * M * K * N / 1e12
print(f'Shape {i}: M={M:6d} K={K:4d} N={N:4d} FLOPs={flops:.3f}T')
" 2>&1
Triton: 3.6.0
Torch: 2.11.0+cu130
CUDA: 13.0
GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
compute_capability: (12, 0)
Shape 0: M= 256 K=1176 N=1280 FLOPs=0.001T
Shape 1: M= 1024 K=1176 N=1280 FLOPs=0.003T
Shape 2: M= 2304 K=1176 N=1280 FLOPs=0.007T
Shape 3: M= 784 K= 768 N= 768 FLOPs=0.001T
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,6 +1,10 @@
"""Triton fused patch-gather + GEMM for 3D patch embedding.
-Optimized fused im2col + GEMM kernel with per-shape block configs.
+Strategy: explicit two-phase approach.
+1. im2col Triton kernel to materialize (M, K) matrix from input tensor
+2. Standard Triton GEMM kernel for (M, K) x (K, N) → (M, N)
+
+This separates concerns and lets each kernel be independently optimized.
"""
import torch
@@ -10,73 +14,106 @@
# ---------------------------------------------------------------------------
-# Fused patch-gather + GEMM kernel
+# Phase 1: im2col kernel
# ---------------------------------------------------------------------------
@triton.jit
-def _fused_gemm_kernel(
- X_PTR, W_PTR, Y_PTR,
+def _im2col_kernel(
+ X_PTR, Y_PTR,
C_val, T_val, H_val, W_val,
kT_val, kH_val, kW_val,
T_out_val, H_out_val, W_out_val,
+ M_total, K_val,
+ stride_y_m, stride_y_k,
+ BLOCK_M: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ """Each program handles BLOCK_M rows and BLOCK_K columns of the im2col matrix."""
+ pid_m = tl.program_id(0)
+ pid_k = tl.program_id(1)
+
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
+
+ m_mask = off_m < M_total
+ k_mask = off_k < K_val
+
+ stride_c = T_val * H_val * W_val
+ stride_t = H_val * W_val
+ stride_h = W_val
+
+ # Decode m -> (b, t_out, h_out, w_out)
+ w_out_arr = off_m % W_out_val
+ h_out_arr = ((off_m // W_out_val) % H_out_val)
+ t_out_arr = ((off_m // (W_out_val * H_out_val)) % T_out_val)
+ b_arr = off_m // (W_out_val * H_out_val * T_out_val)
+
+ t_start = t_out_arr * kT_val
+ h_start = h_out_arr * kH_val
+ w_start = w_out_arr * kW_val
+
+ base_x = b_arr * (C_val * T_val * H_val * W_val) + t_start * stride_t + h_start * stride_h + w_start
+
+ # Map k -> (c, kt, kh, kw)
+ k_per_ch = kT_val * kH_val * kW_val
+ c_idx = off_k // k_per_ch
+ rem = off_k % k_per_ch
+ kHkW = kH_val * kW_val
+ kt_idx = rem // kHkW
+ rem2 = rem % kHkW
+ kh_idx = rem2 // kW_val
+ kw_idx = rem2 % kW_val
+
+ off_k_val = c_idx * stride_c + kt_idx * stride_t + kh_idx * stride_h + kw_idx
+
+ x_idx = base_x[:, None] + off_k_val[None, :]
+ x_val = tl.load(X_PTR + x_idx, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+ y_idx = off_m[:, None] * stride_y_m + off_k[None, :] * stride_y_k
+ tl.store(Y_PTR + y_idx, x_val, mask=m_mask[:, None] & k_mask[None, :])
+
+
+# ---------------------------------------------------------------------------
+# Phase 2: GEMM kernel (standard Triton GEMM)
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _gemm_kernel(
+ A_PTR, B_PTR, C_PTR,
M_total, N_val, K_val,
- stride_wk, stride_wn,
- stride_ym, stride_yn,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
- """Fused patch gather + GEMM. Each program handles BLOCK_M x BLOCK_N output."""
+ """Standard blocked GEMM."""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
-
+
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
-
+
m_mask = off_m < M_total
n_mask = off_n < N_val
-
- w_out_arr = off_m % W_out_val
- h_out_arr = ((off_m // W_out_val) % H_out_val)
- t_out_arr = ((off_m // (W_out_val * H_out_val)) % T_out_val)
- b_arr = off_m // (W_out_val * H_out_val * T_out_val)
-
- t_start = t_out_arr * kT_val
- h_start = h_out_arr * kH_val
- w_start = w_out_arr * kW_val
-
- stride_c = T_val * H_val * W_val
- stride_t = H_val * W_val
- stride_h = W_val
-
- base_x = b_arr * (C_val * T_val * H_val * W_val) + t_start * stride_t + h_start * stride_h + w_start
-
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k_base in range(0, K_val, BLOCK_K):
- offs_k = k_base + tl.arange(0, BLOCK_K)
- k_mask_l = offs_k < K_val
-
- k_per_ch = kT_val * kH_val * kW_val
- c_idx = offs_k // k_per_ch
- rem = offs_k % k_per_ch
- kHkW = kH_val * kW_val
- kt_idx = rem // kHkW
- rem2 = rem % kHkW
- kh_idx = rem2 // kW_val
- kw_idx = rem2 % kW_val
-
- off_k = c_idx * stride_c + kt_idx * stride_t + kh_idx * stride_h + kw_idx
-
- x_val = tl.load(X_PTR + base_x[:, None] + off_k[None, :],
- mask=m_mask[:, None] & k_mask_l[None, :], other=0.0).to(tl.bfloat16)
- w_val = tl.load(W_PTR + offs_k[:, None] * stride_wk + off_n[None, :] * stride_wn,
- mask=k_mask_l[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
-
- acc = tl.dot(x_val, w_val, acc)
-
- tl.store(Y_PTR + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn,
- acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+
+ for k_base in range(K_val, BLOCK_K) - BLOCK_K, BLOCK_K):
+ k = k_base + tl.arange(0, BLOCK_K)
+ k_mask_t = k < K_val
+
+ a_idx = off_m[:, None] * stride_am + k[None, :] * stride_ak
+ b_idx = k[:, None] * stride_bk + off_n[None, :] * stride_bn
+
+ a = tl.load(A_PTR + a_idx, mask=m_mask[:, None] & k_mask_t[None, :], other=0.0).to(tl.bfloat16)
+ b = tl.load(B_PTR + b_idx, mask=k_mask_t[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
+
+ acc = tl.dot(a, b, acc)
+
+ c_idx = off_m[:, None] * stride_cm + off_n[None, :] * stride_cn
+ tl.store(C_PTR + c_idx, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
# ------------------------------------------------------------
... (2187 more chars){
"content": "<6342 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c \"import solution\" 2>&1",
"description": "Check syntax"
}Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py", line 103
for k_base in range(K_val, BLOCK_K) - BLOCK_K, BLOCK_K):
^
SyntaxError: unmatched ')'
{
"command": "rm -f tune_blocks.py",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm",
"description": "Remove broken tuner"
}(no output)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -100,7 +100,7 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_base in range(K_val, BLOCK_K) - BLOCK_K, BLOCK_K):
+ for k_base in range(0, (K_val // BLOCK_K + 1) * BLOCK_K, BLOCK_K):
k = k_base + tl.arange(0, BLOCK_K)
k_mask_t = k < K_val
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": " for k_base in range(K_val, BLOCK_K) - BLOCK_K, BLOCK_K):",
"newString": " for k_base in range(0, (K_val // BLOCK_K + 1) * BLOCK_K, BLOCK_K):"
}Edit applied successfully.{
"command": "uv run python benchmark.py 2>&1 && uv run python check.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm",
"description": "Run benchmark and check"
}shape=0 variant=eager tflops=3.606 gbps=19.969 ms=0.214
shape=0 variant=compiled tflops=3.007 gbps=16.653 ms=0.256
shape=0 variant=sota tflops=3.650 gbps=20.213 ms=0.211
shape=0 variant=solution tflops=5.694 gbps=31.531 ms=0.135
shape=0 solution_peak_fraction=0.0285
shape=1 variant=eager tflops=13.643 gbps=35.582 ms=0.226
shape=1 variant=compiled tflops=11.648 gbps=30.379 ms=0.265
shape=1 variant=sota tflops=13.878 gbps=36.195 ms=0.222
shape=1 variant=solution tflops=20.861 gbps=54.410 ms=0.148
shape=1 solution_peak_fraction=0.1043
shape=2 variant=eager tflops=25.873 gbps=53.443 ms=0.268
shape=2 variant=compiled tflops=22.003 gbps=45.449 ms=0.315
shape=2 variant=sota tflops=25.447 gbps=52.564 ms=0.273
shape=2 variant=solution tflops=44.165 gbps=91.227 ms=0.157
shape=2 solution_peak_fraction=0.2208
shape=3 variant=eager tflops=6.218 gbps=24.124 ms=0.149
shape=3 variant=compiled tflops=4.958 gbps=19.235 ms=0.187
shape=3 variant=sota tflops=6.438 gbps=24.976 ms=0.144
shape=3 variant=solution tflops=14.556 gbps=56.473 ms=0.064
shape=3 solution_peak_fraction=0.0728
peak_fraction: 0.0831
RESULT: LOW
FAIL: shape 0 {'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280} seed 42: tolerance exceeded: max_abs_diff=2.15625 (atol=0.02, rtol=0.02)
{
"command": "<1511 chars \u2014 see below>",
"timeout": 60000,
"description": "Test Triton bf16 GEMM on SM120"
}python -c "
import torch, triton, triton.language as tl
# Check if Triton can handle bf16 dot with fp32 acc on SM120
@triton.jit
def test_kernel(A, B, C, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_mask = off_m < M
n_mask = off_n < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_base in range(0, (K // BLOCK_K) * BLOCK_K, BLOCK_K):
k = k_base + tl.arange(0, BLOCK_K)
k_mask = k < K
a = tl.load(A + off_m[:, None] * K + k[None, :], mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.bfloat16)
b = tl.load(B + k[:, None] * N + off_n[None, :], mask=k_mask[:, None] & n_mask[None, :], other=0.0).to(tl.bfloat16)
acc = tl.dot(a, b, acc)
tl.store(C + off_m[:, None] * N + off_n[None, :], acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
M, N, K = 256, 1280, 1176
a = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
b = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
c = torch.empty(M, N, dtype=torch.bfloat16, device='cuda')
grid = (triton.cdiv(M, 64), triton.cdiv(N, 128))
test_kernel[grid](a, b, c, M, N, K, BLOCK_M=64, BLOCK_N=128, BLOCK_K=64)
expected = a @ b
diff = (expected.float() - c.float()).abs().max().item()
print(f'Max diff: {diff}')
print('OK' if diff < 0.02 else f'FAIL: diff={diff}')
" 2>&1
Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1260, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1089, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 5, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file