"""Fused im2col + GEMM patch-embedding kernel (Triton).
Key optimization: BLOCK_K = PADDED_KW * UNROLL so that:
- Each k-tile of PADDED_KW elements maps to one (c, dt, dh) row
- x reads within each k-tile are contiguous along dw (coalesced)
- UNROLL reduces loop iterations while maintaining coalescing
Weight is pre-transposed to (K, N) for contiguous GEMM loads.
Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "patch_embed"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
def _next_pow2(n: int) -> int:
p = 1
while p < n:
p <<= 1
return p
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=3),
],
key=["M", "N", "NUM_OUTER"],
)
@triton.jit
def _patch_embed_kernel(
x_ptr, w_ptr, out_ptr,
M, N: tl.constexpr, K: tl.constexpr,
NUM_OUTER: tl.constexpr,
PADDED_KW: tl.constexpr,
UNROLL: tl.constexpr,
BLOCK_K: tl.constexpr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_wk,
stride_out_b, stride_out_n, stride_out_to, stride_out_ho, stride_out_wo,
B: tl.constexpr, C: tl.constexpr,
T: tl.constexpr, H: tl.constexpr, W: tl.constexpr,
kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
To: tl.constexpr, Ho: tl.constexpr, Wo: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_patches = To * Ho * Wo
kHkW = kH * kW
kTkHkW = kT * kHkW
HW = H * W
THW = T * HW
CTHW = C * THW
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
mask_m = offs_m < M
b_idx = offs_m // num_patches
rem = offs_m % num_patches
to_idx = rem // (Ho * Wo)
rem2 = rem % (Ho * Wo)
ho_idx = rem2 // Wo
wo_idx = rem2 % Wo
base_x = (b_idx * CTHW
+ to_idx * kT * HW
+ ho_idx * kH * W
+ wo_idx * kW)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
offs_k = tl.arange(0, BLOCK_K)
u = offs_k // PADDED_KW
dw = offs_k % PADDED_KW
mask_dw = dw < kW
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for outer in range(NUM_OUTER):
row_base = outer * UNROLL
row_idx = row_base + u
c = row_idx // (kT * kH)
rem_r = row_idx % (kT * kH)
dt = rem_r // kH
dh = rem_r % kH
x_addr = (base_x[:, None]
+ c[None, :] * THW
+ dt[None, :] * HW
+ dh[None, :] * W
+ dw[None, :])
x_vals = tl.load(x_ptr + x_addr,
mask=mask_m[:, None] & mask_dw[None, :],
other=0.0)
k_base = c * kTkHkW + dt * kHkW + dh * kW
w_addr = (k_base[:, None]) * stride_wk + dw[:, None] * stride_wk + offs_n[None, :]
w_vals = tl.load(w_ptr + w_addr,
mask=mask_dw[:, None] & mask_n[None, :],
other=0.0)
acc += tl.dot(x_vals, w_vals)
out_addr = (b_idx[:, None] * stride_out_b
+ offs_n[None, :] * stride_out_n
+ to_idx[:, None] * stride_out_to
+ ho_idx[:, None] * stride_out_ho
+ wo_idx[:, None] * stride_out_wo)
tl.store(out_ptr + out_addr, acc,
mask=mask_m[:, None] & mask_n[None, :])
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
assert T % kT == 0 and H % kH == 0 and W % kW == 0
self.B = B
self.C = C
self.T = T
self.H = H
self.W = W
self.kT = kT
self.kH = kH
self.kW = kW
self.embed_dim = embed_dim
self.conv = nn.Module()
self.conv.weight = nn.Parameter(
torch.empty(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16)
)
nn.init.normal_(self.conv.weight, std=0.02)
self._w_t = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T, H, W = x.shape
kT, kH, kW = self.kT, self.kH, self.kW
embed_dim = self.embed_dim
To, Ho, Wo = T // kT, H // kH, W // kW
K = C * kT * kH * kW
M = B * To * Ho * Wo
N = embed_dim
PADDED_KW = _next_pow2(kW)
if PADDED_KW < 16:
PADDED_KW = 16
UNROLL = 4
BLOCK_K = PADDED_KW * UNROLL
NUM_OUTER = (C * kT * kH + UNROLL - 1) // UNROLL
if self._w_t is None or self._w_t.shape != (K, N):
self._w_t = self.conv.weight.reshape(N, K).T.contiguous()
out = torch.empty((B, N, To, Ho, Wo), dtype=x.dtype, device=x.device)
def grid(META):
return (triton.cdiv(M, META["BLOCK_M"]),
triton.cdiv(N, META["BLOCK_N"]))
_patch_embed_kernel[grid](
x, self._w_t, out,
M, N, K,
NUM_OUTER, PADDED_KW, UNROLL, BLOCK_K,
x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
N,
out.stride(0), out.stride(1), out.stride(2), out.stride(3), out.stride(4),
B, C, T, H, W,
kT, kH, kW,
To, Ho, Wo,
)
return out
shape=0 variant=eager tflops=3.600 gbps=19.937 ms=0.214
shape=0 variant=compiled tflops=3.043 gbps=16.850 ms=0.253
shape=0 variant=sota tflops=3.630 gbps=20.102 ms=0.212
shape=0 variant=solution tflops=16.202 gbps=89.725 ms=0.048
shape=0 solution_peak_fraction=0.0810
shape=1 variant=eager tflops=13.556 gbps=35.357 ms=0.227
shape=1 variant=compiled tflops=11.294 gbps=29.457 ms=0.273
shape=1 variant=sota tflops=13.749 gbps=35.859 ms=0.224
shape=1 variant=solution tflops=54.199 gbps=141.358 ms=0.057
shape=1 solution_peak_fraction=0.2710
shape=2 variant=eager tflops=25.507 gbps=52.688 ms=0.272
shape=2 variant=compiled tflops=22.177 gbps=45.810 ms=0.313
shape=2 variant=sota tflops=26.109 gbps=53.932 ms=0.266
shape=2 variant=solution tflops=111.876 gbps=231.094 ms=0.062
shape=2 solution_peak_fraction=0.5594
shape=3 variant=eager tflops=6.308 gbps=24.474 ms=0.147
shape=3 variant=compiled tflops=4.866 gbps=18.877 ms=0.190
shape=3 variant=sota tflops=6.443 gbps=24.995 ms=0.144
shape=3 variant=solution tflops=21.228 gbps=82.356 ms=0.044
shape=3 solution_peak_fraction=0.1061
peak_fraction: 0.1900
RESULT: OK
PASS
{
"problem": "10_patch_embed_conv3d_gemm",
"harness": "opencode",
"model": "openrouter-pinned/xiaomi/mimo-v2.5-pro",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1900,
"elapsed_seconds": 2261,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 202583, "output_tokens": 28502, "cache_read_tokens": 6538048, "cache_creation_tokens": 0, "reasoning_tokens": 112435, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, benchmark.py, check.py, problem.yaml to understand the task",
"status": "in_progress",
"priority": "high"
},
{
"content": "Read CLAUDE.md and any other project context files",
"status": "pending",
"priority": "high"
},
{
"content": "Check GPU availability and CUDA environment",
"status": "pending",
"priority": "high"
},
{
"content": "Implement initial CUDA kernel for 3D patch embedding",
"status": "pending",
"priority": "high"
},
{
"c
... (350 more chars)content: 'Read reference.py, benchmark.py, check.py, problem.yaml to understand the task'
status: 'in_progress'
priority: 'high'
content: 'Read CLAUDE.md and any other project context files'
status: 'pending'
priority: 'high'
content: 'Check GPU availability and CUDA environment'
status: 'pending'
priority: 'high'
content: 'Implement initial CUDA kernel for 3D patch embedding'
status: 'pending'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'medium'
content: 'Benchmark against reference'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive Conv3d-as-GEMM patch embedding reference (correctness only).
2:
3: Vision-Transformer / Qwen2-VL style patch embedding: a video tensor
4: (B, C, T, H, W) is split into non-overlapping (kT, kH, kW) patches and each
5: patch projected to embed_dim. This is mathematically a 3D convolution with
6: stride == kernel and equivalently a single GEMM after a strided reshape.
7:
8: Reference uses nn.Conv3d (cuDNN dispatch) for clarity. The agent's solution is
9: forbidden from using Conv3d / conv3d / matmul / linear / einsum, forcing them
10: to write a fused patch-gather + tensor-core GEMM kernel.
11:
12: Output layout: (B, embed_dim, T/kT, H/kH, W/kW). No bias.
13: """
14: import torch
15: import torch.nn as nn
16:
17: OP_TYPE = "patch_embed"
18: SUPPORTED_PRECISIONS = ["bf16"]
19: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
20:
21:
22: class Model(nn.Module):
23: def __init__(self, B: int, C: int, T: int, H: int, W: int,
24: kT: int, kH: int, kW: int, embed_dim: int):
25: super().__init__()
26: assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
27: f"Input dims must be divisible by patch size: T={T} kT={kT} H={H} kH={kH} W={W} kW={kW}"
28: self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
29: self.kT, self.kH, self.kW = kT, kH, kW
30: self.embed_dim = embed_dim
31:
3
... (911 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Conv3d-as-GEMM patch embedding.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regim
... (3879 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Conv3d-as-GEMM patch embedding.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") o
... (3130 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 10_patch_embed_conv3d_gemm
2: display_name: "Conv3d-as-GEMM Patch Embedding"
3: precision: bf16
4: regime: compute # GEMM-equivalent, embed_dim=1280 is well above the bf16 ridge
5:
6: # Dense FLOPs: 2 * (num_patches) * (C * kT * kH * kW) * embed_dim.
7: # With stride==kernel and divisible inputs, num_patches * (kT*kH*kW) = T*H*W,
8: # so this collapses to: 2 * B * embed_dim * C * T * H * W.
9: flops_formula: "2 * B * embed_dim * C * T * H * W"
10:
11: # Bytes moved (lower bound):
12: # read input x: B * C * T * H * W * 2
13: # read weights: embed_dim * C * kT * kH * kW * 2
14: # write output: B * embed_dim * (T // kT) * (H // kH) * (W // kW) * 2
15: bytes_formula: "B*C*T*H*W*2 + embed_dim*C*kT*kH*kW*2 + B*embed_dim*(T//kT)*(H//kH)*(W//kW)*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must implement the fused patch-gather + GEMM
25: # themselves. Banning Conv3d/conv3d (the cuDNN dispatch) and matmul/bmm/linear/
26: # einsum (the lazy reshape-then-cuBLAS shortcut). Unfold is also banned for
27: # completeness even though it's nominally 2D.
28: forbidden:
29: - "torch.nn.Conv3d"
30: - "torch.nn.functional.conv3d"
31: - "F.conv3d"
32: - "torch.conv3d"
33: - "torch.matmul"
34: - "torch.bmm"
35: -
... (386 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for Conv3d-as-GEMM patch embedding.
2:
3: Sized for Qwen2-VL / ViT-style patch embedding. embed_dim=1280 matches Qwen2-VL
4: ViT; embed_dim=768 is a ViT-B/16 reference. Patch sizes are the canonical
5: 14x14 spatial / 2-frame temporal used in modern VL encoders.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL ViT base: 224x224x2-frame video, 14x14x2 patches, embed_dim=1280
10: {"B": 1, "C": 3, "T": 2, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
11:
12: # Same patch config, batch 2, 4-frame video
13: {"B": 2, "C": 3, "T": 4, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
14:
15: # Larger spatial (336x336), 8-frame — exercises larger num_patches
16: {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
17:
18: # ViT-B/16 (image-mode): kT=1, kH=kW=16, embed_dim=768
19: {"B": 4, "C": 3, "T": 1, "H": 224, "W": 224, "kT": 1, "kH": 16, "kW": 16, "embed_dim": 768},
20: ]
(End of file - total 20 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Conv3d-as-GEMM patch embedding.
2:
3: Dispatches torch.nn.functional.conv3d (cuDNN). The agent's solution is
4: forbidden from calling conv3d directly; this file is only for the benchmark's
5: reference baseline.
6: """
7: from __future__ import annotations
8:
9: import torch
10: import torch.nn.functional as F
11:
12:
13: def sota_forward(x: torch.Tensor, weight: torch.Tensor,
14: stride: tuple[int, int, int]) -> torch.Tensor:
15: return F.conv3d(x, weight, bias=None, stride=stride)
16:
17:
18: def is_available() -> bool:
19: return True
(End of file - total 19 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/CLAUDE.md"
}<path>/home/infatoshi/cuda/KernelBench-Hard/CLAUDE.md</path>
<type>file</type>
<content>
1: # KernelBench-Hard — Developer Instructions
2:
3: Last updated: 2026-04-27.
4:
5: This file is for **coding agents editing the repo** (you, via Claude Code). Do not confuse with `problems/<X>/PROMPT.txt` — those are the human-voice queries fed to agents _under test_.
6:
7: For the journey behind the current design, read [DEVLOG.md](./DEVLOG.md).
8:
9: ## What this repo is
10:
11: Small kernel benchmark. Frontier coding agents are given URLs to SOTA implementations (sonic-moe, flashinfer, marlin) and asked to write a competitive kernel on RTX PRO 6000 Blackwell (SM120) in 45 minutes. Roofline-graded. Published artifact is the best kernel per (problem × model × harness), plus the agent trace.
12:
13: See [SPEC.md](./SPEC.md) for methodology. See [README.md](./README.md) for the model matrix and quick start.
14:
15: ## Non-negotiable rules
16:
17: - **uv only.** No bare `python`, no `pip`. Use `uv run ...`, `uv add ...`, `uv pip install ...`.
18: - **Before committing:** `uv run ruff check . --fix && uv run pytest`.
19: - **Never edit `problems/*/solution.py`**. Those files are agent output; they're gitignored for a reason. If you need to inspect one, read it from `outputs/runs/<run>/<problem>/solution.py`.
20: - **Never modify `problems/*/reference.py`, `check.py`, `benchmark.py`, `problem.yaml`, `shapes.py`, or `PROMPT.txt`** once a sweep has been published. Those define the ben
... (5941 more chars){
"command": "nvidia-smi 2>&1 | head -20",
"description": "Check GPU availability"
}Thu May 7 16:17:02 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 On | Off |
| 30% 41C P0 95W / 600W | 385MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (189 more chars){
"command": "which triton 2>/dev/null; python3 -c \"import triton; print(triton.__version__)\" 2>&1; uv run python -c \"import torch; print(torch.__version__); print(torch.cuda.get_device_capability())\" 2>&1",
"description": "Check Triton and PyTorch availability"
}triton not found
3.6.0
2.11.0+cu130
(12, 0)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware</path>
<type>directory</type>
<entries>
__init__.py
__pycache__/
m4_max.py
rtx_pro_6000.py
(4 entries)
</entries>{
"command": "uv run python -c \"import sys; sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard'); from src.hardware import get; hw = get('RTX_PRO_6000'); print('peak_tflops:', hw.peak_tflops_dense); print('peak_bw:', hw.peak_bandwidth_gb_s)\" 2>&1",
"description": "Get GPU peak specs"
}peak_tflops: {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}
peak_bw: 1800.0
{
"command": "uv run python -c \"import torch; print('bf16 support:', torch.cuda.is_bf16_supported()); print('device:', torch.cuda.get_device_name())\" 2>&1",
"description": "Check bf16 support"
}bf16 support: True
device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py</path>
<type>file</type>
<content>
1: """Roofline math: achieved TFLOPS / GB/s, peak fraction.
2:
3: FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
4: dense-equivalent algorithmic work; kernels that skip work (sparsity, early
5: exit) are still measured against the dense cost.
6: """
7: from __future__ import annotations
8:
9:
10: def compute_tflops(flops: float, time_ms: float) -> float:
11: """flops per invocation and wall time in ms -> TFLOPS."""
12: if time_ms <= 0:
13: return 0.0
14: return flops / (time_ms * 1e-3) / 1e12
15:
16:
17: def compute_gbps(bytes_moved: float, time_ms: float) -> float:
18: """bytes per invocation and wall time in ms -> GB/s."""
19: if time_ms <= 0:
20: return 0.0
21: return bytes_moved / (time_ms * 1e-3) / 1e9
22:
23:
24: def peak_fraction(
25: achieved: float,
26: peak: float,
27: ) -> float:
28: """Return achieved / peak, clamped to [0, infinity)."""
29: if peak <= 0:
30: return 0.0
31: return max(0.0, achieved / peak)
(End of file - total 31 lines)
</content>{
"command": "<770 chars \u2014 see below>",
"description": "Calculate GEMM dimensions for each shape"
}uv run python -c "
shapes = [
{'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280},
{'B': 2, 'C': 3, 'T': 4, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280},
{'B': 1, 'C': 3, 'T': 8, 'H': 336, 'W': 336, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280},
{'B': 4, 'C': 3, 'T': 1, 'H': 224, 'W': 224, 'kT': 1, 'kH': 16, 'kW': 16, 'embed_dim': 768},
]
for i, s in enumerate(shapes):
B,C,T,H,W,kT,kH,kW,E = s['B'],s['C'],s['T'],s['H'],s['W'],s['kT'],s['kH'],s['kW'],s['embed_dim']
np = (T//kT)*(H//kH)*(W//kW)
M = B * np
K = C * kT * kH * kW
N = E
flops = 2 * M * K * N
print(f'Shape {i}: M={M}, K={K}, N={N}, FLOPS={flops/1e9:.2f}G, TFLOPS_at_1ms={flops/1e9:.2f}')
" 2>&1Shape 0: M=256, K=1176, N=1280, FLOPS=0.77G, TFLOPS_at_1ms=0.77
Shape 1: M=1024, K=1176, N=1280, FLOPS=3.08G, TFLOPS_at_1ms=3.08
Shape 2: M=2304, K=1176, N=1280, FLOPS=6.94G, TFLOPS_at_1ms=6.94
Shape 3: M=784, K=768, N=768, FLOPS=0.92G, TFLOPS_at_1ms=0.92
{
"command": "uv run python -c \"import triton\" 2>&1; uv pip list 2>&1 | grep -i triton",
"description": "Check if Triton is available"
}triton 3.6.0
{
"todos": [
{
"content": "Read reference.py, benchmark.py, check.py, problem.yaml to understand the task",
"priority": "high",
"status": "completed"
},
{
"content": "Read CLAUDE.md and any other project context files",
"priority": "high",
"status": "completed"
},
{
"content": "Check GPU availability and CUDA environment",
"priority": "high",
"status": "completed"
},
{
"content": "Implement initial Triton kernel for 3D patch embedding (fused im2col + GEMM)",
"priority": "high",
"status": "in_p
... (380 more chars)content: 'Read reference.py, benchmark.py, check.py, problem.yaml to understand the task'
status: 'completed'
priority: 'high'
content: 'Read CLAUDE.md and any other project context files'
status: 'completed'
priority: 'high'
content: 'Check GPU availability and CUDA environment'
status: 'completed'
priority: 'high'
content: 'Implement initial Triton kernel for 3D patch embedding (fused im2col + GEMM)'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'medium'
content: 'Benchmark against reference'
status: 'pending'
priority: 'medium'{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"command": "ls /home/infatoshi/.claude/skills/perf-optim/",
"description": "List perf-optim skill files"
}SKILL.md
{
"command": "wc -l /home/infatoshi/.claude/skills/perf-optim/SKILL.md",
"description": "Check perf-optim file size"
}1002 /home/infatoshi/.claude/skills/perf-optim/SKILL.md
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -0,0 +1,199 @@
+"""Fused im2col + GEMM patch-embedding kernel (Triton).
+
+Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
+patch embedding) using a fused Triton kernel that gathers patches on-the-fly
+in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
+"""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "patch_embed"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel: fused im2col + GEMM
+# ---------------------------------------------------------------------------
+# output[m, n] = sum_k x_im2col[m, k] * weight[n, k]
+#
+# where:
+# m indexes (batch, patch_position) -> M = B * To * Ho * Wo
+# n indexes embed_dim -> N = embed_dim
+# k indexes (c, dt, dh, dw) -> K = C * kT * kH * kW
+# ---------------------------------------------------------------------------
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ ],
+ key=["M", "N", "K"],
+)
+@triton.jit
+def _patch_embed_kernel(
+ # Pointers
+ x_ptr, w_ptr, out_ptr,
+ # GEMM dims (N and K are constexpr so the compiler can optimise)
+ M, N: tl.constexpr, K: tl.constexpr,
+ # Tensor strides (contiguous NCDHW for x, (N,K) row-major for w, NCDHW-like for out)
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_w_n, # weight is (N, K) row-major -> stride_w_n = K
+ stride_out_m, # output is (M, N) row-major -> stride_out_m = N
+ # Problem constants (constexpr for fast integer math)
+ B: tl.constexpr, C: tl.constexpr,
+ T: tl.constexpr, H: tl.constexpr, W: tl.constexpr,
+ kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
+ To: tl.constexpr, Ho: tl.constexpr, Wo: tl.constexpr,
+ # Tile sizes
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ # -- programme ids -------------------------------------------------------
+ pid_m = tl.program_id(0) # tile along M (patches)
+ pid_n = tl.program_id(1) # tile along N (embed_dim)
+
+ num_patches = To * Ho * Wo
+ kHkW = kH * kW
+ kTkHkW = kT * kHkW
+ HW = H * W
+ THW = T * HW
+ CTHW = C * THW
+
+ # -- M tile: decode into (b, to, ho, wo) ---------------------------------
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < M
+
+ b_idx = offs_m // num_patches
+ rem = offs_m % num_patches
+ to_idx = rem // (Ho * Wo)
+ rem2 = rem % (Ho * Wo)
+ ho_idx = rem2 // Wo
+ wo_idx = rem2 % Wo
+
+ # base x-offset per row (the c/dt/dh/dw part is added in the K loop)
+ base_x = (b_idx * CTHW
+ + to_idx * kT * HW
+ + ho_idx * kH * kW # wait, should be ho_idx * kH * W
+ + wo_idx * kW)
+
+ # -- N tile --------------------------------------------------------------
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < N
+
+ # -- accumulators --------------------------------------------------------
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ # -- K loop --------------------------------------------------------------
+ for k_start in range(0, K, BLOCK_K):
+ offs_k = k_start + tl.arange(0, BLOCK_K)
+ mask_k = offs_k < K
+
+ # decode k -> (c, dt, dh, dw)
+ c_idx = offs_k // kTkHkW
+ rem_k = offs_k % kTkHkW
+ dt_idx = rem_k // kHkW
+ rem_k2 = rem_k % kHkW
+ dh_idx = rem_k2 // kW
+ dw_idx = rem_k2 % kW
+
+ # x addresses (BLOCK_M, BLOCK_K)
+ x_addr = (base_x[:, None]
+ + c_idx[None, :] * THW
+ + dt_idx[None, :] * HW
+ + dh_idx[None, :] * W
+ + dw_idx[None, :])
+
+ x_vals = tl.load(x_ptr + x_addr,
+ mask=mask_m[:, None] & mask_k[None, :],
+ other=0.0)
+
+ # w addresses (BLOCK_K, BLOCK_N) -- w is (N, K) row-major
+ w_addr = offs_n[None, :] * K + offs_k[:, None]
+ w_vals = tl.load(w_ptr + w_addr,
+ mask=mask_n[None, :] & mask_k[:, None],
+ other=0.0)
+
+ # tensor-core dot (BLOCK_M, K_tile) @ (K_tile, BLOCK_N)
+ acc += tl.dot(x_vals, w_vals)
+
+ # -- store output --------------------------------------------------------
+ out_addr = offs_m[:, None] * N + offs_n[None, :]
+ tl.store(out_ptr + out_addr, acc,
+ mask=mask_m[:, None] & mask_n[None, :])
+
+
+# ---------------------------------------------------------------------------
+# Wrapper model
+# ---------------------------------------------------------------------------
+
+class Model(nn.Module):
+ """Patch-embedding via fused im2col+GEMM Triton kernel.
+
+ Declares ``conv.weight`` of shape ``(embed_dim, C, kT, kH, kW)`` so that
+ ``state_dict`` from a reference ``nn.Conv3d`` loads with ``strict=True``.
+ """
+
+ def __init__(self, B: int, C: int, T: int, H: int,
... (2065 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<7742 chars \u2014 see diff>"
}Wrote file successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py</path>
<type>file</type>
<content>
1: """Fused im2col + GEMM patch-embedding kernel (Triton).
2:
3: Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
4: patch embedding) using a fused Triton kernel that gathers patches on-the-fly
5: in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
6: """
7: import torch
8: import torch.nn as nn
9: import triton
10: import triton.language as tl
11:
12: OP_TYPE = "patch_embed"
13: SUPPORTED_PRECISIONS = ["bf16"]
14: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
15:
16:
17: # ---------------------------------------------------------------------------
18: # Triton kernel: fused im2col + GEMM
19: # ---------------------------------------------------------------------------
20: # output[m, n] = sum_k x_im2col[m, k] * weight[n, k]
21: #
22: # where:
23: # m indexes (batch, patch_position) -> M = B * To * Ho * Wo
24: # n indexes embed_dim -> N = embed_dim
25: # k indexes (c, dt, dh, dw) -> K = C * kT * kH * kW
26: # ---------------------------------------------------------------------------
27:
28: @triton.autotune(
29: configs=[
30: triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
31: triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
32: triton.Config({"BLOCK_M": 128,
... (7299 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -81,7 +81,7 @@
# base x-offset per row (the c/dt/dh/dw part is added in the K loop)
base_x = (b_idx * CTHW
+ to_idx * kT * HW
- + ho_idx * kH * kW # wait, should be ho_idx * kH * W
+ + ho_idx * kH * W
+ wo_idx * kW)
# -- N tile --------------------------------------------------------------
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<229 chars \u2014 see below>",
"newString": "<193 chars \u2014 see below>"
} # base x-offset per row (the c/dt/dh/dw part is added in the K loop)
base_x = (b_idx * CTHW
+ to_idx * kT * HW
+ ho_idx * kH * kW # wait, should be ho_idx * kH * W
+ wo_idx * kW) # base x-offset per row (the c/dt/dh/dw part is added in the K loop)
base_x = (b_idx * CTHW
+ to_idx * kT * HW
+ ho_idx * kH * W
+ wo_idx * kW)Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm",
"timeout": 300000
}PASS
{
"todos": [
{
"content": "Read reference.py, benchmark.py, check.py, problem.yaml to understand the task",
"priority": "high",
"status": "completed"
},
{
"content": "Read CLAUDE.md and any other project context files",
"priority": "high",
"status": "completed"
},
{
"content": "Check GPU availability and CUDA environment",
"priority": "high",
"status": "completed"
},
{
"content": "Implement initial Triton kernel for 3D patch embedding (fused im2col + GEMM)",
"priority": "high",
"status": "comp
... (382 more chars)content: 'Read reference.py, benchmark.py, check.py, problem.yaml to understand the task'
status: 'completed'
priority: 'high'
content: 'Read CLAUDE.md and any other project context files'
status: 'completed'
priority: 'high'
content: 'Check GPU availability and CUDA environment'
status: 'completed'
priority: 'high'
content: 'Implement initial Triton kernel for 3D patch embedding (fused im2col + GEMM)'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'in_progress'
priority: 'high'
content: 'Benchmark against reference'
status: 'pending'
priority: 'medium'{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.617 gbps=20.031 ms=0.213
shape=0 variant=compiled tflops=3.038 gbps=16.822 ms=0.254
shape=0 variant=sota tflops=3.659 gbps=20.265 ms=0.211
shape=0 variant=solution tflops=10.503 gbps=58.167 ms=0.073
shape=0 solution_peak_fraction=0.0525
shape=1 variant=eager tflops=12.704 gbps=33.133 ms=0.243
shape=1 variant=compiled tflops=11.363 gbps=29.637 ms=0.271
shape=1 variant=sota tflops=13.799 gbps=35.990 ms=0.223
shape=1 variant=solution tflops=37.010 gbps=96.529 ms=0.083
shape=1 solution_peak_fraction=0.1851
shape=2 variant=eager tflops=25.851 gbps=53.398 ms=0.268
shape=2 variant=compiled tflops=22.206 gbps=45.868 ms=0.312
shape=2 variant=sota tflops=25.557 gbps=52.791 ms=0.271
shape=2 variant=solution tflops=63.473 gbps=131.111 ms=0.109
shape=2 solution_peak_fraction=0.3174
shape=3 variant=eager tflops=6.284 gbps=24.381 ms=0.147
shape=3 variant=compiled tflops=4.870 gbps=18.896 ms=0.190
shape=3 variant=sota tflops=6.337 gbps=24.584 ms=0.146
shape=3 variant=solution tflops=18.562 gbps=72.015 ms=0.050
shape=3 solution_peak_fraction=0.0928
peak_fraction: 0.1301
RESULT: OK
{
"command": "<1084 chars \u2014 see below>",
"description": "Profile solution with torch.profiler",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm",
"timeout": 120000
}uv run python -c "
import torch
# Quick check: what autotuning configs get selected, and what are the bottleneck
# Let's profile the solution with torch.profiler
import solution, reference, shapes
device = torch.device('cuda:0')
# Shape 2 is the largest and most interesting
shape = shapes.SHAPES[2]
for k,v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
ref_model = reference.Model(*init_args).to(device).eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup
for _ in range(5):
with torch.no_grad():
sol_model(*inputs)
torch.cuda.synchronize()
# Profile
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
) as prof:
with torch.no_grad():
sol_model(*inputs)
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=20))
" 2>&1/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
_patch_embed_kernel 0.00% 0.000us 0.00% 0.000us 0.000us 40.127us 65.07% 40.127us 40.127us 1
void at::native::elementwise_kernel<128, 4, at::nati... 0.00% 0.000us 0.00% 0.000us 0.000us 21.536us 34.93% 21.536us 21.536us 1
Activity Buffer Request 94.33% 1.056ms 94.33% 1.056ms 1.056ms 0.000us 0.00% 0.000us 0.000us 1
... (834 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py</path>
<type>file</type>
<content>
1: """Fused im2col + GEMM patch-embedding kernel (Triton).
2:
3: Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
4: patch embedding) using a fused Triton kernel that gathers patches on-the-fly
5: in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
6: """
7: import torch
8: import torch.nn as nn
9: import triton
10: import triton.language as tl
11:
12: OP_TYPE = "patch_embed"
13: SUPPORTED_PRECISIONS = ["bf16"]
14: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
15:
16:
17: # ---------------------------------------------------------------------------
18: # Triton kernel: fused im2col + GEMM
19: # ---------------------------------------------------------------------------
20: # output[m, n] = sum_k x_im2col[m, k] * weight[n, k]
21: #
22: # where:
23: # m indexes (batch, patch_position) -> M = B * To * Ho * Wo
24: # n indexes embed_dim -> N = embed_dim
25: # k indexes (c, dt, dh, dw) -> K = C * kT * kH * kW
26: # ---------------------------------------------------------------------------
27:
28: @triton.autotune(
29: configs=[
30: triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
31: triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
32: triton.Config({"BLOCK_M": 128,
... (7263 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -3,6 +3,9 @@
Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
patch embedding) using a fused Triton kernel that gathers patches on-the-fly
in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
+
+Output is written directly in (B, embed_dim, To, Ho, Wo) layout so no
+reshape/permute/contiguous is needed after the kernel.
"""
import torch
import torch.nn as nn
@@ -17,12 +20,12 @@
# ---------------------------------------------------------------------------
# Triton kernel: fused im2col + GEMM
# ---------------------------------------------------------------------------
-# output[m, n] = sum_k x_im2col[m, k] * weight[n, k]
+# output[b, n, to, ho, wo] = sum_k x[b, c, to*kT+dt, ho*kH+dh, wo*kW+dw]
+# * weight[n, c, dt, dh, dw]
#
-# where:
-# m indexes (batch, patch_position) -> M = B * To * Ho * Wo
-# n indexes embed_dim -> N = embed_dim
-# k indexes (c, dt, dh, dw) -> K = C * kT * kH * kW
+# m indexes (batch, patch_position) -> M = B * To * Ho * Wo
+# n indexes embed_dim -> N = embed_dim
+# k indexes (c, dt, dh, dw) -> K = C * kT * kH * kW
# ---------------------------------------------------------------------------
@triton.autotune(
@@ -35,6 +38,8 @@
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
],
key=["M", "N", "K"],
)
@@ -42,13 +47,13 @@
def _patch_embed_kernel(
# Pointers
x_ptr, w_ptr, out_ptr,
- # GEMM dims (N and K are constexpr so the compiler can optimise)
+ # GEMM dims
M, N: tl.constexpr, K: tl.constexpr,
- # Tensor strides (contiguous NCDHW for x, (N,K) row-major for w, NCDHW-like for out)
+ # Tensor strides
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
- stride_w_n, # weight is (N, K) row-major -> stride_w_n = K
- stride_out_m, # output is (M, N) row-major -> stride_out_m = N
- # Problem constants (constexpr for fast integer math)
+ stride_w_n,
+ stride_out_b, stride_out_n, stride_out_to, stride_out_ho, stride_out_wo,
+ # Problem constants
B: tl.constexpr, C: tl.constexpr,
T: tl.constexpr, H: tl.constexpr, W: tl.constexpr,
kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
@@ -56,9 +61,8 @@
# Tile sizes
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
- # -- programme ids -------------------------------------------------------
- pid_m = tl.program_id(0) # tile along M (patches)
- pid_n = tl.program_id(1) # tile along N (embed_dim)
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
num_patches = To * Ho * Wo
kHkW = kH * kW
@@ -78,7 +82,6 @@
ho_idx = rem2 // Wo
wo_idx = rem2 % Wo
- # base x-offset per row (the c/dt/dh/dw part is added in the K loop)
base_x = (b_idx * CTHW
+ to_idx * kT * HW
+ ho_idx * kH * W
@@ -96,7 +99,6 @@
offs_k = k_start + tl.arange(0, BLOCK_K)
mask_k = offs_k < K
- # decode k -> (c, dt, dh, dw)
c_idx = offs_k // kTkHkW
rem_k = offs_k % kTkHkW
dt_idx = rem_k // kHkW
@@ -104,7 +106,6 @@
dh_idx = rem_k2 // kW
dw_idx = rem_k2 % kW
- # x addresses (BLOCK_M, BLOCK_K)
x_addr = (base_x[:, None]
+ c_idx[None, :] * THW
+ dt_idx[None, :] * HW
@@ -115,17 +116,19 @@
mask=mask_m[:, None] & mask_k[None, :],
other=0.0)
- # w addresses (BLOCK_K, BLOCK_N) -- w is (N, K) row-major
w_addr = offs_n[None, :] * K + offs_k[:, None]
w_vals = tl.load(w_ptr + w_addr,
mask=mask_n[None, :] & mask_k[:, None],
other=0.0)
- # tensor-core dot (BLOCK_M, K_tile) @ (K_tile, BLOCK_N)
acc += tl.dot(x_vals, w_vals)
- # -- store output --------------------------------------------------------
- out_addr = offs_m[:, None] * N + offs_n[None, :]
+ # -- store output in (B, N, To, Ho, Wo) layout ---------------------------
+ out_addr = (b_idx[:, None] * stride_out_b
+ + offs_n[None, :] * stride_out_n
+ + to_idx[:, None] * stride_out_to
+ + ho_idx[:, None] * stride_out_ho
+ + wo_idx[:, None] * stride_out_wo)
tl.store(out_ptr + out_addr, acc,
mask=mask_m[:, None] & mask_n[None, :])
@@ -135,11 +138,7 @@
# ---------------------------------------------------------------------------
class Model(nn.Module):
- """Patch-embedding via fused im2col+GEMM Triton kernel.
-
- Declares ``conv.weight`` of shape ``(embed_dim, C, kT, kH, kW)`` so that
- ``state_dict`` from a reference ``nn.Conv3d`` loads with ``strict=True``.
- """
+ """Patch-embedding via fused im2col+GEMM Triton kernel."""
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
@@ -155,7 +154,6 @@
self.kW = kW
self.embed_dim = embed_dim
- # Register weight under "conv" submodule to match reference state_dict
self.conv = nn.Module()
self.conv.weight = nn.Parameter(
torch.empty(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16)
@@ -169,31 +167,28
... (1637 more chars){
"content": "<7078 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify correctness after output layout change",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark after output layout optimization",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.568 gbps=19.761 ms=0.216
shape=0 variant=compiled tflops=3.045 gbps=16.863 ms=0.253
shape=0 variant=sota tflops=3.671 gbps=20.332 ms=0.210
shape=0 variant=solution tflops=11.873 gbps=65.751 ms=0.065
shape=0 solution_peak_fraction=0.0594
shape=1 variant=eager tflops=13.668 gbps=35.648 ms=0.226
shape=1 variant=compiled tflops=11.542 gbps=30.102 ms=0.267
shape=1 variant=sota tflops=13.859 gbps=36.145 ms=0.222
shape=1 variant=solution tflops=42.979 gbps=112.096 ms=0.072
shape=1 solution_peak_fraction=0.2149
shape=2 variant=eager tflops=25.891 gbps=53.481 ms=0.268
shape=2 variant=compiled tflops=22.736 gbps=46.963 ms=0.305
shape=2 variant=sota tflops=25.819 gbps=53.331 ms=0.269
shape=2 variant=solution tflops=78.822 gbps=162.816 ms=0.088
shape=2 solution_peak_fraction=0.3941
shape=3 variant=eager tflops=6.290 gbps=24.402 ms=0.147
shape=3 variant=compiled tflops=5.011 gbps=19.441 ms=0.185
shape=3 variant=sota tflops=6.362 gbps=24.684 ms=0.145
shape=3 variant=solution tflops=20.845 gbps=80.871 ms=0.044
shape=3 solution_peak_fraction=0.1042
peak_fraction: 0.1513
RESULT: OK
{
"command": "<1055 chars \u2014 see below>",
"description": "Profile shape 2 in detail",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}uv run python -c "
import torch, solution, reference, shapes
device = torch.device('cuda:0')
shape = shapes.SHAPES[2]
for k,v in shape.items():
setattr(reference, k, v)
init_args = reference.get_init_inputs()
sol_model = solution.Model(*init_args).to(device).eval()
ref_model = reference.Model(*init_args).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
torch.manual_seed(42)
inputs = [t.to(device) for t in reference.get_inputs()]
# Warmup
for _ in range(10):
with torch.no_grad(): sol_model(*inputs)
torch.cuda.synchronize()
# Time just the kernel (no .contiguous)
import time
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
with torch.no_grad(): sol_model(*inputs)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 100 * 1000
print(f'Shape 2 solution: {elapsed:.3f} ms')
# Compute theoretical TFLOPS
flops = 2 * 1 * 1280 * 3 * 8 * 336 * 336
tflops = flops / (elapsed * 1e-3) / 1e12
print(f'Shape 2 TFLOPS: {tflops:.1f} (peak_fraction={tflops/200:.4f})')
" 2>&1Shape 2 solution: 0.043 ms
Shape 2 TFLOPS: 160.1 (peak_fraction=0.8003)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,11 +1,10 @@
-"""Fused im2col + GEMM patch-embedding kernel (Triton).
-
-Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
-patch embedding) using a fused Triton kernel that gathers patches on-the-fly
-in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
-
-Output is written directly in (B, embed_dim, To, Ho, Wo) layout so no
-reshape/permute/contiguous is needed after the kernel.
+"""Two-phase patch-embedding: fast im2col + tensor-core GEMM (Triton).
+
+Phase 1: gather non-overlapping (kT, kH, kW) patches from x into a
+ contiguous (M, K) im2col buffer, loading kW-wide vectors along W.
+Phase 2: standard tiled GEMM out = im2col @ weight^T using tl.dot.
+
+Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
"""
import torch
import torch.nn as nn
@@ -17,16 +16,85 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
-# ---------------------------------------------------------------------------
-# Triton kernel: fused im2col + GEMM
-# ---------------------------------------------------------------------------
-# output[b, n, to, ho, wo] = sum_k x[b, c, to*kT+dt, ho*kH+dh, wo*kW+dw]
-# * weight[n, c, dt, dh, dw]
-#
-# m indexes (batch, patch_position) -> M = B * To * Ho * Wo
-# n indexes embed_dim -> N = embed_dim
-# k indexes (c, dt, dh, dw) -> K = C * kT * kH * kW
-# ---------------------------------------------------------------------------
+# ===== Phase 1: im2col =====================================================
+# Each programme handles one (m, c*dt*dh) row, loading kW contiguous elements
+# from x along the W dimension and storing them contiguously in the col buffer.
+# ============================================================================
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 64}, num_warps=4),
+ triton.Config({"BLOCK_M": 128}, num_warps=4),
+ triton.Config({"BLOCK_M": 256}, num_warps=8),
+ triton.Config({"BLOCK_M": 512}, num_warps=8),
+ ],
+ key=["M", "INNER"],
+)
+@triton.jit
+def _im2col_kernel(
+ x_ptr, col_ptr,
+ M, # total patches
+ INNER: tl.constexpr, # C * kT * kH
+ kW: tl.constexpr,
+ K: tl.constexpr, # INNER * kW
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_col_m,
+ B: tl.constexpr, C: tl.constexpr,
+ T: tl.constexpr, H: tl.constexpr, W: tl.constexpr,
+ kT: tl.constexpr, kH: tl.constexpr,
+ To: tl.constexpr, Ho: tl.constexpr, Wo: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < M
+
+ num_patches = To * Ho * Wo
+ HW = H * W
+ THW = T * HW
+ CTHW = C * THW
+ kHkW = kH * kW
+ kTkHkW = kT * kHkW
+
+ b_idx = offs_m // num_patches
+ rem = offs_m % num_patches
+ to_idx = rem // (Ho * Wo)
+ rem2 = rem % (Ho * Wo)
+ ho_idx = rem2 // Wo
+ wo_idx = rem2 % Wo
+
+ base_x = (b_idx * CTHW
+ + to_idx * kT * HW
+ + ho_idx * kH * W
+ + wo_idx * kW)
+
+ dw = tl.arange(0, kW) # contiguous along W
+
+ for inner in range(INNER):
+ c = inner // (kT * kH)
+ rem_i = inner % (kT * kH)
+ dt = rem_i // kH
+ dh = rem_i % kH
+
+ x_addr = (base_x[:, None]
+ + c * THW
+ + dt * HW
+ + dh * W
+ + dw[None, :])
+
+ x_vals = tl.load(x_ptr + x_addr,
+ mask=mask_m[:, None],
+ other=0.0)
+
+ col_k_start = c * kTkHkW + dt * kHkW + dh * kW
+ col_addr = offs_m[:, None] * stride_col_m + (col_k_start + dw[None, :])
+ tl.store(col_ptr + col_addr, x_vals,
+ mask=mask_m[:, None])
+
+
+# ===== Phase 2: GEMM =======================================================
+# C[M, N] = A[M, K] @ B[K, N] (A = im2col, B = weight transposed)
+# ============================================================================
@triton.autotune(
configs=[
@@ -44,37 +112,44 @@
key=["M", "N", "K"],
)
@triton.jit
-def _patch_embed_kernel(
- # Pointers
- x_ptr, w_ptr, out_ptr,
- # GEMM dims
+def _gemm_kernel(
+ a_ptr, b_ptr, out_ptr,
M, N: tl.constexpr, K: tl.constexpr,
- # Tensor strides
- stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
- stride_w_n,
+ stride_am, stride_bk,
stride_out_b, stride_out_n, stride_out_to, stride_out_ho, stride_out_wo,
- # Problem constants
- B: tl.constexpr, C: tl.constexpr,
- T: tl.constexpr, H: tl.constexpr, W: tl.constexpr,
- kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
+ B_batch: tl.constexpr,
To: tl.constexpr, Ho: tl.constexpr, Wo: tl.constexpr,
- # Tile sizes
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_patches = To * Ho * Wo
- kHkW = kH * kW
- kTkHkW = kT * kHkW
- HW = H * W
- THW = T * HW
- CTHW = C * THW
-
- # -- M tile: decode into (b, to, ho, wo) ---------------------------------
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_m = offs_m < M
-
+ mask_n = offs_n < N
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, K, BLOCK_K):
+ offs_k = k_start + tl.arange(0, BLOCK_K)
+ mask_k = offs_k < K
+
+ a_addr = offs_m[:, None] * stride_am + offs_k[None, :]
+ a_vals = tl.load(a_ptr + a_addr,
+ mask=mask_m[:, None] & mask_k[None, :],
+ other=0.0)
+
+ b_addr = offs_k[:, None] * stride_bk + offs_n[
... (4234 more chars){
"content": "<8399 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify two-phase approach correctness",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 583, in arange
raise ValueError("arange's range must be a power of 2")
ValueError: arange's range must be a power of 2
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py", line 117, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py", line 78, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^
... (3381 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py</path>
<type>file</type>
<content>
1: """Two-phase patch-embedding: fast im2col + tensor-core GEMM (Triton).
2:
3: Phase 1: gather non-overlapping (kT, kH, kW) patches from x into a
4: contiguous (M, K) im2col buffer, loading kW-wide vectors along W.
5: Phase 2: standard tiled GEMM out = im2col @ weight^T using tl.dot.
6:
7: Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
8: """
9: import torch
10: import torch.nn as nn
11: import triton
12: import triton.language as tl
13:
14: OP_TYPE = "patch_embed"
15: SUPPORTED_PRECISIONS = ["bf16"]
16: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
17:
18:
19: # ===== Phase 1: im2col =====================================================
20: # Each programme handles one (m, c*dt*dh) row, loading kW contiguous elements
21: # from x along the W dimension and storing them contiguously in the col buffer.
22: # ============================================================================
23:
24: @triton.autotune(
25: configs=[
26: triton.Config({"BLOCK_M": 64}, num_warps=4),
27: triton.Config({"BLOCK_M": 128}, num_warps=4),
28: triton.Config({"BLOCK_M": 256}, num_warps=8),
29: triton.Config({"BLOCK_M": 512}, num_warps=8),
30: ],
31: key=["M", "INNER"],
32: )
33: @triton.jit
34: def _im2col_kernel(
35: x_ptr, col_ptr,
36: M, # total patches
37: INNER: tl.constexpr, # C * kT
... (8181 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -16,10 +16,14 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+def _next_pow2(n: int) -> int:
+ p = 1
+ while p < n:
+ p <<= 1
+ return p
+
+
# ===== Phase 1: im2col =====================================================
-# Each programme handles one (m, c*dt*dh) row, loading kW contiguous elements
-# from x along the W dimension and storing them contiguously in the col buffer.
-# ============================================================================
@triton.autotune(
configs=[
@@ -33,10 +37,11 @@
@triton.jit
def _im2col_kernel(
x_ptr, col_ptr,
- M, # total patches
- INNER: tl.constexpr, # C * kT * kH
+ M,
+ INNER: tl.constexpr,
kW: tl.constexpr,
- K: tl.constexpr, # INNER * kW
+ PADDED_KW: tl.constexpr,
+ K: tl.constexpr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_col_m,
B: tl.constexpr, C: tl.constexpr,
@@ -68,7 +73,8 @@
+ ho_idx * kH * W
+ wo_idx * kW)
- dw = tl.arange(0, kW) # contiguous along W
+ dw = tl.arange(0, PADDED_KW)
+ dw_mask = dw < kW
for inner in range(INNER):
c = inner // (kT * kH)
@@ -83,18 +89,16 @@
+ dw[None, :])
x_vals = tl.load(x_ptr + x_addr,
- mask=mask_m[:, None],
+ mask=mask_m[:, None] & dw_mask[None, :],
other=0.0)
col_k_start = c * kTkHkW + dt * kHkW + dh * kW
col_addr = offs_m[:, None] * stride_col_m + (col_k_start + dw[None, :])
tl.store(col_ptr + col_addr, x_vals,
- mask=mask_m[:, None])
+ mask=mask_m[:, None] & dw_mask[None, :])
# ===== Phase 2: GEMM =======================================================
-# C[M, N] = A[M, K] @ B[K, N] (A = im2col, B = weight transposed)
-# ============================================================================
@triton.autotune(
configs=[
@@ -149,7 +153,6 @@
acc += tl.dot(a_vals, b_vals)
- # Decode m -> (b, to, ho, wo) for output store
b_idx = offs_m // num_patches
rem = offs_m % num_patches
to_idx = rem // (Ho * Wo)
@@ -202,6 +205,7 @@
M = B * num_patches
N = embed_dim
INNER = C * kT * kH
+ PADDED_KW = _next_pow2(kW)
# Phase 1: im2col
col = torch.empty((M, K), dtype=x.dtype, device=x.device)
@@ -211,18 +215,15 @@
_im2col_kernel[grid_im2col](
x, col,
- M, INNER, kW, K,
+ M, INNER, kW, PADDED_KW, K,
x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
- K, # stride_col_m = K (row-major)
+ K,
B, C, T, H, W,
kT, kH,
To, Ho, Wo,
)
# Phase 2: GEMM out = col @ w^T
- # w is (N, K) row-major; for B = w^T we need (K, N) -> load w[n, k] as b[k, n]
- # b[k, n] = w_ptr[n * K + k] -> stride_bk = 1, but that's column-major on w.
- # Instead, pre-transpose w to (K, N) for contiguous column loads.
w = self.conv.weight.reshape(N, K)
w_t = w.T.contiguous() # (K, N) contiguous
{
"content": "<7848 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Verify two-phase approach correctness",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark two-phase approach",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.501 gbps=19.389 ms=0.220
shape=0 variant=compiled tflops=2.953 gbps=16.351 ms=0.261
shape=0 variant=sota tflops=3.669 gbps=20.316 ms=0.210
shape=0 variant=solution tflops=6.907 gbps=38.249 ms=0.112
shape=0 solution_peak_fraction=0.0345
shape=1 variant=eager tflops=13.526 gbps=35.278 ms=0.228
shape=1 variant=compiled tflops=11.009 gbps=28.714 ms=0.280
shape=1 variant=sota tflops=13.858 gbps=36.143 ms=0.222
shape=1 variant=solution tflops=26.554 gbps=69.257 ms=0.116
shape=1 solution_peak_fraction=0.1328
shape=2 variant=eager tflops=26.177 gbps=54.072 ms=0.265
shape=2 variant=compiled tflops=22.350 gbps=46.166 ms=0.310
shape=2 variant=sota tflops=25.365 gbps=52.395 ms=0.273
shape=2 variant=solution tflops=48.952 gbps=101.117 ms=0.142
shape=2 solution_peak_fraction=0.2448
shape=3 variant=eager tflops=6.282 gbps=24.373 ms=0.147
shape=3 variant=compiled tflops=4.877 gbps=18.923 ms=0.190
shape=3 variant=sota tflops=6.433 gbps=24.956 ms=0.144
shape=3 variant=solution tflops=13.362 gbps=51.839 ms=0.069
shape=3 solution_peak_fraction=0.0668
peak_fraction: 0.0931
RESULT: LOW
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,9 +1,10 @@
-"""Two-phase patch-embedding: fast im2col + tensor-core GEMM (Triton).
+"""Fused im2col + GEMM patch-embedding kernel (Triton).
-Phase 1: gather non-overlapping (kT, kH, kW) patches from x into a
- contiguous (M, K) im2col buffer, loading kW-wide vectors along W.
-Phase 2: standard tiled GEMM out = im2col @ weight^T using tl.dot.
+Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
+patch embedding) using a fused Triton kernel that gathers patches on-the-fly
+in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
+Weight is pre-transposed to (K, N) for contiguous GEMM loads.
Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
"""
import torch
@@ -16,50 +17,48 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
-def _next_pow2(n: int) -> int:
- p = 1
- while p < n:
- p <<= 1
- return p
-
-
-# ===== Phase 1: im2col =====================================================
-
@triton.autotune(
configs=[
- triton.Config({"BLOCK_M": 64}, num_warps=4),
- triton.Config({"BLOCK_M": 128}, num_warps=4),
- triton.Config({"BLOCK_M": 256}, num_warps=8),
- triton.Config({"BLOCK_M": 512}, num_warps=8),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 16}, num_warps=4, num_stages=3),
],
- key=["M", "INNER"],
+ key=["M", "N", "K"],
)
@triton.jit
-def _im2col_kernel(
- x_ptr, col_ptr,
- M,
- INNER: tl.constexpr,
- kW: tl.constexpr,
- PADDED_KW: tl.constexpr,
- K: tl.constexpr,
+def _patch_embed_kernel(
+ x_ptr, w_ptr, out_ptr,
+ M, N: tl.constexpr, K: tl.constexpr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
- stride_col_m,
+ stride_wk, # weight is (K, N) row-major -> stride_wk = N
+ stride_out_b, stride_out_n, stride_out_to, stride_out_ho, stride_out_wo,
B: tl.constexpr, C: tl.constexpr,
T: tl.constexpr, H: tl.constexpr, W: tl.constexpr,
- kT: tl.constexpr, kH: tl.constexpr,
+ kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
To: tl.constexpr, Ho: tl.constexpr, Wo: tl.constexpr,
- BLOCK_M: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
- pid = tl.program_id(0)
- offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
- mask_m = offs_m < M
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
num_patches = To * Ho * Wo
+ kHkW = kH * kW
+ kTkHkW = kT * kHkW
HW = H * W
THW = T * HW
CTHW = C * THW
- kHkW = kH * kW
- kTkHkW = kT * kHkW
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < M
b_idx = offs_m // num_patches
rem = offs_m % num_patches
@@ -73,66 +72,7 @@
+ ho_idx * kH * W
+ wo_idx * kW)
- dw = tl.arange(0, PADDED_KW)
- dw_mask = dw < kW
-
- for inner in range(INNER):
- c = inner // (kT * kH)
- rem_i = inner % (kT * kH)
- dt = rem_i // kH
- dh = rem_i % kH
-
- x_addr = (base_x[:, None]
- + c * THW
- + dt * HW
- + dh * W
- + dw[None, :])
-
- x_vals = tl.load(x_ptr + x_addr,
- mask=mask_m[:, None] & dw_mask[None, :],
- other=0.0)
-
- col_k_start = c * kTkHkW + dt * kHkW + dh * kW
- col_addr = offs_m[:, None] * stride_col_m + (col_k_start + dw[None, :])
- tl.store(col_ptr + col_addr, x_vals,
- mask=mask_m[:, None] & dw_mask[None, :])
-
-
-# ===== Phase 2: GEMM =======================================================
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- ],
- key=["M", "N", "K"],
-
... (4675 more chars){
"content": "<6111 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Check correctness and benchmark fused kernel with transposed weight",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
shape=0 variant=eager tflops=3.612 gbps=20.001 ms=0.213
shape=0 variant=compiled tflops=2.983 gbps=16.521 ms=0.258
shape=0 variant=sota tflops=3.629 gbps=20.099 ms=0.212
shape=0 variant=solution tflops=11.326 gbps=62.721 ms=0.068
shape=0 solution_peak_fraction=0.0566
shape=1 variant=eager tflops=13.589 gbps=35.442 ms=0.227
shape=1 variant=compiled tflops=11.377 gbps=29.672 ms=0.271
shape=1 variant=sota tflops=13.797 gbps=35.985 ms=0.223
shape=1 variant=solution tflops=44.944 gbps=117.221 ms=0.069
shape=1 solution_peak_fraction=0.2247
shape=2 variant=eager tflops=25.952 gbps=53.606 ms=0.267
shape=2 variant=compiled tflops=22.676 gbps=46.840 ms=0.306
shape=2 variant=sota tflops=26.471 gbps=54.680 ms=0.262
shape=2 variant=solution tflops=87.757 gbps=181.273 ms=0.079
shape=2 solution_peak_fraction=0.4388
shape=3 variant=eager tflops=6.237 gbps=24.197 ms=0.148
shape=3 variant=compiled tflops=4.917 gbps=19.077 ms=0.188
shape=3 variant=sota tflops=6.395 gbps=24.810 ms=0.145
shape=3 variant=solution tflops=22.784 gbps=88.394 ms=0.041
shape=3 solution_peak_fraction=0.1139
peak_fraction: 0.1588
RESULT: OK
{
"todos": [
{
"content": "Read reference.py, benchmark.py, check.py, problem.yaml to understand the task",
"priority": "high",
"status": "completed"
},
{
"content": "Read CLAUDE.md and any other project context files",
"priority": "high",
"status": "completed"
},
{
"content": "Check GPU availability and CUDA environment",
"priority": "high",
"status": "completed"
},
{
"content": "Implement initial Triton kernel for 3D patch embedding (fused im2col + GEMM)",
"priority": "high",
"status": "comp
... (555 more chars)content: 'Read reference.py, benchmark.py, check.py, problem.yaml to understand the task'
status: 'completed'
priority: 'high'
content: 'Read CLAUDE.md and any other project context files'
status: 'completed'
priority: 'high'
content: 'Check GPU availability and CUDA environment'
status: 'completed'
priority: 'high'
content: 'Implement initial Triton kernel for 3D patch embedding (fused im2col + GEMM)'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Implement CUDA C++ WMMA kernel for better memory access control'
status: 'in_progress'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'medium'
content: 'Final verification with check.py and benchmark.py'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,117 +1,249 @@
-"""Fused im2col + GEMM patch-embedding kernel (Triton).
-
-Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
-patch embedding) using a fused Triton kernel that gathers patches on-the-fly
-in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
-
-Weight is pre-transposed to (K, N) for contiguous GEMM loads.
+"""Fused im2col + GEMM patch-embedding kernel (CUDA C++ with WMMA).
+
+Uses shared memory to cache the im2col data and weight, then computes the
+GEMM using WMMA (Warp Matrix Multiply Accumulate) tensor core instructions.
+
Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
"""
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
+from torch.utils.cpp_extension import load_inline
OP_TYPE = "patch_embed"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
-
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 16}, num_warps=4, num_stages=3),
- ],
- key=["M", "N", "K"],
+_cuda_source = r"""
+#include <torch/extension.h>
+#include <cuda_bf16.h>
+#include <mma.h>
+
+using namespace nvcuda;
+
+// WMMA tile sizes
+#define WMMA_M 16
+#define WMMA_N 16
+#define WMMA_K 16
+
+// Block tile sizes
+#define BLOCK_M 64
+#define BLOCK_N 64
+#define BLOCK_K 32
+
+// Number of WMMA tiles per block
+#define WARP_M (BLOCK_M / WMMA_M) // 4
+#define WARP_N (BLOCK_N / WMMA_N) // 4
+
+// Number of warps per block
+#define NUM_WARPS (WARP_M * WARP_N) // 16
+
+__global__ void __launch_bounds__(NUM_WARPS * 32)
+patch_embed_kernel(
+ const __nv_bfloat16* __restrict__ x_ptr,
+ const __nv_bfloat16* __restrict__ w_ptr, // (K, N) row-major
+ __nv_bfloat16* __restrict__ out_ptr,
+ int M, int N, int K,
+ int stride_x_b, int stride_x_c, int stride_x_t, int stride_x_h, int stride_x_w,
+ int stride_wk, // stride for (K, N) weight = N
+ int stride_out_b, int stride_out_n, int stride_out_to, int stride_out_ho, int stride_out_wo,
+ int B, int C, int T, int H, int W,
+ int kT, int kH, int kW,
+ int To, int Ho, int Wo
+) {
+ // Block indices
+ int block_m = blockIdx.x;
+ int block_n = blockIdx.y;
+
+ // Thread indices
+ int tid = threadIdx.x;
+ int warp_id = tid / 32;
+ int warp_m = warp_id / WARP_N; // which warp row
+ int warp_n = warp_id % WARP_N; // which warp col
+
+ // Shared memory for A tile (im2col) and B tile (weight)
+ __shared__ __nv_bfloat16 smem_a[BLOCK_M][BLOCK_K];
+ __shared__ __nv_bfloat16 smem_b[BLOCK_K][BLOCK_N];
+
+ // WMMA fragments
+ wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __nv_bfloat16, wmma::row_major> a_frag;
+ wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __nv_bfloat16, wmma::row_major> b_frag;
+ wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
+ wmma::fill_fragment(acc_frag, 0.0f);
+
+ // Precompute constants
+ int num_patches = To * Ho * Wo;
+ int HW = H * W;
+ int THW = T * HW;
+ int CTHW = C * THW;
+ int kHkW = kH * kW;
+ int kTkHkW = kT * kHkW;
+
+ // Global M offsets for this block
+ int m_base = block_m * BLOCK_M;
+
+ // Decode M offsets into (b, to, ho, wo)
+ int m_offs[BLOCK_M];
+ int b_idx[BLOCK_M], to_idx[BLOCK_M], ho_idx[BLOCK_M], wo_idx[BLOCK_M];
+ int base_x[BLOCK_M];
+
+ for (int i = 0; i < BLOCK_M; i++) {
+ int m = m_base + i;
+ m_offs[i] = m;
+ if (m < M) {
+ int b = m / num_patches;
+ int rem = m % num_patches;
+ int to = rem / (Ho * Wo);
+ int rem2 = rem % (Ho * Wo);
+ int ho = rem2 / Wo;
+ int wo = rem2 % Wo;
+ b_idx[i] = b;
+ to_idx[i] = to;
+ ho_idx[i] = ho;
+ wo_idx[i] = wo;
+ base_x[i] = b * CTHW + to * kT * HW + ho * kH * W + wo * kW;
+ }
+ }
+
+ // K loop
+ for (int k_start = 0; k_start < K; k_start += BLOCK_K) {
+ // Load A tile (im2col) into shared memory
+ // Each thread loads multiple elements
+ int tid_m = tid / BLOCK_K; // which row
+ int tid_k = tid % BLOCK_K; // which col
+
+ // We have NUM_WARPS * 32 = 512 threads
+ // BLOCK_M * BLOCK_K = 64 * 32 = 2048 elements
+ // So each thread loads 2048 / 512 = 4 elements
+ for (int i = tid; i < BLOCK_M * BLOCK_K; i += NUM_WARPS * 32) {
+ int row = i / BLOCK_K;
+ int col = i % BLOCK_K;
+ int m = m_base + row;
+ int k = k_start + col;
+
+ __nv_bfloat16 val = __float2bfloat16(0.0f);
+ if (m < M && k < K) {
+ int c = k / k
... (7751 more chars){
"content": "<9102 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,249 +1,116 @@
-"""Fused im2col + GEMM patch-embedding kernel (CUDA C++ with WMMA).
+"""Fused im2col + GEMM patch-embedding kernel (Triton).
-Uses shared memory to cache the im2col data and weight, then computes the
-GEMM using WMMA (Warp Matrix Multiply Accumulate) tensor core instructions.
+Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
+patch embedding) using a fused Triton kernel that gathers patches on-the-fly
+in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
+Weight is pre-transposed to (K, N) for contiguous GEMM loads.
Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
"""
import torch
import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
+import triton
+import triton.language as tl
OP_TYPE = "patch_embed"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
-_cuda_source = r"""
-#include <torch/extension.h>
-#include <cuda_bf16.h>
-#include <mma.h>
-using namespace nvcuda;
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 16}, num_warps=4, num_stages=3),
+ ],
+ key=["M", "N", "K"],
+)
+@triton.jit
+def _patch_embed_kernel(
+ x_ptr, w_ptr, out_ptr,
+ M, N: tl.constexpr, K: tl.constexpr,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_wk,
+ stride_out_b, stride_out_n, stride_out_to, stride_out_ho, stride_out_wo,
+ B: tl.constexpr, C: tl.constexpr,
+ T: tl.constexpr, H: tl.constexpr, W: tl.constexpr,
+ kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
+ To: tl.constexpr, Ho: tl.constexpr, Wo: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
-// WMMA tile sizes
-#define WMMA_M 16
-#define WMMA_N 16
-#define WMMA_K 16
+ num_patches = To * Ho * Wo
+ kHkW = kH * kW
+ kTkHkW = kT * kHkW
+ HW = H * W
+ THW = T * HW
+ CTHW = C * THW
-// Block tile sizes
-#define BLOCK_M 64
-#define BLOCK_N 64
-#define BLOCK_K 32
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < M
-// Number of WMMA tiles per block
-#define WARP_M (BLOCK_M / WMMA_M) // 4
-#define WARP_N (BLOCK_N / WMMA_N) // 4
+ b_idx = offs_m // num_patches
+ rem = offs_m % num_patches
+ to_idx = rem // (Ho * Wo)
+ rem2 = rem % (Ho * Wo)
+ ho_idx = rem2 // Wo
+ wo_idx = rem2 % Wo
-// Number of warps per block
-#define NUM_WARPS (WARP_M * WARP_N) // 16
+ base_x = (b_idx * CTHW
+ + to_idx * kT * HW
+ + ho_idx * kH * W
+ + wo_idx * kW)
-__global__ void __launch_bounds__(NUM_WARPS * 32)
-patch_embed_kernel(
- const __nv_bfloat16* __restrict__ x_ptr,
- const __nv_bfloat16* __restrict__ w_ptr, // (K, N) row-major
- __nv_bfloat16* __restrict__ out_ptr,
- int M, int N, int K,
- int stride_x_b, int stride_x_c, int stride_x_t, int stride_x_h, int stride_x_w,
- int stride_wk, // stride for (K, N) weight = N
- int stride_out_b, int stride_out_n, int stride_out_to, int stride_out_ho, int stride_out_wo,
- int B, int C, int T, int H, int W,
- int kT, int kH, int kW,
- int To, int Ho, int Wo
-) {
- // Block indices
- int block_m = blockIdx.x;
- int block_n = blockIdx.y;
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < N
- // Thread indices
- int tid = threadIdx.x;
- int warp_id = tid / 32;
- int warp_m = warp_id / WARP_N; // which warp row
- int warp_n = warp_id % WARP_N; // which warp col
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- // Shared memory for A tile (im2col) and B tile (weight)
- __shared__ __nv_bfloat16 smem_a[BLOCK_M][BLOCK_K];
- __shared__ __nv_bfloat16 smem_b[BLOCK_K][BLOCK_N];
+ for k_start in range(0, K, BLOCK_K):
+ offs_k = k_start + tl.arange(0, BLOCK_K)
+ mask_k = offs_k < K
- // WMMA fragments
- wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __nv_bfloat16, wmma::row_major> a_frag;
- wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __nv_bfloat16, wmma::row_major> b_frag;
- wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
- wmma::fill_fragment(acc_frag, 0.0f);
+ c_idx = offs_k // kTkHkW
+ rem_k = offs_k % kTkHkW
+ dt_idx = rem_k // kHkW
+ rem_k2 = rem_k % kHkW
+ dh_idx = rem_k2 // kW
+ dw_idx = rem_k2 % kW
- // Precompute constants
- int num_patches = To * Ho * Wo;
- int HW = H * W;
- int THW = T * HW;
- int CTHW = C * THW;
- int kHkW = kH * kW;
- int kTkHkW = kT * kHkW;
+ x_a
... (7585 more chars){
"content": "<5973 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Check correctness and benchmark",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
shape=0 variant=eager tflops=3.615 gbps=20.019 ms=0.213
shape=0 variant=compiled tflops=2.655 gbps=14.702 ms=0.290
shape=0 variant=sota tflops=3.289 gbps=18.215 ms=0.234
shape=0 variant=solution tflops=11.135 gbps=61.663 ms=0.069
shape=0 solution_peak_fraction=0.0557
shape=1 variant=eager tflops=13.593 gbps=35.452 ms=0.227
shape=1 variant=compiled tflops=11.544 gbps=30.108 ms=0.267
shape=1 variant=sota tflops=13.804 gbps=36.003 ms=0.223
shape=1 variant=solution tflops=41.100 gbps=107.195 ms=0.075
shape=1 solution_peak_fraction=0.2055
shape=2 variant=eager tflops=26.061 gbps=53.832 ms=0.266
shape=2 variant=compiled tflops=22.642 gbps=46.769 ms=0.306
shape=2 variant=sota tflops=26.437 gbps=54.610 ms=0.262
shape=2 variant=solution tflops=84.771 gbps=175.105 ms=0.082
shape=2 solution_peak_fraction=0.4239
shape=3 variant=eager tflops=6.130 gbps=23.781 ms=0.151
shape=3 variant=compiled tflops=4.833 gbps=18.752 ms=0.191
shape=3 variant=sota tflops=6.421 gbps=24.912 ms=0.144
shape=3 variant=solution tflops=21.962 gbps=85.204 ms=0.042
shape=3 solution_peak_fraction=0.1098
peak_fraction: 0.1519
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,8 +1,8 @@
"""Fused im2col + GEMM patch-embedding kernel (Triton).
-Implements the same operation as nn.Conv3d with stride==kernel (non-overlapping
-patch embedding) using a fused Triton kernel that gathers patches on-the-fly
-in the K-loop of a tensor-core GEMM, avoiding materialising the im2col buffer.
+Key optimization: BLOCK_K = padded kW so that adjacent threads along K read
+contiguous dw elements from x (coalesced), instead of scattered elements
+across different (c, dt, dh) boundaries.
Weight is pre-transposed to (K, N) for contiguous GEMM loads.
Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
@@ -17,27 +17,34 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+def _next_pow2(n: int) -> int:
+ p = 1
+ while p < n:
+ p <<= 1
+ return p
+
+
@triton.autotune(
configs=[
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_warps=4, num_stages=3),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 16}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=3),
],
- key=["M", "N", "K"],
+ key=["M", "N", "NUM_OUTER"],
)
@triton.jit
def _patch_embed_kernel(
x_ptr, w_ptr, out_ptr,
M, N: tl.constexpr, K: tl.constexpr,
+ NUM_OUTER: tl.constexpr, # C * kT * kH
+ PADDED_KW: tl.constexpr, # next_power_of_2(kW), must be >= 16
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_wk,
stride_out_b, stride_out_n, stride_out_to, stride_out_ho, stride_out_wo,
@@ -45,7 +52,7 @@
T: tl.constexpr, H: tl.constexpr, W: tl.constexpr,
kT: tl.constexpr, kH: tl.constexpr, kW: tl.constexpr,
To: tl.constexpr, Ho: tl.constexpr, Wo: tl.constexpr,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
@@ -75,31 +82,31 @@
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
+ offs_dw = tl.arange(0, PADDED_KW)
+ mask_dw = offs_dw < kW
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_start in range(0, K, BLOCK_K):
- offs_k = k_start + tl.arange(0, BLOCK_K)
- mask_k = offs_k < K
+ for outer in range(NUM_OUTER):
+ c = outer // (kT * kH)
+ rem_o = outer % (kT * kH)
+ dt = rem_o // kH
+ dh = rem_o % kH
- c_idx = offs_k // kTkHkW
- rem_k = offs_k % kTkHkW
- dt_idx = rem_k // kHkW
- rem_k2 = rem_k % kHkW
- dh_idx = rem_k2 // kW
- dw_idx = rem_k2 % kW
+ k_base = c * kTkHkW + dt * kHkW + dh * kW
x_addr = (base_x[:, None]
- + c_idx[None, :] * THW
- + dt_idx[None, :] * HW
- + dh_idx[None, :] * W
- + dw_idx[None, :])
+ + c * THW
+ + dt * HW
+ + dh * W
+ + offs_dw[None, :])
x_vals = tl.load(x_ptr + x_addr,
- mask=mask_m[:, None] & mask_k[None, :],
+ mask=mask_m[:, None] & mask_dw[None, :],
other=0.0)
- w_addr = offs_k[:, None] * stride_wk + offs_n[None, :]
+ w_addr = (k_base + offs_dw[:, None]) * stride_wk + offs_n[None, :]
w_vals = tl.load(w_ptr + w_addr,
- mask=mask_k[:, None] & mask_n[None, :],
+ mask=mask_dw[:, None] & mask_n[None, :],
other=0.0)
acc += tl.dot(x_vals, w_vals)
@@ -144,6 +151,11 @@
K = C * kT * kH * kW
M = B * To * Ho * Wo
N = embed_dim
+ NUM_OUTER = C * kT * kH
+ PADDED_KW = _next_pow2(kW)
+
+ if PADDED_KW < 16:
+ PADDED_KW = 16
if self._w_t is None or self._w_t.shape != (K, N):
self._w_t = self.conv.weight.reshape(N, K).T.contiguous()
@@ -157,6 +169,7 @@
_patch_embed_kernel[grid](
x, self._w
... (248 more chars){
"content": "<5861 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Check correctness of BLOCK_K=padded kW approach",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark BLOCK_K=padded kW approach",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.539 gbps=19.598 ms=0.218
shape=0 variant=compiled tflops=3.071 gbps=17.007 ms=0.251
shape=0 variant=sota tflops=3.679 gbps=20.375 ms=0.209
shape=0 variant=solution tflops=13.946 gbps=77.230 ms=0.055
shape=0 solution_peak_fraction=0.0697
shape=1 variant=eager tflops=13.448 gbps=35.076 ms=0.229
shape=1 variant=compiled tflops=11.540 gbps=30.097 ms=0.267
shape=1 variant=sota tflops=13.190 gbps=34.401 ms=0.234
shape=1 variant=solution tflops=50.651 gbps=132.105 ms=0.061
shape=1 solution_peak_fraction=0.2533
shape=2 variant=eager tflops=26.124 gbps=53.961 ms=0.266
shape=2 variant=compiled tflops=22.191 gbps=45.838 ms=0.313
shape=2 variant=sota tflops=26.290 gbps=54.305 ms=0.264
shape=2 variant=solution tflops=104.012 gbps=214.848 ms=0.067
shape=2 solution_peak_fraction=0.5201
shape=3 variant=eager tflops=6.220 gbps=24.132 ms=0.149
shape=3 variant=compiled tflops=4.868 gbps=18.888 ms=0.190
shape=3 variant=sota tflops=6.458 gbps=25.057 ms=0.143
shape=3 variant=solution tflops=14.917 gbps=57.873 ms=0.062
shape=3 solution_peak_fraction=0.0746
peak_fraction: 0.1618
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,8 +1,9 @@
"""Fused im2col + GEMM patch-embedding kernel (Triton).
-Key optimization: BLOCK_K = padded kW so that adjacent threads along K read
-contiguous dw elements from x (coalesced), instead of scattered elements
-across different (c, dt, dh) boundaries.
+Key optimization: BLOCK_K = PADDED_KW * UNROLL so that:
+- Each k-tile of PADDED_KW elements maps to one (c, dt, dh) row
+- x reads within each k-tile are contiguous along dw (coalesced)
+- UNROLL reduces loop iterations while maintaining coalescing
Weight is pre-transposed to (K, N) for contiguous GEMM loads.
Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
@@ -43,8 +44,10 @@
def _patch_embed_kernel(
x_ptr, w_ptr, out_ptr,
M, N: tl.constexpr, K: tl.constexpr,
- NUM_OUTER: tl.constexpr, # C * kT * kH
- PADDED_KW: tl.constexpr, # next_power_of_2(kW), must be >= 16
+ NUM_OUTER: tl.constexpr,
+ PADDED_KW: tl.constexpr,
+ UNROLL: tl.constexpr,
+ BLOCK_K: tl.constexpr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_wk,
stride_out_b, stride_out_n, stride_out_to, stride_out_ho, stride_out_wo,
@@ -82,29 +85,33 @@
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
- offs_dw = tl.arange(0, PADDED_KW)
- mask_dw = offs_dw < kW
+ offs_k = tl.arange(0, BLOCK_K)
+ u = offs_k // PADDED_KW
+ dw = offs_k % PADDED_KW
+ mask_dw = dw < kW
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for outer in range(NUM_OUTER):
- c = outer // (kT * kH)
- rem_o = outer % (kT * kH)
- dt = rem_o // kH
- dh = rem_o % kH
+ row_base = outer * UNROLL
+ row_idx = row_base + u
- k_base = c * kTkHkW + dt * kHkW + dh * kW
+ c = row_idx // (kT * kH)
+ rem_r = row_idx % (kT * kH)
+ dt = rem_r // kH
+ dh = rem_r % kH
x_addr = (base_x[:, None]
- + c * THW
- + dt * HW
- + dh * W
- + offs_dw[None, :])
+ + c[None, :] * THW
+ + dt[None, :] * HW
+ + dh[None, :] * W
+ + dw[None, :])
x_vals = tl.load(x_ptr + x_addr,
mask=mask_m[:, None] & mask_dw[None, :],
other=0.0)
- w_addr = (k_base + offs_dw[:, None]) * stride_wk + offs_n[None, :]
+ k_base = c * kTkHkW + dt * kHkW + dh * kW
+ w_addr = (k_base[:, None]) * stride_wk + dw[:, None] * stride_wk + offs_n[None, :]
w_vals = tl.load(w_ptr + w_addr,
mask=mask_dw[:, None] & mask_n[None, :],
other=0.0)
@@ -151,11 +158,12 @@
K = C * kT * kH * kW
M = B * To * Ho * Wo
N = embed_dim
- NUM_OUTER = C * kT * kH
PADDED_KW = _next_pow2(kW)
-
if PADDED_KW < 16:
PADDED_KW = 16
+ UNROLL = 2
+ BLOCK_K = PADDED_KW * UNROLL
+ NUM_OUTER = (C * kT * kH + UNROLL - 1) // UNROLL
if self._w_t is None or self._w_t.shape != (K, N):
self._w_t = self.conv.weight.reshape(N, K).T.contiguous()
@@ -169,7 +177,7 @@
_patch_embed_kernel[grid](
x, self._w_t, out,
M, N, K,
- NUM_OUTER, PADDED_KW,
+ NUM_OUTER, PADDED_KW, UNROLL, BLOCK_K,
x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
N,
out.stride(0), out.stride(1), out.stride(2), out.stride(3), out.stride(4),
{
"content": "<6170 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Check correctness of UNROLL=2 approach",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark UNROLL=2 approach",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.619 gbps=20.043 ms=0.213
shape=0 variant=compiled tflops=3.044 gbps=16.860 ms=0.253
shape=0 variant=sota tflops=3.668 gbps=20.312 ms=0.210
shape=0 variant=solution tflops=15.128 gbps=83.779 ms=0.051
shape=0 solution_peak_fraction=0.0756
shape=1 variant=eager tflops=13.788 gbps=35.962 ms=0.224
shape=1 variant=compiled tflops=11.464 gbps=29.900 ms=0.269
shape=1 variant=sota tflops=13.872 gbps=36.179 ms=0.222
shape=1 variant=solution tflops=54.122 gbps=141.160 ms=0.057
shape=1 solution_peak_fraction=0.2706
shape=2 variant=eager tflops=26.196 gbps=54.111 ms=0.265
shape=2 variant=compiled tflops=22.431 gbps=46.334 ms=0.309
shape=2 variant=sota tflops=26.509 gbps=54.757 ms=0.262
shape=2 variant=solution tflops=111.159 gbps=229.612 ms=0.062
shape=2 solution_peak_fraction=0.5558
shape=3 variant=eager tflops=6.321 gbps=24.522 ms=0.146
shape=3 variant=compiled tflops=4.895 gbps=18.990 ms=0.189
shape=3 variant=sota tflops=6.428 gbps=24.937 ms=0.144
shape=3 variant=solution tflops=18.896 gbps=73.310 ms=0.049
shape=3 solution_peak_fraction=0.0945
peak_fraction: 0.1811
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -161,7 +161,7 @@
PADDED_KW = _next_pow2(kW)
if PADDED_KW < 16:
PADDED_KW = 16
- UNROLL = 2
+ UNROLL = 4
BLOCK_K = PADDED_KW * UNROLL
NUM_OUTER = (C * kT * kH + UNROLL - 1) // UNROLL
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": " UNROLL = 2",
"newString": " UNROLL = 4"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Test UNROLL=4",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
shape=0 variant=eager tflops=3.614 gbps=20.011 ms=0.213
shape=0 variant=compiled tflops=2.994 gbps=16.583 ms=0.257
shape=0 variant=sota tflops=3.630 gbps=20.100 ms=0.212
shape=0 variant=solution tflops=15.819 gbps=87.603 ms=0.049
shape=0 solution_peak_fraction=0.0791
shape=1 variant=eager tflops=13.661 gbps=35.630 ms=0.226
shape=1 variant=compiled tflops=11.392 gbps=29.713 ms=0.271
shape=1 variant=sota tflops=13.766 gbps=35.903 ms=0.224
shape=1 variant=solution tflops=53.506 gbps=139.552 ms=0.058
shape=1 solution_peak_fraction=0.2675
shape=2 variant=eager tflops=25.897 gbps=53.494 ms=0.268
shape=2 variant=compiled tflops=22.159 gbps=45.772 ms=0.313
shape=2 variant=sota tflops=26.210 gbps=54.141 ms=0.265
shape=2 variant=solution tflops=108.407 gbps=223.928 ms=0.064
shape=2 solution_peak_fraction=0.5420
shape=3 variant=eager tflops=6.269 gbps=24.323 ms=0.148
shape=3 variant=compiled tflops=4.875 gbps=18.915 ms=0.190
shape=3 variant=sota tflops=6.423 gbps=24.920 ms=0.144
shape=3 variant=solution tflops=20.792 gbps=80.668 ms=0.044
shape=3 solution_peak_fraction=0.1040
peak_fraction: 0.1858
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -161,7 +161,7 @@
PADDED_KW = _next_pow2(kW)
if PADDED_KW < 16:
PADDED_KW = 16
- UNROLL = 4
+ UNROLL = 8
BLOCK_K = PADDED_KW * UNROLL
NUM_OUTER = (C * kT * kH + UNROLL - 1) // UNROLL
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"newString": " UNROLL = 8",
"oldString": " UNROLL = 4"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Test UNROLL=8",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}FAIL: shape 0 {'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280} seed 42: tolerance exceeded: max_abs_diff=0.246094 (atol=0.02, rtol=0.02)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -161,7 +161,7 @@
PADDED_KW = _next_pow2(kW)
if PADDED_KW < 16:
PADDED_KW = 16
- UNROLL = 8
+ UNROLL = 4
BLOCK_K = PADDED_KW * UNROLL
NUM_OUTER = (C * kT * kH + UNROLL - 1) // UNROLL
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"newString": " UNROLL = 4",
"oldString": " UNROLL = 8"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Verify UNROLL=4 still passes",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
shape=0 variant=eager tflops=3.585 gbps=19.854 ms=0.215
shape=0 variant=compiled tflops=3.061 gbps=16.953 ms=0.252
shape=0 variant=sota tflops=3.664 gbps=20.290 ms=0.210
shape=0 variant=solution tflops=14.548 gbps=80.565 ms=0.053
shape=0 solution_peak_fraction=0.0727
shape=1 variant=eager tflops=13.594 gbps=35.454 ms=0.227
shape=1 variant=compiled tflops=11.394 gbps=29.716 ms=0.271
shape=1 variant=sota tflops=13.700 gbps=35.732 ms=0.225
shape=1 variant=solution tflops=57.481 gbps=149.919 ms=0.054
shape=1 solution_peak_fraction=0.2874
shape=2 variant=eager tflops=26.150 gbps=54.017 ms=0.265
shape=2 variant=compiled tflops=22.275 gbps=46.012 ms=0.311
shape=2 variant=sota tflops=26.355 gbps=54.440 ms=0.263
shape=2 variant=solution tflops=106.125 gbps=219.214 ms=0.065
shape=2 solution_peak_fraction=0.5306
shape=3 variant=eager tflops=6.312 gbps=24.487 ms=0.147
shape=3 variant=compiled tflops=4.774 gbps=18.523 ms=0.194
shape=3 variant=sota tflops=6.332 gbps=24.565 ms=0.146
shape=3 variant=solution tflops=21.739 gbps=84.338 ms=0.043
shape=3 solution_peak_fraction=0.1087
peak_fraction: 0.1863
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -161,7 +161,7 @@
PADDED_KW = _next_pow2(kW)
if PADDED_KW < 16:
PADDED_KW = 16
- UNROLL = 4
+ UNROLL = 2
BLOCK_K = PADDED_KW * UNROLL
NUM_OUTER = (C * kT * kH + UNROLL - 1) // UNROLL
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"newString": " UNROLL = 2",
"oldString": " UNROLL = 4"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Final verification",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
shape=0 variant=eager tflops=3.531 gbps=19.555 ms=0.218
shape=0 variant=compiled tflops=2.928 gbps=16.216 ms=0.263
shape=0 variant=sota tflops=3.658 gbps=20.259 ms=0.211
shape=0 variant=solution tflops=14.762 gbps=81.751 ms=0.052
shape=0 solution_peak_fraction=0.0738
shape=1 variant=eager tflops=13.686 gbps=35.696 ms=0.225
shape=1 variant=compiled tflops=11.511 gbps=30.023 ms=0.268
shape=1 variant=sota tflops=13.902 gbps=36.257 ms=0.222
shape=1 variant=solution tflops=53.956 gbps=140.725 ms=0.057
shape=1 solution_peak_fraction=0.2698
shape=2 variant=eager tflops=26.091 gbps=53.893 ms=0.266
shape=2 variant=compiled tflops=22.549 gbps=46.577 ms=0.308
shape=2 variant=sota tflops=26.441 gbps=54.616 ms=0.262
shape=2 variant=solution tflops=102.633 gbps=212.000 ms=0.068
shape=2 solution_peak_fraction=0.5132
shape=3 variant=eager tflops=6.282 gbps=24.370 ms=0.147
shape=3 variant=compiled tflops=4.791 gbps=18.589 ms=0.193
shape=3 variant=sota tflops=6.391 gbps=24.796 ms=0.145
shape=3 variant=solution tflops=20.267 gbps=78.631 ms=0.046
shape=3 solution_peak_fraction=0.1013
peak_fraction: 0.1794
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -161,7 +161,7 @@
PADDED_KW = _next_pow2(kW)
if PADDED_KW < 16:
PADDED_KW = 16
- UNROLL = 2
+ UNROLL = 4
BLOCK_K = PADDED_KW * UNROLL
NUM_OUTER = (C * kT * kH + UNROLL - 1) // UNROLL
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"newString": " UNROLL = 4",
"oldString": " UNROLL = 2"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Final verification with UNROLL=4",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
shape=0 variant=eager tflops=3.485 gbps=19.302 ms=0.221
shape=0 variant=compiled tflops=3.072 gbps=17.014 ms=0.251
shape=0 variant=sota tflops=3.643 gbps=20.176 ms=0.212
shape=0 variant=solution tflops=16.401 gbps=90.825 ms=0.047
shape=0 solution_peak_fraction=0.0820
shape=1 variant=eager tflops=13.606 gbps=35.487 ms=0.227
shape=1 variant=compiled tflops=11.598 gbps=30.249 ms=0.266
shape=1 variant=sota tflops=13.909 gbps=36.276 ms=0.222
shape=1 variant=solution tflops=53.536 gbps=139.630 ms=0.058
shape=1 solution_peak_fraction=0.2677
shape=2 variant=eager tflops=26.069 gbps=53.848 ms=0.266
shape=2 variant=compiled tflops=22.584 gbps=46.650 ms=0.307
shape=2 variant=sota tflops=26.397 gbps=54.526 ms=0.263
shape=2 variant=solution tflops=111.245 gbps=229.789 ms=0.062
shape=2 solution_peak_fraction=0.5562
shape=3 variant=eager tflops=6.250 gbps=24.247 ms=0.148
shape=3 variant=compiled tflops=4.800 gbps=18.621 ms=0.193
shape=3 variant=sota tflops=6.372 gbps=24.722 ms=0.145
shape=3 variant=solution tflops=21.173 gbps=82.145 ms=0.044
shape=3 solution_peak_fraction=0.1059
peak_fraction: 0.1896
RESULT: OK
{
"todos": [
{
"content": "Read reference.py, benchmark.py, check.py, problem.yaml to understand the task",
"priority": "high",
"status": "completed"
},
{
"content": "Read CLAUDE.md and any other project context files",
"priority": "high",
"status": "completed"
},
{
"content": "Check GPU availability and CUDA environment",
"priority": "high",
"status": "completed"
},
{
"content": "Implement initial Triton kernel for 3D patch embedding (fused im2col + GEMM)",
"priority": "high",
"status": "comp
... (443 more chars)content: 'Read reference.py, benchmark.py, check.py, problem.yaml to understand the task'
status: 'completed'
priority: 'high'
content: 'Read CLAUDE.md and any other project context files'
status: 'completed'
priority: 'high'
content: 'Check GPU availability and CUDA environment'
status: 'completed'
priority: 'high'
content: 'Implement initial Triton kernel for 3D patch embedding (fused im2col + GEMM)'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Optimize: direct output layout, weight pre-transpose, BLOCK_K=padded_kW*UNROLL'
status: 'completed'
priority: 'high'
content: 'Final verification with check.py and benchmark.py'
status: 'completed'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py</path>
<type>file</type>
<content>
1: """Fused im2col + GEMM patch-embedding kernel (Triton).
2:
3: Key optimization: BLOCK_K = PADDED_KW * UNROLL so that:
4: - Each k-tile of PADDED_KW elements maps to one (c, dt, dh) row
5: - x reads within each k-tile are contiguous along dw (coalesced)
6: - UNROLL reduces loop iterations while maintaining coalescing
7:
8: Weight is pre-transposed to (K, N) for contiguous GEMM loads.
9: Output is written directly in (B, embed_dim, To, Ho, Wo) layout.
10: """
11: import torch
12: import torch.nn as nn
13: import triton
14: import triton.language as tl
15:
16: OP_TYPE = "patch_embed"
17: SUPPORTED_PRECISIONS = ["bf16"]
18: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
19:
20:
21: def _next_pow2(n: int) -> int:
22: p = 1
23: while p < n:
24: p <<= 1
25: return p
26:
27:
28: @triton.autotune(
29: configs=[
30: triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=3),
31: triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=3),
32: triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=3),
33: triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3),
34: triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=4),
35: triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, num_stages=3),
36: tri
... (5677 more chars){
"command": "uv run python check.py 2>&1 && echo \"---\" && uv run python benchmark.py 2>&1",
"description": "Final correctness and benchmark run",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
---
shape=0 variant=eager tflops=3.615 gbps=20.017 ms=0.213
shape=0 variant=compiled tflops=3.065 gbps=16.974 ms=0.251
shape=0 variant=sota tflops=3.649 gbps=20.205 ms=0.211
shape=0 variant=solution tflops=16.401 gbps=90.825 ms=0.047
shape=0 solution_peak_fraction=0.0820
shape=1 variant=eager tflops=13.673 gbps=35.661 ms=0.225
shape=1 variant=compiled tflops=11.610 gbps=30.282 ms=0.266
shape=1 variant=sota tflops=13.921 gbps=36.307 ms=0.221
shape=1 variant=solution tflops=54.290 gbps=141.597 ms=0.057
shape=1 solution_peak_fraction=0.2715
shape=2 variant=eager tflops=26.133 gbps=53.981 ms=0.265
shape=2 variant=compiled tflops=22.323 gbps=46.112 ms=0.311
shape=2 variant=sota tflops=26.047 gbps=53.802 ms=0.266
shape=2 variant=solution tflops=106.464 gbps=219.914 ms=0.065
shape=2 solution_peak_fraction=0.5323
shape=3 variant=eager tflops=6.353 gbps=24.646 ms=0.146
shape=3 variant=compiled tflops=4.863 gbps=18.867 ms=0.190
shape=3 variant=sota tflops=6.451 gbps=25.029 ms=0.143
shape=3 variant=solution tflops=19.898 gbps=77.197 ms=0.046
shape=3 solution_peak_fraction=0.0995
peak_fraction: 0.1853
RESULT: OK