import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
],
key=['num_patches', 'K'],
)
@triton.jit
def _im2col_3d(x_ptr, buf_ptr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
T_out, H_out, W_out, num_patches, K,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
pid_m = pid % num_pid_m
pid_k = pid // num_pid_m
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
k_start = pid_k * BLOCK_K
k_offsets = k_start + tl.arange(0, BLOCK_K)
m_mask = m_offsets < num_patches
k_mask = k_offsets < K
_m = m_offsets[:, None]
tom = T_out * H_out * W_out
how = H_out * W_out
_b = _m // tom
r1 = _m - _b * tom
_out_t = r1 // how
r2 = r1 - _out_t * how
_out_h = r2 // W_out
_out_w = r2 - _out_h * W_out
_k2 = k_offsets[None, :]
prod_khw = kH * kW
prod_tkhw = kT * prod_khw
_c = _k2 // prod_tkhw
rk1 = _k2 - _c * prod_tkhw
_dt = rk1 // prod_khw
rk2 = rk1 - _dt * prod_khw
_dh = rk2 // kW
_dw = rk2 - _dh * kW
x_ptrs = (x_ptr + _b * stride_x_b + _c * stride_x_c
+ (_out_t * kT + _dt) * stride_x_t
+ (_out_h * kH + _dh) * stride_x_h
+ (_out_w * kW + _dw) * stride_x_w)
x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
buf_ptrs = buf_ptr + _m * K + _k2
tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
],
key=['num_patches', 'N', 'K'],
)
@triton.jit
def _gemm_patch_5d(a_ptr, w_ptr, out_ptr,
stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
T_out, H_out, W_out, num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
m_mask = m_offsets < num_patches
n_mask = n_offsets < N
_m = m_offsets[:, None]
_n_2d = n_offsets[None, :]
tom = T_out * H_out * W_out
how = H_out * W_out
_b = _m // tom
r1 = _m - _b * tom
_out_t = r1 // how
r2 = r1 - _out_t * how
_out_h = r2 // W_out
_out_w = r2 - _out_h * W_out
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
mk_mask = m_mask[:, None] & k_mask[None, :]
nk_mask = n_mask[None, :] & k_mask[:, None]
a = tl.load(a_ptr + _m * K + k_offsets[None, :], mask=mk_mask, other=0.0)
w = tl.load(w_ptr + k_offsets[:, None] * N + _n_2d, mask=nk_mask, other=0.0)
acc += tl.dot(a.to(tl.bfloat16), w.to(tl.bfloat16))
out_ptrs = (out_ptr + _b * stride_out_b + _n_2d * stride_out_d
+ _out_t * stride_out_t + _out_h * stride_out_h + _out_w * stride_out_w)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
],
key=['num_patches', 'N', 'K'],
)
@triton.jit
def _fused_patch_embed(x_ptr, w_ptr, out_ptr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
T_out, H_out, W_out, num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
m_mask = m_offsets < num_patches
n_mask = n_offsets < N
_m = m_offsets[:, None]
_n_2d = n_offsets[None, :]
tom = T_out * H_out * W_out
how = H_out * W_out
_b = _m // tom
r1 = _m - _b * tom
_out_t = r1 // how
r2 = r1 - _out_t * how
_out_h = r2 // W_out
_out_w = r2 - _out_h * W_out
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
prod_khw = kH * kW
prod_tkhw = kT * prod_khw
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
mk_mask = m_mask[:, None] & k_mask[None, :]
nk_mask = n_mask[None, :] & k_mask[:, None]
w = tl.load(w_ptr + k_offsets[:, None] + _n_2d * K, mask=nk_mask, other=0.0)
_k2 = k_offsets[None, :]
_c = _k2 // prod_tkhw
rk1 = _k2 - _c * prod_tkhw
_dt = rk1 // prod_khw
rk2 = rk1 - _dt * prod_khw
_dh = rk2 // kW
_dw = rk2 - _dh * kW
x = tl.load(x_ptr + _b * stride_x_b + _c * stride_x_c
+ (_out_t * kT + _dt) * stride_x_t
+ (_out_h * kH + _dh) * stride_x_h
+ (_out_w * kW + _dw) * stride_x_w, mask=mk_mask, other=0.0)
acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
out_ptrs = (out_ptr + _b * stride_out_b + _n_2d * stride_out_d
+ _out_t * stride_out_t + _out_h * stride_out_h + _out_w * stride_out_w)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
super().__init__()
assert T % kT == 0 and H % kH == 0 and W % kW == 0
self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
self.kT, self.kH, self.kW = kT, kH, kW
self.embed_dim = embed_dim
w = nn.Parameter(torch.empty(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16))
nn.init.normal_(w, std=0.02)
conv = nn.Module()
conv.register_parameter('weight', w)
self.conv = conv
K = C * kT * kH * kW
self._w_t = w.reshape(embed_dim, K).t().contiguous()
self._weight_version = w._version
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T, H, W = x.shape
kT, kH, kW = self.kT, self.kH, self.kW
embed_dim = self.embed_dim
weight = self.conv.weight
T_out = T // kT; H_out = H // kH; W_out = W // kW
K = C * kT * kH * kW; M = B * T_out * H_out * W_out; N = embed_dim
# Shapes with K >= 1000 (K=1176): im2col + GEMM with transposed weight
# Shape 3 (K=768): fused kernel with original weight layout
if K >= 1000:
if weight._version != self._weight_version:
self._w_t = weight.reshape(N, K).t().contiguous()
self._weight_version = weight._version
w_t = self._w_t
buf = torch.empty(M, K, dtype=torch.bfloat16, device=x.device)
grid_i = lambda META: (
triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(K, META['BLOCK_K']),
)
_im2col_3d[grid_i](x, buf, x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
T_out, H_out, W_out, M, K, kT, kH, kW)
out = torch.empty(B, N, T_out, H_out, W_out, dtype=torch.bfloat16, device=x.device)
stride_out = out.stride()
grid_g = lambda META: (
triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
)
_gemm_patch_5d[grid_g](buf, w_t, out,
stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
T_out, H_out, W_out, M, K, N, kT, kH, kW)
return out
else:
out = torch.empty(B, N, T_out, H_out, W_out, dtype=torch.bfloat16, device=x.device)
stride_out = out.stride()
grid_f = lambda META: (
triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
)
_fused_patch_embed[grid_f](x, weight, out,
x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
T_out, H_out, W_out, M, K, N, kT, kH, kW)
return out
B = 1; C = 3; T = 2; H = 224; W = 224
kT = 2; kH = 14; kW = 14; embed_dim = 1280
def get_inputs():
return [torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5]
def get_init_inputs():
return [B, C, T, H, W, kT, kH, kW, embed_dim]
shape=0 variant=eager tflops=3.688 gbps=20.422 ms=0.209
shape=0 variant=compiled tflops=3.190 gbps=17.665 ms=0.242
shape=0 variant=sota tflops=3.727 gbps=20.640 ms=0.207
shape=0 variant=solution tflops=15.365 gbps=85.088 ms=0.050
shape=0 solution_peak_fraction=0.0768
shape=1 variant=eager tflops=13.919 gbps=36.302 ms=0.221
shape=1 variant=compiled tflops=12.041 gbps=31.404 ms=0.256
shape=1 variant=sota tflops=14.125 gbps=36.840 ms=0.218
shape=1 variant=solution tflops=52.400 gbps=136.668 ms=0.059
shape=1 solution_peak_fraction=0.2620
shape=2 variant=eager tflops=26.124 gbps=53.961 ms=0.266
shape=2 variant=compiled tflops=23.415 gbps=48.366 ms=0.296
shape=2 variant=sota tflops=26.442 gbps=54.620 ms=0.262
shape=2 variant=solution tflops=92.652 gbps=191.384 ms=0.075
shape=2 solution_peak_fraction=0.4633
shape=3 variant=eager tflops=6.495 gbps=25.197 ms=0.142
shape=3 variant=compiled tflops=5.248 gbps=20.361 ms=0.176
shape=3 variant=sota tflops=6.639 gbps=25.759 ms=0.139
shape=3 variant=solution tflops=31.655 gbps=122.813 ms=0.029
shape=3 solution_peak_fraction=0.1583
peak_fraction: 0.1960
RESULT: OK
PASS
{
"problem": "10_patch_embed_conv3d_gemm",
"harness": "opencode",
"model": "deepseek/deepseek-v4-flash",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.1960,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 405557, "output_tokens": 59342, "cache_read_tokens": 6112128, "cache_creation_tokens": 0, "reasoning_tokens": 80800, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/reference.py</path>
<type>file</type>
<content>
1: """Naive Conv3d-as-GEMM patch embedding reference (correctness only).
2:
3: Vision-Transformer / Qwen2-VL style patch embedding: a video tensor
4: (B, C, T, H, W) is split into non-overlapping (kT, kH, kW) patches and each
5: patch projected to embed_dim. This is mathematically a 3D convolution with
6: stride == kernel and equivalently a single GEMM after a strided reshape.
7:
8: Reference uses nn.Conv3d (cuDNN dispatch) for clarity. The agent's solution is
9: forbidden from using Conv3d / conv3d / matmul / linear / einsum, forcing them
10: to write a fused patch-gather + tensor-core GEMM kernel.
11:
12: Output layout: (B, embed_dim, T/kT, H/kH, W/kW). No bias.
13: """
14: import torch
15: import torch.nn as nn
16:
17: OP_TYPE = "patch_embed"
18: SUPPORTED_PRECISIONS = ["bf16"]
19: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
20:
21:
22: class Model(nn.Module):
23: def __init__(self, B: int, C: int, T: int, H: int, W: int,
24: kT: int, kH: int, kW: int, embed_dim: int):
25: super().__init__()
26: assert T % kT == 0 and H % kH == 0 and W % kW == 0, \
27: f"Input dims must be divisible by patch size: T={T} kT={kT} H={H} kH={kH} W={W} kW={kW}"
28: self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
29: self.kT, self.kH, self.kW = kT, kH, kW
30: self.embed_dim = embed_dim
31:
3
... (911 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/problem.yaml</path>
<type>file</type>
<content>
1: name: 10_patch_embed_conv3d_gemm
2: display_name: "Conv3d-as-GEMM Patch Embedding"
3: precision: bf16
4: regime: compute # GEMM-equivalent, embed_dim=1280 is well above the bf16 ridge
5:
6: # Dense FLOPs: 2 * (num_patches) * (C * kT * kH * kW) * embed_dim.
7: # With stride==kernel and divisible inputs, num_patches * (kT*kH*kW) = T*H*W,
8: # so this collapses to: 2 * B * embed_dim * C * T * H * W.
9: flops_formula: "2 * B * embed_dim * C * T * H * W"
10:
11: # Bytes moved (lower bound):
12: # read input x: B * C * T * H * W * 2
13: # read weights: embed_dim * C * kT * kH * kW * 2
14: # write output: B * embed_dim * (T // kT) * (H // kH) * (W // kW) * 2
15: bytes_formula: "B*C*T*H*W*2 + embed_dim*C*kT*kH*kW*2 + B*embed_dim*(T//kT)*(H//kH)*(W//kW)*2"
16:
17: hardware: [RTX_PRO_6000]
18: peak_tflops_key: bf16
19: peak_bandwidth_key: dram
20:
21: tolerance:
22: bfloat16: 0.02
23:
24: # Forbidden ops -- agent must implement the fused patch-gather + GEMM
25: # themselves. Banning Conv3d/conv3d (the cuDNN dispatch) and matmul/bmm/linear/
26: # einsum (the lazy reshape-then-cuBLAS shortcut). Unfold is also banned for
27: # completeness even though it's nominally 2D.
28: forbidden:
29: - "torch.nn.Conv3d"
30: - "torch.nn.functional.conv3d"
31: - "F.conv3d"
32: - "torch.conv3d"
33: - "torch.matmul"
34: - "torch.bmm"
35: -
... (386 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for Conv3d-as-GEMM patch embedding.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") o
... (3130 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for Conv3d-as-GEMM patch embedding.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regim
... (3879 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for Conv3d-as-GEMM patch embedding.
2:
3: Sized for Qwen2-VL / ViT-style patch embedding. embed_dim=1280 matches Qwen2-VL
4: ViT; embed_dim=768 is a ViT-B/16 reference. Patch sizes are the canonical
5: 14x14 spatial / 2-frame temporal used in modern VL encoders.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL ViT base: 224x224x2-frame video, 14x14x2 patches, embed_dim=1280
10: {"B": 1, "C": 3, "T": 2, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
11:
12: # Same patch config, batch 2, 4-frame video
13: {"B": 2, "C": 3, "T": 4, "H": 224, "W": 224, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
14:
15: # Larger spatial (336x336), 8-frame — exercises larger num_patches
16: {"B": 1, "C": 3, "T": 8, "H": 336, "W": 336, "kT": 2, "kH": 14, "kW": 14, "embed_dim": 1280},
17:
18: # ViT-B/16 (image-mode): kT=1, kH=kW=16, embed_dim=768
19: {"B": 4, "C": 3, "T": 1, "H": 224, "W": 224, "kT": 1, "kH": 16, "kW": 16, "embed_dim": 768},
20: ]
(End of file - total 20 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for Conv3d-as-GEMM patch embedding.
2:
3: Dispatches torch.nn.functional.conv3d (cuDNN). The agent's solution is
4: forbidden from calling conv3d directly; this file is only for the benchmark's
5: reference baseline.
6: """
7: from __future__ import annotations
8:
9: import torch
10: import torch.nn.functional as F
11:
12:
13: def sota_forward(x: torch.Tensor, weight: torch.Tensor,
14: stride: tuple[int, int, int]) -> torch.Tensor:
15: return F.conv3d(x, weight, bias=None, stride=stride)
16:
17:
18: def is_available() -> bool:
19: return True
(End of file - total 19 lines)
</content>{
"command": "nvidia-smi",
"description": "Check GPU info"
}Thu May 7 06:11:25 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 Off | Off |
| 30% 33C P0 87W / 600W | 271MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (281 more chars){
"description": "Check Python and Torch/Triton versions",
"command": "uv run python --version && uv run python -c \"import torch; print('torch', torch.__version__); print('cuda', torch.version.cuda)\" && uv run python -c \"import triton; print('triton', triton.__version__)\" 2>&1"
}Python 3.11.15
torch 2.11.0+cu130
cuda 13.0
triton 3.6.0
{
"description": "Check GPU arch and compute capability",
"command": "uv run python -c \"import torch; print(torch.cuda.get_device_capability(0)); print(torch.cuda.get_device_name(0))\" 2>&1"
}(12, 0)
NVIDIA RTX PRO 6000 Blackwell Workstation Edition
{
"description": "Check memory bandwidth and GPU specs",
"command": "nvidia-smi --query-gpu=memory.bandwidth,memory.total --format=csv 2>&1"
}Field "memory.bandwidth" is not a valid field to query.
{
"command": "<353 chars \u2014 see below>",
"description": "Check Triton SM support details"
}uv run python -c "
import torch
import triton
print('Triton support:', triton.runtime.driver.active.get_current_target())
print('SM:', torch.cuda.get_device_capability(0))
print('Registers per SM:', torch.cuda.get_device_properties(0).multi_processor_count)
print('Shared mem per block:', torch.cuda.get_device_properties(0).shared_mem_per_block)
" 2>&1Traceback (most recent call last):
File "<string>", line 7, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'shared_mem_per_block'. Did you mean: 'shared_memory_per_block'?
Triton support: GPUTarget(backend='cuda', arch=120, warp_size=32)
SM: (12, 0)
Registers per SM: 188
{
"command": "<149 chars \u2014 see below>",
"description": "Quick Triton import test"
}uv run python -c "
import torch
# Quick test: can triton dot work on SM120?
import triton.language as tl
print('Triton imported successfully')
" 2>&1Triton imported successfully
{
"command": "<284 chars \u2014 see below>",
"description": "Get GPU properties detail"
}uv run python -c "
import torch
p = torch.cuda.get_device_properties(0)
print('SM count:', p.multi_processor_count)
print('Shared mem/block:', p.shared_memory_per_block)
print('Max threads/block:', p.max_threads_per_block)
print('Total global mem:', p.total_memory / 1e9, 'GB')
" 2>&1SM count: 188
Shared mem/block: 49152
Max threads/block: 1024
Total global mem: 101.949898752 GB
{
"todos": [
{
"content": "Write initial Triton fused patch-embed + GEMM kernel in solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py to verify correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize \u2014 iterate on block sizes and launch parameters",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py and compare to reference",
"status": "pending",
"priority": "medium"
},
{
"content"
... (136 more chars)content: 'Write initial Triton fused patch-embed + GEMM kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Run check.py to verify correctness'
status: 'pending'
priority: 'high'
content: 'Profile and optimize — iterate on block sizes and launch parameters'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py and compare to reference'
status: 'pending'
priority: 'medium'
content: 'Push performance — tune autotune configs for SM120 Blackwell'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -0,0 +1,168 @@
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ ],
+ key=['num_patches', 'N', 'K'],
+)
+@triton.jit
+def _patch_embed_fwd(
+ x_ptr, w_ptr, out_ptr,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
+ T_out, H_out, W_out,
+ num_patches, K, N,
+ kT, kH, kW,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(num_patches, BLOCK_M)
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+
+ m_start = pid_m * BLOCK_M
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ n_start = pid_n * BLOCK_N
+ n_offsets = n_start + tl.arange(0, BLOCK_N)
+
+ m_mask = m_offsets < num_patches
+ n_mask = n_offsets < N
+ both_mask = m_mask[:, None] & n_mask[None, :]
+
+ _m = m_offsets[:, None]
+ _n_2d = n_offsets[None, :]
+ _k_1d = tl.arange(0, BLOCK_K)
+
+ tom = T_out * H_out * W_out
+ how = H_out * W_out
+
+ _b = _m // tom
+ r1 = _m - _b * tom
+ _out_t = r1 // how
+ r2 = r1 - _out_t * how
+ _out_h = r2 // W_out
+ _out_w = r2 - _out_h * W_out
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ prod_khw = kH * kW
+ prod_tkhw = kT * prod_khw
+
+ for k_start in range(0, K, BLOCK_K):
+ k_offsets = k_start + _k_1d
+ k_mask = k_offsets < K
+ mk_mask = m_mask[:, None] & k_mask[None, :]
+ nk_mask = n_mask[None, :] & k_mask[:, None]
+
+ w_ptrs = w_ptr + k_offsets[:, None] + _n_2d * K
+ w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
+
+ _k2 = k_offsets[None, :]
+ _c = _k2 // prod_tkhw
+ rk1 = _k2 - _c * prod_tkhw
+ _dt = rk1 // prod_khw
+ rk2 = rk1 - _dt * prod_khw
+ _dh = rk2 // kW
+ _dw = rk2 - _dh * kW
+
+ x_ptrs = (x_ptr
+ + _b * stride_x_b
+ + _c * stride_x_c
+ + (_out_t * kT + _dt) * stride_x_t
+ + (_out_h * kH + _dh) * stride_x_h
+ + (_out_w * kW + _dw) * stride_x_w)
+
+ x = tl.load(x_ptrs, mask=mk_mask, other=0.0)
+
+ acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
+
+ out_ptrs = (out_ptr
+ + _b * stride_out_b
+ + _n_2d * stride_out_d
+ + _out_t * stride_out_t
+ + _out_h * stride_out_h
+ + _out_w * stride_out_w)
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=both_mask)
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, C: int, T: int, H: int, W: int,
+ kT: int, kH: int, kW: int, embed_dim: int):
+ super().__init__()
+ assert T % kT == 0 and H % kH == 0 and W % kW == 0
+ self.B, self.C, self.T, self.H, self.W = B, C, T, H, W
+ self.kT, self.kH, self.kW = kT, kH, kW
+ self.embed_dim = embed_dim
+
+ w = nn.Parameter(torch.empty(embed_dim, C, kT, kH, kW, dtype=torch.bfloat16))
+ nn.init.normal_(w, std=0.02)
+ conv = nn.Module()
+ conv.register_parameter('weight', w)
+ self.conv = conv
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, T, H, W = x.shape
+ kT, kH, kW = self.kT, self.kH, self.kW
+ embed_dim = self.embed_dim
+ weight = self.conv.weight
+
+ T_out = T // kT
+ H_out = H // kH
+ W_out = W // kW
+ K = C * kT * kH * kW
+ num_patches = B * T_out * H_out * W_out
+
+ out = torch.empty(B, embed_dim, T_out, H_out, W_out,
+ dtype=torch.bfloat16, device=x.device)
+
+ stride_x = x.stride()
+ stride_out = out.stride()
+
+ grid = lambda META: (
+ triton.cdiv(num_patches, META['BLOCK_M']) *
+ triton.cdiv(embed_dim, META['BLOCK_N']),
+ )
+
+ _patch_embed_fwd[grid](
+ x, weight, out,
+ stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
+ stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
+ T_out, H_out, W_out,
+ num_patches, K, embed_dim,
+ kT, kH, kW,
+ )
+
+ return out
+
+
+# Module-level shape shims
+B = 1
+C = 3
+T = 2
+H = 224
+W = 224
+kT = 2
+kH = 14
+kW = 14
+embed_dim = 1280
+
+
+def get_inputs():
+ x = torch.randn(B, C, T, H, W, dtype=torch.bfloat16) * 0.5
+ return [x]
+
+
+def get_init_inputs():
+ return [B, C, T, H, W, kT, kH, kW, embed_dim]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<5146 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Quick sanity test of Triton kernel",
"command": "<915 chars \u2014 see below>",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import torch
import solution
import reference
# Test shape 0
reference.B=1; reference.C=3; reference.T=2; reference.H=224; reference.W=224
reference.kT=2; reference.kH=14; reference.kW=14; reference.embed_dim=1280
init = reference.get_init_inputs()
ref = reference.Model(*init).cuda().eval()
sol = solution.Model(*init).cuda().eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = reference.get_inputs()[0].cuda()
with torch.no_grad():
r = ref(x)
s = sol(x)
print('Output shapes:', r.shape, s.shape)
print('Max diff:', (r - s).abs().max().item())
print('Mean diff:', (r - s).abs().mean().item())
print('Ref min/max:', r.min().item(), r.max().item())
print('Sol min/max:', s.min().item(), s.max().item())
print('PASS' if (r - s).abs().max().item() < 0.02 else 'FAIL')
" 2>&1Output shapes: torch.Size([1, 1280, 1, 16, 16]) torch.Size([1, 1280, 1, 16, 16])
Max diff: 0.0078125
Mean diff: 3.7439167499542236e-07
Ref min/max: -1.53125 1.484375
Sol min/max: -1.53125 1.484375
PASS
{
"command": "<1009 chars \u2014 see below>",
"description": "Test all shapes for correctness",
"timeout": 180000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import torch
import solution
import reference
import shapes
device = torch.device('cuda:0')
for idx, shape in enumerate(shapes.SHAPES):
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
ref = reference.Model(*init).to(device).eval()
sol = solution.Model(*init).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
for seed in (42, 123, 456):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
x = reference.get_inputs()[0].to(device)
with torch.no_grad():
r = ref(x)
s = sol(x)
diff = (r - s).abs().max().item()
ok = diff < 0.02
print(f'Shape {idx} seed {seed}: max_diff={diff:.6f} {\"PASS\" if ok else \"FAIL\"}')
if not ok:
raise SystemExit(1)
print('ALL PASS')
" 2>&1Shape 0 seed 42: max_diff=0.003906 PASS
Shape 0 seed 123: max_diff=0.003906 PASS
Shape 0 seed 456: max_diff=0.007812 PASS
Shape 1 seed 42: max_diff=0.007812 PASS
Shape 1 seed 123: max_diff=0.007812 PASS
Shape 1 seed 456: max_diff=0.007812 PASS
Shape 2 seed 42: max_diff=0.007812 PASS
Shape 2 seed 123: max_diff=0.007812 PASS
Shape 2 seed 456: max_diff=0.007812 PASS
Shape 3 seed 42: max_diff=0.003906 PASS
Shape 3 seed 123: max_diff=0.003906 PASS
Shape 3 seed 456: max_diff=0.003906 PASS
ALL PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1",
"description": "Run official check.py",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python benchmark.py 2>&1",
"description": "Run benchmark for performance baseline",
"timeout": 300000
}shape=0 variant=eager tflops=3.827 gbps=21.194 ms=0.201
shape=0 variant=compiled tflops=3.575 gbps=19.798 ms=0.216
shape=0 variant=sota tflops=3.827 gbps=21.191 ms=0.201
shape=0 variant=solution tflops=19.693 gbps=109.056 ms=0.039
shape=0 solution_peak_fraction=0.0985
shape=1 variant=eager tflops=13.936 gbps=36.347 ms=0.221
shape=1 variant=compiled tflops=12.018 gbps=31.345 ms=0.257
shape=1 variant=sota tflops=14.113 gbps=36.810 ms=0.218
shape=1 variant=solution tflops=53.850 gbps=140.449 ms=0.057
shape=1 solution_peak_fraction=0.2693
shape=2 variant=eager tflops=26.191 gbps=54.101 ms=0.265
shape=2 variant=compiled tflops=23.637 gbps=48.824 ms=0.293
shape=2 variant=sota tflops=26.333 gbps=54.394 ms=0.263
shape=2 variant=solution tflops=64.996 gbps=134.256 ms=0.107
shape=2 solution_peak_fraction=0.3250
shape=3 variant=eager tflops=6.511 gbps=25.260 ms=0.142
shape=3 variant=compiled tflops=5.261 gbps=20.411 ms=0.176
shape=3 variant=sota tflops=6.713 gbps=26.046 ms=0.138
shape=3 variant=solution tflops=29.446 gbps=114.241 ms=0.031
shape=3 solution_peak_fraction=0.1472
peak_fraction: 0.1887
RESULT: OK
{
"todos": [
{
"priority": "high",
"content": "Write initial Triton fused patch-embed + GEMM kernel in solution.py",
"status": "completed"
},
{
"priority": "high",
"content": "Run check.py to verify correctness",
"status": "completed"
},
{
"priority": "high",
"content": "Profile and optimize \u2014 try im2col+GEMM two-kernel approach",
"status": "in_progress"
},
{
"priority": "medium",
"content": "Run benchmark.py and compare to reference",
"status": "completed"
},
{
"priority":
... (133 more chars)content: 'Write initial Triton fused patch-embed + GEMM kernel in solution.py'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness'
status: 'completed'
priority: 'high'
content: 'Profile and optimize — try im2col+GEMM two-kernel approach'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark.py and compare to reference'
status: 'completed'
priority: 'medium'
content: 'Push performance — tune autotune configs for SM120 Blackwell'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -2,6 +2,71 @@
import torch.nn as nn
import triton
import triton.language as tl
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ ],
+ key=['num_patches', 'K'],
+)
+@triton.jit
+def _im2col_3d(
+ x_ptr, buf_ptr,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ T_out, H_out, W_out, num_patches, K,
+ kT, kH, kW,
+ BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(num_patches, BLOCK_M)
+ pid_m = pid % num_pid_m
+ pid_k = pid // num_pid_m
+
+ m_start = pid_m * BLOCK_M
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ k_start = pid_k * BLOCK_K
+ k_offsets = k_start + tl.arange(0, BLOCK_K)
+
+ m_mask = m_offsets < num_patches
+ k_mask = k_offsets < K
+
+ _m = m_offsets[:, None]
+ tom = T_out * H_out * W_out
+ how = H_out * W_out
+
+ _b = _m // tom
+ r1 = _m - _b * tom
+ _out_t = r1 // how
+ r2 = r1 - _out_t * how
+ _out_h = r2 // W_out
+ _out_w = r2 - _out_h * W_out
+
+ _k2 = k_offsets[None, :]
+ prod_khw = kH * kW
+ prod_tkhw = kT * prod_khw
+ _c = _k2 // prod_tkhw
+ rk1 = _k2 - _c * prod_tkhw
+ _dt = rk1 // prod_khw
+ rk2 = rk1 - _dt * prod_khw
+ _dh = rk2 // kW
+ _dw = rk2 - _dh * kW
+
+ x_ptrs = (x_ptr
+ + _b * stride_x_b
+ + _c * stride_x_c
+ + (_out_t * kT + _dt) * stride_x_t
+ + (_out_h * kH + _dh) * stride_x_h
+ + (_out_w * kW + _dw) * stride_x_w)
+
+ x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+ buf_ptrs = buf_ptr + m_offsets[:, None] * K + k_offsets[None, :]
+ tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
@triton.autotune(
@@ -13,22 +78,20 @@
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
],
key=['num_patches', 'N', 'K'],
)
@triton.jit
-def _patch_embed_fwd(
- x_ptr, w_ptr, out_ptr,
- stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
- stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
- T_out, H_out, W_out,
- num_patches, K, N,
- kT, kH, kW,
+def _gemm_nt(
+ a_ptr, b_ptr, c_ptr,
+ M, N, K,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
- num_pid_m = tl.cdiv(num_patches, BLOCK_M)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
@@ -37,64 +100,31 @@
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
- m_mask = m_offsets < num_patches
+ m_mask = m_offsets < M
n_mask = n_offsets < N
- both_mask = m_mask[:, None] & n_mask[None, :]
-
- _m = m_offsets[:, None]
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ _k_1d = tl.arange(0, BLOCK_K)
+ _m_2d = m_offsets[:, None]
_n_2d = n_offsets[None, :]
- _k_1d = tl.arange(0, BLOCK_K)
-
- tom = T_out * H_out * W_out
- how = H_out * W_out
-
- _b = _m // tom
- r1 = _m - _b * tom
- _out_t = r1 // how
- r2 = r1 - _out_t * how
- _out_h = r2 // W_out
- _out_w = r2 - _out_h * W_out
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- prod_khw = kH * kW
- prod_tkhw = kT * prod_khw
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + _k_1d
k_mask = k_offsets < K
+ nk_mask = n_mask[None, :] & k_mask[:, None]
mk_mask = m_mask[:, None] & k_mask[None, :]
- nk_mask = n_mask[None, :] & k_mask[:, None]
-
- w_ptrs = w_ptr + k_offsets[:, None] + _n_2d * K
- w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
-
- _k2 = k_offsets[None, :]
- _c = _k2 // prod_tkhw
- rk1 = _k2 - _c * prod_tkhw
- _dt = rk1 // prod_khw
- rk2 = rk1 - _dt * prod_khw
- _dh = rk2 // kW
- _dw = rk2 - _dh * kW
-
- x_ptrs = (x_ptr
- + _b * stride_x_b
- + _c * stride_x_c
- + (_out_t * kT + _dt) * stride_x_t
- + (_out_h * kH + _dh) * stride_x_h
- + (_out_w * kW + _dw) * stride_x_w)
-
- x = tl.load(x_ptrs, mask=mk_mask, other=0.0)
-
- acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
-
- out_ptrs = (out_ptr
- + _b * stride_out_b
- + _n_2d * stride_out_d
- + _out_t * stride_out_t
- + _out_h * stride_out_h
- + _out_w * stride_out_w)
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=both_mask)
+
+ a_ptrs = a_ptr + _m_2d * stride_am + k_offsets[None, :] * stride_ak
+ a = tl.load(a_ptrs, mask=mk_mask, other=0.0)
+
+ b_ptrs = b_ptr + k_offsets[:, None] * stride_bk + _n_2d * stride_bn
+ b = tl.load(b_ptrs, mask=nk_mask, other=0.0)
+
+ acc += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
+
+ c_ptrs = c_ptr + _m_2d * stride_cm + _n_2d * stride_cn
+ tl.store(c_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
class Model(nn.Module):
@@ -122,28 +152,42 @@
... (2110 more chars){
"content": "<6668 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"description": "Quick test of two-kernel solution",
"command": "<961 chars \u2014 see below>",
"timeout": 180000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import torch, solution, reference, shapes
device = torch.device('cuda:0')
for idx, shape in enumerate(shapes.SHAPES):
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
ref = reference.Model(*init).to(device).eval()
sol = solution.Model(*init).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
for seed in (42, 123, 456):
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
x = reference.get_inputs()[0].to(device)
with torch.no_grad():
r = ref(x); s = sol(x)
diff = (r - s).abs().max().item()
ok = diff < 0.02
print(f'Shape {idx} seed {seed}: max_diff={diff:.6f} {\"PASS\" if ok else \"FAIL\"}')
if not ok: raise SystemExit(1)
print('ALL PASS')
" 2>&1Shape 0 seed 42: max_diff=0.003906 PASS
Shape 0 seed 123: max_diff=0.003906 PASS
Shape 0 seed 456: max_diff=0.003906 PASS
Shape 1 seed 42: max_diff=0.007812 PASS
Shape 1 seed 123: max_diff=0.007812 PASS
Shape 1 seed 456: max_diff=0.007812 PASS
Shape 2 seed 42: max_diff=0.007812 PASS
Shape 2 seed 123: max_diff=0.007812 PASS
Shape 2 seed 456: max_diff=0.007812 PASS
Shape 3 seed 42: max_diff=0.003906 PASS
Shape 3 seed 123: max_diff=0.003906 PASS
Shape 3 seed 456: max_diff=0.003906 PASS
ALL PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python benchmark.py 2>&1",
"description": "Benchmark two-kernel solution",
"timeout": 300000
}shape=0 variant=eager tflops=3.681 gbps=20.386 ms=0.209
shape=0 variant=compiled tflops=3.199 gbps=17.715 ms=0.241
shape=0 variant=sota tflops=3.729 gbps=20.648 ms=0.207
shape=0 variant=solution tflops=12.453 gbps=68.964 ms=0.062
shape=0 solution_peak_fraction=0.0623
shape=1 variant=eager tflops=13.913 gbps=36.286 ms=0.222
shape=1 variant=compiled tflops=12.040 gbps=31.402 ms=0.256
shape=1 variant=sota tflops=14.118 gbps=36.821 ms=0.218
shape=1 variant=solution tflops=39.515 gbps=103.062 ms=0.078
shape=1 solution_peak_fraction=0.1976
shape=2 variant=eager tflops=26.250 gbps=54.223 ms=0.264
shape=2 variant=compiled tflops=23.511 gbps=48.565 ms=0.295
shape=2 variant=sota tflops=26.862 gbps=55.486 ms=0.258
shape=2 variant=solution tflops=60.658 gbps=125.296 ms=0.114
shape=2 solution_peak_fraction=0.3033
shape=3 variant=eager tflops=6.498 gbps=25.209 ms=0.142
shape=3 variant=compiled tflops=5.245 gbps=20.348 ms=0.176
shape=3 variant=sota tflops=6.621 gbps=25.688 ms=0.140
shape=3 variant=solution tflops=17.912 gbps=69.494 ms=0.052
shape=3 solution_peak_fraction=0.0896
peak_fraction: 0.1352
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -2,71 +2,6 @@
import torch.nn as nn
import triton
import triton.language as tl
-
-
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
- ],
- key=['num_patches', 'K'],
-)
-@triton.jit
-def _im2col_3d(
- x_ptr, buf_ptr,
- stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
- T_out, H_out, W_out, num_patches, K,
- kT, kH, kW,
- BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr,
-):
- pid = tl.program_id(0)
- num_pid_m = tl.cdiv(num_patches, BLOCK_M)
- pid_m = pid % num_pid_m
- pid_k = pid // num_pid_m
-
- m_start = pid_m * BLOCK_M
- m_offsets = m_start + tl.arange(0, BLOCK_M)
- k_start = pid_k * BLOCK_K
- k_offsets = k_start + tl.arange(0, BLOCK_K)
-
- m_mask = m_offsets < num_patches
- k_mask = k_offsets < K
-
- _m = m_offsets[:, None]
- tom = T_out * H_out * W_out
- how = H_out * W_out
-
- _b = _m // tom
- r1 = _m - _b * tom
- _out_t = r1 // how
- r2 = r1 - _out_t * how
- _out_h = r2 // W_out
- _out_w = r2 - _out_h * W_out
-
- _k2 = k_offsets[None, :]
- prod_khw = kH * kW
- prod_tkhw = kT * prod_khw
- _c = _k2 // prod_tkhw
- rk1 = _k2 - _c * prod_tkhw
- _dt = rk1 // prod_khw
- rk2 = rk1 - _dt * prod_khw
- _dh = rk2 // kW
- _dw = rk2 - _dh * kW
-
- x_ptrs = (x_ptr
- + _b * stride_x_b
- + _c * stride_x_c
- + (_out_t * kT + _dt) * stride_x_t
- + (_out_h * kH + _dh) * stride_x_h
- + (_out_w * kW + _dw) * stride_x_w)
-
- x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
-
- buf_ptrs = buf_ptr + m_offsets[:, None] * K + k_offsets[None, :]
- tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
@triton.autotune(
@@ -78,20 +13,23 @@
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
],
key=['num_patches', 'N', 'K'],
)
@triton.jit
-def _gemm_nt(
- a_ptr, b_ptr, c_ptr,
- M, N, K,
- stride_am, stride_ak,
- stride_bk, stride_bn,
- stride_cm, stride_cn,
+def _patch_embed_fwd(
+ x_ptr, w_ptr, out_ptr,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
+ T_out, H_out, W_out,
+ num_patches, K, N,
+ kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
- num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_m = tl.cdiv(num_patches, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
@@ -100,31 +38,64 @@
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
- m_mask = m_offsets < M
+ m_mask = m_offsets < num_patches
n_mask = n_offsets < N
+ both_mask = m_mask[:, None] & n_mask[None, :]
+
+ _m = m_offsets[:, None]
+ _n_2d = n_offsets[None, :]
+ _k_1d = tl.arange(0, BLOCK_K)
+
+ tom = T_out * H_out * W_out
+ how = H_out * W_out
+
+ _b = _m // tom
+ r1 = _m - _b * tom
+ _out_t = r1 // how
+ r2 = r1 - _out_t * how
+ _out_h = r2 // W_out
+ _out_w = r2 - _out_h * W_out
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- _k_1d = tl.arange(0, BLOCK_K)
- _m_2d = m_offsets[:, None]
- _n_2d = n_offsets[None, :]
+ prod_khw = kH * kW
+ prod_tkhw = kT * prod_khw
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + _k_1d
k_mask = k_offsets < K
+ mk_mask = m_mask[:, None] & k_mask[None, :]
nk_mask = n_mask[None, :] & k_mask[:, None]
- mk_mask = m_mask[:, None] & k_mask[None, :]
- a_ptrs = a_ptr + _m_2d * stride_am + k_offsets[None, :] * stride_ak
- a = tl.load(a_ptrs, mask=mk_mask, other=0.0)
+ w_ptrs = w_ptr + k_offsets[:, None] * N + _n_2d
+ w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
- b_ptrs = b_ptr + k_offsets[:, None] * stride_bk + _n_2d * stride_bn
- b = tl.load(b_ptrs, mask=nk_mask, other=0.0)
+ _k2 = k_offsets[None, :]
+ _c = _k2 // prod_tkhw
+ rk1 = _k2 - _c * prod_tkhw
+ _dt = rk1 // prod_khw
+ rk2 = rk1 - _dt * prod_khw
+ _dh = rk2 // kW
+ _dw = rk2 - _dh * kW
- acc += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
+ x_ptrs = (x_ptr
+ + _b * stride_x_b
+ + _c * stride_x_c
+ + (_out_t * kT + _dt) * stride_x_t
+ + (_out_h * kH + _dh) * stride_x_h
+ + (_out_w * kW + _dw) * stride_x_w)
- c_ptrs = c_ptr + _m_2d * stride_cm + _n_2d * stride_cn
- tl.store(c_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+ x = tl.load(x_ptrs, mask=mk_mask, other=0.0)
+
+ acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
+
+ out_ptrs = (out_ptr
+ + _b * stride_out_b
+ + _n_2d * stride_out_d
+ + _out_t * stride_out_t
+ + _out_h * stride_out_h
+ + _out_w * stride_out_w)
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=both_ma
... (2306 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<5412 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Quick test pre-transposed weight approach",
"command": "<961 chars \u2014 see below>",
"timeout": 180000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import torch, solution, reference, shapes
device = torch.device('cuda:0')
for idx, shape in enumerate(shapes.SHAPES):
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
ref = reference.Model(*init).to(device).eval()
sol = solution.Model(*init).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
for seed in (42, 123, 456):
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
x = reference.get_inputs()[0].to(device)
with torch.no_grad():
r = ref(x); s = sol(x)
diff = (r - s).abs().max().item()
ok = diff < 0.02
print(f'Shape {idx} seed {seed}: max_diff={diff:.6f} {\"PASS\" if ok else \"FAIL\"}')
if not ok: raise SystemExit(1)
print('ALL PASS')
" 2>&1Shape 0 seed 42: max_diff=0.003906 PASS
Shape 0 seed 123: max_diff=0.007812 PASS
Shape 0 seed 456: max_diff=0.003906 PASS
Shape 1 seed 42: max_diff=0.007812 PASS
Shape 1 seed 123: max_diff=0.007812 PASS
Shape 1 seed 456: max_diff=0.007812 PASS
Shape 2 seed 42: max_diff=0.007812 PASS
Shape 2 seed 123: max_diff=0.007812 PASS
Shape 2 seed 456: max_diff=0.007812 PASS
Shape 3 seed 42: max_diff=0.003906 PASS
Shape 3 seed 123: max_diff=0.003906 PASS
Shape 3 seed 456: max_diff=0.003906 PASS
ALL PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python benchmark.py 2>&1",
"description": "Benchmark pre-transposed weight approach",
"timeout": 300000
}shape=0 variant=eager tflops=3.683 gbps=20.394 ms=0.209
shape=0 variant=compiled tflops=3.197 gbps=17.703 ms=0.241
shape=0 variant=sota tflops=3.735 gbps=20.685 ms=0.206
shape=0 variant=solution tflops=11.864 gbps=65.702 ms=0.065
shape=0 solution_peak_fraction=0.0593
shape=1 variant=eager tflops=13.980 gbps=36.463 ms=0.221
shape=1 variant=compiled tflops=12.026 gbps=31.365 ms=0.256
shape=1 variant=sota tflops=14.122 gbps=36.831 ms=0.218
shape=1 variant=solution tflops=50.598 gbps=131.966 ms=0.061
shape=1 solution_peak_fraction=0.2530
shape=2 variant=eager tflops=26.354 gbps=54.437 ms=0.263
shape=2 variant=compiled tflops=23.428 gbps=48.394 ms=0.296
shape=2 variant=sota tflops=26.887 gbps=55.538 ms=0.258
shape=2 variant=solution tflops=73.891 gbps=152.631 ms=0.094
shape=2 solution_peak_fraction=0.3695
shape=3 variant=eager tflops=6.516 gbps=25.280 ms=0.142
shape=3 variant=compiled tflops=5.216 gbps=20.238 ms=0.177
shape=3 variant=sota tflops=6.615 gbps=25.664 ms=0.140
shape=3 variant=solution tflops=24.744 gbps=96.000 ms=0.037
shape=3 solution_peak_fraction=0.1237
peak_fraction: 0.1618
RESULT: OK
{
"command": "<977 chars \u2014 see below>",
"description": "Check autotune cache",
"timeout": 60000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import torch, solution, reference, shapes
# Revert to original (non-transposed) fused kernel
import importlib
importlib.reload(solution)
# Run once to trigger autotune
device = torch.device('cuda:0')
shape = shapes.SHAPES[0]
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
sol = solution.Model(*init).to(device).eval()
x = reference.get_inputs()[0].to(device)
_ = sol(x)
# Check autotune results
import triton.compiler as tc
from pathlib import Path
import glob
cache_dir = Path(torch.__file__).parent / '.triton' / 'cache'
if not cache_dir.exists():
cache_dir = Path.home() / '.triton' / 'cache'
print('Cache dir:', cache_dir)
print('Cache exists:', cache_dir.exists())
files = list(cache_dir.rglob('*.ttgir')) if cache_dir.exists() else []
print('TTGIR files:', len(files))
" 2>&1Cache dir: /home/infatoshi/.triton/cache
Cache exists: True
TTGIR files: 9008
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -44,7 +44,6 @@
_m = m_offsets[:, None]
_n_2d = n_offsets[None, :]
- _k_1d = tl.arange(0, BLOCK_K)
tom = T_out * H_out * W_out
how = H_out * W_out
@@ -62,7 +61,7 @@
prod_tkhw = kT * prod_khw
for k_start in range(0, K, BLOCK_K):
- k_offsets = k_start + _k_1d
+ k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
mk_mask = m_mask[:, None] & k_mask[None, :]
nk_mask = n_mask[None, :] & k_mask[:, None]
@@ -113,6 +112,10 @@
conv.register_parameter('weight', w)
self.conv = conv
+ K = C * kT * kH * kW
+ self._w_t = w.reshape(embed_dim, K).t().contiguous()
+ self._weight_version = w._version
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T, H, W = x.shape
kT, kH, kW = self.kT, self.kH, self.kW
@@ -126,12 +129,14 @@
num_patches = B * T_out * H_out * W_out
N = embed_dim
+ # Update cached transpose if weight changed (e.g. after load_state_dict)
+ if weight._version != self._weight_version:
+ self._w_t = weight.reshape(N, K).t().contiguous()
+ self._weight_version = weight._version
+ w_t = self._w_t
+
out = torch.empty(B, embed_dim, T_out, H_out, W_out,
dtype=torch.bfloat16, device=x.device)
-
- # Pre-transpose weight from (N, K) to (K, N) for contiguous GEMM access
- w_reshaped = weight.reshape(N, K)
- w_t = w_reshaped.t().contiguous()
stride_x = x.stride()
stride_out = out.stride()
{
"content": "<5633 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Check and benchmark cached transpose",
"timeout": 300000
}PASS
shape=0 variant=eager tflops=3.684 gbps=20.403 ms=0.209
shape=0 variant=compiled tflops=3.189 gbps=17.659 ms=0.242
shape=0 variant=sota tflops=3.724 gbps=20.624 ms=0.207
shape=0 variant=solution tflops=12.952 gbps=71.727 ms=0.060
shape=0 solution_peak_fraction=0.0648
shape=1 variant=eager tflops=13.899 gbps=36.250 ms=0.222
shape=1 variant=compiled tflops=12.057 gbps=31.445 ms=0.256
shape=1 variant=sota tflops=14.134 gbps=36.864 ms=0.218
shape=1 variant=solution tflops=57.670 gbps=150.412 ms=0.053
shape=1 solution_peak_fraction=0.2884
shape=2 variant=eager tflops=26.327 gbps=54.381 ms=0.263
shape=2 variant=compiled tflops=23.442 gbps=48.423 ms=0.296
shape=2 variant=sota tflops=26.468 gbps=54.673 ms=0.262
shape=2 variant=solution tflops=77.414 gbps=159.909 ms=0.090
shape=2 solution_peak_fraction=0.3871
shape=3 variant=eager tflops=6.487 gbps=25.169 ms=0.143
shape=3 variant=compiled tflops=5.224 gbps=20.269 ms=0.177
shape=3 variant=sota tflops=6.619 gbps=25.679 ms=0.140
shape=3 variant=solution tflops=29.061 gbps=112.748 ms=0.032
shape=3 solution_peak_fraction=0.1453
peak_fraction: 0.1800
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -22,7 +22,6 @@
def _patch_embed_fwd(
x_ptr, w_ptr, out_ptr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
- stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
T_out, H_out, W_out,
num_patches, K, N,
kT, kH, kW,
@@ -88,12 +87,7 @@
acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
- out_ptrs = (out_ptr
- + _b * stride_out_b
- + _n_2d * stride_out_d
- + _out_t * stride_out_t
- + _out_h * stride_out_h
- + _out_w * stride_out_w)
+ out_ptrs = out_ptr + _m * N + _n_2d
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=both_mask)
@@ -126,35 +120,31 @@
H_out = H // kH
W_out = W // kW
K = C * kT * kH * kW
- num_patches = B * T_out * H_out * W_out
+ M = B * T_out * H_out * W_out
N = embed_dim
- # Update cached transpose if weight changed (e.g. after load_state_dict)
if weight._version != self._weight_version:
self._w_t = weight.reshape(N, K).t().contiguous()
self._weight_version = weight._version
w_t = self._w_t
- out = torch.empty(B, embed_dim, T_out, H_out, W_out,
- dtype=torch.bfloat16, device=x.device)
+ out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
stride_x = x.stride()
- stride_out = out.stride()
grid = lambda META: (
- triton.cdiv(num_patches, META['BLOCK_M']) *
- triton.cdiv(N, META['BLOCK_N']),
+ triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
)
_patch_embed_fwd[grid](
- x, w_t, out,
+ x, w_t, out_2d,
stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
- stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
T_out, H_out, W_out,
- num_patches, K, N,
+ M, K, N,
kT, kH, kW,
)
+ out = out_2d.reshape(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3).contiguous()
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<5218 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Check and benchmark contiguous output",
"timeout": 300000
}PASS
shape=0 variant=eager tflops=3.687 gbps=20.420 ms=0.209
shape=0 variant=compiled tflops=3.190 gbps=17.666 ms=0.242
shape=0 variant=sota tflops=3.727 gbps=20.642 ms=0.207
shape=0 variant=solution tflops=11.789 gbps=65.284 ms=0.065
shape=0 solution_peak_fraction=0.0589
shape=1 variant=eager tflops=13.940 gbps=36.357 ms=0.221
shape=1 variant=compiled tflops=12.021 gbps=31.353 ms=0.256
shape=1 variant=sota tflops=14.112 gbps=36.807 ms=0.218
shape=1 variant=solution tflops=40.995 gbps=106.921 ms=0.075
shape=1 solution_peak_fraction=0.2050
shape=2 variant=eager tflops=26.201 gbps=54.121 ms=0.265
shape=2 variant=compiled tflops=23.387 gbps=48.308 ms=0.297
shape=2 variant=sota tflops=26.492 gbps=54.723 ms=0.262
shape=2 variant=solution tflops=62.171 gbps=128.422 ms=0.112
shape=2 solution_peak_fraction=0.3109
shape=3 variant=eager tflops=6.487 gbps=25.166 ms=0.143
shape=3 variant=compiled tflops=5.229 gbps=20.285 ms=0.177
shape=3 variant=sota tflops=6.606 gbps=25.629 ms=0.140
shape=3 variant=solution tflops=26.310 gbps=102.074 ms=0.035
shape=3 solution_peak_fraction=0.1315
peak_fraction: 0.1491
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -6,15 +6,16 @@
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
],
key=['num_patches', 'N', 'K'],
)
@@ -22,15 +23,22 @@
def _patch_embed_fwd(
x_ptr, w_ptr, out_ptr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
T_out, H_out, W_out,
num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
- pid_m = pid % num_pid_m
- pid_n = pid // num_pid_m
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ num_pid_in_group = GROUP_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
@@ -87,7 +95,12 @@
acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
- out_ptrs = out_ptr + _m * N + _n_2d
+ out_ptrs = (out_ptr
+ + _b * stride_out_b
+ + _n_2d * stride_out_d
+ + _out_t * stride_out_t
+ + _out_h * stride_out_h
+ + _out_w * stride_out_w)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=both_mask)
@@ -120,7 +133,7 @@
H_out = H // kH
W_out = W // kW
K = C * kT * kH * kW
- M = B * T_out * H_out * W_out
+ num_patches = B * T_out * H_out * W_out
N = embed_dim
if weight._version != self._weight_version:
@@ -128,23 +141,26 @@
self._weight_version = weight._version
w_t = self._w_t
- out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
+ out = torch.empty(B, embed_dim, T_out, H_out, W_out,
+ dtype=torch.bfloat16, device=x.device)
stride_x = x.stride()
+ stride_out = out.stride()
grid = lambda META: (
- triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),
+ triton.cdiv(num_patches, META['BLOCK_M']) *
+ triton.cdiv(N, META['BLOCK_N']),
)
_patch_embed_fwd[grid](
- x, w_t, out_2d,
+ x, w_t, out,
stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
+ stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
T_out, H_out, W_out,
- M, K, N,
+ num_patches, K, N,
kT, kH, kW,
)
- out = out_2d.reshape(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3).contiguous()
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"content": "<6072 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Check autotune selections for each shape",
"command": "<901 chars \u2014 see below>",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
import torch
import solution
import reference
import shapes
device = torch.device('cuda:0')
for idx, shape in enumerate(shapes.SHAPES):
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
ref = reference.Model(*init).to(device).eval()
sol = solution.Model(*init).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = reference.get_inputs()[0].to(device)
with torch.no_grad():
s = sol(x)
print(f'Shape {idx}: M={sol.forward.__globals__}', flush=True)
# Now let's also time it
for _ in range(10):
s = sol(x)
print(f'Shape {idx}: Done')
" 2>&1 | head -80Autotuning kernel _patch_embed_fwd with config BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 64, GROUP_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 64, BLOCK_N: 128, BLOCK_K: 64, GROUP_M: 8, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 128, BLOCK_N: 64, BLOCK_K: 64, GROUP_M: 8, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 256, BLOCK_N: 128, BLOCK_K: 64, GROUP_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 64, BLOCK_N: 64, BLOCK_K: 64, GROUP_M: 8, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, GROUP_M: 8, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 64, BLOCK_N: 128, BLOCK_K: 32, GROUP_M: 8, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 128, BLOCK_N: 64, BLOCK_K: 32, GROUP_M: 8, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 256, BLOCK_N: 64, BLOCK_K: 32, GROUP_M: 8, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _patch_embed_fwd with config BLOCK_M: 64, BLOCK_N: 256, BLOCK_K: 64, GROUP_M: 8, num_
... (29260 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -6,16 +6,20 @@
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=2, num_warps=8),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
],
key=['num_patches', 'N', 'K'],
)
@@ -28,17 +32,11 @@
num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
- GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
- num_pid_n = tl.cdiv(N, BLOCK_N)
- num_pid_in_group = GROUP_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
- pid_m = first_pid_m + (pid % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
@@ -73,7 +71,7 @@
mk_mask = m_mask[:, None] & k_mask[None, :]
nk_mask = n_mask[None, :] & k_mask[:, None]
- w_ptrs = w_ptr + k_offsets[:, None] * N + _n_2d
+ w_ptrs = w_ptr + k_offsets[:, None] + _n_2d * K
w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
_k2 = k_offsets[None, :]
@@ -119,10 +117,6 @@
conv.register_parameter('weight', w)
self.conv = conv
- K = C * kT * kH * kW
- self._w_t = w.reshape(embed_dim, K).t().contiguous()
- self._weight_version = w._version
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T, H, W = x.shape
kT, kH, kW = self.kT, self.kH, self.kW
@@ -136,11 +130,6 @@
num_patches = B * T_out * H_out * W_out
N = embed_dim
- if weight._version != self._weight_version:
- self._w_t = weight.reshape(N, K).t().contiguous()
- self._weight_version = weight._version
- w_t = self._w_t
-
out = torch.empty(B, embed_dim, T_out, H_out, W_out,
dtype=torch.bfloat16, device=x.device)
@@ -153,7 +142,7 @@
)
_patch_embed_fwd[grid](
- x, w_t, out,
+ x, weight, out,
stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
T_out, H_out, W_out,
{
"content": "<5724 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"description": "Clear cache and benchmark original approach with more configs",
"command": "rm -rf /home/infatoshi/.triton/cache && cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 600000
}{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache uv run python check.py 2>&1 && TRITON_CACHE_DIR=/tmp/opencode/triton_cache uv run python benchmark.py 2>&1",
"description": "Check and benchmark with fresh cache",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=3.684 gbps=20.403 ms=0.209
shape=0 variant=compiled tflops=3.121 gbps=17.286 ms=0.247
shape=0 variant=sota tflops=3.727 gbps=20.637 ms=0.207
shape=0 variant=solution tflops=13.354 gbps=73.954 ms=0.058
shape=0 solution_peak_fraction=0.0668
shape=1 variant=eager tflops=13.934 gbps=36.341 ms=0.221
shape=1 variant=compiled tflops=12.015 gbps=31.337 ms=0.257
shape=1 variant=sota tflops=14.110 gbps=36.802 ms=0.218
shape=1 variant=solution tflops=51.711 gbps=134.871 ms=0.060
shape=1 solution_peak_fraction=0.2586
shape=2 variant=eager tflops=26.341 gbps=54.411 ms=0.263
shape=2 variant=compiled tflops=23.484 gbps=48.510 ms=0.295
shape=2 variant=sota tflops=26.865 gbps=55.493 ms=0.258
shape=2 variant=solution tflops=65.181 gbps=134.640 ms=0.106
shape=2 solution_peak_fraction=0.3259
shape=3 variant=eager tflops=6.458 gbps=25.057 ms=0.143
shape=3 variant=compiled tflops=5.247 gbps=20.357 ms=0.176
shape=3 variant=sota tflops=6.606 gbps=25.629 ms=0.140
shape=3 variant=solution tflops=31.312 gbps=121.482 ms=0.030
shape=3 solution_peak_fraction=0.1566
peak_fraction: 0.1723
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -2,6 +2,69 @@
import torch.nn as nn
import triton
import triton.language as tl
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ ],
+ key=['num_patches', 'K'],
+)
+@triton.jit
+def _im2col_3d(x_ptr, buf_ptr,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ T_out, H_out, W_out, num_patches, K,
+ kT, kH, kW,
+ BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(num_patches, BLOCK_M)
+ pid_m = pid % num_pid_m
+ pid_k = pid // num_pid_m
+
+ m_start = pid_m * BLOCK_M
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ k_start = pid_k * BLOCK_K
+ k_offsets = k_start + tl.arange(0, BLOCK_K)
+
+ m_mask = m_offsets < num_patches
+ k_mask = k_offsets < K
+
+ _m = m_offsets[:, None]
+ tom = T_out * H_out * W_out
+ how = H_out * W_out
+
+ _b = _m // tom
+ r1 = _m - _b * tom
+ _out_t = r1 // how
+ r2 = r1 - _out_t * how
+ _out_h = r2 // W_out
+ _out_w = r2 - _out_h * W_out
+
+ _k2 = k_offsets[None, :]
+ prod_khw = kH * kW
+ prod_tkhw = kT * prod_khw
+ _c = _k2 // prod_tkhw
+ rk1 = _k2 - _c * prod_tkhw
+ _dt = rk1 // prod_khw
+ rk2 = rk1 - _dt * prod_khw
+ _dh = rk2 // kW
+ _dw = rk2 - _dh * kW
+
+ x_ptrs = (x_ptr
+ + _b * stride_x_b
+ + _c * stride_x_c
+ + (_out_t * kT + _dt) * stride_x_t
+ + (_out_h * kH + _dh) * stride_x_h
+ + (_out_w * kW + _dw) * stride_x_w)
+ x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+ buf_ptrs = buf_ptr + _m * K + _k2
+ tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
@triton.autotune(
@@ -13,8 +76,6 @@
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=2, num_warps=8),
@@ -24,12 +85,9 @@
key=['num_patches', 'N', 'K'],
)
@triton.jit
-def _patch_embed_fwd(
- x_ptr, w_ptr, out_ptr,
- stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+def _gemm_patch(a_ptr, w_ptr, out_ptr,
stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
- T_out, H_out, W_out,
- num_patches, K, N,
+ T_out, H_out, W_out, num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
@@ -45,7 +103,6 @@
m_mask = m_offsets < num_patches
n_mask = n_offsets < N
- both_mask = m_mask[:, None] & n_mask[None, :]
_m = m_offsets[:, None]
_n_2d = n_offsets[None, :]
@@ -62,36 +119,19 @@
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
- prod_khw = kH * kW
- prod_tkhw = kT * prod_khw
-
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
mk_mask = m_mask[:, None] & k_mask[None, :]
nk_mask = n_mask[None, :] & k_mask[:, None]
- w_ptrs = w_ptr + k_offsets[:, None] + _n_2d * K
+ a_ptrs = a_ptr + _m * K + k_offsets[None, :]
+ a = tl.load(a_ptrs, mask=mk_mask, other=0.0)
+
+ w_ptrs = w_ptr + k_offsets[:, None] * N + _n_2d
w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
- _k2 = k_offsets[None, :]
- _c = _k2 // prod_tkhw
- rk1 = _k2 - _c * prod_tkhw
- _dt = rk1 // prod_khw
- rk2 = rk1 - _dt * prod_khw
- _dh = rk2 // kW
- _dw = rk2 - _dh * kW
-
- x_ptrs = (x_ptr
- + _b * stride_x_b
- + _c * stride_x_c
- + (_out_t * kT + _dt) * stride_x_t
- + (_out_h * kH + _dh) * stride_x_h
- + (_out_w * kW + _dw) * stride_x_w)
-
- x = tl.load(x_ptrs, mask=mk_mask, other=0.0)
-
- acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
+ acc += tl.dot(a.to(tl.bfloat16), w.to(tl.bfloat16))
out_ptrs = (out_ptr
+ _b * stride_out_b
@@ -99,7 +139,7 @@
+ _out_t * stride_out_t
+ _out_h * stride_out_h
+ _out_w * stride_out_w)
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=both_mask)
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
class Model(nn.Module):
@@ -117,6 +157,10 @@
conv.register_parameter('weight', w)
self.conv = conv
+ K = C * kT * kH * kW
+ self._w_t = w.reshape(embed_dim, K).t().contiguous()
+ self._weight_version = w._version
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T, H, W = x.shape
kT, kH, kW = self.kT, self.kH, self.kW
@@ -127,8 +171,13 @@
H_out = H // kH
W_out = W // kW
K = C * kT * kH * kW
- num_patches = B * T_out * H_out * W_out
+ M = B * T_out * H_out * W_out
... (1525 more chars){
"content": "<7533 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"description": "Check and benchmark two-kernel approach with direct 5D output",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache2 uv run python check.py 2>&1 && TRITON_CACHE_DIR=/tmp/opencode/triton_cache2 uv run python benchmark.py 2>&1",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=3.680 gbps=20.381 ms=0.209
shape=0 variant=compiled tflops=3.193 gbps=17.683 ms=0.241
shape=0 variant=sota tflops=3.695 gbps=20.460 ms=0.209
shape=0 variant=solution tflops=15.350 gbps=85.007 ms=0.050
shape=0 solution_peak_fraction=0.0768
shape=1 variant=eager tflops=13.944 gbps=36.368 ms=0.221
shape=1 variant=compiled tflops=12.078 gbps=31.501 ms=0.255
shape=1 variant=sota tflops=14.147 gbps=36.896 ms=0.218
shape=1 variant=solution tflops=53.240 gbps=138.858 ms=0.058
shape=1 solution_peak_fraction=0.2662
shape=2 variant=eager tflops=26.125 gbps=53.965 ms=0.266
shape=2 variant=compiled tflops=23.528 gbps=48.599 ms=0.295
shape=2 variant=sota tflops=26.375 gbps=54.480 ms=0.263
shape=2 variant=solution tflops=91.229 gbps=188.444 ms=0.076
shape=2 solution_peak_fraction=0.4561
shape=3 variant=eager tflops=6.474 gbps=25.115 ms=0.143
shape=3 variant=compiled tflops=5.148 gbps=19.973 ms=0.180
shape=3 variant=sota tflops=6.589 gbps=25.565 ms=0.140
shape=3 variant=solution tflops=22.535 gbps=87.429 ms=0.041
shape=3 solution_peak_fraction=0.1127
peak_fraction: 0.1800
RESULT: OK
{
"command": "<787 chars \u2014 see below>",
"description": "Check autotune configs selected for both kernels",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache2 TRITON_PRINT_AUTOTUNING=1 uv run python -c "
import torch, solution, reference, shapes
device = torch.device('cuda:0')
for idx, shape in enumerate(shapes.SHAPES):
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
sol = solution.Model(*init).to(device).eval()
ref = reference.Model(*init).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = reference.get_inputs()[0].to(device)
_ = sol(x)
print(f'Shape {idx} done')
" 2>&1 | grep -E 'best config|key as|Autotuning.*_im2col|Autotuning.*_gemm'Autotuning kernel _im2col_3d with config BLOCK_M: 128, BLOCK_K: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel _im2col_3d with config BLOCK_M: 64, BLOCK_K: 64, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _im2col_3d with config BLOCK_M: 256, BLOCK_K: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel _im2col_3d with config BLOCK_M: 128, BLOCK_K: 32, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _im2col_3d with config BLOCK_M: 256, BLOCK_K: 32, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
with key as (256, 1176, 'torch.bfloat16', 'torch.bfloat16'),
best config selected: BLOCK_M: 64, BLOCK_K: 64, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None;
Autotuning kernel _gemm_patch with config BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel _gemm_patch with config BLOCK_M: 64, BLOCK_N: 128, BLOCK_K: 64, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _gemm_patch with config BLOCK_M: 128, BLOCK_N: 64, BLOCK_K: 64, num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _gemm_patch with config BLOCK_M: 256, BLOCK_N: 128, BLOCK_K: 64, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel _gemm_patch with config BLOCK_M: 64, BLOCK_N: 64, BLOCK_K: 64, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel _gemm_
... (9097 more chars){
"todos": [
{
"priority": "high",
"content": "Write initial Triton fused patch-embed + GEMM kernel in solution.py",
"status": "completed"
},
{
"priority": "high",
"content": "Run check.py to verify correctness",
"status": "completed"
},
{
"priority": "high",
"content": "Try im2col+GEMM two-kernel approach",
"status": "completed"
},
{
"priority": "medium",
"content": "Run benchmark.py and compare to reference",
"status": "completed"
},
{
"priority": "high",
"content": "Shap
... (218 more chars)content: 'Write initial Triton fused patch-embed + GEMM kernel in solution.py'
status: 'completed'
priority: 'high'
content: 'Run check.py to verify correctness'
status: 'completed'
priority: 'high'
content: 'Try im2col+GEMM two-kernel approach'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py and compare to reference'
status: 'completed'
priority: 'medium'
content: 'Shape 3 bottleneck: use hybrid fused/two-kernel dispatch'
status: 'in_progress'
priority: 'high'
content: 'Profile and finalize best approach'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -142,6 +142,95 @@
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+# Fused kernel for smaller shapes
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ ],
+ key=['num_patches', 'N', 'K'],
+)
+@triton.jit
+def _fused_patch_embed(x_ptr, w_ptr, out_ptr,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
+ T_out, H_out, W_out,
+ num_patches, K, N,
+ kT, kH, kW,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(num_patches, BLOCK_M)
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+
+ m_start = pid_m * BLOCK_M
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ n_start = pid_n * BLOCK_N
+ n_offsets = n_start + tl.arange(0, BLOCK_N)
+
+ m_mask = m_offsets < num_patches
+ n_mask = n_offsets < N
+
+ _m = m_offsets[:, None]
+ _n_2d = n_offsets[None, :]
+
+ tom = T_out * H_out * W_out
+ how = H_out * W_out
+
+ _b = _m // tom
+ r1 = _m - _b * tom
+ _out_t = r1 // how
+ r2 = r1 - _out_t * how
+ _out_h = r2 // W_out
+ _out_w = r2 - _out_h * W_out
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ prod_khw = kH * kW
+ prod_tkhw = kT * prod_khw
+
+ for k_start in range(0, K, BLOCK_K):
+ k_offsets = k_start + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < K
+ mk_mask = m_mask[:, None] & k_mask[None, :]
+ nk_mask = n_mask[None, :] & k_mask[:, None]
+
+ w_ptrs = w_ptr + k_offsets[:, None] + _n_2d * K
+ w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
+
+ _k2 = k_offsets[None, :]
+ _c = _k2 // prod_tkhw
+ rk1 = _k2 - _c * prod_tkhw
+ _dt = rk1 // prod_khw
+ rk2 = rk1 - _dt * prod_khw
+ _dh = rk2 // kW
+ _dw = rk2 - _dh * kW
+
+ x_ptrs = (x_ptr
+ + _b * stride_x_b
+ + _c * stride_x_c
+ + (_out_t * kT + _dt) * stride_x_t
+ + (_out_h * kH + _dh) * stride_x_h
+ + (_out_w * kW + _dw) * stride_x_w)
+ x = tl.load(x_ptrs, mask=mk_mask, other=0.0)
+
+ acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
+
+ out_ptrs = (out_ptr
+ + _b * stride_out_b
+ + _n_2d * stride_out_d
+ + _out_t * stride_out_t
+ + _out_h * stride_out_h
+ + _out_w * stride_out_w)
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+
+
class Model(nn.Module):
def __init__(self, B: int, C: int, T: int, H: int, W: int,
kT: int, kH: int, kW: int, embed_dim: int):
@@ -174,40 +263,55 @@
M = B * T_out * H_out * W_out
N = embed_dim
- if weight._version != self._weight_version:
- self._w_t = weight.reshape(N, K).t().contiguous()
- self._weight_version = weight._version
- w_t = self._w_t
-
out = torch.empty(B, embed_dim, T_out, H_out, W_out,
dtype=torch.bfloat16, device=x.device)
stride_x = x.stride()
stride_out = out.stride()
- buf = torch.empty(M, K, dtype=torch.bfloat16, device=x.device)
-
- grid_im2col = lambda META: (
- triton.cdiv(M, META['BLOCK_M']) *
- triton.cdiv(K, META['BLOCK_K']),
- )
- _im2col_3d[grid_im2col](
- x, buf,
- stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
- T_out, H_out, W_out, M, K,
- kT, kH, kW,
- )
-
- grid_gemm = lambda META: (
- triton.cdiv(M, META['BLOCK_M']) *
- triton.cdiv(N, META['BLOCK_N']),
- )
- _gemm_patch[grid_gemm](
- buf, w_t, out,
- stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
- T_out, H_out, W_out, M, K, N,
- kT, kH, kW,
- )
+ # Hybrid dispatch: use im2col+GEMM for large M (better GEMM perf),
+ # fused kernel for small M (avoids im2col overhead)
+ if M * K > 1000000:
+ if weight._version != self._weight_version:
+ self._w_t = weight.reshape(N, K).t().contiguous()
+ self._weight_version = weight._version
+ w_t = self._w_t
+
+ buf = torch.empty(M, K, dtype=torch.bfloat16, device=x.device)
+
+ grid_i = lambda META: (
+ triton.cdiv(M, META['BLOCK_M']) *
+ triton.cdiv(K, META['BLOCK_K']),
+ )
+ _im2col_3d[grid_i](
+ x, buf,
+ stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
+ T_out, H_out, W_out, M, K,
+ kT, kH, kW,
+ )
+
+ grid_g = lambda META: (
+ triton.cdiv(M, META['BLOCK_M']) *
+ triton.cdiv(N, META['BLOCK_N']),
+ )
+ _gemm_patch[grid_g](
+ buf, w_t, out,
+ stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
+
... (608 more chars){
"content": "<11304 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache3 uv run python check.py 2>&1 && TRITON_CACHE_DIR=/tmp/opencode/triton_cache3 uv run python benchmark.py 2>&1",
"description": "Check and benchmark hybrid approach",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=3.691 gbps=20.438 ms=0.209
shape=0 variant=compiled tflops=3.176 gbps=17.588 ms=0.243
shape=0 variant=sota tflops=3.718 gbps=20.592 ms=0.207
shape=0 variant=solution tflops=13.136 gbps=72.744 ms=0.059
shape=0 solution_peak_fraction=0.0657
shape=1 variant=eager tflops=13.964 gbps=36.420 ms=0.221
shape=1 variant=compiled tflops=11.992 gbps=31.277 ms=0.257
shape=1 variant=sota tflops=14.097 gbps=36.767 ms=0.219
shape=1 variant=solution tflops=52.287 gbps=136.371 ms=0.059
shape=1 solution_peak_fraction=0.2614
shape=2 variant=eager tflops=26.215 gbps=54.151 ms=0.265
shape=2 variant=compiled tflops=23.501 gbps=48.544 ms=0.295
shape=2 variant=sota tflops=26.450 gbps=54.636 ms=0.262
shape=2 variant=solution tflops=92.455 gbps=190.976 ms=0.075
shape=2 solution_peak_fraction=0.4623
shape=3 variant=eager tflops=6.476 gbps=25.124 ms=0.143
shape=3 variant=compiled tflops=5.237 gbps=20.317 ms=0.177
shape=3 variant=sota tflops=6.586 gbps=25.553 ms=0.140
shape=3 variant=solution tflops=31.432 gbps=121.945 ms=0.029
shape=3 solution_peak_fraction=0.1572
peak_fraction: 0.1879
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -10,7 +10,6 @@
triton.Config({'BLOCK_M': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
],
key=['num_patches', 'K'],
)
@@ -72,7 +71,6 @@
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
@@ -80,12 +78,11 @@
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=2, num_warps=8),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
],
key=['num_patches', 'N', 'K'],
)
@triton.jit
-def _gemm_patch(a_ptr, w_ptr, out_ptr,
+def _gemm_patch_5d(a_ptr, w_ptr, out_ptr,
stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
T_out, H_out, W_out, num_patches, K, N,
kT, kH, kW,
@@ -142,16 +139,15 @@
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
-# Fused kernel for smaller shapes
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
],
key=['num_patches', 'N', 'K'],
)
@@ -269,9 +265,7 @@
stride_x = x.stride()
stride_out = out.stride()
- # Hybrid dispatch: use im2col+GEMM for large M (better GEMM perf),
- # fused kernel for small M (avoids im2col overhead)
- if M * K > 1000000:
+ if M >= 1024 and K >= 1000:
if weight._version != self._weight_version:
self._w_t = weight.reshape(N, K).t().contiguous()
self._weight_version = weight._version
@@ -294,7 +288,7 @@
triton.cdiv(M, META['BLOCK_M']) *
triton.cdiv(N, META['BLOCK_N']),
)
- _gemm_patch[grid_g](
+ _gemm_patch_5d[grid_g](
buf, w_t, out,
stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
T_out, H_out, W_out, M, K, N,
@@ -316,7 +310,6 @@
return out
-# Module-level shape shims
B = 1
C = 3
T = 2
{
"content": "<10842 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -6,72 +6,11 @@
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
- ],
- key=['num_patches', 'K'],
-)
-@triton.jit
-def _im2col_3d(x_ptr, buf_ptr,
- stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
- T_out, H_out, W_out, num_patches, K,
- kT, kH, kW,
- BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr,
-):
- pid = tl.program_id(0)
- num_pid_m = tl.cdiv(num_patches, BLOCK_M)
- pid_m = pid % num_pid_m
- pid_k = pid // num_pid_m
-
- m_start = pid_m * BLOCK_M
- m_offsets = m_start + tl.arange(0, BLOCK_M)
- k_start = pid_k * BLOCK_K
- k_offsets = k_start + tl.arange(0, BLOCK_K)
-
- m_mask = m_offsets < num_patches
- k_mask = k_offsets < K
-
- _m = m_offsets[:, None]
- tom = T_out * H_out * W_out
- how = H_out * W_out
-
- _b = _m // tom
- r1 = _m - _b * tom
- _out_t = r1 // how
- r2 = r1 - _out_t * how
- _out_h = r2 // W_out
- _out_w = r2 - _out_h * W_out
-
- _k2 = k_offsets[None, :]
- prod_khw = kH * kW
- prod_tkhw = kT * prod_khw
- _c = _k2 // prod_tkhw
- rk1 = _k2 - _c * prod_tkhw
- _dt = rk1 // prod_khw
- rk2 = rk1 - _dt * prod_khw
- _dh = rk2 // kW
- _dw = rk2 - _dh * kW
-
- x_ptrs = (x_ptr
- + _b * stride_x_b
- + _c * stride_x_c
- + (_out_t * kT + _dt) * stride_x_t
- + (_out_h * kH + _dh) * stride_x_h
- + (_out_w * kW + _dw) * stride_x_w)
- x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
-
- buf_ptrs = buf_ptr + _m * K + _k2
- tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
-
-
-@triton.autotune(
- configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
@@ -82,81 +21,9 @@
key=['num_patches', 'N', 'K'],
)
@triton.jit
-def _gemm_patch_5d(a_ptr, w_ptr, out_ptr,
- stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
+def _patch_embed_fwd(x_ptr, w_ptr, out_ptr,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
T_out, H_out, W_out, num_patches, K, N,
- kT, kH, kW,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
-):
- pid = tl.program_id(0)
- num_pid_m = tl.cdiv(num_patches, BLOCK_M)
- pid_m = pid % num_pid_m
- pid_n = pid // num_pid_m
-
- m_start = pid_m * BLOCK_M
- m_offsets = m_start + tl.arange(0, BLOCK_M)
- n_start = pid_n * BLOCK_N
- n_offsets = n_start + tl.arange(0, BLOCK_N)
-
- m_mask = m_offsets < num_patches
- n_mask = n_offsets < N
-
- _m = m_offsets[:, None]
- _n_2d = n_offsets[None, :]
-
- tom = T_out * H_out * W_out
- how = H_out * W_out
-
- _b = _m // tom
- r1 = _m - _b * tom
- _out_t = r1 // how
- r2 = r1 - _out_t * how
- _out_h = r2 // W_out
- _out_w = r2 - _out_h * W_out
-
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
- for k_start in range(0, K, BLOCK_K):
- k_offsets = k_start + tl.arange(0, BLOCK_K)
- k_mask = k_offsets < K
- mk_mask = m_mask[:, None] & k_mask[None, :]
- nk_mask = n_mask[None, :] & k_mask[:, None]
-
- a_ptrs = a_ptr + _m * K + k_offsets[None, :]
- a = tl.load(a_ptrs, mask=mk_mask, other=0.0)
-
- w_ptrs = w_ptr + k_offsets[:, None] * N + _n_2d
- w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
-
- acc += tl.dot(a.to(tl.bfloat16), w.to(tl.bfloat16))
-
- out_ptrs = (out_ptr
- + _b * stride_out_b
- + _n_2d * stride_out_d
- + _out_t * stride_out_t
- + _out_h * stride_out_h
- + _out_w * stride_out_w)
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
-
-
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- ],
- key=['num_patches', 'N', 'K'],
-)
-@triton.jit
-def _fused_patch_embed(x_ptr, w_ptr, out_ptr,
- stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
- stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
- T_out, H_out, W_out,
- num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
... (3260 more chars){
"content": "<5033 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache4 uv run python check.py 2>&1 && TRITON_CACHE_DIR=/tmp/opencode/triton_cache4 uv run python benchmark.py 2>&1",
"description": "Benchmark fused kernel with contiguous output",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=3.683 gbps=20.394 ms=0.209
shape=0 variant=compiled tflops=3.197 gbps=17.707 ms=0.241
shape=0 variant=sota tflops=3.732 gbps=20.667 ms=0.207
shape=0 variant=solution tflops=11.071 gbps=61.308 ms=0.070
shape=0 solution_peak_fraction=0.0554
shape=1 variant=eager tflops=13.954 gbps=36.394 ms=0.221
shape=1 variant=compiled tflops=12.072 gbps=31.487 ms=0.255
shape=1 variant=sota tflops=14.102 gbps=36.780 ms=0.219
shape=1 variant=solution tflops=37.913 gbps=98.884 ms=0.081
shape=1 solution_peak_fraction=0.1896
shape=2 variant=eager tflops=26.276 gbps=54.275 ms=0.264
shape=2 variant=compiled tflops=23.408 gbps=48.352 ms=0.296
shape=2 variant=sota tflops=26.897 gbps=55.558 ms=0.258
shape=2 variant=solution tflops=53.508 gbps=110.527 ms=0.130
shape=2 solution_peak_fraction=0.2675
shape=3 variant=eager tflops=6.481 gbps=25.144 ms=0.143
shape=3 variant=compiled tflops=5.239 gbps=20.324 ms=0.177
shape=3 variant=sota tflops=6.615 gbps=25.664 ms=0.140
shape=3 variant=solution tflops=26.037 gbps=101.016 ms=0.036
shape=3 solution_peak_fraction=0.1302
peak_fraction: 0.1383
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -4,28 +4,14 @@
import triton.language as tl
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=2, num_warps=8),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
- ],
- key=['num_patches', 'N', 'K'],
-)
@triton.jit
def _patch_embed_fwd(x_ptr, w_ptr, out_ptr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
T_out, H_out, W_out, num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ num_warps: tl.constexpr, num_stages: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
@@ -85,7 +71,12 @@
acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
- out_ptrs = out_ptr + _m * N + _n_2d
+ out_ptrs = (out_ptr
+ + _b * stride_out_b
+ + _n_2d * stride_out_d
+ + _out_t * stride_out_t
+ + _out_h * stride_out_h
+ + _out_w * stride_out_w)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
@@ -117,24 +108,33 @@
M = B * T_out * H_out * W_out
N = embed_dim
- out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
+ out = torch.empty(B, embed_dim, T_out, H_out, W_out,
+ dtype=torch.bfloat16, device=x.device)
+ stride_x = x.stride()
+ stride_out = out.stride()
- stride_x = x.stride()
+ if M >= 1536:
+ B_M, B_N, B_K = 128, 128, 64
+ elif M >= 512:
+ B_M, B_N, B_K = 128, 128, 64
+ else:
+ B_M, B_N, B_K = 128, 128, 64
- grid = lambda META: (
- triton.cdiv(M, META['BLOCK_M']) *
- triton.cdiv(N, META['BLOCK_N']),
+ grid = (triton.cdiv(M, B_M) * triton.cdiv(N, B_N), 1, 1)
+ num_stages = 3
+ num_warps = 8
+
+ _patch_embed_fwd[grid](
+ x, weight, out,
+ stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
+ stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
+ T_out, H_out, W_out, M, K, N,
+ kT, kH, kW,
+ BLOCK_M=B_M, BLOCK_N=B_N, BLOCK_K=B_K,
+ num_warps=num_warps, num_stages=num_stages,
)
- _patch_embed_fwd[grid](
- x, weight, out_2d,
- stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
- T_out, H_out, W_out, M, K, N,
- kT, kH, kW,
- )
-
- out = out_2d.reshape(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3)
- return out.contiguous()
+ return out
B = 1
{
"content": "<4490 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "<1006 chars \u2014 see below>",
"description": "Test manual config approach",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache5 uv run python -c "
import torch, solution, reference, shapes
device = torch.device('cuda:0')
for idx, shape in enumerate(shapes.SHAPES):
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
ref = reference.Model(*init).to(device).eval()
sol = solution.Model(*init).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
for seed in (42, 123, 456):
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
x = reference.get_inputs()[0].to(device)
with torch.no_grad():
r = ref(x); s = sol(x)
diff = (r - s).abs().max().item()
ok = diff < 0.02
print(f'Shape {idx} seed {seed}: max_diff={diff:.6f} {\"PASS\" if ok else \"FAIL\"}')
if not ok: raise SystemExit(1)
print('ALL PASS')
" 2>&1Shape 0 seed 42: max_diff=0.007812 PASS
Shape 0 seed 123: max_diff=0.003906 PASS
Shape 0 seed 456: max_diff=0.003906 PASS
Shape 1 seed 42: max_diff=0.007812 PASS
Shape 1 seed 123: max_diff=0.007812 PASS
Shape 1 seed 456: max_diff=0.007812 PASS
Shape 2 seed 42: max_diff=0.007812 PASS
Shape 2 seed 123: max_diff=0.007812 PASS
Shape 2 seed 456: max_diff=0.007812 PASS
Shape 3 seed 42: max_diff=0.003906 PASS
Shape 3 seed 123: max_diff=0.003906 PASS
Shape 3 seed 456: max_diff=0.003906 PASS
ALL PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache5 uv run python benchmark.py 2>&1",
"description": "Benchmark manual config fused kernel",
"timeout": 300000
}shape=0 variant=eager tflops=3.691 gbps=20.441 ms=0.209
shape=0 variant=compiled tflops=3.188 gbps=17.656 ms=0.242
shape=0 variant=sota tflops=3.724 gbps=20.624 ms=0.207
shape=0 variant=solution tflops=7.622 gbps=42.208 ms=0.101
shape=0 solution_peak_fraction=0.0381
shape=1 variant=eager tflops=13.926 gbps=36.320 ms=0.221
shape=1 variant=compiled tflops=11.799 gbps=30.773 ms=0.261
shape=1 variant=sota tflops=14.028 gbps=36.587 ms=0.220
shape=1 variant=solution tflops=28.822 gbps=75.172 ms=0.107
shape=1 solution_peak_fraction=0.1441
shape=2 variant=eager tflops=26.280 gbps=54.285 ms=0.264
shape=2 variant=compiled tflops=23.008 gbps=47.526 ms=0.301
shape=2 variant=sota tflops=26.868 gbps=55.500 ms=0.258
shape=2 variant=solution tflops=56.967 gbps=117.673 ms=0.122
shape=2 solution_peak_fraction=0.2848
shape=3 variant=eager tflops=6.506 gbps=25.240 ms=0.142
shape=3 variant=compiled tflops=5.215 gbps=20.234 ms=0.177
shape=3 variant=sota tflops=6.597 gbps=25.594 ms=0.140
shape=3 variant=solution tflops=29.223 gbps=113.375 ms=0.032
shape=3 solution_peak_fraction=0.1461
peak_fraction: 0.1230
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -4,14 +4,30 @@
import triton.language as tl
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ ],
+ key=['num_patches', 'N', 'K'],
+)
@triton.jit
-def _patch_embed_fwd(x_ptr, w_ptr, out_ptr,
+def _patch_embed_fwd(
+ x_ptr, w_ptr, out_ptr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
- T_out, H_out, W_out, num_patches, K, N,
+ T_out, H_out, W_out,
+ num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
- num_warps: tl.constexpr, num_stages: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
@@ -25,6 +41,7 @@
m_mask = m_offsets < num_patches
n_mask = n_offsets < N
+ both_mask = m_mask[:, None] & n_mask[None, :]
_m = m_offsets[:, None]
_n_2d = n_offsets[None, :]
@@ -77,7 +94,7 @@
+ _out_t * stride_out_t
+ _out_h * stride_out_h
+ _out_w * stride_out_w)
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=both_mask)
class Model(nn.Module):
@@ -105,33 +122,27 @@
H_out = H // kH
W_out = W // kW
K = C * kT * kH * kW
- M = B * T_out * H_out * W_out
+ num_patches = B * T_out * H_out * W_out
N = embed_dim
out = torch.empty(B, embed_dim, T_out, H_out, W_out,
dtype=torch.bfloat16, device=x.device)
+
stride_x = x.stride()
stride_out = out.stride()
- if M >= 1536:
- B_M, B_N, B_K = 128, 128, 64
- elif M >= 512:
- B_M, B_N, B_K = 128, 128, 64
- else:
- B_M, B_N, B_K = 128, 128, 64
-
- grid = (triton.cdiv(M, B_M) * triton.cdiv(N, B_N), 1, 1)
- num_stages = 3
- num_warps = 8
+ grid = lambda META: (
+ triton.cdiv(num_patches, META['BLOCK_M']) *
+ triton.cdiv(N, META['BLOCK_N']),
+ )
_patch_embed_fwd[grid](
x, weight, out,
stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
stride_out[0], stride_out[1], stride_out[2], stride_out[3], stride_out[4],
- T_out, H_out, W_out, M, K, N,
+ T_out, H_out, W_out,
+ num_patches, K, N,
kT, kH, kW,
- BLOCK_M=B_M, BLOCK_N=B_N, BLOCK_K=B_K,
- num_warps=num_warps, num_stages=num_stages,
)
return out
{
"content": "<5304 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache6 uv run python check.py 2>&1 && TRITON_CACHE_DIR=/tmp/opencode/triton_cache6 uv run python benchmark.py 2>&1",
"description": "Benchmark improved autotune configs",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=3.674 gbps=20.346 ms=0.210
shape=0 variant=compiled tflops=3.168 gbps=17.541 ms=0.243
shape=0 variant=sota tflops=3.725 gbps=20.627 ms=0.207
shape=0 variant=solution tflops=11.366 gbps=62.943 ms=0.068
shape=0 solution_peak_fraction=0.0568
shape=1 variant=eager tflops=13.952 gbps=36.389 ms=0.221
shape=1 variant=compiled tflops=12.104 gbps=31.568 ms=0.255
shape=1 variant=sota tflops=14.127 gbps=36.845 ms=0.218
shape=1 variant=solution tflops=41.364 gbps=107.885 ms=0.075
shape=1 solution_peak_fraction=0.2068
shape=2 variant=eager tflops=26.336 gbps=54.401 ms=0.263
shape=2 variant=compiled tflops=22.984 gbps=47.476 ms=0.302
shape=2 variant=sota tflops=26.681 gbps=55.114 ms=0.260
shape=2 variant=solution tflops=64.879 gbps=134.015 ms=0.107
shape=2 solution_peak_fraction=0.3244
shape=3 variant=eager tflops=6.487 gbps=25.166 ms=0.143
shape=3 variant=compiled tflops=5.237 gbps=20.317 ms=0.177
shape=3 variant=sota tflops=6.616 gbps=25.667 ms=0.140
shape=3 variant=solution tflops=29.749 gbps=115.417 ms=0.031
shape=3 solution_peak_fraction=0.1487
peak_fraction: 0.1543
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -2,6 +2,68 @@
import torch.nn as nn
import triton
import triton.language as tl
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ ],
+ key=['num_patches', 'K'],
+)
+@triton.jit
+def _im2col_3d(x_ptr, buf_ptr,
+ stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
+ T_out, H_out, W_out, num_patches, K,
+ kT, kH, kW,
+ BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(num_patches, BLOCK_M)
+ pid_m = pid % num_pid_m
+ pid_k = pid // num_pid_m
+
+ m_start = pid_m * BLOCK_M
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ k_start = pid_k * BLOCK_K
+ k_offsets = k_start + tl.arange(0, BLOCK_K)
+
+ m_mask = m_offsets < num_patches
+ k_mask = k_offsets < K
+
+ _m = m_offsets[:, None]
+ tom = T_out * H_out * W_out
+ how = H_out * W_out
+
+ _b = _m // tom
+ r1 = _m - _b * tom
+ _out_t = r1 // how
+ r2 = r1 - _out_t * how
+ _out_h = r2 // W_out
+ _out_w = r2 - _out_h * W_out
+
+ _k2 = k_offsets[None, :]
+ prod_khw = kH * kW
+ prod_tkhw = kT * prod_khw
+ _c = _k2 // prod_tkhw
+ rk1 = _k2 - _c * prod_tkhw
+ _dt = rk1 // prod_khw
+ rk2 = rk1 - _dt * prod_khw
+ _dh = rk2 // kW
+ _dw = rk2 - _dh * kW
+
+ x_ptrs = (x_ptr
+ + _b * stride_x_b
+ + _c * stride_x_c
+ + (_out_t * kT + _dt) * stride_x_t
+ + (_out_h * kH + _dh) * stride_x_h
+ + (_out_w * kW + _dw) * stride_x_w)
+ x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+
+ buf_ptrs = buf_ptr + _m * K + _k2
+ tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
@triton.autotune(
@@ -9,19 +71,85 @@
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
],
key=['num_patches', 'N', 'K'],
)
@triton.jit
-def _patch_embed_fwd(
- x_ptr, w_ptr, out_ptr,
+def _gemm_patch_5d(a_ptr, w_ptr, out_ptr,
+ stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
+ T_out, H_out, W_out, num_patches, K, N,
+ kT, kH, kW,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(num_patches, BLOCK_M)
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+
+ m_start = pid_m * BLOCK_M
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ n_start = pid_n * BLOCK_N
+ n_offsets = n_start + tl.arange(0, BLOCK_N)
+
+ m_mask = m_offsets < num_patches
+ n_mask = n_offsets < N
+
+ _m = m_offsets[:, None]
+ _n_2d = n_offsets[None, :]
+
+ tom = T_out * H_out * W_out
+ how = H_out * W_out
+
+ _b = _m // tom
+ r1 = _m - _b * tom
+ _out_t = r1 // how
+ r2 = r1 - _out_t * how
+ _out_h = r2 // W_out
+ _out_w = r2 - _out_h * W_out
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, K, BLOCK_K):
+ k_offsets = k_start + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < K
+ mk_mask = m_mask[:, None] & k_mask[None, :]
+ nk_mask = n_mask[None, :] & k_mask[:, None]
+
+ a_ptrs = a_ptr + _m * K + k_offsets[None, :]
+ a = tl.load(a_ptrs, mask=mk_mask, other=0.0)
+
+ w_ptrs = w_ptr + k_offsets[:, None] * N + _n_2d
+ w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
+
+ acc += tl.dot(a.to(tl.bfloat16), w.to(tl.bfloat16))
+
+ out_ptrs = (out_ptr
+ + _b * stride_out_b
+ + _n_2d * stride_out_d
+ + _out_t * stride_out_t
+ + _out_h * stride_out_h
+ + _out_w * stride_out_w)
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=
... (3933 more chars){
"content": "<10690 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "<1024 chars \u2014 see below>",
"description": "Manual timing of shape 0 with chosen config",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache7 uv run python -c "
import torch, solution, reference, shapes, time
device = torch.device('cuda:0')
# Test shape 0 (M=256, N=1280, K=1176) with different configs
shape = shapes.SHAPES[0]
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
ref = reference.Model(*init).to(device).eval()
sol = solution.Model(*init).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
torch.manual_seed(42)
x = reference.get_inputs()[0].to(device)
# Warmup
for _ in range(5):
_ = sol(x)
torch.cuda.synchronize()
# Time many iterations
N_ITERS = 200
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(N_ITERS):
_ = sol(x)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / N_ITERS
print(f'Shape 0: {ms:.4f} ms')
" 2>&1Shape 0: 0.0234 ms
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache7 uv run python benchmark.py 2>&1",
"description": "Benchmark with existing cache",
"timeout": 300000
}shape=0 variant=eager tflops=3.701 gbps=20.497 ms=0.208
shape=0 variant=compiled tflops=3.139 gbps=17.381 ms=0.246
shape=0 variant=sota tflops=3.722 gbps=20.613 ms=0.207
shape=0 variant=solution tflops=13.054 gbps=72.291 ms=0.059
shape=0 solution_peak_fraction=0.0653
shape=1 variant=eager tflops=13.931 gbps=36.333 ms=0.221
shape=1 variant=compiled tflops=12.029 gbps=31.373 ms=0.256
shape=1 variant=sota tflops=14.110 gbps=36.802 ms=0.218
shape=1 variant=solution tflops=51.545 gbps=134.438 ms=0.060
shape=1 solution_peak_fraction=0.2577
shape=2 variant=eager tflops=26.320 gbps=54.368 ms=0.264
shape=2 variant=compiled tflops=23.447 gbps=48.434 ms=0.296
shape=2 variant=sota tflops=26.890 gbps=55.544 ms=0.258
shape=2 variant=solution tflops=91.537 gbps=189.081 ms=0.076
shape=2 solution_peak_fraction=0.4577
shape=3 variant=eager tflops=6.507 gbps=25.246 ms=0.142
shape=3 variant=compiled tflops=5.215 gbps=20.234 ms=0.177
shape=3 variant=sota tflops=6.670 gbps=25.878 ms=0.139
shape=3 variant=solution tflops=31.621 gbps=122.678 ms=0.029
shape=3 solution_peak_fraction=0.1581
peak_fraction: 0.1868
RESULT: OK
{
"command": "<3058 chars \u2014 see below>",
"description": "Benchmark all configs for shape 0 GEMM",
"timeout": 300000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python -c "
import torch, triton, triton.language as tl
# Write a quick test harness to benchmark different configs for shape 0
device = torch.device('cuda:0')
M, N, K = 256, 1280, 1176
kT, kH, kW = 2, 14, 14
T_out, H_out, W_out = 1, 16, 16
x = torch.randn(1, 3, 2, 224, 224, dtype=torch.bfloat16, device=device)
w = torch.randn(N, K, dtype=torch.bfloat16, device=device)
out = torch.empty(M, N, dtype=torch.bfloat16, device=device)
stride_x = x.stride()
# I need to use a simplified kernel for this. Let me just create a test GEMM
@triton.jit
def test_gemm(a_ptr, b_ptr, c_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
m_mask = m_offsets < M
n_mask = n_offsets < N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
a = tl.load(a_ptr + m_offsets[:, None] * K + k_offsets[None, :],
mask=m_mask[:, None] & k_mask[None, :], other=0.0)
b = tl.load(b_ptr + k_offsets[:, None] + n_offsets[None, :] * K,
mask=n_mask[None, :] & k_mask[:, None], other=0.0)
acc += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
c = c_ptr + m_offsets[:, None] * N + n_offsets[None, :]
tl.store(c, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N, K, dtype=torch.bfloat16, device=device)
c = torch.empty(M, N, dtype=torch.bfloat16, device=device)
configs = [
(128, 128, 64, 8, 3),
(64, 128, 64, 8, 4),
(128, 64, 64, 8, 4),
(256, 128, 64, 8, 3),
(64, 64, 64, 4, 4),
(128, 128, 32, 8, 4),
(64, 128, 32, 4, 4),
(128, 64, 32, 4, 4),
(256, 64, 32, 8, 4),
(64, 256, 64, 8, 3),
(128, 256, 64, 8, 3),
(256, 256, 64, 8, 2),
(32, 128, 64, 4, 4),
]
for BM, BN, BK, nw, ns in configs:
grid = (triton.cdiv(M, BM) * triton.cdiv(N, BN),)
# Warmup
for _ in range(5):
test_gemm[grid](a, b, c, M, N, K, BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=nw, num_stages=ns)
torch.cuda.synchronize()
# Time
N_ITERS = 100
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(N_ITERS):
test_gemm[grid](a, b, c, M, N, K, BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=nw, num_stages=ns)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / N_ITERS
tflops = 2 * M * N * K / (ms * 1e-3) / 1e12
print(f'BM={BM:3d} BN={BN:3d} BK={BK:2d} nw={nw} ns={ns}: {ms:.4f}ms {tflops:.2f} TFLOPS')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 18, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __
... (150 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/bench_configs.py
@@ -0,0 +1,79 @@
+import torch
+import triton
+import triton.language as tl
+import time
+
+
+@triton.jit
+def test_gemm(a_ptr, b_ptr, c_ptr, M, N, K,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+ pid = tl.program_id(0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+ m_start = pid_m * BLOCK_M
+ m_offsets = m_start + tl.arange(0, BLOCK_M)
+ n_start = pid_n * BLOCK_N
+ n_offsets = n_start + tl.arange(0, BLOCK_N)
+ m_mask = m_offsets < M
+ n_mask = n_offsets < N
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k_start in range(0, K, BLOCK_K):
+ k_offsets = k_start + tl.arange(0, BLOCK_K)
+ k_mask = k_offsets < K
+ a = tl.load(a_ptr + m_offsets[:, None] * K + k_offsets[None, :],
+ mask=m_mask[:, None] & k_mask[None, :], other=0.0)
+ b = tl.load(b_ptr + k_offsets[:, None] + n_offsets[None, :] * K,
+ mask=n_mask[None, :] & k_mask[:, None], other=0.0)
+ acc += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
+ c = c_ptr + m_offsets[:, None] * N + n_offsets[None, :]
+ tl.store(c, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+
+
+device = torch.device('cuda:0')
+
+shapes_to_test = [
+ ("M=256,N=1280,K=1176", 256, 1280, 1176),
+ ("M=784,N=768,K=768", 784, 768, 768),
+]
+
+for name, M, N, K in shapes_to_test:
+ print(f"\n=== {name} ===")
+ a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
+ b = torch.randn(N, K, dtype=torch.bfloat16, device=device)
+ c = torch.empty(M, N, dtype=torch.bfloat16, device=device)
+
+ configs = [
+ (128, 128, 64, 8, 3),
+ (64, 128, 64, 8, 4),
+ (128, 64, 64, 8, 4),
+ (256, 128, 64, 8, 3),
+ (64, 64, 64, 4, 4),
+ (128, 128, 32, 8, 4),
+ (64, 128, 32, 4, 4),
+ (128, 64, 32, 4, 4),
+ (256, 64, 32, 8, 4),
+ (64, 256, 64, 8, 3),
+ (128, 256, 64, 8, 3),
+ (256, 256, 64, 8, 2),
+ (32, 128, 64, 4, 4),
+ ]
+
+ for BM, BN, BK, nw, ns in configs:
+ grid = (triton.cdiv(M, BM) * triton.cdiv(N, BN),)
+ for _ in range(5):
+ test_gemm[grid](a, b, c, M, N, K, BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=nw, num_stages=ns)
+ torch.cuda.synchronize()
+
+ N_ITERS = 100
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ for _ in range(N_ITERS):
+ test_gemm[grid](a, b, c, M, N, K, BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=nw, num_stages=ns)
+ end.record()
+ torch.cuda.synchronize()
+ ms = start.elapsed_time(end) / N_ITERS
+ tflops = 2 * M * N * K / (ms * 1e-3) / 1e12
+ blocks = triton.cdiv(M, BM) * triton.cdiv(N, BN)
+ print(f'BM={BM:3d} BN={BN:3d} BK={BK:2d} nw={nw} ns={ns} blk={blocks:4d}: {ms:.4f}ms {tflops:.2f} TFLOPS')
{
"content": "<2894 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/bench_configs.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && uv run python bench_configs.py 2>&1",
"description": "Run config benchmarks",
"timeout": 300000
}=== M=256,N=1280,K=1176 ===
BM=128 BN=128 BK=64 nw=8 ns=3 blk= 20: 0.0323ms 23.89 TFLOPS
BM= 64 BN=128 BK=64 nw=8 ns=4 blk= 40: 0.0211ms 36.52 TFLOPS
BM=128 BN= 64 BK=64 nw=8 ns=4 blk= 40: 0.0242ms 31.90 TFLOPS
BM=256 BN=128 BK=64 nw=8 ns=3 blk= 10: 0.1271ms 6.07 TFLOPS
BM= 64 BN= 64 BK=64 nw=4 ns=4 blk= 80: 0.0197ms 39.22 TFLOPS
BM=128 BN=128 BK=32 nw=8 ns=4 blk= 20: 0.0329ms 23.39 TFLOPS
BM= 64 BN=128 BK=32 nw=4 ns=4 blk= 40: 0.0258ms 29.90 TFLOPS
BM=128 BN= 64 BK=32 nw=4 ns=4 blk= 40: 0.0261ms 29.49 TFLOPS
BM=256 BN= 64 BK=32 nw=8 ns=4 blk= 20: 0.0377ms 20.46 TFLOPS
BM= 64 BN=256 BK=64 nw=8 ns=3 blk= 20: 0.0484ms 15.91 TFLOPS
BM=128 BN=256 BK=64 nw=8 ns=3 blk= 10: 0.0891ms 8.65 TFLOPS
BM=256 BN=256 BK=64 nw=8 ns=2 blk= 5: 1.4203ms 0.54 TFLOPS
BM= 32 BN=128 BK=64 nw=4 ns=4 blk= 80: 0.0228ms 33.87 TFLOPS
=== M=784,N=768,K=768 ===
BM=128 BN=128 BK=64 nw=8 ns=3 blk= 42: 0.0130ms 71.24 TFLOPS
BM= 64 BN=128 BK=64 nw=8 ns=4 blk= 78: 0.0075ms 123.83 TFLOPS
BM=128 BN= 64 BK=64 nw=8 ns=4 blk= 84: 0.0075ms 123.96 TFLOPS
BM=256 BN=128 BK=64 nw=8 ns=3 blk= 24: 0.0244ms 37.85 TFLOPS
BM= 64 BN= 64 BK=64 nw=4 ns=4 blk= 156: 0.0060ms 153.10 TFLOPS
BM=128 BN=128 BK=32 nw=8 ns=4 blk= 42: 0.0127ms 72.60 TFLOPS
BM= 64 BN=128 BK=32 nw=4 ns=4 blk= 78: 0.0078ms 118.89 TFLOPS
BM=128 BN= 64 BK=32 nw=4 ns=4 blk= 84: 0.0079ms 116.55 TFLOPS
BM=256 BN= 64 BK=32 nw=8 ns=4 blk= 48: 0.0136ms 68.05 TFLOPS
BM= 64 BN=256 BK=64 nw=8 ns=3 blk= 39: 0.0134ms 69.25 TFLOPS
BM=128 BN=256 BK=
... (169 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"offset": 214,
"limit": 10
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py</path>
<type>file</type>
<content>
214: x = tl.load(x_ptrs, mask=mk_mask, other=0.0)
215:
216: acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
217:
218: out_ptrs = (out_ptr
219: + _b * stride_out_b
220: + _n_2d * stride_out_d
221: + _out_t * stride_out_t
222: + _out_h * stride_out_h
223: + _out_w * stride_out_w)
(Showing lines 214-223 of 328. Use offset=224 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -263,7 +263,7 @@
# Large shapes (M >= 1024, K >= 1000): im2col + GEMM with transposed weight
# Small shapes: fused kernel with original weight
- if M >= 1024 and K >= 1000:
+ if K >= 1000:
if weight._version != self._weight_version:
self._w_t = weight.reshape(N, K).t().contiguous()
self._weight_version = weight._version
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": " if M >= 1024 and K >= 1000:",
"newString": " if K >= 1000:"
}Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache7 uv run python benchmark.py 2>&1",
"description": "Benchmark with K-only threshold",
"timeout": 300000
}shape=0 variant=eager tflops=3.701 gbps=20.497 ms=0.208
shape=0 variant=compiled tflops=3.193 gbps=17.682 ms=0.241
shape=0 variant=sota tflops=3.743 gbps=20.727 ms=0.206
shape=0 variant=solution tflops=13.898 gbps=76.962 ms=0.055
shape=0 solution_peak_fraction=0.0695
shape=1 variant=eager tflops=13.945 gbps=36.370 ms=0.221
shape=1 variant=compiled tflops=12.005 gbps=31.312 ms=0.257
shape=1 variant=sota tflops=14.069 gbps=36.694 ms=0.219
shape=1 variant=solution tflops=51.353 gbps=133.936 ms=0.060
shape=1 solution_peak_fraction=0.2568
shape=2 variant=eager tflops=26.229 gbps=54.180 ms=0.264
shape=2 variant=compiled tflops=23.436 gbps=48.410 ms=0.296
shape=2 variant=sota tflops=26.487 gbps=54.713 ms=0.262
shape=2 variant=solution tflops=91.731 gbps=189.481 ms=0.076
shape=2 solution_peak_fraction=0.4587
shape=3 variant=eager tflops=6.545 gbps=25.391 ms=0.141
shape=3 variant=compiled tflops=5.175 gbps=20.078 ms=0.179
shape=3 variant=sota tflops=6.674 gbps=25.893 ms=0.139
shape=3 variant=solution tflops=31.517 gbps=122.277 ms=0.029
shape=3 solution_peak_fraction=0.1576
peak_fraction: 0.1895
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -66,27 +66,11 @@
tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- ],
- key=['num_patches', 'N', 'K'],
-)
@triton.jit
-def _gemm_patch_5d(a_ptr, w_ptr, out_ptr,
- stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
- T_out, H_out, W_out, num_patches, K, N,
- kT, kH, kW,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
-):
+def _gemm_nt(a_ptr, b_ptr, c_ptr, M, N, K,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
- num_pid_m = tl.cdiv(num_patches, BLOCK_M)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
@@ -95,21 +79,8 @@
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
- m_mask = m_offsets < num_patches
+ m_mask = m_offsets < M
n_mask = n_offsets < N
-
- _m = m_offsets[:, None]
- _n_2d = n_offsets[None, :]
-
- tom = T_out * H_out * W_out
- how = H_out * W_out
-
- _b = _m // tom
- r1 = _m - _b * tom
- _out_t = r1 // how
- r2 = r1 - _out_t * how
- _out_h = r2 // W_out
- _out_w = r2 - _out_h * W_out
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
@@ -119,21 +90,15 @@
mk_mask = m_mask[:, None] & k_mask[None, :]
nk_mask = n_mask[None, :] & k_mask[:, None]
- a_ptrs = a_ptr + _m * K + k_offsets[None, :]
- a = tl.load(a_ptrs, mask=mk_mask, other=0.0)
-
- w_ptrs = w_ptr + k_offsets[:, None] * N + _n_2d
- w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
-
- acc += tl.dot(a.to(tl.bfloat16), w.to(tl.bfloat16))
-
- out_ptrs = (out_ptr
- + _b * stride_out_b
- + _n_2d * stride_out_d
- + _out_t * stride_out_t
- + _out_h * stride_out_h
- + _out_w * stride_out_w)
- tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+ a = tl.load(a_ptr + m_offsets[:, None] * K + k_offsets[None, :],
+ mask=mk_mask, other=0.0)
+ b = tl.load(b_ptr + k_offsets[:, None] * N + n_offsets[None, :],
+ mask=nk_mask, other=0.0)
+
+ acc += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
+
+ c = c_ptr + m_offsets[:, None] * N + n_offsets[None, :]
+ tl.store(c, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
@triton.autotune(
@@ -143,8 +108,9 @@
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
],
key=['num_patches', 'N', 'K'],
)
@@ -256,13 +222,6 @@
M = B * T_out * H_out * W_out
N = embed_dim
- out = torch.empty(B, embed_dim, T_out, H_out, W_out,
- dtype=torch.bfloat16, device=x.device)
- stride_x = x.stride()
- stride_out = out.stride()
-
- # Large shapes (M >= 1024, K >= 1000): im2col + GEMM with transposed weight
- # Small shapes: fused kernel with original weight
if K >= 1000:
if weight._version != self._weight_version:
self._w_t = weight.reshape(N, K).t().contiguous()
@@ -277,35 +236,38 @@
)
_im2col_3d[grid_i](
x, buf,
- stride_x[0], stride_x[1], stride_x[2], stride_x[3], stride_x[4],
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
T_out, H_out, W_out, M, K,
kT, kH, kW,
)
+
+ out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid_g = lambda META: (
triton.cdiv(M, META['BLOCK_M']) *
triton.cdiv(N, META['BLOCK_N']),
)
- _gemm_patch_5d[grid_g](
- buf, w_t, out,
+ _gemm_nt[grid_g](buf, w_t, out_2d, M, N, K)
+
+ out = out_2d.view(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3).contiguous()
+ return out
+ else:
+ out = torch.empty(B, embed_dim, T_out, H_out, W_out,
+ dtype=torch.bfloat16, device=x.device)
+ stride_out = out.stride()
+
+ grid_f = lambda META: (
+ triton.cdiv(M, META['BLOCK_M']) *
+ triton.cdiv(N, META['BLOCK_N']),
+ )
+ _fused_patch_embed[grid_f](
+ x, weight, out,
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
... (732 more chars){
"content": "<9328 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "<1005 chars \u2014 see below>",
"description": "Check correctness of new GEMM",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache8 uv run python -c "
import torch, solution, reference, shapes
device = torch.device('cuda:0')
for idx, shape in enumerate(shapes.SHAPES):
for attr in ('B','C','T','H','W','kT','kH','kW','embed_dim'):
setattr(reference, attr, shape[attr])
init = reference.get_init_inputs()
ref = reference.Model(*init).to(device).eval()
sol = solution.Model(*init).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
for seed in (42, 123, 456):
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
x = reference.get_inputs()[0].to(device)
with torch.no_grad():
r = ref(x); s = sol(x)
diff = (r - s).abs().max().item()
ok = diff < 0.02
print(f'Shape {idx} seed {seed}: max_diff={diff:.6f} {\"PASS\" if ok else \"FAIL\"}')
if not ok: raise SystemExit(1)
print('ALL PASS')
" 2>&1Traceback (most recent call last):
File "<string>", line 15, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py", line 250, in forward
_gemm_nt[grid_g](buf, w_t, out_2d, M, N, K)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 710, in run
bound_args, specialization, options = binder(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: dynamic_func() missing 3 required positional arguments: 'BLOCK_M', 'BLOCK_N', and 'BLOCK_K'
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -243,11 +243,10 @@
out_2d = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
- grid_g = lambda META: (
- triton.cdiv(M, META['BLOCK_M']) *
- triton.cdiv(N, META['BLOCK_N']),
- )
- _gemm_nt[grid_g](buf, w_t, out_2d, M, N, K)
+ BM, BN, BK = (64, 256, 64) if M >= 1500 else (64, 128, 64) if M >= 500 else (64, 64, 64)
+ grid_g = (triton.cdiv(M, BM) * triton.cdiv(N, BN),)
+ _gemm_nt[grid_g](buf, w_t, out_2d, M, N, K,
+ BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=4, num_stages=4)
out = out_2d.view(B, T_out, H_out, W_out, N).permute(0, 4, 1, 2, 3).contiguous()
return out
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py",
"oldString": "<204 chars \u2014 see below>",
"newString": "<299 chars \u2014 see below>"
} grid_g = lambda META: (
triton.cdiv(M, META['BLOCK_M']) *
triton.cdiv(N, META['BLOCK_N']),
)
_gemm_nt[grid_g](buf, w_t, out_2d, M, N, K) BM, BN, BK = (64, 256, 64) if M >= 1500 else (64, 128, 64) if M >= 500 else (64, 64, 64)
grid_g = (triton.cdiv(M, BM) * triton.cdiv(N, BN),)
_gemm_nt[grid_g](buf, w_t, out_2d, M, N, K,
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, num_warps=4, num_stages=4)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -66,6 +66,19 @@
tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ ],
+ key=['M', 'N', 'K'],
+)
@triton.jit
def _gemm_nt(a_ptr, b_ptr, c_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
@@ -73,32 +86,23 @@
num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
-
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
-
m_mask = m_offsets < M
n_mask = n_offsets < N
-
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
- mk_mask = m_mask[:, None] & k_mask[None, :]
- nk_mask = n_mask[None, :] & k_mask[:, None]
-
a = tl.load(a_ptr + m_offsets[:, None] * K + k_offsets[None, :],
- mask=mk_mask, other=0.0)
+ mask=m_mask[:, None] & k_mask[None, :], other=0.0)
b = tl.load(b_ptr + k_offsets[:, None] * N + n_offsets[None, :],
- mask=nk_mask, other=0.0)
-
+ mask=n_mask[None, :] & k_mask[:, None], other=0.0)
acc += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
-
- c = c_ptr + m_offsets[:, None] * N + n_offsets[None, :]
- tl.store(c, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+ tl.store(c_ptr + m_offsets[:, None] * N + n_offsets[None, :],
+ acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
@triton.autotune(
@@ -118,8 +122,7 @@
def _fused_patch_embed(x_ptr, w_ptr, out_ptr,
stride_x_b, stride_x_c, stride_x_t, stride_x_h, stride_x_w,
stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
- T_out, H_out, W_out,
- num_patches, K, N,
+ T_out, H_out, W_out, num_patches, K, N,
kT, kH, kW,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
@@ -127,42 +130,32 @@
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
-
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
-
m_mask = m_offsets < num_patches
n_mask = n_offsets < N
-
_m = m_offsets[:, None]
_n_2d = n_offsets[None, :]
-
tom = T_out * H_out * W_out
how = H_out * W_out
-
_b = _m // tom
r1 = _m - _b * tom
_out_t = r1 // how
r2 = r1 - _out_t * how
_out_h = r2 // W_out
_out_w = r2 - _out_h * W_out
-
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
-
prod_khw = kH * kW
prod_tkhw = kT * prod_khw
-
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
mk_mask = m_mask[:, None] & k_mask[None, :]
nk_mask = n_mask[None, :] & k_mask[:, None]
-
w_ptrs = w_ptr + k_offsets[:, None] + _n_2d * K
w = tl.load(w_ptrs, mask=nk_mask, other=0.0)
-
_k2 = k_offsets[None, :]
_c = _k2 // prod_tkhw
rk1 = _k2 - _c * prod_tkhw
@@ -170,23 +163,14 @@
rk2 = rk1 - _dt * prod_khw
_dh = rk2 // kW
_dw = rk2 - _dh * kW
-
- x_ptrs = (x_ptr
- + _b * stride_x_b
- + _c * stride_x_c
+ x_ptrs = (x_ptr + _b * stride_x_b + _c * stride_x_c
+ (_out_t * kT + _dt) * stride_x_t
+ (_out_h * kH + _dh) * stride_x_h
+ (_out_w * kW + _dw) * stride_x_w)
x = tl.load(x_ptrs, mask=mk_mask, other=0.0)
-
acc += tl.dot(x.to(tl.bfloat16), w.to(tl.bfloat16))
-
- out_ptrs = (out_ptr
- + _b * stride_out_b
- + _n_2d * stride_out_d
- + _out_t * stride_out_t
- + _out_h * stride_out_h
- + _out_w * stride_out_w)
+ out_ptrs = (out_ptr + _b * stride_out_b + _n_2d * stride_out_d
+ + _out_t * stride_out_t + _out_h * stride_out_h + _out_w * stride_out_w)
tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
@@ -204,7 +188,6 @@
conv = nn.Module()
conv.register_parameter('weight', w)
self.conv = conv
-
K = C * kT * kH * kW
self._w_t = w.reshape(embed_dim, K).t().contiguous()
self._weight_version = w._version
@@ -214,7 +197,6 @@
kT, kH, kW = self.kT, self.kH, self.kW
embed_dim = self.embed_dim
weight = self.conv.weight
-
T_out = T // kT
H_out = H // kH
W_out = W // kW
@@ -227,44 +209,34 @@
self._w_t = weight.reshape(N, K).t().contiguous()
self._weight_version = weight._version
w_t = self._w_t
-
buf = torch.empty(M, K, dtype=torch.bfloat16, device=x.device)
-
grid_i = lambda META: (
-
... (2044 more chars){
"content": "<9962 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm && TRITON_CACHE_DIR=/tmp/opencode/triton_cache9 uv run python check.py 2>&1 && TRITON_CACHE_DIR=/tmp/opencode/triton_cache9 uv run python benchmark.py 2>&1",
"description": "Check and benchmark final version",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=3.683 gbps=20.395 ms=0.209
shape=0 variant=compiled tflops=3.117 gbps=17.262 ms=0.247
shape=0 variant=sota tflops=3.722 gbps=20.611 ms=0.207
shape=0 variant=solution tflops=14.945 gbps=82.765 ms=0.052
shape=0 solution_peak_fraction=0.0747
shape=1 variant=eager tflops=13.950 gbps=36.383 ms=0.221
shape=1 variant=compiled tflops=12.009 gbps=31.322 ms=0.257
shape=1 variant=sota tflops=14.129 gbps=36.850 ms=0.218
shape=1 variant=solution tflops=44.232 gbps=115.365 ms=0.070
shape=1 solution_peak_fraction=0.2212
shape=2 variant=eager tflops=26.319 gbps=54.364 ms=0.264
shape=2 variant=compiled tflops=23.528 gbps=48.599 ms=0.295
shape=2 variant=sota tflops=26.437 gbps=54.610 ms=0.262
shape=2 variant=solution tflops=72.556 gbps=149.872 ms=0.096
shape=2 solution_peak_fraction=0.3628
shape=3 variant=eager tflops=6.468 gbps=25.093 ms=0.143
shape=3 variant=compiled tflops=5.214 gbps=20.227 ms=0.177
shape=3 variant=sota tflops=6.609 gbps=25.641 ms=0.140
shape=3 variant=solution tflops=29.297 gbps=113.662 ms=0.032
shape=3 solution_peak_fraction=0.1465
peak_fraction: 0.1721
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py
@@ -24,26 +24,21 @@
num_pid_m = tl.cdiv(num_patches, BLOCK_M)
pid_m = pid % num_pid_m
pid_k = pid // num_pid_m
-
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
k_start = pid_k * BLOCK_K
k_offsets = k_start + tl.arange(0, BLOCK_K)
-
m_mask = m_offsets < num_patches
k_mask = k_offsets < K
-
_m = m_offsets[:, None]
tom = T_out * H_out * W_out
how = H_out * W_out
-
_b = _m // tom
r1 = _m - _b * tom
_out_t = r1 // how
r2 = r1 - _out_t * how
_out_h = r2 // W_out
_out_w = r2 - _out_h * W_out
-
_k2 = k_offsets[None, :]
prod_khw = kH * kW
prod_tkhw = kT * prod_khw
@@ -53,68 +48,79 @@
rk2 = rk1 - _dt * prod_khw
_dh = rk2 // kW
_dw = rk2 - _dh * kW
-
- x_ptrs = (x_ptr
- + _b * stride_x_b
- + _c * stride_x_c
+ x_ptrs = (x_ptr + _b * stride_x_b + _c * stride_x_c
+ (_out_t * kT + _dt) * stride_x_t
+ (_out_h * kH + _dh) * stride_x_h
+ (_out_w * kW + _dw) * stride_x_w)
x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
-
buf_ptrs = buf_ptr + _m * K + _k2
tl.store(buf_ptrs, x.to(tl.bfloat16), mask=m_mask[:, None] & k_mask[None, :])
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
],
- key=['M', 'N', 'K'],
+ key=['num_patches', 'N', 'K'],
)
@triton.jit
-def _gemm_nt(a_ptr, b_ptr, c_ptr, M, N, K,
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
+def _gemm_patch_5d(a_ptr, w_ptr, out_ptr,
+ stride_out_b, stride_out_d, stride_out_t, stride_out_h, stride_out_w,
+ T_out, H_out, W_out, num_patches, K, N,
+ kT, kH, kW,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
pid = tl.program_id(0)
- num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_m = tl.cdiv(num_patches, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
m_start = pid_m * BLOCK_M
m_offsets = m_start + tl.arange(0, BLOCK_M)
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
- m_mask = m_offsets < M
+ m_mask = m_offsets < num_patches
n_mask = n_offsets < N
+ _m = m_offsets[:, None]
+ _n_2d = n_offsets[None, :]
+ tom = T_out * H_out * W_out
+ how = H_out * W_out
+ _b = _m // tom
+ r1 = _m - _b * tom
+ _out_t = r1 // how
+ r2 = r1 - _out_t * how
+ _out_h = r2 // W_out
+ _out_w = r2 - _out_h * W_out
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
- a = tl.load(a_ptr + m_offsets[:, None] * K + k_offsets[None, :],
- mask=m_mask[:, None] & k_mask[None, :], other=0.0)
- b = tl.load(b_ptr + k_offsets[:, None] * N + n_offsets[None, :],
- mask=n_mask[None, :] & k_mask[:, None], other=0.0)
- acc += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
- tl.store(c_ptr + m_offsets[:, None] * N + n_offsets[None, :],
- acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
+ mk_mask = m_mask[:, None] & k_mask[None, :]
+ nk_mask = n_mask[None, :] & k_mask[:, None]
+ a = tl.load(a_ptr + _m * K + k_offsets[None, :], mask=mk_mask, other=0.0)
+ w = tl.load(w_ptr + k_offsets[:, None] * N + _n_2d, mask=nk_mask, other=0.0)
+ acc += tl.dot(a.to(tl.bfloat16), w.to(tl.bfloat16))
+ out_ptrs = (out_ptr + _b * stride_out_b + _n_2d * stride_out_d
+ + _out_t * stride_out_t + _out_h * stride_out_h + _out_w * stride_out_w)
+ tl.store(out_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :])
@triton.autotune(
configs=[
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, nu
... (4542 more chars){
"content": "<10499 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/10_patch_embed_conv3d_gemm/solution.py"
}Wrote file successfully.