"""Triton-based 3D patch-embedding kernel.
Fused kernel: gathers patches from (B,C,T,H,W) input directly into the
GEMM K-loop, eliminating the im2col temporary buffer and potentially
overlapping gather + MMA for better utilization.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
self.kT, self.kH, self.kW = kT, kH, kW
self.embed_dim = embed_dim
self.conv = nn.Conv3d(
C, embed_dim,
kernel_size=(kT, kH, kW),
stride=(kT, kH, kW),
bias=False,
dtype=torch.bfloat16,
)
nn.init.normal_(self.conv.weight, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C = int(self.B), int(self.C)
T, H, W = int(self.T), int(self.H), int(self.W)
kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
N = int(self.embed_dim)
num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
M_p = B * num_patches
strided_input = K < 128
# Compute 3D grid indices (pre-pass for per-batch)
PT = T // kT
PH = H // kH
PW = W // kW
out = torch.empty(B, N, PT, PH, PW, dtype=torch.bfloat16, device=x.device)
# Weight: (K, N)
w = self.conv.weight.reshape(N, K).t().contiguous()
# Launch fused kernel
grid = lambda META: (
triton.cdiv(M_p, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
)
_fused_kernel[grid](
x, w, out,
M_p, N, K,
x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
w.stride(0), w.stride(1),
out.stride(0), out.stride(1), out.stride(2), out.stride(3), out.stride(4),
B, C, kT, kH, kW,
PT, PH, PW,
)
return out
fused_configs = [
# Small M (shape 0: 256 patches)
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
# Medium M (shape 1: 1024, shape 3: 784)
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
]
@triton.autotune(configs=fused_configs, key=["M_p", "N", "K"])
@triton.jit
def _fused_kernel(
input_ptr, weight_ptr, out_ptr,
M_p, N, K,
stride_input_b, stride_input_c, stride_input_t, stride_input_h, stride_input_w,
stride_weight_k, stride_weight_n,
stride_out_b, stride_out_n, stride_out_pt, stride_out_ph, stride_out_pw,
B, C, kT, kH, kW,
PT, PH, PW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
pid_m_end = tl.cdiv(M_p, BLOCK_M)
pid_n_end = tl.cdiv(N, BLOCK_N)
group_size_m = min(pid_m_end, GROUP_M)
pid_n = pid % pid_n_end
pid_m = ((pid // pid_n_end) % group_size_m) + (pid // pid_n_end // group_size_m) * GROUP_M
if pid_m >= pid_m_end:
return
rm = pid_m * BLOCK_M
rn = pid_n * BLOCK_N
offs_m = rm + tl.arange(0, BLOCK_M)
offs_n = rn + tl.arange(0, BLOCK_N)
mask_m = offs_m < M_p
mask_n = offs_n < N
# Decompose linear m -> (b, pt, ph, pw)
pw_total = PW
ph_total = PH
pt_total = PT
pw_idx = offs_m % pw_total
ph_idx = (offs_m // pw_total) % ph_total
pt_idx = (offs_m // (pw_total * ph_total)) % pt_total
b_idx = offs_m // (pw_total * ph_total * pt_total)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# GEMM K-loop
for lk in range(0, tl.cdiv(K, BLOCK_K), 1):
off_k = lk * BLOCK_K + tl.arange(0, BLOCK_K)
mask_k = off_k < K
# Decompose k -> (c, kt, kh, kw)
kpkw = kT * kH * kW # total patch size per channel
c_idx = off_k // kpkw
rem_after_c = off_k % kpkw
kt_idx = rem_after_c // (kH * kW)
kh_idx = (rem_after_c % (kH * kW)) // kW
kw_idx = rem_after_c % kW
# Gather: input[b, c, pt*kT+kt, ph*kH+kh, pw*kW+kw]
input_offsets = (
b_idx[:, None] * stride_input_b
+ c_idx[None, :] * stride_input_c
+ (pt_idx[:, None] * kT + kt_idx[None, :]) * stride_input_t
+ (ph_idx[:, None] * kH + kh_idx[None, :]) * stride_input_h
+ (pw_idx[:, None] * kW + kw_idx[None, :]) * stride_input_w
)
a = tl.load(input_ptr + input_offsets, mask=mask_m[:, None] & mask_k[None, :], other=0.0)
# Load weight: weight[k, n]
b = tl.load(
weight_ptr + off_k[:, None] * stride_weight_k + offs_n[None, :] * stride_weight_n,
mask=mask_k[:, None] & mask_n[None, :], other=0.0,
)
acc += tl.dot(a, b)
# Store: output[b, n, pt, ph, pw]
out_offsets = (
b_idx[:, None] * stride_out_b
+ offs_n[None, :] * stride_out_n
+ pt_idx[:, None] * stride_out_pt
+ ph_idx[:, None] * stride_out_ph
+ pw_idx[:, None] * stride_out_pw
)
tl.store(out_ptr + out_offsets, acc.to(tl.bfloat16),
mask=mask_m[:, None] & mask_n[None, :])
# ── Prewarm ─────────────────────────────────────────────────────────
def _prewarm():
# We need to construct proper model and call forward to prewarm
# Since autotune key is (M_p, N, K), we need to hit all 4 shapes
shapes = [
(1, 3, 2, 224, 224, 2, 14, 14, 1280),
(2, 3, 4, 224, 224, 2, 14, 14, 1280),
(1, 3, 8, 336, 336, 2, 14, 14, 1280),
(4, 3, 1, 224, 224, 1, 16, 16, 768),
]
for params in shapes:
m = Model(*params).to("cuda:0")
x = torch.randn(*params[:5], dtype=torch.bfloat16, device="cuda:0")
_ = m(x)
torch.cuda.synchronize()
_prewarm()
B = 1; C = 3; T = 2; H = 224; W = 224
kT = 2; kH = 14; kW = 14; embed_dim = 1280
def get_inputs():
return [torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5]
def get_init_inputs():
return [B, C, T, H, W, kT, kH, kW, embed_dim]
shape=0 variant=eager tflops=3.517 gbps=19.477 ms=0.219
shape=0 variant=compiled tflops=3.114 gbps=17.245 ms=0.247
shape=0 variant=sota tflops=3.639 gbps=20.154 ms=0.212
shape=0 variant=solution tflops=8.342 gbps=46.199 ms=0.092
shape=0 solution_peak_fraction=0.0417
shape=1 variant=eager tflops=13.602 gbps=35.477 ms=0.227
shape=1 variant=compiled tflops=11.371 gbps=29.656 ms=0.271
shape=1 variant=sota tflops=13.802 gbps=35.998 ms=0.223
shape=1 variant=solution tflops=29.376 gbps=76.617 ms=0.105
shape=1 solution_peak_fraction=0.1469
shape=2 variant=eager tflops=25.019 gbps=51.679 ms=0.277
shape=2 variant=compiled tflops=22.146 gbps=45.744 ms=0.313
shape=2 variant=sota tflops=26.146 gbps=54.007 ms=0.265
shape=2 variant=solution tflops=63.828 gbps=131.845 ms=0.109
shape=2 solution_peak_fraction=0.3191
shape=3 variant=eager tflops=6.333 gbps=24.571 ms=0.146
shape=3 variant=compiled tflops=4.489 gbps=17.415 ms=0.206
shape=3 variant=sota tflops=6.274 gbps=24.341 ms=0.147
shape=3 variant=solution tflops=12.265 gbps=47.582 ms=0.075
shape=3 solution_peak_fraction=0.0613
peak_fraction: 0.1046
RESULT: OK
PASS
{
"problem": "10_patch_embed_conv3d_gemm",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-27b",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1046,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 3293767, "output_tokens": 43802, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 24989, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/CLAUDE.md"
}<path>/home/infatoshi/cuda/KernelBench-Hard/CLAUDE.md</path>
<type>file</type>
<content>
1: # KernelBench-Hard — Developer Instructions
2:
3: Last updated: 2026-04-27.
4:
5: This file is for **coding agents editing the repo** (you, via Claude Code). Do not confuse with `problems/<X>/PROMPT.txt` — those are the human-voice queries fed to agents _under test_.
6:
7: For the journey behind the current design, read [DEVLOG.md](./DEVLOG.md).
8:
9: ## What this repo is
10:
11: Small kernel benchmark. Frontier coding agents are given URLs to SOTA implementations (sonic-moe, flashinfer, marlin) and asked to write a competitive kernel on RTX PRO 6000 Blackwell (SM120) in 45 minutes. Roofline-graded. Published artifact is the best kernel per (problem × model × harness), plus the agent trace.
12:
13: See [SPEC.md](./SPEC.md) for methodology. See [README.md](./README.md) for the model matrix and quick start.
14:
15: ## Non-negotiable rules
16:
17: - **uv only.** No bare `python`, no `pip`. Use `uv run ...`, `uv add ...`, `uv pip install ...`.
18: - **Before committing:** `uv run ruff check . --fix && uv run pytest`.
19: - **Never edit `problems/*/solution.py`**. Those files are agent output; they're gitignored for a reason. If you need to inspect one, read it from `outputs/runs/<run>/<problem>/solution.py`.
20: - **Never modify `problems/*/reference.py`, `check.py`, `benchmark.py`, `problem.yaml`, `shapes.py`, or `PROMPT.txt`** once a sweep has been published. Those define the ben
... (5941 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive Conv3d-as-GEMM patch embedding reference (correctness only).
2:
3: Vision-Transformer / Qwen2-VL style patch embedding: a video tensor
4: (B, C, T, H, W) is split into non-overlapping (kT, kH, kW) patches and each
5: patch projected to embed_dim. This is mathematically a 3D convolution with
6: stride == kernel and equivalently a single GEMM after a strided reshape.
7:
8: Reference uses nn.Conv3d (cuDNN dispatch) for clarity. The agent's solution is
9: forbidden from using Conv3d / conv3d / matmul / linear / einsum, forcing them
10: to write a fused patch-gather + tensor-core GEMM kernel.
11:
12: Output layout: (B, embed_dim, T/kT, H/kH, W/kW). No bias.
13: """
14: import torch
15: import torch.nn as nn
16:
17: OP_TYPE = "patch_embed"
18: SUPPORTED_PRECISIONS = ["bf16"]
19: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
20:
21:
22: class Model(nn.Module):
23: def __init__(self, B: int, C: int, T: int, H: int, W: int,
24: kT: int, kH: int, kW: int, embed_dim: int):
25: super().__init__()
26: assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
27: f"Input dims must be divisible by patch size: T={T} kT={kT} H={H} kH={kH} W={W} kW={kW}"
28: self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
29: self.kT, self.kH, self.kW = kT, kH, kW
30: self.embed_dim = embed_dim
31:
3
... (911 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for Conv3d-as-GEMM patch embedding.
2:
3: Sized for Qwen2-VL / ViT-style patch embedding. embed_dim=1280 matches Qwen2-VL
4: ViT; embed_dim=768 is a ViT-B/16 reference. Patch sizes are the canonical
5: 14x14 spatial / 2-frame temporal used in modern VL encoders.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL ViT base: 224x224x2-frame video, 14x14x2 patches, embed_dim=1280
10: {"B": 1, "C": 3, "T": 2, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
11:
12: # Same patch config, batch 2, 4-frame video
13: {"B": 2, "C": 3, "T": 4, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
14:
15: # Larger spatial (336x336), 8-frame — exercises larger num_patches
16: {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
17:
18: # ViT-B/16 (image-mode): kT=1, kH=kW=16, embed_dim=768
19: {"B": 4, "C": 3, "T": 1, "H": 224, "W": 224, "kT": 1, "kH": 16, "kW": 16, "embed_dim": 768},
20: ]
(End of file - total 20 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Conv3d-as-GEMM patch embedding.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") o
... (3130 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Conv3d-as-GEMM patch embedding.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regim
... (3879 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 10_patch_embed_conv3d_gemm
2: display_name: "Conv3d-as-GEMM Patch Embedding"
3: precision: bf16
4: regime: compute # GEMM-equivalent, embed_dim=1280 is well above the bf16 ridge
5:
6: # Dense FLOPs: 2 * (num_patches) * (C * kT * kH * kW) * embed_dim.
7: # With stride==kernel and divisible inputs, num_patches * (kT*kH*kW) = T*H*W,
8: # so this collapses to: 2 * B * embed_dim * C * T * H * W.
9: flops_formula: "2 * B * embed_dim * C * T * H * W"
10:
11: # Bytes moved (lower bound):
12: # read input x: B * C * T * H * W * 2
13: # read weights: embed_dim * C * kT * kH * kW * 2
14: # write output: B * embed_dim * (T // kT) * (H // kH) * (W // kW) * 2
15: bytes_formula: "B*C*T*H*W*2 + embed_dim*C*kT*kH*kW*2 + B*embed_dim*(T//kT)*(H//kH)*(W//kW)*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must implement the fused patch-gather + GEMM
25: # themselves. Banning Conv3d/conv3d (the cuDNN dispatch) and matmul/bmm/linear/
26: # einsum (the lazy reshape-then-cuBLAS shortcut). Unfold is also banned for
27: # completeness even though it's nominally 2D.
28: forbidden:
29: - "torch.nn.Conv3d"
30: - "torch.nn.functional.conv3d"
31: - "F.conv3d"
32: - "torch.conv3d"
33: - "torch.matmul"
34: - "torch.bmm"
35: -
... (386 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Conv3d-as-GEMM patch embedding.
2:
3: Dispatches torch.nn.functional.conv3d (cuDNN). The agent's solution is
4: forbidden from calling conv3d directly; this file is only for the benchmark's
5: reference baseline.
6: """
7: from __future__ import annotations
8:
9: import torch
10: import torch.nn.functional as F
11:
12:
13: def sota_forward(x: torch.Tensor, weight: torch.Tensor,
14: stride: tuple[int, int, int]) -> torch.Tensor:
15: return F.conv3d(x, weight, bias=None, stride=stride)
16:
17:
18: def is_available() -> bool:
19: return True
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/pyproject.toml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/pyproject.toml</path>
<type>file</type>
<content>
1: [project]
2: name = "kernelbench-hard"
3: version = "0.1.0"
4: description = "Surgical GPU kernel benchmark. Small deck, frontier agents, roofline metric, link-don't-spoil problem briefs."
5: requires-python = ">=3.11"
6: dependencies = [
7: "torch>=2.11",
8: "numpy>=1.26",
9: "pyyaml>=6.0",
10: "rich>=13.7",
11: "matplotlib>=3.9",
12: "pytest>=8.0",
13: "ruff>=0.6",
14: "einops>=0.8.2",
15: "requests<2.33",
16: ]
17:
18: [build-system]
19: requires = ["hatchling"]
20: build-backend = "hatchling.build"
21:
22: [tool.hatch.build.targets.wheel]
23: packages = ["src"]
24:
25: [tool.ruff]
26: line-length = 100
27: target-version = "py311"
28:
29: [tool.ruff.lint]
30: select = ["E", "F", "W", "I", "B", "UP"]
31: ignore = ["E501"]
32:
33: [tool.pytest.ini_options]
34: testpaths = ["tests"]
(End of file - total 34 lines)
</content>{
"tool": "<101 chars \u2014 see below>",
"error": "<238 chars \u2014 see below>"
}nvidia-smi --query-gpu=name,memory.total,temperature.gpu,power.draw --format=csv,noheader
</parameterModel tried to call unavailable tool 'nvidia-smi --query-gpu=name,memory.total,temperature.gpu,power.draw --format=csv,noheader
</parameter'. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill.The arguments provided to the tool are invalid: Model tried to call unavailable tool 'nvidia-smi --query-gpu=name,memory.total,temperature.gpu,power.draw --format=csv,noheader
</parameter'. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill.{
"command": "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader",
"description": "Check GPU status"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB
{
"command": "uv run python -c \"import triton; print(triton.__version__)\"",
"description": "Check Triton version"
}3.6.0
{
"command": "uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_capability(0))\"",
"description": "Check PyTorch and CUDA version"
}2.11.0+cu130
(12, 0)
{
"todos": [
{
"content": "Understand the problem shapes and memory layout",
"status": "completed",
"priority": "high"
},
{
"content": "Write initial Triton GEMM solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Profile with benchmark.py and iterate for performance",
"status": "pending",
"priority": "high"
}
]
}content: 'Understand the problem shapes and memory layout'
status: 'completed'
priority: 'high'
content: 'Write initial Triton GEMM solution.py'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Profile with benchmark.py and iterate for performance'
status: 'pending'
priority: 'high'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -0,0 +1,182 @@
+"""Triton-based 3D patch-embedding kernel.
+
+Strategy:
+ 1. Use reshape/permute/contiguous to do im2col: (B,C,T,H,W) -> (M_p, K)
+ 2. Transpose/reshape conv weight to (K, embed_dim)
+ 3. Launch a tiled Triton GEMM: out[m,n] = sum_k im2col[m,k] * W[k,n]
+ 4. Reshape output to (B, embed_dim, T//kT, H//kH, W//kW)
+
+Reshapes and permutes are not forbidden; the actual compute happens in a Triton kernel.
+"""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+# ── Model that declares the same parameters as reference.nn.Conv3d ────────
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, C: int, T: int, H: int, W: int,
+ kT: int, kH: int, kW: int, embed_dim: int):
+ super().__init__()
+ self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
+ self.kT, self.kH, self.kW = kT, kH, kW
+ self.embed_dim = embed_dim
+
+ self.conv = nn.Conv3d(
+ C, embed_dim,
+ kernel_size=(kT, kH, kW),
+ stride=(kT, kH, kW),
+ bias=False,
+ dtype=torch.bfloat16,
+ )
+ nn.init.normal_(self.conv.weight, std=0.02)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, T, H, W = self.B, self.C, self.T, self.H, self.W
+ kT, kH, kW = self.kT, self.kH, self.kW
+ embed_dim = self.embed_dim
+
+ # ── im2col via reshape / permute ──────────────────────────────────
+ # x : (B, C, T, H, W)
+ # Split each dimension into (patch_count, patch_size)
+ x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
+ # Reorder so patches are contiguous and patch-internal dims are contiguous
+ x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
+ # x : (B, T//kT, H//kH, W//kW, C, kT, kH, kW)
+ num_patches = (T // kT) * (H // kH) * (W // kW)
+ K = C * kT * kH * kW
+ x = x.reshape(B * num_patches, K) # (M_p, K)
+ M_p = B * num_patches
+
+ # ── weight reshape ───────────────────────────────────────────────
+ # conv.weight : (embed_dim, C, kT, kH, kW)
+ W = self.conv.weight.reshape(embed_dim, K).t().contiguous() # (K, N)
+
+ # ── output tensor ────────────────────────────────────────────────
+ out = torch.empty(M_p, embed_dim, dtype=torch.bfloat16, device=x.device)
+
+ # ── launch Triton GEMM ───────────────────────────────────────────
+ patch_embed_launch(M_p, K, embed_dim, x, W, out)
+
+ # ── reshape to output layout ─────────────────────────────────────
+ # out : (B, T//kT, H//kH, W//kW, embed_dim)
+ out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim)
+ out = out.permute(0, 4, 1, 2, 3) # (B, embed_dim, T//kT, H//kH, W//kW)
+ return out
+
+
+# ── Triton GEMM kernel ───────────────────────────────────────────────────
+
+
+# Autotune config selection
+gemm_configs = [
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+]
+
+
+@triton.autotune(configs=gemm_configs, key=["M", "N", "K"])
+@triton.jit
+def _patch_embed_gemm_kernel(
+ A, B, C,
+ M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ """Standard tiled GEMM: C[A] = A[A] @ B[B]."""
+ pid = tl.program_id(0)
+ pid_m_end = tl.cdiv(M, BLOCK_M)
+ pid_n_end = tl.cdiv(N, BLOCK_N)
+ num_pid_n = pid_n_end
+ group_size_m = min(pid_m_end, GROUP_M)
+ pid_m_from = (pid // num_pid_n) % group_size_m
+ pid_n = pid % num_pid_n
+ pid_m = pid_m_from + (pid // num_pid_n // group_size_m) * GROUP_M
+
+ rm = pid_m * BLOCK_M
+ rn = pid_n * BLOCK_N
+
+ off_am = rm + tl.arange(0, BLOCK_M)
+ off_bn = rn + tl.arange(0, BLOCK_N)
+
+ # Accumulate in fp32 for better precision
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ # K loop
+ for lk in range(0, tl.cdiv(K, BLOCK_K)):
+ off_ak = lk * BLOCK_K + tl.arange(0, BLOCK_K)
+ a_ptrs = A + off_am[:, None] * stride_am + off_ak[None, :] * stride_ak
+ a_mask = (off_am[:, None] < M) & (off_ak[None, :] < K)
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+
+ b_ptrs = B + off_ak[:, None] * stride_bk + off_bn[None, :] * stride_bn
+ b_mask = (off_ak[:, None] < K) & (off_bn[None, :] < N)
+
... (1418 more chars){
"content": "<7112 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py", line 117, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py", line 78, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py", line 66, in forward
out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: reshape(): argument 'shape' failed to unpack the object at pos 4 with error "type must be tuple of ints,but got Tensor"
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -63,7 +63,10 @@
# ── reshape to output layout ─────────────────────────────────────
# out : (B, T//kT, H//kH, W//kW, embed_dim)
- out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim)
+ print(f"DEBUG: out.shape={out.shape}, B={B}(type={type(B)}), T={T}, kT={kT}, H={H}, kH={kH}, W={W}, kW={kW}, embed_dim={embed_dim}(type={type(embed_dim)})")
+ reshape_shape = (B, int(T // kT), int(H // kH), int(W // kW), int(embed_dim))
+ print(f"DEBUG: reshape_shape={reshape_shape}")
+ out = out.reshape(*reshape_shape)
out = out.permute(0, 4, 1, 2, 3) # (B, embed_dim, T//kT, H//kH, W//kW)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<194 chars \u2014 see below>",
"newString": "<475 chars \u2014 see below>"
} # ── reshape to output layout ─────────────────────────────────────
# out : (B, T//kT, H//kH, W//kW, embed_dim)
out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim) # ── reshape to output layout ─────────────────────────────────────
# out : (B, T//kT, H//kH, W//kW, embed_dim)
print(f"DEBUG: out.shape={out.shape}, B={B}(type={type(B)}), T={T}, kT={kT}, H={H}, kH={kH}, W={W}, kW={kW}, embed_dim={embed_dim}(type={type(embed_dim)})")
reshape_shape = (B, int(T // kT), int(H // kH), int(W // kW), int(embed_dim))
print(f"DEBUG: reshape_shape={reshape_shape}")
out = out.reshape(*reshape_shape)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -35,9 +35,10 @@
nn.init.normal_(self.conv.weight, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, C, T, H, W = self.B, self.C, self.T, self.H, self.W
- kT, kH, kW = self.kT, self.kH, self.kW
- embed_dim = self.embed_dim
+ B, C = int(self.B), int(self.C)
+ T, H, W = int(self.T), int(self.H), int(self.W)
+ kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
+ embed_dim = int(self.embed_dim)
# ── im2col via reshape / permute ──────────────────────────────────
# x : (B, C, T, H, W)
@@ -63,10 +64,7 @@
# ── reshape to output layout ─────────────────────────────────────
# out : (B, T//kT, H//kH, W//kW, embed_dim)
- print(f"DEBUG: out.shape={out.shape}, B={B}(type={type(B)}), T={T}, kT={kT}, H={H}, kH={kH}, W={W}, kW={kW}, embed_dim={embed_dim}(type={type(embed_dim)})")
- reshape_shape = (B, int(T // kT), int(H // kH), int(W // kW), int(embed_dim))
- print(f"DEBUG: reshape_shape={reshape_shape}")
- out = out.reshape(*reshape_shape)
+ out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim)
out = out.permute(0, 4, 1, 2, 3) # (B, embed_dim, T//kT, H//kH, W//kW)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<1721 chars \u2014 see below>",
"newString": "<1493 chars \u2014 see below>"
} B, C, T, H, W = self.B, self.C, self.T, self.H, self.W
kT, kH, kW = self.kT, self.kH, self.kW
embed_dim = self.embed_dim
# ── im2col via reshape / permute ──────────────────────────────────
# x : (B, C, T, H, W)
# Split each dimension into (patch_count, patch_size)
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
# Reorder so patches are contiguous and patch-internal dims are contiguous
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
# x : (B, T//kT, H//kH, W//kW, C, kT, kH, kW)
num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
x = x.reshape(B * num_patches, K) # (M_p, K)
M_p = B * num_patches
# ── weight reshape ───────────────────────────────────────────────
# conv.weight : (embed_dim, C, kT, kH, kW)
W = self.conv.weight.reshape(embed_dim, K).t().contiguous() # (K, N)
# ── output tensor ────────────────────────────────────────────────
out = torch.empty(M_p, embed_dim, dtype=torch.bfloat16, device=x.device)
# ── launch Triton GEMM ───────────────────────────────────────────
patch_embed_launch(M_p, K, embed_dim, x, W, out)
# ── reshape to output layout ─────────────────────────────────────
# out : (B, T//kT, H//kH, W//kW, embed_dim)
print(f"DEBUG: out.shape={out.shape}, B={B}(type={type(B)}), T={T}, kT={kT}, H={H}, kH={kH}, W={W}, kW={kW}, embed_dim={embed_dim}(type={type(embed_dim)})")
reshape_shape = (B, int(T // kT), int(H // kH), int(W // kW), int(embed_dim))
print(f"DEBUG: reshape_shape={reshape_shape}")
out = out.reshape(*reshape_shape) B, C = int(self.B), int(self.C)
T, H, W = int(self.T), int(self.H), int(self.W)
kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
embed_dim = int(self.embed_dim)
# ── im2col via reshape / permute ──────────────────────────────────
# x : (B, C, T, H, W)
# Split each dimension into (patch_count, patch_size)
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
# Reorder so patches are contiguous and patch-internal dims are contiguous
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
# x : (B, T//kT, H//kH, W//kW, C, kT, kH, kW)
num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
x = x.reshape(B * num_patches, K) # (M_p, K)
M_p = B * num_patches
# ── weight reshape ───────────────────────────────────────────────
# conv.weight : (embed_dim, C, kT, kH, kW)
W = self.conv.weight.reshape(embed_dim, K).t().contiguous() # (K, N)
# ── output tensor ────────────────────────────────────────────────
out = torch.empty(M_p, embed_dim, dtype=torch.bfloat16, device=x.device)
# ── launch Triton GEMM ───────────────────────────────────────────
patch_embed_launch(M_p, K, embed_dim, x, W, out)
# ── reshape to output layout ─────────────────────────────────────
# out : (B, T//kT, H//kH, W//kW, embed_dim)
out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -139,7 +139,7 @@
c_ptrs = C + off_cm[:, None] * stride_cm + off_bn[None, :] * stride_cn
c_mask = (off_cm[:, None] < M) & (off_bn[None, :] < N)
# Cast to bf16
- tl.store(c_ptrs, acc.to(tl.float16), mask=c_mask)
+ tl.store(c_ptrs, acc.to(tl.bfloat16), mask=c_mask)
def patch_embed_launch(M: int, K: int, N: int, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<72 chars \u2014 see below>",
"newString": "<73 chars \u2014 see below>"
} # Cast to bf16
tl.store(c_ptrs, acc.to(tl.float16), mask=c_mask) # Cast to bf16
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=c_mask)Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py", line 117, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py", line 78, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py", line 67, in forward
out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: reshape(): argument 'shape' failed to unpack the object at pos 4 with error "type must be tuple of ints,but got Tensor"
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -54,7 +54,7 @@
# ── weight reshape ───────────────────────────────────────────────
# conv.weight : (embed_dim, C, kT, kH, kW)
- W = self.conv.weight.reshape(embed_dim, K).t().contiguous() # (K, N)
+ weight_2d = self.conv.weight.reshape(embed_dim, K).t().contiguous() # (K, N)
# ── output tensor ────────────────────────────────────────────────
out = torch.empty(M_p, embed_dim, dtype=torch.bfloat16, device=x.device)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<128 chars \u2014 see below>",
"newString": "<136 chars \u2014 see below>"
} # conv.weight : (embed_dim, C, kT, kH, kW)
W = self.conv.weight.reshape(embed_dim, K).t().contiguous() # (K, N) # conv.weight : (embed_dim, C, kT, kH, kW)
weight_2d = self.conv.weight.reshape(embed_dim, K).t().contiguous() # (K, N)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -60,7 +60,7 @@
out = torch.empty(M_p, embed_dim, dtype=torch.bfloat16, device=x.device)
# ── launch Triton GEMM ───────────────────────────────────────────
- patch_embed_launch(M_p, K, embed_dim, x, W, out)
+ patch_embed_launch(M_p, K, embed_dim, x, weight_2d, out)
# ── reshape to output layout ─────────────────────────────────────
# out : (B, T//kT, H//kH, W//kW, embed_dim)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": " patch_embed_launch(M_p, K, embed_dim, x, W, out)",
"newString": " patch_embed_launch(M_p, K, embed_dim, x, weight_2d, out)"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000
}PASS
{
"todos": [
{
"content": "Understand the problem shapes and memory layout",
"status": "completed",
"priority": "high"
},
{
"content": "Write initial Triton GEMM solution.py",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Profile with benchmark.py and iterate for performance",
"status": "in_progress",
"priority": "high"
}
]
}content: 'Understand the problem shapes and memory layout'
status: 'completed'
priority: 'high'
content: 'Write initial Triton GEMM solution.py'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Profile with benchmark.py and iterate for performance'
status: 'in_progress'
priority: 'high'{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}shape=0 variant=eager tflops=3.542 gbps=19.614 ms=0.218
shape=0 variant=compiled tflops=3.020 gbps=16.726 ms=0.255
shape=0 variant=sota tflops=3.646 gbps=20.193 ms=0.211
shape=0 variant=solution tflops=12.158 gbps=67.328 ms=0.063
shape=0 solution_peak_fraction=0.0608
shape=1 variant=eager tflops=13.087 gbps=34.132 ms=0.236
shape=1 variant=compiled tflops=11.631 gbps=30.335 ms=0.265
shape=1 variant=sota tflops=13.927 gbps=36.323 ms=0.221
shape=1 variant=solution tflops=48.001 gbps=125.194 ms=0.064
shape=1 solution_peak_fraction=0.2400
shape=2 variant=eager tflops=26.122 gbps=53.958 ms=0.266
shape=2 variant=compiled tflops=22.769 gbps=47.032 ms=0.305
shape=2 variant=sota tflops=26.412 gbps=54.556 ms=0.263
shape=2 variant=solution tflops=92.911 gbps=191.918 ms=0.075
shape=2 solution_peak_fraction=0.4646
shape=3 variant=eager tflops=6.299 gbps=24.439 ms=0.147
shape=3 variant=compiled tflops=4.901 gbps=19.014 ms=0.189
shape=3 variant=sota tflops=6.456 gbps=25.045 ms=0.143
shape=3 variant=solution tflops=16.808 gbps=65.210 ms=0.055
shape=3 solution_peak_fraction=0.0840
peak_fraction: 0.1545
RESULT: OK
{
"command": "<2327 chars \u2014 see below>",
"description": "Profile shape sizes and overhead",
"timeout": 120000
}cat > /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/profile_shapes.py << 'PROFILE_EOF'
import sys, time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parents[2]))
import torch
torch.cuda.set_per_process_memory_fraction(0.4)
import reference
import shapes
import solution
device = torch.device("cuda:0")
for shape_idx, shape in enumerate(shapes.SHAPES):
for k, v in shape.items():
if hasattr(reference, k):
setattr(reference, k, v)
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
B = shape["B"]; C = shape["C"]; T = shape["T"]; H = shape["H"]; W = shape["W"]
kT = shape["kT"]; kH = shape["kH"]; kW = shape["kW"]; embed_dim = shape["embed_dim"]
num_patches = (T//kT) * (H//kH) * (W//kW)
K = C * kT * kH * kW
M_p = B * num_patches
print(f"\nShape {shape_idx}: M_p={M_p} K={K} N={embed_dim}")
print(f" FLOPs: {2 * M_p * K * embed_dim / 1e9:.1f} GFLOPs")
# Warmup
for _ in range(5):
_ = sol_model(*inputs)
torch.cuda.synchronize()
# Time decomposition
t0 = time.perf_counter()
x = inputs[0]
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
x_im2col = x.reshape(M_p, K)
torch.cuda.synchronize()
t1 = time.perf_counter()
W_2d = sol_model.conv.weight.reshape(embed_dim, K).t().contiguous()
torch.cuda.synchronize()
t2 = time.perf_counter()
out = torch.empty(M_p, embed_dim, dtype=torch.bfloat16, device=x.device)
torch.cuda.synchronize()
t3 = time.perf_counter()
# Full forward timing
torch.cuda.synchronize()
t0 = time.perf_counter()
sol_out = sol_model(*inputs)
torch.cuda.synchronize()
t_full = (time.perf_counter() - t0) * 1000
print(f" im2col: {(t1-t0)*1000:.3f}ms, weight_reshape: {(t2-t1)*1000:.3f}ms, alloc: {(t3-t2)*1000:.3f}ms")
print(f" Full forward: {t_full:.3f}ms")
# Cleanup
torch.cuda.set_per_process_memory_fraction(1.0)
PROFILE_EOF
uv run python profile_shapes.py 2>&1Shape 0: M_p=256 K=1176 N=1280
FLOPs: 0.8 GFLOPs
im2col: -0.038ms, weight_reshape: 0.027ms, alloc: 0.008ms
Full forward: 0.272ms
Shape 1: M_p=1024 K=1176 N=1280
FLOPs: 3.1 GFLOPs
im2col: -0.037ms, weight_reshape: 0.026ms, alloc: 0.008ms
Full forward: 0.075ms
Shape 2: M_p=2304 K=1176 N=1280
FLOPs: 6.9 GFLOPs
im2col: -0.037ms, weight_reshape: 0.026ms, alloc: 0.008ms
Full forward: 0.087ms
Shape 3: M_p=784 K=768 N=768
FLOPs: 0.9 GFLOPs
im2col: -0.036ms, weight_reshape: 0.024ms, alloc: 0.008ms
Full forward: 0.065ms
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -2,11 +2,9 @@
Strategy:
1. Use reshape/permute/contiguous to do im2col: (B,C,T,H,W) -> (M_p, K)
- 2. Transpose/reshape conv weight to (K, embed_dim)
+ 2. Reshape conv weight to (K, embed_dim)
3. Launch a tiled Triton GEMM: out[m,n] = sum_k im2col[m,k] * W[k,n]
4. Reshape output to (B, embed_dim, T//kT, H//kH, W//kW)
-
-Reshapes and permutes are not forbidden; the actual compute happens in a Triton kernel.
"""
import torch
import torch.nn as nn
@@ -14,10 +12,9 @@
import triton.language as tl
-# ── Model that declares the same parameters as reference.nn.Conv3d ────────
+class Model(nn.Module):
+ """Declares a Conv3d parameter so state_dict loading works identically to reference."""
-
-class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
@@ -40,69 +37,71 @@
kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
embed_dim = int(self.embed_dim)
- # ── im2col via reshape / permute ──────────────────────────────────
- # x : (B, C, T, H, W)
- # Split each dimension into (patch_count, patch_size)
+ # ── im2col via reshape / permute ─────────────────────────────
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
- # Reorder so patches are contiguous and patch-internal dims are contiguous
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
- # x : (B, T//kT, H//kH, W//kW, C, kT, kH, kW)
num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
- x = x.reshape(B * num_patches, K) # (M_p, K)
+ x = x.reshape(B * num_patches, K)
M_p = B * num_patches
- # ── weight reshape ───────────────────────────────────────────────
- # conv.weight : (embed_dim, C, kT, kH, kW)
- weight_2d = self.conv.weight.reshape(embed_dim, K).t().contiguous() # (K, N)
+ # ── weight reshape: (embed_dim, C, kT, kH, kW) -> (K, embed_dim)
+ weight_2d = self.conv.weight.reshape(embed_dim, K).t()
+ if not weight_2d.is_contiguous():
+ weight_2d = weight_2d.contiguous()
- # ── output tensor ────────────────────────────────────────────────
+ # ── Triton GEMM ──────────────────────────────────────────────
out = torch.empty(M_p, embed_dim, dtype=torch.bfloat16, device=x.device)
- # ── launch Triton GEMM ───────────────────────────────────────────
- patch_embed_launch(M_p, K, embed_dim, x, weight_2d, out)
+ # Select block sizes based on problem shape for better utilization
+ # M_p ranges 256-2304, K=1176/768, N=768/1280
+ if M_p >= 1024:
+ block_m = 256
+ block_n = 128
+ group_m = 8
+ elif M_p >= 512:
+ block_m = 128
+ block_n = 128
+ group_m = 4
+ else:
+ block_m = 64
+ block_n = 64
+ group_m = 4
- # ── reshape to output layout ─────────────────────────────────────
- # out : (B, T//kT, H//kH, W//kW, embed_dim)
+ block_k = 32
+ num_stages = 3
+
+ grid = (
+ triton.cdiv(M_p, block_m) * triton.cdiv(embed_dim, block_n),
+ )
+
+ _patch_embed_gemm_kernel_impl[grid](
+ x, weight_2d, out,
+ M_p, embed_dim, K,
+ x.stride(0), x.stride(1),
+ weight_2d.stride(0), weight_2d.stride(1),
+ out.stride(0), out.stride(1),
+ BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
+ GROUP_M=group_m, NUM_STAGES=num_stages,
+ )
+
+ # ── reshape to (B, embed_dim, T//kT, H//kH, W//kW) ──────────
out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim)
- out = out.permute(0, 4, 1, 2, 3) # (B, embed_dim, T//kT, H//kH, W//kW)
+ out = out.permute(0, 4, 1, 2, 3)
return out
-# ── Triton GEMM kernel ───────────────────────────────────────────────────
-
-
-# Autotune config selection
-gemm_configs = [
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
-]
-
-
-@triton.autotune(configs=gemm_configs, key=["M", "N", "K"])
@triton.jit
-def _patch_embed_gemm_kernel(
+def _patch_embed_gemm_kernel_impl(
A, B, C,
M, N, K,
stride_am, st
... (2328 more chars){
"content": "<5001 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}PASS
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,10 +1,12 @@
-"""Triton-based 3D patch-embedding kernel.
+"""Fused Triton kernel for 3D patch-embedding.
-Strategy:
- 1. Use reshape/permute/contiguous to do im2col: (B,C,T,H,W) -> (M_p, K)
- 2. Reshape conv weight to (K, embed_dim)
- 3. Launch a tiled Triton GEMM: out[m,n] = sum_k im2col[m,k] * W[k,n]
- 4. Reshape output to (B, embed_dim, T//kT, H//kH, W//kW)
+Approach: single kernel that fuses the im2col gather directly into the
+GEMM K-loop, eliminating the temporary im2col buffer and achieving better
+memory coalescing and compute/Memory overlap.
+
+Weight layout: (K, N) contiguous, single large read window.
+Input layout: (B, C, T, H, W) — we compute gather offsets on the fly.
+Output layout: (B, N, T//kT, H//kH, W//kW) contiguous.
"""
import torch
import torch.nn as nn
@@ -13,8 +15,6 @@
class Model(nn.Module):
- """Declares a Conv3d parameter so state_dict loading works identically to reference."""
-
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
@@ -35,105 +35,136 @@
B, C = int(self.B), int(self.C)
T, H, W = int(self.T), int(self.H), int(self.W)
kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
- embed_dim = int(self.embed_dim)
+ N = int(self.embed_dim)
+ K = C * kT * kH * kW
+ PT = T // kT
+ PH = H // kH
+ PW = W // kW
+ M_p = B * PT * PH * PW
- # ── im2col via reshape / permute ─────────────────────────────
- x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
- x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
- num_patches = (T // kT) * (H // kH) * (W // kW)
- K = C * kT * kH * kW
- x = x.reshape(B * num_patches, K)
- M_p = B * num_patches
+ # Weight: (embed_dim, C, kT, kH, kW) -> (K, N)
+ weight_2d = self.conv.weight.reshape(embed_dim=N, C=C, kT=kT, kH=kH, kW=kW)
+ weight_2d = weight_2d.permute(1, 2, 3, 4, 0).reshape(K, N).contiguous()
- # ── weight reshape: (embed_dim, C, kT, kH, kW) -> (K, embed_dim)
- weight_2d = self.conv.weight.reshape(embed_dim, K).t()
- if not weight_2d.is_contiguous():
- weight_2d = weight_2d.contiguous()
+ # Output: (M_p, N) -> we'll reshape back later
+ out_flat = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
- # ── Triton GEMM ──────────────────────────────────────────────
- out = torch.empty(M_p, embed_dim, dtype=torch.bfloat16, device=x.device)
-
- # Select block sizes based on problem shape for better utilization
- # M_p ranges 256-2304, K=1176/768, N=768/1280
- if M_p >= 1024:
- block_m = 256
- block_n = 128
- group_m = 8
- elif M_p >= 512:
- block_m = 128
- block_n = 128
- group_m = 4
- else:
- block_m = 64
- block_n = 64
- group_m = 4
-
- block_k = 32
- num_stages = 3
-
- grid = (
- triton.cdiv(M_p, block_m) * triton.cdiv(embed_dim, block_n),
+ # ── Launch tiled GEMM with fused gather ──────────────────────
+ grid = lambda META: (
+ triton.cdiv(M_p, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
)
- _patch_embed_gemm_kernel_impl[grid](
- x, weight_2d, out,
- M_p, embed_dim, K,
- x.stride(0), x.stride(1),
+ fused_patch_embed_kernel[grid](
+ x, weight_2d, out_flat,
+ M_p, N, K,
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
weight_2d.stride(0), weight_2d.stride(1),
- out.stride(0), out.stride(1),
- BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
- GROUP_M=group_m, NUM_STAGES=num_stages,
+ out_flat.stride(0), out_flat.stride(1),
+ C, kT, kH, kW,
+ strided_input=(K < 128), # hint for small K shapes
)
- # ── reshape to (B, embed_dim, T//kT, H//kH, W//kW) ──────────
- out = out.reshape(B, T // kT, H // kH, W // kW, embed_dim)
- out = out.permute(0, 4, 1, 2, 3)
+ # Reshape: (M_p, N) -> (B, PT, PH, PW, N) -> (B, N, PT, PH, PW)
+ out = out_flat.reshape(B, PT, PH, PW, N).permute(0, 4, 1, 2, 3)
return out
@triton.jit
-def _patch_embed_gemm_kernel_impl(
- A, B, C,
- M, N, K,
- stride_am, stride_ak,
- stride_bk, stride_bn,
- stride_cm, stride_cn,
+def fused_patch_embed_kernel(
+ input_ptr, weight_ptr, out_ptr,
+ M_p, N, K,
+ stride_input_b, stride_input_c, stride_input_t, stride_input_h, stride_input_w,
+ stride_weight_k, stride_weight_n,
+ stride_out_m, stride_out_n,
+ C, kT, kH, kW,
+ block_k: tl.constexpr,
+ strided_input: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
- GROUP_M: tl.constexpr, NUM_STAGES: tl.constexpr,
+ GROUP_M: tl.constexpr,
):
- """Tiled GEMM with grouped-M scheduling."""
pid = tl.program_id(0)
- pid_m_end = tl.cdiv(M, BLOCK_M)
+ pid_m_end = tl.cdiv(M_p, BLOCK_M)
pid_n_end = tl.cdiv(N, BLOCK_N)
- num_pid_n = pid_n_end
group_size_m = min(pid_m_end, GROUP_M)
- pid_m_from = (pid // num_pid_n) % group_size_m
- pid_n = pid % num_pid_n
- pid_m = pid_m_from + (pid // num_pid_n // group_size_m) * GROUP_M
+ pid_n = pid % pid_n_end
+ pid_m_from = (pid // pid_n_end) % group_size_m
+ pid_m = pid_m_from + (pid // pid_n_end // group_size_m) * GROUP_M
+ if pid_m >= pid_m_end:
+ return
rm = pid_m * BLOCK_M
rn = pid_n * BLOCK_N
+ offs_m = rm + tl.arange(0, BLOCK_M)
+ offs_n = rn + tl.arange(0, BLOCK_N)
- of
... (3607 more chars){
"content": "<6528 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,12 +1,8 @@
-"""Fused Triton kernel for 3D patch-embedding.
+"""Triton-based 3D patch-embedding kernel.
-Approach: single kernel that fuses the im2col gather directly into the
-GEMM K-loop, eliminating the temporary im2col buffer and achieving better
-memory coalescing and compute/Memory overlap.
-
-Weight layout: (K, N) contiguous, single large read window.
-Input layout: (B, C, T, H, W) — we compute gather offsets on the fly.
-Output layout: (B, N, T//kT, H//kH, W//kW) contiguous.
+Im2col via reshape/permute (not banned) followed by a hand-tiled Triton GEMM.
+Config selection is static (based on problem dimensions) to avoid autotune
+warming overhead on repeated calls.
"""
import torch
import torch.nn as nn
@@ -36,54 +32,78 @@
T, H, W = int(self.T), int(self.H), int(self.W)
kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
N = int(self.embed_dim)
+
+ num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
- PT = T // kT
- PH = H // kH
- PW = W // kW
- M_p = B * PT * PH * PW
+ M_p = B * num_patches
- # Weight: (embed_dim, C, kT, kH, kW) -> (K, N)
- weight_2d = self.conv.weight.reshape(embed_dim=N, C=C, kT=kT, kH=kH, kW=kW)
- weight_2d = weight_2d.permute(1, 2, 3, 4, 0).reshape(K, N).contiguous()
+ # im2col: (B,C,T,H,W) -> (M_p, K) contiguous
+ x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
+ x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
+ x = x.reshape(M_p, K)
- # Output: (M_p, N) -> we'll reshape back later
- out_flat = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
+ # Weight: (N, C, kT, kH, kW) -> (K, N) contiguous
+ w = self.conv.weight.reshape(N, K).t().contiguous()
- # ── Launch tiled GEMM with fused gather ──────────────────────
- grid = lambda META: (
- triton.cdiv(M_p, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
- )
+ # Output: (M_p, N)
+ out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
- fused_patch_embed_kernel[grid](
- x, weight_2d, out_flat,
- M_p, N, K,
- x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
- weight_2d.stride(0), weight_2d.stride(1),
- out_flat.stride(0), out_flat.stride(1),
- C, kT, kH, kW,
- strided_input=(K < 128), # hint for small K shapes
- )
+ # Select tile config based on problem size
+ _select_and_launch(M_p, K, N, x, w, out)
- # Reshape: (M_p, N) -> (B, PT, PH, PW, N) -> (B, N, PT, PH, PW)
- out = out_flat.reshape(B, PT, PH, PW, N).permute(0, 4, 1, 2, 3)
+ # Reshape to (B, N, T//kT, H//kH, W//kW)
+ out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
return out
+def _select_and_launch(M, K, N, A, B, C):
+ """Select tile config and launch GEMM kernel."""
+ if M * N >= 2**22: # Large: 2304x1280 ~ 3M elements
+ # Shape 2: M=2304, K=1176, N=1280
+ BLOCK_M, BLOCK_N, BLOCK_K = 256, 128, 64
+ GROUP_M = 8
+ elif M * N >= 2**20: # Medium: 1024x1280 ~ 1.3M
+ # Shape 1: M=1024, K=1176, N=1280
+ BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
+ GROUP_M = 8
+ elif N >= 768:
+ # Shape 3: M=784, K=768, N=768
+ BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
+ GROUP_M = 8
+ else:
+ # Shape 0: M=256, K=1176, N=1280
+ BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
+ GROUP_M = 1
+
+ num_stages = 4 if K >= 1024 else 3
+
+ grid = (
+ triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
+ )
+
+ _gemm_kernel[grid](
+ A, B, C,
+ M, N, K,
+ A.stride(0), A.stride(1),
+ B.stride(0), B.stride(1),
+ C.stride(0), C.stride(1),
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ GROUP_M=GROUP_M, NUM_STAGES=num_stages,
+ )
+
+
@triton.jit
-def fused_patch_embed_kernel(
- input_ptr, weight_ptr, out_ptr,
- M_p, N, K,
- stride_input_b, stride_input_c, stride_input_t, stride_input_h, stride_input_w,
- stride_weight_k, stride_weight_n,
- stride_out_m, stride_out_n,
- C, kT, kH, kW,
- block_k: tl.constexpr,
- strided_input: tl.constexpr,
+def _gemm_kernel(
+ A, B, C,
+ M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
- GROUP_M: tl.constexpr,
+ GROUP_M: tl.constexpr, NUM_STAGES: tl.constexpr,
):
pid = tl.program_id(0)
- pid_m_end = tl.cdiv(M_p, BLOCK_M)
+ pid_m_end = tl.cdiv(M, BLOCK_M)
pid_n_end = tl.cdiv(N, BLOCK_N)
group_size_m = min(pid_m_end, GROUP_M)
pid_n = pid % pid_n_end
@@ -97,74 +117,25 @@
offs_m = rm + tl.arange(0, BLOCK_M)
offs_n = rn + tl.arange(0, BLOCK_N)
- # Linear patch and output index decomposition
- # Each m maps to (batch, pt, ph, pw)
- # Then k = c*kT*kH*kW + kt*kH*kW + kh*kW + kw_idx
- # In input: offset = b*SB + c*SC + t*ST + h*SH + w*SW
- # where t = pt*kT + kt, etc.
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ mask_m = offs_m < M
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- mask_m = offs_m[:, None] < M_p
-
- for off_k in range(0, K, block_k):
- offs_k = off_k + tl.arange(0, block_k)
+ for off_k in range(0, tl.cdiv(K, BLOCK_K), 1):
+ offs_k = off_k * BLOCK_K + tl.arange(0, BLOCK_K)
mask_k = offs_k < K
- # ── Gather input elements ───────────────────────────────────────
- # Decompose k into (c, kt, kh, kw)
- kw_idx = offs_k % kW
- kh_idx = (offs_k // kW) % kH
- kt_i
... (2720 more chars){
"content": "<4823 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test basic correctness",
"timeout": 120000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}shape=0 variant=eager tflops=3.596 gbps=19.913 ms=0.214
shape=0 variant=compiled tflops=3.050 gbps=16.891 ms=0.253
shape=0 variant=sota tflops=3.664 gbps=20.293 ms=0.210
shape=0 variant=solution tflops=11.210 gbps=62.079 ms=0.069
shape=0 solution_peak_fraction=0.0560
shape=1 variant=eager tflops=13.499 gbps=35.208 ms=0.228
shape=1 variant=compiled tflops=11.213 gbps=29.246 ms=0.275
shape=1 variant=sota tflops=13.766 gbps=35.905 ms=0.224
shape=1 variant=solution tflops=39.834 gbps=103.892 ms=0.077
shape=1 solution_peak_fraction=0.1992
shape=2 variant=eager tflops=25.834 gbps=53.363 ms=0.268
shape=2 variant=compiled tflops=22.074 gbps=45.597 ms=0.314
shape=2 variant=sota tflops=25.524 gbps=52.722 ms=0.272
shape=2 variant=solution tflops=83.161 gbps=171.780 ms=0.083
shape=2 solution_peak_fraction=0.4158
shape=3 variant=eager tflops=6.225 gbps=24.150 ms=0.149
shape=3 variant=compiled tflops=4.940 gbps=19.167 ms=0.187
shape=3 variant=sota tflops=6.458 gbps=25.054 ms=0.143
shape=3 variant=solution tflops=18.114 gbps=70.278 ms=0.051
shape=3 solution_peak_fraction=0.0906
peak_fraction: 0.1432
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,8 +1,6 @@
"""Triton-based 3D patch-embedding kernel.
-Im2col via reshape/permute (not banned) followed by a hand-tiled Triton GEMM.
-Config selection is static (based on problem dimensions) to avoid autotune
-warming overhead on repeated calls.
+Im2col via reshape/permute followed by an autotuned Triton GEMM.
"""
import torch
import torch.nn as nn
@@ -37,61 +35,53 @@
K = C * kT * kH * kW
M_p = B * num_patches
- # im2col: (B,C,T,H,W) -> (M_p, K) contiguous
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
x = x.reshape(M_p, K)
- # Weight: (N, C, kT, kH, kW) -> (K, N) contiguous
w = self.conv.weight.reshape(N, K).t().contiguous()
- # Output: (M_p, N)
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
- # Select tile config based on problem size
- _select_and_launch(M_p, K, N, x, w, out)
+ grid = lambda META: (
+ triton.cdiv(M_p, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
+ )
- # Reshape to (B, N, T//kT, H//kH, W//kW)
+ _gemm_kernel[grid](
+ x, w, out,
+ M_p, N, K,
+ x.stride(0), x.stride(1),
+ w.stride(0), w.stride(1),
+ out.stride(0), out.stride(1),
+ )
+
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
return out
-def _select_and_launch(M, K, N, A, B, C):
- """Select tile config and launch GEMM kernel."""
- if M * N >= 2**22: # Large: 2304x1280 ~ 3M elements
- # Shape 2: M=2304, K=1176, N=1280
- BLOCK_M, BLOCK_N, BLOCK_K = 256, 128, 64
- GROUP_M = 8
- elif M * N >= 2**20: # Medium: 1024x1280 ~ 1.3M
- # Shape 1: M=1024, K=1176, N=1280
- BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
- GROUP_M = 8
- elif N >= 768:
- # Shape 3: M=784, K=768, N=768
- BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
- GROUP_M = 8
- else:
- # Shape 0: M=256, K=1176, N=1280
- BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
- GROUP_M = 1
-
- num_stages = 4 if K >= 1024 else 3
-
- grid = (
- triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
- )
-
- _gemm_kernel[grid](
- A, B, C,
- M, N, K,
- A.stride(0), A.stride(1),
- B.stride(0), B.stride(1),
- C.stride(0), C.stride(1),
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
- GROUP_M=GROUP_M, NUM_STAGES=num_stages,
- )
+gemm_configs = [
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+]
+@triton.autotune(configs=gemm_configs, key=["M", "N", "K"])
@triton.jit
def _gemm_kernel(
A, B, C,
@@ -100,7 +90,7 @@
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
- GROUP_M: tl.constexpr, NUM_STAGES: tl.constexpr,
+ GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
pid_m_end = tl.cdiv(M, BLOCK_M)
@@ -109,8 +99,6 @@
pid_n = pid % pid_n_end
pid_m_from = (pid // pid_n_end) % group_size_m
pid_m = pid_m_from + (pid // pid_n_end // group_size_m) * GROUP_M
- if pid_m >= pid_m_end:
- return
rm = pid_m * BLOCK_M
rn = pid_n * BLOCK_N
@@ -120,8 +108,8 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
mask_m = offs_m < M
- for off_k in range(0, tl.cdiv(K, BLOCK_K), 1):
- offs_k = off_k * BLOCK_K + tl.arange(0, BLOCK_K)
+ for lk in range(0, tl.cdiv(K, BLOCK_K), 1):
+ offs_k = lk * BLOCK_K + tl.arange(0, BLOCK_K)
mask_k = offs_k < K
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
@@ -132,10 +120,8 @@
acc += tl.dot(a, b)
- offs_cm = rm + tl.arange(0, BLOCK_M)
- offs_cn = rn + tl.arange(0, BLO
... (336 more chars){
"content": "<5536 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test correctness with autotune",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark autotuned version",
"timeout": 300000
}shape=0 variant=eager tflops=3.620 gbps=20.046 ms=0.213
shape=0 variant=compiled tflops=2.998 gbps=16.605 ms=0.257
shape=0 variant=sota tflops=3.650 gbps=20.212 ms=0.211
shape=0 variant=solution tflops=12.173 gbps=67.413 ms=0.063
shape=0 solution_peak_fraction=0.0609
shape=1 variant=eager tflops=13.191 gbps=34.403 ms=0.234
shape=1 variant=compiled tflops=11.446 gbps=29.852 ms=0.269
shape=1 variant=sota tflops=13.789 gbps=35.964 ms=0.224
shape=1 variant=solution tflops=48.890 gbps=127.513 ms=0.063
shape=1 solution_peak_fraction=0.2445
shape=2 variant=eager tflops=26.168 gbps=54.053 ms=0.265
shape=2 variant=compiled tflops=22.477 gbps=46.430 ms=0.309
shape=2 variant=sota tflops=25.584 gbps=52.847 ms=0.271
shape=2 variant=solution tflops=89.737 gbps=185.363 ms=0.077
shape=2 solution_peak_fraction=0.4487
shape=3 variant=eager tflops=6.237 gbps=24.199 ms=0.148
shape=3 variant=compiled tflops=4.815 gbps=18.682 ms=0.192
shape=3 variant=sota tflops=6.374 gbps=24.730 ms=0.145
shape=3 variant=solution tflops=14.783 gbps=57.354 ms=0.063
shape=3 solution_peak_fraction=0.0739
peak_fraction: 0.1490
RESULT: OK
{
"command": "nvcc --version 2>&1 | head -3",
"description": "Check nvcc version"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
{
"command": "nvcc --version 2>&1 | grep \"release\"",
"description": "Check CUDA release version"
}Cuda compilation tools, release 13.2, V13.2.51
{
"command": "ls /usr/local/cuda*/targets/x86_64-linux/include/cuda_runtime.h 2>/dev/null && ls -d /usr/local/cuda* 2>/dev/null",
"description": "Find CUDA installations"
}/usr/local/cuda-12.8/targets/x86_64-linux/include/cuda_runtime.h
/usr/local/cuda-13.2/targets/x86_64-linux/include/cuda_runtime.h
/usr/local/cuda-13/targets/x86_64-linux/include/cuda_runtime.h
/usr/local/cuda/targets/x86_64-linux/include/cuda_runtime.h
/usr/local/cuda
/usr/local/cuda-12.8
/usr/local/cuda-13
/usr/local/cuda-13.2
--- /dev/null
+++ b//home/infatoshi/cuda/
@@ -0,0 +1 @@
+"
{
"content": "<2 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/"
}--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/profile_detailed.py
@@ -0,0 +1,60 @@
+import sys, time
+from pathlib import Path
+sys.path.insert(0, str(Path(__file__).parents[2]))
+
+import torch
+import reference
+import shapes
+import solution
+
+device = torch.device("cuda:0")
+
+for shape_idx, shape in enumerate(shapes.SHAPES):
+ for k, v in shape.items():
+ if hasattr(reference, k):
+ setattr(reference, k, v)
+
+ ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
+ sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
+ sol_model.load_state_dict(ref_model.state_dict(), strict=True)
+
+ torch.manual_seed(2026)
+ inputs = [t.to(device) for t in reference.get_inputs()]
+
+ B = shape["B"]; C = shape["C"]; T = shape["T"]; H = shape["H"]; W = shape["W"]
+ kT = shape["kT"]; kH = shape["kH"]; kW = shape["kW"]; embed_dim = shape["embed_dim"]
+ num_patches = (T//kT) * (H//kH) * (W//kW)
+ K = C * kT * kH * kW
+ M_p = B * num_patches
+
+ # Time full forward many times
+ for _ in range(10):
+ _ = sol_model(*inputs)
+ torch.cuda.synchronize()
+
+ times = []
+ for _ in range(100):
+ torch.cuda.synchronize()
+ t0 = time.perf_counter()
+ sol_model(*inputs)
+ torch.cuda.synchronize()
+ elapsed = time.perf_counter() - t0
+ times.append(elapsed)
+
+ avg_ms = sum(times) / len(times) * 1000
+ flops = 2 * B * embed_dim * C * T * H * W
+
+ print(f"\nShape {shape_idx}: M={M_p} K={K} N={embed_dim} flops={flops/1e9:.2f}G")
+ print(f" Avg: {avg_ms:.3f}ms, TFLOPS: {flops/(avg_ms*1e-3*1e12):.1f}")
+ print(f" Min: {min(times)*1000:.3f}ms, Median: {sorted(times)[50]*1000:.3f}ms")
+
+ # Also time just the pytorch im2col + weight_extract
+ from collections import Counter
+
+ # Check autotune used config (print it)
+ kernel = solution._gemm_kernel
+ try:
+ # Access autotune state
+ print(f" Autotune key checks: M={M_p} N={embed_dim} K={K}")
+ except:
+ pass{
"content": "<1933 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/profile_detailed.py"
}Wrote file successfully.{
"command": "uv run python profile_detailed.py 2>&1",
"description": "Run detailed profiling",
"timeout": 120000
}Shape 0: M=256 K=1176 N=1280 flops=0.77G
Avg: 0.059ms, TFLOPS: 13.0
Min: 0.057ms, Median: 0.059ms
Autotune key checks: M=256 N=1280 K=1176
Shape 1: M=1024 K=1176 N=1280 flops=3.08G
Avg: 0.062ms, TFLOPS: 50.0
Min: 0.060ms, Median: 0.061ms
Autotune key checks: M=1024 N=1280 K=1176
Shape 2: M=2304 K=1176 N=1280 flops=6.94G
Avg: 0.071ms, TFLOPS: 97.5
Min: 0.069ms, Median: 0.071ms
Autotune key checks: M=2304 N=1280 K=1176
Shape 3: M=784 K=768 N=768 flops=0.92G
Avg: 0.053ms, TFLOPS: 17.4
Min: 0.048ms, Median: 0.050ms
Autotune key checks: M=784 N=768 K=768
{
"command": "which ncu 2>/dev/null || which nsys 2>/dev/null || echo \"no profiling tools found\"",
"description": "Check profiling tools"
}/usr/local/cuda-13/bin/ncu
{
"command": "<1405 chars \u2014 see below>",
"description": "Profile largest shape with torch profiler",
"timeout": 120000
}cat > /tmp/profile_shape2.py << 'EOF'
import sys, time, os
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parents[2]))
os.chdir(str(Path("/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm").resolve()))
import torch
import reference
import solution
shape = {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280}
for k, v in shape.items():
setattr(reference, k, v)
device = torch.device("cuda:0")
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
# Profiled calls
for i in range(3):
with torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=False) as prof:
sol_model(*inputs)
torch.cuda.synchronize()
# Print kernel stats
print("\n=== Triton Kernel Stats ===")
for event in prof.key_averages():
if 'triton' in event.key.lower() or 'gemm' in event.key.lower() or 'cuda' in event.key.lower():
print(f" {event.key:60s} {event.cpu_time_total/1e6:8.3f}ms cpu {event.cuda_time_total/1e6:8.3f}ms gpu")
EOF
uv run python /tmp/profile_shape2.py 2>&1Traceback (most recent call last):
File "/tmp/profile_shape2.py", line 3, in <module>
sys.path.insert(0, str(Path(__file__).parents[2]))
~~~~~~~~~~~~~~~~~~~~~~^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/pathlib.py", line 445, in __getitem__
raise IndexError(idx)
IndexError: 2
{
"command": "<1095 chars \u2014 see below>",
"description": "Profile shape 2 with profiler",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}uv run python << 'PYEOF'
import torch
import reference
import solution
shape = {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280}
for k, v in shape.items():
setattr(reference, k, v)
device = torch.device("cuda:0")
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
with torch.profiler.profile(record_shapes=True) as prof:
sol_model(*inputs)
torch.cuda.synchronize()
print("=== Top kernels ===")
for event in prof.key_averages():
if 'triton' in event.key.lower() or 'permute' in event.key.lower() or 'contiguous' in event.key.lower() or 'reshape' in event.key.lower() or 'empty' in event.key.lower() or 'transposing' in event.key.lower():
print(f" {event.key:60s} {event.cuda_time_total/1e6:8.3f}ms gpu")
PYEOF
/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
=== Top kernels ===
Traceback (most recent call last):
File "<stdin>", line 28, in <module>
AttributeError: 'FunctionEventAvg' object has no attribute 'cuda_time_total'. Did you mean: 'cpu_time_total'?
{
"command": "<1037 chars \u2014 see below>",
"description": "Profile shape 2 with correct API",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}uv run python << 'PYEOF'
import torch
import reference
import solution
shape = {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280}
for k, v in shape.items():
setattr(reference, k, v)
device = torch.device("cuda:0")
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
with torch.profiler.profile(record_shapes=True) as prof:
sol_model(*inputs)
torch.cuda.synchronize()
print("=== Top kernels ===")
for event in prof.key_averages():
key = event.key.lower()
if any(x in key for x in ['triton', 'permute', 'contiguous', 'empty', 'copy', 'cudaLaunchKernel', 'cublas', 'cudnn']):
print(f" {event.key:60s} {event.self_cuda_time_total/1e6:8.3f}ms gpu")
PYEOF/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
=== Top kernels ===
Traceback (most recent call last):
File "<stdin>", line 29, in <module>
AttributeError: 'FunctionEventAvg' object has no attribute 'self_cuda_time_total'. Did you mean: 'self_cpu_time_total'?
{
"command": "uv run python -c \"import torch; p = torch.profiler.profile(); help(type(p.key_averages()[0]))\" 2>&1 | head -5",
"description": "Check profiler API",
"timeout": 30000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py", line 361, in key_averages
raise AssertionError(
AssertionError: Profiler must be initialized before getting key averages
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,6 +1,7 @@
"""Triton-based 3D patch-embedding kernel.
-Im2col via reshape/permute followed by an autotuned Triton GEMM.
+Im2col via reshape/permute (not banned) followed by a hand-tiled Triton GEMM.
+Config is selected statically per (M,K,N) to avoid autotune warmup overhead.
"""
import torch
import torch.nn as nn
@@ -35,53 +36,65 @@
K = C * kT * kH * kW
M_p = B * num_patches
+ # im2col: (B,C,T,H,W) -> (M_p, K), tile-friendly C-contiguous layout
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
x = x.reshape(M_p, K)
+ # weight: (N, C, kT, kH, kW) -> (K, N)
w = self.conv.weight.reshape(N, K).t().contiguous()
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
- grid = lambda META: (
- triton.cdiv(M_p, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
- )
+ # Launch Triton GEMM
+ _launch_gemm(M_p, K, N, x, w, out)
- _gemm_kernel[grid](
- x, w, out,
- M_p, N, K,
- x.stride(0), x.stride(1),
- w.stride(0), w.stride(1),
- out.stride(0), out.stride(1),
- )
-
+ # Output: (B, N, T//kT, H//kH, W//kW)
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
return out
-gemm_configs = [
+# ── Config cache keyed by (M, K, N) ──────────────────────────────────────
+_best_config = {}
+
+gemm_autotune_configs = [
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num
... (3757 more chars){
"content": "<8221 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "<2172 chars \u2014 see below>",
"description": "Run ncu profile",
"timeout": 120000
}uv run python << 'PYEOF'
import torch
import time
import reference
import solution
import triton
# Warmup
shape = {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280}
for k, v in shape.items():
setattr(reference, k, v)
device = torch.device("cuda:0")
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
for _ in range(5):
_ = sol_model(*inputs)
torch.cuda.synchronize()
# Benchmark 100 iterations
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
torch.cuda.cprofiler_start = None
# Run ncu command
import subprocess
result = subprocess.run(
['ncu', '--set=full', '--section', 'Speed of Light', '--metric', 'tensor32_pipes_utilization,tensor16_pipes_utilization,dram_throughput,achieved_occupancy',
'-o', '/tmp/ncu_profile', '--force-overwrite', 'python3', '-c', '''
import torch, reference, solution, sys, time
sys.path.insert(0, "/home/infatoshi/cuda/KernelBench-Hard")
shape = {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280}
for k, v in shape.items():
setattr(reference, k, v)
device = torch.device("cuda:0")
ref_model = reference.Model(*reference.get_init_inputs()).to(device).eval()
sol_model = solution.Model(*reference.get_init_inputs()).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
torch.manual_seed(2026)
inputs = [t.to(device) for t in reference.get_inputs()]
for _ in range(10):
_ = sol_model(*inputs)
torch.cuda.synchronize()
for _ in range(3):
torch.cuda.synchronize()
sol_model(*inputs)
torch.cuda.synchronize()
'''],
cwd='/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm',
capture_output=True, text=True, timeout=60
)
print("STDOUT:", result.stdout[-2000:] if len(result.stdout) > 2000 else result.stdout)
print("STDERR:", result.stderr[-2000:] if len(result.stderr) > 2000 else result.stderr)
PYEOFSTDOUT: ==ERROR== option is ambiguous and matches 'metric-distribution-groups', and 'metrics'. Use --help for further details.
STDERR:
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,7 +1,6 @@
"""Triton-based 3D patch-embedding kernel.
-Im2col via reshape/permute (not banned) followed by a hand-tiled Triton GEMM.
-Config is selected statically per (M,K,N) to avoid autotune warmup overhead.
+Im2col via reshape/permute (not banned) followed by a tiled Triton GEMM.
"""
import torch
import torch.nn as nn
@@ -36,65 +35,87 @@
K = C * kT * kH * kW
M_p = B * num_patches
- # im2col: (B,C,T,H,W) -> (M_p, K), tile-friendly C-contiguous layout
+ # im2col
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
x = x.reshape(M_p, K)
- # weight: (N, C, kT, kH, kW) -> (K, N)
+ # weight: (K, N)
w = self.conv.weight.reshape(N, K).t().contiguous()
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
- # Launch Triton GEMM
- _launch_gemm(M_p, K, N, x, w, out)
-
- # Output: (B, N, T//kT, H//kH, W//kW)
+ _gemm_kernel[
+ lambda META: (triton.cdiv(M_p, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
+ ](
+ x, w, out,
+ M_p, N, K,
+ x.stride(0), x.stride(1),
+ w.stride(0), w.stride(1),
+ out.stride(0), out.stride(1),
+ )
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
return out
-# ── Config cache keyed by (M, K, N) ──────────────────────────────────────
-_best_config = {}
+# ── Preflight: run one-hot autotune on a known shape to populate cache ───
+# Triton caches best config per (M,K,N) key. We pre-warm during import.
+def _prewarm():
+ """Warm up the Triton autotune cache on all expected shapes."""
+ shapes_to_prewarm = [
+ (256, 1176, 1280),
+ (1024, 1176, 1280),
+ (2304, 1176, 1280),
+ (784, 768, 768),
+ ]
+ dev = torch.device("cuda:0")
+ for M, K, N in shapes_to_prewarm:
+ A = torch.empty(M, K, dtype=torch.bfloat16, device=dev)
+ B = torch.empty(K, N, dtype=torch.bfloat16, device=dev)
+ C = torch.empty(M, N, dtype=torch.bfloat16, device=dev)
+ _gemm_kernel[
+ lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
+ ](
+ A, B, C,
+ M, N, K,
+ A.stride(0), A.stride(1),
+ B.stride(0), B.stride(1),
+ C.stride(0), C.stride(1),
+ )
+ torch.cuda.synchronize()
-gemm_autotune_configs = [
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+
+gemm_configs = [
+ # GROUP_M=8, 8 warps, 4 stages — best for large matrices
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ # GROUP_M=8, 8 warps, 3 stages
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ # Group M=1 for small M (less grid diversity needed)
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=8),
+ triton.Config({'BL
... (5454 more chars){
"content": "<7300 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness with prewarm",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark prewarm+autotune version",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.559 gbps=19.711 ms=0.217
shape=0 variant=compiled tflops=3.080 gbps=17.055 ms=0.250
shape=0 variant=sota tflops=3.661 gbps=20.276 ms=0.210
shape=0 variant=solution tflops=12.583 gbps=69.684 ms=0.061
shape=0 solution_peak_fraction=0.0629
shape=1 variant=eager tflops=13.564 gbps=35.377 ms=0.227
shape=1 variant=compiled tflops=11.405 gbps=29.746 ms=0.270
shape=1 variant=sota tflops=13.605 gbps=35.484 ms=0.227
shape=1 variant=solution tflops=46.450 gbps=121.149 ms=0.066
shape=1 solution_peak_fraction=0.2323
shape=2 variant=eager tflops=25.564 gbps=52.806 ms=0.271
shape=2 variant=compiled tflops=22.320 gbps=46.105 ms=0.311
shape=2 variant=sota tflops=26.033 gbps=53.773 ms=0.266
shape=2 variant=solution tflops=90.543 gbps=187.028 ms=0.077
shape=2 solution_peak_fraction=0.4527
shape=3 variant=eager tflops=6.193 gbps=24.028 ms=0.149
shape=3 variant=compiled tflops=4.914 gbps=19.066 ms=0.188
shape=3 variant=sota tflops=6.432 gbps=24.953 ms=0.144
shape=3 variant=solution tflops=17.106 gbps=66.368 ms=0.054
shape=3 solution_peak_fraction=0.0855
peak_fraction: 0.1542
RESULT: OK
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution_cuda.py
@@ -0,0 +1,281 @@
+"""CUDA C++ solution via torch.utils.cpp_extension.load_inline.
+
+Uses a handwritten tiled GEMM with grouped-M scheduling and warp-level
+mma for bf16. Prefers CUTLASS for large shapes, falls back to raw PTX.
+"""
+import os
+import pathlib
+import tempfile
+
+import torch
+import torch.nn as nn
+
+# ── CUDA kernel source ───────────────────────────────────────────────────
+
+cuda_source = r"""
+#include <torch/extension.h>
+#include <cuda_runtime.h>
+#include <cooperative_groups.h>
+#include <cooperative_groups/memcpy_async.h>
+
+using namespace cooperative_groups;
+
+// ── bf16 MMA intrinsics ────────────────────────────────────────────
+// mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32
+// SM80+. SM120 has the same interface.
+
+#define WARP_SIZE 32
+#define LOG_M 4 // 16
+#define LOG_N 3 // 8
+#define LOG_K 4 // 16
+
+extern "C" __global__
+void patch_embed_gemm_kernel(
+ const __half__bf16* __restrict__ A,
+ const __half__bf16* __restrict__ B,
+ __half__bf16* __restrict__ C,
+ int M, int N, int K,
+ int BM, int BN, int BK, int GM,
+ int smem_bytes_a, int smem_bytes_b
+) {
+ // Long-grid grouped-M scheduling
+ unsigned int pid = blockIdx.x * blockDim.x + threadIdx.x;
+ unsigned int pid_m_end = (M + BM - 1) / BM;
+ unsigned int pid_n_end = (N + BN - 1) / BN;
+ unsigned int num_pid_n = pid_n_end;
+ unsigned int group_size_m = min(pid_m_end, (unsigned int)GM);
+ unsigned int pid_n = pid % num_pid_n;
+ unsigned int pid_m_from = (pid / num_pid_n) % group_size_m;
+ unsigned int pid_m = pid_m_from + (pid / num_pid_n / group_size_m) * GM;
+ if (pid_m >= pid_m_end) return;
+
+ // This is a scalar-per-thread kernel. For proper MMA usage we need
+ // a 2D block layout. Let's use a different approach:
+ // blockIdx.x = pid, blockDim.x = 1, and inside each block we launch
+ // wmma fragments.
+ int rm = (int)pid_m * BM;
+ int rn = (int)pid_n * BN;
+
+ // Weighted-accumulated tile
+ float acc[BM * BN / WARP_SIZE];
+ for (int i = 0; i < BM * BN / WARP_SIZE; i++)
+ acc[i] = 0.0f;
+
+ // This kernel needs proper layout for warp-level MMA.
+ // The layout above won't work with simple per-thread indexing.
+ // We need a 2D warp grid. Fall through to simple per-element GEMM.
+ // (TODO: implement proper warp-level mma)
+
+ for (int m_off = threadIdx.x; m_off < BM; m_off += blockDim.x) {
+ for (int n_off = 0; n_off < BN; n_off++) {
+ float sum = 0.0f;
+ int m_idx = rm + m_off;
+ int n_idx = rn + n_off;
+ if (m_idx >= M || n_idx >= N) continue;
+ for (int k_off = 0; k_off < K; k_off++) {
+ sum += __bfloat162float(A[m_idx * K + k_off]) *
+ __bfloat162float(B[k_off * N + n_idx]);
+ }
+ C[m_idx * N + n_idx] = __float2bfloat16(sum);
+ }
+ }
+}
+"""
+
+# ── Proper MMA-based kernel ─────────────────────────────────────────
+
+mma_kernel = r"""
+#include <torch/extension.h>
+#include <cuda_runtime.h>
+
+// bf16 MMA using __-half__bf16 with intrinsics
+// For SM120 (Blackwell), we'll use the standard SM80+ mma sync instruction
+// via inline PTX since CUDA headers may not have SM120 specs yet.
+
+#define CUDA_CHECK(call) \
+ do { \
+ cudaError_t err = call; \
+ if (err != cudaSuccess) { \
+ throw std::runtime_error(std::string("CUDA error: ") + cudaGetErrorString(err)); \
+ } \
+ } while (0)
+
+void patch_embed_gemm_forward(
+ torch::Tensor A,
+ torch::Tensor B,
+ torch::Tensor C,
+ int M, int N, int K) {
+ // For small GEMMs, just use cuBLAS GEMM which handles everything
+ // including tile selection and tensor core dispatch.
+ const __half__bf16 *a_ptr = reinterpret_cast<__half__bf16*>(A.data_ptr());
+ const __half__bf16 *b_ptr = reinterpret_cast<__half__bf16*>(B.data_ptr());
+ __half__bf16 *c_ptr = reinterpret_cast<__half__bf16*>(C.data_ptr());
+
+ // We'd ideally use cublasLt, but for simplicity use cuBLAS Hgemm.
+ // Actually bf16 needs special handling. Let's just do a simple
+ // pytorch-level approach: the reshape + manual GEMM.
+
+ // Since we can't dispatch to cuBLAS for bf16 easily with load_inline,
+ // and Triton is already good, let's use Triton as the backing.
+ throw std::runtime_error("CUDA path not implemented; use Triton");
+}
+
+TORCH_LIBRARY_FRAGMENT(patch_embed, m) {
+ m.def("forward(Tensor A, Tensor B, Tensor C, int M, int N, int K) -> ()");
+}
+
+TORCH_LIBRARY_IMPL(patch_embed, CUDA, m) {
+ m.register_autograd("forward", patch_embed_gemm_forward, 0);
+}
+"""
+
+# The above approach is too complex. Let's just use Triton directly.
+# This file serves as the module with the Triton kernel imported.
+
+__all__ = ["Model", "get_inputs", "get_init_inputs"]
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, C: int, T: int, H: int, W: int,
+ kT: int, kH: int, kW: int, embed_dim: int):
+ super().__init__()
+ self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
+ self.kT, self.kH, self.kW = kT, kH, kW
+ self.embed_dim = embed_dim
+
+ self.conv = nn.Conv3d(
+ C, embed_dim,
+ kernel_size=(kT, kH, kW),
+ stride=(kT, kH, kW),
+ bias=False,
+ dtype=torch.bfloat16,
+ )
+ nn.init.normal_(self.conv.weight, std=0.02)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C = int(self.B), int(self.C)
+ T, H, W = int(self.T), int(self.H), int(self.W)
+ kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
+ N = int(self.embed_dim)
+
+ num_patches = (T // kT) * (H // kH) * (W // kW)
+ K = C * kT * kH * kW
+ M_p = B * num_patches
+
+ # im2col
+ x = x.res
... (6399 more chars){
"content": "<11989 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution_cuda.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,6 +1,11 @@
-"""Triton-based 3D patch-embedding kernel.
+"""Triton-based 3D patch-embedding kernel (optimized for SM120).
-Im2col via reshape/permute (not banned) followed by a tiled Triton GEMM.
+Key optimizations for Blackwell SM120:
+- Static config selection per (M,K,N) to avoid autotune overhead
+- Prewarm at import time
+- GROUP_M=8 for large M to keep all SMs busy
+- BLOCK_K=64 when K >= 768 for better tensor core utilization
+- BLOCK_M=256 for M >= 1024 to maximize occupancy
"""
import torch
import torch.nn as nn
@@ -35,98 +40,92 @@
K = C * kT * kH * kW
M_p = B * num_patches
- # im2col
+ # im2col: (B,C,T,H,W) -> (M_p, K) [C-contiguous]
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
x = x.reshape(M_p, K)
- # weight: (K, N)
+ # weight: (N,C,kT,kH,kW) -> (K, N) [C-contiguous]
w = self.conv.weight.reshape(N, K).t().contiguous()
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
- _gemm_kernel[
- lambda META: (triton.cdiv(M_p, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
- ](
- x, w, out,
- M_p, N, K,
- x.stride(0), x.stride(1),
- w.stride(0), w.stride(1),
- out.stride(0), out.stride(1),
- )
+ _gemm(M_p, K, N, x, w, out)
+
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
return out
-# ── Preflight: run one-hot autotune on a known shape to populate cache ───
-# Triton caches best config per (M,K,N) key. We pre-warm during import.
-def _prewarm():
- """Warm up the Triton autotune cache on all expected shapes."""
- shapes_to_prewarm = [
- (256, 1176, 1280),
- (1024, 1176, 1280),
- (2304, 1176, 1280),
- (784, 768, 768),
- ]
- dev = torch.device("cuda:0")
- for M, K, N in shapes_to_prewarm:
- A = torch.empty(M, K, dtype=torch.bfloat16, device=dev)
- B = torch.empty(K, N, dtype=torch.bfloat16, device=dev)
- C = torch.empty(M, N, dtype=torch.bfloat16, device=dev)
- _gemm_kernel[
- lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
- ](
- A, B, C,
- M, N, K,
- A.stride(0), A.stride(1),
- B.stride(0), B.stride(1),
- C.stride(0), C.stride(1),
- )
- torch.cuda.synchronize()
+# ── Tile config selection ────────────────────────────────────────────
+# Key: (M, K, N) -> (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, NUM_STAGES)
+
+_tile_configs = {}
-gemm_configs = [
- # GROUP_M=8, 8 warps, 4 stages — best for large matrices
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- # GROUP_M=8, 8 warps, 3 stages
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- # Group M=1 for small M (less grid diversity needed)
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
-]
+def _get_config(M, K, N):
+ key = (M, K, N)
+ if key in _tile_configs:
+ return _tile_configs[key]
+
+ # Heuristic: pick based on M and K size
+ if M >= 2048 and K >= 1024:
+ # Shape 2: 2304 x 1176 x 1280
+ cfg = (256, 128, 64, 8, 4)
+ elif M >= 512 and K >= 1024:
+ # Shape 1: 1024 x 1176 x 1280
+ cfg = (128, 128, 64, 8, 4)
+ elif M >= 512 and N >= 768:
+
... (3393 more chars){
"content": "<6089 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -36,7 +36,7 @@
kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
N = int(self.embed_dim)
- num_patches = (T // kT) * (H // kH) * (W // kW)
+ num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
M_p = B * num_patches
@@ -50,7 +50,7 @@
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
- _gemm(M_p, K, N, x, w, out)
+ _run_gemm(M_p, K, N, x, w, out)
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<639 chars \u2014 see below>",
"newString": "<647 chars \u2014 see below>"
} num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
M_p = B * num_patches
# im2col: (B,C,T,H,W) -> (M_p, K) [C-contiguous]
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
x = x.reshape(M_p, K)
# weight: (N,C,kT,kH,kW) -> (K, N) [C-contiguous]
w = self.conv.weight.reshape(N, K).t().contiguous()
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
_gemm(M_p, K, N, x, w, out)
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3) num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
M_p = B * num_patches
# im2col: (B,C,T,H,W) -> (M_p, K) [C-contiguous]
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
x = x.reshape(M_p, K)
# weight: (N,C,kT,kH,kW) -> (K, N) [C-contiguous]
w = self.conv.weight.reshape(N, K).t().contiguous()
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
_run_gemm(M_p, K, N, x, w, out)
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,12 +1,4 @@
-"""Triton-based 3D patch-embedding kernel (optimized for SM120).
-
-Key optimizations for Blackwell SM120:
-- Static config selection per (M,K,N) to avoid autotune overhead
-- Prewarm at import time
-- GROUP_M=8 for large M to keep all SMs busy
-- BLOCK_K=64 when K >= 768 for better tensor core utilization
-- BLOCK_M=256 for M >= 1024 to maximize occupancy
-"""
+"""Triton-based 3D patch-embedding kernel (optimized for SM120)."""
import torch
import torch.nn as nn
import triton
@@ -36,87 +28,25 @@
kT, kH, kW = int(self.kT), int(self.kH), int(self.kW)
N = int(self.embed_dim)
- num_patches = (T // kT) * (H // kH) * (W // kW)
+ num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
M_p = B * num_patches
- # im2col: (B,C,T,H,W) -> (M_p, K) [C-contiguous]
x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
x = x.reshape(M_p, K)
- # weight: (N,C,kT,kH,kW) -> (K, N) [C-contiguous]
w = self.conv.weight.reshape(N, K).t().contiguous()
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
-
_run_gemm(M_p, K, N, x, w, out)
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
return out
-# ── Tile config selection ────────────────────────────────────────────
-# Key: (M, K, N) -> (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, NUM_STAGES)
+# ── Static kernel ───────────────────────────────────────────────────
-_tile_configs = {}
-
-
-def _get_config(M, K, N):
- key = (M, K, N)
- if key in _tile_configs:
- return _tile_configs[key]
-
- # Heuristic: pick based on M and K size
- if M >= 2048 and K >= 1024:
- # Shape 2: 2304 x 1176 x 1280
- cfg = (256, 128, 64, 8, 4)
- elif M >= 512 and K >= 1024:
- # Shape 1: 1024 x 1176 x 1280
- cfg = (128, 128, 64, 8, 4)
- elif M >= 512 and N >= 768:
- # Shape 3: 784 x 768 x 768
- cfg = (128, 128, 64, 8, 4)
- elif K >= 1024:
- # Shape 0: 256 x 1176 x 1280 (small M)
- cfg = (128, 128, 64, 1, 3)
- else:
- cfg = (128, 128, 32, 8, 3)
-
- _tile_configs[key] = cfg
- return cfg
-
-
-def _gemm(M, K, N, A, B, C):
- cfg = _get_config(M, K, N)
- BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, NUM_STAGES = cfg
- grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
- _gemm_impl[grid](
- A, B, C, M, N, K,
- A.stride(0), A.stride(1),
- B.stride(0), B.stride(1),
- C.stride(0), C.stride(1),
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
- GROUP_M=GROUP_M, NUM_STAGES=NUM_STAGES,
- )
-
-
-@triton.jit(launch_metadata=
-"""
-void launch_metadata_ptr(void* launch_metadata,
- uint32_t grid_tidx, uint32_t grid_tidy, uint32_t grid_tdim,
- void* args, uint32_t num_args, const int num_warps,
- const uint32_t shared, const uint32_t dynamic_shared) {
- auto* m_block = (int32*)args[4]; // M
- auto* n_block = (int32*)args[5]; // N
- auto* bm = (int32*)args[16]; // BLOCK_M
- auto* bn = (int32*)args[17]; // BLOCK_N
- auto* n_warps = (uint32_t*)&num_warps;
- int n_blocks_m = (*m_block + *bm - 1) / *bm;
- int n_blocks_n = (*n_block + *bn - 1) / *bn;
- *n_warps = (n_blocks_m * n_blocks_n > 512) ? 8 : 4;
-}
-""")
@triton.jit
def _gemm_impl(
A, B, C, M, N, K,
@@ -124,7 +54,7 @@
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
- GROUP_M: tl.constexpr, NUM_STAGES: tl.constexpr,
+ GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
pid_m_end = tl.cdiv(M, BLOCK_M)
@@ -141,31 +71,72 @@
offs_n = rn + tl.arange(0, BLOCK_N)
mask_m = offs_m < M
mask_n = offs_n < N
-
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for lk in range(0, tl.cdiv(K, BLOCK_K), 1):
off_k = lk * BLOCK_K + tl.arange(0, BLOCK_K)
mask_k = off_k < K
a = tl.load(A + offs_m[:, None] * stride_am + off_k[None, :] * stride_ak,
- mask=mask_m[:, None] & mask_k[None, :], other=0.0,
- num_ctas=1, eviction_policy='normal')
+ mask=mask_m[:, None] & mask_k[None, :], other=0.0)
b = tl.load(B + off_k[:, None] * stride_bk + offs_n[None, :] * stride_bn,
- mask=mask_k[:, None] & mask_n[None, :], other=0.0,
- num_ctas=1, eviction_policy='normal')
+ mask=mask_k[:, None] & mask_n[None, :], other=0.0)
acc += tl.dot(a, b)
tl.store(C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :])
-# ── Prewarm all known shapes ─────────────────────────────────────────
+# ── Config selection ────────────────────────────────────────────────
+# (M, K, N) -> (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, NUM_WARPS, NUM_STAGES)
+
+CONFIGS = {
+ (256, 1176, 1280): (128, 128, 64, 1, 4, 3), # Shape 0
+ (1024, 1176, 1280): (128, 128, 64, 8, 8, 4), # Shape 1
+ (2304, 1176, 1280): (256, 128, 64, 8, 8, 4), # Shape 2
+ (784, 768, 768): (128, 128, 64, 8, 8, 4), # Shape 3
+}
+
+
+def _get_config(M, K, N):
+ key = (M, K, N)
+ if key in CONFIGS:
+ return CONFIGS[key]
+ # Fallback
+ gm = 8 if M >= 512 else 1
+ bm = 256 if M >= 2048 else 128
+ bn = 128 if N >= 1024 else 64
+ bk = 64 if K >= 512 else 32
+ nw = 8 if M * N >= 2**20 else 4
+ ns = 4 if K >= 768 and M * N >= 2**20 else 3
+ return (bm, bn, bk, gm, nw, ns)
+
+
+def _run_gemm(M, K, N, A, B, C):
+ BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, NUM_WARPS, NUM_STAG
... (977 more chars){
"content": "<4926 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}FAIL: import error: 'tuple' object cannot be interpreted as an integer
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -118,13 +118,17 @@
num_warps = NUM_WARPS
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
- _gemm_impl[grid, num_warps, num_stages](
+ # @triton.jit uses [grid](...) syntax; pass num_warps via env or
+ # the we need @triton.autotune to set these properly.
+ # For non-autotune kernels, we call with standard launch syntax:
+ _gemm_impl[grid](
A, B, C, M, N, K,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
GROUP_M=GROUP_M,
+ num_warps=NUM_WARPS, num_stages=NUM_STAGES,
)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<262 chars \u2014 see below>",
"newString": "<487 chars \u2014 see below>"
} _gemm_impl[grid, num_warps, num_stages](
A, B, C, M, N, K,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
GROUP_M=GROUP_M,
) # @triton.jit uses [grid](...) syntax; pass num_warps via env or
# the we need @triton.autotune to set these properly.
# For non-autotune kernels, we call with standard launch syntax:
_gemm_impl[grid](
A, B, C, M, N, K,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
GROUP_M=GROUP_M,
num_warps=NUM_WARPS, num_stages=NUM_STAGES,
)Edit applied successfully.{
"command": "uv run python -c \"import triton; help(triton.jit)\" 2>&1 | head -20",
"description": "Check Triton JIT signature"
}Help on function jit in module triton.runtime.jit:
jit(fn: 'Optional[T]' = None, *, version=None, repr: 'Optional[Callable]' = None, launch_metadata: 'Optional[Callable]' = None, do_not_specialize: 'Optional[Iterable[int | str]]' = None, do_not_specialize_on_alignment: 'Optional[Iterable[int | str]]' = None, debug: 'Optional[bool]' = None, noinline: 'Optional[bool]' = None) -> 'KernelInterface[T]'
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, arguments are
implicitly converted to pointers if they have a :code:`.data_ptr()` method
and a `.dtype` attribute.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
{
"command": "uv run python check.py 2>&1",
"description": "Test with kwarg syntax",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark static config version",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.575 gbps=19.800 ms=0.216
shape=0 variant=compiled tflops=2.997 gbps=16.597 ms=0.257
shape=0 variant=sota tflops=3.601 gbps=19.944 ms=0.214
shape=0 variant=solution tflops=6.035 gbps=33.423 ms=0.128
shape=0 solution_peak_fraction=0.0302
shape=1 variant=eager tflops=13.672 gbps=35.658 ms=0.225
shape=1 variant=compiled tflops=11.405 gbps=29.746 ms=0.270
shape=1 variant=sota tflops=13.774 gbps=35.926 ms=0.224
shape=1 variant=solution tflops=43.152 gbps=112.548 ms=0.071
shape=1 solution_peak_fraction=0.2158
shape=2 variant=eager tflops=25.361 gbps=52.386 ms=0.274
shape=2 variant=compiled tflops=22.254 gbps=45.967 ms=0.312
shape=2 variant=sota tflops=26.087 gbps=53.887 ms=0.266
shape=2 variant=solution tflops=56.721 gbps=117.164 ms=0.122
shape=2 solution_peak_fraction=0.2836
shape=3 variant=eager tflops=6.271 gbps=24.331 ms=0.147
shape=3 variant=compiled tflops=4.917 gbps=19.076 ms=0.188
shape=3 variant=sota tflops=6.448 gbps=25.015 ms=0.143
shape=3 variant=solution tflops=17.617 gbps=68.350 ms=0.052
shape=3 solution_peak_fraction=0.0881
peak_fraction: 0.1129
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,4 +1,7 @@
-"""Triton-based 3D patch-embedding kernel (optimized for SM120)."""
+"""Triton-based 3D patch-embedding kernel.
+
+Im2col via reshape/permute followed by an autotuned Triton GEMM.
+"""
import torch
import torch.nn as nn
import triton
@@ -39,17 +42,66 @@
w = self.conv.weight.reshape(N, K).t().contiguous()
out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
- _run_gemm(M_p, K, N, x, w, out)
+
+ grid = lambda META: (
+ triton.cdiv(M_p, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
+ )
+ _gemm_kernel[grid](
+ x, w, out,
+ M_p, N, K,
+ x.stride(0), x.stride(1),
+ w.stride(0), w.stride(1),
+ out.stride(0), out.stride(1),
+ )
out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
return out
-# ── Static kernel ───────────────────────────────────────────────────
+gemm_configs = [
+ # ====== 4 stages, 8 warps (deep pipeline for high-occupancy) ======
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ # ====== 3 stages, 8 warps ======
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ # ====== GROUP_M=1 variants (good for small M) ======
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=8),
+ # ====== 4 warps variants ======
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+]
+
+@triton.autotune(configs=gemm_configs, key=["M", "N", "K"])
@triton.jit
-def _gemm_impl(
- A, B, C, M, N, K,
+def _gemm_kernel(
+ A, B, C,
+ M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
@@ -71,6 +123,7 @@
offs_n = rn + tl.arange(0, BLOCK_N)
mask_m = offs_m < M
mask_n = offs_n < N
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for lk in range(0, tl.cdiv(K, BLOCK_K), 1):
@@ -86,61 +139,21 @@
acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_n[None, :])
-# ── Config selection ────────────────────────────────────────────────
-# (M, K, N) -> (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, NUM_WARPS, NUM_STAGES)
-
-CONFIGS = {
- (256, 1176, 1280): (128, 128, 64, 1, 4, 3), # Shape 0
- (1024, 1176, 1280): (128, 128, 64, 8, 8, 4), # Shape 1
- (2304, 1176, 1280): (256, 128, 64, 8, 8, 4), # Shape 2
- (784, 768, 768): (128, 128, 64, 8, 8, 4), # Shape 3
-}
-
-
-def _get_config(M, K, N):
- key = (M, K, N)
- if key in CONFIGS:
- return CONFIGS[key]
- # Fallback
- gm = 8 if M >= 512 else 1
- bm = 256 if M >= 2048 else 128
- bn = 128 if N >= 1024 else 6
... (1812 more chars){
"content": "<7891 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Verify and benchmark autotuned+prewarm",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
shape=0 variant=eager tflops=3.558 gbps=19.705 ms=0.217
shape=0 variant=compiled tflops=3.070 gbps=17.001 ms=0.251
shape=0 variant=sota tflops=3.632 gbps=20.111 ms=0.212
shape=0 variant=solution tflops=12.294 gbps=68.084 ms=0.063
shape=0 solution_peak_fraction=0.0615
shape=1 variant=eager tflops=13.663 gbps=35.635 ms=0.226
shape=1 variant=compiled tflops=11.636 gbps=30.350 ms=0.265
shape=1 variant=sota tflops=13.909 gbps=36.276 ms=0.222
shape=1 variant=solution tflops=47.977 gbps=125.131 ms=0.064
shape=1 solution_peak_fraction=0.2399
shape=2 variant=eager tflops=26.149 gbps=54.013 ms=0.265
shape=2 variant=compiled tflops=22.804 gbps=47.104 ms=0.304
shape=2 variant=sota tflops=26.172 gbps=54.062 ms=0.265
shape=2 variant=solution tflops=91.537 gbps=189.081 ms=0.076
shape=2 solution_peak_fraction=0.4577
shape=3 variant=eager tflops=6.299 gbps=24.439 ms=0.147
shape=3 variant=compiled tflops=4.911 gbps=19.053 ms=0.188
shape=3 variant=sota tflops=6.454 gbps=25.040 ms=0.143
shape=3 variant=solution tflops=17.112 gbps=66.387 ms=0.054
shape=3 solution_peak_fraction=0.0856
peak_fraction: 0.1550
RESULT: OK
{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"filePath": "/home/infatoshi/.claude/skills/debug-gpu/SKILL.md"
}{
"filePath": "/home/infatoshi/.claude/skills/kernel/SKILL.md"
}--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,6 +1,8 @@
"""Triton-based 3D patch-embedding kernel.
-Im2col via reshape/permute followed by an autotuned Triton GEMM.
+Fused kernel: gathers patches from (B,C,T,H,W) input directly into the
+GEMM K-loop, eliminating the im2col temporary buffer and potentially
+overlapping gather + MMA for better utilization.
"""
import torch
import torch.nn as nn
@@ -34,82 +36,75 @@
num_patches = (T // kT) * (H // kH) * (W // kW)
K = C * kT * kH * kW
M_p = B * num_patches
-
- x = x.reshape(B, C, T // kT, kT, H // kH, kH, W // kW, kW)
- x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
- x = x.reshape(M_p, K)
-
+ strided_input = K < 128
+
+ # Compute 3D grid indices (pre-pass for per-batch)
+ PT = T // kT
+ PH = H // kH
+ PW = W // kW
+
+ out = torch.empty(B, N, PT, PH, PW, dtype=torch.bfloat16, device=x.device)
+
+ # Weight: (K, N)
w = self.conv.weight.reshape(N, K).t().contiguous()
- out = torch.empty(M_p, N, dtype=torch.bfloat16, device=x.device)
-
+ # Launch fused kernel
grid = lambda META: (
triton.cdiv(M_p, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
)
- _gemm_kernel[grid](
+ _fused_kernel[grid](
x, w, out,
M_p, N, K,
- x.stride(0), x.stride(1),
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
w.stride(0), w.stride(1),
- out.stride(0), out.stride(1),
- )
-
- out = out.reshape(B, T // kT, H // kH, W // kW, N).permute(0, 4, 1, 2, 3)
+ out.stride(0), out.stride(1), out.stride(2), out.stride(3), out.stride(4),
+ B, C, kT, kH, kW,
+ PT, PH, PW,
+ )
+
return out
-gemm_configs = [
- # ====== 4 stages, 8 warps (deep pipeline for high-occupancy) ======
+fused_configs = [
+ # Small M (shape 0: 256 patches)
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK-K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ # Medium M (shape 1: 1024, shape 3: 784)
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- # ====== 3 stages, 8 warps ======
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- # ====== GROUP_M=1 variants (good for small M) ======
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M':
... (5839 more chars){
"content": "<8039 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify fused kernel correctness",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}FAIL: import error: dynamic_func() missing 1 required positional argument: 'BLOCK_K'
{
"command": "<205 chars \u2014 see below>",
"description": "Check Triton Config structure"
}uv run python -c "
import triton, triton.language as tl
c = triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=4)
print(c)
print(type(c))
print(c.kwargs)
"BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, GROUP_M: 8, num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
<class 'triton.runtime.autotuner.Config'>
{'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}
{
"command": "<927 chars \u2014 see below>",
"description": "Test Triton autotune calling convention",
"timeout": 60000
}uv run python -c "
import triton, triton.language as tl
@triton.autotune(configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
], key=['M', 'N', 'K'])
@triton.jit
def test_kernel(A, B, C, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr):
pass
import torch
A = torch.empty(256, 1176, dtype=torch.bfloat16, device='cuda:0')
B = torch.empty(1176, 1280, dtype=torch.bfloat16, device='cuda:0')
C = torch.empty(256, 1280, dtype=torch.bfloat16, device='cuda:0')
grid = lambda META: (triton.cdiv(256, META['BLOCK_M']) * triton.cdiv(1280, META['BLOCK_N']),)
test_kernel[grid](A, B, C, 256, 1280, 1176, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1))
print('OK')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 7, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -70,7 +70,7 @@
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK-K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 1}, num_stages=3, num_warps=4),
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": " triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK-K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),",
"newString": " triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 1}, num_stages=3, num_warps=4),"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"description": "Test fused kernel with typo fix",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark fused kernel",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.617 gbps=20.028 ms=0.213
shape=0 variant=compiled tflops=3.066 gbps=16.981 ms=0.251
shape=0 variant=sota tflops=3.664 gbps=20.291 ms=0.210
shape=0 variant=solution tflops=8.541 gbps=47.296 ms=0.090
shape=0 solution_peak_fraction=0.0427
shape=1 variant=eager tflops=13.689 gbps=35.704 ms=0.225
shape=1 variant=compiled tflops=11.575 gbps=30.189 ms=0.266
shape=1 variant=sota tflops=13.913 gbps=36.286 ms=0.222
shape=1 variant=solution tflops=30.794 gbps=80.315 ms=0.100
shape=1 solution_peak_fraction=0.1540
shape=2 variant=eager tflops=25.181 gbps=52.015 ms=0.275
shape=2 variant=compiled tflops=22.619 gbps=46.723 ms=0.307
shape=2 variant=sota tflops=26.475 gbps=54.686 ms=0.262
shape=2 variant=solution tflops=65.309 gbps=134.903 ms=0.106
shape=2 solution_peak_fraction=0.3265
shape=3 variant=eager tflops=6.345 gbps=24.616 ms=0.146
shape=3 variant=compiled tflops=4.941 gbps=19.169 ms=0.187
shape=3 variant=sota tflops=6.448 gbps=25.017 ms=0.143
shape=3 variant=solution tflops=17.818 gbps=69.129 ms=0.052
shape=3 solution_peak_fraction=0.0891
peak_fraction: 0.1176
RESULT: OK