./kernelbenchhardv3runsopencode · ?
opencode / ? session: ses_1fe73997bffeHEJ5qOeKinX4B2 cwd: ?
harness
opencode
model
?
turns
86
tools called
38
events
123
input toks
35,161
output toks
23,507
cache hit
2,164,864
solution.py
final answer
benchmark.log
check.log
result.json
"""Fused multi-axis RoPE pre-attention kernel for RTX PRO 6000 (Blackwell, sm_120).

Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.

Each thread handles one RoPE pair (d, d+D/2) and reads cos/sin ONLY from
the axis assigned to its slice by mrope_section — cutting cache reads by 3x.
One block per (b,t) pair. No shared memory, no syncthreads.
"""
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline

OP_TYPE = "rope"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]

B = 1
T = 4096
Hq = 32
Hkv = 8
D = 128
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768


_cuda_src = r"""
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <stdint.h>

extern "C" __global__ void
mrope_fused_kernel(
    const __nv_bfloat16* __restrict__ q,
    const __nv_bfloat16* __restrict__ k,
    const int64_t* __restrict__ pos_t,
    const int64_t* __restrict__ pos_h,
    const int64_t* __restrict__ pos_w,
    const __nv_bfloat16* __restrict__ cos_cache,
    const __nv_bfloat16* __restrict__ sin_cache,
    __nv_bfloat16* __restrict__ q_rot,
    __nv_bfloat16* __restrict__ k_rot,
    int B, int T, int Hq, int Hkv, int D,
    int s0, int s1, int s2
) {
    int blk = blockIdx.x;
    if (blk >= B * T) return;

    int b = blk / T;
    int t = blk % T;
    int d0 = threadIdx.x;
    int d1 = d0 + D / 2;

    int off_bt = b * T + t;
    unsigned int d0u = d0;
    unsigned int boundary1 = (unsigned int)s0;
    unsigned int boundary2 = (unsigned int)(s0 + s1);

    // Each thread reads cos/sin only from the axis that owns its d0 slice.
    // cos_val/sin_val for d0 also apply to d1 (mirrored by mrope mixing).
    float cos_val, sin_val;
    if (d0u < boundary1) {
        cos_val = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
        sin_val = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
    } else if (d0u < boundary2) {
        cos_val = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
        sin_val = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
    } else {
        cos_val = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
        sin_val = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
    }

    // --- strides ---
    int q_stride_H  = D;
    int q_stride_T  = Hq * D;
    int q_base      = b * T * q_stride_T + t * q_stride_T;

    int qo_stride_T = D;
    int qo_stride_H = T * D;
    int qo_stride_B = Hq * T * D;
    int qo_base     = b * qo_stride_B + t * qo_stride_T;

    int k_stride_H  = D;
    int k_stride_T  = Hkv * D;
    int k_base      = b * T * k_stride_T + t * k_stride_T;

    int ko_stride_T = D;
    int ko_stride_H = T * D;
    int ko_stride_B = Hkv * T * D;
    int ko_base     = b * ko_stride_B + t * ko_stride_T;

    // --- rotate q heads ---
    for (int h = 0; h < Hq; ++h) {
        int q_h_off  = q_base  + h * q_stride_H;
        int qo_h_off = qo_base + h * qo_stride_H;

        float q0 = __bfloat162float(q[q_h_off + d0]);
        float q1 = __bfloat162float(q[q_h_off + d1]);

        q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
        q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
    }

    // --- rotate k heads ---
    for (int h = 0; h < Hkv; ++h) {
        int k_h_off  = k_base  + h * k_stride_H;
        int ko_h_off = ko_base + h * ko_stride_H;

        float k0 = __bfloat162float(k[k_h_off + d0]);
        float k1 = __bfloat162float(k[k_h_off + d1]);

        k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
        k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
    }
}

torch::Tensor launch_mrope(torch::Tensor q, torch::Tensor k,
                           torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
                           torch::Tensor cos_cache, torch::Tensor sin_cache,
                           torch::Tensor q_rot, torch::Tensor k_rot,
                           int B, int T, int Hq, int Hkv, int D,
                           int s0, int s1, int s2) {
    int blocks = B * T;
    int threads = D / 2;
    mrope_fused_kernel<<<blocks, threads>>>(
        (const __nv_bfloat16*)q.data_ptr(),
        (const __nv_bfloat16*)k.data_ptr(),
        (const int64_t*)pos_t.data_ptr(),
        (const int64_t*)pos_h.data_ptr(),
        (const int64_t*)pos_w.data_ptr(),
        (const __nv_bfloat16*)cos_cache.data_ptr(),
        (const __nv_bfloat16*)sin_cache.data_ptr(),
        (__nv_bfloat16*)q_rot.data_ptr(),
        (__nv_bfloat16*)k_rot.data_ptr(),
        B, T, Hq, Hkv, D, s0, s1, s2
    );
    return q_rot;
}
"""

_cpp_src = """
torch::Tensor launch_mrope(torch::Tensor q, torch::Tensor k,
                           torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
                           torch::Tensor cos_cache, torch::Tensor sin_cache,
                           torch::Tensor q_rot, torch::Tensor k_rot,
                           int B, int T, int Hq, int Hkv, int D,
                           int s0, int s1, int s2);
"""


def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))


class Model(nn.Module):
    def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
                 mrope_section: tuple[int, int, int], max_pos: int):
        super().__init__()
        assert sum(mrope_section) == D // 2, \
            f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
        self.B, self.T = B, T
        self.Hq, self.Hkv, self.D = Hq, Hkv, D
        self.mrope_section = tuple(mrope_section)
        self.max_pos = max_pos

        inv_freq = _build_inv_freq(D)
        pos = torch.arange(max_pos, dtype=torch.float32)
        freqs = torch.outer(pos, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
        self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))

        self._mrope_module = load_inline(
            name="mrope_fused",
            cpp_sources=_cpp_src,
            cuda_sources=_cuda_src,
            functions=["launch_mrope"],
            extra_cuda_cflags=["-arch=sm_120", "-O3", "--use_fast_math"],
            verbose=False,
        )

    def forward(
        self,
        q: torch.Tensor, k: torch.Tensor,
        pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        B, T, Hq, D = q.shape
        Hkv = k.shape[2]
        s0, s1, s2 = self.mrope_section

        q_rot = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
        k_rot = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)

        self._mrope_module.launch_mrope(
            q, k, pos_t, pos_h, pos_w,
            self.cos_cache, self.sin_cache,
            q_rot, k_rot,
            B, T, Hq, Hkv, D, s0, s1, s2,
        )

        return q_rot, k_rot


def get_inputs():
    q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
    k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
    pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
    pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
    pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
    return [q, k, pos_t, pos_h, pos_w]


def get_init_inputs():
    return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
All checks pass: - **Correctness**: `PASS` (all 4 shapes x 3 seeds) - **Lint**: Clean - **Framework**: `cuda_raw` (CUDA C++ via load_inline) - **Forbidden ops**: None used The kernel fuses cache gather + per-axis mrope mixing + rotate-half rotation + layout transpose into a single pass. Each thread reads cos/sin only from its assigned mrope axis (2 reads/thread instead of 6), and handles one RoPE pair (d, d+D/2) without shared memory or synchronization. Output lands directly in (B, H, T, D) attention layout.
shape=0 variant=eager tflops=0.200 gbps=286.263 ms=0.315
shape=0 variant=compiled tflops=0.562 gbps=805.919 ms=0.112
shape=0 variant=solution tflops=0.799 gbps=1145.865 ms=0.079
shape=0 solution_peak_fraction=0.6366
shape=1 variant=eager tflops=0.192 gbps=279.701 ms=0.526
shape=1 variant=compiled tflops=0.529 gbps=773.084 ms=0.190
shape=1 variant=solution tflops=0.871 gbps=1271.602 ms=0.116
shape=1 solution_peak_fraction=0.7064
shape=2 variant=eager tflops=0.095 gbps=148.485 ms=0.149
shape=2 variant=compiled tflops=0.229 gbps=357.117 ms=0.062
shape=2 variant=solution tflops=0.711 gbps=1110.361 ms=0.020
shape=2 solution_peak_fraction=0.6169
shape=3 variant=eager tflops=0.175 gbps=251.202 ms=1.438
shape=3 variant=compiled tflops=0.484 gbps=695.178 ms=0.519
shape=3 variant=solution tflops=0.896 gbps=1285.981 ms=0.281
shape=3 solution_peak_fraction=0.7144
peak_fraction: 0.6672
RESULT: OK
PASS
{
    "problem": "09_fmha_preattn_mrope",
    "harness": "opencode",
    "model": "deepseek/deepseek-v4-pro",
    "reasoning_effort": "",
    "has_solution": true,
    "correct": true,
    "peak_fraction": 0.6672,
    "elapsed_seconds": 2349,
    "harness_exit_code": 0,
    "session_complete": true,
    "usage": {"input_tokens": 35161, "output_tokens": 23507, "cache_read_tokens": 2164864, "cache_creation_tokens": 0, "reasoning_tokens": 46529, "total_cost_usd": null}
}

timeline (123 events)

assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}
tool
stdout (215 chars)
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py

(7 entries)
</entries>
assistant
in 17345out 67cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py"
}
tool
stdout (4884 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
1: """Naive multi-axis RoPE pre-attention reference (correctness only).
2: 
3: Mirrors the Qwen2-VL apply_multimodal_rotary_pos_emb semantics: three position
4: axes (temporal, height, width), each contributing a slice of the head_dim via
5: mrope_section. Standard rotate-half on q and k once per-axis cos/sin have been
6: mixed.
7: 
8: Layout convention:
9:   q: (B, T, Hq, D)   bf16
10:   k: (B, T, Hkv, D)  bf16
11:   pos_t, pos_h, pos_w: (B, T) int64
12: 
13: Output: q_rot, k_rot in (B, H, T, D) — i.e. the layout an attention kernel
14: expects (head dim second). The transpose is part of the "pre-attention prep".
15: """
16: import torch
17: import torch.nn as nn
18: 
19: OP_TYPE = "rope"
20: SUPPORTED_PRECISIONS = ["bf16"]
21: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
22: 
23: 
24: def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
25:     return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
26: 
27: 
28: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
29:     half = x.shape[-1] // 2
30:     return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
31: 
32: 
33: class Model(nn.Module):
34:     def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
35:                  mrope_section: tuple[int, int, int], max_pos: int):
36:         super().__init__()
37:         assert sum(mrope_section) == D // 2, \
38:          

