{
"problem": "09_fmha_preattn_mrope",
"harness": "droid",
"model": "custom:GLM-5.1-[Z.AI-Coding-Plan]-0",
"reasoning_effort": "",
"has_solution": false,
"correct": false,
"peak_fraction": null,
"elapsed_seconds": 2701,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 10092, "output_tokens": 74135, "cache_read_tokens": 78976, "cache_creation_tokens": 0, "reasoning_tokens": 0, "total_cost_usd": null}
}
user
I need you to optimize the multi-axis RoPE pre-attention prep step for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
The op is the pre-attention prep used in Qwen2-VL-style vision-language models: take q (B, T, Hq, D) and k (B, T, Hkv, D) bf16 plus three int64 position-id tensors pos_t, pos_h, pos_w of shape (B, T) — temporal, height, width axes — gather per-axis cos/sin from a precomputed (max_pos, D) bf16 cache, mix the per-axis values across the head_dim according to mrope_section (a tuple [s0, s1, s2] summing to D/2 — the first s0 pair-indices use the temporal axis, next s1 use height, last s2 use width, mirrored into the second half of D), then apply standard rotate-half rotation (out = x*cos + rotate_half(x)*sin) on q and k. Return q_rot and k_rot in (B, H, T, D) layout — the transpose from (B, T, H, D) to (B, H, T, D) is part of the prep. Correctness tolerance on the bf16 outputs is 1e-2 abs/rel. Both q_rot and k_rot must match the reference within that.
The shapes you have to handle are B=1 T=4096 Hq=32 Hkv=8 D=128 mrope_section=(16,24,24) max_pos=32768 (Qwen2-VL base), B=1 T=8192 Hq=28 Hkv=4 D=128 mrope_section=(16,24,24) max_pos=32768 (long-context GQA), B=2 T=2048 Hq=16 Hkv=2 D=64 mrope_section=(8,12,12) max_pos=16384 (smaller head_dim, batch 2), and B=1 T=16384 Hq=32 Hkv=8 D=128 mrope_section=(16,24,24) max_pos=65536 (very long context).
This needs to be a real custom kernel. Don't reach for transformers' apply_multimodal_rotary_pos_emb, flash_attn.layers.rotary, flashinfer.rope, or flashinfer.apply_rope — they're off-limits and using them fails correctness. Try CUDA C++ via torch.utils.cpp_extension.load_inline, Triton, inline PTX, or whatever fits. The work is bandwidth-bound, so wins come from fusing the gather + per-axis mix + rotate + transpose into one kernel that streams q/k once and emits the rotated (B,H,T,D) layout directly. Anything you're uncertain about — RoPE math, mRoPE per-axis slicing, layout conventions — look up PTX docs, read the transformers Qwen2-VL implementation, browse FlashInfer / flash-attn rotary headers, and investigate.
Your flywheel is implement, profile (ncu, nsys, torch.profiler — whatever's useful) and time it with benchmark.py, verify correctness by running `python check.py` and reading the output, then iterate. Don't substitute your own one-off correctness snippets for check.py — it iterates over every shape across multiple seeds, your spot-check almost certainly won't. If `python check.py` hasn't printed PASS, you're not done. Take as long as you need to actually push the number up.