{
"problem": "10_patch_embed_conv3d_gemm",
"harness": "droid",
"model": "custom:GLM-5.1-[Z.AI-Coding-Plan]-0",
"reasoning_effort": "",
"has_solution": false,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 26538, "output_tokens": 80156, "cache_read_tokens": 131072, "cache_creation_tokens": 0, "reasoning_tokens": 0, "total_cost_usd": null}
}
user
I need you to optimize a 3D patch-embedding kernel for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
The op is the patch-embedding step that opens every modern Vision-Transformer / video-language model: a bf16 video tensor x of shape (B, C, T, H, W) is split into non-overlapping (kT, kH, kW) patches and each patch is projected to embed_dim. Mathematically this is a 3D convolution with stride equal to the kernel, equivalently a single (num_patches, C*kT*kH*kW) by (C*kT*kH*kW, embed_dim) GEMM after a strided gather. Output is (B, embed_dim, T/kT, H/kH, W/kW) bf16. The Model has a Conv3d weight registered as a parameter — your solution must declare it identically so state_dict loading works. Correctness tolerance on the bf16 output is 1e-2 abs/rel.
The shapes you have to handle are B=1 C=3 T=2 H=224 W=224 kT=2 kH=14 kW=14 embed_dim=1280 (Qwen2-VL ViT base, the canonical case), B=2 C=3 T=4 H=224 W=224 kT=2 kH=14 kW=14 embed_dim=1280 (batch 2, 4-frame video), B=1 C=3 T=8 H=336 W=336 kT=2 kH=14 kW=14 embed_dim=1280 (larger spatial, 8 frames), and B=4 C=3 T=1 H=224 W=224 kT=1 kH=16 kW=16 embed_dim=768 (image-mode ViT-B/16). All inputs are aligned to the patch size — no fractional-patch tails to predicate.
This needs to be a real custom kernel. Don't reach for torch.nn.Conv3d, torch.nn.functional.conv3d, F.conv3d, or torch.conv3d — they're off-limits and using them fails correctness. Don't take the lazy reshape-then-cuBLAS shortcut either: torch.matmul, torch.bmm, torch.nn.functional.linear, F.linear, torch.einsum, torch.nn.functional.unfold, and F.unfold are all banned. Try CUDA C++ via torch.utils.cpp_extension.load_inline, CUTLASS / CuTe (which has good support for strided patch loads), Triton with tl.dot, inline PTX with mma.sync, or whatever fits. The work is compute-bound at embed_dim=1280, so wins come from getting tensor cores busy with bf16 MMAs while fusing the strided patch gather into the K-loop. Anything you're uncertain about — im2col layout for 3D, MMA tile shapes for SM120, CUTLASS Conv-as-GEMM examples — look up PTX docs, browse CUTLASS, read library source, and investigate.
Your flywheel is implement, profile (ncu, nsys, torch.profiler — whatever's useful) and time it with benchmark.py, verify correctness by running `python check.py` and reading the output, then iterate. Don't substitute your own one-off correctness snippets for check.py — it iterates over every shape across multiple seeds, your spot-check almost certainly won't. If `python check.py` hasn't printed PASS, you're not done. Take as long as you need to actually push the number up.