"""Fused patch-gather + tensor-core GEMM kernel for 3D patch embedding.
Uses Triton with tl.dot to fuse the strided patch gather into the GEMM K-loop,
avoiding materialization of the im2col buffer. The patch-gather offsets decompose
into a rank-1 sum of m-dependent and k-dependent terms, enabling efficient
2D pointer arithmetic.
Output layout: (B, embed_dim, T//kT, H//kH, W//kW). No bias.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 128}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
],
key=['K', 'N', 'M_total'],
)
@triton.jit
def patch_embed_kernel(
x_ptr, w_ptr, out_ptr,
nPT, nPH, nPW,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_out_b, stride_out_n, stride_out_pt, stride_out_ph, stride_out_pw,
K, N,
kT, kH, kW,
M_total,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_n_blocks = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_n_blocks
pid_n = pid % num_n_blocks
m_start = pid_m * BLOCK_M
n_start = pid_n * BLOCK_N
m_off = m_start + tl.arange(0, BLOCK_M)
n_off = n_start + tl.arange(0, BLOCK_N)
patches_per_batch = nPT * nPH * nPW
pw_m = m_off % nPW
ph_m = (m_off // nPW) % nPH
pt_m = (m_off // (nPW * nPH)) % nPT
b_m = m_off // patches_per_batch
m_mask = m_off < M_total
n_mask = n_off < N
m_part = b_m * stride_x_b + pt_m * kT * stride_x_t + ph_m * kH * stride_x_h + pw_m * kW * stride_x_w
out_m_part = b_m * stride_out_b + pt_m * stride_out_pt + ph_m * stride_out_ph + pw_m * stride_out_pw
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
kTkHkW = kT * kH * kW
kHkW = kH * kW
for k_start in range(0, K, BLOCK_K):
k_off = k_start + tl.arange(0, BLOCK_K)
kw_k = k_off % kW
kh_k = (k_off // kW) % kH
kt_k = (k_off // kHkW) % kT
c_k = k_off // kTkHkW
k_mask = k_off < K
k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
a_mask = m_mask[:, None] * k_mask[None, :]
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
w_mask = n_mask[None, :] * k_mask[:, None]
w = tl.load(w_ptrs, mask=w_mask, other=0.0, eviction_policy='evict_last')
acc += tl.dot(a, w)
out_ptrs = out_ptr + out_m_part[:, None] + n_off[None, :] * stride_out_n
out_mask = m_mask[:, None] * n_mask[None, :]
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
self.kT, self.kH, self.kW = kT, kH, kW
self.embed_dim = embed_dim
self.conv = nn.Conv3d(
C, embed_dim,
kernel_size=(kT, kH, kW),
stride=(kT, kH, kW),
bias=False,
dtype=torch.bfloat16,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T, H, W = x.shape
kT, kH, kW = self.kT, self.kH, self.kW
embed_dim = self.embed_dim
nPT = T // kT
nPH = H // kH
nPW = W // kW
M_total = B * nPT * nPH * nPW
K = C * kT * kH * kW
N = embed_dim
out = torch.empty(B, embed_dim, nPT, nPH, nPW, dtype=torch.bfloat16, device=x.device)
stride_x = x.stride()
stride_out = out.stride()
grid = lambda META: (triton.cdiv(M_total, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
patch_embed_kernel[grid](
x, self.conv.weight, out,
nPT, nPH, nPW,
stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
K, N,
kT, kH, kW,
M_total,
)
return out
B = 1
C = 3
T = 2
H = 224
W = 224
kT = 2
kH = 14
kW = 14
embed_dim = 1280
def get_inputs():
x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
return [x]
def get_init_inputs():
return [B, C, T, H, W, kT, kH, kW, embed_dim]
shape=0 variant=eager tflops=3.686 gbps=20.411 ms=0.209
shape=0 variant=compiled tflops=3.110 gbps=17.225 ms=0.248
shape=0 variant=sota tflops=3.731 gbps=20.661 ms=0.207
shape=0 variant=solution tflops=16.725 gbps=92.622 ms=0.046
shape=0 solution_peak_fraction=0.0836
shape=1 variant=eager tflops=14.218 gbps=37.081 ms=0.217
shape=1 variant=compiled tflops=12.598 gbps=32.858 ms=0.245
shape=1 variant=sota tflops=14.360 gbps=37.452 ms=0.215
shape=1 variant=solution tflops=62.314 gbps=162.525 ms=0.049
shape=1 solution_peak_fraction=0.3116
shape=2 variant=eager tflops=26.168 gbps=54.053 ms=0.265
shape=2 variant=compiled tflops=23.402 gbps=48.339 ms=0.296
shape=2 variant=sota tflops=26.858 gbps=55.479 ms=0.258
shape=2 variant=solution tflops=54.925 gbps=113.453 ms=0.126
shape=2 solution_peak_fraction=0.2746
shape=3 variant=eager tflops=6.435 gbps=24.964 ms=0.144
shape=3 variant=compiled tflops=5.082 gbps=19.717 ms=0.182
shape=3 variant=sota tflops=6.567 gbps=25.478 ms=0.141
shape=3 variant=solution tflops=30.439 gbps=118.092 ms=0.030
shape=3 solution_peak_fraction=0.1522
peak_fraction: 0.1817
RESULT: OK
PASS
{
"problem": "10_patch_embed_conv3d_gemm",
"harness": "opencode",
"model": "zai/glm-5.1",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1817,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 138002, "output_tokens": 16156, "cache_read_tokens": 2327296, "cache_creation_tokens": 0, "reasoning_tokens": 74013, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and problem files",
"status": "in_progress",
"priority": "high"
},
{
"content": "Load kernel optimization skill for SM120/Blackwell guidance",
"status": "pending",
"priority": "high"
},
{
"content": "Implement custom CUDA kernel for 3D patch embedding",
"status": "pending",
"priority": "high"
},
{
"content": "Verify correctness with check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and
... (221 more chars)content: 'Read reference.py, check.py, benchmark.py, and problem files'
status: 'in_progress'
priority: 'high'
content: 'Load kernel optimization skill for SM120/Blackwell guidance'
status: 'pending'
priority: 'high'
content: 'Implement custom CUDA kernel for 3D patch embedding'
status: 'pending'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'high'
content: 'Benchmark with benchmark.py and iterate'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive Conv3d-as-GEMM patch embedding reference (correctness only).
2:
3: Vision-Transformer / Qwen2-VL style patch embedding: a video tensor
4: (B, C, T, H, W) is split into non-overlapping (kT, kH, kW) patches and each
5: patch projected to embed_dim. This is mathematically a 3D convolution with
6: stride == kernel and equivalently a single GEMM after a strided reshape.
7:
8: Reference uses nn.Conv3d (cuDNN dispatch) for clarity. The agent's solution is
9: forbidden from using Conv3d / conv3d / matmul / linear / einsum, forcing them
10: to write a fused patch-gather + tensor-core GEMM kernel.
11:
12: Output layout: (B, embed_dim, T/kT, H/kH, W/kW). No bias.
13: """
14: import torch
15: import torch.nn as nn
16:
17: OP_TYPE = "patch_embed"
18: SUPPORTED_PRECISIONS = ["bf16"]
19: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
20:
21:
22: class Model(nn.Module):
23: def __init__(self, B: int, C: int, T: int, H: int, W: int,
24: kT: int, kH: int, kW: int, embed_dim: int):
25: super().__init__()
26: assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
27: f"Input dims must be divisible by patch size: T={T} kT={kT} H={H} kH={kH} W={W} kW={kW}"
28: self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
29: self.kT, self.kH, self.kW = kT, kH, kW
30: self.embed_dim = embed_dim
31:
3
... (911 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Conv3d-as-GEMM patch embedding.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") o
... (3130 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Conv3d-as-GEMM patch embedding.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regim
... (3879 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for Conv3d-as-GEMM patch embedding.
2:
3: Sized for Qwen2-VL / ViT-style patch embedding. embed_dim=1280 matches Qwen2-VL
4: ViT; embed_dim=768 is a ViT-B/16 reference. Patch sizes are the canonical
5: 14x14 spatial / 2-frame temporal used in modern VL encoders.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL ViT base: 224x224x2-frame video, 14x14x2 patches, embed_dim=1280
10: {"B": 1, "C": 3, "T": 2, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
11:
12: # Same patch config, batch 2, 4-frame video
13: {"B": 2, "C": 3, "T": 4, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
14:
15: # Larger spatial (336x336), 8-frame — exercises larger num_patches
16: {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
17:
18: # ViT-B/16 (image-mode): kT=1, kH=kW=16, embed_dim=768
19: {"B": 4, "C": 3, "T": 1, "H": 224, "W": 224, "kT": 1, "kH": 16, "kW": 16, "embed_dim": 768},
20: ]
(End of file - total 20 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 10_patch_embed_conv3d_gemm
2: display_name: "Conv3d-as-GEMM Patch Embedding"
3: precision: bf16
4: regime: compute # GEMM-equivalent, embed_dim=1280 is well above the bf16 ridge
5:
6: # Dense FLOPs: 2 * (num_patches) * (C * kT * kH * kW) * embed_dim.
7: # With stride==kernel and divisible inputs, num_patches * (kT*kH*kW) = T*H*W,
8: # so this collapses to: 2 * B * embed_dim * C * T * H * W.
9: flops_formula: "2 * B * embed_dim * C * T * H * W"
10:
11: # Bytes moved (lower bound):
12: # read input x: B * C * T * H * W * 2
13: # read weights: embed_dim * C * kT * kH * kW * 2
14: # write output: B * embed_dim * (T // kT) * (H // kH) * (W // kW) * 2
15: bytes_formula: "B*C*T*H*W*2 + embed_dim*C*kT*kH*kW*2 + B*embed_dim*(T//kT)*(H//kH)*(W//kW)*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must implement the fused patch-gather + GEMM
25: # themselves. Banning Conv3d/conv3d (the cuDNN dispatch) and matmul/bmm/linear/
26: # einsum (the lazy reshape-then-cuBLAS shortcut). Unfold is also banned for
27: # completeness even though it's nominally 2D.
28: forbidden:
29: - "torch.nn.Conv3d"
30: - "torch.nn.functional.conv3d"
31: - "F.conv3d"
32: - "torch.conv3d"
33: - "torch.matmul"
34: - "torch.bmm"
35: -
... (386 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Conv3d-as-GEMM patch embedding.
2:
3: Dispatches torch.nn.functional.conv3d (cuDNN). The agent's solution is
4: forbidden from calling conv3d directly; this file is only for the benchmark's
5: reference baseline.
6: """
7: from __future__ import annotations
8:
9: import torch
10: import torch.nn.functional as F
11:
12:
13: def sota_forward(x: torch.Tensor, weight: torch.Tensor,
14: stride: tuple[int, int, int]) -> torch.Tensor:
15: return F.conv3d(x, weight, bias=None, stride=stride)
16:
17:
18: def is_available() -> bool:
19: return True
(End of file - total 19 lines)
</content>{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and problem files",
"priority": "high",
"status": "completed"
},
{
"content": "Load kernel optimization skill for SM120/Blackwell guidance",
"priority": "high",
"status": "in_progress"
},
{
"content": "Implement custom CUDA kernel for 3D patch embedding",
"priority": "high",
"status": "pending"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "pending"
},
{
"content": "Profile a
... (223 more chars)content: 'Read reference.py, check.py, benchmark.py, and problem files'
status: 'completed'
priority: 'high'
content: 'Load kernel optimization skill for SM120/Blackwell guidance'
status: 'in_progress'
priority: 'high'
content: 'Implement custom CUDA kernel for 3D patch embedding'
status: 'pending'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'high'
content: 'Benchmark with benchmark.py and iterate'
status: 'pending'
priority: 'medium'{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars){
"command": "nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader 2>/dev/null; echo \"---\"; uv run python -c \"import triton; print('triton', triton.__version__)\" 2>/dev/null; echo \"---\"; uv run python -c \"import torch; print('torch', torch.__version__); print('cuda', torch.version.cuda)\" 2>/dev/null; echo \"---\"; nvcc --version 2>/dev/null | tail -1",
"description": "Check GPU, Triton, PyTorch, CUDA versions",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0, 97887 MiB
---
triton 3.6.0
---
torch 2.11.0+cu130
cuda 13.0
---
Build cuda_13.2.r13.2/compiler.37434383_0
{
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md"
}{
"filePath": "/home/infatoshi/.claude/skills/perf-optim/SKILL.md",
"limit": 200
}{
"command": "head -300 /home/infatoshi/.claude/skills/perf-optim/SKILL.md",
"description": "Read perf-optim skill reference"
}---
name: perf-optim
description: "[SUB-REFERENCE of kernel skill] Performance optimization. Do NOT invoke directly -- use /kernel instead, which reads this file on demand."
---
# Performance Optimization Atlas
Comprehensive, hierarchical reference for optimizing GPU workloads. Pick combinations from this atlas based on your bottleneck profile. **Always profile before optimizing.**
---
## 0. FIRST PRINCIPLES: SPEED OF LIGHT, THEN PROFILE
### Rule
Never optimize blind. But before you even profile, compute the **speed of light** -- the theoretical physical limit for this operation. Then profile to measure the gap. The gap IS the optimization opportunity.
### Speed of Light Analysis
1. **Compute the limit**: What does physics allow? Peak FLOPS, peak bandwidth, minimum latency for this exact workload
2. **Profile current state**: Where are we now?
3. **Measure the gap**: speed_of_light - measured = optimization budget
4. **Don't continuous-improve from current state**: Strip to zero. If you built this from scratch knowing only the physical limits, what would you build? Then reason about every piece of the gap between that and current
- "It takes 74ms" -- don't optimize to 72ms. Compute that physics allows 6ms. Now reason about the 6-to-74 gap. The rest might be well-reasoned compromises, but at least you know what they are
### NVIDIA Profiling Tools
| Tool | What it tells you | When to use |
|------|------------------|-------------|
| `nvidia-smi` | GPU util, memory, te
... (17269 more chars){
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and problem files",
"priority": "high",
"status": "completed"
},
{
"content": "Load kernel optimization skill for SM120/Blackwell guidance",
"priority": "high",
"status": "completed"
},
{
"content": "Implement custom CUDA kernel for 3D patch embedding",
"priority": "high",
"status": "in_progress"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "pending"
},
{
"content": "Profile
... (225 more chars)content: 'Read reference.py, check.py, benchmark.py, and problem files'
status: 'completed'
priority: 'high'
content: 'Load kernel optimization skill for SM120/Blackwell guidance'
status: 'completed'
priority: 'high'
content: 'Implement custom CUDA kernel for 3D patch embedding'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'high'
content: 'Benchmark with benchmark.py and iterate'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -0,0 +1,158 @@
+"""Fused patch-gather + tensor-core GEMM kernel for 3D patch embedding.
+
+Uses Triton with tl.dot to fuse the strided patch gather into the GEMM K-loop,
+avoiding materialization of the im2col buffer. The patch-gather offsets decompose
+into a rank-1 sum of m-dependent and k-dependent terms, enabling efficient
+2D pointer arithmetic.
+
+Output layout: (B, embed_dim, T//kT, H//kH, W//kW). No bias.
+"""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def patch_embed_kernel(
+ x_ptr, w_ptr, out_ptr,
+ nPT, nPH, nPW,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_out_b, stride_out_n, stride_out_pt, stride_out_ph, stride_out_pw,
+ K, N,
+ kT, kH, kW,
+ M_total,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_n_blocks = tl.cdiv(N, BLOCK_N)
+ pid_m = pid // num_n_blocks
+ pid_n = pid % num_n_blocks
+
+ m_start = pid_m * BLOCK_M
+ n_start = pid_n * BLOCK_N
+
+ m_off = m_start + tl.arange(0, BLOCK_M)
+ n_off = n_start + tl.arange(0, BLOCK_N)
+
+ patches_per_batch = nPT * nPH * nPW
+ pw_m = m_off % nPW
+ ph_m = (m_off // nPW) % nPH
+ pt_m = (m_off // (nPW * nPH)) % nPT
+ b_m = m_off // patches_per_batch
+
+ m_mask = m_off < M_total
+ n_mask = n_off < N
+
+ m_part = b_m * stride_x_b + pt_m * kT * stride_x_t + ph_m * kH * stride_x_h + pw_m * kW * stride_x_w
+
+ out_m_part = b_m * stride_out_b + pt_m * stride_out_pt + ph_m * stride_out_ph + pw_m * stride_out_pw
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ kTkHkW = kT * kH * kW
+ kHkW = kH * kW
+
+ for k_start in range(0, K, BLOCK_K):
+ k_off = k_start + tl.arange(0, BLOCK_K)
+
+ kw_k = k_off % kW
+ kh_k = (k_off // kW) % kH
+ kt_k = (k_off // kHkW) % kT
+ c_k = k_off // kTkHkW
+
+ k_mask = k_off < K
+
+ k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
+
+ a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
+ a_mask = m_mask[:, None] & k_mask[None, :]
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+
+ w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
+ w_mask = n_mask[None, :] & k_mask[:, None]
+ w = tl.load(w_ptrs, mask=w_mask, other=0.0)
+
+ acc += tl.dot(a, w)
+
+ out_ptrs = out_ptr + out_m_part[:, None] + n_off[None, :] * stride_out_n
+ out_mask = m_mask[:, None] & n_mask[None, :]
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, C: int, T: int, H: int, W: int,
+ kT: int, kH: int, kW: int, embed_dim: int):
+ super().__init__()
+ self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
+ self.kT, self.kH, self.kW = kT, kH, kW
+ self.embed_dim = embed_dim
+
+ self.conv = nn.Conv3d(
+ C, embed_dim,
+ kernel_size=(kT, kH, kW),
+ stride=(kT, kH, kW),
+ bias=False,
+ dtype=torch.bfloat16,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, T, H, W = x.shape
+ kT, kH, kW = self.kT, self.kH, self.kW
+ embed_dim = self.embed_dim
+
+ nPT = T // kT
+ nPH = H // kH
+ nPW = W // kW
+
+ M_total = B * nPT * nPH * nPW
+ K = C * kT * kH * kW
+ N = embed_dim
+
+ out = torch.empty(B, embed_dim, nPT, nPH, nPW, dtype=torch.bfloat16, device=x.device)
+
+ stride_x = x.stride()
+ stride_out = out.stride()
+
+ BLOCK_M = 64
+ BLOCK_N = 128
+ BLOCK_K = 64
+
+ grid = (triton.cdiv(M_total, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
+
+ patch_embed_kernel[grid](
+ x, self.conv.weight, out,
+ nPT, nPH, nPW,
+ stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
+ stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
+ K, N,
+ kT, kH, kW,
+ M_total,
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ num_warps=8,
+ num_stages=3,
+ )
+
+ return out
+
+
+B = 1
+C = 3
+T = 2
+H = 224
+W = 224
+kT = 2
+kH = 14
+kW = 14
+embed_dim = 1280
+
+
+def get_inputs():
+ x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
+ return [x]
+
+
+def get_init_inputs():
+ return [B, C, T, H, W, kT, kH, kW, embed_dim]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<4395 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm",
"timeout": 120000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run performance benchmark",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.608 gbps=19.983 ms=0.214
shape=0 variant=compiled tflops=3.563 gbps=19.730 ms=0.216
shape=0 variant=sota tflops=3.679 gbps=20.372 ms=0.210
shape=0 variant=solution tflops=11.601 gbps=64.247 ms=0.066
shape=0 solution_peak_fraction=0.0580
shape=1 variant=eager tflops=13.695 gbps=35.719 ms=0.225
shape=1 variant=compiled tflops=11.505 gbps=30.007 ms=0.268
shape=1 variant=sota tflops=13.884 gbps=36.210 ms=0.222
shape=1 variant=solution tflops=39.234 gbps=102.327 ms=0.079
shape=1 solution_peak_fraction=0.1962
shape=2 variant=eager tflops=25.272 gbps=52.203 ms=0.274
shape=2 variant=compiled tflops=22.317 gbps=46.097 ms=0.311
shape=2 variant=sota tflops=25.710 gbps=53.107 ms=0.270
shape=2 variant=solution tflops=49.477 gbps=102.201 ms=0.140
shape=2 solution_peak_fraction=0.2474
shape=3 variant=eager tflops=6.265 gbps=24.307 ms=0.148
shape=3 variant=compiled tflops=6.174 gbps=23.954 ms=0.150
shape=3 variant=sota tflops=6.756 gbps=26.210 ms=0.137
shape=3 variant=solution tflops=64.297 gbps=249.450 ms=0.014
shape=3 solution_peak_fraction=0.3215
peak_fraction: 0.1734
RESULT: OK
{
"command": "<185 chars \u2014 see below>",
"description": "Check hardware peak specs",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
from src.hardware import get
hw = get('RTX_PRO_6000')
print('peak bf16:', hw.peak_tflops_dense.get('bf16', 0))
print('peak bandwidth:', hw.peak_bandwidth_gb_s)
" 2>&1peak bf16: 200.0
peak bandwidth: 1800.0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -13,6 +13,21 @@
import triton.language as tl
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
+ ],
+ key=['K', 'N', 'M_total'],
+)
@triton.jit
def patch_embed_kernel(
x_ptr, w_ptr, out_ptr,
@@ -73,7 +88,7 @@
w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
w_mask = n_mask[None, :] & k_mask[:, None]
- w = tl.load(w_ptrs, mask=w_mask, other=0.0)
+ w = tl.load(w_ptrs, mask=w_mask, other=0.0, eviction_policy='evict_last')
acc += tl.dot(a, w)
@@ -116,11 +131,7 @@
stride_x = x.stride()
stride_out = out.stride()
- BLOCK_M = 64
- BLOCK_N = 128
- BLOCK_K = 64
-
- grid = (triton.cdiv(M_total, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
+ grid = (triton.cdiv(M_total, 64) * triton.cdiv(N, 128),)
patch_embed_kernel[grid](
x, self.conv.weight, out,
@@ -130,9 +141,6 @@
K, N,
kT, kH, kW,
M_total,
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
- num_warps=8,
- num_stages=3,
)
return out
{
"content": "<5292 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check with autotune",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}FAIL: shape 0 {'B': 1, 'C': 3, 'T': 2, 'H': 224, 'W': 224, 'kT': 2, 'kH': 14, 'kW': 14, 'embed_dim': 1280} seed 123: tolerance exceeded: max_abs_diff=1.53906 (atol=0.02, rtol=0.02)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py</path>
<type>file</type>
<content>
1: """Fused patch-gather + tensor-core GEMM kernel for 3D patch embedding.
2:
3: Uses Triton with tl.dot to fuse the strided patch gather into the GEMM K-loop,
4: avoiding materialization of the im2col buffer. The patch-gather offsets decompose
5: into a rank-1 sum of m-dependent and k-dependent terms, enabling efficient
6: 2D pointer arithmetic.
7:
8: Output layout: (B, embed_dim, T//kT, H//kH, W//kW). No bias.
9: """
10: import torch
11: import torch.nn as nn
12: import triton
13: import triton.language as tl
14:
15:
16: @triton.autotune(
17: configs=[
18: triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
19: triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
20: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
21: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
22: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=4),
23: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
24: triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
25: triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num
... (4684 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -131,7 +131,7 @@
stride_x = x.stride()
stride_out = out.stride()
- grid = (triton.cdiv(M_total, 64) * triton.cdiv(N, 128),)
+ grid = lambda META: (triton.cdiv(M_total, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
patch_embed_kernel[grid](
x, self.conv.weight, out,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<99 chars \u2014 see below>",
"newString": "<137 chars \u2014 see below>"
} grid = (triton.cdiv(M_total, 64) * triton.cdiv(N, 128),)
patch_embed_kernel[grid]( grid = lambda META: (triton.cdiv(M_total, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
patch_embed_kernel[grid](Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check with grid lambda",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark with autotuning",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.689 gbps=20.428 ms=0.209
shape=0 variant=compiled tflops=3.188 gbps=17.653 ms=0.242
shape=0 variant=sota tflops=3.728 gbps=20.643 ms=0.207
shape=0 variant=solution tflops=17.112 gbps=94.761 ms=0.045
shape=0 solution_peak_fraction=0.0856
shape=1 variant=eager tflops=14.224 gbps=37.098 ms=0.217
shape=1 variant=compiled tflops=12.596 gbps=32.851 ms=0.245
shape=1 variant=sota tflops=14.401 gbps=37.561 ms=0.214
shape=1 variant=solution tflops=63.694 gbps=166.125 ms=0.048
shape=1 solution_peak_fraction=0.3185
shape=2 variant=eager tflops=26.261 gbps=54.246 ms=0.264
shape=2 variant=compiled tflops=22.473 gbps=46.420 ms=0.309
shape=2 variant=sota tflops=26.429 gbps=54.593 ms=0.262
shape=2 variant=solution tflops=58.902 gbps=121.670 ms=0.118
shape=2 solution_peak_fraction=0.2945
shape=3 variant=eager tflops=6.498 gbps=25.209 ms=0.142
shape=3 variant=compiled tflops=5.135 gbps=19.921 ms=0.180
shape=3 variant=sota tflops=6.598 gbps=25.600 ms=0.140
shape=3 variant=solution tflops=29.090 gbps=112.862 ms=0.032
shape=3 solution_peak_fraction=0.1455
peak_fraction: 0.1848
RESULT: OK
--- /dev/null
+++ b//tmp/opencode/bench_configs.py
@@ -0,0 +1,157 @@
+"""Quick benchmark to test different Triton configs."""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+import time
+import sys
+sys.path.insert(0, '..')
+from problems.10_patch_embed_conv3d_gemm.shapes import SHAPES
+
+
+@triton.jit
+def patch_embed_kernel(
+ x_ptr, w_ptr, out_ptr,
+ nPT, nPH, nPW,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_out_b, stride_out_n, stride_out_pt, stride_out_ph, stride_out_pw,
+ K, N,
+ kT, kH, kW,
+ M_total,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_n_blocks = tl.cdiv(N, BLOCK_N)
+ pid_m = pid // num_n_blocks
+ pid_n = pid % num_n_blocks
+
+ m_off = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ n_off = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ patches_per_batch = nPT * nPH * nPW
+ pw_m = m_off % nPW
+ ph_m = (m_off // nPW) % nPH
+ pt_m = (m_off // (nPW * nPH)) % nPT
+ b_m = m_off // patches_per_batch
+
+ m_mask = m_off < M_total
+ n_mask = n_off < N
+
+ m_part = b_m * stride_x_b + pt_m * kT * stride_x_t + ph_m * kH * stride_x_h + pw_m * kW * stride_x_w
+ out_m_part = b_m * stride_out_b + pt_m * stride_out_pt + ph_m * stride_out_ph + pw_m * stride_out_pw
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ kTkHkW = kT * kH * kW
+ kHkW = kH * kW
+
+ for k_start in range(0, K, BLOCK_K):
+ k_off = k_start + tl.arange(0, BLOCK_K)
+ kw_k = k_off % kW
+ kh_k = (k_off // kW) % kH
+ kt_k = (k_off // kHkW) % kT
+ c_k = k_off // kTkHkW
+ k_mask = k_off < K
+
+ k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
+ a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
+ a_mask = m_mask[:, None] & k_mask[None, :]
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+
+ w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
+ w_mask = n_mask[None, :] & k_mask[:, None]
+ w = tl.load(w_ptrs, mask=w_mask, other=0.0, eviction_policy='evict_last')
+
+ acc += tl.dot(a, w)
+
+ out_ptrs = out_ptr + out_m_part[:, None] + n_off[None, :] * stride_out_n
+ out_mask = m_mask[:, None] & n_mask[None, :]
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
+
+
+def benchmark_config(shape, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, n_iter=50):
+ device = torch.device("cuda:0")
+ B, C, T, H, W = shape["B"], shape["C"], shape["T"], shape["H"], shape["W"]
+ kT, kH, kW = shape["kT"], shape["kH"], shape["kW"]
+ embed_dim = shape["embed_dim"]
+
+ nPT, nPH, nPW = T // kT, H // kH, W // kW
+ M_total = B * nPT * nPH * nPW
+ K = C * kT * kH * kW
+ N = embed_dim
+
+ conv = nn.Conv3d(C, embed_dim, kernel_size=(kT, kH, kW), stride=(kT, kH, kW), bias=False, dtype=torch.bfloat16).to(device)
+
+ x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16, device=device) * 0.5
+ out = torch.empty(B, embed_dim, nPT, nPH, nPW, dtype=torch.bfloat16, device=device)
+
+ stride_x = x.stride()
+ stride_out = out.stride()
+
+ grid = (triton.cdiv(M_total, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
+
+ def run():
+ patch_embed_kernel[grid](
+ x, conv.weight, out,
+ nPT, nPH, nPW,
+ stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
+ stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
+ K, N, kT, kH, kW, M_total,
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ num_warps=num_warps, num_stages=num_stages,
+ )
+
+ for _ in range(5):
+ run()
+ torch.cuda.synchronize()
+
+ t0 = time.perf_counter()
+ for _ in range(n_iter):
+ run()
+ torch.cuda.synchronize()
+ t1 = time.perf_counter()
+
+ ms = (t1 - t0) / n_iter * 1000
+ flops = 2 * B * embed_dim * C * T * H * W
+ tflops = flops / ms / 1e9
+ peak_frac = tflops / 200.0
+ return ms, tflops, peak_frac
+
+
+configs = [
+ (64, 128, 64, 8, 3),
+ (64, 128, 64, 8, 4),
+ (64, 128, 64, 8, 5),
+ (64, 128, 128, 8, 3),
+ (64, 128, 128, 8, 4),
+ (64, 256, 64, 8, 3),
+ (64, 256, 64, 8, 4),
+ (128, 128, 64, 8, 3),
+ (128, 128, 64, 8, 4),
+ (128, 256, 64, 8, 3),
+ (128, 256, 128, 8, 3),
+ (32, 128, 64, 4, 4),
+ (32, 256, 64, 8, 3),
+ (32, 256, 128, 8, 3),
+]
+
+for si, shape in enumerate(SHAPES):
+ print(f"\n=== Shape {si}: B={shape['B']} T={shape['T']} H={shape['H']} W={shape['W']} kT={shape['kT']} kH={shape['kH']} kW={shape['kW']} emb={shape['embed_dim']} ===")
+ M = shape['B'] * (shape['T']//shape['kT']) * (shape['H']//shape['kH']) * (shape['W']//shape['kW'])
+ K = shape['C'] * shape['kT'] * shape['kH'] * shape['kW']
+ N = shape['embed_dim']
+ print(f" M={M} K={K} N={N}")
+ best_tflops = 0
+ best_cfg = None
+ for BM, BN, BK, nw, ns in configs:
+ try:
+ ms, tflops, pf = benchmark_config(shape, BM, BN, BK, nw, ns, n_iter=30)
+ if tflops > best_tflops:
+ best_tflops = tflops
+ best_cfg = (BM, BN, BK, nw, ns)
+ if tflops > 30:
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: {tflops:6.1f} TF ({pf:.1%}) {ms:.3f}ms")
+ except Exception as e:
+ pass
+ print(f" BEST: {best_cfg} -> {best_tflops:.1f} TF ({best_tflops/200:.1%})")
{
"content": "<5315 chars \u2014 see diff>",
"filePath": "/tmp/opencode/bench_configs.py"
}{
"command": "uv run python /tmp/opencode/bench_configs.py 2>&1",
"description": "Benchmark different Triton configs",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/bin/python3: can't open file '/tmp/opencode/bench_configs.py': [Errno 2] No such file or directory
{
"command": "ls /tmp/opencode/",
"description": "Check tmp directory"
}(no output){
"command": "mkdir -p /tmp/opencode",
"description": "Create temp directory"
}--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/bench_configs.py
@@ -0,0 +1,164 @@
+"""Quick benchmark to test different Triton configs."""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+import time
+import sys
+
+sys.path.insert(0, '.')
+from src.eval.correctness import check_correctness
+
+
+@triton.jit
+def patch_embed_kernel(
+ x_ptr, w_ptr, out_ptr,
+ nPT, nPH, nPW,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_out_b, stride_out_n, stride_out_pt, stride_out_ph, stride_out_pw,
+ K, N,
+ kT, kH, kW,
+ M_total,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_n_blocks = tl.cdiv(N, BLOCK_N)
+ pid_m = pid // num_n_blocks
+ pid_n = pid % num_n_blocks
+
+ m_off = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ n_off = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ patches_per_batch = nPT * nPH * nPW
+ pw_m = m_off % nPW
+ ph_m = (m_off // nPW) % nPH
+ pt_m = (m_off // (nPW * nPH)) % nPT
+ b_m = m_off // patches_per_batch
+
+ m_mask = m_off < M_total
+ n_mask = n_off < N
+
+ m_part = b_m * stride_x_b + pt_m * kT * stride_x_t + ph_m * kH * stride_x_h + pw_m * kW * stride_x_w
+ out_m_part = b_m * stride_out_b + pt_m * stride_out_pt + ph_m * stride_out_ph + pw_m * stride_out_pw
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ kTkHkW = kT * kH * kW
+ kHkW = kH * kW
+
+ for k_start in range(0, K, BLOCK_K):
+ k_off = k_start + tl.arange(0, BLOCK_K)
+ kw_k = k_off % kW
+ kh_k = (k_off // kW) % kH
+ kt_k = (k_off // kHkW) % kT
+ c_k = k_off // kTkHkW
+ k_mask = k_off < K
+
+ k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
+ a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
+ a_mask = m_mask[:, None] & k_mask[None, :]
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
+
+ w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
+ w_mask = n_mask[None, :] & k_mask[None, :]
+ w = tl.load(w_ptrs, mask=w_mask, other=0.0, eviction_policy='evict_last')
+
+ acc += tl.dot(a, w)
+
+ out_ptrs = out_ptr + out_m_part[:, None] + n_off[None, :] * stride_out_n
+ out_mask = m_mask[:, None] & n_mask[None, :]
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
+
+
+SHAPES = [
+ {"B": 1, "C": 3, "T": 2, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
+ {"B": 2, "C": 3, "T": 4, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
+ {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
+ {"B": 4, "C": 3, "T": 1, "H": 224, "W": 224, "kT": 1, "kH": 16, "kW": 16, "embed_dim": 768},
+]
+
+
+def benchmark_config(shape, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, n_iter=30):
+ device = torch.device("cuda:0")
+ B, C, T, H, W = shape["B"], shape["C"], shape["T"], shape["H"], shape["W"]
+ kT, kH, kW = shape["kT"], shape["kH"], shape["kW"]
+ embed_dim = shape["embed_dim"]
+
+ nPT, nPH, nPW = T // kT, H // kH, W // kW
+ M_total = B * nPT * nPH * nPW
+ K = C * kT * kH * kW
+ N = embed_dim
+
+ conv = nn.Conv3d(C, embed_dim, kernel_size=(kT, kH, kW), stride=(kT, kH, kW), bias=False, dtype=torch.bfloat16).to(device)
+ x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16, device=device) * 0.5
+ out = torch.empty(B, embed_dim, nPT, nPH, nPW, dtype=torch.bfloat16, device=device)
+
+ stride_x = x.stride()
+ stride_out = out.stride()
+
+ grid = (triton.cdiv(M_total, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
+
+ def run():
+ patch_embed_kernel[grid](
+ x, conv.weight, out,
+ nPT, nPH, nPW,
+ stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
+ stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
+ K, N, kT, kH, kW, M_total,
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
+ num_warps=num_warps, num_stages=num_stages,
+ )
+
+ for _ in range(5):
+ run()
+ torch.cuda.synchronize()
+
+ t0 = time.perf_counter()
+ for _ in range(n_iter):
+ run()
+ torch.cuda.synchronize()
+ t1 = time.perf_counter()
+
+ ms = (t1 - t0) / n_iter * 1000
+ flops = 2 * B * embed_dim * C * T * H * W
+ tflops = flops / ms / 1e9
+ peak_frac = tflops / 200.0
+ return ms, tflops, peak_frac
+
+
+configs = [
+ (64, 128, 64, 8, 3),
+ (64, 128, 64, 8, 4),
+ (64, 128, 64, 8, 5),
+ (64, 128, 128, 8, 3),
+ (64, 128, 128, 8, 4),
+ (64, 256, 64, 8, 3),
+ (64, 256, 64, 8, 4),
+ (128, 128, 64, 8, 3),
+ (128, 128, 64, 8, 4),
+ (128, 256, 64, 8, 3),
+ (128, 256, 128, 8, 3),
+ (32, 128, 64, 4, 4),
+ (32, 256, 64, 8, 3),
+ (32, 256, 128, 8, 3),
+]
+
+for si, shape in enumerate(SHAPES):
+ print(f"\n=== Shape {si}: B={shape['B']} T={shape['T']} H={shape['H']} W={shape['W']} kT={shape['kT']} kH={shape['kH']} kW={shape['kW']} emb={shape['embed_dim']} ===")
+ M = shape['B'] * (shape['T']//shape['kT']) * (shape['H']//shape['kH']) * (shape['W']//shape['kW'])
+ K = shape['C'] * shape['kT'] * shape['kH'] * shape['kW']
+ N = shape['embed_dim']
+ print(f" M={M} K={K} N={N}")
+ best_tflops = 0
+ best_cfg = None
+ for BM, BN, BK, nw, ns in configs:
+ try:
+ ms, tflops, pf = benchmark_config(shape, BM, BN, BK, nw, ns, n_iter=30)
+ if tflops > best_tflops:
+ best_tflops = tflops
+ best_cfg = (BM, BN, BK, nw, ns)
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: {tflops:6.1f} TF ({pf:.1%}) {ms:.3f}ms")
+ except Exception as e:
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: FAILED {e}")
+ print(f" BEST: {best_cfg} -> {best
... (42 more chars){
"content": "<5749 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/bench_configs.py"
}Wrote file successfully.{
"command": "uv run python problems/10_patch_embed_conv3d_gemm/bench_configs.py 2>&1",
"description": "Benchmark different Triton configs",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}=== Shape 0: B=1 T=2 H=224 W=224 kT=2 kH=14 kW=14 emb=1280 ===
M=256 K=1176 N=1280
BM= 64 BN=128 BK= 64 nw=8 ns=3: FAILED at 51:17:
kh_k = (k_off // kW) % kH
kt_k = (k_off // kHkW) % kT
c_k = k_off // kTkHkW
k_mask = k_off < K
k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
a_mask = m_mask[:, None] & k_mask[None, :]
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
w_mask = n_mask[None, :] & k_mask[None, :]
^
ValueError('Cannot make_shape_compatible: incompatible dimensions at index 1: 128 and 64')
BM= 64 BN=128 BK= 64 nw=8 ns=4: FAILED at 51:17:
kh_k = (k_off // kW) % kH
kt_k = (k_off // kHkW) % kT
c_k = k_off // kTkHkW
k_mask = k_off < K
k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
a_mask = m_mask[:, None] & k_mask[None, :]
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
w_mask = n_mask[None, :] & k_mask[None, :]
^
ValueError('Cannot make_shape_compatible: incompatible dimensions at index 1: 128 and 64')
BM= 64 BN=128 BK= 64 nw=8 ns=5: FAILED at 51:17:
kh_k = (k_off // kW) % kH
kt_k =
... (31209 more chars){
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, and problem files",
"priority": "high",
"status": "completed"
},
{
"content": "Load kernel optimization skill for SM120/Blackwell guidance",
"priority": "high",
"status": "completed"
},
{
"content": "Implement custom CUDA kernel for 3D patch embedding",
"priority": "high",
"status": "in_progress"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "pending"
},
{
"content": "Profile
... (225 more chars)content: 'Read reference.py, check.py, benchmark.py, and problem files'
status: 'completed'
priority: 'high'
content: 'Load kernel optimization skill for SM120/Blackwell guidance'
status: 'completed'
priority: 'high'
content: 'Implement custom CUDA kernel for 3D patch embedding'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel performance'
status: 'pending'
priority: 'high'
content: 'Benchmark with benchmark.py and iterate'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -15,16 +15,16 @@
@triton.autotune(
configs=[
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=2),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 128}, num_warps=4, num_stages=3),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_warps=8, num_stages=2),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_warps=8, num_stages=3),
],
key=['K', 'N', 'M_total'],
)
@@ -83,17 +83,17 @@
k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
- a_mask = m_mask[:, None] & k_mask[None, :]
+ a_mask = m_mask[:, None] * k_mask[None, :]
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
- w_mask = n_mask[None, :] & k_mask[:, None]
+ w_mask = n_mask[None, :] * k_mask[:, None]
w = tl.load(w_ptrs, mask=w_mask, other=0.0, eviction_policy='evict_last')
acc += tl.dot(a, w)
out_ptrs = out_ptr + out_m_part[:, None] + n_off[None, :] * stride_out_n
- out_mask = m_mask[:, None] & n_mask[None, :]
+ out_mask = m_mask[:, None] * n_mask[None, :]
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
{
"content": "<5335 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check with * mask",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark with * mask and autotune",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}shape=0 variant=eager tflops=3.688 gbps=20.425 ms=0.209
shape=0 variant=compiled tflops=3.103 gbps=17.183 ms=0.248
shape=0 variant=sota tflops=3.687 gbps=20.419 ms=0.209
shape=0 variant=solution tflops=16.807 gbps=93.075 ms=0.046
shape=0 solution_peak_fraction=0.0840
shape=1 variant=eager tflops=14.257 gbps=37.186 ms=0.216
shape=1 variant=compiled tflops=12.510 gbps=32.627 ms=0.246
shape=1 variant=sota tflops=14.409 gbps=37.581 ms=0.214
shape=1 variant=solution tflops=58.105 gbps=151.546 ms=0.053
shape=1 solution_peak_fraction=0.2905
shape=2 variant=eager tflops=26.195 gbps=54.108 ms=0.265
shape=2 variant=compiled tflops=23.417 gbps=48.371 ms=0.296
shape=2 variant=sota tflops=26.857 gbps=55.476 ms=0.258
shape=2 variant=solution tflops=54.333 gbps=112.231 ms=0.128
shape=2 solution_peak_fraction=0.2717
shape=3 variant=eager tflops=6.499 gbps=25.214 ms=0.142
shape=3 variant=compiled tflops=5.212 gbps=20.221 ms=0.177
shape=3 variant=sota tflops=6.628 gbps=25.714 ms=0.140
shape=3 variant=solution tflops=30.090 gbps=116.739 ms=0.031
shape=3 solution_peak_fraction=0.1504
peak_fraction: 0.1777
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/bench_configs.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/bench_configs.py
@@ -1,13 +1,9 @@
-"""Quick benchmark to test different Triton configs."""
+"""Quick benchmark to test different Triton configs with & mask."""
import torch
import torch.nn as nn
import triton
import triton.language as tl
import time
-import sys
-
-sys.path.insert(0, '.')
-from src.eval.correctness import check_correctness
@triton.jit
@@ -61,7 +57,7 @@
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
- w_mask = n_mask[None, :] & k_mask[None, :]
+ w_mask = n_mask[None, :] & k_mask[:, None]
w = tl.load(w_ptrs, mask=w_mask, other=0.0, eviction_policy='evict_last')
acc += tl.dot(a, w)
@@ -127,21 +123,21 @@
return ms, tflops, peak_frac
+# Only test configs that work with & mask (BLOCK_K=128 based)
configs = [
- (64, 128, 64, 8, 3),
- (64, 128, 64, 8, 4),
- (64, 128, 64, 8, 5),
(64, 128, 128, 8, 3),
+ (64, 128, 128, 8, 2),
+ (32, 128, 128, 4, 3),
+ (32, 128, 128, 4, 2),
+ (128, 128, 128, 8, 3),
+ (128, 128, 128, 8, 2),
+ (64, 256, 128, 8, 3),
+ (64, 256, 128, 8, 2),
+ (64, 64, 128, 4, 3),
+ (32, 64, 128, 4, 3),
+ (32, 256, 128, 8, 2),
+ (128, 256, 128, 8, 2),
(64, 128, 128, 8, 4),
- (64, 256, 64, 8, 3),
- (64, 256, 64, 8, 4),
- (128, 128, 64, 8, 3),
- (128, 128, 64, 8, 4),
- (128, 256, 64, 8, 3),
- (128, 256, 128, 8, 3),
- (32, 128, 64, 4, 4),
- (32, 256, 64, 8, 3),
- (32, 256, 128, 8, 3),
]
for si, shape in enumerate(SHAPES):
@@ -160,5 +156,5 @@
best_cfg = (BM, BN, BK, nw, ns)
print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: {tflops:6.1f} TF ({pf:.1%}) {ms:.3f}ms")
except Exception as e:
- print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: FAILED {e}")
+ print(f" BM={BM:3d} BN={BN:3d} BK={BK:3d} nw={nw} ns={ns}: FAILED {str(e)[:80]}")
print(f" BEST: {best_cfg} -> {best_tflops:.1f} TF ({best_tflops/200:.1%})")
{
"content": "<5727 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/bench_configs.py"
}Wrote file successfully.{
"command": "uv run python problems/10_patch_embed_conv3d_gemm/bench_configs.py 2>&1",
"description": "Benchmark BK=128 configs with & mask",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}=== Shape 0: B=1 T=2 H=224 W=224 kT=2 kH=14 kW=14 emb=1280 ===
M=256 K=1176 N=1280
BM= 64 BN=128 BK=128 nw=8 ns=3: 27.8 TF (13.9%) 0.028ms
BM= 64 BN=128 BK=128 nw=8 ns=2: 27.8 TF (13.9%) 0.028ms
BM= 32 BN=128 BK=128 nw=4 ns=3: 17.3 TF (8.6%) 0.045ms
BM= 32 BN=128 BK=128 nw=4 ns=2: 17.2 TF (8.6%) 0.045ms
BM=128 BN=128 BK=128 nw=8 ns=3: 13.9 TF (7.0%) 0.055ms
BM=128 BN=128 BK=128 nw=8 ns=2: 13.9 TF (7.0%) 0.055ms
BM= 64 BN=256 BK=128 nw=8 ns=3: 7.8 TF (3.9%) 0.098ms
BM= 64 BN=256 BK=128 nw=8 ns=2: 7.8 TF (3.9%) 0.098ms
BM= 64 BN= 64 BK=128 nw=4 ns=3: 6.6 TF (3.3%) 0.117ms
BM= 32 BN= 64 BK=128 nw=4 ns=3: 20.3 TF (10.1%) 0.038ms
BM= 32 BN=256 BK=128 nw=8 ns=2: 11.9 TF (6.0%) 0.065ms
BM=128 BN=256 BK=128 nw=8 ns=2: 4.2 TF (2.1%) 0.184ms
BM= 64 BN=128 BK=128 nw=8 ns=4: 27.8 TF (13.9%) 0.028ms
BEST: (64, 128, 128, 8, 2) -> 27.8 TF (13.9%)
=== Shape 1: B=2 T=4 H=224 W=224 kT=2 kH=14 kW=14 emb=1280 ===
M=1024 K=1176 N=1280
BM= 64 BN=128 BK=128 nw=8 ns=3: 107.1 TF (53.6%) 0.029ms
BM= 64 BN=128 BK=128 nw=8 ns=2: 107.0 TF (53.5%) 0.029ms
BM= 32 BN=128 BK=128 nw=4 ns=3: 50.0 TF (25.0%) 0.062ms
BM= 32 BN=128 BK=128 nw=4 ns=2: 50.1 TF (25.1%) 0.061ms
BM=128 BN=128 BK=128 nw=8 ns=3: 54.8 TF (27.4%) 0.056ms
BM=128 BN=128 BK=128 nw=8 ns=2: 54.8 TF (27.4%) 0.056ms
BM= 64 BN=256 BK=128 nw=8 ns=3: 30.3 TF (15.1%) 0.102ms
BM= 64 BN=256 BK=128 nw=8 ns=2: 30.3 TF (15.1%) 0.102ms
BM= 64 BN= 64 BK=128 nw=4 n
... (2330 more chars){
"command": "<5712 chars \u2014 see below>",
"description": "Compare & vs * mask performance",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard"
}uv run python -c "
import torch, torch.nn as nn, triton, triton.language as tl, time
@triton.jit
def kernel_and(x_ptr, w_ptr, out_ptr, nPT, nPH, nPW,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_out_b, stride_out_n, stride_out_pt, stride_out_ph, stride_out_pw,
K, N, kT, kH, kW, M_total,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
num_n_blocks = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_n_blocks; pid_n = pid % num_n_blocks
m_off = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n_off = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ppb = nPT * nPH * nPW
pw_m = m_off % nPW; ph_m = (m_off // nPW) % nPH
pt_m = (m_off // (nPW * nPH)) % nPT; b_m = m_off // ppb
m_mask = m_off < M_total; n_mask = n_off < N
m_part = b_m * stride_x_b + pt_m * kT * stride_x_t + ph_m * kH * stride_x_h + pw_m * kW * stride_x_w
out_m_part = b_m * stride_out_b + pt_m * stride_out_pt + ph_m * stride_out_ph + pw_m * stride_out_pw
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
kTkHkW = kT * kH * kW; kHkW = kH * kW
for k_start in range(0, K, BLOCK_K):
k_off = k_start + tl.arange(0, BLOCK_K)
kw_k = k_off % kW; kh_k = (k_off // kW) % kH
kt_k = (k_off // kHkW) % kT; c_k = k_off // kTkHkW
k_mask = k_off < K
k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
a_mask = m_mask[:, None] & k_mask[None, :]
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
w_mask = n_mask[None, :] & k_mask[:, None]
w = tl.load(w_ptrs, mask=w_mask, other=0.0, eviction_policy='evict_last')
acc += tl.dot(a, w)
out_ptrs = out_ptr + out_m_part[:, None] + n_off[None, :] * stride_out_n
out_mask = m_mask[:, None] & n_mask[None, :]
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
@triton.jit
def kernel_mul(x_ptr, w_ptr, out_ptr, nPT, nPH, nPW,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_out_b, stride_out_n, stride_out_pt, stride_out_ph, stride_out_pw,
K, N, kT, kH, kW, M_total,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
num_n_blocks = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_n_blocks; pid_n = pid % num_n_blocks
m_off = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n_off = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ppb = nPT * nPH * nPW
pw_m = m_off % nPW; ph_m = (m_off // nPW) % nPH
pt_m = (m_off // (nPW * nPH)) % nPT; b_m = m_off // ppb
m_mask = m_off < M_total; n_mask = n_off < N
m_part = b_m * stride_x_b + pt_m * kT * stride_x_t + ph_m * kH * stride_x_h + pw_m * kW * stride_x_w
out_m_part = b_m * stride_out_b + pt_m * stride_out_pt + ph_m * stride_out_ph + pw_m * stride_out_pw
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
kTkHkW = kT * kH * kW; kHkW = kH * kW
for k_start in range(0, K, BLOCK_K):
k_off = k_start + tl.arange(0, BLOCK_K)
kw_k = k_off % kW; kh_k = (k_off // kW) % kH
kt_k = (k_off // kHkW) % kT; c_k = k_off // kTkHkW
k_mask = k_off < K
k_part = c_k * stride_x_c + kt_k * stride_x_t + kh_k * stride_x_h + kw_k * stride_x_w
a_ptrs = x_ptr + m_part[:, None] + k_part[None, :]
a_mask = m_mask[:, None] * k_mask[None, :]
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
w_ptrs = w_ptr + n_off[None, :] * K + k_off[:, None]
w_mask = n_mask[None, :] * k_mask[:, None]
w = tl.load(w_ptrs, mask=w_mask, other=0.0, eviction_policy='evict_last')
acc += tl.dot(a, w)
out_ptrs = out_ptr + out_m_part[:, None] + n_off[None, :] * stride_out_n
out_mask = m_mask[:, None] * n_mask[None, :]
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask)
device = torch.device('cuda:0')
shape = {'B': 2, 'C': 3,
... (1712 more chars)Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 4, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __i
... (149 more chars)