"""Fused patch-gather + tensor-core GEMM for 3D patch embedding.
Two-phase approach:
1) im2col via torch.as_strided (zero-copy view then contiguous reshape)
2) GEMM via Triton tl.dot kernel with L2-cache-tiling + autotune
Also includes a fused implicit-GEMM kernel used as fallback for small problems
where the im2col materialization overhead dominates.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
class _Conv3dWeight(nn.Module):
def __init__(self, C: int, embed_dim: int, kT: int, kH: int, kW: int):
super().__init__()
self.weight = nn.Parameter(
torch.empty(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16)
)
# ---------- Fused implicit-GEMM kernel ----------
@triton.autotune(
configs=[
triton.Config({'BM': 64, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 32, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 32, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 128, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=3),
triton.Config({'BM': 128, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
],
key=['C', 'kT', 'kH', 'kW', 'embed_dim', 'K'],
)
@triton.jit
def fused_patch_embed_kernel(
x_ptr, w_ptr, out_ptr,
C: tl.constexpr, kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
num_t, num_h, num_w,
total_patches, embed_dim, K,
stride_xb, stride_xc, stride_xt, stride_xh, stride_xw,
stride_ob, stride_on, stride_ot, stride_oh, stride_ow,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
pid = tl.program_id(0)
num_n_tiles = tl.cdiv(embed_dim, BN)
num_m_tiles = tl.cdiv(total_patches, BM)
total_tiles = num_m_tiles * num_n_tiles
num_pid_in_group = 2 * num_n_tiles
group_id = pid // num_pid_in_group
first_pid_m = group_id * 2
group_size = min(num_m_tiles - first_pid_m, 2)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size)
pid_n = (pid % num_pid_in_group) // group_size
m_off = pid_m * BM + tl.arange(0, BM)
n_off = pid_n * BN + tl.arange(0, BN)
m_mask = m_off < total_patches
n_mask = n_off < embed_dim
num_spatial = num_t * num_h * num_w
b_idx = m_off // num_spatial
rem = m_off % num_spatial
pt_idx = rem // (num_h * num_w)
rem2 = rem % (num_h * num_w)
ph_idx = rem2 // num_w
pw_idx = rem2 % num_w
base = (b_idx * stride_xb
+ pt_idx * kT * stride_xt
+ ph_idx * kH * stride_xh
+ pw_idx * kW * stride_xw)
acc = tl.zeros((BM, BN), dtype=tl.float32)
patch_vol: tl.constexpr = kT * kH * kW
kh_kw: tl.constexpr = kH * kW
for k_start in range(0, K, BK):
k_off = k_start + tl.arange(0, BK)
k_mask = k_off < K
c_idx = k_off // patch_vol
rem_k = k_off % patch_vol
kt_idx = rem_k // kh_kw
rem_k2 = rem_k % kh_kw
kh_idx = rem_k2 // kW
kw_idx = rem_k2 % kW
delta = (c_idx * stride_xc
+ kt_idx * stride_xt
+ kh_idx * stride_xh
+ kw_idx * stride_xw)
x_off = base[:, None] + delta[None, :]
x_block = tl.load(x_ptr + x_off, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
w_off = n_off[None, :] * K + k_off[:, None]
w_block = tl.load(w_ptr + w_off, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
acc += tl.dot(x_block, w_block)
out_vals = acc.to(tl.bfloat16)
out_off = (b_idx[:, None] * stride_ob
+ n_off[None, :] * stride_on
+ pt_idx[:, None] * stride_ot
+ ph_idx[:, None] * stride_oh
+ pw_idx[:, None] * stride_ow)
tl.store(out_ptr + out_off, out_vals, mask=m_mask[:, None] & n_mask[None, :])
# ---------- GEMM kernel (im2col + matmul) ----------
@triton.autotune(
configs=[
triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=3),
triton.Config({'BM': 128, 'BN': 128, 'BK': 64}, num_warps=8, num_stages=3),
triton.Config({'BM': 128, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 128, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 128, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 32, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 32, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 32, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 16, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 16, 'BN': 64, 'BK': 32}, num_warps=2, num_stages=3),
triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=4),
triton.Config({'BM': 128, 'BN': 128, 'BK': 64}, num_warps=8, num_stages=4),
triton.Config({'BM': 64, 'BN': 64, 'BK': 128}, num_warps=4, num_stages=2),
triton.Config({'BM': 64, 'BN': 128, 'BK': 128}, num_warps=4, num_stages=2),
triton.Config({'BM': 32, 'BN': 128, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=3),
triton.Config({'BM': 64, 'BN': 128, 'BK': 64}, num_warps=8, num_stages=3),
],
key=['M', 'N', 'K'],
)
@triton.jit
def gemm_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bn, stride_bk,
stride_cm, stride_cn,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BM)
num_pid_n = tl.cdiv(N, BN)
num_pid_in_group = 2 * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * 2
group_size = min(num_pid_m - first_pid_m, 2)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size)
pid_n = (pid % num_pid_in_group) // group_size
m_off = pid_m * BM + tl.arange(0, BM)
n_off = pid_n * BN + tl.arange(0, BN)
m_mask = m_off < M
n_mask = n_off < N
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k_start in range(0, K, BK):
k_off = k_start + tl.arange(0, BK)
k_mask = k_off < K
a_off = m_off[:, None] * stride_am + k_off[None, :] * stride_ak
a_block = tl.load(a_ptr + a_off, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
b_off = k_off[:, None] * stride_bk + n_off[None, :] * stride_bn
b_block = tl.load(b_ptr + b_off, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
acc += tl.dot(a_block, b_block)
c_off = m_off[:, None] * stride_cm + n_off[None, :] * stride_cn
tl.store(c_ptr + c_off, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
# ---------- Model ----------
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
self.kT, self.kH, self.kW = kT, kH, kW
self.embed_dim = embed_dim
self.conv = _Conv3dWeight(C, embed_dim, kT, kH, kW)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T, H, W = x.shape
kT, kH, kW = self.kT, self.kH, self.kW
embed_dim = self.embed_dim
num_t = T // kT
num_h = H // kH
num_w = W // kW
M = B * num_t * num_h * num_w
K = C * kT * kH * kW
N = embed_dim
use_fused = (M * N * K) < 800_000_000
if use_fused:
weight = self.conv.weight.reshape(N, K)
out = torch.empty(B, embed_dim, num_t, num_h, num_w,
dtype=torch.bfloat16, device=x.device)
grid = lambda meta: (triton.cdiv(total_patches, meta['BM']) * triton.cdiv(embed_dim, meta['BN']),)
total_patches = M
fused_patch_embed_kernel[grid](
x, weight, out,
C, kT, kH, kW,
num_t, num_h, num_w,
total_patches, embed_dim, K,
x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
out.stride(0), out.stride(1), out.stride(2), out.stride(3), out.stride(4),
)
return out
else:
patches = torch.as_strided(
x,
size=(B, num_t, num_h, num_w, C, kT, kH, kW),
stride=(x.stride(0), kT * x.stride(2), kH * x.stride(3),
kW * x.stride(4), x.stride(1), x.stride(2),
x.stride(3), x.stride(4)),
)
im2col = patches.reshape(M, K)
weight = self.conv.weight.reshape(N, K)
out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid = lambda meta: (triton.cdiv(M, meta['BM']) * triton.cdiv(N, meta['BN']),)
gemm_kernel[grid](
im2col, weight, out_2d,
M, N, K,
im2col.stride(0), im2col.stride(1),
weight.stride(0), weight.stride(1),
out_2d.stride(0), out_2d.stride(1),
)
out = out_2d.reshape(B, num_t, num_h, num_w, N).permute(0, 4, 1, 2, 3)
return out
B = 1
C = 3
T = 2
H = 224
W = 224
kT = 2
kH = 14
kW = 14
embed_dim = 1280
def get_inputs():
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
return [x]
def get_init_inputs():
return [B, C, T, H, W, kT, kH, kW, embed_dim]
shape=0 variant=eager tflops=3.276 gbps=18.140 ms=0.235
shape=0 variant=compiled tflops=2.859 gbps=15.831 ms=0.270
shape=0 variant=sota tflops=3.304 gbps=18.300 ms=0.233
shape=0 variant=solution tflops=11.629 gbps=64.402 ms=0.066
shape=0 solution_peak_fraction=0.0581
shape=1 variant=eager tflops=12.396 gbps=32.331 ms=0.249
shape=1 variant=compiled tflops=10.711 gbps=27.935 ms=0.288
shape=1 variant=sota tflops=12.569 gbps=32.783 ms=0.245
shape=1 variant=solution tflops=52.315 gbps=136.445 ms=0.059
shape=1 solution_peak_fraction=0.2616
shape=2 variant=eager tflops=24.262 gbps=50.117 ms=0.286
shape=2 variant=compiled tflops=21.544 gbps=44.501 ms=0.322
shape=2 variant=sota tflops=24.670 gbps=50.958 ms=0.281
shape=2 variant=solution tflops=93.451 gbps=193.035 ms=0.074
shape=2 solution_peak_fraction=0.4673
shape=3 variant=eager tflops=5.792 gbps=22.471 ms=0.160
shape=3 variant=compiled tflops=4.691 gbps=18.201 ms=0.197
shape=3 variant=sota tflops=5.854 gbps=22.712 ms=0.158
shape=3 variant=solution tflops=25.909 gbps=100.518 ms=0.036
shape=3 solution_peak_fraction=0.1295
peak_fraction: 0.1742
RESULT: OK
PASS
{
"problem": "10_patch_embed_conv3d_gemm",
"harness": "opencode",
"model": "zai/glm-5.1",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1742,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 376094, "output_tokens": 10962, "cache_read_tokens": 1464896, "cache_creation_tokens": 0, "reasoning_tokens": 113911, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, shapes.py to understand the problem",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement custom CUDA kernel for 3D patch embedding (Triton or CUDA C++)",
"status": "pending",
"priority": "high"
},
{
"content": "Verify correctness with check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize for RTX PRO 6000 (SM120)",
"status": "pending",
"priority": "high"
},
{
... (114 more chars)content: 'Read reference.py, check.py, benchmark.py, shapes.py to understand the problem'
status: 'in_progress'
priority: 'high'
content: 'Implement custom CUDA kernel for 3D patch embedding (Triton or CUDA C++)'
status: 'pending'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize for RTX PRO 6000 (SM120)'
status: 'pending'
priority: 'high'
content: 'Benchmark all shapes and iterate'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Conv3d-as-GEMM patch embedding.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") o
... (3130 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive Conv3d-as-GEMM patch embedding reference (correctness only).
2:
3: Vision-Transformer / Qwen2-VL style patch embedding: a video tensor
4: (B, C, T, H, W) is split into non-overlapping (kT, kH, kW) patches and each
5: patch projected to embed_dim. This is mathematically a 3D convolution with
6: stride == kernel and equivalently a single GEMM after a strided reshape.
7:
8: Reference uses nn.Conv3d (cuDNN dispatch) for clarity. The agent's solution is
9: forbidden from using Conv3d / conv3d / matmul / linear / einsum, forcing them
10: to write a fused patch-gather + tensor-core GEMM kernel.
11:
12: Output layout: (B, embed_dim, T/kT, H/kH, W/kW). No bias.
13: """
14: import torch
15: import torch.nn as nn
16:
17: OP_TYPE = "patch_embed"
18: SUPPORTED_PRECISIONS = ["bf16"]
19: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
20:
21:
22: class Model(nn.Module):
23: def __init__(self, B: int, C: int, T: int, H: int, W: int,
24: kT: int, kH: int, kW: int, embed_dim: int):
25: super().__init__()
26: assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
27: f"Input dims must be divisible by patch size: T={T} kT={kT} H={H} kH={kH} W={W} kW={kW}"
28: self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
29: self.kT, self.kH, self.kW = kT, kH, kW
30: self.embed_dim = embed_dim
31:
3
... (911 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Conv3d-as-GEMM patch embedding.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regim
... (3879 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for Conv3d-as-GEMM patch embedding.
2:
3: Sized for Qwen2-VL / ViT-style patch embedding. embed_dim=1280 matches Qwen2-VL
4: ViT; embed_dim=768 is a ViT-B/16 reference. Patch sizes are the canonical
5: 14x14 spatial / 2-frame temporal used in modern VL encoders.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL ViT base: 224x224x2-frame video, 14x14x2 patches, embed_dim=1280
10: {"B": 1, "C": 3, "T": 2, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
11:
12: # Same patch config, batch 2, 4-frame video
13: {"B": 2, "C": 3, "T": 4, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
14:
15: # Larger spatial (336x336), 8-frame — exercises larger num_patches
16: {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
17:
18: # ViT-B/16 (image-mode): kT=1, kH=kW=16, embed_dim=768
19: {"B": 4, "C": 3, "T": 1, "H": 224, "W": 224, "kT": 1, "kH": 16, "kW": 16, "embed_dim": 768},
20: ]
(End of file - total 20 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Conv3d-as-GEMM patch embedding.
2:
3: Dispatches torch.nn.functional.conv3d (cuDNN). The agent's solution is
4: forbidden from calling conv3d directly; this file is only for the benchmark's
5: reference baseline.
6: """
7: from __future__ import annotations
8:
9: import torch
10: import torch.nn.functional as F
11:
12:
13: def sota_forward(x: torch.Tensor, weight: torch.Tensor,
14: stride: tuple[int, int, int]) -> torch.Tensor:
15: return F.conv3d(x, weight, bias=None, stride=stride)
16:
17:
18: def is_available() -> bool:
19: return True
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 10_patch_embed_conv3d_gemm
2: display_name: "Conv3d-as-GEMM Patch Embedding"
3: precision: bf16
4: regime: compute # GEMM-equivalent, embed_dim=1280 is well above the bf16 ridge
5:
6: # Dense FLOPs: 2 * (num_patches) * (C * kT * kH * kW) * embed_dim.
7: # With stride==kernel and divisible inputs, num_patches * (kT*kH*kW) = T*H*W,
8: # so this collapses to: 2 * B * embed_dim * C * T * H * W.
9: flops_formula: "2 * B * embed_dim * C * T * H * W"
10:
11: # Bytes moved (lower bound):
12: # read input x: B * C * T * H * W * 2
13: # read weights: embed_dim * C * kT * kH * kW * 2
14: # write output: B * embed_dim * (T // kT) * (H // kH) * (W // kW) * 2
15: bytes_formula: "B*C*T*H*W*2 + embed_dim*C*kT*kH*kW*2 + B*embed_dim*(T//kT)*(H//kH)*(W//kW)*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must implement the fused patch-gather + GEMM
25: # themselves. Banning Conv3d/conv3d (the cuDNN dispatch) and matmul/bmm/linear/
26: # einsum (the lazy reshape-then-cuBLAS shortcut). Unfold is also banned for
27: # completeness even though it's nominally 2D.
28: forbidden:
29: - "torch.nn.Conv3d"
30: - "torch.nn.functional.conv3d"
31: - "F.conv3d"
32: - "torch.conv3d"
33: - "torch.matmul"
34: - "torch.bmm"
35: -
... (386 more chars){
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py 2>/dev/null; echo \"---\"; cat /home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py 2>/dev/null | head -60",
"description": "Check if solution.py exists and read correctness module"
}---
"""Per-dtype correctness tolerance.
Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
We extend to fp8 (0.1, permissive) and int (exact).
"""
from __future__ import annotations
import torch
DEFAULT_TOLERANCE = {
torch.float32: {"atol": 1e-4, "rtol": 1e-4},
torch.float16: {"atol": 1e-2, "rtol": 1e-2},
torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
torch.int8: {"atol": 0, "rtol": 0},
torch.int32: {"atol": 0, "rtol": 0},
torch.int64: {"atol": 0, "rtol": 0},
}
def _coerce_float(x):
"""Accept int/float/str scientific-notation as float; pass through dicts."""
if isinstance(x, (int, float)):
return float(x)
if isinstance(x, str):
try:
return float(x)
except ValueError:
return x
return x
def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
"""Lookup atol/rtol for a given dtype, with optional per-problem override.
The override dict is keyed by str(dtype). Values may be:
- a single number (or a string parseable as float, e.g. "5e-2") -> used
as both atol and rtol
- a dict {"atol": ..., "rtol": ...} -> used directly, with values coerced
"""
# Accept several spellings of the same dtype key (PyYAML quirks):
# "torch.bfloat16", "bfloat16", torch.bfloat16
if override is not None:
t
... (688 more chars){
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>/dev/null | head -3",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "uv run python -c \"import triton; print(triton.__version__)\" 2>/dev/null",
"description": "Check Triton version",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}3.6.0
{
"command": "uv run python -c \"import torch; print(torch.__version__, torch.cuda.get_device_capability())\"",
"description": "Check PyTorch version and CUDA cap",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}2.11.0+cu130 (12, 0)
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, shapes.py to understand the problem",
"priority": "high",
"status": "completed"
},
{
"content": "Implement custom CUDA kernel for 3D patch embedding (Triton with tl.dot)",
"priority": "high",
"status": "in_progress"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "pending"
},
{
"content": "Profile and optimize for RTX PRO 6000 (SM120)",
"priority": "high",
"status": "pending"
},
{
... (116 more chars)content: 'Read reference.py, check.py, benchmark.py, shapes.py to understand the problem'
status: 'completed'
priority: 'high'
content: 'Implement custom CUDA kernel for 3D patch embedding (Triton with tl.dot)'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize for RTX PRO 6000 (SM120)'
status: 'pending'
priority: 'high'
content: 'Benchmark all shapes and iterate'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -0,0 +1,159 @@
+"""Fused patch-gather + tensor-core GEMM for 3D patch embedding.
+
+Implicit GEMM: the strided patch gather from the input tensor is fused
+into the K-loop of the GEMM, avoiding materialization of an im2col buffer.
+Uses Triton tl.dot for Blackwell SM120 tensor core bf16 MMAs.
+"""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+class _Conv3dWeight(nn.Module):
+ def __init__(self, C: int, embed_dim: int, kT: int, kH: int, kW: int):
+ super().__init__()
+ self.weight = nn.Parameter(
+ torch.empty(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16)
+ )
+
+
+@triton.jit
+def patch_embed_gemm_kernel(
+ x_ptr, w_ptr, out_ptr,
+ C: tl.constexpr, kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
+ num_t, num_h, num_w,
+ total_patches,
+ embed_dim,
+ K,
+ stride_xb, stride_xc, stride_xt, stride_xh, stride_xw,
+ stride_ob, stride_on, stride_ot, stride_oh, stride_ow,
+ BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ m_off = pid_m * BM + tl.arange(0, BM)
+ n_off = pid_n * BN + tl.arange(0, BN)
+
+ m_mask = m_off < total_patches
+ n_mask = n_off < embed_dim
+
+ num_spatial = num_t * num_h * num_w
+ b_idx = m_off // num_spatial
+ rem = m_off % num_spatial
+ pt_idx = rem // (num_h * num_w)
+ rem2 = rem % (num_h * num_w)
+ ph_idx = rem2 // num_w
+ pw_idx = rem2 % num_w
+
+ base = (b_idx * stride_xb
+ + pt_idx * kT * stride_xt
+ + ph_idx * kH * stride_xh
+ + pw_idx * kW * stride_xw)
+
+ acc = tl.zeros((BM, BN), dtype=tl.float32)
+ patch_vol: tl.constexpr = kT * kH * kW
+ kh_kw: tl.constexpr = kH * kW
+
+ for k_start in range(0, K, BK):
+ k_off = k_start + tl.arange(0, BK)
+ k_mask = k_off < K
+
+ c_idx = k_off // patch_vol
+ rem_k = k_off % patch_vol
+ kt_idx = rem_k // kh_kw
+ rem_k2 = rem_k % kh_kw
+ kh_idx = rem_k2 // kW
+ kw_idx = rem_k2 % kW
+
+ delta = (c_idx * stride_xc
+ + kt_idx * stride_xt
+ + kh_idx * stride_xh
+ + kw_idx * stride_xw)
+
+ x_off = base[:, None] + delta[None, :]
+ x_block = tl.load(x_ptr + x_off, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+ w_off = n_off[None, :] * K + k_off[:, None]
+ w_block = tl.load(w_ptr + w_off, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+
+ acc += tl.dot(x_block, w_block)
+
+ out_vals = acc.to(tl.bfloat16)
+
+ out_off = (b_idx[:, None] * stride_ob
+ + n_off[None, :] * stride_on
+ + pt_idx[:, None] * stride_ot
+ + ph_idx[:, None] * stride_oh
+ + pw_idx[:, None] * stride_ow)
+
+ tl.store(out_ptr + out_off, out_vals, mask=m_mask[:, None] & n_mask[None, :])
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, C: int, T: int, H: int, W: int,
+ kT: int, kH: int, kW: int, embed_dim: int):
+ super().__init__()
+ self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
+ self.kT, self.kH, self.kW = kT, kH, kW
+ self.embed_dim = embed_dim
+ self.conv = _Conv3dWeight(C, embed_dim, kT, kH, kW)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, T, H, W = x.shape
+ kT, kH, kW = self.kT, self.kH, self.kW
+ embed_dim = self.embed_dim
+
+ num_t = T // kT
+ num_h = H // kH
+ num_w = W // kW
+ total_patches = B * num_t * num_h * num_w
+ K = C * kT * kH * kW
+
+ weight = self.conv.weight.reshape(embed_dim, K)
+
+ out = torch.empty(B, embed_dim, num_t, num_h, num_w,
+ dtype=torch.bfloat16, device=x.device)
+
+ BM: tl.constexpr = 64
+ BN: tl.constexpr = 128
+ BK: tl.constexpr = 32
+
+ grid = (triton.cdiv(total_patches, BM), triton.cdiv(embed_dim, BN))
+
+ patch_embed_gemm_kernel[grid](
+ x, weight, out,
+ C, kT, kH, kW,
+ num_t, num_h, num_w,
+ total_patches,
+ embed_dim,
+ K,
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
+ out.stride(0), out.stride(1), out.stride(2), out.stride(3), out.stride(4),
+ BM=BM, BN=BN, BK=BK,
+ num_warps=4,
+ num_stages=3,
+ )
+
+ return out
+
+
+B = 1
+C = 3
+T = 2
+H = 224
+W = 224
+kT = 2
+kH = 14
+kW = 14
+embed_dim = 1280
+
+
+def get_inputs():
+ x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
+ return [x]
+
+
+def get_init_inputs():
+ return [B, C, T, H, W, kT, kH, kW, embed_dim]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<4603 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm",
"timeout": 180000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.280 gbps=18.164 ms=0.235
shape=0 variant=compiled tflops=2.866 gbps=15.871 ms=0.269
shape=0 variant=sota tflops=3.308 gbps=18.317 ms=0.233
shape=0 variant=solution tflops=9.786 gbps=54.196 ms=0.079
shape=0 solution_peak_fraction=0.0489
shape=1 variant=eager tflops=12.420 gbps=32.394 ms=0.248
shape=1 variant=compiled tflops=10.760 gbps=28.065 ms=0.286
shape=1 variant=sota tflops=12.537 gbps=32.698 ms=0.246
shape=1 variant=solution tflops=39.011 gbps=101.747 ms=0.079
shape=1 solution_peak_fraction=0.1951
shape=2 variant=eager tflops=24.290 gbps=50.173 ms=0.286
shape=2 variant=compiled tflops=21.686 gbps=44.795 ms=0.320
shape=2 variant=sota tflops=24.604 gbps=50.822 ms=0.282
shape=2 variant=solution tflops=71.882 gbps=148.481 ms=0.096
shape=2 solution_peak_fraction=0.3594
shape=3 variant=eager tflops=5.766 gbps=22.370 ms=0.160
shape=3 variant=compiled tflops=4.679 gbps=18.153 ms=0.198
shape=3 variant=sota tflops=5.869 gbps=22.772 ms=0.158
shape=3 variant=solution tflops=31.988 gbps=124.104 ms=0.029
shape=3 solution_peak_fraction=0.1599
peak_fraction: 0.1530
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,8 +1,11 @@
"""Fused patch-gather + tensor-core GEMM for 3D patch embedding.
-Implicit GEMM: the strided patch gather from the input tensor is fused
-into the K-loop of the GEMM, avoiding materialization of an im2col buffer.
-Uses Triton tl.dot for Blackwell SM120 tensor core bf16 MMAs.
+Two-phase approach:
+ 1) im2col via torch.as_strided (zero-copy view then contiguous reshape)
+ 2) GEMM via Triton tl.dot kernel with L2-cache-tiling
+
+The GEMM kernel is a standard tiled matmul operating on contiguous (M,K) and (N,K)
+inputs, achieving high tensor-core utilization on SM120 Blackwell.
"""
import torch
import torch.nn as nn
@@ -19,76 +22,46 @@
@triton.jit
-def patch_embed_gemm_kernel(
- x_ptr, w_ptr, out_ptr,
- C: tl.constexpr, kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
- num_t, num_h, num_w,
- total_patches,
- embed_dim,
- K,
- stride_xb, stride_xc, stride_xt, stride_xh, stride_xw,
- stride_ob, stride_on, stride_ot, stride_oh, stride_ow,
+def gemm_kernel(
+ a_ptr, b_ptr, c_ptr,
+ M, N, K,
+ stride_am, stride_ak,
+ stride_bn, stride_bk,
+ stride_cm, stride_cn,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
- pid_m = tl.program_id(0)
- pid_n = tl.program_id(1)
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(M, BM)
+ num_pid_n = tl.cdiv(N, BN)
+
+ num_pid_in_group = 2 * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * 2
+ group_size = min(num_pid_m - first_pid_m, 2)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size)
+ pid_n = (pid % num_pid_in_group) // group_size
m_off = pid_m * BM + tl.arange(0, BM)
n_off = pid_n * BN + tl.arange(0, BN)
-
- m_mask = m_off < total_patches
- n_mask = n_off < embed_dim
-
- num_spatial = num_t * num_h * num_w
- b_idx = m_off // num_spatial
- rem = m_off % num_spatial
- pt_idx = rem // (num_h * num_w)
- rem2 = rem % (num_h * num_w)
- ph_idx = rem2 // num_w
- pw_idx = rem2 % num_w
-
- base = (b_idx * stride_xb
- + pt_idx * kT * stride_xt
- + ph_idx * kH * stride_xh
- + pw_idx * kW * stride_xw)
+ m_mask = m_off < M
+ n_mask = n_off < N
acc = tl.zeros((BM, BN), dtype=tl.float32)
- patch_vol: tl.constexpr = kT * kH * kW
- kh_kw: tl.constexpr = kH * kW
for k_start in range(0, K, BK):
k_off = k_start + tl.arange(0, BK)
k_mask = k_off < K
- c_idx = k_off // patch_vol
- rem_k = k_off % patch_vol
- kt_idx = rem_k // kh_kw
- rem_k2 = rem_k % kh_kw
- kh_idx = rem_k2 // kW
- kw_idx = rem_k2 % kW
+ a_off = m_off[:, None] * stride_am + k_off[None, :] * stride_ak
+ a_block = tl.load(a_ptr + a_off, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
- delta = (c_idx * stride_xc
- + kt_idx * stride_xt
- + kh_idx * stride_xh
- + kw_idx * stride_xw)
+ b_off = k_off[:, None] * stride_bk + n_off[None, :] * stride_bn
+ b_block = tl.load(b_ptr + b_off, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
- x_off = base[:, None] + delta[None, :]
- x_block = tl.load(x_ptr + x_off, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+ acc += tl.dot(a_block, b_block)
- w_off = n_off[None, :] * K + k_off[:, None]
- w_block = tl.load(w_ptr + w_off, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
-
- acc += tl.dot(x_block, w_block)
-
- out_vals = acc.to(tl.bfloat16)
-
- out_off = (b_idx[:, None] * stride_ob
- + n_off[None, :] * stride_on
- + pt_idx[:, None] * stride_ot
- + ph_idx[:, None] * stride_oh
- + pw_idx[:, None] * stride_ow)
-
- tl.store(out_ptr + out_off, out_vals, mask=m_mask[:, None] & n_mask[None, :])
+ c_off = m_off[:, None] * stride_cm + n_off[None, :] * stride_cn
+ tl.store(c_ptr + c_off, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
class Model(nn.Module):
@@ -108,34 +81,40 @@
num_t = T // kT
num_h = H // kH
num_w = W // kW
- total_patches = B * num_t * num_h * num_w
+ M = B * num_t * num_h * num_w
K = C * kT * kH * kW
+ N = embed_dim
- weight = self.conv.weight.reshape(embed_dim, K)
+ patches = torch.as_strided(
+ x,
+ size=(B, num_t, num_h, num_w, C, kT, kH, kW),
+ stride=(x.stride(0), kT * x.stride(2), kH * x.stride(3),
+ kW * x.stride(4), x.stride(1), x.stride(2),
+ x.stride(3), x.stride(4)),
+ )
+ im2col = patches.reshape(M, K)
- out = torch.empty(B, embed_dim, num_t, num_h, num_w,
- dtype=torch.bfloat16, device=x.device)
+ weight = self.conv.weight.reshape(N, K)
+ out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
- BM: tl.constexpr = 64
+ BM: tl.constexpr = 128
BN: tl.constexpr = 128
- BK: tl.constexpr = 32
+ BK: tl.constexpr = 64
- grid = (triton.cdiv(total_patches, BM), triton.cdiv(embed_dim, BN))
+ grid = (triton.cdiv(M, BM) * triton.cdiv(N, BN),)
- patch_embed_gemm_kernel[grid](
- x, weight, out,
- C, kT, kH, kW,
- num_t, num_h, num_w,
- total_patches,
- embed_dim,
- K,
- x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
- out.stride(0), out.stride(1), out.stride(2), out.stride(3), out.stride(4),
+ gemm_kernel[grid](
+ im2col, weight, out_2d,
+ M, N, K,
+ im2col.stride(0), im2col.stride(1),
+ weight.stride(0), weight.stride
... (284 more chars){
"content": "<4089 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 180000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}Uninstalled 1 package in 0.17ms
Installed 1 package in 1ms
PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.281 gbps=18.171 ms=0.235
shape=0 variant=compiled tflops=2.857 gbps=15.823 ms=0.270
shape=0 variant=sota tflops=3.310 gbps=18.331 ms=0.233
shape=0 variant=solution tflops=11.815 gbps=65.429 ms=0.065
shape=0 solution_peak_fraction=0.0591
shape=1 variant=eager tflops=12.469 gbps=32.520 ms=0.247
shape=1 variant=compiled tflops=10.697 gbps=27.900 ms=0.288
shape=1 variant=sota tflops=12.592 gbps=32.843 ms=0.245
shape=1 variant=solution tflops=47.563 gbps=124.050 ms=0.065
shape=1 solution_peak_fraction=0.2378
shape=2 variant=eager tflops=24.258 gbps=50.108 ms=0.286
shape=2 variant=compiled tflops=21.492 gbps=44.395 ms=0.323
shape=2 variant=sota tflops=24.678 gbps=50.976 ms=0.281
shape=2 variant=solution tflops=99.158 gbps=204.823 ms=0.070
shape=2 solution_peak_fraction=0.4958
shape=3 variant=eager tflops=5.790 gbps=22.464 ms=0.160
shape=3 variant=compiled tflops=4.585 gbps=17.787 ms=0.202
shape=3 variant=sota tflops=5.818 gbps=22.572 ms=0.159
shape=3 variant=solution tflops=23.421 gbps=90.865 ms=0.039
shape=3 solution_peak_fraction=0.1171
peak_fraction: 0.1690
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -2,7 +2,7 @@
Two-phase approach:
1) im2col via torch.as_strided (zero-copy view then contiguous reshape)
- 2) GEMM via Triton tl.dot kernel with L2-cache-tiling
+ 2) GEMM via Triton tl.dot kernel with L2-cache-tiling + autotune
The GEMM kernel is a standard tiled matmul operating on contiguous (M,K) and (N,K)
inputs, achieving high tensor-core utilization on SM120 Blackwell.
@@ -21,6 +21,22 @@
)
+@triton.autotune(
+ configs=[
+ triton.Config({'BM': 64, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=3),
+ triton.Config({'BM': 128, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 64, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 128, 'BN': 128, 'BK': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BM': 64, 'BN': 128, 'BK': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 128, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 64, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 32, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=4),
+ triton.Config({'BM': 128, 'BN': 128, 'BK': 64}, num_warps=8, num_stages=4),
+ ],
+ key=['M', 'N', 'K'],
+)
@triton.jit
def gemm_kernel(
a_ptr, b_ptr, c_ptr,
@@ -97,11 +113,7 @@
weight = self.conv.weight.reshape(N, K)
out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
- BM: tl.constexpr = 128
- BN: tl.constexpr = 128
- BK: tl.constexpr = 64
-
- grid = (triton.cdiv(M, BM) * triton.cdiv(N, BN),)
+ grid = lambda meta: (triton.cdiv(M, meta['BM']) * triton.cdiv(N, meta['BN']),)
gemm_kernel[grid](
im2col, weight, out_2d,
@@ -109,9 +121,6 @@
im2col.stride(0), im2col.stride(1),
weight.stride(0), weight.stride(1),
out_2d.stride(0), out_2d.stride(1),
- BM=BM, BN=BN, BK=BK,
- num_warps=8,
- num_stages=3,
)
out = out_2d.reshape(B, num_t, num_h, num_w, N).permute(0, 4, 1, 2, 3)
{
"content": "<4933 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark with autotune",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.279 gbps=18.159 ms=0.235
shape=0 variant=compiled tflops=2.832 gbps=15.685 ms=0.272
shape=0 variant=sota tflops=3.307 gbps=18.312 ms=0.233
shape=0 variant=solution tflops=13.493 gbps=74.720 ms=0.057
shape=0 solution_peak_fraction=0.0675
shape=1 variant=eager tflops=12.429 gbps=32.417 ms=0.248
shape=1 variant=compiled tflops=10.611 gbps=27.675 ms=0.291
shape=1 variant=sota tflops=12.526 gbps=32.670 ms=0.246
shape=1 variant=solution tflops=51.094 gbps=133.261 ms=0.060
shape=1 solution_peak_fraction=0.2555
shape=2 variant=eager tflops=24.203 gbps=49.994 ms=0.287
shape=2 variant=compiled tflops=21.507 gbps=44.426 ms=0.323
shape=2 variant=sota tflops=24.644 gbps=50.906 ms=0.281
shape=2 variant=solution tflops=94.925 gbps=196.078 ms=0.073
shape=2 solution_peak_fraction=0.4746
shape=3 variant=eager tflops=5.765 gbps=22.365 ms=0.160
shape=3 variant=compiled tflops=4.603 gbps=17.859 ms=0.201
shape=3 variant=sota tflops=5.841 gbps=22.661 ms=0.158
shape=3 variant=solution tflops=25.921 gbps=100.563 ms=0.036
shape=3 solution_peak_fraction=0.1296
peak_fraction: 0.1804
RESULT: OK
{
"command": "<1650 chars \u2014 see below>",
"description": "Benchmark im2col and permute steps",
"timeout": 60000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
import torch, time
for name, B, C, T, H, W, kT, kH, kW, ed in [
('s0', 1,3,2,224,224,2,14,14,1280),
('s1', 2,3,4,224,224,2,14,14,1280),
('s2', 1,3,8,336,336,2,14,14,1280),
('s3', 4,3,1,224,224,1,16,16,768),
]:
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16, device='cuda')
nt,nh,nw = T//kT, H//kH, W//kW
M,K,N = B*nt*nh*nw, C*kT*kH*kW, ed
# warmup
for _ in range(10):
p = torch.as_strided(x,(B,nt,nh,nw,C,kT,kH,kW),
(x.stride(0),kT*x.stride(2),kH*x.stride(3),kW*x.stride(4),
x.stride(1),x.stride(2),x.stride(3),x.stride(4)))
im = p.reshape(M,K)
torch.cuda.synchronize()
# bench im2col
t0 = time.perf_counter()
for _ in range(200):
p = torch.as_strided(x,(B,nt,nh,nw,C,kT,kH,kW),
(x.stride(0),kT*x.stride(2),kH*x.stride(3),kW*x.stride(4),
x.stride(1),x.stride(2),x.stride(3),x.stride(4)))
im = p.reshape(M,K)
torch.cuda.synchronize()
im2col_ms = (time.perf_counter()-t0)/200*1000
# bench permute
out2d = torch.randn(M,N,dtype=torch.bfloat16,device='cuda')
for _ in range(10):
o = out2d.reshape(B,nt,nh,nw,N).permute(0,4,1,2,3)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(200):
o = out2d.reshape(B,nt,nh,nw,N).permute(0,4,1,2,3)
torch.cuda.synchronize()
perm_ms = (time.perf_counter()-t0)/200*1000
flops = 2*B*ed*C*T*H*W
gemm_ideal_ms = flops/200e12*1000
print(f'{name}: im2col={im2col_ms:.3f}ms perm={perm_ms:.3f}ms gemm_ideal={gemm_ideal_ms:.3f}ms M={M} N={N} K={K}')
" 2>&1s0: im2col=0.005ms perm=0.001ms gemm_ideal=0.004ms M=256 N=1280 K=1176
s1: im2col=0.005ms perm=0.001ms gemm_ideal=0.015ms M=1024 N=1280 K=1176
s2: im2col=0.007ms perm=0.001ms gemm_ideal=0.035ms M=2304 N=1280 K=1176
s3: im2col=0.005ms perm=0.001ms gemm_ideal=0.005ms M=784 N=768 K=768
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -4,8 +4,8 @@
1) im2col via torch.as_strided (zero-copy view then contiguous reshape)
2) GEMM via Triton tl.dot kernel with L2-cache-tiling + autotune
-The GEMM kernel is a standard tiled matmul operating on contiguous (M,K) and (N,K)
-inputs, achieving high tensor-core utilization on SM120 Blackwell.
+Also includes a fused implicit-GEMM kernel used as fallback for small problems
+where the im2col materialization overhead dominates.
"""
import torch
import torch.nn as nn
@@ -21,19 +21,121 @@
)
+# ---------- Fused implicit-GEMM kernel ----------
+
@triton.autotune(
configs=[
triton.Config({'BM': 64, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 64, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 32, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 32, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 64, 'BN': 128, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=3),
triton.Config({'BM': 128, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
+ ],
+ key=['C', 'kT', 'kH', 'kW', 'embed_dim', 'K'],
+)
+@triton.jit
+def fused_patch_embed_kernel(
+ x_ptr, w_ptr, out_ptr,
+ C: tl.constexpr, kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
+ num_t, num_h, num_w,
+ total_patches, embed_dim, K,
+ stride_xb, stride_xc, stride_xt, stride_xh, stride_xw,
+ stride_ob, stride_on, stride_ot, stride_oh, stride_ow,
+ BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_n_tiles = tl.cdiv(embed_dim, BN)
+ num_m_tiles = tl.cdiv(total_patches, BM)
+ total_tiles = num_m_tiles * num_n_tiles
+
+ num_pid_in_group = 2 * num_n_tiles
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * 2
+ group_size = min(num_m_tiles - first_pid_m, 2)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size)
+ pid_n = (pid % num_pid_in_group) // group_size
+
+ m_off = pid_m * BM + tl.arange(0, BM)
+ n_off = pid_n * BN + tl.arange(0, BN)
+ m_mask = m_off < total_patches
+ n_mask = n_off < embed_dim
+
+ num_spatial = num_t * num_h * num_w
+ b_idx = m_off // num_spatial
+ rem = m_off % num_spatial
+ pt_idx = rem // (num_h * num_w)
+ rem2 = rem % (num_h * num_w)
+ ph_idx = rem2 // num_w
+ pw_idx = rem2 % num_w
+
+ base = (b_idx * stride_xb
+ + pt_idx * kT * stride_xt
+ + ph_idx * kH * stride_xh
+ + pw_idx * kW * stride_xw)
+
+ acc = tl.zeros((BM, BN), dtype=tl.float32)
+ patch_vol: tl.constexpr = kT * kH * kW
+ kh_kw: tl.constexpr = kH * kW
+
+ for k_start in range(0, K, BK):
+ k_off = k_start + tl.arange(0, BK)
+ k_mask = k_off < K
+
+ c_idx = k_off // patch_vol
+ rem_k = k_off % patch_vol
+ kt_idx = rem_k // kh_kw
+ rem_k2 = rem_k % kh_kw
+ kh_idx = rem_k2 // kW
+ kw_idx = rem_k2 % kW
+
+ delta = (c_idx * stride_xc
+ + kt_idx * stride_xt
+ + kh_idx * stride_xh
+ + kw_idx * stride_xw)
+
+ x_off = base[:, None] + delta[None, :]
+ x_block = tl.load(x_ptr + x_off, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+ w_off = n_off[None, :] * K + k_off[:, None]
+ w_block = tl.load(w_ptr + w_off, mask=k_mask[:, None] & n_mask[None, :], other=0.0)
+
+ acc += tl.dot(x_block, w_block)
+
+ out_vals = acc.to(tl.bfloat16)
+ out_off = (b_idx[:, None] * stride_ob
+ + n_off[None, :] * stride_on
+ + pt_idx[:, None] * stride_ot
+ + ph_idx[:, None] * stride_oh
+ + pw_idx[:, None] * stride_ow)
+ tl.store(out_ptr + out_off, out_vals, mask=m_mask[:, None] & n_mask[None, :])
+
+
+# ---------- GEMM kernel (im2col + matmul) ----------
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=3),
+ triton.Config({'BM': 128, 'BN': 128, 'BK': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BM': 128, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 128, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 64, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 64, 'BN': 128, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
- triton.Config({'BM': 128, 'BN': 128, 'BK': 64}, num_warps=8, num_stages=3),
- triton.Config({'BM': 64, 'BN': 128, 'BK': 64}, num_warps=4, num_stages=3),
- triton.Config({'BM': 128, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 64, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
triton.Config({'BM': 32, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 32, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 32, 'BN': 64, 'BK': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 16, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 16, 'BN': 64, 'BK': 32}, num_warps=2, num_stages=3),
triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=8, num_stages=4),
triton.Config({'BM': 128, 'BN': 128, 'BK': 64}, num_warps=8, num_stages=4),
+ triton.Config({'BM': 64, 'BN': 64, 'BK': 128}, num_warps=4, num_stages=2),
+ triton.Config({'BM': 64, 'BN': 128, 'BK': 128}, num_warps=4, num_stages=2),
+ triton.Config({'BM': 32, 'BN': 128, 'BK': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BM': 64, 'BN'
... (3280 more chars){
"content": "<9940 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark with hybrid approach",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.277 gbps=18.150 ms=0.235
shape=0 variant=compiled tflops=2.832 gbps=15.684 ms=0.272
shape=0 variant=sota tflops=3.303 gbps=18.293 ms=0.233
shape=0 variant=solution tflops=11.835 gbps=65.541 ms=0.065
shape=0 solution_peak_fraction=0.0592
shape=1 variant=eager tflops=12.469 gbps=32.522 ms=0.247
shape=1 variant=compiled tflops=10.739 gbps=28.010 ms=0.287
shape=1 variant=sota tflops=12.614 gbps=32.899 ms=0.244
shape=1 variant=solution tflops=48.680 gbps=126.965 ms=0.063
shape=1 solution_peak_fraction=0.2434
shape=2 variant=eager tflops=24.177 gbps=49.941 ms=0.287
shape=2 variant=compiled tflops=21.442 gbps=44.292 ms=0.323
shape=2 variant=sota tflops=24.664 gbps=50.947 ms=0.281
shape=2 variant=solution tflops=92.238 gbps=190.529 ms=0.075
shape=2 solution_peak_fraction=0.4612
shape=3 variant=eager tflops=5.770 gbps=22.388 ms=0.160
shape=3 variant=compiled tflops=4.683 gbps=18.167 ms=0.198
shape=3 variant=sota tflops=5.862 gbps=22.744 ms=0.158
shape=3 variant=solution tflops=25.554 gbps=99.141 ms=0.036
shape=3 solution_peak_fraction=0.1278
peak_fraction: 0.1707
RESULT: OK
{
"command": "nvcc --version 2>&1 | head -5; nvcc --list-gpu-arch 2>&1 | tail -5",
"description": "Check nvcc version and supported archs"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
compute_100
compute_110
compute_103
compute_120
compute_121
{
"command": "<534 chars \u2014 see below>",
"description": "Test CUDA C++ WMMA compilation",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
import torch
# Test that CUDA C++ extensions work with bf16 wmma
from torch.utils.cpp_extension import load_inline
code = '''
#include <cuda.h>
#include <cuda_bf16.h>
#include <mma.h>
using namespace nvcuda;
__global__ void test_wmma() {
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc;
wmma::fill_fragment(acc, 0.0f);
}
'''
mod = load_inline(name='test_wmma', cpp_sources=['void test();'], cuda_sources=[code], functions=['test_wmma'], verbose=False)
print('WMMA compilation: OK')
" 2>&1 | tail -3[2/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=test_wmma -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/test_wmma/cuda.cu -o cuda.cuda.o
ninja: build stopped: subcommand failed.
{
"command": "<557 chars \u2014 see below>",
"description": "Test basic CUDA bf16 compilation on SM120",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
from torch.utils.cpp_extension import load_inline
code = '''
#include <cuda.h>
#include <cuda_bf16.h>
__global__ void test_bf16(__nv_bfloat16* out, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) out[idx] = __float_to_bfloat16_rn(float(idx));
}
'''
mod = load_inline(name='test_bf16', cpp_sources=[], cuda_sources=[code],
functions=['test_bf16'], verbose=True,
extra_cuda_cflags=['-gencode=arch=compute_120,code=sm_120'])
print('BF16 compilation: OK')
" 2>&1 | tail -20File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 11, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2268, in _jit_compile
_write_ninja_file_and_build_library(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2429, in _write_ninja_file_and_build_library
_run_ninja_build(
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2791, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error building extension 'test_bf16'
{
"command": "<502 chars \u2014 see below>",
"description": "Debug CUDA compilation error",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
from torch.utils.cpp_extension import load_inline
code = '''
#include <cuda.h>
#include <cuda_bf16.h>
__global__ void test_bf16(__nv_bfloat16* out, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) out[idx] = __float_to_bfloat16_rn(float(idx));
}
'''
mod = load_inline(name='test_bf16', cpp_sources=[], cuda_sources=[code],
functions=['test_bf16'], verbose=True)
print('OK')
" 2>&1 | grep -E "(error|Error|sm_|arch|gencode)" | head -20[1/3] /usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=test_bf16 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,code=sm_120 --compiler-options '-fPIC' -std=c++17 -c /home/infatoshi/.cache/torch_extensions/py311_cu130/test_bf16/cuda.cu -o cuda.cuda.o
/usr/local/cuda-13/bin/nvcc -MD -MF cuda.cuda.o.d -DTORCH_EXTENSION_NAME=test_bf16 -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include -isystem /home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13/include -isystem /home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/include/python3.11 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_120,code=compute_120 -gencode=ar
... (755 more chars)