... (3384 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py"
}
tool
stdout (5019 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for multi-axis RoPE pre-attention prep.
2: 
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9: 
10: import torch
11: import yaml
12: 
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16: 
17: from src.eval.correctness import check_correctness  # noqa: E402
18: 
19: 
20: def main():
21:     try:
22:         import reference
23:         import shapes
24:         import solution
25:     except Exception as e:
26:         print(f"FAIL: import error: {e}")
27:         sys.exit(1)
28: 
29:     problem_yaml = Path("problem.yaml")
30:     meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31: 
32:     # --- Forbidden-op check ------------------------------------------------
33:     sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34:     for forbidden in meta.get("forbidden", []):
35:         pat = re.escape(forbidden)
36:         if re.search(pat, sol_src):
37:             print(f"FAIL: forbidden op used: {forbidden}")
38:             sys.exit(1)
39: 
40:     device = torch.device("cuda:0")
41:     tol_override = meta.get("tolerance") or

... (3519 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py"
}
tool
stdout (5466 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for multi-axis RoPE pre-attention prep.
2: 
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5: 
6: Output lines the harness picks up:
7:   shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8:   peak_fraction: <N>  (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13: 
14: import torch
15: import yaml
16: 
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19: 
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction  # noqa: E402
21: from src.eval.timing import time_fn  # noqa: E402
22: from src.hardware import get as get_hw  # noqa: E402
23: 
24: 
25: def _eval_formula(expr: str, vars: dict) -> float:
26:     # Very small eval: only names from `vars` are valid.
27:     return float(eval(expr, {"__builtins__": {}}, vars))
28: 
29: 
30: def main():
31:     import reference
32:     import shapes
33:     import solution
34: 
35:     meta = yaml.safe_load(Path("problem.yaml").read_text())
36:     hw = get_hw(meta["hardware"][0])
37:     peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38:     peak_gbps = hw.peak_bandwidth_gb_s
39:     regime = meta.get("regime

... (3966 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py"
}
tool
stdout (1187 chars)
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for multi-axis RoPE pre-attention.
2: 
3: Sized for Qwen2-VL-style vision-language inference: long T, GQA, mrope_section
4: splitting D/2 across temporal/height/width axes. Mix of base, long-context,
5: small-head-dim, and very-long-context.
6: """
7: 
8: SHAPES = [
9:     # Qwen2-VL base: T=4k, head_dim=128, GQA 32:8, mrope [16,24,24]
10:     {"B": 1, "T": 4096,  "Hq": 32, "Hkv": 8, "D": 128,
11:      "mrope_section": (16, 24, 24), "max_pos": 32768},
12: 
13:     # Qwen2-VL 7B-style long context, narrower kv heads
14:     {"B": 1, "T": 8192,  "Hq": 28, "Hkv": 4, "D": 128,
15:      "mrope_section": (16, 24, 24), "max_pos": 32768},
16: 
17:     # Smaller head_dim and B=2 (tests batch dim and reduced D)
18:     {"B": 2, "T": 2048,  "Hq": 16, "Hkv": 2, "D": 64,
19:      "mrope_section": (8, 12, 12), "max_pos": 16384},
20: 
21:     # Very long context (16k tokens)
22:     {"B": 1, "T": 16384, "Hq": 32, "Hkv": 8, "D": 128,
23:      "mrope_section": (16, 24, 24), "max_pos": 65536},
24: ]

(End of file - total 24 lines)
</content>
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml"
}
tool
stdout (1784 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml</path>
<type>file</type>
<content>
1: name: 09_fmha_preattn_mrope
2: display_name: "Multi-axis RoPE Pre-attention Prep"
3: precision: bf16
4: regime: memory  # pointwise rotation + cache gather; bandwidth-bound
5: 
6: # FLOPs per call: rotate-half on q and k.
7: # Per element: out = a*cos + b*sin -> 3 flops/elem (1 mul + 1 mul + 1 add).
8: # We do this on q ((B*T*Hq*D) elems) and k ((B*T*Hkv*D) elems).
9: flops_formula: "3 * B * T * (Hq + Hkv) * D"
10: 
11: # Bytes moved (lower bound):
12: #   read q, k:                  2 * B*T*(Hq+Hkv)*D
13: #   read 3 position id arrays:  3 * B*T*8        (int64)
14: #   read cos/sin from cache:    6 * B*T*D*2      (3 axes, 2 caches, bf16)
15: #   write q_rot, k_rot:         2 * B*T*(Hq+Hkv)*D
16: bytes_formula: "4 * B * T * (Hq + Hkv) * D + 12 * B * T * D + 24 * B * T"
17: 
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16    # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram
21: 
22: tolerance:
23:   bfloat16: 0.02
24: 
25: # Forbidden ops -- agent must implement the gather + per-axis mix + rotation.
26: # Banning vendor mRoPE entry points and obvious shortcuts.
27: forbidden:
28:   - "transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb"
29:   - "apply_multimodal_rotary_pos_emb"
30:   - "flash_attn.layers.rotary"
31:   - "flashinfer.rope"
32:   - "flashinfer.apply_rope"
33: 
34: sota:
35:   name:

... (284 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py"
}
tool
stdout (1995 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for multi-axis RoPE pre-attention.
2: 
3: Wraps transformers' apply_multimodal_rotary_pos_emb. The full pipeline (gather
4: per-axis cos/sin from cache, then call the vendor rotation) is timed, matching
5: what the reference and the agent solution do.
6: 
7: If transformers isn't installed or the impl signature has drifted, returns
8: is_available()=False and the benchmark omits the SOTA line.
9: """
10: from __future__ import annotations
11: 
12: import torch
13: 
14: 
15: def _try_transformers():
16:     try:
17:         from transformers.models.qwen2_vl.modeling_qwen2_vl import (
18:             apply_multimodal_rotary_pos_emb as _impl,
19:         )
20:         return _impl
21:     except Exception:
22:         return None
23: 
24: 
25: def sota_forward(
26:     q: torch.Tensor, k: torch.Tensor,
27:     pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
28:     cos_cache: torch.Tensor, sin_cache: torch.Tensor,
29:     mrope_section: tuple[int, int, int],
30: ):
31:     """q,k: (B, T, H, D); pos_*: (B, T); cos_cache,sin_cache: (max_pos, D)."""
32:     impl = _try_transformers()
33:     if impl is None:
34:         raise RuntimeError("transformers not available for SOTA")
35: 
36:     # Gather per-axis cos/sin and stack into (3, B, T, D) as transformers expects.
37:     cos = torch.stack([cos_cache[pos_t], cos_cache[pos_h], c

... (495 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/PROMPT.txt
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/PROMPT.txt"
}
tool
stdout (3117 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to optimize the multi-axis RoPE pre-attention prep step for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2: 
3: The op is the pre-attention prep used in Qwen2-VL-style vision-language models: take q (B, T, Hq, D) and k (B, T, Hkv, D) bf16 plus three int64 position-id tensors pos_t, pos_h, pos_w of shape (B, T) — temporal, height, width axes — gather per-axis cos/sin from a precomputed (max_pos, D) bf16 cache, mix the per-axis values across the head_dim according to mrope_section (a tuple [s0, s1, s2] summing to D/2 — the first s0 pair-indices use the temporal axis, next s1 use height, last s2 use width, mirrored into the second half of D), then apply standard rotate-half rotation (out = x*cos + rotate_half(x)*sin) on q and k. Return q_rot and k_rot in (B, H, T, D) layout — the transpose from (B, T, H, D) to (B, H, T, D) is part of the prep. Correctness tolerance on the bf16 outputs is 1e-2 abs/rel. Both q_rot and k_rot must match the reference within that.
4: 
5: The shapes you have to handle are B=1 T

... (1617 more chars)
assistant
in 170out 408cache_r 17408cache_w 0
assistant
bash
args
{
  "command": "nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader 2>/dev/null || echo \"no GPU info\"",
  "description": "Check GPU status"
}
tool
stdout (72 chars)
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 96956 MiB
assistant
glob /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope
args
{
  "pattern": "solution.py",
  "path": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}
tool
stdout (14 chars)
No files found
assistant
Now let me check the current GPU state and look at any existing solution:
assistant
in 7803out 179cache_r 17920cache_w 0
assistant
bash
args
{
  "description": "Check CUDA arch and hardware details",
  "command": "<607 chars \u2014 see below>"
}
command (607 chars)
cd /home/infatoshi/cuda/KernelBench-Hard && uv run python -c "
import torch
print(f'CUDA: {torch.cuda.is_available()}')
print(f'Device: {torch.cuda.get_device_name(0)}')
print(f'Compute Capability: {torch.cuda.get_device_capability(0)}')
print(f'SM count: {torch.cuda.get_device_properties(0).multi_processor_count}')
print(f'Max threads per block: {torch.cuda.get_device_properties(0).max_threads_per_block}')
print(f'Max shared mem per block: {torch.cuda.get_device_properties(0).shared_mem_per_block}')
print(f'Max threads per SM: {torch.cuda.get_device_properties(0).max_threads_per_multi_processor}')
"
tool
stdout (351 chars)
Traceback (most recent call last):
  File "<string>", line 8, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'shared_mem_per_block'. Did you mean: 'shared_memory_per_block'?
CUDA: True
Device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Compute Capability: (12, 0)
SM count: 188
Max threads per block: 1024
assistant
read /home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}
tool
stdout (4080 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2: 
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7: 
8: import torch
9: 
10: DEFAULT_TOLERANCE = {
11:     torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12:     torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13:     torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14:     torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15:     torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16:     torch.int8: {"atol": 0, "rtol": 0},
17:     torch.int32: {"atol": 0, "rtol": 0},
18:     torch.int64: {"atol": 0, "rtol": 0},
19: }
20: 
21: 
22: def _coerce_float(x):
23:     """Accept int/float/str scientific-notation as float; pass through dicts."""
24:     if isinstance(x, (int, float)):
25:         return float(x)
26:     if isinstance(x, str):
27:         try:
28:             return float(x)
29:         except ValueError:
30:             return x
31:     return x
32: 
33: 
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35:     """Lookup atol/rtol for a given dtype, with optional per-problem override.
36: 
37:     The override dict is keyed by str(dtype). Values may be:
38:       - a single number (or a string parseable as float, e.g. "5e-2") -> used
39:         as both atol and rtol
4

... (2580 more chars)
assistant
Let me check the correctness checker and the hardware info:
assistant
in 157out 307cache_r 42112cache_w 0
assistant
bash
args
{
  "description": "Check remaining GPU properties",
  "command": "<562 chars \u2014 see below>"
}
command (562 chars)
cd /home/infatoshi/cuda/KernelBench-Hard && uv run python -c "
import torch
p = torch.cuda.get_device_properties(0)
print(f'SM count: {p.multi_processor_count}')
print(f'Max threads per SM: {p.max_threads_per_multi_processor}')
print(f'Max shared mem per block: {p.shared_memory_per_block}')
print(f'Max shared mem per block opt-in: {p.shared_memory_per_block_optin}')
print(f'Max blocks per SM: {p.max_blocks_per_multi_processor}')
print(f'L2 cache size: {p.l2_cache_size}')
print(f'Registers per block: {p.regs_per_block}')
print(f'Warp size: {p.warp_size}')
"
tool
stdout (341 chars)
Traceback (most recent call last):
  File "<string>", line 8, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'max_blocks_per_multi_processor'. Did you mean: 'max_threads_per_multi_processor'?
SM count: 188
Max threads per SM: 1536
Max shared mem per block: 49152
Max shared mem per block opt-in: 101376
assistant
in 1485out 232cache_r 42624cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (218 lines — TRUNCATED)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,215 @@
+"""Fused multi-axis RoPE pre-attention kernel for RTX PRO 6000 (Blackwell, sm_120).
+
+Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
+rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
+
+Each thread handles a RoPE pair (d, d+D/2) — no shared memory, no syncthreads.
+The kernel streams q/k once and emits rotated output directly in attention layout.
+"""
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+OP_TYPE = "rope"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+
+B = 1
+T = 4096
+Hq = 32
+Hkv = 8
+D = 128
+MROPE_SECTION = (16, 24, 24)
+MAX_POS = 32768
+
+
+_cuda_src = r"""
+#include <cuda_runtime.h>
+#include <cuda_bf16.h>
+#include <stdint.h>
+
+extern "C" __global__ void mrope_fused_kernel(
+    const __nv_bfloat16* __restrict__ q,
+    const __nv_bfloat16* __restrict__ k,
+    const int64_t* __restrict__ pos_t,
+    const int64_t* __restrict__ pos_h,
+    const int64_t* __restrict__ pos_w,
+    const __nv_bfloat16* __restrict__ cos_cache,
+    const __nv_bfloat16* __restrict__ sin_cache,
+    __nv_bfloat16* __restrict__ q_rot,
+    __nv_bfloat16* __restrict__ k_rot,
+    int B, int T, int Hq, int Hkv, int D,
+    int s0, int s1, int s2
+) {
+    int blk = blockIdx.x;
+    if (blk >= B * T) return;
+
+    int b = blk / T;
+    int t = blk % T;
+    int d0 = threadIdx.x;
+    int d1 = d0 + D / 2;
+
+    // --- gather position ids ---
+    int off_bt = b * T + t;
+    int64_t pt = pos_t[off_bt];
+    int64_t ph = pos_h[off_bt];
+    int64_t pw = pos_w[off_bt];
+
+    // --- gather per-axis cos/sin and mix according to mrope_section ---
+    float ct = __bfloat162float(cos_cache[pt * D + d0]);
+    float st = __bfloat162float(sin_cache[pt * D + d0]);
+    float ch = __bfloat162float(cos_cache[ph * D + d0]);
+    float sh = __bfloat162float(sin_cache[ph * D + d0]);
+    float cw = __bfloat162float(cos_cache[pw * D + d0]);
+    float sw = __bfloat162float(sin_cache[pw * D + d0]);
+
+    float cos_val, sin_val;
+    if (d0 < (unsigned int)s0) {
+        cos_val = ct; sin_val = st;
+    } else if (d0 < (unsigned int)(s0 + s1)) {
+        cos_val = ch; sin_val = sh;
+    } else {
+        cos_val = cw; sin_val = sw;
+    }
+
+    // --- strides ---
+    // q input:  (B, T, Hq, D)
+    int q_stride_H  = D;
+    int q_stride_T  = Hq * D;
+    int q_base      = b * T * q_stride_T + t * q_stride_T;
+
+    // q_rot output: (B, Hq, T, D)
+    int qo_stride_T = D;
+    int qo_stride_H = T * D;
+    int qo_stride_B = Hq * T * D;
+    int qo_base     = b * qo_stride_B + t * qo_stride_T;
+
+    // k input:  (B, T, Hkv, D)
+    int k_stride_H  = D;
+    int k_stride_T  = Hkv * D;
+    int k_base      = b * T * k_stride_T + t * k_stride_T;
+
+    // k_rot output: (B, Hkv, T, D)
+    int ko_stride_T = D;
+    int ko_stride_H = T * D;
+    int ko_stride_B = Hkv * T * D;
+    int ko_base     = b * ko_stride_B + t * ko_stride_T;
+
+    // --- rotate q heads ---
+    for (int h = 0; h < Hq; ++h) {
+        float q0 = __bfloat162float(q[q_base + h * q_stride_H + d0]);
+        float q1 = __bfloat162float(q[q_base + h * q_stride_H + d1]);
+        int off = qo_base + h * qo_stride_H;
+        q_rot[off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
+        q_rot[off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
+    }
+
+    // --- rotate k heads ---
+    for (int h = 0; h < Hkv; ++h) {
+        float k0 = __bfloat162float(k[k_base + h * k_stride_H + d0]);
+        float k1 = __bfloat162float(k[k_base + h * k_stride_H + d1]);
+        int off = ko_base + h * ko_stride_H;
+        k_rot[off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
+        k_rot[off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
+    }
+}
+
+torch::Tensor launch_mrope(torch::Tensor q, torch::Tensor k,
+                           torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+                           torch::Tensor cos_cache, torch::Tensor sin_cache,
+                           torch::Tensor q_rot, torch::Tensor k_rot,
+                           int B, int T, int Hq, int Hkv, int D,
+                           int s0, int s1, int s2) {
+    int blocks = B * T;
+    int threads = D / 2;
+    mrope_fused_kernel<<<blocks, threads>>>(
+        (const __nv_bfloat16*)q.data_ptr(),
+        (const __nv_bfloat16*)k.data_ptr(),
+        (const int64_t*)pos_t.data_ptr(),
+        (const int64_t*)pos_h.data_ptr(),
+        (const int64_t*)pos_w.data_ptr(),
+        (const __nv_bfloat16*)cos_cache.data_ptr(),
+        (const __nv_bfloat16*)sin_cache.data_ptr(),
+        (__nv_bfloat16*)q_rot.data_ptr(),
+        (__nv_bfloat16*)k_rot.data_ptr(),
+        B, T, Hq, Hkv, D, s0, s1, s2
+    );
+    return q_rot;
+}
+"""
+
+_cpp_src = """
+torch::Tensor launch_mrope(torch::Tensor q, torch::Tensor k,
+                           torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+                           torch::Tensor cos_cache, torch::Tensor sin_cache,
+                           torch::Tensor q_rot, torch::Tensor k_rot,
+                           int B, int T, int Hq, int Hkv, int D,
+                           int s0, int s1, int s2);
+"""
+
+
+def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
+    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+
+
+class Model(nn.Module):
+    def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
+                 mrope_section: tuple[int, int, int], max_pos: int):
+        super().__init__()
+        assert sum(mrope_section) == D // 2, \
+            f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
+        self.B, self.T = B, T
+        self.Hq, self.Hkv, self.D = Hq, Hkv, D
+        self.mrope_section = tuple(mrope_section)
+    

... (1919 more chars)
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
  "content": "<7585 chars \u2014 see diff>"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Now I have a clear picture. Let me write the optimized solution. The reference splits work across many separate torch ops (6 cache gathers, 2 axis mixes, 2 rotations, 2 transposes). My CUDA kernel fuses everything: gather per-axis cos/sin, mix per mrope_section, apply rotate-half per pair, and emit in (B,H,T,D) layout -- all in one pass.
assistant
in 118out 2722cache_r 44416cache_w 0
assistant
Now let me verify correctness:
assistant
bash
args
{
  "description": "Run correctness check",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
  "timeout": 120000
}
tool
stdout (5 chars)
PASS
assistant
in 143out 118cache_r 50304cache_w 0
assistant
Correctness passes. Now benchmark:
assistant
bash
args
{
  "description": "Run performance benchmark",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
  "timeout": 300000
}
tool
stdout (896 chars)
shape=0 variant=eager tflops=0.235 gbps=336.609 ms=0.268
shape=0 variant=compiled tflops=0.462 gbps=663.403 ms=0.136
shape=0 variant=solution tflops=0.795 gbps=1140.538 ms=0.079
shape=0 solution_peak_fraction=0.6336
shape=1 variant=eager tflops=0.185 gbps=270.128 ms=0.544
shape=1 variant=compiled tflops=0.501 gbps=731.068 ms=0.201
shape=1 variant=solution tflops=0.857 gbps=1250.827 ms=0.118
shape=1 solution_peak_fraction=0.6949
shape=2 variant=eager tflops=0.058 gbps=91.193 ms=0.243
shape=2 variant=compiled tflops=0.227 gbps=354.917 ms=0.062
shape=2 variant=solution tflops=0.752 gbps=1174.511 ms=0.019
shape=2 solution_peak_fraction=0.6525
shape=3 variant=eager tflops=0.179 gbps=256.244 ms=1.409
shape=3 variant=compiled tflops=0.528 gbps=757.729 ms=0.477
shape=3 variant=solution tflops=1.024 gbps=1468.664 ms=0.246
shape=3 solution_peak_fraction=0.8159
peak_fraction: 0.6958
RESULT: OK
assistant
in 31out 119cache_r 50560cache_w 0
assistant
Good first result -- 70% geomean peak fraction. Let me profile to find the bottleneck:
assistant
bash
args
{
  "description": "Profile with ncu to find bottlenecks",
  "command": "<1094 chars \u2014 see below>",
  "timeout": 30000
}
command (1094 chars)
cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch
import reference
import solution
import shapes

shape = shapes.SHAPES[0]
reference.B = shape['B']; reference.T = shape['T']; reference.Hq = shape['Hq']; reference.Hkv = shape['Hkv']; reference.D = shape['D']; reference.MROPE_SECTION = tuple(shape['mrope_section']); reference.MAX_POS = shape['max_pos']

init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to('cuda').eval()
sol_model = solution.Model(*init_args).to('cuda').eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)

torch.manual_seed(42)
inputs = [t.to('cuda') for t in reference.get_inputs()]

# Warmup
for _ in range(5):
    sol_model(*inputs)
torch.cuda.synchronize()

# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(100):
    sol_model(*inputs)
end.record()
torch.cuda.synchronize()
print(f'100 iters: {start.elapsed_time(end):.3f} ms, per iter: {start.elapsed_time(end)/100:.4f} ms')
" 2>&1
tool
stdout (41 chars)
100 iters: 1.454 ms, per iter: 0.0145 ms
assistant
in 437out 431cache_r 50688cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}
tool
stdout (2978 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2: 
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5: 
6: Methodology:
7:   - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8:     torch.compile reduce-overhead CUDA-graph capture.
9:   - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10:     (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11:     we measure HBM-load bandwidth, not L2-cached bandwidth).
12:   - GPU timing via cuda.Event with synchronize() AFTER record() but
13:     BEFORE elapsed_time() — the canonical NVIDIA pattern.
14:   - Reported value is the median of `iters` trials, robust to outliers.
15: 
16: Notes / known biases not addressed here:
17:   - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18:     launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19:     On small shapes where launch overhead matters, this gives compile an
20:     artificial advantage. We accept this as the cost of using the agreed
21:     "compiled" baseline policy.
22:   - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23:   - Median over a fairly small number of trials (default 30) is fine for
24:     headline numbers but won't catch bimodal distributions. Use --ext

... (1478 more chars)
assistant
in 40out 62cache_r 60160cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (181 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
 Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
 rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
 
-Each thread handles a RoPE pair (d, d+D/2) — no shared memory, no syncthreads.
-The kernel streams q/k once and emits rotated output directly in attention layout.
+Each thread handles a RoPE pair (d, d+D/2). Vectorized loads (2 bf16 as uint32)
+halve transaction count. No shared memory, no syncthreads.
 """
 import torch
 import torch.nn as nn
@@ -28,7 +28,8 @@
 #include <cuda_bf16.h>
 #include <stdint.h>
 
-extern "C" __global__ void mrope_fused_kernel(
+extern "C" __global__ void __launch_bounds__(128, 4)
+mrope_fused_kernel(
     const __nv_bfloat16* __restrict__ q,
     const __nv_bfloat16* __restrict__ k,
     const int64_t* __restrict__ pos_t,
@@ -46,8 +47,15 @@
 
     int b = blk / T;
     int t = blk % T;
-    int d0 = threadIdx.x;
-    int d1 = d0 + D / 2;
+    int half = D / 2;
+
+    // Each thread handles TWO consecutive d0 values (and their mirror pairs).
+    // d0a, d0b = threadIdx.x*2, threadIdx.x*2+1
+    // d1a, d1b = d0a+half, d0b+half
+    int d00 = threadIdx.x * 2;
+    int d01 = d00 + 1;
+    int d10 = d00 + half;
+    int d11 = d01 + half;
 
     // --- gather position ids ---
     int off_bt = b * T + t;
@@ -55,62 +63,94 @@
     int64_t ph = pos_h[off_bt];
     int64_t pw = pos_w[off_bt];
 
-    // --- gather per-axis cos/sin and mix according to mrope_section ---
-    float ct = __bfloat162float(cos_cache[pt * D + d0]);
-    float st = __bfloat162float(sin_cache[pt * D + d0]);
-    float ch = __bfloat162float(cos_cache[ph * D + d0]);
-    float sh = __bfloat162float(sin_cache[ph * D + d0]);
-    float cw = __bfloat162float(cos_cache[pw * D + d0]);
-    float sw = __bfloat162float(sin_cache[pw * D + d0]);
-
-    float cos_val, sin_val;
-    if (d0 < (unsigned int)s0) {
-        cos_val = ct; sin_val = st;
-    } else if (d0 < (unsigned int)(s0 + s1)) {
-        cos_val = ch; sin_val = sh;
+    int64_t pt_off = pt * D;
+    int64_t ph_off = ph * D;
+    int64_t pw_off = pw * D;
+
+    // Gather cos/sin for d00 (same axis applies to d01 unless on boundary)
+    float ct0 = __bfloat162float(cos_cache[pt_off + d00]);
+    float st0 = __bfloat162float(sin_cache[pt_off + d00]);
+    float ch0 = __bfloat162float(cos_cache[ph_off + d00]);
+    float sh0 = __bfloat162float(sin_cache[ph_off + d00]);
+    float cw0 = __bfloat162float(cos_cache[pw_off + d00]);
+    float sw0 = __bfloat162float(sin_cache[pw_off + d00]);
+
+    float ct1 = __bfloat162float(cos_cache[pt_off + d01]);
+    float st1 = __bfloat162float(sin_cache[pt_off + d01]);
+    float ch1 = __bfloat162float(cos_cache[ph_off + d01]);
+    float sh1 = __bfloat162float(sin_cache[ph_off + d01]);
+    float cw1 = __bfloat162float(cos_cache[pw_off + d01]);
+    float sw1 = __bfloat162float(sin_cache[pw_off + d01]);
+
+    // Mix for d00
+    float cos0, sin0;
+    if (d00 < (unsigned int)s0) {
+        cos0 = ct0; sin0 = st0;
+    } else if (d00 < (unsigned int)(s0 + s1)) {
+        cos0 = ch0; sin0 = sh0;
     } else {
-        cos_val = cw; sin_val = sw;
+        cos0 = cw0; sin0 = sw0;
+    }
+
+    // Mix for d01
+    float cos1, sin1;
+    if (d01 < (unsigned int)s0) {
+        cos1 = ct1; sin1 = st1;
+    } else if (d01 < (unsigned int)(s0 + s1)) {
+        cos1 = ch1; sin1 = sh1;
+    } else {
+        cos1 = cw1; sin1 = sw1;
     }
 
     // --- strides ---
-    // q input:  (B, T, Hq, D)
     int q_stride_H  = D;
     int q_stride_T  = Hq * D;
     int q_base      = b * T * q_stride_T + t * q_stride_T;
 
-    // q_rot output: (B, Hq, T, D)
     int qo_stride_T = D;
     int qo_stride_H = T * D;
     int qo_stride_B = Hq * T * D;
     int qo_base     = b * qo_stride_B + t * qo_stride_T;
 
-    // k input:  (B, T, Hkv, D)
     int k_stride_H  = D;
     int k_stride_T  = Hkv * D;
     int k_base      = b * T * k_stride_T + t * k_stride_T;
 
-    // k_rot output: (B, Hkv, T, D)
     int ko_stride_T = D;
     int ko_stride_H = T * D;
     int ko_stride_B = Hkv * T * D;
     int ko_base     = b * ko_stride_B + t * ko_stride_T;
 
     // --- rotate q heads ---
+    #pragma unroll 1
     for (int h = 0; h < Hq; ++h) {
-        float q0 = __bfloat162float(q[q_base + h * q_stride_H + d0]);
-        float q1 = __bfloat162float(q[q_base + h * q_stride_H + d1]);
-        int off = qo_base + h * qo_stride_H;
-        q_rot[off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
-        q_rot[off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
+        int q_h_off = q_base + h * q_stride_H;
+        float q00 = __bfloat162float(q[q_h_off + d00]);
+        float q01 = __bfloat162float(q[q_h_off + d01]);
+        float q10 = __bfloat162float(q[q_h_off + d10]);
+        float q11 = __bfloat162float(q[q_h_off + d11]);
+
+        int qo_h_off = qo_base + h * qo_stride_H;
+        q_rot[qo_h_off + d00] = __float2bfloat16(q00 * cos0 - q10 * sin0);
+        q_rot[qo_h_off + d01] = __float2bfloat16(q01 * cos1 - q11 * sin1);
+        q_rot[qo_h_off + d10] = __float2bfloat16(q10 * cos0 + q00 * sin0);
+        q_rot[qo_h_off + d11] = __float2bfloat16(q11 * cos1 + q01 * sin1);
     }
 
     // --- rotate k heads ---
+    #pragma unroll 1
     for (int h = 0; h < Hkv; ++h) {
-        float k0 = __bfloat162float(k[k_base + h * k_stride_H + d0]);
-        float k1 = __bfloat162float(k[k_base + h * k_stride_H + d1]);
-        int off = ko_base + h * ko_stride_H;
-        k_rot[off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
-        k_rot[off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
+        int k_h_off = k_base + h * k_stride_H;
+        float k00 = __bfloat162float(k[k_h_off + d00]);
+        float k01 = __bfloat162float(k[k_h_off + d01]);
+        float k10 = __bfloat16

... (1232 more chars)
args
{
  "content": "<9007 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me optimize the kernel with vectorized loads, loop unrolling, and launch bounds:
assistant
in 896out 3212cache_r 60416cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (164 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
 Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
 rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
 
-Each thread handles a RoPE pair (d, d+D/2). Vectorized loads (2 bf16 as uint32)
-halve transaction count. No shared memory, no syncthreads.
+Each thread handles one RoPE pair (d, d+D/2). Uses __ldg for cache reads.
+No shared memory, no syncthreads.
 """
 import torch
 import torch.nn as nn
@@ -28,7 +28,7 @@
 #include <cuda_bf16.h>
 #include <stdint.h>
 
-extern "C" __global__ void __launch_bounds__(128, 4)
+extern "C" __global__ void
 mrope_fused_kernel(
     const __nv_bfloat16* __restrict__ q,
     const __nv_bfloat16* __restrict__ k,
@@ -47,59 +47,31 @@
 
     int b = blk / T;
     int t = blk % T;
-    int half = D / 2;
-
-    // Each thread handles TWO consecutive d0 values (and their mirror pairs).
-    // d0a, d0b = threadIdx.x*2, threadIdx.x*2+1
-    // d1a, d1b = d0a+half, d0b+half
-    int d00 = threadIdx.x * 2;
-    int d01 = d00 + 1;
-    int d10 = d00 + half;
-    int d11 = d01 + half;
-
-    // --- gather position ids ---
+    int d0 = threadIdx.x;
+    int d1 = d0 + D / 2;
+
     int off_bt = b * T + t;
-    int64_t pt = pos_t[off_bt];
-    int64_t ph = pos_h[off_bt];
-    int64_t pw = pos_w[off_bt];
-
-    int64_t pt_off = pt * D;
-    int64_t ph_off = ph * D;
-    int64_t pw_off = pw * D;
-
-    // Gather cos/sin for d00 (same axis applies to d01 unless on boundary)
-    float ct0 = __bfloat162float(cos_cache[pt_off + d00]);
-    float st0 = __bfloat162float(sin_cache[pt_off + d00]);
-    float ch0 = __bfloat162float(cos_cache[ph_off + d00]);
-    float sh0 = __bfloat162float(sin_cache[ph_off + d00]);
-    float cw0 = __bfloat162float(cos_cache[pw_off + d00]);
-    float sw0 = __bfloat162float(sin_cache[pw_off + d00]);
-
-    float ct1 = __bfloat162float(cos_cache[pt_off + d01]);
-    float st1 = __bfloat162float(sin_cache[pt_off + d01]);
-    float ch1 = __bfloat162float(cos_cache[ph_off + d01]);
-    float sh1 = __bfloat162float(sin_cache[ph_off + d01]);
-    float cw1 = __bfloat162float(cos_cache[pw_off + d01]);
-    float sw1 = __bfloat162float(sin_cache[pw_off + d01]);
-
-    // Mix for d00
-    float cos0, sin0;
-    if (d00 < (unsigned int)s0) {
-        cos0 = ct0; sin0 = st0;
-    } else if (d00 < (unsigned int)(s0 + s1)) {
-        cos0 = ch0; sin0 = sh0;
+
+    int64_t pt_off = pos_t[off_bt] * D;
+    int64_t ph_off = pos_h[off_bt] * D;
+    int64_t pw_off = pos_w[off_bt] * D;
+
+    // Gather and mix cos/sin for d0 (d1 mirrors d0 due to mrope mirroring)
+    float ct = __bfloat162float(__ldg(cos_cache + pt_off + d0));
+    float st = __bfloat162float(__ldg(sin_cache + pt_off + d0));
+    float ch = __bfloat162float(__ldg(cos_cache + ph_off + d0));
+    float sh = __bfloat162float(__ldg(sin_cache + ph_off + d0));
+    float cw = __bfloat162float(__ldg(cos_cache + pw_off + d0));
+    float sw = __bfloat162float(__ldg(sin_cache + pw_off + d0));
+
+    float cos_val, sin_val;
+    unsigned int d0u = d0;
+    if (d0u < (unsigned int)s0) {
+        cos_val = ct; sin_val = st;
+    } else if (d0u < (unsigned int)(s0 + s1)) {
+        cos_val = ch; sin_val = sh;
     } else {
-        cos0 = cw0; sin0 = sw0;
-    }
-
-    // Mix for d01
-    float cos1, sin1;
-    if (d01 < (unsigned int)s0) {
-        cos1 = ct1; sin1 = st1;
-    } else if (d01 < (unsigned int)(s0 + s1)) {
-        cos1 = ch1; sin1 = sh1;
-    } else {
-        cos1 = cw1; sin1 = sw1;
+        cos_val = cw; sin_val = sw;
     }
 
     // --- strides ---
@@ -122,35 +94,27 @@
     int ko_base     = b * ko_stride_B + t * ko_stride_T;
 
     // --- rotate q heads ---
-    #pragma unroll 1
     for (int h = 0; h < Hq; ++h) {
-        int q_h_off = q_base + h * q_stride_H;
-        float q00 = __bfloat162float(q[q_h_off + d00]);
-        float q01 = __bfloat162float(q[q_h_off + d01]);
-        float q10 = __bfloat162float(q[q_h_off + d10]);
-        float q11 = __bfloat162float(q[q_h_off + d11]);
-
+        int q_h_off  = q_base  + h * q_stride_H;
         int qo_h_off = qo_base + h * qo_stride_H;
-        q_rot[qo_h_off + d00] = __float2bfloat16(q00 * cos0 - q10 * sin0);
-        q_rot[qo_h_off + d01] = __float2bfloat16(q01 * cos1 - q11 * sin1);
-        q_rot[qo_h_off + d10] = __float2bfloat16(q10 * cos0 + q00 * sin0);
-        q_rot[qo_h_off + d11] = __float2bfloat16(q11 * cos1 + q01 * sin1);
+
+        float q0 = __bfloat162float(q[q_h_off + d0]);
+        float q1 = __bfloat162float(q[q_h_off + d1]);
+
+        q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
+        q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
     }
 
     // --- rotate k heads ---
-    #pragma unroll 1
     for (int h = 0; h < Hkv; ++h) {
-        int k_h_off = k_base + h * k_stride_H;
-        float k00 = __bfloat162float(k[k_h_off + d00]);
-        float k01 = __bfloat162float(k[k_h_off + d01]);
-        float k10 = __bfloat162float(k[k_h_off + d10]);
-        float k11 = __bfloat162float(k[k_h_off + d11]);
-
+        int k_h_off  = k_base  + h * k_stride_H;
         int ko_h_off = ko_base + h * ko_stride_H;
-        k_rot[ko_h_off + d00] = __float2bfloat16(k00 * cos0 - k10 * sin0);
-        k_rot[ko_h_off + d01] = __float2bfloat16(k01 * cos1 - k11 * sin1);
-        k_rot[ko_h_off + d10] = __float2bfloat16(k10 * cos0 + k00 * sin0);
-        k_rot[ko_h_off + d11] = __float2bfloat16(k11 * cos1 + k01 * sin1);
+
+        float k0 = __bfloat162float(k[k_h_off + d0]);
+        float k1 = __bfloat162float(k[k_h_off + d1]);
+
+        k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
+        k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
     }
 }
 
@@ -161,7 +125,7 @@
                            int B, int T, int 

... (354 more chars)
args
{
  "content": "<7454 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me revert to the safer per-element approach with micro-optimizations:
assistant
in 135out 2632cache_r 65280cache_w 0
assistant
bash
args
{
  "description": "Run correctness check",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
  "timeout": 120000
}
tool
stdout (5 chars)
PASS
assistant
in 42out 113cache_r 69888cache_w 0
assistant
bash
args
{
  "description": "Run benchmark",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
  "timeout": 300000
}
tool
stdout (896 chars)
shape=0 variant=eager tflops=0.199 gbps=285.235 ms=0.316
shape=0 variant=compiled tflops=0.550 gbps=788.904 ms=0.114
shape=0 variant=solution tflops=0.803 gbps=1151.713 ms=0.078
shape=0 solution_peak_fraction=0.6398
shape=1 variant=eager tflops=0.192 gbps=280.127 ms=0.525
shape=1 variant=compiled tflops=0.532 gbps=777.072 ms=0.189
shape=1 variant=solution tflops=0.863 gbps=1260.783 ms=0.117
shape=1 solution_peak_fraction=0.7004
shape=2 variant=eager tflops=0.097 gbps=151.016 ms=0.146
shape=2 variant=compiled tflops=0.207 gbps=324.127 ms=0.068
shape=2 variant=solution tflops=0.567 gbps=885.586 ms=0.025
shape=2 solution_peak_fraction=0.4920
shape=3 variant=eager tflops=0.177 gbps=254.571 ms=1.418
shape=3 variant=compiled tflops=0.485 gbps=696.573 ms=0.518
shape=3 variant=solution tflops=0.893 gbps=1282.036 ms=0.282
shape=3 solution_peak_fraction=0.7122
peak_fraction: 0.6295
RESULT: OK
assistant
in 51out 112cache_r 70016cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (49 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
 Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
 rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
 
-Each thread handles one RoPE pair (d, d+D/2). Uses __ldg for cache reads.
-No shared memory, no syncthreads.
+Each thread handles one RoPE pair (d, d+D/2) and reads cos/sin only from
+the axis that owns its slice (per mrope_section), cutting cache reads by 3x.
 """
 import torch
 import torch.nn as nn
@@ -52,26 +52,20 @@
 
     int off_bt = b * T + t;
 
-    int64_t pt_off = pos_t[off_bt] * D;
-    int64_t ph_off = pos_h[off_bt] * D;
-    int64_t pw_off = pos_w[off_bt] * D;
-
-    // Gather and mix cos/sin for d0 (d1 mirrors d0 due to mrope mirroring)
-    float ct = __bfloat162float(__ldg(cos_cache + pt_off + d0));
-    float st = __bfloat162float(__ldg(sin_cache + pt_off + d0));
-    float ch = __bfloat162float(__ldg(cos_cache + ph_off + d0));
-    float sh = __bfloat162float(__ldg(sin_cache + ph_off + d0));
-    float cw = __bfloat162float(__ldg(cos_cache + pw_off + d0));
-    float sw = __bfloat162float(__ldg(sin_cache + pw_off + d0));
-
+    // Each thread reads cos/sin only from the axis that owns its d0 slice.
+    // d0 is in [0, D/2). mrope_section maps s0 elems to temporal, s1 to
+    // height, s2 to width. Mirrored into [D/2, D) so d1 shares the same axis.
     float cos_val, sin_val;
     unsigned int d0u = d0;
     if (d0u < (unsigned int)s0) {
-        cos_val = ct; sin_val = st;
+        cos_val = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
+        sin_val = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
     } else if (d0u < (unsigned int)(s0 + s1)) {
-        cos_val = ch; sin_val = sh;
+        cos_val = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
+        sin_val = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
     } else {
-        cos_val = cw; sin_val = sw;
+        cos_val = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
+        sin_val = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
     }
 
     // --- strides ---
args
{
  "content": "<7454 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me analyze the profile data, then optimize further. The main issue: we're loading 6 full cache rows per (b,t) pair but only using D/3 of each one (one per axis slice). Let me restructure to load only what each thread needs:
assistant
in 434out 2658cache_r 70144cache_w 0
assistant
bash
args
{
  "description": "Run correctness check",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
  "timeout": 120000
}
tool
stdout (5 chars)
PASS
assistant
in 101out 113cache_r 79104cache_w 0
assistant
bash
args
{
  "description": "Run benchmark",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
  "timeout": 300000
}
tool
stdout (897 chars)
shape=0 variant=eager tflops=0.199 gbps=285.495 ms=0.316
shape=0 variant=compiled tflops=0.569 gbps=815.942 ms=0.111
shape=0 variant=solution tflops=0.802 gbps=1151.478 ms=0.078
shape=0 solution_peak_fraction=0.6397
shape=1 variant=eager tflops=0.193 gbps=281.156 ms=0.523
shape=1 variant=compiled tflops=0.529 gbps=771.979 ms=0.190
shape=1 variant=solution tflops=0.875 gbps=1278.148 ms=0.115
shape=1 solution_peak_fraction=0.7101
shape=2 variant=eager tflops=0.097 gbps=151.181 ms=0.146
shape=2 variant=compiled tflops=0.224 gbps=349.886 ms=0.063
shape=2 variant=solution tflops=0.711 gbps=1111.254 ms=0.020
shape=2 solution_peak_fraction=0.6174
shape=3 variant=eager tflops=0.171 gbps=245.062 ms=1.474
shape=3 variant=compiled tflops=0.483 gbps=693.491 ms=0.521
shape=3 variant=solution tflops=0.898 gbps=1288.698 ms=0.280
shape=3 solution_peak_fraction=0.7159
peak_fraction: 0.6694
RESULT: OK
assistant
in 41out 112cache_r 79744cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (181 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
 Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
 rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
 
-Each thread handles one RoPE pair (d, d+D/2) and reads cos/sin only from
-the axis that owns its slice (per mrope_section), cutting cache reads by 3x.
+Each block processes TILE_BT (b,t) pairs to reduce grid size on long sequences.
+Each thread handles one RoPE pair (d, d+D/2). No shared memory, no syncthreads.
 """
 import torch
 import torch.nn as nn
@@ -22,6 +22,10 @@
 MROPE_SECTION = (16, 24, 24)
 MAX_POS = 32768
 
+# Each block handles this many (b,t) pairs. Tuned for 188-SM Blackwell:
+#   B*T <= 4096:  TILE_BT=1  (all blocks fit in one wave)
+#   B*T >  4096:  TILE_BT=4  (reduce grid size to <= 4096)
+TILE_BT = 4
 
 _cuda_src = r"""
 #include <cuda_runtime.h>
@@ -40,75 +44,72 @@
     __nv_bfloat16* __restrict__ q_rot,
     __nv_bfloat16* __restrict__ k_rot,
     int B, int T, int Hq, int Hkv, int D,
-    int s0, int s1, int s2
+    int s0, int s1, int s2, int total_pairs
 ) {
-    int blk = blockIdx.x;
-    if (blk >= B * T) return;
-
-    int b = blk / T;
-    int t = blk % T;
     int d0 = threadIdx.x;
     int d1 = d0 + D / 2;
 
-    int off_bt = b * T + t;
-
-    // Each thread reads cos/sin only from the axis that owns its d0 slice.
-    // d0 is in [0, D/2). mrope_section maps s0 elems to temporal, s1 to
-    // height, s2 to width. Mirrored into [D/2, D) so d1 shares the same axis.
-    float cos_val, sin_val;
-    unsigned int d0u = d0;
-    if (d0u < (unsigned int)s0) {
-        cos_val = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
-        sin_val = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
-    } else if (d0u < (unsigned int)(s0 + s1)) {
-        cos_val = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
-        sin_val = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
-    } else {
-        cos_val = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
-        sin_val = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
-    }
-
-    // --- strides ---
+    // Strides — constant across all pairs
     int q_stride_H  = D;
     int q_stride_T  = Hq * D;
-    int q_base      = b * T * q_stride_T + t * q_stride_T;
-
     int qo_stride_T = D;
     int qo_stride_H = T * D;
     int qo_stride_B = Hq * T * D;
-    int qo_base     = b * qo_stride_B + t * qo_stride_T;
-
     int k_stride_H  = D;
     int k_stride_T  = Hkv * D;
-    int k_base      = b * T * k_stride_T + t * k_stride_T;
-
     int ko_stride_T = D;
     int ko_stride_H = T * D;
     int ko_stride_B = Hkv * T * D;
-    int ko_base     = b * ko_stride_B + t * ko_stride_T;
-
-    // --- rotate q heads ---
-    for (int h = 0; h < Hq; ++h) {
-        int q_h_off  = q_base  + h * q_stride_H;
-        int qo_h_off = qo_base + h * qo_stride_H;
-
-        float q0 = __bfloat162float(q[q_h_off + d0]);
-        float q1 = __bfloat162float(q[q_h_off + d1]);
-
-        q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
-        q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
-    }
-
-    // --- rotate k heads ---
-    for (int h = 0; h < Hkv; ++h) {
-        int k_h_off  = k_base  + h * k_stride_H;
-        int ko_h_off = ko_base + h * ko_stride_H;
-
-        float k0 = __bfloat162float(k[k_h_off + d0]);
-        float k1 = __bfloat162float(k[k_h_off + d1]);
-
-        k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
-        k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
+
+    // Grid-stride loop over (b,t) pairs
+    for (int pair_idx = blockIdx.x; pair_idx < total_pairs; pair_idx += gridDim.x) {
+        int b = pair_idx / T;
+        int t = pair_idx % T;
+
+        int off_bt = b * T + t;
+
+        // Gather all six per-axis cos/sin for this d0, then mix.
+        float ct = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
+        float st = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
+        float ch = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
+        float sh = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
+        float cw = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
+        float sw = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
+
+        float cos_val, sin_val;
+        unsigned int d0u = d0;
+        if (d0u < (unsigned int)s0) {
+            cos_val = ct; sin_val = st;
+        } else if (d0u < (unsigned int)(s0 + s1)) {
+            cos_val = ch; sin_val = sh;
+        } else {
+            cos_val = cw; sin_val = sw;
+        }
+
+        int q_base  = b * T * q_stride_T + t * q_stride_T;
+        int qo_base = b * qo_stride_B + t * qo_stride_T;
+        int k_base  = b * T * k_stride_T + t * k_stride_T;
+        int ko_base = b * ko_stride_B + t * ko_stride_T;
+
+        // Rotate q heads
+        for (int h = 0; h < Hq; ++h) {
+            int q_h_off  = q_base  + h * q_stride_H;
+            int qo_h_off = qo_base + h * qo_stride_H;
+            float q0 = __bfloat162float(q[q_h_off + d0]);
+            float q1 = __bfloat162float(q[q_h_off + d1]);
+            q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
+            q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
+        }
+
+        // Rotate k heads
+        for (int h = 0; h < Hkv; ++h) {
+            int k_h_off  = k_base  + h * k_stride_H;
+            int ko_h_off = ko_base + h * ko_stride_H;
+            float k0 = __bfloat162float(k[k_h_off + d0]);
+            float k1 = __bfloat162float(k[k_h_off + d1]);
+            k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
+            k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
+        }
     }
 }
 
@@ -118

... (1149 more chars)
args
{
  "content": "<8210 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The per-axis approach hurt shape 3. Let me revert to the original efficient design but add multi-pair processing per block to reduce wave count on large shapes:
assistant
in 421out 2854cache_r 79872cache_w 0
assistant
edit /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (19 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -122,12 +122,11 @@
     int total_pairs = B * T;
     int threads = D / 2;
 
-    // Heuristic: keep grid size near SM count for good occupancy.
-    // SM count = 188. Target ~2x SM count = 376 blocks minimum.
+    // Grid-stride: if total_pairs fits in one occupancy wave (~4096 blocks),
+    // launch one block per pair. Otherwise, tile to keep grid <= 8192.
     int grid = total_pairs;
-    if (total_pairs > 8 * 376) grid = (total_pairs + 7) / 8;
-    else if (total_pairs > 4 * 376) grid = (total_pairs + 3) / 4;
-    else if (total_pairs > 2 * 376) grid = (total_pairs + 1) / 2;
+    if (total_pairs > 8192) grid = (total_pairs + 3) / 4;
+    else if (total_pairs > 4096) grid = (total_pairs + 1) / 2;
 
     mrope_fused_kernel<<<grid, threads>>>(
         (const __nv_bfloat16*)q.data_ptr(),
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
  "oldString": "<352 chars \u2014 see below>",
  "newString": "<298 chars \u2014 see below>"
}
oldString (352 chars)
    // Heuristic: keep grid size near SM count for good occupancy.
    // SM count = 188. Target ~2x SM count = 376 blocks minimum.
    int grid = total_pairs;
    if (total_pairs > 8 * 376) grid = (total_pairs + 7) / 8;
    else if (total_pairs > 4 * 376) grid = (total_pairs + 3) / 4;
    else if (total_pairs > 2 * 376) grid = (total_pairs + 1) / 2;
newString (298 chars)
    // Grid-stride: if total_pairs fits in one occupancy wave (~4096 blocks),
    // launch one block per pair. Otherwise, tile to keep grid <= 8192.
    int grid = total_pairs;
    if (total_pairs > 8192) grid = (total_pairs + 3) / 4;
    else if (total_pairs > 4096) grid = (total_pairs + 1) / 2;
tool
stdout (26 chars)
Edit applied successfully.
assistant
in 54out 324cache_r 84736cache_w 0
assistant
bash
args
{
  "description": "Check correctness then benchmark",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
  "timeout": 300000
}
tool
stdout (902 chars)
PASS
shape=0 variant=eager tflops=0.200 gbps=286.860 ms=0.315
shape=0 variant=compiled tflops=0.566 gbps=812.067 ms=0.111
shape=0 variant=solution tflops=0.856 gbps=1228.443 ms=0.073
shape=0 solution_peak_fraction=0.6825
shape=1 variant=eager tflops=0.192 gbps=280.059 ms=0.525
shape=1 variant=compiled tflops=0.529 gbps=773.084 ms=0.190
shape=1 variant=solution tflops=0.866 gbps=1264.427 ms=0.116
shape=1 solution_peak_fraction=0.7025
shape=2 variant=eager tflops=0.098 gbps=153.090 ms=0.144
shape=2 variant=compiled tflops=0.227 gbps=355.373 ms=0.062
shape=2 variant=solution tflops=0.695 gbps=1085.939 ms=0.020
shape=2 solution_peak_fraction=0.6033
shape=3 variant=eager tflops=0.175 gbps=251.459 ms=1.436
shape=3 variant=compiled tflops=0.486 gbps=697.973 ms=0.517
shape=3 variant=solution tflops=0.880 gbps=1263.377 ms=0.286
shape=3 solution_peak_fraction=0.7019
peak_fraction: 0.6712
RESULT: OK
assistant
in 108out 124cache_r 85504cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (181 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
 Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
 rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
 
-Each block processes TILE_BT (b,t) pairs to reduce grid size on long sequences.
 Each thread handles one RoPE pair (d, d+D/2). No shared memory, no syncthreads.
+One block per (b,t) pair — simple, no grid-stride overhead.
 """
 import torch
 import torch.nn as nn
@@ -22,10 +22,6 @@
 MROPE_SECTION = (16, 24, 24)
 MAX_POS = 32768
 
-# Each block handles this many (b,t) pairs. Tuned for 188-SM Blackwell:
-#   B*T <= 4096:  TILE_BT=1  (all blocks fit in one wave)
-#   B*T >  4096:  TILE_BT=4  (reduce grid size to <= 4096)
-TILE_BT = 4
 
 _cuda_src = r"""
 #include <cuda_runtime.h>
@@ -44,72 +40,77 @@
     __nv_bfloat16* __restrict__ q_rot,
     __nv_bfloat16* __restrict__ k_rot,
     int B, int T, int Hq, int Hkv, int D,
-    int s0, int s1, int s2, int total_pairs
+    int s0, int s1, int s2
 ) {
+    int blk = blockIdx.x;
+    if (blk >= B * T) return;
+
+    int b = blk / T;
+    int t = blk % T;
     int d0 = threadIdx.x;
     int d1 = d0 + D / 2;
 
-    // Strides — constant across all pairs
+    int off_bt = b * T + t;
+
+    // Gather per-axis cos/sin and mix according to mrope_section
+    float ct = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
+    float st = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
+    float ch = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
+    float sh = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
+    float cw = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
+    float sw = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
+
+    float cos_val, sin_val;
+    unsigned int d0u = d0;
+    if (d0u < (unsigned int)s0) {
+        cos_val = ct; sin_val = st;
+    } else if (d0u < (unsigned int)(s0 + s1)) {
+        cos_val = ch; sin_val = sh;
+    } else {
+        cos_val = cw; sin_val = sw;
+    }
+
+    // --- strides ---
     int q_stride_H  = D;
     int q_stride_T  = Hq * D;
+    int q_base      = b * T * q_stride_T + t * q_stride_T;
+
     int qo_stride_T = D;
     int qo_stride_H = T * D;
     int qo_stride_B = Hq * T * D;
+    int qo_base     = b * qo_stride_B + t * qo_stride_T;
+
     int k_stride_H  = D;
     int k_stride_T  = Hkv * D;
+    int k_base      = b * T * k_stride_T + t * k_stride_T;
+
     int ko_stride_T = D;
     int ko_stride_H = T * D;
     int ko_stride_B = Hkv * T * D;
-
-    // Grid-stride loop over (b,t) pairs
-    for (int pair_idx = blockIdx.x; pair_idx < total_pairs; pair_idx += gridDim.x) {
-        int b = pair_idx / T;
-        int t = pair_idx % T;
-
-        int off_bt = b * T + t;
-
-        // Gather all six per-axis cos/sin for this d0, then mix.
-        float ct = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
-        float st = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
-        float ch = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
-        float sh = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
-        float cw = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
-        float sw = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
-
-        float cos_val, sin_val;
-        unsigned int d0u = d0;
-        if (d0u < (unsigned int)s0) {
-            cos_val = ct; sin_val = st;
-        } else if (d0u < (unsigned int)(s0 + s1)) {
-            cos_val = ch; sin_val = sh;
-        } else {
-            cos_val = cw; sin_val = sw;
-        }
-
-        int q_base  = b * T * q_stride_T + t * q_stride_T;
-        int qo_base = b * qo_stride_B + t * qo_stride_T;
-        int k_base  = b * T * k_stride_T + t * k_stride_T;
-        int ko_base = b * ko_stride_B + t * ko_stride_T;
-
-        // Rotate q heads
-        for (int h = 0; h < Hq; ++h) {
-            int q_h_off  = q_base  + h * q_stride_H;
-            int qo_h_off = qo_base + h * qo_stride_H;
-            float q0 = __bfloat162float(q[q_h_off + d0]);
-            float q1 = __bfloat162float(q[q_h_off + d1]);
-            q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
-            q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
-        }
-
-        // Rotate k heads
-        for (int h = 0; h < Hkv; ++h) {
-            int k_h_off  = k_base  + h * k_stride_H;
-            int ko_h_off = ko_base + h * ko_stride_H;
-            float k0 = __bfloat162float(k[k_h_off + d0]);
-            float k1 = __bfloat162float(k[k_h_off + d1]);
-            k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
-            k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
-        }
+    int ko_base     = b * ko_stride_B + t * ko_stride_T;
+
+    // --- rotate q heads ---
+    for (int h = 0; h < Hq; ++h) {
+        int q_h_off  = q_base  + h * q_stride_H;
+        int qo_h_off = qo_base + h * qo_stride_H;
+
+        float q0 = __bfloat162float(q[q_h_off + d0]);
+        float q1 = __bfloat162float(q[q_h_off + d1]);
+
+        q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
+        q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
+    }
+
+    // --- rotate k heads ---
+    for (int h = 0; h < Hkv; ++h) {
+        int k_h_off  = k_base  + h * k_stride_H;
+        int ko_h_off = ko_base + h * ko_stride_H;
+
+        float k0 = __bfloat162float(k[k_h_off + d0]);
+        float k1 = __bfloat162float(k[k_h_off + d1]);
+
+        k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
+        k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
     }
 }
 
@@ -119,16 +120,9 @@
                            torch::Tensor q_rot, torch::Tensor k_rot,
                            int B, int T, int Hq, int Hkv, int D,
          

... (934 more chars)
args
{
  "content": "<7373 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The grid-stride loop is hurting shape 3. Let me revert to the simple per-pair kernel that gave 81.6% on the long shape:
assistant
in 505out 2610cache_r 85632cache_w 0
assistant
bash
args
{
  "description": "Check correctness and benchmark",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
  "timeout": 300000
}
tool
stdout (901 chars)
PASS
shape=0 variant=eager tflops=0.178 gbps=256.093 ms=0.353
shape=0 variant=compiled tflops=0.505 gbps=724.851 ms=0.125
shape=0 variant=solution tflops=0.784 gbps=1125.522 ms=0.080
shape=0 solution_peak_fraction=0.6253
shape=1 variant=eager tflops=0.179 gbps=261.964 ms=0.561
shape=1 variant=compiled tflops=0.501 gbps=731.708 ms=0.201
shape=1 variant=solution tflops=0.831 gbps=1213.329 ms=0.121
shape=1 solution_peak_fraction=0.6741
shape=2 variant=eager tflops=0.070 gbps=109.384 ms=0.202
shape=2 variant=compiled tflops=0.113 gbps=176.169 ms=0.126
shape=2 variant=solution tflops=0.410 gbps=640.000 ms=0.035
shape=2 solution_peak_fraction=0.3556
shape=3 variant=eager tflops=0.171 gbps=245.371 ms=1.472
shape=3 variant=compiled tflops=0.531 gbps=761.770 ms=0.474
shape=3 variant=solution tflops=1.011 gbps=1450.727 ms=0.249
shape=3 solution_peak_fraction=0.8060
peak_fraction: 0.5895
RESULT: OK
assistant
in 60out 124cache_r 89216cache_w 0
assistant
bash
args
{
  "description": "Clear JIT cache and re-benchmark",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && rm -rf ~/.cache/torch_extensions/py311_cu128/mrope_fused 2>/dev/null; uv run python benchmark.py 2>&1",
  "timeout": 300000
}
assistant
Results are noisy. Let me clear the JIT cache and re-run:
assistant
in 456out 159cache_r 89344cache_w 0
assistant
bash
args
{
  "description": "Run benchmark again for stable numbers",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
  "timeout": 300000
}
tool
stdout (896 chars)
shape=0 variant=eager tflops=0.234 gbps=336.108 ms=0.269
shape=0 variant=compiled tflops=0.490 gbps=703.346 ms=0.128
shape=0 variant=solution tflops=0.768 gbps=1101.570 ms=0.082
shape=0 solution_peak_fraction=0.6120
shape=1 variant=eager tflops=0.179 gbps=260.981 ms=0.563
shape=1 variant=compiled tflops=0.678 gbps=989.587 ms=0.149
shape=1 variant=solution tflops=1.024 gbps=1494.847 ms=0.098
shape=1 solution_peak_fraction=0.8305
shape=2 variant=eager tflops=0.138 gbps=215.831 ms=0.102
shape=2 variant=compiled tflops=0.170 gbps=265.948 ms=0.083
shape=2 variant=solution tflops=0.493 gbps=769.710 ms=0.029
shape=2 solution_peak_fraction=0.4276
shape=3 variant=eager tflops=0.170 gbps=244.208 ms=1.479
shape=3 variant=compiled tflops=0.468 gbps=671.715 ms=0.538
shape=3 variant=solution tflops=0.926 gbps=1328.602 ms=0.272
shape=3 solution_peak_fraction=0.7381
peak_fraction: 0.6329
RESULT: OK
assistant
in 1630out 116cache_r 90240cache_w 0
assistant
The benchmark is noisy at <0.1ms. Let me add some final micro-optimizations and do a more controlled benchmark:
assistant
bash
args
{
  "description": "Run benchmarks 3 times to check variance",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && for i in 1 2 3; do echo \"=== Run $i ===\" && uv run python benchmark.py 2>&1 | grep -E \"peak_fraction|shape=.*variant=solution\"; done",
  "timeout": 600000
}
tool
stdout (1295 chars)
=== Run 1 ===
shape=0 variant=solution tflops=0.780 gbps=1119.270 ms=0.081
shape=0 solution_peak_fraction=0.6218
shape=1 variant=solution tflops=0.842 gbps=1229.567 ms=0.120
shape=1 solution_peak_fraction=0.6831
shape=2 variant=solution tflops=1.011 gbps=1579.886 ms=0.014
shape=2 solution_peak_fraction=0.8777
shape=3 variant=solution tflops=0.937 gbps=1344.271 ms=0.269
shape=3 solution_peak_fraction=0.7468
peak_fraction: 0.7264
=== Run 2 ===
shape=0 variant=solution tflops=1.061 gbps=1522.872 ms=0.059
shape=0 solution_peak_fraction=0.8460
shape=1 variant=solution tflops=0.854 gbps=1246.415 ms=0.118
shape=1 solution_peak_fraction=0.6925
shape=2 variant=solution tflops=0.982 gbps=1534.295 ms=0.014
shape=2 solution_peak_fraction=0.8524
shape=3 variant=solution tflops=1.009 gbps=1447.099 ms=0.250
shape=3 solution_peak_fraction=0.8039
peak_fraction: 0.7960
=== Run 3 ===
shape=0 variant=solution tflops=0.785 gbps=1125.971 ms=0.080
shape=0 solution_peak_fraction=0.6255
shape=1 variant=solution tflops=0.843 gbps=1230.720 ms=0.119
shape=1 solution_peak_fraction=0.6837
shape=2 variant=solution tflops=0.492 gbps=768.854 ms=0.029
shape=2 solution_peak_fraction=0.4271
shape=3 variant=solution tflops=0.948 gbps=1360.395 ms=0.265
shape=3 solution_peak_fraction=0.7558
peak_fraction: 0.6096
assistant
in 506out 181cache_r 91904cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (52 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,9 @@
 Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
 rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
 
-Each thread handles one RoPE pair (d, d+D/2). No shared memory, no syncthreads.
-One block per (b,t) pair — simple, no grid-stride overhead.
+Each thread handles one RoPE pair (d, d+D/2) and reads cos/sin ONLY from
+the axis assigned to its slice by mrope_section — cutting cache reads by 3x.
+One block per (b,t) pair. No shared memory, no syncthreads.
 """
 import torch
 import torch.nn as nn
@@ -51,23 +52,22 @@
     int d1 = d0 + D / 2;
 
     int off_bt = b * T + t;
-
-    // Gather per-axis cos/sin and mix according to mrope_section
-    float ct = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
-    float st = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
-    float ch = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
-    float sh = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
-    float cw = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
-    float sw = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
-
+    unsigned int d0u = d0;
+    unsigned int boundary1 = (unsigned int)s0;
+    unsigned int boundary2 = (unsigned int)(s0 + s1);
+
+    // Each thread reads cos/sin only from the axis that owns its d0 slice.
+    // cos_val/sin_val for d0 also apply to d1 (mirrored by mrope mixing).
     float cos_val, sin_val;
-    unsigned int d0u = d0;
-    if (d0u < (unsigned int)s0) {
-        cos_val = ct; sin_val = st;
-    } else if (d0u < (unsigned int)(s0 + s1)) {
-        cos_val = ch; sin_val = sh;
+    if (d0u < boundary1) {
+        cos_val = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
+        sin_val = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
+    } else if (d0u < boundary2) {
+        cos_val = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
+        sin_val = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
     } else {
-        cos_val = cw; sin_val = sw;
+        cos_val = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
+        sin_val = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
     }
 
     // --- strides ---
args
{
  "content": "<7517 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Significant variance across runs (0.6096-0.7960 peak fraction) due to GPU boost states. Let me try one more optimization — per-thread selective cache reads (each thread reads only from its assigned axis slice, cutting 6 cache loads to 2):
assistant
in 661out 2662cache_r 94336cache_w 0
assistant
bash
args
{
  "description": "Verify correctness then benchmark 3x",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && echo \"---\" && for i in 1 2 3; do echo \"=== Run $i ===\" && uv run python benchmark.py 2>&1 | grep -E \"peak_fraction|shape=.*variant=solution\"; done",
  "timeout": 600000
}
tool
stdout (1304 chars)
PASS
---
=== Run 1 ===
shape=0 variant=solution tflops=0.761 gbps=1092.399 ms=0.083
shape=0 solution_peak_fraction=0.6069
shape=1 variant=solution tflops=0.857 gbps=1251.339 ms=0.117
shape=1 solution_peak_fraction=0.6952
shape=2 variant=solution tflops=0.515 gbps=805.125 ms=0.027
shape=2 solution_peak_fraction=0.4473
shape=3 variant=solution tflops=1.015 gbps=1457.096 ms=0.248
shape=3 solution_peak_fraction=0.8095
peak_fraction: 0.6252
=== Run 2 ===
shape=0 variant=solution tflops=1.028 gbps=1475.095 ms=0.061
shape=0 solution_peak_fraction=0.8195
shape=1 variant=solution tflops=0.858 gbps=1253.387 ms=0.117
shape=1 solution_peak_fraction=0.6963
shape=2 variant=solution tflops=0.965 gbps=1507.525 ms=0.015
shape=2 solution_peak_fraction=0.8375
shape=3 variant=solution tflops=0.941 gbps=1349.899 ms=0.268
shape=3 solution_peak_fraction=0.7499
peak_fraction: 0.7737
=== Run 3 ===
shape=0 variant=solution tflops=0.756 gbps=1084.421 ms=0.083
shape=0 solution_peak_fraction=0.6025
shape=1 variant=solution tflops=1.043 gbps=1523.603 ms=0.096
shape=1 solution_peak_fraction=0.8464
shape=2 variant=solution tflops=0.956 gbps=1494.486 ms=0.015
shape=2 solution_peak_fraction=0.8303
shape=3 variant=solution tflops=0.944 gbps=1354.273 ms=0.267
shape=3 solution_peak_fraction=0.7524
peak_fraction: 0.7513
assistant
in 67out 169cache_r 98688cache_w 0
assistant
The per-axis version shows higher peak fractions on good runs (up to 77.4%). Variance is from GPU boost states. This is the better kernel (3x fewer cache reads). Let me make sure it passes check.py and do a final benchmark:
assistant
bash
args
{
  "description": "Final correctness and benchmark verification",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && echo \"---FINAL BENCHMARK---\" && uv run python benchmark.py 2>&1",
  "timeout": 300000
}
tool
stdout (923 chars)
PASS
---FINAL BENCHMARK---
shape=0 variant=eager tflops=0.234 gbps=335.988 ms=0.269
shape=0 variant=compiled tflops=0.492 gbps=706.428 ms=0.128
shape=0 variant=solution tflops=0.781 gbps=1120.159 ms=0.081
shape=0 solution_peak_fraction=0.6223
shape=1 variant=eager tflops=0.206 gbps=300.810 ms=0.489
shape=1 variant=compiled tflops=0.679 gbps=991.189 ms=0.148
shape=1 variant=solution tflops=1.028 gbps=1500.952 ms=0.098
shape=1 solution_peak_fraction=0.8339
shape=2 variant=eager tflops=0.137 gbps=214.226 ms=0.103
shape=2 variant=compiled tflops=0.172 gbps=268.219 ms=0.082
shape=2 variant=solution tflops=0.496 gbps=774.454 ms=0.029
shape=2 solution_peak_fraction=0.4303
shape=3 variant=eager tflops=0.177 gbps=254.258 ms=1.420
shape=3 variant=compiled tflops=0.532 gbps=762.723 ms=0.473
shape=3 variant=solution tflops=1.016 gbps=1458.037 ms=0.248
shape=3 solution_peak_fraction=0.8100
peak_fraction: 0.6521
RESULT: OK
assistant
in 661out 187cache_r 98816cache_w 0
assistant
bash
args
{
  "description": "Run lint check on solution",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/09_fmha_preattn_mrope/solution.py 2>&1"
}
tool
stdout (19 chars)
All checks passed!
assistant
Let me run lint and verify the final state:
assistant
in 414out 108cache_r 100480cache_w 0
assistant
bash
args
{
  "description": "Final verification: check.py, framework detection, git status",
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && echo \"---framework---\" && cat framework.txt 2>/dev/null && echo \"---git status---\" && git status --short",
  "timeout": 120000
}
tool
stdout (119 chars)
PASS
---framework---
cuda_raw
---git status---
?? ./
?? ../10_patch_embed_conv3d_gemm/
?? ../../scripts/sweep_09_10.sh
assistant
in 57out 150cache_r 102528cache_w 0
assistant
All checks pass: - **Correctness**: `PASS` (all 4 shapes x 3 seeds) - **Lint**: Clean - **Framework**: `cuda_raw` (CUDA C++ via load_inline) - **Forbidden ops**: None used The kernel fuses cache gather + per-axis mrope mixing + rotate-half rotation + layout transpose into a single pass. Each thread reads cos/sin only from its assigned mrope axis (2 reads/thread instead of 6), and handles one RoPE pair (d, d+D/2) without shared memory or synchronization. Output lands directly in (B, H, T, D) attention layout.
assistant
in 132out 142cache_r 102784cache_w 0