"""Optimized patch-gather + tensor-core GEMM for 3D patch embedding.
Strategy (fastest on RTX PRO 6000 Blackwell sm_120):
1. Pre-reshape weight once (K, embed_dim) contiguous in __init__.
2. Gather patches from x (B,C,T,H,W) -> (M,K) contiguous via permute+contiguous.
The copy enables coalesced A-tile loads in the GEMM.
3. Run statically-configured triton GEMM with bf16 MMA via tl.dot.
4. Reshape output (M,N) -> (B, embed_dim, nT, nH, nW) as a view (no copy).
All per-forward-call overhead is minimized: config, strides, grid dimensions,
and shape constants are pre-computed and stored as attributes.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "patch_embed"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
@triton.jit
def _gemm_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Tiled bf16 GEMM: C[M,N] = A[M,K] * B[K,N] using tensor cores."""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(
a_ptrs,
mask=(rm[:, None] < M) & (rk[None, :] + k < K),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(rk[:, None] + k < K) & (rn[None, :] < N),
other=0.0,
)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.bfloat16)
c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
"Input dims must be divisible by patch size"
self.conv = nn.Conv3d(
C, embed_dim,
kernel_size=(kT, kH, kW),
stride=(kT, kH, kW),
bias=False,
dtype=torch.bfloat16,
)
nn.init.normal_(self.conv.weight, std=0.02)
K_dim = C * kT * kH * kW
w = self.conv.weight.reshape(embed_dim, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
self._cached_B = B
self._cached_C = C
self._cached_T = T
self._cached_H = H
self._cached_W = W
self._nT = T // kT
self._nH = H // kH
self._nW = W // kW
self._kT = kT
self._kH = kH
self._kW = kW
self._N = embed_dim
self._K = K_dim
self._M = B * self._nT * self._nH * self._nW
self._BLOCK_M, self._BLOCK_N, self._BLOCK_K, self._num_warps = \
_pick_config(self._M, self._N, self._K)
self._grid = (
triton.cdiv(self._M, self._BLOCK_M),
triton.cdiv(self._N, self._BLOCK_N),
)
def load_state_dict(self, state_dict, strict=True):
result = super().load_state_dict(state_dict, strict=strict)
K_dim = self._K
w = self.conv.weight.reshape(self._N, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
return result
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_p = x.reshape(
self._cached_B, self._cached_C,
self._nT, self._kT,
self._nH, self._kH,
self._nW, self._kW,
)
x_p = x_p.permute(0, 2, 4, 6, 1, 3, 5, 7)
x_p = x_p.contiguous().reshape(self._M, self._K)
out = torch.empty(self._M, self._N, dtype=torch.bfloat16, device=x.device)
_gemm_kernel[self._grid](
x_p,
self._w_reshaped,
out,
self._M,
self._N,
self._K,
self._K,
1,
self._N,
1,
self._N,
1,
BLOCK_M=self._BLOCK_M,
BLOCK_N=self._BLOCK_N,
BLOCK_K=self._BLOCK_K,
num_warps=self._num_warps,
)
out = out.reshape(self._cached_B, self._nT, self._nH, self._nW, self._N)
out = out.permute(0, 4, 1, 2, 3)
return out
def _pick_config(M: int, N: int, K: int):
"""Return (BLOCK_M, BLOCK_N, BLOCK_K, num_warps) for this shape."""
if K < 900:
return (64, 128, 32, 8)
if M <= 350:
return (64, 64, 64, 4)
if M <= 1200:
return (64, 64, 32, 4)
return (64, 128, 32, 4)
B = 1
C = 3
T = 2
H = 224
W = 224
kT = 2
kH = 14
kW = 14
embed_dim = 1280
def get_inputs():
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
return [x]
def get_init_inputs():
return [B, C, T, H, W, kT, kH, kW, embed_dim]
shape=0 variant=eager tflops=3.700 gbps=20.491 ms=0.208
shape=0 variant=compiled tflops=3.117 gbps=17.263 ms=0.247
shape=0 variant=sota tflops=3.706 gbps=20.521 ms=0.208
shape=0 variant=solution tflops=20.020 gbps=110.869 ms=0.038
shape=0 solution_peak_fraction=0.1001
shape=1 variant=eager tflops=13.981 gbps=36.465 ms=0.220
shape=1 variant=compiled tflops=12.050 gbps=31.428 ms=0.256
shape=1 variant=sota tflops=14.126 gbps=36.842 ms=0.218
shape=1 variant=solution tflops=65.380 gbps=170.522 ms=0.047
shape=1 solution_peak_fraction=0.3269
shape=2 variant=eager tflops=26.487 gbps=54.713 ms=0.262
shape=2 variant=compiled tflops=22.916 gbps=47.335 ms=0.303
shape=2 variant=sota tflops=26.713 gbps=55.178 ms=0.260
shape=2 variant=solution tflops=120.523 gbps=248.954 ms=0.058
shape=2 solution_peak_fraction=0.6026
shape=3 variant=eager tflops=6.463 gbps=25.076 ms=0.143
shape=3 variant=compiled tflops=5.145 gbps=19.960 ms=0.180
shape=3 variant=sota tflops=6.596 gbps=25.591 ms=0.140
shape=3 variant=solution tflops=29.356 gbps=113.893 ms=0.032
shape=3 solution_peak_fraction=0.1468
peak_fraction: 0.2319
RESULT: OK
PASS
{
"problem": "10_patch_embed_conv3d_gemm",
"harness": "opencode",
"model": "deepseek/deepseek-v4-pro",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.2319,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 41003, "output_tokens": 26276, "cache_read_tokens": 3281408, "cache_creation_tokens": 0, "reasoning_tokens": 60526, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and understand problem structure",
"status": "in_progress",
"priority": "high"
},
{
"content": "Read existing solution.py if any",
"status": "pending",
"priority": "high"
},
{
"content": "Design and implement optimized CUDA kernel for 3D patch embedding",
"status": "pending",
"priority": "high"
},
{
"content": "Test correctness with check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Benchmark p
... (229 more chars)content: 'Read reference.py, check.py, benchmark.py, and understand problem structure'
status: 'in_progress'
priority: 'high'
content: 'Read existing solution.py if any'
status: 'pending'
priority: 'high'
content: 'Design and implement optimized CUDA kernel for 3D patch embedding'
status: 'pending'
priority: 'high'
content: 'Test correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Benchmark performance with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate on optimization until performance is good'
status: 'pending'
priority: 'high'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive Conv3d-as-GEMM patch embedding reference (correctness only).
2:
3: Vision-Transformer / Qwen2-VL style patch embedding: a video tensor
4: (B, C, T, H, W) is split into non-overlapping (kT, kH, kW) patches and each
5: patch projected to embed_dim. This is mathematically a 3D convolution with
6: stride == kernel and equivalently a single GEMM after a strided reshape.
7:
8: Reference uses nn.Conv3d (cuDNN dispatch) for clarity. The agent's solution is
9: forbidden from using Conv3d / conv3d / matmul / linear / einsum, forcing them
10: to write a fused patch-gather + tensor-core GEMM kernel.
11:
12: Output layout: (B, embed_dim, T/kT, H/kH, W/kW). No bias.
13: """
14: import torch
15: import torch.nn as nn
16:
17: OP_TYPE = "patch_embed"
18: SUPPORTED_PRECISIONS = ["bf16"]
19: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
20:
21:
22: class Model(nn.Module):
23: def __init__(self, B: int, C: int, T: int, H: int, W: int,
24: kT: int, kH: int, kW: int, embed_dim: int):
25: super().__init__()
26: assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
27: f"Input dims must be divisible by patch size: T={T} kT={kT} H={H} kH={kH} W={W} kW={kW}"
28: self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
29: self.kT, self.kH, self.kW = kT, kH, kW
30: self.embed_dim = embed_dim
31:
3
... (911 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Conv3d-as-GEMM patch embedding.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") o
... (3130 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Conv3d-as-GEMM patch embedding.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regim
... (3879 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for Conv3d-as-GEMM patch embedding.
2:
3: Sized for Qwen2-VL / ViT-style patch embedding. embed_dim=1280 matches Qwen2-VL
4: ViT; embed_dim=768 is a ViT-B/16 reference. Patch sizes are the canonical
5: 14x14 spatial / 2-frame temporal used in modern VL encoders.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL ViT base: 224x224x2-frame video, 14x14x2 patches, embed_dim=1280
10: {"B": 1, "C": 3, "T": 2, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
11:
12: # Same patch config, batch 2, 4-frame video
13: {"B": 2, "C": 3, "T": 4, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
14:
15: # Larger spatial (336x336), 8-frame — exercises larger num_patches
16: {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
17:
18: # ViT-B/16 (image-mode): kT=1, kH=kW=16, embed_dim=768
19: {"B": 4, "C": 3, "T": 1, "H": 224, "W": 224, "kT": 1, "kH": 16, "kW": 16, "embed_dim": 768},
20: ]
(End of file - total 20 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 10_patch_embed_conv3d_gemm
2: display_name: "Conv3d-as-GEMM Patch Embedding"
3: precision: bf16
4: regime: compute # GEMM-equivalent, embed_dim=1280 is well above the bf16 ridge
5:
6: # Dense FLOPs: 2 * (num_patches) * (C * kT * kH * kW) * embed_dim.
7: # With stride==kernel and divisible inputs, num_patches * (kT*kH*kW) = T*H*W,
8: # so this collapses to: 2 * B * embed_dim * C * T * H * W.
9: flops_formula: "2 * B * embed_dim * C * T * H * W"
10:
11: # Bytes moved (lower bound):
12: # read input x: B * C * T * H * W * 2
13: # read weights: embed_dim * C * kT * kH * kW * 2
14: # write output: B * embed_dim * (T // kT) * (H // kH) * (W // kW) * 2
15: bytes_formula: "B*C*T*H*W*2 + embed_dim*C*kT*kH*kW*2 + B*embed_dim*(T//kT)*(H//kH)*(W//kW)*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must implement the fused patch-gather + GEMM
25: # themselves. Banning Conv3d/conv3d (the cuDNN dispatch) and matmul/bmm/linear/
26: # einsum (the lazy reshape-then-cuBLAS shortcut). Unfold is also banned for
27: # completeness even though it's nominally 2D.
28: forbidden:
29: - "torch.nn.Conv3d"
30: - "torch.nn.functional.conv3d"
31: - "F.conv3d"
32: - "torch.conv3d"
33: - "torch.matmul"
34: - "torch.bmm"
35: -
... (386 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Conv3d-as-GEMM patch embedding.
2:
3: Dispatches torch.nn.functional.conv3d (cuDNN). The agent's solution is
4: forbidden from calling conv3d directly; this file is only for the benchmark's
5: reference baseline.
6: """
7: from __future__ import annotations
8:
9: import torch
10: import torch.nn.functional as F
11:
12:
13: def sota_forward(x: torch.Tensor, weight: torch.Tensor,
14: stride: tuple[int, int, int]) -> torch.Tensor:
15: return F.conv3d(x, weight, bias=None, stride=stride)
16:
17:
18: def is_available() -> bool:
19: return True
(End of file - total 19 lines)
</content>{
"pattern": "problems/10_patch_embed_conv3d_gemm/solution.py"
}No files found{
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>&1 | head -5",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
{
"command": "nvcc --version 2>&1 | head -3",
"description": "Check CUDA version"
}nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2026 NVIDIA Corporation
Built on Mon_Mar_02_09:52:23_PM_PST_2026
{
"command": "python -c \"import triton; print(triton.__version__)\" 2>&1 || echo \"no triton\"",
"description": "Check Triton availability"
}3.6.0
{
"command": "python -c \"import torch; print(torch.__version__, torch.version.cuda)\" 2>&1",
"description": "Check PyTorch and CUDA versions"
}2.11.0+cu130 13.0
{
"todos": [
{
"priority": "high",
"content": "Read reference.py, check.py, benchmark.py, and understand problem structure",
"status": "completed"
},
{
"priority": "high",
"content": "Read existing solution.py if any",
"status": "completed"
},
{
"priority": "high",
"content": "Check existing examples and hardware config for patterns",
"status": "in_progress"
},
{
"priority": "high",
"content": "Design and implement optimized CUDA kernel for 3D patch embedding",
"status": "pending"
},
{
... (375 more chars)content: 'Read reference.py, check.py, benchmark.py, and understand problem structure'
status: 'completed'
priority: 'high'
content: 'Read existing solution.py if any'
status: 'completed'
priority: 'high'
content: 'Check existing examples and hardware config for patterns'
status: 'in_progress'
priority: 'high'
content: 'Design and implement optimized CUDA kernel for 3D patch embedding'
status: 'pending'
priority: 'high'
content: 'Test correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Benchmark performance with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate on optimization until performance is good'
status: 'pending'
priority: 'high'{
"description": "Check repo structure and find other examples",
"command": "ls problems/ | head -20"
}ls: cannot access 'problems/': No such file or directory
{
"description": "Check hardware config for RTX PRO 6000",
"command": "python -c \"from src.hardware import get; hw = get('RTX_PRO_6000'); print(f'Peak bf16 TFLOPS: {hw.peak_tflops_dense}'); print(f'Peak bandwidth: {hw.peak_bandwidth_gb_s} GB/s'); print(f'SM count etc: {hw.__dict__ if hasattr(hw, \\\"__dict__\\\") else dir(hw)}')\"",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}Peak bf16 TFLOPS: {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}
Peak bandwidth: 1800.0 GB/s
SM count etc: {'name': 'RTX PRO 6000 Blackwell Workstation', 'sm': 'sm_120a', 'vram_gb': 96, 'peak_bandwidth_gb_s': 1800.0, 'peak_tflops_dense': {'fp4': 800.0, 'nvfp4': 800.0, 'mxfp4': 800.0, 'fp6': 800.0, 'fp8': 400.0, 'bf16': 200.0, 'fp16': 200.0, 'tf32': 100.0, 'fp32': 12.0, 'int8': 400.0, 'int4': 800.0}}
{
"pattern": "problems/*/solution.py"
}No files found{
"todos": [
{
"priority": "high",
"content": "Read reference.py, check.py, benchmark.py, and understand problem structure",
"status": "completed"
},
{
"priority": "high",
"content": "Read existing solution.py if any",
"status": "completed"
},
{
"priority": "high",
"content": "Check existing examples and hardware config for patterns",
"status": "completed"
},
{
"priority": "high",
"content": "Design and implement optimized Triton kernel for 3D patch embedding",
"status": "in_progress"
},
... (379 more chars)content: 'Read reference.py, check.py, benchmark.py, and understand problem structure'
status: 'completed'
priority: 'high'
content: 'Read existing solution.py if any'
status: 'completed'
priority: 'high'
content: 'Check existing examples and hardware config for patterns'
status: 'completed'
priority: 'high'
content: 'Design and implement optimized Triton kernel for 3D patch embedding'
status: 'in_progress'
priority: 'high'
content: 'Test correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Benchmark performance with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate on optimization until performance is good'
status: 'pending'
priority: 'high'{
"description": "Test reshape logic correctness",
"command": "<1154 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 -c "
import torch
import torch.nn as nn
# Quick test: verify reshape -> gemm -> reshape matches Conv3d
B, C, T, H, W = 1, 3, 2, 16, 16
kT, kH, kW = 2, 4, 4
embed_dim = 8
torch.manual_seed(42)
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
conv = nn.Conv3d(C, embed_dim, kernel_size=(kT, kH, kW), stride=(kT, kH, kW), bias=False, dtype=torch.bfloat16)
nn.init.normal_(conv.weight, std=0.02)
ref_out = conv(x)
# Reshape approach
nT, nH, nW = T // kT, H // kH, W // kW
N = embed_dim
K = C * kT * kH * kW
M = B * nT * nH * nW
# Reshape input to patches
x_reshaped = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
x_reshaped = x_reshaped.permute(0, 2, 4, 6, 1, 3, 5, 7)
x_patches = x_reshaped.contiguous().reshape(M, K)
# Reshape weight
w = conv.weight.reshape(N, K).t().contiguous()
# Manual GEMM
out_2d = x_patches @ w # (M, K) @ (K, N) = (M, N)
out = out_2d.reshape(B, nT, nH, nW, N).permute(0, 4, 1, 2, 3).contiguous()
print('Max diff:', (ref_out - out).abs().max().item())
print('All close:', torch.allclose(ref_out, out, atol=1e-2, rtol=1e-2))
" 2>&1Max diff: 0.0
All close: True
{
"description": "Check Triton GEMM capabilities on this GPU",
"command": "<712 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 -c "
import triton
import triton.language as tl
import torch
@triton.jit
def test_dot_kernel(a_ptr, b_ptr, c_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr):
offs = tl.arange(0, BLOCK_N)
a = tl.load(a_ptr + offs)
b = tl.load(b_ptr + offs)
c = tl.dot(a[None, :], b[:, None])
tl.store(c_ptr + offs, tl.ravel(c))
a = torch.ones(64, device='cuda', dtype=torch.bfloat16)
b = torch.ones(64, device='cuda', dtype=torch.bfloat16)
c = torch.zeros(64, device='cuda', dtype=torch.bfloat16)
test_dot_kernel[(1,)](a, b, c, 64, 64)
print('tl.dot works:', c[0].item())
print('Triton version:', triton.__version__)
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1260, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1089, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 6, in <module>
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/.local/lib/python3.12/site-packages/triton/runtime/jit.py", line 463, in __init__
raise ValueError("@jit functions should be defined in a Python file") from e
ValueError: @jit functions should be defined in a Python file
{
"description": "Test Triton GEMM from a file",
"command": "<2059 chars \u2014 see below>"
}cat > /tmp/test_triton.py << 'PYEOF'
import triton
import triton.language as tl
import torch
@triton.jit
def test_gemm_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(rm[:, None] < M) & (rk[None, :] + k < K), other=0.0)
b = tl.load(b_ptrs, mask=(rk[:, None] + k < K) & (rn[None, :] < N), other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.bfloat16)
c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
M, N, K = 256, 128, 256
a = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
b = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
c_triton = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']), triton.cdiv(N, meta['BLOCK_N']))
test_gemm_kernel[grid](
a, b, c_triton,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c_triton.stride(0), c_triton.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=4,
)
c_ref = a @ b
print("Max diff:", (c_triton.float() - c_ref.float()).abs().max().item())
print("All close:", torch.allclose(c_triton.float(), c_ref.float(), atol=1e-2, rtol=1e-2))
PYEOF
cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 /tmp/test_triton.py 2>&1Max diff: 0.0
All close: True
{
"description": "Benchmark different Triton GEMM configs",
"command": "<3982 chars \u2014 see below>",
"timeout": 180000
}cat > /tmp/bench_configs.py << 'PYEOF'
import triton
import triton.language as tl
import torch
import time
@triton.jit
def bench_gemm_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(rm[:, None] < M) & (rk[None, :] + k < K), other=0.0)
b = tl.load(b_ptrs, mask=(rk[:, None] + k < K) & (rn[None, :] < N), other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.bfloat16)
c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
def time_kernel(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps=4, iters=200):
a = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
b = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']), triton.cdiv(N, meta['BLOCK_N']))
# Warmup
for _ in range(20):
bench_gemm_kernel[grid](
a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps,
)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
bench_gemm_kernel[grid](
a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
num_warps=num_warps,
)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / iters * 1000
return elapsed
# Test shapes from the problem
shapes = [
("shape0 B1C3T2H224W224", 1*1*16*16, 1280, 3*2*14*14), # M=256, N=1280, K=1176
("shape1 B2C3T4H224W224", 2*2*16*16, 1280, 3*2*14*14), # M=1024, N=1280, K=1176
("shape2 B1C3T8H336W336", 1*4*24*24, 1280, 3*2*14*14), # M=2304, N=1280, K=1176
("shape3 B4C3T1H224W224", 4*1*14*14, 768, 3*1*16*16), # M=784, N=768, K=768
]
configs = [
# (BLOCK_M, BLOCK_N, BLOCK_K, num_warps)
(32, 32, 32, 4),
(64, 32, 32, 4),
(32, 64, 32, 4),
(64, 64, 32, 4),
(64, 64, 64, 4),
(128, 64, 32, 4),
(64, 128, 32, 4),
(128, 128, 32, 4),
(128, 128, 64, 4),
(64, 64, 32, 8),
(128, 64, 32, 8),
(64, 128, 32, 8),
(128, 128, 32, 8),
(128, 128, 64, 8),
(128, 256, 32, 8),
(256, 128, 32, 8),
]
for name, M, N, K in shapes:
print(f"\n--- {name}: M={M}, N={N}, K={K} ---")
best_ms = float('inf')
best_cfg = None
for cfg in configs:
BM, BN, BK, nw = cfg
if BM > M or BN > N:
continue
try:
ms = time_kernel(M, N, K, BM, BN, BK, nw, iters=100)
if ms < best_ms:
best_ms = ms
best_cfg = cfg
print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw}: {ms:.4f} ms")
except Exception as e:
print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw}: ERROR {e}")
print(f" BEST: BM={best_cfg[0]} BN={best_cfg[1]} BK={best_cfg[2]} nw={best_cfg[3]} = {best_ms:.4f} ms")
PYEOF
cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 /tmp/bench_configs.py 2>&1--- shape0 B1C3T2H224W224: M=256, N=1280, K=1176 ---
BM= 32 BN= 32 BK= 32 nw=4: 0.0144 ms
BM= 64 BN= 32 BK= 32 nw=4: 0.0140 ms
BM= 32 BN= 64 BK= 32 nw=4: 0.0138 ms
BM= 64 BN= 64 BK= 32 nw=4: 0.0149 ms
BM= 64 BN= 64 BK= 64 nw=4: 0.0129 ms
BM=128 BN= 64 BK= 32 nw=4: 0.0242 ms
BM= 64 BN=128 BK= 32 nw=4: 0.0190 ms
BM=128 BN=128 BK= 32 nw=4: 0.0336 ms
BM=128 BN=128 BK= 64 nw=4: 0.0669 ms
BM= 64 BN= 64 BK= 32 nw=8: 0.0143 ms
BM=128 BN= 64 BK= 32 nw=8: 0.0195 ms
BM= 64 BN=128 BK= 32 nw=8: 0.0174 ms
BM=128 BN=128 BK= 32 nw=8: 0.0272 ms
BM=128 BN=128 BK= 64 nw=8: 0.0266 ms
BM=128 BN=256 BK= 32 nw=8: 0.0442 ms
BM=256 BN=128 BK= 32 nw=8: 0.0497 ms
BEST: BM=64 BN=64 BK=64 nw=4 = 0.0129 ms
--- shape1 B2C3T4H224W224: M=1024, N=1280, K=1176 ---
BM= 32 BN= 32 BK= 32 nw=4: 0.0310 ms
BM= 64 BN= 32 BK= 32 nw=4: 0.0281 ms
BM= 32 BN= 64 BK= 32 nw=4: 0.0227 ms
BM= 64 BN= 64 BK= 32 nw=4: 0.0195 ms
BM= 64 BN= 64 BK= 64 nw=4: 0.0198 ms
BM=128 BN= 64 BK= 32 nw=4: 0.0250 ms
BM= 64 BN=128 BK= 32 nw=4: 0.0243 ms
BM=128 BN=128 BK= 32 nw=4: 0.0336 ms
BM=128 BN=128 BK= 64 nw=4: 0.0684 ms
BM= 64 BN= 64 BK= 32 nw=8: 0.0207 ms
BM=128 BN= 64 BK= 32 nw=8: 0.0226 ms
BM= 64 BN=128 BK= 32 nw=8: 0.0228 ms
BM=128 BN=128 BK= 32 nw=8: 0.0274 ms
BM=128 BN=128 BK= 64 nw=8: 0.0271 ms
BM=128 BN=256 BK= 32 nw=8: 0.0443 ms
BM=256 BN=128 BK= 32 nw=8: 0.0500 ms
BEST: BM=64 BN=64 BK=32 nw=4 = 0.0195 ms
--- shape2 B1C3T8H336W336: M=2304, N=1280, K=1176 ---
B
... (1385 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -0,0 +1,160 @@
+"""Fused patch-gather + tensor-core GEMM kernel for 3D patch embedding.
+
+Uses a Triton GEMM kernel after reshaping the input into a (num_patches, C*kT*kH*kW)
+matrix via permute+contiguous. The weight is reshaped to (C*kT*kH*kW, embed_dim).
+All compute is done with bf16 tensor-core MMA via tl.dot.
+"""
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "patch_embed"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8),
+ ],
+ key=["M", "N", "K"],
+)
+@triton.jit
+def _patch_gemm_kernel(
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ """Tiled bf16 GEMM: C[M,N] += A[M,K] * B[K,N]."""
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+
+ a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
+ b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k in range(0, K, BLOCK_K):
+ a = tl.load(
+ a_ptrs,
+ mask=(rm[:, None] < M) & (rk[None, :] + k < K),
+ other=0.0,
+ )
+ b = tl.load(
+ b_ptrs,
+ mask=(rk[:, None] + k < K) & (rn[None, :] < N),
+ other=0.0,
+ )
+ acc = tl.dot(a, b, acc)
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+
+ c = acc.to(tl.bfloat16)
+ c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
+ tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, C: int, T: int, H: int, W: int,
+ kT: int, kH: int, kW: int, embed_dim: int):
+ super().__init__()
+ assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
+ f"Input dims must be divisible by patch size"
+ self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
+ self.kT, self.kH, self.kW = kT, kH, kW
+ self.embed_dim = embed_dim
+
+ self.conv = nn.Conv3d(
+ C, embed_dim,
+ kernel_size=(kT, kH, kW),
+ stride=(kT, kH, kW),
+ bias=False,
+ dtype=torch.bfloat16,
+ )
+ nn.init.normal_(self.conv.weight, std=0.02)
+
+ K_dim = C * kT * kH * kW
+ w = self.conv.weight.reshape(embed_dim, K_dim).t().contiguous()
+ self.register_buffer("_w_reshaped", w, persistent=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, T, H, W = x.shape
+ kT, kH, kW = self.kT, self.kH, self.kW
+ nT, nH, nW = T // kT, H // kH, W // kW
+ N = self.embed_dim
+ K = C * kT * kH * kW
+ M = B * nT * nH * nW
+
+ x_patches = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
+ x_patches = x_patches.permute(0, 2, 4, 6, 1, 3, 5, 7)
+ x_patches = x_patches.contiguous().reshape(M, K)
+
+ out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
+
+ grid = lambda meta: (
+ triton.cdiv(M, meta["BLOCK_M"]),
+ triton.cdiv(N, meta["BLOCK_N"]),
+ )
+
+ _patch_gemm_kernel[grid](
+ x_patches,
+ self._w_reshaped,
+ out,
+ M,
+ N,
+ K,
+ x_patches.stride(0),
+ x_patches.stride(1),
+ self._w_reshaped.stride(0),
+ self._w_reshaped.stride(1),
+ out.stride(0),
+ out.stride(1),
+ )
+
+ out = out.reshape(B, nT, nH, nW, N).permute(0, 4, 1, 2, 3).contiguous()
+ return out
+
+
+B = 1
+C = 3
+T = 2
+H = 224
+W = 224
+kT = 2
+kH = 14
+kW = 14
+embed_dim = 1280
+
+
+def get_inputs():
+ x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
+ return [x]
+
+
+def get_init_inputs():
+ return [B, C, T, H, W, kT, kH, kW, embed_dim]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<4724 chars \u2014 see diff>"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -102,6 +102,13 @@
w = self.conv.weight.reshape(embed_dim, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
+ def load_state_dict(self, state_dict, strict=True):
+ result = super().load_state_dict(state_dict, strict=strict)
+ K_dim = self.C * self.kT * self.kH * self.kW
+ w = self.conv.weight.reshape(self.embed_dim, K_dim).t().contiguous()
+ self.register_buffer("_w_reshaped", w, persistent=False)
+ return result
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T, H, W = x.shape
kT, kH, kW = self.kT, self.kH, self.kW
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<226 chars \u2014 see below>",
"newString": "<568 chars \u2014 see below>"
} K_dim = C * kT * kH * kW
w = self.conv.weight.reshape(embed_dim, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: K_dim = C * kT * kH * kW
w = self.conv.weight.reshape(embed_dim, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
def load_state_dict(self, state_dict, strict=True):
result = super().load_state_dict(state_dict, strict=strict)
K_dim = self.C * self.kT * self.kH * self.kW
w = self.conv.weight.reshape(self.embed_dim, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
return result
def forward(self, x: torch.Tensor) -> torch.Tensor:Edit applied successfully.{
"todos": [
{
"priority": "high",
"content": "Read reference.py, check.py, benchmark.py, and understand problem structure",
"status": "completed"
},
{
"priority": "high",
"content": "Read existing solution.py if any",
"status": "completed"
},
{
"priority": "high",
"content": "Check existing examples and hardware config for patterns",
"status": "completed"
},
{
"priority": "high",
"content": "Design and implement optimized Triton kernel for 3D patch embedding",
"status": "completed"
},
... (381 more chars)content: 'Read reference.py, check.py, benchmark.py, and understand problem structure'
status: 'completed'
priority: 'high'
content: 'Read existing solution.py if any'
status: 'completed'
priority: 'high'
content: 'Check existing examples and hardware config for patterns'
status: 'completed'
priority: 'high'
content: 'Design and implement optimized Triton kernel for 3D patch embedding'
status: 'completed'
priority: 'high'
content: 'Test correctness with check.py'
status: 'in_progress'
priority: 'high'
content: 'Benchmark performance with benchmark.py'
status: 'pending'
priority: 'high'
content: 'Iterate on optimization until performance is good'
status: 'pending'
priority: 'high'{
"description": "Run correctness check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 check.py 2>&1",
"timeout": 300000
}PASS
{
"description": "Run benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 benchmark.py 2>&1",
"timeout": 300000
}shape=0 variant=eager tflops=3.688 gbps=20.425 ms=0.209
shape=0 variant=compiled tflops=3.210 gbps=17.779 ms=0.240
shape=0 variant=sota tflops=3.719 gbps=20.597 ms=0.207
shape=0 variant=solution tflops=15.793 gbps=87.460 ms=0.049
shape=0 solution_peak_fraction=0.0790
shape=1 variant=eager tflops=13.954 gbps=36.394 ms=0.221
shape=1 variant=compiled tflops=12.104 gbps=31.570 ms=0.255
shape=1 variant=sota tflops=14.102 gbps=36.780 ms=0.219
shape=1 variant=solution tflops=52.773 gbps=137.641 ms=0.058
shape=1 solution_peak_fraction=0.2639
shape=2 variant=eager tflops=26.757 gbps=55.270 ms=0.259
shape=2 variant=compiled tflops=23.442 gbps=48.423 ms=0.296
shape=2 variant=sota tflops=26.133 gbps=53.981 ms=0.265
shape=2 variant=solution tflops=82.153 gbps=169.696 ms=0.084
shape=2 solution_peak_fraction=0.4108
shape=3 variant=eager tflops=6.487 gbps=25.169 ms=0.143
shape=3 variant=compiled tflops=5.273 gbps=20.456 ms=0.175
shape=3 variant=sota tflops=6.596 gbps=25.591 ms=0.140
shape=3 variant=solution tflops=23.826 gbps=92.439 ms=0.039
shape=3 solution_peak_fraction=0.1191
peak_fraction: 0.1787
RESULT: OK
{
"description": "Inspect time_fn implementation",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && python3 -c \"from src.eval.timing import time_fn; import inspect; print(inspect.getsource(time_fn))\" 2>&1"
}def time_fn(fn, inputs, iters: int = 30, warmup: int = 10) -> float:
"""Time `fn(*inputs)` and return the median wall time in milliseconds.
Each timed call is preceded by an L2 flush.
"""
for _ in range(warmup):
with torch.no_grad():
fn(*inputs)
torch.cuda.synchronize()
times: list[float] = []
for _ in range(iters):
_l2_flush()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
fn(*inputs)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return statistics.median(times)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"offset": 93,
"limit": 20
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py</path>
<type>file</type>
<content>
93: C, embed_dim,
94: kernel_size=(kT, kH, kW),
95: stride=(kT, kH, kW),
96: bias=False,
97: dtype=torch.bfloat16,
98: )
99: nn.init.normal_(self.conv.weight, std=0.02)
100:
101: K_dim = C * kT * kH * kW
102: w = self.conv.weight.reshape(embed_dim, K_dim).t().contiguous()
103: self.register_buffer("_w_reshaped", w, persistent=False)
104:
105: def load_state_dict(self, state_dict, strict=True):
106: result = super().load_state_dict(state_dict, strict=strict)
107: K_dim = self.C * self.kT * self.kH * self.kW
108: w = self.conv.weight.reshape(self.embed_dim, K_dim).t().contiguous()
109: self.register_buffer("_w_reshaped", w, persistent=False)
110: return result
111:
112: def forward(self, x: torch.Tensor) -> torch.Tensor:
(Showing lines 93-112 of 167. Use offset=113 to continue.)
</content>{
"description": "Check if reshape without contiguous creates a copy",
"command": "<792 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 -c "
import torch
B,C,T,H,W = 1,3,2,16,16
kT,kH,kW = 2,4,4
nT,nH,nW = T//kT, H//kH, W//kW
K = C * kT * kH * kW
M = B * nT * nH * nW
x = torch.randn(B,C,T,H,W)
x1 = x.reshape(B,C,nT,kT,nH,kH,nW,kW)
x2 = x1.permute(0,2,4,6,1,3,5,7)
print('After permute - contiguous:', x2.is_contiguous())
print('After permute - strides:', x2.stride())
x3 = x2.reshape(M, K)
print('After reshape - contiguous:', x3.is_contiguous())
print('After reshape - strides:', x3.stride())
print('After reshape - data_ptr same as x:', x3.data_ptr() == x.data_ptr())
# Try without reshape - just view_as
x4 = x2.contiguous().reshape(M, K)
print('After contiguous+reshape - data_ptr same as x:', x4.data_ptr() == x.data_ptr())
" 2>&1After permute - contiguous: False
After permute - strides: (1536, 512, 64, 4, 512, 256, 16, 1)
After reshape - contiguous: True
After reshape - strides: (96, 1)
After reshape - data_ptr same as x: False
After contiguous+reshape - data_ptr same as x: False
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py</path>
<type>file</type>
<content>
1: """Fused patch-gather + tensor-core GEMM kernel for 3D patch embedding.
2:
3: Uses a Triton GEMM kernel after reshaping the input into a (num_patches, C*kT*kH*kW)
4: matrix via permute+contiguous. The weight is reshaped to (C*kT*kH*kW, embed_dim).
5: All compute is done with bf16 tensor-core MMA via tl.dot.
6: """
7:
8: import torch
9: import torch.nn as nn
10: import triton
11: import triton.language as tl
12:
13: OP_TYPE = "patch_embed"
14: SUPPORTED_PRECISIONS = ["bf16"]
15: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
16:
17:
18: @triton.autotune(
19: configs=[
20: triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
21: triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4),
22: triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
23: triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
24: triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
25: triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8),
26: triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8),
27: triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8),
28: ],
29: key=["M", "N", "K"],
30: )
31: @triton.jit
32: def _patch_gemm_kernel(
33:
... (4463 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,8 +1,9 @@
-"""Fused patch-gather + tensor-core GEMM kernel for 3D patch embedding.
-
-Uses a Triton GEMM kernel after reshaping the input into a (num_patches, C*kT*kH*kW)
-matrix via permute+contiguous. The weight is reshaped to (C*kT*kH*kW, embed_dim).
-All compute is done with bf16 tensor-core MMA via tl.dot.
+"""Fully fused patch-gather + tensor-core GEMM kernel for 3D patch embedding.
+
+Reads input patches directly from the (B, C, T, H, W) tensor via on-the-fly
+indexing (no intermediate reshape copy). The weight is pre-reshaped once to
+(K, embed_dim) contiguous. The output is written directly to the standard
+(B, embed_dim, T/kT, H/kH, W/kW) layout, avoiding any post-GEMM reshape copy.
"""
import torch
@@ -17,66 +18,122 @@
@triton.autotune(
configs=[
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
+ triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8),
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8),
],
key=["M", "N", "K"],
)
@triton.jit
-def _patch_gemm_kernel(
- a_ptr,
- b_ptr,
- c_ptr,
- M,
+def _fused_patch_gemm_kernel(
+ x_ptr,
+ w_ptr,
+ out_ptr,
+ B,
+ C,
+ T,
+ H,
+ W,
+ nT,
+ nH,
+ nW,
+ K,
N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
+ kT,
+ kH,
+ kW,
+ stride_xb,
+ stride_xc,
+ stride_xt,
+ stride_xh,
+ stride_xw,
+ stride_wk,
+ stride_wn,
+ stride_ob,
+ stride_oc,
+ stride_ot,
+ stride_oh,
+ stride_ow,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
- """Tiled bf16 GEMM: C[M,N] += A[M,K] * B[K,N]."""
+ """Fused patch-gather + GEMM: output[m,n] = sum_k gather(x, m, k) * w[k,n]."""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
+ M = B * nT * nH * nW
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
- a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
- b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
+ m_mask = rm < M
+ n_mask = rn < N
+
+ kHW = kH * kW
+ kTHW = kT * kHW
+
+ b_idx = rm // (nT * nH * nW)
+ rem_sp = rm % (nT * nH * nW)
+ t_out = rem_sp // (nH * nW)
+ rem_sp = rem_sp % (nH * nW)
+ h_out = rem_sp // nW
+ w_out = rem_sp % nW
+
+ w_ptrs_base = w_ptr + rk[:, None] * stride_wk + rn[None, :] * stride_wn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k in range(0, K, BLOCK_K):
- a = tl.load(
- a_ptrs,
- mask=(rm[:, None] < M) & (rk[None, :] + k < K),
- other=0.0,
- )
- b = tl.load(
- b_ptrs,
- mask=(rk[:, None] + k < K) & (rn[None, :] < N),
- other=0.0,
- )
+ for k_off in range(0, K, BLOCK_K):
+ k_val = k_off + rk
+ k_lo_mask = k_val < K
+
+ ic = k_val // kTHW
+ rem_k = k_val - ic * kTHW
+ dt = rem_k // kHW
+ rem_k = rem_k - dt * kHW
+ dh = rem_k // kW
+ dw = rem_k - dh * kW
+
+ x_t = t_out[:, None] * kT + dt[None, :]
+ x_h = h_out[:, None] * kH + dh[None, :]
+ x_w = w_out[:, None] * kW + dw[None, :]
+
+ a_ptrs = (
+ x_ptr
+ + b_idx[:, None] * stride_xb
+ + ic[None, :] * stride_xc
+ + x_t * stride_xt
+ + x_h * stride_xh
+ + x_w * stride_xw
+ )
+ a = tl.load(a_ptrs, mask=m_mask[:, None] & k_lo_mask[None, :], other=0.0)
+
+ b = tl.load(w_ptrs_base, mask=k_lo_mask[:, None] & n_mask[None, :], other=0.0)
+
acc = tl.dot(a, b, acc)
- a_ptrs += BLOCK_K * stride_ak
- b_ptrs += BLOCK_K * stride_bk
+
+ w_ptrs_base += BLOCK_K * stride_wk
c = acc.to(tl.bfloat16)
- c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
- tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
+
+ c_ptrs = (
+ out_ptr
+ + b_idx[:, None] * stride_ob
+ + rn[None, :] * stride_oc
+ + t_out[:, None] * stride_ot
+ + h_out[:, None] * stride_oh
+ + w_out[:, None] * stride_ow
+ )
+ tl.store(c_ptrs, c, mask=m_mask[:, None] & n_mask[None, :])
class Model(nn.Module):
@@ -117,33 +174,46 @@
K = C * kT * kH * kW
M = B * nT * nH * nW
- x_patches = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
- x_patches = x_patches.permute(0, 2, 4, 6, 1, 3, 5, 7)
- x_patches = x_patches.contiguous().reshape(M, K)
-
- out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
+ out = torch.empty(B, N, nT, nH, nW, dtype=torch.bfloat16, device=x.device)
+
+ w = self._w_reshaped
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
- _patch_gemm
... (1010 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<6470 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test correctness of fused kernel",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 check.py 2>&1",
"timeout": 300000
}PASS
{
"description": "Benchmark fused kernel",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 benchmark.py 2>&1",
"timeout": 300000
}shape=0 variant=eager tflops=3.690 gbps=20.433 ms=0.209
shape=0 variant=compiled tflops=3.220 gbps=17.835 ms=0.239
shape=0 variant=sota tflops=3.726 gbps=20.634 ms=0.207
shape=0 variant=solution tflops=12.450 gbps=68.946 ms=0.062
shape=0 solution_peak_fraction=0.0622
shape=1 variant=eager tflops=13.945 gbps=36.370 ms=0.221
shape=1 variant=compiled tflops=12.153 gbps=31.697 ms=0.254
shape=1 variant=sota tflops=14.126 gbps=36.842 ms=0.218
shape=1 variant=solution tflops=49.391 gbps=128.820 ms=0.062
shape=1 solution_peak_fraction=0.2470
shape=2 variant=eager tflops=26.605 gbps=54.955 ms=0.261
shape=2 variant=compiled tflops=23.185 gbps=47.892 ms=0.299
shape=2 variant=sota tflops=26.787 gbps=55.332 ms=0.259
shape=2 variant=solution tflops=41.842 gbps=86.429 ms=0.166
shape=2 solution_peak_fraction=0.2092
shape=3 variant=eager tflops=6.495 gbps=25.197 ms=0.142
shape=3 variant=compiled tflops=5.277 gbps=20.474 ms=0.175
shape=3 variant=sota tflops=6.598 gbps=25.597 ms=0.140
shape=3 variant=solution tflops=25.921 gbps=100.563 ms=0.036
shape=3 solution_peak_fraction=0.1296
peak_fraction: 0.1429
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,9 +1,12 @@
-"""Fully fused patch-gather + tensor-core GEMM kernel for 3D patch embedding.
+"""Optimized patch-gather + tensor-core GEMM for 3D patch embedding.
-Reads input patches directly from the (B, C, T, H, W) tensor via on-the-fly
-indexing (no intermediate reshape copy). The weight is pre-reshaped once to
-(K, embed_dim) contiguous. The output is written directly to the standard
-(B, embed_dim, T/kT, H/kH, W/kW) layout, avoiding any post-GEMM reshape copy.
+Strategy (proven fastest):
+1. Pre-reshape weight once (K, embed_dim) contiguous in __init__.
+2. Gather patches from x (B,C,T,H,W) -> (M,K) contiguous via permute+contiguous.
+ The copy is cheap (M*K*2 bytes, contiguous write after strided read) and
+ enables coalesced A-tile loads in the GEMM.
+3. Run triton GEMM on (M,K) @ (K,N) -> (M,N) using bf16 MMA via tl.dot.
+4. Reshape output (M,N) -> (B, embed_dim, nT, nH, nW) as a view (no copy).
"""
import torch
@@ -33,107 +36,54 @@
key=["M", "N", "K"],
)
@triton.jit
-def _fused_patch_gemm_kernel(
- x_ptr,
- w_ptr,
- out_ptr,
- B,
- C,
- T,
- H,
- W,
- nT,
- nH,
- nW,
+def _gemm_kernel(
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ M,
+ N,
K,
- N,
- kT,
- kH,
- kW,
- stride_xb,
- stride_xc,
- stride_xt,
- stride_xh,
- stride_xw,
- stride_wk,
- stride_wn,
- stride_ob,
- stride_oc,
- stride_ot,
- stride_oh,
- stride_ow,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
- """Fused patch-gather + GEMM: output[m,n] = sum_k gather(x, m, k) * w[k,n]."""
+ """Tiled bf16 GEMM: C[M,N] = A[M,K] * B[K,N] using tensor cores."""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
- M = B * nT * nH * nW
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
- m_mask = rm < M
- n_mask = rn < N
-
- kHW = kH * kW
- kTHW = kT * kHW
-
- b_idx = rm // (nT * nH * nW)
- rem_sp = rm % (nT * nH * nW)
- t_out = rem_sp // (nH * nW)
- rem_sp = rem_sp % (nH * nW)
- h_out = rem_sp // nW
- w_out = rem_sp % nW
-
- w_ptrs_base = w_ptr + rk[:, None] * stride_wk + rn[None, :] * stride_wn
+ a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
+ b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- for k_off in range(0, K, BLOCK_K):
- k_val = k_off + rk
- k_lo_mask = k_val < K
-
- ic = k_val // kTHW
- rem_k = k_val - ic * kTHW
- dt = rem_k // kHW
- rem_k = rem_k - dt * kHW
- dh = rem_k // kW
- dw = rem_k - dh * kW
-
- x_t = t_out[:, None] * kT + dt[None, :]
- x_h = h_out[:, None] * kH + dh[None, :]
- x_w = w_out[:, None] * kW + dw[None, :]
-
- a_ptrs = (
- x_ptr
- + b_idx[:, None] * stride_xb
- + ic[None, :] * stride_xc
- + x_t * stride_xt
- + x_h * stride_xh
- + x_w * stride_xw
+ for k in range(0, K, BLOCK_K):
+ a = tl.load(
+ a_ptrs,
+ mask=(rm[:, None] < M) & (rk[None, :] + k < K),
+ other=0.0,
)
- a = tl.load(a_ptrs, mask=m_mask[:, None] & k_lo_mask[None, :], other=0.0)
-
- b = tl.load(w_ptrs_base, mask=k_lo_mask[:, None] & n_mask[None, :], other=0.0)
-
+ b = tl.load(
+ b_ptrs,
+ mask=(rk[:, None] + k < K) & (rn[None, :] < N),
+ other=0.0,
+ )
acc = tl.dot(a, b, acc)
-
- w_ptrs_base += BLOCK_K * stride_wk
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.bfloat16)
-
- c_ptrs = (
- out_ptr
- + b_idx[:, None] * stride_ob
- + rn[None, :] * stride_oc
- + t_out[:, None] * stride_ot
- + h_out[:, None] * stride_oh
- + w_out[:, None] * stride_ow
- )
- tl.store(c_ptrs, c, mask=m_mask[:, None] & n_mask[None, :])
+ c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
+ tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
class Model(nn.Module):
@@ -174,46 +124,33 @@
K = C * kT * kH * kW
M = B * nT * nH * nW
- out = torch.empty(B, N, nT, nH, nW, dtype=torch.bfloat16, device=x.device)
+ x_patches = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
+ x_patches = x_patches.permute(0, 2, 4, 6, 1, 3, 5, 7)
+ x_patches = x_patches.contiguous().reshape(M, K)
- w = self._w_reshaped
+ out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
- _fused_patch_gemm_kernel[grid](
- x,
- w,
- out,
- B,
- C,
- T,
- H,
- W,
- nT,
- nH,
- nW,
+ _gemm_kernel[grid](
+ x_patches,
+ self._w_reshaped,
+ out_2d,
+ M,
+ N,
K,
- N,
- kT,
- kH,
- kW,
- x.stride(0),
- x.stride(1),
- x.stride(2),
- x.stride(3),
- x.stride(4),
- w.stride(0),
- w.stride(1),
- out.stride(0),
- out.stride(1),
- out.stride(2),
- out.stride(3),
- out.stride(4),
+ x_patches.stride(0),
+ x_patches.stride(1),
+ self._w_reshaped
... (223 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<5544 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test correctness and benchmark optimized version",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 check.py 2>&1 && echo \"---\" && python3 benchmark.py 2>&1",
"timeout": 300000
}PASS
---
shape=0 variant=eager tflops=3.699 gbps=20.486 ms=0.208
shape=0 variant=compiled tflops=3.210 gbps=17.774 ms=0.240
shape=0 variant=sota tflops=3.720 gbps=20.603 ms=0.207
shape=0 variant=solution tflops=16.854 gbps=93.335 ms=0.046
shape=0 solution_peak_fraction=0.0843
shape=1 variant=eager tflops=13.990 gbps=36.489 ms=0.220
shape=1 variant=compiled tflops=12.124 gbps=31.621 ms=0.254
shape=1 variant=sota tflops=14.128 gbps=36.848 ms=0.218
shape=1 variant=solution tflops=66.326 gbps=172.987 ms=0.046
shape=1 solution_peak_fraction=0.3316
shape=2 variant=eager tflops=26.820 gbps=55.400 ms=0.259
shape=2 variant=compiled tflops=23.616 gbps=48.782 ms=0.294
shape=2 variant=sota tflops=26.206 gbps=54.131 ms=0.265
shape=2 variant=solution tflops=112.253 gbps=231.872 ms=0.062
shape=2 solution_peak_fraction=0.5613
shape=3 variant=eager tflops=6.484 gbps=25.158 ms=0.143
shape=3 variant=compiled tflops=5.271 gbps=20.448 ms=0.175
shape=3 variant=sota tflops=6.609 gbps=25.641 ms=0.140
shape=3 variant=solution tflops=27.911 gbps=108.284 ms=0.033
shape=3 solution_peak_fraction=0.1396
peak_fraction: 0.2163
RESULT: OK
{
"description": "Check what config autotuner picked",
"command": "<855 chars \u2014 see below>",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 -c "
import torch
import triton
import solution
# Force one shape through to see autotune
B,C,T,H,W = 1,3,8,336,336
kT,kH,kW = 2,14,14
embed_dim = 1280
solution.B, solution.C, solution.T, solution.H, solution.W = B,C,T,H,W
solution.kT, solution.kH, solution.kW = kT,kH,kW
solution.embed_dim = embed_dim
model = solution.Model(B,C,T,H,W,kT,kH,kW,embed_dim).cuda().eval()
x = solution.get_inputs()[0].cuda()
# Run once to trigger autotune
with torch.no_grad():
model(x)
torch.cuda.synchronize()
import os
cache_dir = os.path.expanduser('~/.triton/cache')
import json
# Find the autotune result
for root, dirs, files in os.walk(cache_dir):
for f in files:
if 'patch_gemm' in root or 'gemm_kernel' in root:
print(root, f)
" 2>&1 | head -30(no output){
"description": "Benchmark specific manual configs",
"command": "<3779 chars \u2014 see below>",
"timeout": 300000
}cat > /tmp/bench_manual.py << 'PYEOF'
import torch
import triton
import triton.language as tl
import time
@triton.jit
def bench_gemm_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(rm[:, None] < M) & (rk[None, :] + k < K), other=0.0)
b = tl.load(b_ptrs, mask=(rk[:, None] + k < K) & (rn[None, :] < N), other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.bfloat16)
c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
def time_full(M, N, K, BM, BN, BK, nw, nT, nH, nW, kT, kH, kW, iters=200):
# Simulate the full forward pass including im2col copy
C = K // (kT * kH * kW)
B = M // (nT * nH * nW)
x = torch.randn(B, C, nT*kT, nH*kH, nW*kW, dtype=torch.bfloat16, device='cuda')
w = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
def forward():
# im2col copy
x_p = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
x_p = x_p.permute(0, 2, 4, 6, 1, 3, 5, 7)
x_p = x_p.contiguous().reshape(M, K)
out = torch.empty(M, N, dtype=torch.bfloat16, device='cuda')
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
bench_gemm_kernel[grid](
x_p, w, out, M, N, K,
x_p.stride(0), x_p.stride(1),
w.stride(0), w.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=nw,
)
return out
for _ in range(30):
forward()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
forward()
torch.cuda.synchronize()
return (time.perf_counter() - start) / iters * 1000
shapes = [
("shape0", 256, 1280, 1176, 1, 16, 16, 2, 14, 14),
("shape1", 1024, 1280, 1176, 2, 16, 16, 2, 14, 14),
("shape2", 2304, 1280, 1176, 1, 24, 24, 2, 14, 14),
("shape3", 784, 768, 768, 4, 14, 14, 1, 16, 16),
]
print("Testing full forward (copy + GEMM):")
for name, M, N, K, B, nH, nW, kT, kH, kW in shapes:
print(f"\n--- {name}: M={M}, N={N}, K={K} ---")
configs_to_try = [
(64, 64, 32, 4),
(64, 64, 64, 4),
(64, 128, 32, 4),
(128, 64, 32, 4),
(64, 64, 32, 8),
(64, 64, 64, 8),
(64, 128, 32, 8),
(128, 64, 32, 8),
(128, 128, 32, 8),
(128, 128, 64, 8),
]
best = float('inf')
best_cfg = None
for BM, BN, BK, nw in configs_to_try:
if BM > M or BN > N:
continue
try:
ms = time_full(M, N, K, BM, BN, BK, nw, nT=B*nH*nW, nH=nH, nW=nW, kT=kT, kH=kH, kW=kW, iters=100)
if ms < best:
best = ms
best_cfg = (BM, BN, BK, nw)
print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw}: {ms:.4f} ms")
except Exception as e:
print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw}: ERROR {e}")
print(f" BEST: {best_cfg} = {best:.4f} ms")
PYEOF
cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 /tmp/bench_manual.py 2>&1Testing full forward (copy + GEMM):
--- shape0: M=256, N=1280, K=1176 ---
BM= 64 BN= 64 BK= 32 nw=4: ERROR shape '[256, 1176]' is invalid for input of size 0
BM= 64 BN= 64 BK= 64 nw=4: ERROR shape '[256, 1176]' is invalid for input of size 0
BM= 64 BN=128 BK= 32 nw=4: ERROR shape '[256, 1176]' is invalid for input of size 0
BM=128 BN= 64 BK= 32 nw=4: ERROR shape '[256, 1176]' is invalid for input of size 0
BM= 64 BN= 64 BK= 32 nw=8: ERROR shape '[256, 1176]' is invalid for input of size 0
BM= 64 BN= 64 BK= 64 nw=8: ERROR shape '[256, 1176]' is invalid for input of size 0
BM= 64 BN=128 BK= 32 nw=8: ERROR shape '[256, 1176]' is invalid for input of size 0
BM=128 BN= 64 BK= 32 nw=8: ERROR shape '[256, 1176]' is invalid for input of size 0
BM=128 BN=128 BK= 32 nw=8: ERROR shape '[256, 1176]' is invalid for input of size 0
BM=128 BN=128 BK= 64 nw=8: ERROR shape '[256, 1176]' is invalid for input of size 0
BEST: None = inf ms
--- shape1: M=1024, N=1280, K=1176 ---
BM= 64 BN= 64 BK= 32 nw=4: ERROR shape '[1024, 1176]' is invalid for input of size 0
BM= 64 BN= 64 BK= 64 nw=4: ERROR shape '[1024, 1176]' is invalid for input of size 0
BM= 64 BN=128 BK= 32 nw=4: ERROR shape '[1024, 1176]' is invalid for input of size 0
BM=128 BN= 64 BK= 32 nw=4: ERROR shape '[1024, 1176]' is invalid for input of size 0
BM= 64 BN= 64 BK= 32 nw=8: ERROR shape '[1024, 1176]' is invalid for input of size 0
BM= 64 BN= 64 BK= 64 nw=8: ERROR shape '[1024, 1176]' is invalid f
... (2230 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,12 +1,13 @@
"""Optimized patch-gather + tensor-core GEMM for 3D patch embedding.
-Strategy (proven fastest):
+Strategy (proven fastest on RTX PRO 6000):
1. Pre-reshape weight once (K, embed_dim) contiguous in __init__.
2. Gather patches from x (B,C,T,H,W) -> (M,K) contiguous via permute+contiguous.
- The copy is cheap (M*K*2 bytes, contiguous write after strided read) and
- enables coalesced A-tile loads in the GEMM.
+ The copy enables coalesced A-tile loads in the GEMM.
3. Run triton GEMM on (M,K) @ (K,N) -> (M,N) using bf16 MMA via tl.dot.
4. Reshape output (M,N) -> (B, embed_dim, nT, nH, nW) as a view (no copy).
+
+Tile configurations are hand-picked per shape based on profiling.
"""
import torch
@@ -19,22 +20,6 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
- triton.Config({"BLOCK_M": 32, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
- triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8),
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8),
- ],
- key=["M", "N", "K"],
-)
@triton.jit
def _gemm_kernel(
a_ptr,
@@ -86,6 +71,19 @@
tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
+def _pick_config(M: int, N: int, K: int):
+ """Return (BLOCK_M, BLOCK_N, BLOCK_K, num_warps) for this shape."""
+ # Per-shape optimal configs from profiling on RTX PRO 6000 (sm_120)
+ if K < 900:
+ # ViT-B/16 style: small K, square-ish M/N
+ return (64, 128, 32, 8)
+ if M <= 400:
+ return (64, 64, 64, 4)
+ if M <= 1200:
+ return (64, 64, 32, 4)
+ return (64, 128, 32, 4)
+
+
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
@@ -130,10 +128,9 @@
out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
- grid = lambda meta: (
- triton.cdiv(M, meta["BLOCK_M"]),
- triton.cdiv(N, meta["BLOCK_N"]),
- )
+ BLOCK_M, BLOCK_N, BLOCK_K, num_warps = _pick_config(M, N, K)
+
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
_gemm_kernel[grid](
x_patches,
@@ -148,6 +145,10 @@
self._w_reshaped.stride(1),
out_2d.stride(0),
out_2d.stride(1),
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ BLOCK_K=BLOCK_K,
+ num_warps=num_warps,
)
out = out_2d.reshape(B, nT, nH, nW, N).permute(0, 4, 1, 2, 3)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<5105 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test correctness then benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 check.py 2>&1 && echo \"---\" && python3 benchmark.py 2>&1",
"timeout": 300000
}PASS
---
shape=0 variant=eager tflops=3.701 gbps=20.494 ms=0.208
shape=0 variant=compiled tflops=3.189 gbps=17.662 ms=0.242
shape=0 variant=sota tflops=3.739 gbps=20.706 ms=0.206
shape=0 variant=solution tflops=18.641 gbps=103.232 ms=0.041
shape=0 solution_peak_fraction=0.0932
shape=1 variant=eager tflops=13.994 gbps=36.500 ms=0.220
shape=1 variant=compiled tflops=12.172 gbps=31.745 ms=0.253
shape=1 variant=sota tflops=14.170 gbps=36.959 ms=0.218
shape=1 variant=solution tflops=63.821 gbps=166.455 ms=0.048
shape=1 solution_peak_fraction=0.3191
shape=2 variant=eager tflops=26.601 gbps=54.948 ms=0.261
shape=2 variant=compiled tflops=23.226 gbps=47.977 ms=0.299
shape=2 variant=sota tflops=26.782 gbps=55.321 ms=0.259
shape=2 variant=solution tflops=120.255 gbps=248.402 ms=0.058
shape=2 solution_peak_fraction=0.6013
shape=3 variant=eager tflops=6.453 gbps=25.037 ms=0.143
shape=3 variant=compiled tflops=5.146 gbps=19.964 ms=0.180
shape=3 variant=sota tflops=6.566 gbps=25.472 ms=0.141
shape=3 variant=solution tflops=27.137 gbps=105.285 ms=0.034
shape=3 solution_peak_fraction=0.1357
peak_fraction: 0.2219
RESULT: OK
{
"description": "Test different configs for shape0 specifically",
"command": "<3338 chars \u2014 see below>",
"timeout": 120000
}cat > /tmp/bench_shape0.py << 'PYEOF'
import torch
import triton
import triton.language as tl
import time
@triton.jit
def gemm_kernel(a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(rm[:, None] < M) & (rk[None, :] + k < K), other=0.0)
b = tl.load(b_ptrs, mask=(rk[:, None] + k < K) & (rn[None, :] < N), other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.bfloat16)
c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
# shape0 parameters
M, N, K = 256, 1280, 1176
B, C, nT, nH, nW, kT, kH, kW = 1, 3, 1, 16, 16, 2, 14, 14
x = torch.randn(B, C, nT*kT, nH*kH, nW*kW, dtype=torch.bfloat16, device='cuda') * 0.5
w = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
# Pre-warm to avoid first-run overhead
for _ in range(50):
x_p = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
x_p = x_p.permute(0, 2, 4, 6, 1, 3, 5, 7)
x_p = x_p.contiguous().reshape(M, K)
torch.cuda.synchronize()
configs = [
(32, 64, 32, 4), (32, 64, 64, 4), (32, 128, 32, 4),
(64, 64, 32, 4), (64, 64, 64, 4), (64, 128, 32, 4),
(128, 64, 32, 4), (128, 128, 32, 4),
(32, 64, 32, 8), (32, 128, 32, 8),
(64, 64, 32, 8), (64, 128, 32, 8),
(128, 64, 32, 8), (128, 128, 32, 8),
(256, 32, 32, 8), (256, 64, 32, 8),
]
best = float('inf')
best_cfg = None
for BM, BN, BK, nw in configs:
if BM > M or BN > N:
continue
# Full forward sim: alloc x_p, ge_alloc out, GEMM
def forward():
x_p = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
x_p = x_p.permute(0, 2, 4, 6, 1, 3, 5, 7)
x_p = x_p.contiguous().reshape(M, K)
out = torch.empty(M, N, dtype=torch.bfloat16, device='cuda')
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
gemm_kernel[grid](x_p, w, out, M, N, K,
x_p.stride(0), x_p.stride(1),
w.stride(0), w.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=nw)
return out
for _ in range(30):
forward()
torch.cuda.synchronize()
start = time.perf_counter()
iters = 200
for _ in range(iters):
forward()
torch.cuda.synchronize()
ms = (time.perf_counter() - start) / iters * 1000
if ms < best:
best = ms
best_cfg = (BM, BN, BK, nw)
print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw}: {ms:.4f} ms")
print(f"BEST: {best_cfg} = {best:.4f} ms")
PYEOF
cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 /tmp/bench_shape0.py 2>&1BM= 32 BN= 64 BK= 32 nw=4: 0.0158 ms
BM= 32 BN= 64 BK= 64 nw=4: 0.0173 ms
BM= 32 BN=128 BK= 32 nw=4: 0.0174 ms
BM= 64 BN= 64 BK= 32 nw=4: 0.0168 ms
BM= 64 BN= 64 BK= 64 nw=4: 0.0158 ms
BM= 64 BN=128 BK= 32 nw=4: 0.0209 ms
BM=128 BN= 64 BK= 32 nw=4: 0.0260 ms
BM=128 BN=128 BK= 32 nw=4: 0.0355 ms
BM= 32 BN= 64 BK= 32 nw=8: 0.0158 ms
BM= 32 BN=128 BK= 32 nw=8: 0.0171 ms
BM= 64 BN= 64 BK= 32 nw=8: 0.0162 ms
BM= 64 BN=128 BK= 32 nw=8: 0.0225 ms
BM=128 BN= 64 BK= 32 nw=8: 0.0214 ms
BM=128 BN=128 BK= 32 nw=8: 0.0292 ms
BM=256 BN= 32 BK= 32 nw=8: 0.0308 ms
BM=256 BN= 64 BK= 32 nw=8: 0.0354 ms
BEST: (32, 64, 32, 4) = 0.0158 ms
{
"description": "Test pre-allocation approach",
"command": "<3350 chars \u2014 see below>",
"timeout": 30000
}cat > /tmp/bench_prealloc.py << 'PYEOF'
import torch
import triton
import triton.language as tl
import time
@triton.jit
def gemm_kernel(a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(rm[:, None] < M) & (rk[None, :] + k < K), other=0.0)
b = tl.load(b_ptrs, mask=(rk[:, None] + k < K) & (rn[None, :] < N), other=0.0)
acc = tl.dot(a, b, acc)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.bfloat16)
c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
# shape0
M, N, K = 256, 1280, 1176
B, C, nT, nH, nW, kT, kH, kW = 1, 3, 1, 16, 16, 2, 14, 14
BM, BN, BK, nw = 64, 64, 64, 4
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
x = torch.randn(B, C, nT*kT, nH*kH, nW*kW, dtype=torch.bfloat16, device='cuda') * 0.5
w = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
# --- Method 1: allocate per call (current) ---
def forward_alloc():
x_p = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
x_p = x_p.permute(0, 2, 4, 6, 1, 3, 5, 7)
x_p = x_p.contiguous().reshape(M, K)
out = torch.empty(M, N, dtype=torch.bfloat16, device='cuda')
gemm_kernel[grid](x_p, w, out, M, N, K,
x_p.stride(0), x_p.stride(1),
w.stride(0), w.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=nw)
return out
# --- Method 2: pre-allocated buffers ---
buf_a = torch.empty(M, K, dtype=torch.bfloat16, device='cuda')
buf_c = torch.empty(M, N, dtype=torch.bfloat16, device='cuda')
def forward_prealloc():
src_view = x.reshape(B, C, nT, kT, nH, kH, nW, kW).permute(0, 2, 4, 6, 1, 3, 5, 7)
buf_a.view(-1)[:M*K].copy_(src_view.reshape(-1)[:M*K])
gemm_kernel[grid](buf_a, w, buf_c, M, N, K,
buf_a.stride(0), buf_a.stride(1),
w.stride(0), w.stride(1),
buf_c.stride(0), buf_c.stride(1),
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=nw)
return buf_c
def time_it(fn, warmup=30, iters=200):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
fn()
torch.cuda.synchronize()
return (time.perf_counter() - start) / iters * 1000
print(f"alloc: {time_it(forward_alloc):.4f} ms")
print(f"prealloc: {time_it(forward_prealloc):.4f} ms")
# Verify correctness
out1 = forward_alloc()
out2 = forward_prealloc()
print(f"match: {torch.allclose(out1, out2, atol=1e-2, rtol=1e-2)}")
PYEOF
cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 /tmp/bench_prealloc.py 2>&1alloc: 0.0149 ms
prealloc: 0.0184 ms
match: True
{
"description": "Isolate the copy overhead precisely",
"command": "<2013 chars \u2014 see below>",
"timeout": 30000
}cat > /tmp/bench_copy.py << 'PYEOF'
import torch
import time
M, K = 256, 1176
B, C, nT, nH, nW, kT, kH, kW = 1, 3, 1, 16, 16, 2, 14, 14
x = torch.randn(B, C, nT*kT, nH*kH, nW*kW, dtype=torch.bfloat16, device='cuda') * 0.5
# Test: time just the copy
def copy_only():
x_p = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
x_p = x_p.permute(0, 2, 4, 6, 1, 3, 5, 7)
return x_p.contiguous().reshape(M, K)
# Warmup
for _ in range(100):
copy_only()
torch.cuda.synchronize()
# Time the copy
start = time.perf_counter()
for _ in range(500):
copy_only()
torch.cuda.synchronize()
ms_copy = (time.perf_counter() - start) / 500 * 1000
print(f"Copy only: {ms_copy:.4f} ms")
# Test: time empty allocation
def alloc_only():
return torch.empty(M, K, dtype=torch.bfloat16, device='cuda')
for _ in range(100):
alloc_only()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(500):
alloc_only()
torch.cuda.synchronize()
ms_alloc = (time.perf_counter() - start) / 500 * 1000
print(f"empty alloc: {ms_alloc:.4f} ms")
# Test: time empty + copy
def alloc_copy():
out = torch.empty(M, K, dtype=torch.bfloat16, device='cuda')
x_p = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
x_view = x_p.permute(0, 2, 4, 6, 1, 3, 5, 7)
out.view(-1)[:M*K].copy_(x_view.reshape(-1)[:M*K])
return out
for _ in range(100):
alloc_copy()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(500):
alloc_copy()
torch.cuda.synchronize()
ms_ac = (time.perf_counter() - start) / 500 * 1000
print(f"alloc+copy: {ms_ac:.4f} ms")
# Check data_ptr to confirm copy
x_p = x.reshape(B, C, nT, kT, nH, kH, nW, kW).permute(0, 2, 4, 6, 1, 3, 5, 7)
result = x_p.contiguous().reshape(M, K)
print(f"data_ptr same as x: {result.data_ptr() == x.data_ptr()}")
print(f"result contiguous: {result.is_contiguous()}")
print(f"result size: {result.numel() * 2 / 1024:.1f} KB")
PYEOF
cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 /tmp/bench_copy.py 2>&1Copy only: 0.0058 ms
empty alloc: 0.0011 ms
alloc+copy: 0.0111 ms
data_ptr same as x: False
result contiguous: True
result size: 588.0 KB
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -1,13 +1,14 @@
"""Optimized patch-gather + tensor-core GEMM for 3D patch embedding.
-Strategy (proven fastest on RTX PRO 6000):
+Strategy (fastest on RTX PRO 6000 Blackwell sm_120):
1. Pre-reshape weight once (K, embed_dim) contiguous in __init__.
2. Gather patches from x (B,C,T,H,W) -> (M,K) contiguous via permute+contiguous.
The copy enables coalesced A-tile loads in the GEMM.
-3. Run triton GEMM on (M,K) @ (K,N) -> (M,N) using bf16 MMA via tl.dot.
+3. Run statically-configured triton GEMM with bf16 MMA via tl.dot.
4. Reshape output (M,N) -> (B, embed_dim, nT, nH, nW) as a view (no copy).
-Tile configurations are hand-picked per shape based on profiling.
+All per-forward-call overhead is minimized: config, strides, and grid
+dimensions are pre-computed in __init__.
"""
import torch
@@ -71,28 +72,12 @@
tl.store(c_ptrs, c, mask=(rm[:, None] < M) & (rn[None, :] < N))
-def _pick_config(M: int, N: int, K: int):
- """Return (BLOCK_M, BLOCK_N, BLOCK_K, num_warps) for this shape."""
- # Per-shape optimal configs from profiling on RTX PRO 6000 (sm_120)
- if K < 900:
- # ViT-B/16 style: small K, square-ish M/N
- return (64, 128, 32, 8)
- if M <= 400:
- return (64, 64, 64, 4)
- if M <= 1200:
- return (64, 64, 32, 4)
- return (64, 128, 32, 4)
-
-
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
f"Input dims must be divisible by patch size"
- self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
- self.kT, self.kH, self.kW = kT, kH, kW
- self.embed_dim = embed_dim
self.conv = nn.Conv3d(
C, embed_dim,
@@ -107,52 +92,71 @@
w = self.conv.weight.reshape(embed_dim, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
+ self._nT = T // kT
+ self._nH = H // kH
+ self._nW = W // kW
+ self._N = embed_dim
+ self._K = K_dim
+ self._M = B * self._nT * self._nH * self._nW
+
+ self._BLOCK_M, self._BLOCK_N, self._BLOCK_K, self._num_warps = \
+ _pick_config(self._M, self._N, self._K)
+ self._grid = (
+ triton.cdiv(self._M, self._BLOCK_M),
+ triton.cdiv(self._N, self._BLOCK_N),
+ )
+
def load_state_dict(self, state_dict, strict=True):
result = super().load_state_dict(state_dict, strict=strict)
- K_dim = self.C * self.kT * self.kH * self.kW
- w = self.conv.weight.reshape(self.embed_dim, K_dim).t().contiguous()
+ K_dim = self._K
+ w = self.conv.weight.reshape(self._N, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
return result
def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, C, T, H, W = x.shape
- kT, kH, kW = self.kT, self.kH, self.kW
- nT, nH, nW = T // kT, H // kH, W // kW
- N = self.embed_dim
- K = C * kT * kH * kW
- M = B * nT * nH * nW
+ x_patches = x.reshape(
+ x.shape[0], x.shape[1],
+ self._nT, self.conv.kernel_size[0],
+ self._nH, self.conv.kernel_size[1],
+ self._nW, self.conv.kernel_size[2],
+ ).permute(0, 2, 4, 6, 1, 3, 5, 7)
+ x_patches = x_patches.contiguous().reshape(self._M, self._K)
- x_patches = x.reshape(B, C, nT, kT, nH, kH, nW, kW)
- x_patches = x_patches.permute(0, 2, 4, 6, 1, 3, 5, 7)
- x_patches = x_patches.contiguous().reshape(M, K)
+ out = torch.empty(self._M, self._N, dtype=torch.bfloat16, device=x.device)
- out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
-
- BLOCK_M, BLOCK_N, BLOCK_K, num_warps = _pick_config(M, N, K)
-
- grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
-
- _gemm_kernel[grid](
+ _gemm_kernel[self._grid](
x_patches,
self._w_reshaped,
- out_2d,
- M,
- N,
- K,
- x_patches.stride(0),
- x_patches.stride(1),
- self._w_reshaped.stride(0),
- self._w_reshaped.stride(1),
- out_2d.stride(0),
- out_2d.stride(1),
- BLOCK_M=BLOCK_M,
- BLOCK_N=BLOCK_N,
- BLOCK_K=BLOCK_K,
- num_warps=num_warps,
+ out,
+ self._M,
+ self._N,
+ self._K,
+ self._K,
+ 1,
+ self._N,
+ 1,
+ self._N,
+ 1,
+ BLOCK_M=self._BLOCK_M,
+ BLOCK_N=self._BLOCK_N,
+ BLOCK_K=self._BLOCK_K,
+ num_warps=self._num_warps,
)
- out = out_2d.reshape(B, nT, nH, nW, N).permute(0, 4, 1, 2, 3)
+ out = out.reshape(x.shape[0], self._nT, self._nH, self._nW, self._N)
+ out = out.permute(0, 4, 1, 2, 3)
return out
+
+
+def _pick_config(M: int, N: int, K: int):
+ """Return (BLOCK_M, BLOCK_N, BLOCK_K, num_warps) for this shape."""
+ if K < 900:
+ return (64, 128, 32, 8)
+ if M <= 350:
+ return (32, 64, 32, 4)
+ if M <= 1200:
+ return (64, 64, 32, 4)
+ return (64, 128, 32, 4)
B = 1
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<5094 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Test correctness and benchmark with micro-optimizations",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 check.py 2>&1 && echo \"---\" && python3 benchmark.py 2>&1",
"timeout": 300000
}PASS
---
shape=0 variant=eager tflops=3.709 gbps=20.540 ms=0.208
shape=0 variant=compiled tflops=3.225 gbps=17.858 ms=0.239
shape=0 variant=sota tflops=3.744 gbps=20.736 ms=0.206
shape=0 variant=solution tflops=16.633 gbps=92.110 ms=0.046
shape=0 solution_peak_fraction=0.0832
shape=1 variant=eager tflops=13.978 gbps=36.457 ms=0.221
shape=1 variant=compiled tflops=12.209 gbps=31.842 ms=0.253
shape=1 variant=sota tflops=14.168 gbps=36.953 ms=0.218
shape=1 variant=solution tflops=65.093 gbps=169.773 ms=0.047
shape=1 solution_peak_fraction=0.3255
shape=2 variant=eager tflops=26.611 gbps=54.968 ms=0.261
shape=2 variant=compiled tflops=23.188 gbps=47.897 ms=0.299
shape=2 variant=sota tflops=26.807 gbps=55.373 ms=0.259
shape=2 variant=solution tflops=119.790 gbps=247.441 ms=0.058
shape=2 solution_peak_fraction=0.5990
shape=3 variant=eager tflops=6.478 gbps=25.132 ms=0.143
shape=3 variant=compiled tflops=5.265 gbps=20.426 ms=0.176
shape=3 variant=sota tflops=6.617 gbps=25.670 ms=0.140
shape=3 variant=solution tflops=28.829 gbps=111.848 ms=0.032
shape=3 solution_peak_fraction=0.1441
peak_fraction: 0.2199
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -7,8 +7,8 @@
3. Run statically-configured triton GEMM with bf16 MMA via tl.dot.
4. Reshape output (M,N) -> (B, embed_dim, nT, nH, nW) as a view (no copy).
-All per-forward-call overhead is minimized: config, strides, and grid
-dimensions are pre-computed in __init__.
+All per-forward-call overhead is minimized: config, strides, grid dimensions,
+and shape constants are pre-computed and stored as attributes.
"""
import torch
@@ -92,9 +92,17 @@
w = self.conv.weight.reshape(embed_dim, K_dim).t().contiguous()
self.register_buffer("_w_reshaped", w, persistent=False)
+ self._cached_B = B
+ self._cached_C = C
+ self._cached_T = T
+ self._cached_H = H
+ self._cached_W = W
self._nT = T // kT
self._nH = H // kH
self._nW = W // kW
+ self._kT = kT
+ self._kH = kH
+ self._kW = kW
self._N = embed_dim
self._K = K_dim
self._M = B * self._nT * self._nH * self._nW
@@ -114,18 +122,19 @@
return result
def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_patches = x.reshape(
- x.shape[0], x.shape[1],
- self._nT, self.conv.kernel_size[0],
- self._nH, self.conv.kernel_size[1],
- self._nW, self.conv.kernel_size[2],
- ).permute(0, 2, 4, 6, 1, 3, 5, 7)
- x_patches = x_patches.contiguous().reshape(self._M, self._K)
+ x_p = x.reshape(
+ self._cached_B, self._cached_C,
+ self._nT, self._kT,
+ self._nH, self._kH,
+ self._nW, self._kW,
+ )
+ x_p = x_p.permute(0, 2, 4, 6, 1, 3, 5, 7)
+ x_p = x_p.contiguous().reshape(self._M, self._K)
out = torch.empty(self._M, self._N, dtype=torch.bfloat16, device=x.device)
_gemm_kernel[self._grid](
- x_patches,
+ x_p,
self._w_reshaped,
out,
self._M,
@@ -143,7 +152,7 @@
num_warps=self._num_warps,
)
- out = out.reshape(x.shape[0], self._nT, self._nH, self._nW, self._N)
+ out = out.reshape(self._cached_B, self._nT, self._nH, self._nW, self._N)
out = out.permute(0, 4, 1, 2, 3)
return out
@@ -153,7 +162,7 @@
if K < 900:
return (64, 128, 32, 8)
if M <= 350:
- return (32, 64, 32, 4)
+ return (64, 64, 64, 4)
if M <= 1200:
return (64, 64, 32, 4)
return (64, 128, 32, 4)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<5283 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Final correctness and benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 check.py 2>&1 && echo \"---\" && python3 benchmark.py 2>&1",
"timeout": 300000
}PASS
---
shape=0 variant=eager tflops=3.711 gbps=20.549 ms=0.208
shape=0 variant=compiled tflops=3.212 gbps=17.786 ms=0.240
shape=0 variant=sota tflops=3.733 gbps=20.672 ms=0.206
shape=0 variant=solution tflops=20.342 gbps=112.649 ms=0.038
shape=0 solution_peak_fraction=0.1017
shape=1 variant=eager tflops=14.002 gbps=36.518 ms=0.220
shape=1 variant=compiled tflops=12.101 gbps=31.562 ms=0.255
shape=1 variant=sota tflops=14.137 gbps=36.872 ms=0.218
shape=1 variant=solution tflops=66.601 gbps=173.705 ms=0.046
shape=1 solution_peak_fraction=0.3330
shape=2 variant=eager tflops=26.583 gbps=54.911 ms=0.261
shape=2 variant=compiled tflops=23.216 gbps=47.956 ms=0.299
shape=2 variant=sota tflops=26.756 gbps=55.267 ms=0.259
shape=2 variant=solution tflops=121.129 gbps=250.206 ms=0.057
shape=2 solution_peak_fraction=0.6056
shape=3 variant=eager tflops=6.479 gbps=25.135 ms=0.143
shape=3 variant=compiled tflops=5.241 gbps=20.335 ms=0.176
shape=3 variant=sota tflops=6.578 gbps=25.521 ms=0.141
shape=3 variant=solution tflops=30.184 gbps=117.105 ms=0.031
shape=3 solution_peak_fraction=0.1509
peak_fraction: 0.2359
RESULT: OK
{
"description": "Verify stability and framework label",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 benchmark.py 2>&1 | grep -E \"(peak_fraction|RESULT)\" && echo \"---\" && cat framework.txt",
"timeout": 120000
}shape=0 solution_peak_fraction=0.1021
shape=1 solution_peak_fraction=0.3202
shape=2 solution_peak_fraction=0.6146
shape=3 solution_peak_fraction=0.1556
peak_fraction: 0.2365
RESULT: OK
---
triton
{
"todos": [
{
"priority": "high",
"content": "Read reference.py, check.py, benchmark.py, and understand problem structure",
"status": "completed"
},
{
"priority": "high",
"content": "Read existing solution.py if any",
"status": "completed"
},
{
"priority": "high",
"content": "Check existing examples and hardware config for patterns",
"status": "completed"
},
{
"priority": "high",
"content": "Design and implement optimized Triton kernel for 3D patch embedding",
"status": "completed"
},
... (373 more chars)content: 'Read reference.py, check.py, benchmark.py, and understand problem structure'
status: 'completed'
priority: 'high'
content: 'Read existing solution.py if any'
status: 'completed'
priority: 'high'
content: 'Check existing examples and hardware config for patterns'
status: 'completed'
priority: 'high'
content: 'Design and implement optimized Triton kernel for 3D patch embedding'
status: 'completed'
priority: 'high'
content: 'Test correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Benchmark performance with benchmark.py'
status: 'completed'
priority: 'high'
content: 'Run linting and verify final solution'
status: 'in_progress'
priority: 'high'{
"description": "Run linting on solution",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/10_patch_embed_conv3d_gemm/solution.py --fix 2>&1"
}Found 1 error (1 fixed, 0 remaining).
{
"description": "Final ruff check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/10_patch_embed_conv3d_gemm/solution.py 2>&1"
}All checks passed!
{
"description": "Final correctness check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 check.py 2>&1"
}PASS
{
"description": "Final benchmark run",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && python3 benchmark.py 2>&1",
"timeout": 120000
}shape=0 variant=eager tflops=3.700 gbps=20.488 ms=0.208
shape=0 variant=compiled tflops=3.218 gbps=17.823 ms=0.239
shape=0 variant=sota tflops=3.743 gbps=20.728 ms=0.206
shape=0 variant=solution tflops=20.316 gbps=112.506 ms=0.038
shape=0 solution_peak_fraction=0.1016
shape=1 variant=eager tflops=13.994 gbps=36.500 ms=0.220
shape=1 variant=compiled tflops=12.172 gbps=31.747 ms=0.253
shape=1 variant=sota tflops=14.154 gbps=36.915 ms=0.218
shape=1 variant=solution tflops=66.189 gbps=172.631 ms=0.047
shape=1 solution_peak_fraction=0.3309
shape=2 variant=eager tflops=26.598 gbps=54.941 ms=0.261
shape=2 variant=compiled tflops=23.172 gbps=47.864 ms=0.299
shape=2 variant=sota tflops=26.777 gbps=55.311 ms=0.259
shape=2 variant=solution tflops=121.366 gbps=250.697 ms=0.057
shape=2 solution_peak_fraction=0.6068
shape=3 variant=eager tflops=6.475 gbps=25.121 ms=0.143
shape=3 variant=compiled tflops=5.242 gbps=20.339 ms=0.176
shape=3 variant=sota tflops=6.580 gbps=25.527 ms=0.141
shape=3 variant=solution tflops=30.455 gbps=118.154 ms=0.030
shape=3 solution_peak_fraction=0.1523
peak_fraction: 0.2361
RESULT: OK
{
"todos": [
{
"priority": "high",
"content": "Read reference.py, check.py, benchmark.py, and understand problem structure",
"status": "completed"
},
{
"priority": "high",
"content": "Read existing solution.py if any",
"status": "completed"
},
{
"priority": "high",
"content": "Check existing examples and hardware config for patterns",
"status": "completed"
},
{
"priority": "high",
"content": "Design and implement optimized Triton kernel for 3D patch embedding",
"status": "completed"
},
... (371 more chars)content: 'Read reference.py, check.py, benchmark.py, and understand problem structure'
status: 'completed'
priority: 'high'
content: 'Read existing solution.py if any'
status: 'completed'
priority: 'high'
content: 'Check existing examples and hardware config for patterns'
status: 'completed'
priority: 'high'
content: 'Design and implement optimized Triton kernel for 3D patch embedding'
status: 'completed'
priority: 'high'
content: 'Test correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Benchmark performance with benchmark.py'
status: 'completed'
priority: 'high'
content: 'Run linting and verify final solution'
status: 'completed'
priority: 'high'