./kernelbenchhardv3runsopencode · ?
opencode / ? session: ses_1fc52ab5fffe0YiDprtcLym5Qc cwd: ?
INCOMPLETE SESSION. hit wall-clock budget (SIGTERM). The transcript below is usable but may be missing the agent's final tool calls or summary. Don't score this run as a clean failure or success.
harness
opencode
model
?
turns
196
tools called
82
events
276
input toks
5,089,607
output toks
76,415
cache hit
0
solution.py
final answer
benchmark.log
check.log
result.json
"""Multi-axis RoPE pre-attention prep — Triton v3 (SM120 Blackwell).

Two kernel types:
- Large shapes: Grid=(B*T, 1), each program loops over all H heads (serial).
- Small H shapes: Grid=(B*T*H, 1), each program processes exactly one (b,h,t).

The second gives more parallel SMs for better occupancy with small head counts.
"""

import torch
import triton
import triton.language as tl
import torch.nn as nn


@triton.jit
def _rope_serial_h_kernel(
    x_ptr, x_out_ptr,
    cos_cache_ptr, sin_cache_ptr,
    pos_t_ptr, pos_h_ptr, pos_w_ptr,
    B_val: tl.constexpr, T_val: tl.constexpr,
    H_val: tl.constexpr, D_val: tl.constexpr,
    stride_xb, stride_xt, stride_xh, stride_xd,
    stride_ob, stride_oh, stride_ot, stride_od,
    s0: tl.constexpr, s1: tl.constexpr,
    half: tl.constexpr, BLOCK_D: tl.constexpr,
):
    """Grid=(B*T). Serial loop over H heads."""
    pid_bt = tl.program_id(0)
    pid_b = pid_bt // T_val
    pid_t = pid_bt % T_val

    pos_idx = pid_b * T_val + pid_t
    p_t = tl.load(pos_t_ptr + pos_idx)
    p_h = tl.load(pos_h_ptr + pos_idx)
    p_w = tl.load(pos_w_ptr + pos_idx)

    d_offs = tl.arange(0, BLOCK_D)
    mask = d_offs < D_val

    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)

    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
    is_t = axis_idx < s0
    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
    cos_v = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
    sin_v = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))

    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
    sign = tl.where(d_offs < half, -1.0, 1.0)

    for h in tl.static_range(H_val):
        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
        x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
        xr = tl.load(x_ptr + x_base + perm  * stride_xd, mask=mask, other=0.0).to(tl.float32)
        xr = xr * sign
        out = x * cos_v + xr * sin_v
        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)


@triton.jit
def _rope_parallel_h_kernel(
    x_ptr, x_out_ptr,
    cos_cache_ptr, sin_cache_ptr,
    pos_t_ptr, pos_h_ptr, pos_w_ptr,
    B_val: tl.constexpr, T_val: tl.constexpr,
    D_val: tl.constexpr,
    stride_xb, stride_xt, stride_xh, stride_xd,
    stride_ob, stride_oh, stride_ot, stride_od,
    s0: tl.constexpr, s1: tl.constexpr,
    half: tl.constexpr, BLOCK_D: tl.constexpr,
):
    """Grid=(B*T*H). Each program = exactly one (b,h,t)."""
    pid_bth = tl.program_id(0)
    pid_b = pid_bth // tl.num_programs(1)
    pid_th = pid_bth % tl.num_programs(1)
    # Actually this doesn't work cleanly. Let me pass explicit B, T, H.
    pass


@triton.jit
def _rope_flat_kernel(
    x_ptr, x_out_ptr,
    cos_cache_ptr, sin_cache_ptr,
    pos_t_ptr, pos_h_ptr, pos_w_ptr,
    H_val: tl.constexpr, D_val: tl.constexpr,
    total_bt: tl.constexpr,
    stride_xb, stride_xt, stride_xh, stride_xd,
    stride_ob, stride_oh, stride_ot, stride_od,
    s0: tl.constexpr, s1: tl.constexpr,
    half: tl.constexpr, BLOCK_D: tl.constexpr,
):
    """Grid=(B*T*H). Each program = one (b,h,t). No serial loop."""
    pid = tl.program_id(0)
    pid_h = pid % H_val
    pid_bt = pid // H_val
    pid_b = pid_bt // (total_bt // H_val)  # total_bt = B*T, so pid_bt // (BT/H) doesn't work
    # This approach is getting messy. Let me just pass B and T directly.
    pass


@triton.jit
def _rope_one_per_head(
    x_ptr, x_out_ptr,
    cos_cache_ptr, sin_cache_ptr,
    pos_t_ptr, pos_h_ptr, pos_w_ptr,
    B_val: tl.constexpr, T_val: tl.constexpr,
    H_val: tl.constexpr, D_val: tl.constexpr,
    stride_xb, stride_xt, stride_xh, stride_xd,
    stride_ob, stride_oh, stride_ot, stride_od,
    s0: tl.constexpr, s1: tl.constexpr,
    half: tl.constexpr, BLOCK_D: tl.constexpr,
):
    """Grid=(B*T*H). Each program = exactly one (b, h, t).
    Cache cos/sin per (b,t), but no head loop - maximum parallelism."""
    pid = tl.program_id(0)
    total_bt = B_val * T_val
    pid_h = pid % H_val
    pid_bth = pid // H_val
    pid_b = pid_bth // T_val
    pid_t = pid_bth % T_val

    pos_idx = pid_b * T_val + pid_t
    p_t = tl.load(pos_t_ptr + pos_idx)
    p_h = tl.load(pos_h_ptr + pos_idx)
    p_w = tl.load(pos_w_ptr + pos_idx)

    d_offs = tl.arange(0, BLOCK_D)
    mask = d_offs < D_val

    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)

    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
    is_t = axis_idx < s0
    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
    cos_v = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
    sin_v = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))

    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
    sign = tl.where(d_offs < half, -1.0, 1.0)

    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
    x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
    xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
    xr = xr * sign
    out = x * cos_v + xr * sin_v
    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)


class Model(nn.Module):
    def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
        super().__init__()
        assert sum(mrope_section) == D // 2
        self.B, self.T = B, T
        self.Hq, self.Hkv, self.D = Hq, Hkv, D
        self.mrope_section = tuple(mrope_section)
        self.max_pos = max_pos

        inv_freq = _build_inv_freq(D)
        pos = torch.arange(max_pos, dtype=torch.float32)
        freqs = torch.outer(pos, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
        self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))

    def _next_pow2(self, n):
        v = 1
        while v < n:
            v *= 2
        return v

    def _launch(self, x, x_out, H, method):
        B, T, _, D = x.shape
        s0, s1, _ = self.mrope_section
        half = D // 2
        BLOCK_D = self._next_pow2(D)

        if method == "parallel":
            grid = (B * T * H,)
            _rope_one_per_head[grid](
                x, x_out,
                self.cos_cache, self.sin_cache,
                x.reshape(-1, 1, 1)[:, 0, 0].data_ptr(),  # wrong
                pos_h, pos_w,
                B_val=B, T_val=T, H_val=H, D_val=D,
                stride_xb=T*H*D, stride_xt=H*D, stride_xh=D, stride_xd=1,
                stride_ob=H*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
                s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
                num_warps=1,
            )
        else:
            grid = (B * T,)
            _rope_serial_h_kernel[grid](
                x, x_out,
                self.cos_cache, self.sin_cache,
                x.reshape(-1, 1, 1)[:, 0, 0].data_ptr(),  # wrong
                pos_h, pos_w,
                B_val=B, T_val=T, H_val=H, D_val=D,
                stride_xb=T*H*D, stride_xt=H*D, stride_xh=D, stride_xd=1,
                stride_ob=H*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
                s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
                num_warps=1, num_stages=4,
            )

    def forward(self, q, k, pos_t, pos_h, pos_w):
        B, T, Hq, D = q.shape
        _, _, Hkv, _ = k.shape
        s0, s1, _ = self.mrope_section

        q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
        k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)

        BLOCK_D = self._next_pow2(D)
        half = D // 2

        _rope_serial_h_kernel[(B * T,)](
            q, q_out,
            self.cos_cache, self.sin_cache,
            pos_t, pos_h, pos_w,
            B_val=B, T_val=T, H_val=Hq, D_val=D,
            stride_xb=T*Hq*D, stride_xt=Hq*D, stride_xh=D, stride_xd=1,
            stride_ob=Hq*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
            num_warps=1, num_stages=4,
        )
        _rope_serial_h_kernel[(B * T,)](
            k, k_out,
            self.cos_cache, self.sin_cache,
            pos_t, pos_h, pos_w,
            B_val=B, T_val=T, H_val=Hkv, D_val=D,
            stride_xb=T*Hkv*D, stride_xt=Hkv*D, stride_xh=D, stride_xd=1,
            stride_ob=Hkv*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
            num_warps=1, num_stages=4,
        )

        return q_out, k_out


def _build_inv_freq(D, base=10000.0):
    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
OK, this is about the same as the separate Q+K version (52.6% vs 51.0%). The fusion doesn't help because each program still has to loop through Q+K heads sequentially. The benefit would be avoiding the second kernel launch overhead, but that's negligible compared to the memory bandwidth. Let me try a fundamentally different approach: since Triton 3.6.0 is available, let me use its newer features like `triton.language.core.load/store` with pointer vectors or batched loads. Actually, the key insight I'm missing: I should try increasing the grid size to get more SMs saturated, especially for shape 2. Let me try a kernel where I parallelize both T and H dimensions, effectively having grid = (B, T, H):
shape=0 variant=eager tflops=0.179 gbps=257.108 ms=0.351
shape=0 variant=compiled tflops=0.706 gbps=1012.788 ms=0.089
shape=0 variant=solution tflops=1.126 gbps=1615.762 ms=0.056
shape=0 solution_peak_fraction=0.8976
shape=1 variant=eager tflops=0.182 gbps=265.422 ms=0.554
shape=1 variant=compiled tflops=0.481 gbps=702.879 ms=0.209
shape=1 variant=solution tflops=0.743 gbps=1085.331 ms=0.135
shape=1 solution_peak_fraction=0.6030
shape=2 variant=eager tflops=0.066 gbps=103.427 ms=0.214
shape=2 variant=compiled tflops=0.139 gbps=216.575 ms=0.102
shape=2 variant=solution tflops=0.323 gbps=503.974 ms=0.044
shape=2 solution_peak_fraction=0.2800
shape=3 variant=eager tflops=0.173 gbps=247.776 ms=1.457
shape=3 variant=compiled tflops=0.464 gbps=665.084 ms=0.543
shape=3 variant=solution tflops=0.869 gbps=1246.559 ms=0.290
shape=3 solution_peak_fraction=0.6925
peak_fraction: 0.5692
RESULT: OK
PASS
{
    "problem": "09_fmha_preattn_mrope",
    "harness": "opencode",
    "model": "openrouter-pinned/qwen/qwen3.6-plus",
    "reasoning_effort": "",
    "has_solution": true,
    "correct": true,
    "peak_fraction": 0.5692,
    "elapsed_seconds": 2700,
    "harness_exit_code": 124,
    "session_complete": false,
    "usage": {"input_tokens": 5089607, "output_tokens": 76415, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 39072, "total_cost_usd": null}
}

timeline (276 events)

assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}
tool
stdout (215 chars)
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py

(7 entries)
</entries>
assistant
in 17787out 53cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py"
}
tool
stdout (4884 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
1: """Naive multi-axis RoPE pre-attention reference (correctness only).
2: 
3: Mirrors the Qwen2-VL apply_multimodal_rotary_pos_emb semantics: three position
4: axes (temporal, height, width), each contributing a slice of the head_dim via
5: mrope_section. Standard rotate-half on q and k once per-axis cos/sin have been
6: mixed.
7: 
8: Layout convention:
9:   q: (B, T, Hq, D)   bf16
10:   k: (B, T, Hkv, D)  bf16
11:   pos_t, pos_h, pos_w: (B, T) int64
12: 
13: Output: q_rot, k_rot in (B, H, T, D) — i.e. the layout an attention kernel
14: expects (head dim second). The transpose is part of the "pre-attention prep".
15: """
16: import torch
17: import torch.nn as nn
18: 
19: OP_TYPE = "rope"
20: SUPPORTED_PRECISIONS = ["bf16"]
21: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
22: 
23: 
24: def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
25:     return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
26: 
27: 
28: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
29:     half = x.shape[-1] // 2
30:     return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
31: 
32: 
33: class Model(nn.Module):
34:     def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
35:                  mrope_section: tuple[int, int, int], max_pos: int):
36:         super().__init__()
37:         assert sum(mrope_section) == D // 2, \
38:          

... (3384 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py"
}
tool
stdout (5019 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for multi-axis RoPE pre-attention prep.
2: 
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9: 
10: import torch
11: import yaml
12: 
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16: 
17: from src.eval.correctness import check_correctness  # noqa: E402
18: 
19: 
20: def main():
21:     try:
22:         import reference
23:         import shapes
24:         import solution
25:     except Exception as e:
26:         print(f"FAIL: import error: {e}")
27:         sys.exit(1)
28: 
29:     problem_yaml = Path("problem.yaml")
30:     meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31: 
32:     # --- Forbidden-op check ------------------------------------------------
33:     sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34:     for forbidden in meta.get("forbidden", []):
35:         pat = re.escape(forbidden)
36:         if re.search(pat, sol_src):
37:             print(f"FAIL: forbidden op used: {forbidden}")
38:             sys.exit(1)
39: 
40:     device = torch.device("cuda:0")
41:     tol_override = meta.get("tolerance") or

... (3519 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py"
}
tool
stdout (5466 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for multi-axis RoPE pre-attention prep.
2: 
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5: 
6: Output lines the harness picks up:
7:   shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8:   peak_fraction: <N>  (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13: 
14: import torch
15: import yaml
16: 
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19: 
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction  # noqa: E402
21: from src.eval.timing import time_fn  # noqa: E402
22: from src.hardware import get as get_hw  # noqa: E402
23: 
24: 
25: def _eval_formula(expr: str, vars: dict) -> float:
26:     # Very small eval: only names from `vars` are valid.
27:     return float(eval(expr, {"__builtins__": {}}, vars))
28: 
29: 
30: def main():
31:     import reference
32:     import shapes
33:     import solution
34: 
35:     meta = yaml.safe_load(Path("problem.yaml").read_text())
36:     hw = get_hw(meta["hardware"][0])
37:     peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38:     peak_gbps = hw.peak_bandwidth_gb_s
39:     regime = meta.get("regime

... (3966 more chars)
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py"
}
tool
stdout (1187 chars)
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for multi-axis RoPE pre-attention.
2: 
3: Sized for Qwen2-VL-style vision-language inference: long T, GQA, mrope_section
4: splitting D/2 across temporal/height/width axes. Mix of base, long-context,
5: small-head-dim, and very-long-context.
6: """
7: 
8: SHAPES = [
9:     # Qwen2-VL base: T=4k, head_dim=128, GQA 32:8, mrope [16,24,24]
10:     {"B": 1, "T": 4096,  "Hq": 32, "Hkv": 8, "D": 128,
11:      "mrope_section": (16, 24, 24), "max_pos": 32768},
12: 
13:     # Qwen2-VL 7B-style long context, narrower kv heads
14:     {"B": 1, "T": 8192,  "Hq": 28, "Hkv": 4, "D": 128,
15:      "mrope_section": (16, 24, 24), "max_pos": 32768},
16: 
17:     # Smaller head_dim and B=2 (tests batch dim and reduced D)
18:     {"B": 2, "T": 2048,  "Hq": 16, "Hkv": 2, "D": 64,
19:      "mrope_section": (8, 12, 12), "max_pos": 16384},
20: 
21:     # Very long context (16k tokens)
22:     {"B": 1, "T": 16384, "Hq": 32, "Hkv": 8, "D": 128,
23:      "mrope_section": (16, 24, 24), "max_pos": 65536},
24: ]

(End of file - total 24 lines)
</content>
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
assistant
in 17928out 262cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml"
}
tool
stdout (1784 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml</path>
<type>file</type>
<content>
1: name: 09_fmha_preattn_mrope
2: display_name: "Multi-axis RoPE Pre-attention Prep"
3: precision: bf16
4: regime: memory  # pointwise rotation + cache gather; bandwidth-bound
5: 
6: # FLOPs per call: rotate-half on q and k.
7: # Per element: out = a*cos + b*sin -> 3 flops/elem (1 mul + 1 mul + 1 add).
8: # We do this on q ((B*T*Hq*D) elems) and k ((B*T*Hkv*D) elems).
9: flops_formula: "3 * B * T * (Hq + Hkv) * D"
10: 
11: # Bytes moved (lower bound):
12: #   read q, k:                  2 * B*T*(Hq+Hkv)*D
13: #   read 3 position id arrays:  3 * B*T*8        (int64)
14: #   read cos/sin from cache:    6 * B*T*D*2      (3 axes, 2 caches, bf16)
15: #   write q_rot, k_rot:         2 * B*T*(Hq+Hkv)*D
16: bytes_formula: "4 * B * T * (Hq + Hkv) * D + 12 * B * T * D + 24 * B * T"
17: 
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16    # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram
21: 
22: tolerance:
23:   bfloat16: 0.02
24: 
25: # Forbidden ops -- agent must implement the gather + per-axis mix + rotation.
26: # Banning vendor mRoPE entry points and obvious shortcuts.
27: forbidden:
28:   - "transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb"
29:   - "apply_multimodal_rotary_pos_emb"
30:   - "flash_attn.layers.rotary"
31:   - "flashinfer.rope"
32:   - "flashinfer.apply_rope"
33: 
34: sota:
35:   name:

... (284 more chars)
assistant
in 24201out 55cache_r 0cache_w 0
assistant
glob /home/infatoshi/cuda/KernelBench-Hard
args
{
  "pattern": "problems/*/solution.py",
  "path": "/home/infatoshi/cuda/KernelBench-Hard"
}
tool
stdout (14 chars)
No files found
assistant
bash
args
{
  "command": "uv run python -c \"import triton; print(triton.__version__)\" 2>&1",
  "workdir": "/home/infatoshi/cuda/KernelBench-Hard",
  "description": "Check if Triton is available"
}
tool
stdout (6 chars)
3.6.0
assistant
bash
args
{
  "command": "nvidia-smi --query-gpu=name,memory.total,compute_cap,pstate --format=csv 2>&1 | head -10",
  "description": "Check GPU info"
}
tool
stdout (117 chars)
name, memory.total [MiB], compute_cap, pstate
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 12.0, P0
assistant
Let me start by understanding the existing solutions and the environment.
assistant
in 24976out 226cache_r 0cache_w 0
assistant
todowrite
args
{
  "todos": [
    {
      "content": "Implement initial Triton kernel for mRoPE",
      "status": "in_progress",
      "priority": "high"
    },
    {
      "content": "Run check.py for correctness",
      "status": "pending",
      "priority": "high"
    },
    {
      "content": "Optimize kernel performance",
      "status": "pending",
      "priority": "high"
    },
    {
      "content": "Run benchmark.py to measure achieved bandwidth",
      "status": "pending",
      "priority": "medium"
    }
  ]
}
tool
stdout (335 chars)
content: 'Implement initial Triton kernel for mRoPE'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Optimize kernel performance'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py to measure achieved bandwidth'
status: 'pending'
priority: 'medium'
assistant
in 25279out 127cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}
tool
stdout (4080 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2: 
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7: 
8: import torch
9: 
10: DEFAULT_TOLERANCE = {
11:     torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12:     torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13:     torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14:     torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15:     torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16:     torch.int8: {"atol": 0, "rtol": 0},
17:     torch.int32: {"atol": 0, "rtol": 0},
18:     torch.int64: {"atol": 0, "rtol": 0},
19: }
20: 
21: 
22: def _coerce_float(x):
23:     """Accept int/float/str scientific-notation as float; pass through dicts."""
24:     if isinstance(x, (int, float)):
25:         return float(x)
26:     if isinstance(x, str):
27:         try:
28:             return float(x)
29:         except ValueError:
30:             return x
31:     return x
32: 
33: 
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35:     """Lookup atol/rtol for a given dtype, with optional per-problem override.
36: 
37:     The override dict is keyed by str(dtype). Values may be:
38:       - a single number (or a string parseable as float, e.g. "5e-2") -> used
39:         as both atol and rtol
4

... (2580 more chars)
assistant
Let me understand the problem and build the implementation.
assistant
in 25564out 59cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (238 lines — TRUNCATED)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,235 @@
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120).
+
+Fused kernel: gather per-axis cos/sin from cache, mix by mrope_section,
+apply rotate-half rotation, transpose (B,T,H,D)->(B,H,T,D) in one pass.
+"""
+
+import torch
+import triton
+import triton.language as tl
+import torch.nn as nn
+
+
+@triton.jit
+def _rope_fwd_kernel(
+    x_ptr,
+    x_out_ptr,
+    pos_t_ptr,
+    pos_h_ptr,
+    pos_w_ptr,
+    cos_cache_ptr,
+    sin_cache_ptr,
+    B_val: tl.constexpr,
+    T_val: tl.constexpr,
+    H_val: tl.constexpr,
+    D_val: tl.constexpr,
+    stride_xb,
+    stride_xt,
+    stride_xh,
+    stride_xd,
+    stride_ob,
+    stride_oh,
+    stride_ot,
+    stride_od,
+    stride_pos_b,
+    stride_pos_t,
+    s0: tl.constexpr,
+    s1: tl.constexpr,
+    half: tl.constexpr,
+    BLOCK_D: tl.constexpr,
+):
+    """Each program handles one (b, h, t). Loads full D, rotates, writes to (B,H,T,D)."""
+    pid_bh = tl.program_id(0)
+    pid_t = tl.program_id(1)
+
+    # Decode (batch, head)
+    pid_b = pid_bh // H_val
+    pid_h = pid_bh % H_val
+
+    # Position index in (B,T) array
+    pos_idx = pid_b * stride_pos_b + pid_t * stride_pos_t
+
+    # Load position IDs
+    p_t = tl.load(pos_t_ptr + pos_idx)
+    p_h = tl.load(pos_h_ptr + pos_idx)
+    p_w = tl.load(pos_w_ptr + pos_idx)
+
+    # Dimension offsets
+    d_offs = tl.arange(0, BLOCK_D)
+    mask = d_offs < D_val
+    half_mask = d_offs < half
+
+    # Base offset into input tensor (B,T,H,D)
+    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+
+    # Output base offset (B,H,T,D)
+    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+
+    # Load x vector
+    x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0)
+
+    # Load cos/sin for temporal axis
+    cos_t_ = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0)
+    sin_t_ = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0)
+
+    # Load cos/sin for height axis
+    cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0)
+    sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0)
+
+    # Load cos/sin for width axis
+    cos_w_ = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0)
+    sin_w_ = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0)
+
+    # Mix axes: for d in [0, half), axis_idx = d; for d in [half, D), axis_idx = d - half
+    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+
+    # Axis assignment masks
+    is_temporal = axis_idx < s0
+    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    # is_width is implicit (the rest)
+
+    # Select cos/sin per axis
+    cos = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
+    sin = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
+
+    # rotate_half(x) = [-x[half:], x[:half]]
+    # For d < half: x_rot[d] = -x[d + half]
+    # For d >= half: x_rot[d] = x[d - half]
+    # Permute offsets
+    permuted_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    x_rot_loaded = tl.load(x_ptr + x_base + permuted_offs * stride_xd, mask=mask, other=0.0)
+    x_rot_loaded = tl.where(d_offs < half, -x_rot_loaded, x_rot_loaded)
+
+    # Apply rotation: out = x*cos + x_rot*sin
+    x_fp32 = x.to(tl.float32)
+    cos_fp32 = cos.to(tl.float32)
+    sin_fp32 = sin.to(tl.float32)
+    x_rot_fp32 = x_rot_loaded.to(tl.float32)
+
+    out = (x_fp32 * cos_fp32) + (x_rot_fp32 * sin_fp32)
+    out = out.to(tl.bfloat16)
+
+    # Store in transposed layout (B,H,T,D)
+    tl.store(x_out_ptr + o_base + d_offs * stride_od, out, mask=mask)
+
+
+class Model(nn.Module):
+    def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
+                 mrope_section: tuple[int, int, int], max_pos: int):
+        super().__init__()
+        assert sum(mrope_section) == D // 2, \
+            f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
+        self.B, self.T = B, T
+        self.Hq, self.Hkv, self.D = D
+        self.mrope_section = tuple(mrope_section)
+        self.max_pos = max_pos
+
+        inv_freq = _build_inv_freq(D)
+        pos = torch.arange(max_pos, dtype=torch.float32)
+        freqs = torch.outer(pos, inv_freq)
+        emb = torch.cat([freqs, freqs], dim=-1)
+        self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
+        self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+
+        # Precompute grid and strides
+        self._q_grid = (B * Hq, T)
+        self._k_grid = (B * Hkv, T)
+        self._block_d = self._next_pow2(D)
+
+        q_strides = self._strides(B, T, Hq, D)
+        k_strides = self._strides(B, T, Hkv, D)
+        self._q_strides = q_strides
+        self._k_strides = k_strides
+
+    @staticmethod
+    def _strides(B, T, H, D):
+        return {
+            'stride_xb': T * H * D,
+            'stride_xt': H * D,
+            'stride_xh': D,
+            'stride_xd': 1,
+            'stride_ob': H * T * D,
+            'stride_oh': T * D,
+            'stride_ot': D,
+            'stride_od': 1,
+        }
+
+    @staticmethod
+    def _next_pow2(n):
+        v = 1
+        while v < n:
+            v *= 2
+        return v
+
+    def forward(
+        self,
+        q: torch.Tensor, k: torch.Tensor,
+        pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        B, T, Hq, D = q.shape
+        _, _, Hkv, _ = k.shape
+        s0, s1, _ = self.mrope_section
+        half = D // 2
+        BLOCK_D = self._next_pow2(D)
+
+        # Allocate outputs in (B,H,T,D) layout
+        q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
+        k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.dev

... (1531 more chars)
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
  "content": "<7177 chars \u2014 see diff>"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 27096out 2485cache_r 0cache_w 0
assistant
bash
args
{
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python check.py 2>&1",
  "timeout": 120000,
  "workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
  "description": "Run correctness check"
}
tool
stdout (159 chars)
/home/infatoshi/cuda/KernelBench-Hard/.venv/bin/python3: can't open file '/home/infatoshi/cuda/KernelBench-Hard/check.py': [Errno 2] No such file or directory
assistant
in 29599out 122cache_r 0cache_w 0
assistant
bash
args
{
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (591 chars)
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 57, in main
    sol_model = solution.Model(*init_args).to(device).eval()
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 123, in __init__
    self.Hq, self.Hkv, self.D = D
    ^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot unpack non-iterable int object
assistant
in 29786out 97cache_r 0cache_w 0
assistant
edit /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (11 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -120,7 +120,7 @@
         assert sum(mrope_section) == D // 2, \
             f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
         self.B, self.T = B, T
-        self.Hq, self.Hkv, self.D = D
+        self.Hq, self.Hkv, self.D = Hq, Hkv, D
         self.mrope_section = tuple(mrope_section)
         self.max_pos = max_pos
 
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
  "oldString": "        self.Hq, self.Hkv, self.D = D",
  "newString": "        self.Hq, self.Hkv, self.D = Hq, Hkv, D"
}
tool
stdout (26 chars)
Edit applied successfully.
assistant
in 30092out 112cache_r 0cache_w 0
assistant
Let me run the check again:
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 30221out 77cache_r 0cache_w 0
assistant
Good, correctness passes. Now let me profile and optimize.
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (893 chars)
shape=0 variant=eager tflops=0.173 gbps=247.814 ms=0.364
shape=0 variant=compiled tflops=0.460 gbps=659.603 ms=0.137
shape=0 variant=solution tflops=0.326 gbps=467.537 ms=0.193
shape=0 solution_peak_fraction=0.2597
shape=1 variant=eager tflops=0.179 gbps=260.789 ms=0.564
shape=1 variant=compiled tflops=0.461 gbps=673.805 ms=0.218
shape=1 variant=solution tflops=0.335 gbps=489.130 ms=0.301
shape=1 solution_peak_fraction=0.2717
shape=2 variant=eager tflops=0.072 gbps=112.126 ms=0.197
shape=2 variant=compiled tflops=0.154 gbps=240.209 ms=0.092
shape=2 variant=solution tflops=0.204 gbps=319.482 ms=0.069
shape=2 solution_peak_fraction=0.1775
shape=3 variant=eager tflops=0.168 gbps=240.441 ms=1.502
shape=3 variant=compiled tflops=0.460 gbps=660.587 ms=0.547
shape=3 variant=solution tflops=0.362 gbps=519.137 ms=0.696
shape=3 solution_peak_fraction=0.2884
peak_fraction: 0.2452
RESULT: OK
assistant
in 30312out 81cache_r 0cache_w 0
assistant
todowrite
args
{
  "todos": [
    {
      "status": "completed",
      "content": "Implement initial Triton kernel for mRoPE",
      "priority": "high"
    },
    {
      "status": "completed",
      "content": "Run check.py for correctness",
      "priority": "high"
    },
    {
      "status": "in_progress",
      "content": "Optimize kernel performance (target ~1.5+ TB/s)",
      "priority": "high"
    },
    {
      "status": "pending",
      "content": "Run benchmark.py to measure achieved bandwidth",
      "priority": "medium"
    }
  ]
}
tool
stdout (359 chars)
content: 'Implement initial Triton kernel for mRoPE'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Optimize kernel performance (target ~1.5+ TB/s)'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark.py to measure achieved bandwidth'
status: 'pending'
priority: 'medium'
assistant
in 30860out 136cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (364 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120).
-
-Fused kernel: gather per-axis cos/sin from cache, mix by mrope_section,
-apply rotate-half rotation, transpose (B,T,H,D)->(B,H,T,D) in one pass.
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
+
+Single fused kernel for Q+K: gather per-axis cos/sin from cache, mix by
+mrope_section, apply rotate-half rotation, transpose (B,T,H,D)->(B,H,T,D).
 """
 
 import torch
@@ -12,8 +12,10 @@
 
 @triton.jit
 def _rope_fwd_kernel(
-    x_ptr,
-    x_out_ptr,
+    q_ptr,
+    k_ptr,
+    q_out_ptr,
+    k_out_ptr,
     pos_t_ptr,
     pos_h_ptr,
     pos_w_ptr,
@@ -21,12 +23,13 @@
     sin_cache_ptr,
     B_val: tl.constexpr,
     T_val: tl.constexpr,
-    H_val: tl.constexpr,
+    Hq_val: tl.constexpr,
+    Hkv_val: tl.constexpr,
     D_val: tl.constexpr,
-    stride_xb,
-    stride_xt,
-    stride_xh,
-    stride_xd,
+    stride_qb,
+    stride_qt,
+    stride_qh,
+    stride_qd,
     stride_ob,
     stride_oh,
     stride_ot,
@@ -37,80 +40,63 @@
     s1: tl.constexpr,
     half: tl.constexpr,
     BLOCK_D: tl.constexpr,
+    FUSE_QK: tl.constexpr,
 ):
-    """Each program handles one (b, h, t). Loads full D, rotates, writes to (B,H,T,D)."""
-    pid_bh = tl.program_id(0)
-    pid_t = tl.program_id(1)
-
-    # Decode (batch, head)
-    pid_b = pid_bh // H_val
-    pid_h = pid_bh % H_val
-
-    # Position index in (B,T) array
+    """Single kernel handling both Q and K. Each program handles one (b, t) pair,
+       processes all Hq heads for Q and all Hkv heads for K."""
+    pid_bt = tl.program_id(0)
+    pid_b = pid_bt // T_val
+    pid_t = pid_bt % T_val
+
     pos_idx = pid_b * stride_pos_b + pid_t * stride_pos_t
-
-    # Load position IDs
     p_t = tl.load(pos_t_ptr + pos_idx)
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    # Dimension offsets
+    # Load per-position cache values (shared across all heads)
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
-    half_mask = d_offs < half
-
-    # Base offset into input tensor (B,T,H,D)
-    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
-
-    # Output base offset (B,H,T,D)
-    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
-
-    # Load x vector
-    x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0)
-
-    # Load cos/sin for temporal axis
-    cos_t_ = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0)
-    sin_t_ = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0)
-
-    # Load cos/sin for height axis
-    cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0)
-    sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0)
-
-    # Load cos/sin for width axis
-    cos_w_ = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0)
-    sin_w_ = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0)
-
-    # Mix axes: for d in [0, half), axis_idx = d; for d in [half, D), axis_idx = d - half
+
+    cos_t_val = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_t_val = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h_val = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h_val = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_w_val = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_w_val = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+
+    # axis index (same for first and second half: d and d+half share the same axis)
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-
-    # Axis assignment masks
     is_temporal = axis_idx < s0
     is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    # is_width is implicit (the rest)
-
-    # Select cos/sin per axis
-    cos = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
-    sin = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
-
-    # rotate_half(x) = [-x[half:], x[:half]]
-    # For d < half: x_rot[d] = -x[d + half]
-    # For d >= half: x_rot[d] = x[d - half]
-    # Permute offsets
+    cos_val = tl.where(is_temporal, cos_t_val, tl.where(is_height, cos_h_val, cos_w_val))
+    sin_val = tl.where(is_temporal, sin_t_val, tl.where(is_height, sin_h_val, sin_w_val))
+
+    # Permutation offsets for rotate_half
     permuted_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
-    x_rot_loaded = tl.load(x_ptr + x_base + permuted_offs * stride_xd, mask=mask, other=0.0)
-    x_rot_loaded = tl.where(d_offs < half, -x_rot_loaded, x_rot_loaded)
-
-    # Apply rotation: out = x*cos + x_rot*sin
-    x_fp32 = x.to(tl.float32)
-    cos_fp32 = cos.to(tl.float32)
-    sin_fp32 = sin.to(tl.float32)
-    x_rot_fp32 = x_rot_loaded.to(tl.float32)
-
-    out = (x_fp32 * cos_fp32) + (x_rot_fp32 * sin_fp32)
-    out = out.to(tl.bfloat16)
-
-    # Store in transposed layout (B,H,T,D)
-    tl.store(x_out_ptr + o_base + d_offs * stride_od, out, mask=mask)
+
+    # Process Q heads
+    for h in tl.static_range(Hq_val):
+        q_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qh
+        x = tl.load(q_ptr + q_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(q_ptr + q_base + permuted_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
+        # Apply negation for first half
+        x_rot = tl.where(d_offs < half, -x_rot, x_rot)
+        out = x * cos_val + x_rot * sin_val
+
+        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+        tl.store(q_ou

... (6582 more chars)
args
{
  "content": "<8820 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The compiled reference achieves ~660 GB/s, I'm at ~468-519 GB/s. Let me optimize aggressively - the key insight is that position IDs and cache values are shared across heads, so I should fuse Q and K processing.
assistant
in 31163out 2881cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (150 chars)
FAIL: shape 0 {'B': 1, 'T': 4096, 'Hq': 32, 'Hkv': 8, 'D': 128, 'mrope_section': (16, 24, 24), 'max_pos': 32768} seed 42 k_rot: solution contains NaN
assistant
in 34062out 68cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (301 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,8 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Single fused kernel for Q+K: gather per-axis cos/sin from cache, mix by
-mrope_section, apply rotate-half rotation, transpose (B,T,H,D)->(B,H,T,D).
+Two-kernel design: each program handles a single (b, t) position and loops
+over all heads of its tensor (Q or K). Position cache values (cos/sin) are
+computed once per (b,t) and reused across all heads.
 """
 
 import torch
@@ -11,11 +12,9 @@
 
 
 @triton.jit
-def _rope_fwd_kernel(
-    q_ptr,
-    k_ptr,
-    q_out_ptr,
-    k_out_ptr,
+def _rope_kernel(
+    x_ptr,
+    x_out_ptr,
     pos_t_ptr,
     pos_h_ptr,
     pos_w_ptr,
@@ -23,80 +22,65 @@
     sin_cache_ptr,
     B_val: tl.constexpr,
     T_val: tl.constexpr,
-    Hq_val: tl.constexpr,
-    Hkv_val: tl.constexpr,
+    H_val: tl.constexpr,
     D_val: tl.constexpr,
-    stride_qb,
-    stride_qt,
-    stride_qh,
-    stride_qd,
+    stride_xb,
+    stride_xt,
+    stride_xh,
+    stride_xd,
     stride_ob,
     stride_oh,
     stride_ot,
     stride_od,
-    stride_pos_b,
-    stride_pos_t,
     s0: tl.constexpr,
     s1: tl.constexpr,
     half: tl.constexpr,
     BLOCK_D: tl.constexpr,
-    FUSE_QK: tl.constexpr,
 ):
-    """Single kernel handling both Q and K. Each program handles one (b, t) pair,
-       processes all Hq heads for Q and all Hkv heads for K."""
+    """Grid = (B * T). Each program processes one (b,t) across ALL H heads."""
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
 
-    pos_idx = pid_b * stride_pos_b + pid_t * stride_pos_t
+    pos_idx = pid_b * T_val + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    # Load per-position cache values (shared across all heads)
+    # Load cache values once (shared across all heads)
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
-    cos_t_val = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_t_val = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h_val = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h_val = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_w_val = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_w_val = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_t_ = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_t_ = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_w_ = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_w_ = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    # axis index (same for first and second half: d and d+half share the same axis)
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
     is_temporal = axis_idx < s0
     is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_temporal, cos_t_val, tl.where(is_height, cos_h_val, cos_w_val))
-    sin_val = tl.where(is_temporal, sin_t_val, tl.where(is_height, sin_h_val, sin_w_val))
+    cos_val = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
+    sin_val = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
 
-    # Permutation offsets for rotate_half
-    permuted_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    # Precompute permutation offsets
+    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    # Process Q heads
-    for h in tl.static_range(Hq_val):
-        q_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qh
-        x = tl.load(q_ptr + q_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(q_ptr + q_base + permuted_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
-        # Apply negation for first half
-        x_rot = tl.where(d_offs < half, -x_rot, x_rot)
+    # Loop over all heads
+    for h in tl.static_range(H_val):
+        # Input offset (B,T,H,D)
+        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = x_rot * sign
+
         out = x * cos_val + x_rot * sin_val
 
+        # Output offset (B,H,T,D)
         o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(q_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
-
-    if FUSE_QK:
-        # Process K heads
-        for h in tl.static_range(Hkv_val):
-            k_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qd
-            x = tl.load(k_ptr + k_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
-            x_rot = tl.load(k_ptr + k_base + permuted_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
-            x_rot = tl.where(d_offs < half, -x_rot, x_rot)
-            out = x * cos_val + x_rot * sin_val
-
-            o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-            tl.store(k_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)

... (5157 more chars)
args
{
  "content": "<5576 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me look at the issue. The fused kernel has stride problems for k - the offsets are shared between Q and K. Let me restructure with a cleaner approach.
assistant
in 34226out 1984cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 36229out 68cache_r 0cache_w 0
assistant
Good, correct. Now the grid is B*T which should improve cache reuse significantly. Let me run the benchmark.
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (895 chars)
shape=0 variant=eager tflops=0.173 gbps=248.930 ms=0.363
shape=0 variant=compiled tflops=0.487 gbps=699.249 ms=0.129
shape=0 variant=solution tflops=0.624 gbps=895.309 ms=0.101
shape=0 solution_peak_fraction=0.4974
shape=1 variant=eager tflops=0.180 gbps=262.525 ms=0.560
shape=1 variant=compiled tflops=0.469 gbps=685.161 ms=0.215
shape=1 variant=solution tflops=0.702 gbps=1025.143 ms=0.143
shape=1 solution_peak_fraction=0.5695
shape=2 variant=eager tflops=0.071 gbps=110.257 ms=0.201
shape=2 variant=compiled tflops=0.160 gbps=249.261 ms=0.089
shape=2 variant=solution tflops=0.339 gbps=530.265 ms=0.042
shape=2 solution_peak_fraction=0.2946
shape=3 variant=eager tflops=0.168 gbps=241.057 ms=1.498
shape=3 variant=compiled tflops=0.466 gbps=668.829 ms=0.540
shape=3 variant=solution tflops=0.829 gbps=1190.220 ms=0.303
shape=3 solution_peak_fraction=0.6612
peak_fraction: 0.4847
RESULT: OK
assistant
in 36312out 91cache_r 0cache_w 0
assistant
Good, passing correctness at 48% peak. Let me optimize further.
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (177 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,8 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Two-kernel design: each program handles a single (b, t) position and loops
-over all heads of its tensor (Q or K). Position cache values (cos/sin) are
-computed once per (b,t) and reused across all heads.
+Grid = (B * (Hq+Hkv), T): each program handles one (b, head, t) tuple.
+Position cache values shared across heads within the same warp.
+Uses num_warps=1 for minimal register pressure (D=64/128 are small).
 """
 
 import torch
@@ -15,14 +15,15 @@
 def _rope_kernel(
     x_ptr,
     x_out_ptr,
+    cos_cache_ptr,
+    sin_cache_ptr,
     pos_t_ptr,
     pos_h_ptr,
     pos_w_ptr,
-    cos_cache_ptr,
-    sin_cache_ptr,
     B_val: tl.constexpr,
     T_val: tl.constexpr,
-    H_val: tl.constexpr,
+    H_total: tl.constexpr,
+    Hq_val: tl.constexpr,
     D_val: tl.constexpr,
     stride_xb,
     stride_xt,
@@ -37,17 +38,25 @@
     half: tl.constexpr,
     BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * T). Each program processes one (b,t) across ALL H heads."""
-    pid_bt = tl.program_id(0)
-    pid_b = pid_bt // T_val
-    pid_t = pid_bt % T_val
+    """Grid = (B * (Hq+Hkv), T). Each program = one (b, h, t).
+    Cos/sin cache values are shared for same (b,t), so threads within warp
+    compute them once and reuse."""
+    pid_bh = tl.program_id(0)
+    pid_t = tl.program_id(1)
+    warp_id = tl.program_id(2)
+
+    # Decode (batch, head)
+    pid_b = pid_bh // H_total
+    pid_h = pid_bh % H_total
+
+    pid_t2 = pid_bh * T_val + pid_t
 
     pos_idx = pid_b * T_val + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    # Load cache values once (shared across all heads)
+    # Load cache values
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
@@ -64,23 +73,23 @@
     cos_val = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
     sin_val = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
 
-    # Precompute permutation offsets
+    # Permutation
     perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    # Loop over all heads
-    for h in tl.static_range(H_val):
-        # Input offset (B,T,H,D)
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = x_rot * sign
+    # Input offset
+    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
 
-        out = x * cos_val + x_rot * sin_val
+    # Load x
+    x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+    x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+    x_rot = x_rot * sign
 
-        # Output offset (B,H,T,D)
-        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+    out = x * cos_val + x_rot * sin_val
+
+    # Output offset (B,H,T,D)
+    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
@@ -101,6 +110,8 @@
         self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
         self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
 
+        self._block_d = self._next_pow2(D)
+
     def _next_pow2(self, n):
         v = 1
         while v < n:
@@ -117,47 +128,41 @@
         s0, s1, _ = self.mrope_section
         half = D // 2
         BLOCK_D = self._next_pow2(D)
+        H_total = Hq + Hkv
 
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
 
-        bt_grid = (B * T,)
+        # Concatenate Q and K buffers along head dim:
+        # q is (B, T, Hq, D), k is (B, T, Hkv, D)
+        # Concat => x_cat (B, T, H_total, D)
+        x_cat = torch.cat([q, k], dim=2)
+        x_out = torch.empty(B, H_total, T, D, dtype=torch.bfloat16, device=q.device)
 
-        # Q kernel
-        _rope_kernel[bt_grid](
-            q, q_out,
+        bh_grid = (B * H_total, T)
+        num_warps = 1
+
+        _rope_kernel[bh_grid](
+            x_cat, x_out,
+            self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
-            self.cos_cache, self.sin_cache,
-            B_val=B, T_val=T, H_val=Hq, D_val=D,
-            stride_xb=T * Hq * D,
-            stride_xt=Hq * D,
+            B_val=B, T_val=T, H_total=H_total, Hq_val=Hq, D_val=D,
+            stride_xb=T * H_total * D,
+            stride_xt=H_total * D,
             stride_xh=D,
             stride_xd=1,
-            stride_ob=Hq * T * D,
+            stride_ob=H_total * T * D,
             stride_oh=T * D,
             stride_ot=D,
             stride_od=1,
             s0=s0, s1=s1, half=half,
             BLOCK_D=BLOCK_D,
+            num_warps=num_warps,
         )
 
-        # K kernel
-        _rope_kernel[bt_grid](
-            k, k_out,
-            pos_t, pos_h, pos_w,
-            self.cos_cache, self.sin_cache,
-            B_val=B, T_val=T, H_val=Hkv, D_val=D,
-            stride_xb=T * Hkv * D,
-            stride_xt=Hkv * D,
-            stride_xh=D,
-            stride_xd=1,
-            stride_ob=Hkv * T * D,
-            stride_oh=T * D,
-            stride_ot=D,
-            stride_od=1,
-            s0=s0, s1=s1, half=half,
-            BLOCK_D=BLOCK_D,
-        )
+        # Split back into Q and K outputs
+ 

... (131 more chars)
args
{
  "content": "<5755 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 36873out 2068cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 38959out 68cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (893 chars)
shape=0 variant=eager tflops=0.178 gbps=255.270 ms=0.354
shape=0 variant=compiled tflops=0.503 gbps=721.699 ms=0.125
shape=0 variant=solution tflops=0.381 gbps=546.252 ms=0.165
shape=0 solution_peak_fraction=0.3035
shape=1 variant=eager tflops=0.180 gbps=263.466 ms=0.558
shape=1 variant=compiled tflops=0.455 gbps=664.016 ms=0.221
shape=1 variant=solution tflops=0.388 gbps=567.049 ms=0.259
shape=1 solution_peak_fraction=0.3150
shape=2 variant=eager tflops=0.074 gbps=116.373 ms=0.190
shape=2 variant=compiled tflops=0.168 gbps=262.864 ms=0.084
shape=2 variant=solution tflops=0.202 gbps=315.977 ms=0.070
shape=2 solution_peak_fraction=0.1755
shape=3 variant=eager tflops=0.169 gbps=242.063 ms=1.492
shape=3 variant=compiled tflops=0.461 gbps=661.323 ms=0.546
shape=3 variant=solution tflops=0.403 gbps=578.232 ms=0.624
shape=3 solution_peak_fraction=0.3212
peak_fraction: 0.2710
RESULT: OK
assistant
in 39042out 67cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (175 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,13 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Grid = (B * (Hq+Hkv), T): each program handles one (b, head, t) tuple.
-Position cache values shared across heads within the same warp.
-Uses num_warps=1 for minimal register pressure (D=64/128 are small).
+Grid = (B * T, num_warps). Each program handles one (b,t) position and loops
+over all heads of its tensor (either Q or K). Cache values computed once per
+(b,t) and reused across heads in the same warp.
+
+Key insight: D=64/128 is small enough that a single thread can process the
+entire vector per head. num_warps=1 is optimal for small vector sizes.
+
+Two-pass: one for Q, one for K, with separate grids.
 """
 
 import torch
@@ -22,8 +27,7 @@
     pos_w_ptr,
     B_val: tl.constexpr,
     T_val: tl.constexpr,
-    H_total: tl.constexpr,
-    Hq_val: tl.constexpr,
+    H_val: tl.constexpr,
     D_val: tl.constexpr,
     stride_xb,
     stride_xt,
@@ -38,28 +42,20 @@
     half: tl.constexpr,
     BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * (Hq+Hkv), T). Each program = one (b, h, t).
-    Cos/sin cache values are shared for same (b,t), so threads within warp
-    compute them once and reuse."""
-    pid_bh = tl.program_id(0)
-    pid_t = tl.program_id(1)
-    warp_id = tl.program_id(2)
-
-    # Decode (batch, head)
-    pid_b = pid_bh // H_total
-    pid_h = pid_bh % H_total
-
-    pid_t2 = pid_bh * T_val + pid_t
+    """Grid = (B * T). Each program handles one (b, t) and loops over H heads."""
+    pid_bt = tl.program_id(0)
+    pid_b = pid_bt // T_val
+    pid_t = pid_bt % T_val
 
     pos_idx = pid_b * T_val + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    # Load cache values
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
+    # Load cache values once
     cos_t_ = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t_ = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -67,29 +63,28 @@
     cos_w_ = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w_ = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
+    # Axis assignment (same for both halves)
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
     is_temporal = axis_idx < s0
     is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
     cos_val = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
     sin_val = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
 
-    # Permutation
+    # rotate_half permutation
     perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    # Input offset
-    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+    # Loop over heads
+    for h in tl.static_range(H_val):
+        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = x_rot * sign
 
-    # Load x
-    x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-    x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
-    x_rot = x_rot * sign
+        out = x * cos_val + x_rot * sin_val
 
-    out = x * cos_val + x_rot * sin_val
-
-    # Output offset (B,H,T,D)
-    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
-    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
@@ -128,41 +123,51 @@
         s0, s1, _ = self.mrope_section
         half = D // 2
         BLOCK_D = self._next_pow2(D)
-        H_total = Hq + Hkv
 
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
 
-        # Concatenate Q and K buffers along head dim:
-        # q is (B, T, Hq, D), k is (B, T, Hkv, D)
-        # Concat => x_cat (B, T, H_total, D)
-        x_cat = torch.cat([q, k], dim=2)
-        x_out = torch.empty(B, H_total, T, D, dtype=torch.bfloat16, device=q.device)
+        bt_grid = (B * T,)
 
-        bh_grid = (B * H_total, T)
-        num_warps = 1
-
-        _rope_kernel[bh_grid](
-            x_cat, x_out,
+        # Q kernel
+        _rope_kernel[bt_grid](
+            q, q_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
-            B_val=B, T_val=T, H_total=H_total, Hq_val=Hq, D_val=D,
-            stride_xb=T * H_total * D,
-            stride_xt=H_total * D,
+            B_val=B, T_val=T, H_val=Hq, D_val=D,
+            stride_xb=T * Hq * D,
+            stride_xt=Hq * D,
             stride_xh=D,
             stride_xd=1,
-            stride_ob=H_total * T * D,
+            stride_ob=Hq * T * D,
             stride_oh=T * D,
             stride_ot=D,
             stride_od=1,
             s0=s0, s1=s1, half=half,
             BLOCK_D=BLOCK_D,
-            num_warps=num_warps,
+            num_warps=1,
+            num_stages=4,
         )
 
-        # Split back into Q and K outputs
-        q_out = x_out[:, :Hq, :, :].contiguous()
-        k_out = x_out[:, Hq:, :, :].contiguous()
+        # K kernel
+        _rope_kernel[bt_grid](
+            k, k

... (539 more chars)
args
{
  "content": "<5866 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The concatenation killed performance. Let me go back to the B*T grid approach which got ~48% and optimize it.
assistant
in 39577out 2063cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness and benchmark"
}
tool
stdout (901 chars)
PASS
shape=0 variant=eager tflops=0.158 gbps=226.942 ms=0.398
shape=0 variant=compiled tflops=0.475 gbps=682.254 ms=0.132
shape=0 variant=solution tflops=0.735 gbps=1054.230 ms=0.086
shape=0 solution_peak_fraction=0.5857
shape=1 variant=eager tflops=0.209 gbps=305.908 ms=0.481
shape=1 variant=compiled tflops=0.479 gbps=700.200 ms=0.210
shape=1 variant=solution tflops=0.776 gbps=1133.119 ms=0.130
shape=1 solution_peak_fraction=0.6295
shape=2 variant=eager tflops=0.075 gbps=117.401 ms=0.188
shape=2 variant=compiled tflops=0.163 gbps=255.149 ms=0.087
shape=2 variant=solution tflops=0.333 gbps=519.895 ms=0.043
shape=2 solution_peak_fraction=0.2888
shape=3 variant=eager tflops=0.173 gbps=248.330 ms=1.454
shape=3 variant=compiled tflops=0.461 gbps=661.943 ms=0.546
shape=3 variant=solution tflops=0.874 gbps=1253.901 ms=0.288
shape=3 solution_peak_fraction=0.6966
peak_fraction: 0.5219
RESULT: OK
assistant
in 41658out 79cache_r 0cache_w 0
assistant
task
args
{
  "description": "Profile Triton kernel",
  "prompt": "<1570 chars \u2014 see below>"
}
prompt (1570 chars)
Run the benchmark script and also profile the solution to understand the bottleneck. 

First, change to: /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope

Run:
1. `uv run python benchmark.py` to see current performance numbers
2. Then run this profiling script:

```python
import torch
import torch.profiler
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard')

# Import solution
import solution

B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
MAX_POS = 32768
MROPE_SECTION = (16, 24, 24)

dev = "cuda:0"
model = solution.Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).to(dev).eval()

q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device=dev) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device=dev) * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device=dev)
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device=dev)
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device=dev)

# Warmup 
with torch.no_grad():
    model(q, k, pos_t, pos_h, pos_w)
    
torch.cuda.synchronize()

# Time with profiler
prof = torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CUDA],
    record_shapes=True,
)
for _ in range(50):
    with torch.no_grad():
        model(q, k, pos_t, pos_h, pos_w)
torch.cuda.synchronize()

prof.stop()
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20))
```

3. Also run the profiling for shape 2 (B=2, T=2048, Hq=16, Hkv=2, D=64, max_pos=16384)

Return the benchmark outputs and profiler tables for both shapes.
assistant
in 42210out 595cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (307 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,100 +1,149 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
+"""Optimized multi-axis RoPE pre-attention prep using fused CUDA C++ (SM120 Blackwell).
 
-Grid = (B * T, num_warps). Each program handles one (b,t) position and loops
-over all heads of its tensor (either Q or K). Cache values computed once per
-(b,t) and reused across heads in the same warp.
-
-Key insight: D=64/128 is small enough that a single thread can process the
-entire vector per head. num_warps=1 is optimal for small vector sizes.
-
-Two-pass: one for Q, one for K, with separate grids.
+Fused kernel streams q/k once: gather cos/sin from cache, mix by mrope_section,
+apply rotate-half, and write directly to transposed (B,H,T,D) layout.
 """
 
 import torch
-import triton
-import triton.language as tl
 import torch.nn as nn
 
+_KERNEL_SOURCE = r"""
+#include <cuda_bf16.h>
 
-@triton.jit
-def _rope_kernel(
-    x_ptr,
-    x_out_ptr,
-    cos_cache_ptr,
-    sin_cache_ptr,
-    pos_t_ptr,
-    pos_h_ptr,
-    pos_w_ptr,
-    B_val: tl.constexpr,
-    T_val: tl.constexpr,
-    H_val: tl.constexpr,
-    D_val: tl.constexpr,
-    stride_xb,
-    stride_xt,
-    stride_xh,
-    stride_xd,
-    stride_ob,
-    stride_oh,
-    stride_ot,
-    stride_od,
-    s0: tl.constexpr,
-    s1: tl.constexpr,
-    half: tl.constexpr,
-    BLOCK_D: tl.constexpr,
-):
-    """Grid = (B * T). Each program handles one (b, t) and loops over H heads."""
-    pid_bt = tl.program_id(0)
-    pid_b = pid_bt // T_val
-    pid_t = pid_bt % T_val
+struct bf16x2 { __nv_bfloat16 x, y; };
 
-    pos_idx = pid_b * T_val + pid_t
-    p_t = tl.load(pos_t_ptr + pos_idx)
-    p_h = tl.load(pos_h_ptr + pos_idx)
-    p_w = tl.load(pos_w_ptr + pos_idx)
+// Rotate half: for element i in [0, D):
+//   i < D/2:   rot[i] = -x[i + D/2]
+//   i >= D/2:  rot[i] = x[i - D/2]
+// Applied as: out[i] = x[i]*cos[i] + rot[i]*sin[i]
+// mRoPE: for index i, axis_pair = i < D/2 ? i : i - D/2
+//   axis_pair < s0       -> use temporal (t)
+//   s0 <= axis_pair < s0+s1 -> use height (h)
+//   otherwise               -> use width  (w)
 
-    d_offs = tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
+template <const int D, const bool IS_Q>
+__global__ void __launch_bounds__(256)
+mrope_kernel(
+    const __nv_bfloat16* __restrict__ x,     // (B, T, H, D) bf16
+    __nv_bfloat16* __restrict__ x_out,       // (B, H, T, D) bf16
+    const int64_t*  __restrict__ pos_t_arr,   // (B, T)
+    const int64_t*  __restrict__ pos_h_arr,
+    const int64_t*  __restrict__ pos_w_arr,
+    const __nv_bfloat16* __restrict__ cos_cache, // (max_pos, D)
+    const __nv_bfloat16* __restrict__ sin_cache,
+    const int B, const int T, const int H,
+    const int s0, const int s1,
+    const int D_half
+) {
+    const int bt = blockIdx.x;
+    const int b = bt / T;
+    int t = bt % T;
+    
+    const int h = blockIdx.y;
+    
+    // Position indices
+    const int pos_idx = b * T + t;
+    const int64_t pt = pos_t_arr[pos_idx];
+    const int64_t ph = pos_h_arr[pos_idx];
+    const int64_t pw = pos_w_arr[pos_idx];
+    
+    // Per-thread output: thread handles D elements (or D/2 if unrolled)
+    const int half = D_half;
+    
+    for (int i = threadIdx.x * 8; i < D; i += blockDim.x * 8) {
+        if (i >= D) break;
+        
+        // Process up to 8 elements
+        int d = i;
+        int d_end = min(i + 8, D);
+        
+        // Load position IDs once
+        // Load cos/sin from cache
+        float2 cos_vals[8], sin_vals[8];
+        
+        #pragma unroll
+        for (int idx = 0; idx < 8 && (d + idx) < D; ++idx) {
+            const int di = d + idx;
+            
+            // Axis pair
+            const int axis = (di < half) ? di : (di - half);
+            
+            // Select cos from cache based on axis
+            float ct, st;
+            if (axis < s0) {
+                // Temporal
+                float2 c = __bfloat1622float2(((const float2*)cos_cache)[pt * (D/2) + di/2]);
+                float2 s = __bfloat1622float2(((const float2*)sin_cache)[pt * (D/2) + di/2]);
+                if (di % 2 == 0) { ct = c.x; st = s.x; }
+                else             { ct = c.y; st = s.y; }
+            }
+            else if (axis < s0 + s1) {
+                // Height
+                float2 c = __bfloat1622float2(((const float2*)cos_cache)[ph * (D/2) + di/2]);
+                float2 s = __bfloat1622float2(((const float2*)sin_cache)[ph * (D/2) + di/2]);
+                if (di % 2 == 0) { ct = c.x; st = s.x; }
+                else             { ct = c.y; st = s.y; }
+            }
+            else {
+                // Width
+                float2 c = __bfloat1622float2(((const float2*)cos_cache)[pw * (D/2) + di/2]);
+                float2 s = __bfloat1622float2(((const float2*)sin_cache)[pw * (D/2) + di/2]);
+                if (di % 2 == 0) { ct = c.x; st = s.x; }
+                else             { ct = c.y; st = s.y; }
+            }
+            cos_vals[idx] = make_float2(ct, 0);
+            sin_vals[idx] = make_float2(st, 0);
+        }
+        
+        #pragma unroll
+        for (int idx = 0; idx < 8 && (d + idx) < D; ++idx) {
+            const int di = d + idx;
+            
+            // Load x
+            float2 x_packed = __bfloat1622float2(((const float2*)x)[(b * T * H * D + t * H * D + h * D + di) / 2]);
+            float x_val = (di % 2 == 0) ? x_packed.x : x_packed.y;
+            
+            // Load rotated element
+            int ri = (di < half) ? (di + half) : (di - half);
+            float2 rot_packed = __bfloat1622float2(((const float2*)x)[(b * T * H * D + t * H * D + h * D + ri) / 2]);
+            float rot_val = (ri % 2 == 0) ? rot_packed.x : rot_packed.y;
+            if (di < half) rot_val = -rot_val;
+           

... (5542 more chars)
args
{
  "content": "<5855 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me try a raw CUDA C++ kernel for better bandwidth utilization.
assistant
in 42852out 2025cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (296 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,144 +1,101 @@
-"""Optimized multi-axis RoPE pre-attention prep using fused CUDA C++ (SM120 Blackwell).
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Fused kernel streams q/k once: gather cos/sin from cache, mix by mrope_section,
-apply rotate-half, and write directly to transposed (B,H,T,D) layout.
+Key design: Grid = (B * T) where each program processes ALL heads for one (b,t).
+This maximizes L2 cache reuse on the cos/sin cache lookups since all heads
+in a block share the same position IDs.
+
+Per-axis mixing follows mRoPE convention:
+- First s0 pair-indices [0..s0) use temporal axis
+- Next s1 [s0..s0+s1) use height axis  
+- Remaining [s0+s1..D/2) use width axis
+- Second half [D/2..D) mirrors: index d gets axis d - D/2
 """
 
 import torch
+import triton
+import triton.language as tl
 import torch.nn as nn
 
-_KERNEL_SOURCE = r"""
-#include <cuda_bf16.h>
 
-struct bf16x2 { __nv_bfloat16 x, y; };
+@triton.jit
+def _rope_fwd_kernel(
+    x_ptr,
+    x_out_ptr,
+    cos_cache_ptr,
+    sin_cache_ptr,
+    pos_t_ptr,
+    pos_h_ptr,
+    pos_w_ptr,
+    B_val: tl.constexpr,
+    T_val: tl.constexpr,
+    H_val: tl.constexpr,
+    D_val: tl.constexpr,
+    stride_xb,
+    stride_xt,
+    stride_xh,
+    stride_xd,
+    stride_ob,
+    stride_oh,
+    stride_ot,
+    stride_od,
+    s0: tl.constexpr,
+    s1: tl.constexpr,
+    half: tl.constexpr,
+    BLOCK_D: tl.constexpr,
+):
+    """Grid = (B * T). Each program handles one (b, t) across all H heads.
+    Cos/sin cache is loaded once per (b,t) and reused across heads."""
+    pid_bt = tl.program_id(0)
+    pid_b = pid_bt // T_val
+    pid_t = pid_bt % T_val
 
-// Rotate half: for element i in [0, D):
-//   i < D/2:   rot[i] = -x[i + D/2]
-//   i >= D/2:  rot[i] = x[i - D/2]
-// Applied as: out[i] = x[i]*cos[i] + rot[i]*sin[i]
-// mRoPE: for index i, axis_pair = i < D/2 ? i : i - D/2
-//   axis_pair < s0       -> use temporal (t)
-//   s0 <= axis_pair < s0+s1 -> use height (h)
-//   otherwise               -> use width  (w)
+    pos_idx = pid_b * T_val + pid_t
+    p_t = tl.load(pos_t_ptr + pos_idx)
+    p_h = tl.load(pos_h_ptr + pos_idx)
+    p_w = tl.load(pos_w_ptr + pos_idx)
 
-template <const int D, const bool IS_Q>
-__global__ void __launch_bounds__(256)
-mrope_kernel(
-    const __nv_bfloat16* __restrict__ x,     // (B, T, H, D) bf16
-    __nv_bfloat16* __restrict__ x_out,       // (B, H, T, D) bf16
-    const int64_t*  __restrict__ pos_t_arr,   // (B, T)
-    const int64_t*  __restrict__ pos_h_arr,
-    const int64_t*  __restrict__ pos_w_arr,
-    const __nv_bfloat16* __restrict__ cos_cache, // (max_pos, D)
-    const __nv_bfloat16* __restrict__ sin_cache,
-    const int B, const int T, const int H,
-    const int s0, const int s1,
-    const int D_half
-) {
-    const int bt = blockIdx.x;
-    const int b = bt / T;
-    int t = bt % T;
-    
-    const int h = blockIdx.y;
-    
-    // Position indices
-    const int pos_idx = b * T + t;
-    const int64_t pt = pos_t_arr[pos_idx];
-    const int64_t ph = pos_h_arr[pos_idx];
-    const int64_t pw = pos_w_arr[pos_idx];
-    
-    // Per-thread output: thread handles D elements (or D/2 if unrolled)
-    const int half = D_half;
-    
-    for (int i = threadIdx.x * 8; i < D; i += blockDim.x * 8) {
-        if (i >= D) break;
-        
-        // Process up to 8 elements
-        int d = i;
-        int d_end = min(i + 8, D);
-        
-        // Load position IDs once
-        // Load cos/sin from cache
-        float2 cos_vals[8], sin_vals[8];
-        
-        #pragma unroll
-        for (int idx = 0; idx < 8 && (d + idx) < D; ++idx) {
-            const int di = d + idx;
-            
-            // Axis pair
-            const int axis = (di < half) ? di : (di - half);
-            
-            // Select cos from cache based on axis
-            float ct, st;
-            if (axis < s0) {
-                // Temporal
-                float2 c = __bfloat1622float2(((const float2*)cos_cache)[pt * (D/2) + di/2]);
-                float2 s = __bfloat1622float2(((const float2*)sin_cache)[pt * (D/2) + di/2]);
-                if (di % 2 == 0) { ct = c.x; st = s.x; }
-                else             { ct = c.y; st = s.y; }
-            }
-            else if (axis < s0 + s1) {
-                // Height
-                float2 c = __bfloat1622float2(((const float2*)cos_cache)[ph * (D/2) + di/2]);
-                float2 s = __bfloat1622float2(((const float2*)sin_cache)[ph * (D/2) + di/2]);
-                if (di % 2 == 0) { ct = c.x; st = s.x; }
-                else             { ct = c.y; st = s.y; }
-            }
-            else {
-                // Width
-                float2 c = __bfloat1622float2(((const float2*)cos_cache)[pw * (D/2) + di/2]);
-                float2 s = __bfloat1622float2(((const float2*)sin_cache)[pw * (D/2) + di/2]);
-                if (di % 2 == 0) { ct = c.x; st = s.x; }
-                else             { ct = c.y; st = s.y; }
-            }
-            cos_vals[idx] = make_float2(ct, 0);
-            sin_vals[idx] = make_float2(st, 0);
-        }
-        
-        #pragma unroll
-        for (int idx = 0; idx < 8 && (d + idx) < D; ++idx) {
-            const int di = d + idx;
-            
-            // Load x
-            float2 x_packed = __bfloat1622float2(((const float2*)x)[(b * T * H * D + t * H * D + h * D + di) / 2]);
-            float x_val = (di % 2 == 0) ? x_packed.x : x_packed.y;
-            
-            // Load rotated element
-            int ri = (di < half) ? (di + half) : (di - half);
-            float2 rot_packed = __bfloat1622float2(((const float2*)x)[(b * T * H * D + t * H * D + h * D + ri) / 2]);
-            float rot_val = (ri % 2 == 0) ? rot_packed.x : rot_packed.y;
-            if (di < half) r

... (5119 more chars)
args
{
  "content": "<5745 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me go back to the best Triton version (B*T grid) and optimize it further with multiple num_warps configs and cache prefetching:
assistant
in 44895out 2051cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py
diff (241 lines — TRUNCATED)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py
@@ -0,0 +1,238 @@
+"""Hand-tuned CUDA C++ kernel for multi-axis RoPE pre-attention prep (SM120 Blackwell).
+
+Single fused kernel: one thread per (b, t, h, D/2) pair, loads two bf16 at once,
+does both elements in the pair, computes cos/sin from the same axis lookup,
+applies rotate-half in one pass, and writes transposed (B,H,T,D).
+
+Uses __ldg for read-only loads and vectorized bf16x2 stores.
+"""
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+import os
+import tempfile
+
+CUDA_SRC = r"""
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+#include <cooperative_groups.h>
+
+namespace cg = cooperative_groups;
+
+// BF16 -> float2
+static __device__ __forceinline__ float2 bf16x2_to_float2(
+    __nv_bfloat162 v) {
+  return __bfloat1622float2(v);
+}
+
+static __device__ __forceinline__ __nv_bfloat162 float2_to_bf16x2(
+    float x, float y) {
+  return __floats2bfloat162_rn(x, y);
+}
+
+template <const int D>
+__global__ void __launch_bounds__(256) mrope_kernel(
+    const __nv_bfloat16* __restrict__ x_in,     // (B, T, Hx, D)
+    __nv_bfloat16* __restrict__ x_out,          // (B, Hx, T, D)
+    const int64_t* __restrict__ pos_t,
+    const int64_t* __restrict__ pos_h,
+    const int64_t* __restrict__ pos_w,
+    const __nv_bfloat16* __restrict__ cos_cache, // (max_pos, D)
+    const __nv_bfloat16* __restrict__ sin_cache,
+    const int B,
+    const int T,
+    const int Hx,
+    const int s0,
+    const int s1
+) {
+    // Each block of threads handles one (b, t); threads distribute over heads.
+    // Each thread handles ALL elements of D for ONE head.
+    // Grid = (B * T, Hx), block = (num_threads), each thread processes D elements.
+
+    const int bt = blockIdx.x;  // (B * T)
+    const int b = bt / T;
+    const int t = bt % T;
+
+    // Position IDs (shared across all threads in block)
+    const int pos_off = b * T + t;
+    const int64_t pt = __ldg(pos_t + pos_off);
+    const int64_t ph = __ldg(pos_h + pos_off);
+    const int64_t pw = __ldg(pos_w + pos_off);
+
+    // Thread handles multiple heads (if num_heads > block_size)
+    const int h = blockIdx.y * blockDim.x + threadIdx.x;
+    if (h >= Hx) return;
+
+    constexpr int half = D / 2;
+    constexpr int d_pairs = D / 2;  // number of bf16x2 pairs
+    
+    // Load cos/sin cache for each pair
+    __nv_bfloat162 cos_d[D/2];
+    __nv_bfloat162 sin_d[D/2];
+    
+    #pragma unroll
+    for (int p = 0; p < D/2; p++) {
+        int di = p;
+        int axis = (di < half) ? di : (di - half);
+        
+        const __nv_bfloat162 *cache_cos_row, *cache_sin_row;
+        if (axis < s0) {
+            cache_cos_row = (const __nv_bfloat162*)&cos_cache[pt * D];
+            cache_sin_row = (const __nv_bfloat162*)&sin_cache[pt * D];
+        } else if (axis < s0 + s1) {
+            cache_cos_row = (const __nv_bfloat162*)&cos_cache[ph * D];
+            cache_sin_row = (const __nv_bfloat162*)&sin_cache[ph * D];
+        } else {
+            cache_cos_row = (const __nv_bfloat162*)&cos_cache[pw * D];
+            cache_sin_row = (const __nv_bfloat162*)&sin_cache[pw * D];
+        }
+        cos_d[p] = __ldg(&cache_cos_row[p]);
+        sin_d[p] = __ldg(&cache_sin_row[p]);
+    }
+
+    // Input stride: BxTxDxH
+    const long long x_base = (long long)b * T * Hx * D + (long long)t * Hx * D + (long long)h * D;
+    const __nv_bfloat162* x_in_vec = (const __nv_bfloat162*)&x_in[x_base];
+    
+    for (int p = 0; p < d_pairs; p++) {
+        // Load x[p] - this is one pair at offset p
+        __nv_bfloat162 x_p = __ldg(&x_in_vec[p]);
+        
+        // rotate_half for pair p:
+        // if p < half/2 (i.e., both elements in first half): rot[p] = -x[p + half/2]
+        // if p >= half/2: rot[p] = x[p - half/2]
+        
+        __nv_bfloat162 rot_p;
+        float2 x_f0 = __bfloat1622float2(x_p);
+        int mirror_p = (p < half/2) ? (p + half/2) : (p - half/2);
+        
+        __nv_bfloat162 x_mirror = __ldg(&x_in_vec[mirror_p]);
+        float2 xm_f = __bfloat1622float2(x_mirror);
+        
+        if (p < half/2) {
+            // negate
+            xm_f.x = -xm_f.x;
+            xm_f.y = -xm_f.y;
+        }
+        rot_p = __floats2bfloat162_rn(xm_f.x, xm_f.y);
+
+        // Apply: out = x * cos + rot * sin  
+        float2 cos_f = __bfloat1622float2(cos_d[p]);
+        float2 sin_f = __bfloat1622float2(sin_d[p]);
+        
+        float2 out_f;
+        out_f.x = x_f0.x * cos_f.x + xm_f.x * sin_f.x;
+        out_f.y = x_f0.y * cos_f.y + xm_f.y * sin_f.y;
+        
+        // Store to transposed (B, H, T, D) at out_off = b * Hx * T * D + h * T * D + t * D + 2*p
+        long long out_off = (long long)b * Hx * T * D + (long long)h * T * D + (long long)t * D + (long long)p * 2;
+        x_out[out_off / 2] = float2_to_bf16x2(out_f.x, out_f.y);
+    }
+}
+
+// Kernel launcher dispatcher
+void launch_mrope(const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
+    const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+    const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+    int B, int T, int Hx, int D, int s0, int s1,
+    cudaStream_t stream) {
+    
+    dim3 block(256);
+    dim3 grid(B * T, (Hx + block.x - 1) / block.x);
+
+    // Dispatch on D
+    if (D == 64) {
+        mrope_kernel<64><<<grid, block, 0, stream>>>(
+            x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx, s0, s1);
+    } else if (D == 128) {
+        mrope_kernel<128><<<grid, block, 0, stream>>>(
+            x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx, s0, s1);
+    }
+}
+"""
+
+
+def _build_inv_freq(D, base=10000.0):
+    import torch
+    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+
+
+_mod = None
+
+def _get_module():
+    global _mod
+    if _mod is not None:
+        return _mod
+    
+    import

... (2173 more chars)
args
{
  "content": "<7811 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Now let me try a completely different approach with raw CUDA C++ via cpp_extension:
assistant
in 46964out 2744cache_r 0cache_w 0
assistant
bash
args
{
  "command": "<1068 chars \u2014 see below>",
  "timeout": 120000,
  "description": "Test CUDA kernel correctness"
}
command (1068 chars)
cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch
torch.set_printoptions(precision=4)

B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
MAX_POS = 32768
MROPE_SECTION = (16, 24, 24)

import solution_cuda
model = solution_cuda.Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda()

q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')

import reference
ref = reference.Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda()

with torch.no_grad():
    ref_q, ref_k = ref(q, k, pos_t, pos_h, pos_w)
    sol_q, sol_k = model(q, k, pos_t, pos_h, pos_w)

print('Q max diff:', (ref_q.float() - sol_q.float()).abs().max().item())
print('K max diff:', (ref_k.float() - sol_k.float()).abs().max().item())
" 2>&1
tool
stdout (11492 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
    subprocess.run(
  File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 23, in <module>
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py", line 236, in forward
    q_out = self._run(q, Hq)
            ^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py", line 218, in _run
    _get_module().launch_mrope(
    ^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py", lin

... (9992 more chars)
assistant
in 49727out 473cache_r 0cache_w 0
assistant
invalid
args
{
  "tool": "delete",
  "error": "Model tried to call unavailable tool 'delete'. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill."
}
tool
stdout (191 chars)
The arguments provided to the tool are invalid: Model tried to call unavailable tool 'delete'. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill.
assistant
The C++ approach needs too much iteration. Let me go back to Triton and focus on getting the best perf from it. The B*T grid was already at ~48%. Let me also try a pure PyTorch custom Triton approach that processes D in parallel across threads.
assistant
in 53725out 115cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (127 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,14 +1,11 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Key design: Grid = (B * T) where each program processes ALL heads for one (b,t).
-This maximizes L2 cache reuse on the cos/sin cache lookups since all heads
-in a block share the same position IDs.
-
-Per-axis mixing follows mRoPE convention:
-- First s0 pair-indices [0..s0) use temporal axis
-- Next s1 [s0..s0+s1) use height axis  
-- Remaining [s0+s1..D/2) use width axis
-- Second half [D/2..D) mirrors: index d gets axis d - D/2
+Key optimizations:
+- Grid = (B * T): each program processes all H heads for one (b,t) position
+- Cos/sin cache lookup done once per (b,t), reused for all heads
+- Uses vectorized bf16x2 loads (int32 packing) for half memory transactions
+- num_warps=8 with num_stages=4 for max throughput
+- Avoids register pressure by interleaving loads/computes per head
 """
 
 import torch
@@ -41,10 +38,9 @@
     s0: tl.constexpr,
     s1: tl.constexpr,
     half: tl.constexpr,
-    BLOCK_D: tl.constexpr,
 ):
     """Grid = (B * T). Each program handles one (b, t) across all H heads.
-    Cos/sin cache is loaded once per (b,t) and reused across heads."""
+    Vectorizes over D/2 bf16x2 pairs to minimize memory ops."""
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
@@ -54,41 +50,39 @@
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    d_offs = tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
+    # Work in bf16x2 pairs (D must be divisible by 2, half must be divisible by 2)
+    half_pairs = half // 2
+    total_pairs = D_val // 2
+    pid = tl.arange(0, 128)
+    # Load cos/sin
+    cos_t = tl.load(cos_cache_ptr + p_t * D_val + pid, mask=pid < D_val, other=0.0).to(tl.float32)
+    sin_t = tl.load(sin_cache_ptr + p_t * D_val + pid, mask=pid < D_val, other=0.0).to(tl.float32)
 
-    # Load cos/sin for temporal axis
-    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    # Load cos/sin for height axis
-    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    # Load cos/sin for width axis
-    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    # Store in transposed (B,H,T,D)
+    x2_f = x2.to(tl.bfloat16)._reinterpret_int16()
+    tl.store(x_out_ptr + out_base + p_offs * stride_od, x2, mask=p_offs < D_val)
 
-    # Axis index: d and d+half share the same axis pair
-    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_t = axis_idx < s0
-    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
 
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
-
-    # rotate_half permutation + sign
-    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
-    sign = tl.where(d_offs < half, -1.0, 1.0)
-
-    for h in tl.static_range(H_val):
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = x_rot * sign
-
-        out = x * cos_val + x_rot * sin_val
-
-        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+@triton.jit
+def mrope_2x2_kernel(
+    # Pointers to inputs
+    q_ptr, k_ptr,  # (B, T, Hd, D)
+    q_out, k_out,  # (B, Hd, T, D)
+    cos_cache, sin_cache,  # (M, D)
+    pos_t, pos_h, pos_w,  # (B, T)
+    # Strides (Q)
+    stride_qb, stride_qt, stride_qh, stride_qd,
+    stride_ob, stride_oh, stride_ot, stride_od,
+    # Parameters
+    stride_pos_t, stride_pos_wt,  # pos_t strides
+    Hq, Hk, D, BpT,
+    Dhalf,
+    s0, s1,
+    # Meta
+    BLOCK: tl.constexpr,
+):
+    pid = tl.program_id(0)
+    b, t = pid // T_val, pid % T_val
 
 
 class Model(nn.Module):
@@ -144,8 +138,7 @@
             stride_ot=D,
             stride_od=1,
             s0=s0, s1=s1, half=half,
-            BLOCK_D=BLOCK_D,
-            num_warps=1,
+            num_warps=8,
             num_stages=4,
         )
 
@@ -164,8 +157,7 @@
             stride_ot=D,
             stride_od=1,
             s0=s0, s1=s1, half=half,
-            BLOCK_D=BLOCK_D,
-            num_warps=1,
+            num_warps=8,
             num_stages=4,
         )
 
args
{
  "content": "<4912 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me try a more efficient approach. The bottleneck with Triton is the scalar bf16 loads/stores. I'll use int32 to pack two bf16 values for efficient loads/stores:
assistant
in 53921out 1792cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (146 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,11 +1,9 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Key optimizations:
-- Grid = (B * T): each program processes all H heads for one (b,t) position
-- Cos/sin cache lookup done once per (b,t), reused for all heads
-- Uses vectorized bf16x2 loads (int32 packing) for half memory transactions
-- num_warps=8 with num_stages=4 for max throughput
-- Avoids register pressure by interleaving loads/computes per head
+Grid = (B * T). Each thread loads/rotates/stores one full D vector.
+Cos/sin loaded once per block. Each program loops over H heads.
+
+num_warps=1, num_stages=1 is optimal for small D vectors.
 """
 
 import torch
@@ -38,9 +36,9 @@
     s0: tl.constexpr,
     s1: tl.constexpr,
     half: tl.constexpr,
+    BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * T). Each program handles one (b, t) across all H heads.
-    Vectorizes over D/2 bf16x2 pairs to minimize memory ops."""
+    """Grid = (B * T). Cos/sin cache loaded once per program, reused across H heads."""
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
@@ -50,46 +48,52 @@
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    # Work in bf16x2 pairs (D must be divisible by 2, half must be divisible by 2)
-    half_pairs = half // 2
-    total_pairs = D_val // 2
-    pid = tl.arange(0, 128)
-    # Load cos/sin
-    cos_t = tl.load(cos_cache_ptr + p_t * D_val + pid, mask=pid < D_val, other=0.0).to(tl.float32)
-    sin_t = tl.load(sin_cache_ptr + p_t * D_val + pid, mask=pid < D_val, other=0.0).to(tl.float32)
+    d_offs = tl.arange(0, BLOCK_D)
+    mask = d_offs < D_val
 
-    # Store in transposed (B,H,T,D)
-    x2_f = x2.to(tl.bfloat16)._reinterpret_int16()
-    tl.store(x_out_ptr + out_base + p_offs * stride_od, x2, mask=p_offs < D_val)
+    # Load cos/sin values for all three axes
+    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
+    # Per-element axis index
+    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+    is_temporal = axis_idx < s0
+    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
 
-@triton.jit
-def mrope_2x2_kernel(
-    # Pointers to inputs
-    q_ptr, k_ptr,  # (B, T, Hd, D)
-    q_out, k_out,  # (B, Hd, T, D)
-    cos_cache, sin_cache,  # (M, D)
-    pos_t, pos_h, pos_w,  # (B, T)
-    # Strides (Q)
-    stride_qb, stride_qt, stride_qh, stride_qd,
-    stride_ob, stride_oh, stride_ot, stride_od,
-    # Parameters
-    stride_pos_t, stride_pos_wt,  # pos_t strides
-    Hq, Hk, D, BpT,
-    Dhalf,
-    s0, s1,
-    # Meta
-    BLOCK: tl.constexpr,
-):
-    pid = tl.program_id(0)
-    b, t = pid // T_val, pid % T_val
+    # Select per-axis cos/sin
+    cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
+    sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
+
+    # Permutation and sign for rotate_half
+    perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    sign = tl.where(d_offs < half, -1.0, 1.0)
+
+    for h in tl.static_range(H_val):
+        # Input offset (B,T,H,D)
+        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+
+        # Load x and rotated x
+        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = x_rot * sign
+
+        # Apply rotation
+        out = x * cos_val + x_rot * sin_val
+
+        # Output offset (B,H,T,D)
+        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
     def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
         super().__init__()
         assert sum(mrope_section) == D // 2, \
-            f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
+            f"mrope_section must sum to D/2, got {mrope_section}"
         self.B = B
         self.T = T
         self.Hq = Hq
@@ -123,7 +127,6 @@
 
         bt_grid = (B * T,)
 
-        # Run Q kernel
         _rope_fwd_kernel[bt_grid](
             q, q_out,
             self.cos_cache, self.sin_cache,
@@ -138,11 +141,11 @@
             stride_ot=D,
             stride_od=1,
             s0=s0, s1=s1, half=half,
-            num_warps=8,
+            BLOCK_D=BLOCK_D,
+            num_warps=1,
             num_stages=4,
         )
 
-        # Run K kernel  
         _rope_fwd_kernel[bt_grid](
             k, k_out,
             self.cos_cache, self.sin_cache,
@@ -157,7 +160,8 @@
             stride_ot=D,
             stride_od=1,
             s0=s0, s1=s1, half=half,
-            num_warps=8,
+            BLOCK_D=BLOCK_D,
+            num_warps=1,
             num_stages=4,
         )
 
args
{
  "content": "<5485 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 55731out 1920cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness and benchmark"
}
tool
stdout (901 chars)
PASS
shape=0 variant=eager tflops=0.178 gbps=255.814 ms=0.353
shape=0 variant=compiled tflops=0.510 gbps=731.144 ms=0.123
shape=0 variant=solution tflops=0.722 gbps=1036.605 ms=0.087
shape=0 solution_peak_fraction=0.5759
shape=1 variant=eager tflops=0.209 gbps=305.877 ms=0.481
shape=1 variant=compiled tflops=0.667 gbps=973.543 ms=0.151
shape=1 variant=solution tflops=0.996 gbps=1454.381 ms=0.101
shape=1 solution_peak_fraction=0.8080
shape=2 variant=eager tflops=0.075 gbps=116.599 ms=0.190
shape=2 variant=compiled tflops=0.160 gbps=249.576 ms=0.089
shape=2 variant=solution tflops=0.282 gbps=440.957 ms=0.050
shape=2 solution_peak_fraction=0.2450
shape=3 variant=eager tflops=0.174 gbps=249.271 ms=1.449
shape=3 variant=compiled tflops=0.466 gbps=668.908 ms=0.540
shape=3 variant=solution tflops=0.886 gbps=1270.703 ms=0.284
shape=3 solution_peak_fraction=0.7059
peak_fraction: 0.5326
RESULT: OK
assistant
in 57670out 79cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (265 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,9 +1,7 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Grid = (B * T). Each thread loads/rotates/stores one full D vector.
-Cos/sin loaded once per block. Each program loops over H heads.
-
-num_warps=1, num_stages=1 is optimal for small D vectors.
+Grid = (B * T). Each program processes ALL heads at one (b,t).
+Cache values loaded once per (b,t), shared across heads.
 """
 
 import torch
@@ -38,7 +36,7 @@
     half: tl.constexpr,
     BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * T). Cos/sin cache loaded once per program, reused across H heads."""
+    """Grid = (B * T). Each program handles one (b, t) across all H heads."""
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
@@ -51,7 +49,7 @@
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
-    # Load cos/sin values for all three axes
+    # Load cache: 6 L2 loads per (b,t), shared across all heads
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -59,41 +57,92 @@
     cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    # Per-element axis index
+    # Axis index
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_temporal = axis_idx < s0
-    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
-
-    # Select per-axis cos/sin
-    cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
-    sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
-
-    # Permutation and sign for rotate_half
-    perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    is_t = axis_idx < s0
+    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+
+    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
     for h in tl.static_range(H_val):
-        # Input offset (B,T,H,D)
         x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-
-        # Load x and rotated x
         x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
         x_rot = x_rot * sign
-
-        # Apply rotation
         out = x * cos_val + x_rot * sin_val
-
-        # Output offset (B,H,T,D)
         o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
         tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+
+
+@triton.jit
+def _rope_batch_kernel(
+    x_ptr, x_out_ptr,
+    cos_cache_ptr, sin_cache_ptr,
+    pos_t_ptr, pos_h_ptr, pos_w_ptr,
+    B_val: tl.constexpr,
+    T_val: tl.constexpr,
+    H_total: tl.constexpr,
+    Hq_val: tl.constexpr,
+    D_val: tl.constexpr,
+    stride_xb, stride_xt, stride_xh, stride_xd,
+    stride_ob, stride_oh, stride_ot, stride_od,
+    s0: tl.constexpr, s1: tl.constexpr,
+    half: tl.constexpr,
+    BLOCK_D: tl.constexpr,
+):
+    """Grid = (B * T). Each program handles all Hq+Hkv heads for (b,t).
+    Position cache loaded once, applied to both Q and K heads."""
+    pid_bt = tl.program_id(0)
+    pid_b = pid_bt // T_val
+    pid_t = pid_bt % T_val
+
+    pos_idx = pid_b * T_val + pid_t
+    p_t = tl.load(pos_t_ptr + pos_idx)
+    p_h = tl.load(pos_h_ptr + pos_idx)
+    p_w = tl.load(pos_w_ptr + pos_idx)
+
+    d_offs = tl.arange(0, BLOCK_D)
+    mask = d_offs < D_val
+
+    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+
+    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+    is_t = axis_idx < s0
+    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+
+    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    sign = tl.where(d_offs < half, -1.0, 1.0)
+
+    # Q heads
+    for h_q in tl.static_range(Hq_val):
+        x_base = pid_b * stride_xb + pid_t * stride_xt + h_q * stride_xh
+        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = x_rot * sign
+        out = x * cos_val + x_rot * sin_val
+        o_base = pid_b * stride_ob + h_q * stride_oh + pid_t * stride_ot
+        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+
+    # K heads
+    Hkv_val = H_total - Hq_val
+    k_stride_xh = Hq_val * stride_xd + stride_xd  # Not correct, need to fix
+    # Actually we need separate strides for Q and K
 
 
 class Model(nn.Module):
     def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
         super().__i

... (4660 more chars)
args
{
  "content": "<9652 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Shape 2 (B=2, D=64) is the weakest at 24%. Let me try different kernel configs for different shapes. The issue is that small T*B with D=64 doesn't fully utilize the GPU. Let me try fusing Q into K into a single kernel call for all shapes:
assistant
in 58222out 3562cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 61803out 68cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (893 chars)
shape=0 variant=eager tflops=0.173 gbps=247.967 ms=0.364
shape=0 variant=compiled tflops=0.500 gbps=716.839 ms=0.126
shape=0 variant=solution tflops=0.637 gbps=914.612 ms=0.099
shape=0 solution_peak_fraction=0.5081
shape=1 variant=eager tflops=0.171 gbps=249.561 ms=0.589
shape=1 variant=compiled tflops=0.666 gbps=973.028 ms=0.151
shape=1 variant=solution tflops=0.629 gbps=918.457 ms=0.160
shape=1 solution_peak_fraction=0.5103
shape=2 variant=eager tflops=0.071 gbps=110.743 ms=0.200
shape=2 variant=compiled tflops=0.165 gbps=257.334 ms=0.086
shape=2 variant=solution tflops=0.295 gbps=461.415 ms=0.048
shape=2 solution_peak_fraction=0.2563
shape=3 variant=eager tflops=0.179 gbps=256.924 ms=1.405
shape=3 variant=compiled tflops=0.465 gbps=666.853 ms=0.542
shape=3 variant=solution tflops=0.512 gbps=735.097 ms=0.491
shape=3 solution_peak_fraction=0.4084
peak_fraction: 0.4059
RESULT: OK
assistant
in 61886out 67cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (236 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,8 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Grid = (B * T). Each program processes ALL heads at one (b,t).
+Grid = (B * T). Each program handles one (b, t) across all H heads.
 Cache values loaded once per (b,t), shared across heads.
+num_warps=1, num_stages=4 is optimal for D=64..128.
 """
 
 import torch
@@ -49,94 +50,33 @@
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
-    # Load cache: 6 L2 loads per (b,t), shared across all heads
+    # Load cos/sin values for all three axes
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    # Axis index
+    # Per-element axis index
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_t = axis_idx < s0
-    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+    is_temporal = axis_idx < s0
+    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
 
-    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
+    sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
+
+    perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
     for h in tl.static_range(H_val):
         x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
         x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
         x_rot = x_rot * sign
         out = x * cos_val + x_rot * sin_val
         o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
         tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
-
-
-@triton.jit
-def _rope_batch_kernel(
-    x_ptr, x_out_ptr,
-    cos_cache_ptr, sin_cache_ptr,
-    pos_t_ptr, pos_h_ptr, pos_w_ptr,
-    B_val: tl.constexpr,
-    T_val: tl.constexpr,
-    H_total: tl.constexpr,
-    Hq_val: tl.constexpr,
-    D_val: tl.constexpr,
-    stride_xb, stride_xt, stride_xh, stride_xd,
-    stride_ob, stride_oh, stride_ot, stride_od,
-    s0: tl.constexpr, s1: tl.constexpr,
-    half: tl.constexpr,
-    BLOCK_D: tl.constexpr,
-):
-    """Grid = (B * T). Each program handles all Hq+Hkv heads for (b,t).
-    Position cache loaded once, applied to both Q and K heads."""
-    pid_bt = tl.program_id(0)
-    pid_b = pid_bt // T_val
-    pid_t = pid_bt % T_val
-
-    pos_idx = pid_b * T_val + pid_t
-    p_t = tl.load(pos_t_ptr + pos_idx)
-    p_h = tl.load(pos_h_ptr + pos_idx)
-    p_w = tl.load(pos_w_ptr + pos_idx)
-
-    d_offs = tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
-
-    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-
-    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_t = axis_idx < s0
-    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
-
-    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
-    sign = tl.where(d_offs < half, -1.0, 1.0)
-
-    # Q heads
-    for h_q in tl.static_range(Hq_val):
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h_q * stride_xh
-        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = x_rot * sign
-        out = x * cos_val + x_rot * sin_val
-        o_base = pid_b * stride_ob + h_q * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
-
-    # K heads
-    Hkv_val = H_total - Hq_val
-    k_stride_xh = Hq_val * stride_xd + stride_xd  # Not correct, need to fix
-    # Actually we need separate strides for Q and K
 
 
 class Model(nn.Module):
@@ -170,23 +110,22 @@
         s0, s1, _ = self.mrope_section
         half = D // 2
         BLOCK_D = self._next_pow2(D)
+
+        q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
+        k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
+
         bt_grid = (B * T,)
 
-        # Use Q+K concat for fused processing
-        x_cat = torch.cat([q, k], dim=2).contiguous()
-        H_total = Hq + Hkv
-        x_out = torch.empty(B, H_total, T, D, dtype=torch.bfloat16, device=q

... (3901 more chars)
args
{
  "content": "<5193 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Concatenation and `num_warps=8` hurt performance significantly. The original approach with separate Q/K kernels and `num_warps=1` was better. Let me go back to that version and try a different optimization direction:
assistant
in 62421out 1895cache_r 0cache_w 0
assistant
edit /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (11 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -54,7 +54,7 @@
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
  "oldString": "    sin_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)",
  "newString": "    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)"
}
tool
stdout (26 chars)
Edit applied successfully.
assistant
Wait, I found a bug - line loading sin_h was using cos_cache instead of sin_cache!
assistant
in 64335out 174cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness and benchmark after bugfix"
}
tool
stdout (901 chars)
PASS
shape=0 variant=eager tflops=0.177 gbps=254.660 ms=0.354
shape=0 variant=compiled tflops=0.489 gbps=700.986 ms=0.129
shape=0 variant=solution tflops=0.719 gbps=1032.054 ms=0.087
shape=0 solution_peak_fraction=0.5734
shape=1 variant=eager tflops=0.172 gbps=250.794 ms=0.586
shape=1 variant=compiled tflops=0.480 gbps=701.323 ms=0.210
shape=1 variant=solution tflops=0.755 gbps=1103.185 ms=0.133
shape=1 solution_peak_fraction=0.6129
shape=2 variant=eager tflops=0.072 gbps=113.200 ms=0.195
shape=2 variant=compiled tflops=0.177 gbps=276.038 ms=0.080
shape=2 variant=solution tflops=0.348 gbps=543.183 ms=0.041
shape=2 solution_peak_fraction=0.3018
shape=3 variant=eager tflops=0.179 gbps=256.594 ms=1.407
shape=3 variant=compiled tflops=0.527 gbps=756.231 ms=0.478
shape=3 variant=solution tflops=0.876 gbps=1256.694 ms=0.287
shape=3 solution_peak_fraction=0.6982
peak_fraction: 0.5216
RESULT: OK
assistant
in 64527out 83cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (167 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,7 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Grid = (B * T). Each program handles one (b, t) across all H heads.
-Cache values loaded once per (b,t), shared across heads.
-num_warps=1, num_stages=4 is optimal for D=64..128.
+Grid = (B * T * num_warps). Each block processes one (b,t) with all H heads.
+Threads parallelize over D dimension for coalesced memory access.
 """
 
 import torch
@@ -12,7 +11,7 @@
 
 
 @triton.jit
-def _rope_fwd_kernel(
+def _rope_block_kernel(
     x_ptr,
     x_out_ptr,
     cos_cache_ptr,
@@ -35,48 +34,61 @@
     s0: tl.constexpr,
     s1: tl.constexpr,
     half: tl.constexpr,
-    BLOCK_D: tl.constexpr,
+    BLOCK_D_HEAD: tl.constexpr,
 ):
-    """Grid = (B * T). Each program handles one (b, t) across all H heads."""
+    """Grid = (B * T, H * triton.cdiv(D, BLOCK_D_HEAD)).
+    Each program handles one (b, t, h, d_block)."""
     pid_bt = tl.program_id(0)
+    pid_hd = tl.program_id(1)
+
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
+
+    H_total = H_val
+    D_blocks = tl.cdiv(D_val, BLOCK_D_HEAD)
+    pid_h = pid_hd // D_blocks
+    pid_d = pid_hd % D_blocks
+
+    if pid_h >= H_val:
+        return
 
     pos_idx = pid_b * T_val + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    d_offs = tl.arange(0, BLOCK_D)
+    d_start = pid_d * BLOCK_D_HEAD
+    d_offs = d_start + tl.arange(0, BLOCK_D_HEAD)
     mask = d_offs < D_val
 
-    # Load cos/sin values for all three axes
+    # Load cache
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    # Per-element axis index
+    # Axis index for each element
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_temporal = axis_idx < s0
-    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    is_t = axis_idx < s0
+    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h_, cos_w))
+    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h_, sin_w))
 
-    cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
-    sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
-
+    # Rotate-half permutation
     perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    for h in tl.static_range(H_val):
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = x_rot * sign
-        out = x * cos_val + x_rot * sin_val
-        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+    # Single head processing
+    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+    x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+    x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+    x_rot = x_rot * sign
+    out = x * cos_val + x_rot * sin_val
+
+    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
@@ -98,25 +110,23 @@
         self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
         self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
 
-    def _next_pow2(self, n):
-        v = 1
-        while v < n:
-            v *= 2
-        return v
-
     def forward(self, q, k, pos_t, pos_h, pos_w):
         B, T, Hq, D = q.shape
         _, _, Hkv, _ = k.shape
         s0, s1, _ = self.mrope_section
         half = D // 2
-        BLOCK_D = self._next_pow2(D)
 
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
 
-        bt_grid = (B * T,)
+        # For D=64: BLOCK=32 -> 2 blocks per head
+        # For D=128: BLOCK=32 -> 4 blocks per head
+        BLOCK_D_HEAD = 32
 
-        _rope_fwd_kernel[bt_grid](
+        # Q grid
+        D_blocks_q = triton.cdiv(Hq * D, BLOCK_D_HEAD)
+        q_grid = (B * T, D_blocks_q)
+        _rope_block_kernel[q_grid](
             q, q_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
@@ -130,12 +140,15 @@
             stride_ot=D,
             stride_od=1,
             s0=s0, s1=s1, half=half,
-            BLOCK_D=BLOCK_D,
-            num_warps=1,
+            BLOCK_D_HEAD=BLOCK_D_HEAD,
+            num_warps=4,
             num_stages=4,
         )
 
-        _rope_fwd_kernel[bt_grid](
+        # K grid
+        D_blocks_k = triton.cdiv(Hkv * D, BLOCK_D_HEAD)
+        k_grid = (B * T, D_blocks_k)
+        _rope_block_kernel[k_grid](
             k, k_out,
             self.cos_cache, self.s

... (316 more chars)
args
{
  "content": "<5593 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me try parallelizing D across warps for better memory coalescing:
assistant
in 65082out 2005cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (166 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Grid = (B * T * num_warps). Each block processes one (b,t) with all H heads.
-Threads parallelize over D dimension for coalesced memory access.
+Grid = (B * T). Each program handles one (b, t) across all H heads.
+Cache values loaded once per (b,t), shared across heads.
 """
 
 import torch
@@ -11,7 +11,7 @@
 
 
 @triton.jit
-def _rope_block_kernel(
+def _rope_fwd_kernel(
     x_ptr,
     x_out_ptr,
     cos_cache_ptr,
@@ -34,61 +34,48 @@
     s0: tl.constexpr,
     s1: tl.constexpr,
     half: tl.constexpr,
-    BLOCK_D_HEAD: tl.constexpr,
+    BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * T, H * triton.cdiv(D, BLOCK_D_HEAD)).
-    Each program handles one (b, t, h, d_block)."""
+    """Grid = (B * T). Each program handles one (b, t) across all H heads."""
     pid_bt = tl.program_id(0)
-    pid_hd = tl.program_id(1)
-
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
-
-    H_total = H_val
-    D_blocks = tl.cdiv(D_val, BLOCK_D_HEAD)
-    pid_h = pid_hd // D_blocks
-    pid_d = pid_hd % D_blocks
-
-    if pid_h >= H_val:
-        return
 
     pos_idx = pid_b * T_val + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    d_start = pid_d * BLOCK_D_HEAD
-    d_offs = d_start + tl.arange(0, BLOCK_D_HEAD)
+    d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
-    # Load cache
+    # Load cos/sin values for all three axes
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    # Axis index for each element
+    # Per-element axis index
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_t = axis_idx < s0
-    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h_, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h_, sin_w))
+    is_temporal = axis_idx < s0
+    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
 
-    # Rotate-half permutation
+    cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
+    sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
+
     perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    # Single head processing
-    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
-    x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-    x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-    x_rot = x_rot * sign
-    out = x * cos_val + x_rot * sin_val
-
-    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
-    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+    for h in tl.static_range(H_val):
+        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = x_rot * sign
+        out = x * cos_val + x_rot * sin_val
+        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
@@ -110,23 +97,25 @@
         self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
         self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
 
+    def _next_pow2(self, n):
+        v = 1
+        while v < n:
+            v *= 2
+        return v
+
     def forward(self, q, k, pos_t, pos_h, pos_w):
         B, T, Hq, D = q.shape
         _, _, Hkv, _ = k.shape
         s0, s1, _ = self.mrope_section
         half = D // 2
+        BLOCK_D = self._next_pow2(D)
 
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
 
-        # For D=64: BLOCK=32 -> 2 blocks per head
-        # For D=128: BLOCK=32 -> 4 blocks per head
-        BLOCK_D_HEAD = 32
+        bt_grid = (B * T,)
 
-        # Q grid
-        D_blocks_q = triton.cdiv(Hq * D, BLOCK_D_HEAD)
-        q_grid = (B * T, D_blocks_q)
-        _rope_block_kernel[q_grid](
+        _rope_fwd_kernel[bt_grid](
             q, q_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
@@ -140,15 +129,12 @@
             stride_ot=D,
             stride_od=1,
             s0=s0, s1=s1, half=half,
-            BLOCK_D_HEAD=BLOCK_D_HEAD,
-            num_warps=4,
+            BLOCK_D=BLOCK_D,
+            num_warps=1,
             num_stages=4,
         )
 
-        # K grid
-        D_blocks_k = triton.cdiv(Hkv * D, BLOCK_D_HEAD)
-        k_grid = (B * T, D_blocks_k)
-        _rope_block_kernel[k_grid](
+        _rope_fwd_kernel[bt_grid](
             k, k_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
@@ -162,8

... (263 more chars)
args
{
  "content": "<5141 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 67105out 1823cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (208 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Grid = (B * T). Each program handles one (b, t) across all H heads.
-Cache values loaded once per (b,t), shared across heads.
+Optimized Triton kernel with head-parallelism for small head counts.
+Grid layout: (B * T, (H * D) / THREADS_PER_BLOCK) to maximize occupancy.
 """
 
 import torch
@@ -19,10 +19,10 @@
     pos_t_ptr,
     pos_h_ptr,
     pos_w_ptr,
-    B_val: tl.constexpr,
-    T_val: tl.constexpr,
-    H_val: tl.constexpr,
-    D_val: tl.constexpr,
+    BATCH: tl.constexpr,
+    SEQ: tl.constexpr,
+    NUM_HEADS: tl.constexpr,
+    HEAD_DIM: tl.constexpr,
     stride_xb,
     stride_xt,
     stride_xh,
@@ -33,52 +33,62 @@
     stride_od,
     s0: tl.constexpr,
     s1: tl.constexpr,
-    half: tl.constexpr,
-    BLOCK_D: tl.constexpr,
+    HALF: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
 ):
-    """Grid = (B * T). Each program handles one (b, t) across all H heads."""
-    pid_bt = tl.program_id(0)
-    pid_b = pid_bt // T_val
-    pid_t = pid_bt % T_val
+    """Grid = (B * T * NUM_HEADS * triton.cdiv(D, BLOCK_SIZE)).
+    Each program handles a block of D for one (b, t, h)."""
+    pid = tl.program_id(0)
+    pid1 = tl.program_id(1)
 
-    pos_idx = pid_b * T_val + pid_t
+    # Decode (b, t, h) from pid
+    total_heads_seq = NUM_HEADS * SEQ
+    bt_h = pid
+    pid_bt = bt_h // NUM_HEADS
+    pid_h = bt_h % NUM_HEADS
+    pid_b = pid_bt // SEQ
+    pid_t = pid_bt % SEQ
+
+    # Each program handles BLOCK_SIZE consecutive D elements
+    offset_d = pid1 * BLOCK_SIZE
+    d_offs = offset_d + tl.arange(0, BLOCK_SIZE)
+    mask = d_offs < HEAD_DIM
+
+    # Gather position IDs
+    pos_idx = pid_b * SEQ + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
-    p_h = tl.load(pos_h_ptr + pos_idx)
+    p_h_val = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    d_offs = tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
+    # Load cos/sin for all three axes
+    cos_t = tl.load(cos_cache_ptr + p_t * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_t = tl.load(sin_cache_ptr + p_t * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h = tl.load(cos_cache_ptr + p_h_val * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h = tl.load(sin_cache_ptr + p_h_val * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_w = tl.load(cos_cache_ptr + p_w * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_w = tl.load(sin_cache_ptr + p_w * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    # Load cos/sin values for all three axes
-    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    axis_idx = tl.where(d_offs < HALF, d_offs, d_offs - HALF)
+    is_t = axis_idx < s0
+    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
 
-    # Per-element axis index
-    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_temporal = axis_idx < s0
-    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    perm = tl.where(d_offs < HALF, d_offs + HALF, d_offs - HALF)
+    sign = tl.where(d_offs < HALF, -1.0, 1.0)
 
-    cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
-    sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
+    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+    x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+    x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+    x_rot = x_rot * sign
 
-    perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
-    sign = tl.where(d_offs < half, -1.0, 1.0)
-
-    for h in tl.static_range(H_val):
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = x_rot * sign
-        out = x * cos_val + x_rot * sin_val
-        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+    out = x * cos_val + x_rot * sin_val
+    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
+    """Triton-based mRoPE kernel with per-(b,t,h,d_block) parallelism."""
     def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
         super().__init__()
         assert sum(mrope_section) == D // 2
@@ -97,59 +107,48 @@
         self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
         self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
 
-    def _next_pow2(self, n):
-        v = 1
-        while v < n:
-            v *= 2
-        return v
-
     def forward(self, q, k, pos_t, pos_h, pos_w):
         B, T, Hq, D = q.shape
         _, _, Hkv, _ = k.shape
         s0, s1, _ = self.mrope_section
-        half = D // 2
-   

... (2295 more chars)
args
{
  "content": "<5406 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me also try a C++ Triton-free approach using raw CUDA with torch.compile and custom autograd. Actually, let me first try to understand why num_warps>1 hurt and what's actually the bottleneck. Let me try the D-block approach but also with a smarter grid that processes multiple heads per program for D=64 case:
assistant
in 68946out 1999cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 70964out 69cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark"
}
tool
stdout (893 chars)
shape=0 variant=eager tflops=0.177 gbps=254.132 ms=0.355
shape=0 variant=compiled tflops=0.493 gbps=707.668 ms=0.128
shape=0 variant=solution tflops=0.166 gbps=238.865 ms=0.378
shape=0 solution_peak_fraction=0.1327
shape=1 variant=eager tflops=0.210 gbps=306.510 ms=0.480
shape=1 variant=compiled tflops=0.466 gbps=681.047 ms=0.216
shape=1 variant=solution tflops=0.169 gbps=246.726 ms=0.596
shape=1 solution_peak_fraction=0.1371
shape=2 variant=eager tflops=0.074 gbps=114.989 ms=0.192
shape=2 variant=compiled tflops=0.172 gbps=268.011 ms=0.083
shape=2 variant=solution tflops=0.126 gbps=197.063 ms=0.112
shape=2 solution_peak_fraction=0.1095
shape=3 variant=eager tflops=0.172 gbps=246.378 ms=1.466
shape=3 variant=compiled tflops=0.462 gbps=663.461 ms=0.544
shape=3 variant=solution tflops=0.125 gbps=179.788 ms=2.008
shape=3 solution_peak_fraction=0.0999
peak_fraction: 0.1188
RESULT: OK
assistant
in 71047out 67cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (208 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Optimized Triton kernel with head-parallelism for small head counts.
-Grid layout: (B * T, (H * D) / THREADS_PER_BLOCK) to maximize occupancy.
+Grid = (B * T). Each program handles one (b, t) across all H heads.
+Cache values loaded once per (b,t), shared across heads.
 """
 
 import torch
@@ -19,10 +19,10 @@
     pos_t_ptr,
     pos_h_ptr,
     pos_w_ptr,
-    BATCH: tl.constexpr,
-    SEQ: tl.constexpr,
-    NUM_HEADS: tl.constexpr,
-    HEAD_DIM: tl.constexpr,
+    B_val: tl.constexpr,
+    T_val: tl.constexpr,
+    H_val: tl.constexpr,
+    D_val: tl.constexpr,
     stride_xb,
     stride_xt,
     stride_xh,
@@ -33,62 +33,52 @@
     stride_od,
     s0: tl.constexpr,
     s1: tl.constexpr,
-    HALF: tl.constexpr,
-    BLOCK_SIZE: tl.constexpr,
+    half: tl.constexpr,
+    BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * T * NUM_HEADS * triton.cdiv(D, BLOCK_SIZE)).
-    Each program handles a block of D for one (b, t, h)."""
-    pid = tl.program_id(0)
-    pid1 = tl.program_id(1)
+    """Grid = (B * T). Each program handles one (b, t) across all H heads."""
+    pid_bt = tl.program_id(0)
+    pid_b = pid_bt // T_val
+    pid_t = pid_bt % T_val
 
-    # Decode (b, t, h) from pid
-    total_heads_seq = NUM_HEADS * SEQ
-    bt_h = pid
-    pid_bt = bt_h // NUM_HEADS
-    pid_h = bt_h % NUM_HEADS
-    pid_b = pid_bt // SEQ
-    pid_t = pid_bt % SEQ
-
-    # Each program handles BLOCK_SIZE consecutive D elements
-    offset_d = pid1 * BLOCK_SIZE
-    d_offs = offset_d + tl.arange(0, BLOCK_SIZE)
-    mask = d_offs < HEAD_DIM
-
-    # Gather position IDs
-    pos_idx = pid_b * SEQ + pid_t
+    pos_idx = pid_b * T_val + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
-    p_h_val = tl.load(pos_h_ptr + pos_idx)
+    p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    # Load cos/sin for all three axes
-    cos_t = tl.load(cos_cache_ptr + p_t * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_t = tl.load(sin_cache_ptr + p_t * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h = tl.load(cos_cache_ptr + p_h_val * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h_val * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_w = tl.load(cos_cache_ptr + p_w * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_w = tl.load(sin_cache_ptr + p_w * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+    d_offs = tl.arange(0, BLOCK_D)
+    mask = d_offs < D_val
 
-    axis_idx = tl.where(d_offs < HALF, d_offs, d_offs - HALF)
-    is_t = axis_idx < s0
-    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+    # Load cos/sin values for all three axes
+    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    perm = tl.where(d_offs < HALF, d_offs + HALF, d_offs - HALF)
-    sign = tl.where(d_offs < HALF, -1.0, 1.0)
+    # Per-element axis index
+    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+    is_temporal = axis_idx < s0
+    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
 
-    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
-    x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-    x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
-    x_rot = x_rot * sign
+    cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
+    sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
 
-    out = x * cos_val + x_rot * sin_val
-    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
-    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+    perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    sign = tl.where(d_offs < half, -1.0, 1.0)
+
+    for h in tl.static_range(H_val):
+        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = x_rot * sign
+        out = x * cos_val + x_rot * sin_val
+        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
-    """Triton-based mRoPE kernel with per-(b,t,h,d_block) parallelism."""
     def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
         super().__init__()
         assert sum(mrope_section) == D // 2
@@ -107,48 +97,59 @@
         self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
         self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
 
+    def _next_pow2(self, n):
+        v = 1
+        while v < n:
+            v *= 2
+        return v
+
     def forward(self, q, k, pos_t, pos_h, pos_w):
         B, T, Hq, D = q.shape
         _, _, Hkv, _ = k.shape
         s0, s1, _ = self.mrope_section
+        half = D // 2
+   

... (2295 more chars)
args
{
  "content": "<5141 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
That 2D grid was terrible. The extra overhead is killing us. Let me go back to the best version we had and focus on improving it:
assistant
in 71582out 1854cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (120 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,15 @@
 """Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Grid = (B * T). Each program handles one (b, t) across all H heads.
-Cache values loaded once per (b,t), shared across heads.
+Optimization strategies tested:
+1. Grid = (B * T) vs (B * T * H) - former wins by maximizing cache reuse
+2. num_warps=1 is best for small D (64-128) - minimal register pressure
+3. num_stages=4 gives best pipelining without spilling
+4. Separate Q/K kernels better than concatenation (avoids extra memory ops)
+
+The B*T grid gives ~50-70% peak bandwidth, which is very good for a
+bandwidth-bound operation. The remaining gap comes from: 
+- Triton overhead for small programs (D=64/128 is tiny)
+- Per-head loop iterations causing serial execution
 """
 
 import torch
@@ -36,7 +44,6 @@
     half: tl.constexpr,
     BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * T). Each program handles one (b, t) across all H heads."""
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
@@ -49,7 +56,6 @@
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
-    # Load cos/sin values for all three axes
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -57,21 +63,19 @@
     cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    # Per-element axis index
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_temporal = axis_idx < s0
-    is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    is_t = axis_idx < s0
+    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
 
-    cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
-    sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
-
-    perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
     for h in tl.static_range(H_val):
         x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
         x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
         x_rot = x_rot * sign
         out = x * cos_val + x_rot * sin_val
         o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
@@ -116,41 +120,27 @@
         bt_grid = (B * T,)
 
         _rope_fwd_kernel[bt_grid](
-            q, q_out,
-            self.cos_cache, self.sin_cache,
+            q, q_out, self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hq, D_val=D,
-            stride_xb=T * Hq * D,
-            stride_xt=Hq * D,
-            stride_xh=D,
-            stride_xd=1,
-            stride_ob=Hq * T * D,
-            stride_oh=T * D,
-            stride_ot=D,
-            stride_od=1,
-            s0=s0, s1=s1, half=half,
-            BLOCK_D=BLOCK_D,
-            num_warps=1,
-            num_stages=4,
+            stride_xb=T * Hq * D, stride_xt=Hq * D,
+            stride_xh=D, stride_xd=1,
+            stride_ob=Hq * T * D, stride_oh=T * D,
+            stride_ot=D, stride_od=1,
+            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+            num_warps=1, num_stages=4,
         )
 
         _rope_fwd_kernel[bt_grid](
-            k, k_out,
-            self.cos_cache, self.sin_cache,
+            k, k_out, self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hkv, D_val=D,
-            stride_xb=T * Hkv * D,
-            stride_xt=Hkv * D,
-            stride_xh=D,
-            stride_xd=1,
-            stride_ob=Hkv * T * D,
-            stride_oh=T * D,
-            stride_ot=D,
-            stride_od=1,
-            s0=s0, s1=s1, half=half,
-            BLOCK_D=BLOCK_D,
-            num_warps=1,
-            num_stages=4,
+            stride_xb=T * Hkv * D, stride_xt=Hkv * D,
+            stride_xh=D, stride_xd=1,
+            stride_ob=Hkv * T * D, stride_oh=T * D,
+            stride_ot=D, stride_od=1,
+            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+            num_warps=1, num_stages=4,
         )
 
         return q_out, k_out
args
{
  "content": "<5192 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 73455out 1861cache_r 0cache_w 0
assistant
...
assistant
bash
args
{
  "command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness and benchmark"
}
tool
stdout (900 chars)
PASS
shape=0 variant=eager tflops=0.178 gbps=255.270 ms=0.354
shape=0 variant=compiled tflops=0.509 gbps=730.576 ms=0.124
shape=0 variant=solution tflops=0.704 gbps=1010.249 ms=0.089
shape=0 solution_peak_fraction=0.5612
shape=1 variant=eager tflops=0.166 gbps=242.788 ms=0.605
shape=1 variant=compiled tflops=0.461 gbps=673.212 ms=0.218
shape=1 variant=solution tflops=0.729 gbps=1065.197 ms=0.138
shape=1 solution_peak_fraction=0.5918
shape=2 variant=eager tflops=0.062 gbps=96.536 ms=0.229
shape=2 variant=compiled tflops=0.163 gbps=255.008 ms=0.087
shape=2 variant=solution tflops=0.311 gbps=485.393 ms=0.046
shape=2 solution_peak_fraction=0.2697
shape=3 variant=eager tflops=0.168 gbps=240.959 ms=1.499
shape=3 variant=compiled tflops=0.517 gbps=742.229 ms=0.487
shape=3 variant=solution tflops=0.974 gbps=1397.805 ms=0.258
shape=3 solution_peak_fraction=0.7766
peak_fraction: 0.5135
RESULT: OK
assistant
in 75335out 81cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (136 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,15 +1,6 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — Triton kernel (SM120 Blackwell).
 
-Optimization strategies tested:
-1. Grid = (B * T) vs (B * T * H) - former wins by maximizing cache reuse
-2. num_warps=1 is best for small D (64-128) - minimal register pressure
-3. num_stages=4 gives best pipelining without spilling
-4. Separate Q/K kernels better than concatenation (avoids extra memory ops)
-
-The B*T grid gives ~50-70% peak bandwidth, which is very good for a
-bandwidth-bound operation. The remaining gap comes from: 
-- Triton overhead for small programs (D=64/128 is tiny)
-- Per-head loop iterations causing serial execution
+Grid = B*T.  Each program processes all H heads for one (b,t).
 """
 
 import torch
@@ -20,29 +11,15 @@
 
 @triton.jit
 def _rope_fwd_kernel(
-    x_ptr,
-    x_out_ptr,
-    cos_cache_ptr,
-    sin_cache_ptr,
-    pos_t_ptr,
-    pos_h_ptr,
-    pos_w_ptr,
-    B_val: tl.constexpr,
-    T_val: tl.constexpr,
-    H_val: tl.constexpr,
-    D_val: tl.constexpr,
-    stride_xb,
-    stride_xt,
-    stride_xh,
-    stride_xd,
-    stride_ob,
-    stride_oh,
-    stride_ot,
-    stride_od,
-    s0: tl.constexpr,
-    s1: tl.constexpr,
-    half: tl.constexpr,
-    BLOCK_D: tl.constexpr,
+    x_ptr, x_out_ptr,
+    cos_cache_ptr, sin_cache_ptr,
+    pos_t_ptr, pos_h_ptr, pos_w_ptr,
+    B_val: tl.constexpr, T_val: tl.constexpr,
+    H_val: tl.constexpr, D_val: tl.constexpr,
+    stride_xb, stride_xt, stride_xh, stride_xd,
+    stride_ob, stride_oh, stride_ot, stride_od,
+    s0: tl.constexpr, s1: tl.constexpr,
+    half: tl.constexpr, BLOCK_D: tl.constexpr,
 ):
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
@@ -69,15 +46,15 @@
     cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
     sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
 
-    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
     for h in tl.static_range(H_val):
         x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        x_rot = x_rot * sign
-        out = x * cos_val + x_rot * sin_val
+        x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        xr = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        xr = xr * sign
+        out = x * cos_val + xr * sin_val
         o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
         tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
@@ -86,11 +63,8 @@
     def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
         super().__init__()
         assert sum(mrope_section) == D // 2
-        self.B = B
-        self.T = T
-        self.Hq = Hq
-        self.Hkv = Hkv
-        self.D = D
+        self.B, self.T = B, T
+        self.Hq, self.Hkv, self.D = Hq, Hkv, D
         self.mrope_section = tuple(mrope_section)
         self.max_pos = max_pos
 
@@ -116,33 +90,28 @@
 
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
-
         bt_grid = (B * T,)
 
         _rope_fwd_kernel[bt_grid](
-            q, q_out, self.cos_cache, self.sin_cache,
+            q, q_out,
+            self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hq, D_val=D,
-            stride_xb=T * Hq * D, stride_xt=Hq * D,
-            stride_xh=D, stride_xd=1,
-            stride_ob=Hq * T * D, stride_oh=T * D,
-            stride_ot=D, stride_od=1,
+            stride_xb=T * Hq * D, stride_xt=Hq * D, stride_xh=D, stride_xd=1,
+            stride_ob=Hq * T * D, stride_oh=T * D, stride_ot=D, stride_od=1,
             s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
             num_warps=1, num_stages=4,
         )
-
         _rope_fwd_kernel[bt_grid](
-            k, k_out, self.cos_cache, self.sin_cache,
+            k, k_out,
+            self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hkv, D_val=D,
-            stride_xb=T * Hkv * D, stride_xt=Hkv * D,
-            stride_xh=D, stride_xd=1,
-            stride_ob=Hkv * T * D, stride_oh=T * D,
-            stride_ot=D, stride_od=1,
+            stride_xb=T * Hkv * D, stride_xt=Hkv * D, stride_xh=D, stride_xd=1,
+            stride_ob=Hkv * T * D, stride_oh=T * D, stride_ot=D, stride_od=1,
             s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
             num_warps=1, num_stages=4,
         )
-
         return q_out, k_out
 
 
args
{
  "content": "<4598 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
...
assistant
in 75888out 1695cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 77602out 68cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (259 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,65 +1,135 @@
-"""Multi-axis RoPE pre-attention prep — Triton kernel (SM120 Blackwell).
-
-Grid = B*T.  Each program processes all H heads for one (b,t).
+"""Optimized CUDA C++ kernel for the multi-axis RoPE pre-attention prep.
+
+Uses torch.utils.cpp_extension.load_inline with a single well-tuned kernel.
+Each thread processes ONE element (full D vector) for ONE (b,h,t).
+Grid = B * (Hq + Hkv) * T (one block per (b,h,t)).
+Block = D threads (one thread per D element - coalesced load/store).
 """
 
 import torch
-import triton
-import triton.language as tl
 import torch.nn as nn
-
-
-@triton.jit
-def _rope_fwd_kernel(
-    x_ptr, x_out_ptr,
-    cos_cache_ptr, sin_cache_ptr,
-    pos_t_ptr, pos_h_ptr, pos_w_ptr,
-    B_val: tl.constexpr, T_val: tl.constexpr,
-    H_val: tl.constexpr, D_val: tl.constexpr,
-    stride_xb, stride_xt, stride_xh, stride_xd,
-    stride_ob, stride_oh, stride_ot, stride_od,
-    s0: tl.constexpr, s1: tl.constexpr,
-    half: tl.constexpr, BLOCK_D: tl.constexpr,
-):
-    pid_bt = tl.program_id(0)
-    pid_b = pid_bt // T_val
-    pid_t = pid_bt % T_val
-
-    pos_idx = pid_b * T_val + pid_t
-    p_t = tl.load(pos_t_ptr + pos_idx)
-    p_h = tl.load(pos_h_ptr + pos_idx)
-    p_w = tl.load(pos_w_ptr + pos_idx)
-
-    d_offs = tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
-
-    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-
-    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_t = axis_idx < s0
-    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
-
-    perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
-    sign = tl.where(d_offs < half, -1.0, 1.0)
-
-    for h in tl.static_range(H_val):
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = xr * sign
-        out = x * cos_val + xr * sin_val
-        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+from torch.utils.cpp_extension import load_inline
+import os
+
+CUDA_SRC = r"""
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+
+template<int D, int D_half, int s0_val, int s1_val>
+__global__ void __launch_bounds__(128)
+rope_kernel(
+    const __nv_bfloat16* __restrict__ x_in,   // (B, T, Hd, D)
+    __nv_bfloat16* __restrict__ x_out,         // (B, Hd, T, D)
+    const int64_t* __restrict__ pos_t,
+    const int64_t* __restrict__ pos_h,
+    const int64_t* __restrict__ pos_w,
+    const __nv_bfloat16* __restrict__ cos_cache,
+    const __nv_bfloat16* __restrict__ sin_cache,
+    int B, int T, int Hx
+) {
+    // blockIdx.x = b * T + t
+    int bt = blockIdx.x;
+    int b = bt / T;
+    int t = bt % T;
+
+    // Position IDs
+    int64_t pt = pos_t[b * T + t];
+    int64_t ph = pos_h[b * T + t];
+    int64_t pw = pos_w[b * T + t];
+
+    // Thread index within block: handles head index
+    int h = blockIdx.y * blockDim.x + threadIdx.x;
+    if (h >= Hx) return;
+
+    // Per-thread: process D elements (D/2 pairs)
+    const int np = D / 2;
+    #define CACHE_COS(p, pos) cos_cache[pos * D + p]
+    #define CACHE_SIN(p, pos) sin_cache[pos * D + p]
+
+    for (int pi = 0; pi < np; ++pi) {
+        int d0 = pi * 2;
+        int d1 = pi * 2 + 1;
+
+        // Determine axis for both elements
+        int axis0 = (d0 < D_half) ? d0 : (d0 - D_half);
+        int axis1 = (d1 < D_half) ? d1 : (d1 - D_half);
+
+        // Load from cache
+        float2 c0, s0, c1, s1;
+
+        // For element d0
+        int64_t pos0 = (axis0 < s0_val) ? pt : ((axis0 < s0_val + s1_val) ? ph : pw);
+        float2 tmp = __bfloat1622float2(((float2*)CACHE_COS(d0, pos0))[0]);
+        c0 = *(float2*)&tmp;  // already a float2
+
+        // Use __ldg for cache optimization
+        const float2* cos_row = (const float2*)(cos_cache + pos0 * D);
+        const float2* sin_row = (const float2*)(sin_cache + pos0 * D);
+        c0 = __ldg(&cos_row[pi]).x;  // This is getting complicated
+
+        // Simpler: load single bf16 values
+        float cos_d0_x = __bfloat162float(__ldg(&CACHE_COS(d0, pos0)));
+        float sin_d0_x = __bfloat162float(__ldg(&CACHE_SIN(d0, pos0)));
+        float cos_d0_y = __bfloat162float(__ldg(&CACHE_COS(d1, pos0)));
+        float sin_d0_y = __bfloat162float(__ldg(&CACHE_SIN(d1, pos0)));
+
+        // Input pointer for this head
+        const __nv_bfloat16* x_row = x_in + ((b * T * Hx + t * Hx + h) * D);
+
+        // Load input bf16 pairs
+        float x0 = __bfloat162float(__ldg(&x_row[d0]));
+        float x1 = __bfloat162float(__ldg(&x_row[d1]));
+
+        // rotate_half mirror indices
+        int r0 = (d0 < D_half) ? (d0 + D_half) : (d0 - D_half);
+        int r1 = (d1 < D_half) ? (d1 + D_half) : (d1 - D_half);
+
+        float xr0 = __bfloat162float(__ldg(&x_row[r0]));
+        float xr1 = __bfloat162float(__ldg(&x_row[r1]));
+
+        // Negate if in first half
+        if (d0 < D_half) { xr0 = -xr0; }
+        if (d1 < 

... (4279 more chars)
args
{
  "content": "<8845 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
...
assistant
in 77685out 3275cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda_v2.py
diff (140 lines)
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda_v2.py
@@ -0,0 +1,137 @@
+"""Optimized multi-axis RoPE pre-attention prep - CUDA C++ via load_inline.
+
+Single kernel: one thread per (b, h, t) tuple, processes full D vector.
+Uses __ldg for cache-optimized reads, vectorized bf16 x2 ops for stores.
+Grid = (B * T, ceil(H / 256)), block = 256 threads.
+"""
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+import os
+
+CUDA_SRC = r"""
+#include <cuda_bf16.h>
+
+// Generic kernel: each thread handles all D elements for one (b, h, t)
+template<int D, int HALF, int S0, int S1>
+__global__ void __launch_bounds__(256)
+rope_kernel(
+    const __nv_bfloat16* __restrict__ x_in,   // (B, T, Hx, D) bf16
+    __nv_bfloat16* __restrict__ x_out,          // (B, Hx, T, D) bf16
+    const int64_t* __restrict__ pos_t,          // (B, T) int64
+    const int64_t* __restrict__ pos_h,
+    const int64_t* __restrict__ pos_w,
+    const __nv_bfloat16* __restrict__ cos_cache, // (max_pos, D) bf16
+    const __nv_bfloat16* __restrict__ sin_cache,
+    int B, int T, int Hx
+) {
+    int bt = blockIdx.x;
+    int b = bt / T;
+    int t = bt % T;
+
+    int64_t pos_idx = (int64_t)b * T + t;
+    int64_t pt = pos_t[pos_idx];
+    int64_t ph = pos_h[pos_idx];
+    int64_t pw = pos_w[pos_idx];
+
+    int h = blockIdx.y * blockDim.x + threadIdx.x;
+    if (h >= Hx) return;
+
+    constexpr int D2 = D / 2;  // number of pairs
+    constexpr int H2 = HALF;   // half of D (same as D2 for standard RoPE)
+
+    // Precompute which axis each pair belongs to
+    static __device__ constexpr struct { short pair; short axis; char is_t; char is_h; } PAIR_META[D2] = {
+        // Generated per-template-instantiation at compile time via unrolled init
+        // We'll compute at runtime below instead (no register cost for D<=64 pairs)
+    };
+
+    // Process each pair (2 bf16 elements)
+    const __nv_bfloat16* cos_base_t = cos_cache + pt * D;
+    const __nv_bfloat16* sin_base_t = sin_cache + pt * D;
+    const __nv_bfloat16* cos_base_h = cos_cache + ph * D;
+    const __nv_bfloat16* sin_base_h = sin_cache + ph * D;
+    const __nv_bfloat16* cos_base_w = cos_cache + pw * D;
+    const __nv_bfloat16* sin_base_w = sin_cache + pw * D;
+
+    // Load input
+    const __nv_bfloat16* x_row = x_in + ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
+
+    // Output pointer
+    __nv_bfloat16* out_row = x_out + ((int64_t)b * Hx * T + (int64_t)h * T + (int64_t)t) * D;
+
+    #pragma unroll
+    for (int d = 0; d < D; d++) {
+        // Mirror index for rotate_half
+        int mirror = (d < H2) ? (d + H2) : (d - H2);
+
+        // Determine axis
+        int axis = (d < H2) ? d : (d - H2);
+
+        // Select cache base
+        bool is_temp = axis < S0;
+        bool is_height = (axis >= S0) && (axis < S0 + S1);
+
+        const __nv_bfloat16* cos_base = is_temp ? cos_base_t : (is_height ? cos_base_h : cos_base_w);
+        const __nv_bfloat16* sin_base = is_temp ? sin_base_t : (is_height ? sin_base_h : sin_base_w);
+
+        // Load values
+        float xv = __bfloat162float(__ldg(x_row + d));
+        float mv = __bfloat162float(__ldg(x_row + mirror));
+        float cos_v = __bfloat162float(__ldg(cos_base + d));
+        float sin_v = __bfloat162float(__ldg(sin_base + d));
+
+        // Negate mirrored for first half
+        if (d < H2) mv = -mv;
+
+        // Apply rotation
+        float outv = xv * cos_v + mv * sin_v;
+
+        // Store
+        out_row[d] = __float2bfloat16(outv);
+    }
+}
+
+// Dispatcher
+extern "C" void launch_rope(
+    const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
+    const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+    const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+    int B, int T, int Hx, int D, int s0, int s1, cudaStream_t stream)
+{
+    dim3 block(256);
+    dim3 grid(B * T, (Hx + block.x - 1) / block.x);
+
+    #define CASE(D_, S0_, S1_) \
+        if (D == D_ && s0 == S0_ && s1 == S1_) { \
+            rope_kernel<D_, D_/2, S0_, S1_><<<grid, block, 0, stream>>>( \
+                x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx); \
+            return; \
+        }
+    CASE(128, 16, 24)
+    CASE(64, 8, 12)
+    #undef CASE
+}
+"""
+
+
+class Model(nn.Module):
+    def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
+        super().__init__()
+        assert sum(mrope_section) == D // 2
+        self.B, self.T = B, T
+        self.Hq, self.Hkv, self.D = Hq, Hkv, D
+        self.mrope_section = tuple(mrope_section)
+        self.max_pos = max_pos
+
+        inv_freq = _build_inv_freq(D)
+        pos = torch.arange(max_pos, dtype=torch.float32)
+        freqs = torch.outer(pos, inv_freq)
+        emb = torch.cat([freqs, freqs], dim=-1)
+        self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
+        self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+
+
+def _build_inv_freq(D, base=10000.0):
+    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
args
{
  "content": "<4890 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda_v2.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me write a proper, corrected CUDA kernel approach:
assistant
in 80978out 1796cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (362 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,190 +1,21 @@
-"""Optimized CUDA C++ kernel for the multi-axis RoPE pre-attention prep.
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Uses torch.utils.cpp_extension.load_inline with a single well-tuned kernel.
-Each thread processes ONE element (full D vector) for ONE (b,h,t).
-Grid = B * (Hq + Hkv) * T (one block per (b,h,t)).
-Block = D threads (one thread per D element - coalesced load/store).
+Grid = (B * T). Each program handles one (b, t) across all H heads.
+Cache values loaded once per (b,t), shared across heads.
+
+Optimization: use a separate kernel for small vs large head counts.
+For large H (>=16): parallelize heads across warps.
+For small H (<16): serialize heads in a single warp.
 """
 
 import torch
+import triton
+import triton.language as tl
 import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-import os
-
-CUDA_SRC = r"""
-#include <cuda_fp16.h>
-#include <cuda_bf16.h>
-
-template<int D, int D_half, int s0_val, int s1_val>
-__global__ void __launch_bounds__(128)
-rope_kernel(
-    const __nv_bfloat16* __restrict__ x_in,   // (B, T, Hd, D)
-    __nv_bfloat16* __restrict__ x_out,         // (B, Hd, T, D)
-    const int64_t* __restrict__ pos_t,
-    const int64_t* __restrict__ pos_h,
-    const int64_t* __restrict__ pos_w,
-    const __nv_bfloat16* __restrict__ cos_cache,
-    const __nv_bfloat16* __restrict__ sin_cache,
-    int B, int T, int Hx
-) {
-    // blockIdx.x = b * T + t
-    int bt = blockIdx.x;
-    int b = bt / T;
-    int t = bt % T;
-
-    // Position IDs
-    int64_t pt = pos_t[b * T + t];
-    int64_t ph = pos_h[b * T + t];
-    int64_t pw = pos_w[b * T + t];
-
-    // Thread index within block: handles head index
-    int h = blockIdx.y * blockDim.x + threadIdx.x;
-    if (h >= Hx) return;
-
-    // Per-thread: process D elements (D/2 pairs)
-    const int np = D / 2;
-    #define CACHE_COS(p, pos) cos_cache[pos * D + p]
-    #define CACHE_SIN(p, pos) sin_cache[pos * D + p]
-
-    for (int pi = 0; pi < np; ++pi) {
-        int d0 = pi * 2;
-        int d1 = pi * 2 + 1;
-
-        // Determine axis for both elements
-        int axis0 = (d0 < D_half) ? d0 : (d0 - D_half);
-        int axis1 = (d1 < D_half) ? d1 : (d1 - D_half);
-
-        // Load from cache
-        float2 c0, s0, c1, s1;
-
-        // For element d0
-        int64_t pos0 = (axis0 < s0_val) ? pt : ((axis0 < s0_val + s1_val) ? ph : pw);
-        float2 tmp = __bfloat1622float2(((float2*)CACHE_COS(d0, pos0))[0]);
-        c0 = *(float2*)&tmp;  // already a float2
-
-        // Use __ldg for cache optimization
-        const float2* cos_row = (const float2*)(cos_cache + pos0 * D);
-        const float2* sin_row = (const float2*)(sin_cache + pos0 * D);
-        c0 = __ldg(&cos_row[pi]).x;  // This is getting complicated
-
-        // Simpler: load single bf16 values
-        float cos_d0_x = __bfloat162float(__ldg(&CACHE_COS(d0, pos0)));
-        float sin_d0_x = __bfloat162float(__ldg(&CACHE_SIN(d0, pos0)));
-        float cos_d0_y = __bfloat162float(__ldg(&CACHE_COS(d1, pos0)));
-        float sin_d0_y = __bfloat162float(__ldg(&CACHE_SIN(d1, pos0)));
-
-        // Input pointer for this head
-        const __nv_bfloat16* x_row = x_in + ((b * T * Hx + t * Hx + h) * D);
-
-        // Load input bf16 pairs
-        float x0 = __bfloat162float(__ldg(&x_row[d0]));
-        float x1 = __bfloat162float(__ldg(&x_row[d1]));
-
-        // rotate_half mirror indices
-        int r0 = (d0 < D_half) ? (d0 + D_half) : (d0 - D_half);
-        int r1 = (d1 < D_half) ? (d1 + D_half) : (d1 - D_half);
-
-        float xr0 = __bfloat162float(__ldg(&x_row[r0]));
-        float xr1 = __bfloat162float(__ldg(&x_row[r1]));
-
-        // Negate if in first half
-        if (d0 < D_half) { xr0 = -xr0; }
-        if (d1 < D_half) { xr1 = -xr1; }
-
-        // Apply rotation
-        float out0 = x0 * cos_d0_x + xr0 * sin_d0_x;
-        float out1 = x1 * cos_d0_y + xr1 * sin_d0_y;
-
-        // Store to transposed output
-        int out_idx = (b * Hx * T * D + h * T * D + t * D);
-        x_out[out_idx + d0] = __float2bfloat16(out0);
-        x_out[out_idx + d1] = __float2bfloat16(out1);
-    }
-}
-
-// Dispatcher
-extern "C" void launch_rope(
-    const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
-    const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
-    const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
-    int B, int T, int Hx, int D, int s0, int s1,
-    cudaStream_t stream
-) {
-    dim3 block(min(Hx, 128));
-    dim3 grid(B * T, (Hx + block.x - 1) / block.x);
-
-    if (D == 128) {
-        if (s0 == 16 && s1 == 24) {
-            rope_kernel<128, 64, 16, 24><<<grid, block, 0, stream>>>(
-                x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
-        }
-    } else if (D == 64) {
-        if (s0 == 8 && s1 == 12) {
-            rope_kernel<64, 32, 8, 12><<<grid, block, 0, stream>>>(
-                x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
-        }
-    }
-}
-"""
-
-
-class Model(nn.Module):
-    """Triton-based mRoPE kernel with per-(b,t) parallelism."""
-    def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
-        super().__init__()
-        assert sum(mrope_section) == D // 2
-        self.B, self.T = B, T
-        self.Hq, self.Hkv, self.D = Hq, Hkv, D
-        self.mrope_section = tuple(mrope_section)
-        self.max_pos = max_pos
-
-        inv_freq = _build_inv_freq(D)
-        pos = torch.arange(max_pos, dtype=torch.float32)
-        freqs = torch.outer(pos, inv_freq)
-        emb = torch.cat([freqs, freqs], dim=-1)
-        self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
-        self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
-
-    de

... (8309 more chars)
args
{
  "content": "<7846 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 82792out 2798cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 85608out 69cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (203 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,11 +1,9 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
 
-Grid = (B * T). Each program handles one (b, t) across all H heads.
-Cache values loaded once per (b,t), shared across heads.
+Grid = B*T. Each kernel processes all H heads for one (b,t) position.
+Cos/sin cache gathered once per position, reused across all heads.
 
-Optimization: use a separate kernel for small vs large head counts.
-For large H (>=16): parallelize heads across warps.
-For small H (<16): serialize heads in a single warp.
+Uses num_warps=1, num_stages=4 for tiny D vectors (64/128 elements).
 """
 
 import torch
@@ -15,7 +13,7 @@
 
 
 @triton.jit
-def _rope_small_h_kernel(
+def _rope_fwd_kernel(
     x_ptr, x_out_ptr,
     cos_cache_ptr, sin_cache_ptr,
     pos_t_ptr, pos_h_ptr, pos_w_ptr,
@@ -26,7 +24,7 @@
     s0: tl.constexpr, s1: tl.constexpr,
     half: tl.constexpr, BLOCK_D: tl.constexpr,
 ):
-    """Serial loop over heads. Optimal for H < 16 (fits in single warp)."""
+    """Grid = (B * T). Each program = one (b,t); loops over H heads."""
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
@@ -39,6 +37,7 @@
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
+    # Gather cache values (6 reads, shared across all heads)
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -46,80 +45,22 @@
     cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
+    # Per-element axis assignment (mirrored across D/2)
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
     is_t = axis_idx < s0
     is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
     cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
     sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
 
+    # rotate_half: swap halves with sign flip
     perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
+    # Loop over heads
     for h in tl.static_range(H_val):
         x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = xr * sign
-        out = x * cos_val + xr * sin_val
-        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
-
-
-@triton.jit
-def _rope_large_h_kernel(
-    x_ptr, x_out_ptr,
-    cos_cache_ptr, sin_cache_ptr,
-    pos_t_ptr, pos_h_ptr, pos_w_ptr,
-    B_val: tl.constexpr, T_val: tl.constexpr,
-    H_val: tl.constexpr, D_val: tl.constexpr,
-    stride_xb, stride_xt, stride_xh, stride_xd,
-    stride_ob, stride_oh, stride_ot, stride_od,
-    s0: tl.constexpr, s1: tl.constexpr,
-    half: tl.constexpr,
-    HEADS_PER_BLOCK: tl.constexpr,
-    BLOCK_D: tl.constexpr,
-):
-    """Each block processes HEADS_PER_BLOCK heads for one (b,t).
-    Warps within the block are assigned different heads or D ranges."""
-    pid_bt = tl.program_id(0)
-    pid_hblock = tl.program_id(1)
-    pid_b = pid_bt // T_val
-    pid_t = pid_bt % T_val
-
-    pos_idx = pid_b * T_val + pid_t
-    p_t = tl.load(pos_t_ptr + pos_idx)
-    p_h = tl.load(pos_h_ptr + pos_idx)
-    p_w = tl.load(pos_w_ptr + pos_idx)
-
-    d_offs = tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
-
-    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-
-    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_t = axis_idx < s0
-    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
-
-    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
-    sign = tl.where(d_offs < half, -1.0, 1.0)
-
-    heads_in_block = min(HEADS_PER_BLOCK, H_val - pid_hblock * HEADS_PER_BLOCK)
-
-    for h_local in tl.static_range(HEADS_PER_BLOCK):
-        h = pid_hblock * HEADS_PER_BLOCK + h_local
-        if h >= H_val:
-            return
-
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        xr = tl.load(x_ptr + x_base + perm  * stride_xd, mask=mask, other=0.0).to(tl.float32)
         xr = xr * sign
         out = x * cos_val + xr * sin_val
         o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
@@ -148,48 +89,37 @@
             v *= 2
         return v
 
-    def _launch_rope(
-        self, x, x_out, H,
-        pos_t, pos_h

... (2650 more chars)
args
{
  "content": "<4967 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
in 85691out 1796cache_r 0cache_w 0
assistant
read /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
args
{
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (5664 chars) — TRUNCATED
<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: """Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
2: 
3: Grid = B*T. Each kernel processes all H heads for one (b,t) position.
4: Cos/sin cache gathered once per position, reused across all heads.
5: 
6: Uses num_warps=1, num_stages=4 for tiny D vectors (64/128 elements).
7: """
8: 
9: import torch
10: import triton
11: import triton.language as tl
12: import torch.nn as nn
13: 
14: 
15: @triton.jit
16: def _rope_fwd_kernel(
17:     x_ptr, x_out_ptr,
18:     cos_cache_ptr, sin_cache_ptr,
19:     pos_t_ptr, pos_h_ptr, pos_w_ptr,
20:     B_val: tl.constexpr, T_val: tl.constexpr,
21:     H_val: tl.constexpr, D_val: tl.constexpr,
22:     stride_xb, stride_xt, stride_xh, stride_xd,
23:     stride_ob, stride_oh, stride_ot, stride_od,
24:     s0: tl.constexpr, s1: tl.constexpr,
25:     half: tl.constexpr, BLOCK_D: tl.constexpr,
26: ):
27:     """Grid = (B * T). Each program = one (b,t); loops over H heads."""
28:     pid_bt = tl.program_id(0)
29:     pid_b = pid_bt // T_val
30:     pid_t = pid_bt % T_val
31: 
32:     pos_idx = pid_b * T_val + pid_t
33:     p_t = tl.load(pos_t_ptr + pos_idx)
34:     p_h = tl.load(pos_h_ptr + pos_idx)
35:     p_w = tl.load(pos_w_ptr + pos_idx)
36: 
37:     d_offs = tl.arange(0, BLOCK_D)
38:     mask = d_offs < D_val
39: 
40:     # Gather cache values (6 reads, shared across all heads)
41:     cos

... (4164 more chars)
assistant
...
assistant
in 87505out 57cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (238 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,70 +1,121 @@
-"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
 
-Grid = B*T. Each kernel processes all H heads for one (b,t) position.
-Cos/sin cache gathered once per position, reused across all heads.
-
-Uses num_warps=1, num_stages=4 for tiny D vectors (64/128 elements).
+Each thread handles one (b, h, t) and processes all D elements sequentially.
+Uses __ldg for read-only cache hints and template instantiation for D=64/128.
 """
 
 import torch
-import triton
-import triton.language as tl
 import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+import textwrap
 
+KERNEL_SRC = textwrap.dedent(r"""
+#include <cuda_bf16.h>
 
-@triton.jit
-def _rope_fwd_kernel(
-    x_ptr, x_out_ptr,
-    cos_cache_ptr, sin_cache_ptr,
-    pos_t_ptr, pos_h_ptr, pos_w_ptr,
-    B_val: tl.constexpr, T_val: tl.constexpr,
-    H_val: tl.constexpr, D_val: tl.constexpr,
-    stride_xb, stride_xt, stride_xh, stride_xd,
-    stride_ob, stride_oh, stride_ot, stride_od,
-    s0: tl.constexpr, s1: tl.constexpr,
-    half: tl.constexpr, BLOCK_D: tl.constexpr,
-):
-    """Grid = (B * T). Each program = one (b,t); loops over H heads."""
-    pid_bt = tl.program_id(0)
-    pid_b = pid_bt // T_val
-    pid_t = pid_bt % T_val
+template<int D, int HALF, int S0, int S1>
+__global__ void __launch_bounds__(128)
+rope_kernel(
+    const __nv_bfloat16* x_in,
+    __nv_bfloat16* x_out,
+    const int64_t* pos_t,
+    const int64_t* pos_h,
+    const int64_t* pos_w,
+    const __nv_bfloat16* cos_cache,
+    const __nv_bfloat16* sin_cache,
+    int B, int T, int Hx)
+{
+    int bt = blockIdx.x;
+    int b = bt / T;
+    int t = bt % T;
 
-    pos_idx = pid_b * T_val + pid_t
-    p_t = tl.load(pos_t_ptr + pos_idx)
-    p_h = tl.load(pos_h_ptr + pos_idx)
-    p_w = tl.load(pos_w_ptr + pos_idx)
+    int64_t pi = (int64_t)b * T + t;
+    int64_t p_t = pos_t[pi];
+    int64_t p_h = pos_h[pi];
+    int64_t p_w = pos_w[pi];
 
-    d_offs = tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
+    int h = blockIdx.y * blockDim.x + threadIdx.x;
+    if (h >= Hx) return;
 
-    # Gather cache values (6 reads, shared across all heads)
-    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    int64_t in_off = ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
+    int64_t out_off = ((int64_t)b * Hx * T + (int64_t)h * T + (int64_t)t) * D;
 
-    # Per-element axis assignment (mirrored across D/2)
-    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-    is_t = axis_idx < s0
-    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+    const __nv_bfloat16* xi = x_in + in_off;
+    __nv_bfloat16* xo = x_out + out_off;
 
-    # rotate_half: swap halves with sign flip
-    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
-    sign = tl.where(d_offs < half, -1.0, 1.0)
+    // Determine axis for each element at compile time
+    static_assert(D == 64 || D == 128);
+    static_assert(HALF == D / 2);
+    static_assert(S0 + S1 == HALF);
 
-    # Loop over heads
-    for h in tl.static_range(H_val):
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = tl.load(x_ptr + x_base + perm  * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = xr * sign
-        out = x * cos_val + xr * sin_val
-        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+    #pragma unroll
+    for (int d = 0; d < D; d++) {
+        int mirror = (d < HALF) ? (d + HALF) : (d - HALF);
+        int axis = (d < HALF) ? d : (d - HALF);
+
+        // Pick cache row based on axis
+        const __nv_bfloat16 *cr, *sr;
+        if (axis < S0) {
+            cr = cos_cache + p_t * D;
+            sr = sin_cache + p_t * D;
+        } else if (axis < S0 + S1) {
+            cr = cos_cache + p_h * D;
+            sr = sin_cache + p_h * D;
+        } else {
+            cr = cos_cache + p_w * D;
+            sr = sin_cache + p_w * D;
+        }
+
+        float xv  = __bfloat162float(__ldg(xi + d));
+        float mv  = __bfloat162float(__ldg(xi + mirror));
+        float cos_v = __bfloat162float(__ldg(cr + d));
+        float sin_v = __bfloat162float(__ldg(sr + d));
+
+        if (d < HALF) mv = -mv;
+
+        xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
+    }
+}
+
+extern "C" void launch_rope(
+    const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
+    const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+    const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+    int B, int T, int Hx, int D, int s0, int s1, cudaStream_t stream)
+{
+    dim3 block(min(Hx, 128));
+    dim3 grid(B * T, (Hx + block.x - 1) / block.x);
+
+    if (D == 128 && s0 == 16 && s1 == 24) {
+        rope_kernel<128, 64, 16, 24><<<grid, block, 0, stream>>>(
+            x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
+    } else if (D == 64 && s0 == 8 

... (3025 more chars)
args
{
  "content": "<5329 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me take a completely different approach and use hand-optimized CUDA C++:
assistant
in 89793out 1985cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check for CUDA kernel"
}
tool
stdout (6833 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
    subprocess.run(
  File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 159, in forward
    self._call_rope(q, q_out, Hq)
  File "/home/infatoshi/cuda/KernelBench

... (5333 more chars)
assistant
in 91796out 71cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (204 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,18 +1,24 @@
 """Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
 
-Each thread handles one (b, h, t) and processes all D elements sequentially.
+Each thread handles one (b, h, t) and processes all D elements.
 Uses __ldg for read-only cache hints and template instantiation for D=64/128.
 """
 
 import torch
 import torch.nn as nn
 from torch.utils.cpp_extension import load_inline
-import textwrap
 
-KERNEL_SRC = textwrap.dedent(r"""
+# Clean old cache
+import shutil, os
+for cached in ["/home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda_128",
+               "/home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda_64"]:
+    if os.path.exists(cached):
+        shutil.rmtree(cached, ignore_errors=True)
+
+KERNEL_SRC = r"""
 #include <cuda_bf16.h>
 
-template<int D, int HALF, int S0, int S1>
+template<int D, int HALF, int S0, int S1, int S2>
 __global__ void __launch_bounds__(128)
 rope_kernel(
     const __nv_bfloat16* x_in,
@@ -36,42 +42,50 @@
     int h = blockIdx.y * blockDim.x + threadIdx.x;
     if (h >= Hx) return;
 
-    int64_t in_off = ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
-    int64_t out_off = ((int64_t)b * Hx * T + (int64_t)h * T + (int64_t)t) * D;
+    int64_t in_off  = ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
+    int64_t out_off = ((int64_t)b * Hx * T + (int64_t)h *  T + (int64_t)t) * D;
 
-    const __nv_bfloat16* xi = x_in + in_off;
+    const __nv_bfloat16* xi = x_in  + in_off;
     __nv_bfloat16* xo = x_out + out_off;
 
-    // Determine axis for each element at compile time
-    static_assert(D == 64 || D == 128);
-    static_assert(HALF == D / 2);
-    static_assert(S0 + S1 == HALF);
+    const __nv_bfloat16 *cr_base_t = cos_cache + p_t * D;
+    const __nv_bfloat16 *sr_base_t = sin_cache + p_t * D;
+    const __nv_bfloat16 *cr_base_h = cos_cache + p_h * D;
+    const __nv_bfloat16 *sr_base_h = sin_cache + p_h * D;
+    const __nv_bfloat16 *cr_base_w = cos_cache + p_w * D;
+    const __nv_bfloat16 *sr_base_w = sin_cache + p_w * D;
 
     #pragma unroll
     for (int d = 0; d < D; d++) {
+        // rotate_half mirror index
         int mirror = (d < HALF) ? (d + HALF) : (d - HALF);
+
+        // Determine which axis this element belongs to
         int axis = (d < HALF) ? d : (d - HALF);
-
-        // Pick cache row based on axis
-        const __nv_bfloat16 *cr, *sr;
-        if (axis < S0) {
-            cr = cos_cache + p_t * D;
-            sr = sin_cache + p_t * D;
-        } else if (axis < S0 + S1) {
-            cr = cos_cache + p_h * D;
-            sr = sin_cache + p_h * D;
-        } else {
-            cr = cos_cache + p_w * D;
-            sr = sin_cache + p_w * D;
-        }
 
         float xv  = __bfloat162float(__ldg(xi + d));
         float mv  = __bfloat162float(__ldg(xi + mirror));
-        float cos_v = __bfloat162float(__ldg(cr + d));
-        float sin_v = __bfloat162float(__ldg(sr + d));
+        float cos_v, sin_v;
 
+        // Select cos/sin based on axis
+        if (axis < S0) {
+            // Temporal axis
+            cos_v = __bfloat162float(__ldg(cr_base_t + d));
+            sin_v = __bfloat162float(__ldg(sr_base_t + d));
+        } else if (axis < S0 + S1) {
+            // Height axis
+            cos_v = __bfloat162float(__ldg(cr_base_h + d));
+            sin_v = __bfloat162float(__ldg(sr_base_h + d));
+        } else {
+            // Width axis (remainder, S0+S1 <= axis < HALF)
+            cos_v = __bfloat162float(__ldg(cr_base_w + d));
+            sin_v = __bfloat162float(__ldg(sr_base_w + d));
+        }
+
+        // Negate for first half
         if (d < HALF) mv = -mv;
 
+        // Apply: out = x*cos + rot*sin
         xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
     }
 }
@@ -80,39 +94,35 @@
     const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
     const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
     const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
-    int B, int T, int Hx, int D, int s0, int s1, cudaStream_t stream)
+    int B, int T, int Hx, int D, int s0, int s1, int s2, cudaStream_t stream)
 {
     dim3 block(min(Hx, 128));
     dim3 grid(B * T, (Hx + block.x - 1) / block.x);
 
-    if (D == 128 && s0 == 16 && s1 == 24) {
-        rope_kernel<128, 64, 16, 24><<<grid, block, 0, stream>>>(
+    if (D == 128 && s0 == 16 && s1 == 24 && s2 == 24) {
+        rope_kernel<128, 64, 16, 24, 24><<<grid, block, 0, stream>>>(
             x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
-    } else if (D == 64 && s0 == 8 && s1 == 12) {
-        rope_kernel<64, 32, 8, 12><<<grid, block, 0, stream>>>(
+    } else if (D == 64 && s0 == 8 && s1 == 12 && s2 == 12) {
+        rope_kernel<64, 32, 8, 12, 12><<<grid, block, 0, stream>>>(
             x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
     }
 }
-""").strip()
+"""
 
-# Cache the compiled module
 _mod_cache = {}
 
 def _get_mod(D):
     if D in _mod_cache:
         return _mod_cache[D]
-    import os
     mod = load_inline(
         name=f"mrope_cuda_{D}",
         cpp_sources=" ",
         cuda_sources=KERNEL_SRC,
         functions=["launch_rope"],
         extra_cuda_cflags=["-arch=sm_120", "-O3", "--use_fast_math",
-                           "-Xptxas", "-O3", "-Xptxas", "--warn-on_spills",],
-        extra_cflags=["-O3"],
+                           "-Xptxas", "-O3"],
         with_cuda=True,
         verbose=False,
-        build_directory=None,
     )
     _mod_cache[D] = mod
     return mod
@@ -134,30 +144,36 @@
         self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
         self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
 
-    def _call_rope(self, x, x_out, Hx):
-        B, T, H, D = x.shape

... (1564 more chars)
args
{
  "content": "<6025 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The static_assert catches the issue: s0+s1 != D/2 (16+24=40 != 64, and s2 also exists). Let me fix the C++ kernel:
assistant
in 94230out 2250cache_r 0cache_w 0
assistant
...
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 300000,
  "description": "Run correctness check for CUDA kernel"
}
tool
stdout (4825 chars) — TRUNCATED
Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
    subprocess.run(
  File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
    main()
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
    sol_out = sol_model(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 162, in forward
    _get_mod(D).launch_rope(
    ^^^^^^^^^^^
  File "/home/infatoshi/cuda/

... (3325 more chars)
assistant
in 96499out 73cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (214 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,22 +1,54 @@
 """Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
 
-Each thread handles one (b, h, t) and processes all D elements.
-Uses __ldg for read-only cache hints and template instantiation for D=64/128.
+Each thread handles one (b, h, t) and processes all D elements in an unrolled loop.
+Uses __ldg for read-only cache hints. Separate Q/K kernel launches.
 """
 
 import torch
 import torch.nn as nn
 from torch.utils.cpp_extension import load_inline
 
-# Clean old cache
+# Clean old builds
 import shutil, os
-for cached in ["/home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda_128",
-               "/home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda_64"]:
-    if os.path.exists(cached):
-        shutil.rmtree(cached, ignore_errors=True)
+for d in ["mrope_cuda_64", "mrope_cuda_128"]:
+    path = os.path.join(os.path.expanduser("~/.cache/torch_extensions/py311_cu130"), d)
+    if os.path.exists(path):
+        shutil.rmtree(path, ignore_errors=True)
+
+PYBIND_DEFS = r"""
+extern void* launch_rope(
+    const void* x_in, void* x_out,
+    const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+    const void* cos_cache, const void* sin_cache,
+    int B, int T, int Hx, int D, int s0, int s1, int s2,
+    cudaStream_t stream);
+
+#include <pybind11/pybind11.h>
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+namespace py = pybind11;
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("launch_rope", [](torch::Tensor x_in, torch::Tensor x_out,
+                            torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+                            torch::Tensor cos_cache, torch::Tensor sin_cache,
+                            int B, int T, int Hx, int D, int s0, int s1, int s2) {
+        AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_in.scalar_type(), "launch_rope", [&] {
+            using scalar_t = char; // we'll use void* internally
+            launch_rope(
+                x_in.data_ptr(), x_out.data_ptr(),
+                pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
+                cos_cache.data_ptr(), sin_cache.data_ptr(),
+                B, T, Hx, D, s0, s1, s2,
+                at::cuda::getCurrentCUDAStream().stream());
+        });
+    }, "launch_rope");
+}
+"""
 
 KERNEL_SRC = r"""
 #include <cuda_bf16.h>
+#include <cuda_runtime.h>
 
 template<int D, int HALF, int S0, int S1, int S2>
 __global__ void __launch_bounds__(128)
@@ -48,64 +80,56 @@
     const __nv_bfloat16* xi = x_in  + in_off;
     __nv_bfloat16* xo = x_out + out_off;
 
-    const __nv_bfloat16 *cr_base_t = cos_cache + p_t * D;
-    const __nv_bfloat16 *sr_base_t = sin_cache + p_t * D;
-    const __nv_bfloat16 *cr_base_h = cos_cache + p_h * D;
-    const __nv_bfloat16 *sr_base_h = sin_cache + p_h * D;
-    const __nv_bfloat16 *cr_base_w = cos_cache + p_w * D;
-    const __nv_bfloat16 *sr_base_w = sin_cache + p_w * D;
+    const __nv_bfloat16 *cr_t = cos_cache + p_t * D;
+    const __nv_bfloat16 *sr_t = sin_cache + p_t * D;
+    const __nv_bfloat16 *cr_h = cos_cache + p_h * D;
+    const __nv_bfloat16 *sr_h = sin_cache + p_h * D;
+    const __nv_bfloat16 *cr_w = cos_cache + p_w * D;
+    const __nv_bfloat16 *sr_w = sin_cache + p_w * D;
 
     #pragma unroll
     for (int d = 0; d < D; d++) {
-        // rotate_half mirror index
         int mirror = (d < HALF) ? (d + HALF) : (d - HALF);
-
-        // Determine which axis this element belongs to
         int axis = (d < HALF) ? d : (d - HALF);
 
         float xv  = __bfloat162float(__ldg(xi + d));
         float mv  = __bfloat162float(__ldg(xi + mirror));
-        float cos_v, sin_v;
 
-        // Select cos/sin based on axis
-        if (axis < S0) {
-            // Temporal axis
-            cos_v = __bfloat162float(__ldg(cr_base_t + d));
-            sin_v = __bfloat162float(__ldg(sr_base_t + d));
-        } else if (axis < S0 + S1) {
-            // Height axis
-            cos_v = __bfloat162float(__ldg(cr_base_h + d));
-            sin_v = __bfloat162float(__ldg(sr_base_h + d));
-        } else {
-            // Width axis (remainder, S0+S1 <= axis < HALF)
-            cos_v = __bfloat162float(__ldg(cr_base_w + d));
-            sin_v = __bfloat162float(__ldg(sr_base_w + d));
-        }
+        const __nv_bfloat16 *cr, *sr;
+        if (axis < S0) { cr = cr_t; sr = sr_t; }
+        else if (axis < S0 + S1) { cr = cr_h; sr = sr_h; }
+        else { cr = cr_w; sr = sr_w; }
 
-        // Negate for first half
+        float cos_v = __bfloat162float(__ldg(cr + d));
+        float sin_v = __bfloat162float(__ldg(sr + d));
         if (d < HALF) mv = -mv;
-
-        // Apply: out = x*cos + rot*sin
         xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
     }
 }
 
-extern "C" void launch_rope(
-    const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
+extern "C" void* launch_rope(
+    const void* x_in, void* x_out,
     const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
-    const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+    const void* cos_cache, const void* sin_cache,
     int B, int T, int Hx, int D, int s0, int s1, int s2, cudaStream_t stream)
 {
     dim3 block(min(Hx, 128));
     dim3 grid(B * T, (Hx + block.x - 1) / block.x);
 
-    if (D == 128 && s0 == 16 && s1 == 24 && s2 == 24) {
+    if (D == 128 && s0 == 16 && s1 == 24) {
         rope_kernel<128, 64, 16, 24, 24><<<grid, block, 0, stream>>>(
-            x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
-    } else if (D == 64 && s0 == 8 && s1 == 12 && s2 == 12) {
+            (const __nv_bfloat16*)x_in, (__nv_bfloat16*)x_out,
+            pos_t, pos_h, pos_w,
+            (const __nv_bfloat16*)cos_cache, (const __nv_bfloat16*)sin_cache,
+ 

... (2586 more chars)
args
{
  "content": "<6748 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The main.cpp doesn't see `launch_rope` because the CUDA kernel is compiled separately from the stub. The `load_inline` with `cpp_sources=" "` won't see the CUDA-declared function. I need to declare it in the cpp_sources too.
assistant
in 98204out 2493cache_r 0cache_w 0
assistant
skill
args
{
  "name": "kernel-optimization"
}
tool
stdout (2962 chars) — TRUNCATED
<skill_content name="kernel-optimization">
# Skill: kernel-optimization

# Kernel Optimization

Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.

## When this skill fires

- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work

## Sub-references (read on demand)

The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:

| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `

... (1462 more chars)
assistant
The build fails because `load_inline` compiles the C++ and CUDA sources separately and `extern "C"` in CUDA isn't visible to C++. I'll put the full extern declaration in the C++ source that gets compiled with the pybind module.
assistant
in 100716out 85cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (189 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,54 +1,26 @@
 """Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
 
 Each thread handles one (b, h, t) and processes all D elements in an unrolled loop.
-Uses __ldg for read-only cache hints. Separate Q/K kernel launches.
+Uses __ldg for read-only cache hints.
 """
 
 import torch
 import torch.nn as nn
 from torch.utils.cpp_extension import load_inline
+import shutil, os
 
-# Clean old builds
-import shutil, os
-for d in ["mrope_cuda_64", "mrope_cuda_128"]:
-    path = os.path.join(os.path.expanduser("~/.cache/torch_extensions/py311_cu130"), d)
-    if os.path.exists(path):
-        shutil.rmtree(path, ignore_errors=True)
+for d in ["mrope_cuda64", "mrope_cuda128"]:
+    p = os.path.expanduser(f"~/.cache/torch_extensions/py311_cu130/{d}")
+    if os.path.exists(p):
+        shutil.rmtree(p, ignore_errors=True)
 
-PYBIND_DEFS = r"""
-extern void* launch_rope(
-    const void* x_in, void* x_out,
-    const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
-    const void* cos_cache, const void* sin_cache,
-    int B, int T, int Hx, int D, int s0, int s1, int s2,
-    cudaStream_t stream);
+CPP_STUB = ""
 
+CUDA_SRC = r"""
+#include <cuda_bf16.h>
 #include <pybind11/pybind11.h>
 #include <torch/extension.h>
 #include <ATen/cuda/CUDAContext.h>
-namespace py = pybind11;
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-    m.def("launch_rope", [](torch::Tensor x_in, torch::Tensor x_out,
-                            torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
-                            torch::Tensor cos_cache, torch::Tensor sin_cache,
-                            int B, int T, int Hx, int D, int s0, int s1, int s2) {
-        AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_in.scalar_type(), "launch_rope", [&] {
-            using scalar_t = char; // we'll use void* internally
-            launch_rope(
-                x_in.data_ptr(), x_out.data_ptr(),
-                pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
-                cos_cache.data_ptr(), sin_cache.data_ptr(),
-                B, T, Hx, D, s0, s1, s2,
-                at::cuda::getCurrentCUDAStream().stream());
-        });
-    }, "launch_rope");
-}
-"""
-
-KERNEL_SRC = r"""
-#include <cuda_bf16.h>
-#include <cuda_runtime.h>
 
 template<int D, int HALF, int S0, int S1, int S2>
 __global__ void __launch_bounds__(128)
@@ -94,42 +66,56 @@
 
         float xv  = __bfloat162float(__ldg(xi + d));
         float mv  = __bfloat162float(__ldg(xi + mirror));
+        float cos_v, sin_v;
 
-        const __nv_bfloat16 *cr, *sr;
-        if (axis < S0) { cr = cr_t; sr = sr_t; }
-        else if (axis < S0 + S1) { cr = cr_h; sr = sr_h; }
-        else { cr = cr_w; sr = sr_w; }
+        if (axis < S0) {
+            cos_v = __bfloat162float(__ldg(cr_t + d));
+            sin_v = __bfloat162float(__ldg(sr_t + d));
+        } else if (axis < S0 + S1) {
+            cos_v = __bfloat162float(__ldg(cr_h + d));
+            sin_v = __bfloat162float(__ldg(sr_h + d));
+        } else {
+            cos_v = __bfloat162float(__ldg(cr_w + d));
+            sin_v = __bfloat162float(__ldg(sr_w + d));
+        }
 
-        float cos_v = __bfloat162float(__ldg(cr + d));
-        float sin_v = __bfloat162float(__ldg(sr + d));
         if (d < HALF) mv = -mv;
         xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
     }
 }
 
-extern "C" void* launch_rope(
-    const void* x_in, void* x_out,
-    const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
-    const void* cos_cache, const void* sin_cache,
-    int B, int T, int Hx, int D, int s0, int s1, int s2, cudaStream_t stream)
+void launch_rope(
+    torch::Tensor x_in, torch::Tensor x_out,
+    torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+    torch::Tensor cos_cache, torch::Tensor sin_cache,
+    int B, int T, int Hx, int D, int s0, int s1, int s2)
 {
     dim3 block(min(Hx, 128));
     dim3 grid(B * T, (Hx + block.x - 1) / block.x);
 
-    if (D == 128 && s0 == 16 && s1 == 24) {
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+
+    if (D == 128 && s0 == 16 && s1 == 24 && s2 == 24) {
         rope_kernel<128, 64, 16, 24, 24><<<grid, block, 0, stream>>>(
-            (const __nv_bfloat16*)x_in, (__nv_bfloat16*)x_out,
-            pos_t, pos_h, pos_w,
-            (const __nv_bfloat16*)cos_cache, (const __nv_bfloat16*)sin_cache,
+            static_cast<const __nv_bfloat16*>(x_in.data_ptr()),
+            static_cast<__nv_bfloat16*>(x_out.data_ptr()),
+            pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
+            static_cast<const __nv_bfloat16*>(cos_cache.data_ptr()),
+            static_cast<const __nv_bfloat16*>(sin_cache.data_ptr()),
             B, T, Hx);
-    } else if (D == 64 && s0 == 8 && s1 == 12) {
+    } else if (D == 64 && s0 == 8 && s1 == 12 && s2 == 12) {
         rope_kernel<64, 32, 8, 12, 12><<<grid, block, 0, stream>>>(
-            (const __nv_bfloat16*)x_in, (__nv_bfloat16*)x_out,
-            pos_t, pos_h, pos_w,
-            (const __nv_bfloat16*)cos_cache, (const __nv_bfloat16*)sin_cache,
+            static_cast<const __nv_bfloat16*>(x_in.data_ptr()),
+            static_cast<__nv_bfloat16*>(x_out.data_ptr()),
+            pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
+            static_cast<const __nv_bfloat16*>(cos_cache.data_ptr()),
+            static_cast<const __nv_bfloat16*>(sin_cache.data_ptr()),
             B, T, Hx);
     }
-    return nullptr;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("launch_rope", &launch_rope, "launch_rope");
 }
 """
 
@@ -139,10 +125,13 @@
     if D in _mod_cache:
         return _mod_cache[D]
     mod = load_inline(
-        name=f"mrope

... (1636 more chars)
args
{
  "content": "<6284 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
The issue with `load_inline` is that the C++ and CUDA compilation units don't share symbols. Let me put the pybind AND the kernel in the CUDA file, using a single `extra_cuda_cflags` approach where the entire implementation is in .cu:
assistant
in 101672out 2349cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (255 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,146 +1,70 @@
-"""Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
 
-Each thread handles one (b, h, t) and processes all D elements in an unrolled loop.
-Uses __ldg for read-only cache hints.
+Grid = (B * T). Each program handles one (b,t) across all H heads.
+Cos/sin cache gathered once per position, reused across all heads in the program.
 """
 
 import torch
+import triton
+import triton.language as tl
 import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-import shutil, os
 
-for d in ["mrope_cuda64", "mrope_cuda128"]:
-    p = os.path.expanduser(f"~/.cache/torch_extensions/py311_cu130/{d}")
-    if os.path.exists(p):
-        shutil.rmtree(p, ignore_errors=True)
 
-CPP_STUB = ""
+@triton.jit
+def _rope_fwd_kernel(
+    x_ptr, x_out_ptr,
+    cos_cache_ptr, sin_cache_ptr,
+    pos_t_ptr, pos_h_ptr, pos_w_ptr,
+    B_val: tl.constexpr, T_val: tl.constexpr,
+    H_val: tl.constexpr, D_val: tl.constexpr,
+    stride_xb, stride_xt, stride_xh, stride_xd,
+    stride_ob, stride_oh, stride_ot, stride_od,
+    s0: tl.constexpr, s1: tl.constexpr,
+    half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+    """Grid = (B * T). Each program = one (b,t); serially loops over H heads."""
+    pid_bt = tl.program_id(0)
+    pid_b = pid_bt // T_val
+    pid_t = pid_bt % T_val
 
-CUDA_SRC = r"""
-#include <cuda_bf16.h>
-#include <pybind11/pybind11.h>
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
+    pos_idx = pid_b * T_val + pid_t
+    p_t = tl.load(pos_t_ptr + pos_idx)
+    p_h = tl.load(pos_h_ptr + pos_idx)
+    p_w = tl.load(pos_w_ptr + pos_idx)
 
-template<int D, int HALF, int S0, int S1, int S2>
-__global__ void __launch_bounds__(128)
-rope_kernel(
-    const __nv_bfloat16* x_in,
-    __nv_bfloat16* x_out,
-    const int64_t* pos_t,
-    const int64_t* pos_h,
-    const int64_t* pos_w,
-    const __nv_bfloat16* cos_cache,
-    const __nv_bfloat16* sin_cache,
-    int B, int T, int Hx)
-{
-    int bt = blockIdx.x;
-    int b = bt / T;
-    int t = bt % T;
+    d_offs = tl.arange(0, BLOCK_D)
+    mask = d_offs < D_val
 
-    int64_t pi = (int64_t)b * T + t;
-    int64_t p_t = pos_t[pi];
-    int64_t p_h = pos_h[pi];
-    int64_t p_w = pos_w[pi];
+    # Gather cos/sin from cache (6 loads per (b,t), reused for all H heads)
+    cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+    sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    int h = blockIdx.y * blockDim.x + threadIdx.x;
-    if (h >= Hx) return;
+    # Determine which axis each element belongs to
+    axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+    is_t = axis_idx < s0
+    is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
 
-    int64_t in_off  = ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
-    int64_t out_off = ((int64_t)b * Hx * T + (int64_t)h *  T + (int64_t)t) * D;
+    # rotate_half: x_rot[d] = -x[d+half] if d < half, else x[d-half]
+    perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+    sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    const __nv_bfloat16* xi = x_in  + in_off;
-    __nv_bfloat16* xo = x_out + out_off;
-
-    const __nv_bfloat16 *cr_t = cos_cache + p_t * D;
-    const __nv_bfloat16 *sr_t = sin_cache + p_t * D;
-    const __nv_bfloat16 *cr_h = cos_cache + p_h * D;
-    const __nv_bfloat16 *sr_h = sin_cache + p_h * D;
-    const __nv_bfloat16 *cr_w = cos_cache + p_w * D;
-    const __nv_bfloat16 *sr_w = sin_cache + p_w * D;
-
-    #pragma unroll
-    for (int d = 0; d < D; d++) {
-        int mirror = (d < HALF) ? (d + HALF) : (d - HALF);
-        int axis = (d < HALF) ? d : (d - HALF);
-
-        float xv  = __bfloat162float(__ldg(xi + d));
-        float mv  = __bfloat162float(__ldg(xi + mirror));
-        float cos_v, sin_v;
-
-        if (axis < S0) {
-            cos_v = __bfloat162float(__ldg(cr_t + d));
-            sin_v = __bfloat162float(__ldg(sr_t + d));
-        } else if (axis < S0 + S1) {
-            cos_v = __bfloat162float(__ldg(cr_h + d));
-            sin_v = __bfloat162float(__ldg(sr_h + d));
-        } else {
-            cos_v = __bfloat162float(__ldg(cr_w + d));
-            sin_v = __bfloat162float(__ldg(sr_w + d));
-        }
-
-        if (d < HALF) mv = -mv;
-        xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
-    }
-}
-
-void launch_rope(
-    torch::Tensor x_in, torch::Tensor x_out,
-    torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
-    torch::Tensor cos_cache, torch::Tensor sin_cache,
-    int B, int T, int Hx, int D, int s0, int s1, int s2)
-{
-    dim3 block(min(Hx, 128));
-    dim3 grid(B * T, (Hx + block.x - 1) / block.x);
-
-    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
-
-    if (D == 128 && s0 == 16 && s1 == 24 && s2 == 24) {
-        rope_kernel<128, 64, 16, 24, 24><<<grid, block, 0, stream>>>(
-            static_cast<const __nv_bfloat16*>(x_in.data_ptr()),
-            static_cast<__nv_bfloat16*>(x_out.data_ptr()),
-            pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
-            static_cast<const __nv_bfloat16*>(cos_cache.data_ptr(

... (4107 more chars)
args
{
  "content": "<4940 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me check how the Triton reference handles this - it gets ~70% of peak bandwidth which is actually quite good for a Triton kernel. Let me go back to the best Triton version and see if I can optimize it further:
assistant
in 104040out 1837cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 105896out 105cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
  "description": "Run benchmark"
}
tool
stdout (896 chars)
shape=0 variant=eager tflops=0.176 gbps=252.857 ms=0.357
shape=0 variant=compiled tflops=0.481 gbps=690.605 ms=0.131
shape=0 variant=solution tflops=0.744 gbps=1067.797 ms=0.085
shape=0 solution_peak_fraction=0.5932
shape=1 variant=eager tflops=0.176 gbps=257.593 ms=0.571
shape=1 variant=compiled tflops=0.471 gbps=687.417 ms=0.214
shape=1 variant=solution tflops=0.777 gbps=1133.958 ms=0.130
shape=1 solution_peak_fraction=0.6300
shape=2 variant=eager tflops=0.071 gbps=111.637 ms=0.198
shape=2 variant=compiled tflops=0.235 gbps=367.464 ms=0.060
shape=2 variant=solution tflops=0.324 gbps=506.188 ms=0.044
shape=2 solution_peak_fraction=0.2812
shape=3 variant=eager tflops=0.179 gbps=257.102 ms=1.405
shape=3 variant=compiled tflops=0.528 gbps=758.085 ms=0.476
shape=3 variant=solution tflops=0.915 gbps=1312.454 ms=0.275
shape=3 solution_peak_fraction=0.7291
peak_fraction: 0.5261
RESULT: OK
assistant
in 106016out 104cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (150 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
-"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep (SM120 Blackwell).
 
-Grid = (B * T). Each program handles one (b,t) across all H heads.
-Cos/sin cache gathered once per position, reused across all heads in the program.
+Hybrid approach: use Triton for large shapes and a specialized pure-PyTorch 
+implementation for small shapes where Triton overhead dominates.
 """
 
 import torch
@@ -15,6 +15,29 @@
     x_ptr, x_out_ptr,
     cos_cache_ptr, sin_cache_ptr,
     pos_t_ptr, pos_h_ptr, pos_w_ptr,
+    H_val: tl.constexpr, D_val: tl.constexpr,
+    stride_xb, stride_xt, stride_xh, stride_xd,
+    stride_ob, stride_oh, stride_ot, stride_od,
+    s0: tl.constexpr, s1: tl.constexpr,
+    half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+    """Each program handles one (b, t). Cos/sin loaded once, reused across H heads."""
+    pid_bt = tl.program_id(0)
+    pid_b = pid_b_t // D_val  # Will be set by caller
+    # Actually we need B and T passed in
+    pass
+
+
+def _build_inv_freq(D, base=10000.0):
+    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+
+
+# Triton kernel
+@triton.jit
+def _rope_triton_kernel(
+    x_ptr, x_out_ptr,
+    cos_cache_ptr, sin_cache_ptr,
+    pos_t_ptr, pos_h_ptr, pos_w_ptr,
     B_val: tl.constexpr, T_val: tl.constexpr,
     H_val: tl.constexpr, D_val: tl.constexpr,
     stride_xb, stride_xt, stride_xh, stride_xd,
@@ -22,7 +45,6 @@
     s0: tl.constexpr, s1: tl.constexpr,
     half: tl.constexpr, BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * T). Each program = one (b,t); serially loops over H heads."""
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
@@ -32,10 +54,9 @@
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    d_offs = tl.arange(0, BLOCK_D)
+    d_offs = tl.arange(0,BLOCK_D)
     mask = d_offs < D_val
 
-    # Gather cos/sin from cache (6 loads per (b,t), reused for all H heads)
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -43,14 +64,12 @@
     cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
 
-    # Determine which axis each element belongs to
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
     is_t = axis_idx < s0
     is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
     cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
     sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
 
-    # rotate_half: x_rot[d] = -x[d+half] if d < half, else x[d-half]
     perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
@@ -65,6 +84,7 @@
 
 
 class Model(nn.Module):
+    """Triton mRoPE kernel."""
     def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
         super().__init__()
         assert sum(mrope_section) == D // 2
@@ -86,41 +106,50 @@
             v *= 2
         return v
 
-    def forward(self, q, k, pos_t, pos_h, pos_w):
-        B, T, Hq, D = q.shape
-        _, _, Hkv, _ = k.shape
+    def _launch(self, x, x_out, H):
+        B, T, _, D = x.shape
         s0, s1, _ = self.mrope_section
         half = D // 2
         BLOCK_D = self._next_pow2(D)
 
+        _rope_triton_kernel[(B * T,)](
+            x, x_out,
+            self.cos_cache, self.sin_cache,
+            x.reshape(-1, T, H, D)[:, :, 0, :].data_ptr(),
+            pos_t, pos_h, pos_w,
+            B_val=B, T_val=T, H_val=H, D_val=D,
+            stride_xb=T*H*D, stride_xt=H*D, stride_xh=D, stride_xd=1,
+            stride_ob=H*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
+            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+            num_warps=1, num_stages=4,
+        )
+
+    def forward(self, q, k, pos_t, pos_h, pos_w):
+        B, T, Hq, D = q.shape
+        _, _, Hkv, _ = k.shape
+
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
 
-        bt_grid = (B * T,)
-
-        _rope_fwd_kernel[bt_grid](
+        _rope_triton_kernel[(B * T,)](
             q, q_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hq, D_val=D,
             stride_xb=T*Hq*D, stride_xt=Hq*D, stride_xh=D, stride_xd=1,
             stride_ob=Hq*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
-            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+            s0=s0, s1=s1, half=D//2, BLOCK_D=self._next_pow2(D),
             num_warps=1, num_stages=4,
         )
-        _rope_fwd_kernel[bt_grid](
+        _rope_triton_kernel[(B * T,)](
             k, k_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hkv, D_val=D,
             stride_xb=T*Hkv*D, stride_xt=Hkv*D, stride_xh=D, stride_xd=1,
             stride_ob=Hkv*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
-            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+            s0=s0, s1=s1, half=D//2, BLOCK_D=self._next_pow2(D),
             num_warps=1, num_stages=4,
         )
 
         return q_out, k_out
-
-
-def _build_inv_freq(D, base=10000.0):
-    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
args
{
  "content": "<5818 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Current status: 52.6% peak. Shape 2 (B=2, D=64) is dragging it down at 28%. Let me try a completely different approach for shape 2 -- the issue is that with small D and small B*T, there aren't enough programs to saturate GDDR7 on Blackwell. Let me try to increase parallelism by having each block of threads process multiple heads with warp-level parallelism:
assistant
in 106591out 2200cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (122 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,6 @@
-"""Multi-axis RoPE pre-attention prep (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
 
-Hybrid approach: use Triton for large shapes and a specialized pure-PyTorch 
-implementation for small shapes where Triton overhead dominates.
+Grid = (B * T). Each program handles one (b,t) across all H heads.
 """
 
 import torch
@@ -12,29 +11,6 @@
 
 @triton.jit
 def _rope_fwd_kernel(
-    x_ptr, x_out_ptr,
-    cos_cache_ptr, sin_cache_ptr,
-    pos_t_ptr, pos_h_ptr, pos_w_ptr,
-    H_val: tl.constexpr, D_val: tl.constexpr,
-    stride_xb, stride_xt, stride_xh, stride_xd,
-    stride_ob, stride_oh, stride_ot, stride_od,
-    s0: tl.constexpr, s1: tl.constexpr,
-    half: tl.constexpr, BLOCK_D: tl.constexpr,
-):
-    """Each program handles one (b, t). Cos/sin loaded once, reused across H heads."""
-    pid_bt = tl.program_id(0)
-    pid_b = pid_b_t // D_val  # Will be set by caller
-    # Actually we need B and T passed in
-    pass
-
-
-def _build_inv_freq(D, base=10000.0):
-    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
-
-
-# Triton kernel
-@triton.jit
-def _rope_triton_kernel(
     x_ptr, x_out_ptr,
     cos_cache_ptr, sin_cache_ptr,
     pos_t_ptr, pos_h_ptr, pos_w_ptr,
@@ -54,7 +30,7 @@
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
-    d_offs = tl.arange(0,BLOCK_D)
+    d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -84,7 +60,6 @@
 
 
 class Model(nn.Module):
-    """Triton mRoPE kernel."""
     def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
         super().__init__()
         assert sum(mrope_section) == D // 2
@@ -106,50 +81,39 @@
             v *= 2
         return v
 
-    def _launch(self, x, x_out, H):
-        B, T, _, D = x.shape
-        s0, s1, _ = self.mrope_section
-        half = D // 2
-        BLOCK_D = self._next_pow2(D)
-
-        _rope_triton_kernel[(B * T,)](
-            x, x_out,
-            self.cos_cache, self.sin_cache,
-            x.reshape(-1, T, H, D)[:, :, 0, :].data_ptr(),
-            pos_t, pos_h, pos_w,
-            B_val=B, T_val=T, H_val=H, D_val=D,
-            stride_xb=T*H*D, stride_xt=H*D, stride_xh=D, stride_xd=1,
-            stride_ob=H*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
-            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
-            num_warps=1, num_stages=4,
-        )
-
     def forward(self, q, k, pos_t, pos_h, pos_w):
         B, T, Hq, D = q.shape
         _, _, Hkv, _ = k.shape
+        s0, s1, _ = self.mrope_section
 
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
 
-        _rope_triton_kernel[(B * T,)](
+        bt_grid = (B * T,)
+        BLOCK_D = self._next_pow2(D)
+
+        _rope_fwd_kernel[bt_grid](
             q, q_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hq, D_val=D,
             stride_xb=T*Hq*D, stride_xt=Hq*D, stride_xh=D, stride_xd=1,
             stride_ob=Hq*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
-            s0=s0, s1=s1, half=D//2, BLOCK_D=self._next_pow2(D),
+            s0=s0, s1=s1, half=D//2, BLOCK_D=BLOCK_D,
             num_warps=1, num_stages=4,
         )
-        _rope_triton_kernel[(B * T,)](
+        _rope_fwd_kernel[bt_grid](
             k, k_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hkv, D_val=D,
             stride_xb=T*Hkv*D, stride_xt=Hkv*D, stride_xh=D, stride_xd=1,
             stride_ob=Hkv*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
-            s0=s0, s1=s1, half=D//2, BLOCK_D=self._next_pow2(D),
+            s0=s0, s1=s1, half=D//2, BLOCK_D=BLOCK_D,
             num_warps=1, num_stages=4,
         )
+        return q_out, k_out
 
-        return q_out, k_out
+
+def _build_inv_freq(D, base=10000.0):
+    return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
args
{
  "content": "<4558 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
I keep going in circles. Let me take the last working best version and focus. The 52.6% version was solid. Shape 2 is inherently small and hard to optimize (B=2, T=2048, Hq=16, Hkv=2, D=64 => total heads = 36, T=2048, so grid = 4096 programs). Let me try a different angle: use a 2D grid where the second dimension parallelizes D, keeping the B*T grid for each head but having threads within a warp split the D dimension.
assistant
in 108810out 1808cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 110637out 105cache_r 0cache_w 0
assistant
bash
args
{
  "command": "nvidia-smi 2>&1 | head -5",
  "workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
  "description": "Check GPU status"
}
tool
stdout (400 chars)
Thu May  7 13:03:21 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03              Driver Version: 595.58.03      CUDA Version: 13.2     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
assistant
in 110757out 91cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (139 lines)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,6 +1,7 @@
-"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — Triton kernel (SM120 Blackwell).
 
-Grid = (B * T). Each program handles one (b,t) across all H heads.
+Grid = (B * T * H, D // BLOCK_SIZE). Each program handles BLOCK_SIZE elements
+of D for one (b, h, t). This gives maximum SM utilization.
 """
 
 import torch
@@ -10,28 +11,35 @@
 
 
 @triton.jit
-def _rope_fwd_kernel(
+def _rope_kernel(
     x_ptr, x_out_ptr,
     cos_cache_ptr, sin_cache_ptr,
     pos_t_ptr, pos_h_ptr, pos_w_ptr,
-    B_val: tl.constexpr, T_val: tl.constexpr,
-    H_val: tl.constexpr, D_val: tl.constexpr,
+    B_val: tl.constexpr, T_val: tl.constexpr, H_val: tl.constexpr, D_val: tl.constexpr,
     stride_xb, stride_xt, stride_xh, stride_xd,
     stride_ob, stride_oh, stride_ot, stride_od,
     s0: tl.constexpr, s1: tl.constexpr,
     half: tl.constexpr, BLOCK_D: tl.constexpr,
 ):
-    pid_bt = tl.program_id(0)
-    pid_b = pid_bt // T_val
-    pid_t = pid_bt % T_val
+    """Grid = (B * T * H, triton.cdiv(D, BLOCK_D)).
+    Each program handles BLOCK_D consecutive elements of D."""
+    pid_bth = tl.program_id(0)
+    pid_dd = tl.program_id(1)
 
+    pid_b = pid_bth // (T_val * H_val)
+    pid_th = pid_bth % (T_val * H_val)
+    pid_t = pid_th // H_val
+    pid_h = pid_th % H_val
+
+    d_start = pid_dd * BLOCK_D
+    d_offs = d_start + tl.arange(0, BLOCK_D)
+    mask = d_offs < D_val
+
+    # Each position lookup
     pos_idx = pid_b * T_val + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
-
-    d_offs = tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
 
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -49,14 +57,14 @@
     perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    for h in tl.static_range(H_val):
-        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-        x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = tl.load(x_ptr + x_base + perm  * stride_xd, mask=mask, other=0.0).to(tl.float32)
-        xr = xr * sign
-        out = x * cos_val + xr * sin_val
-        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
-        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+    x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+    xr = tl.load(x_ptr + x_base + perm  * stride_xd, mask=mask, other=0.0).to(tl.float32)
+    xr = xr * sign
+    out = x * cos_val + xr * sin_val
+
+    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
@@ -75,43 +83,43 @@
         self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
         self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
 
-    def _next_pow2(self, n):
-        v = 1
-        while v < n:
-            v *= 2
-        return v
-
     def forward(self, q, k, pos_t, pos_h, pos_w):
         B, T, Hq, D = q.shape
         _, _, Hkv, _ = k.shape
         s0, s1, _ = self.mrope_section
+        half = D // 2
+
+        BLOCK_D = 32
 
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
 
-        bt_grid = (B * T,)
-        BLOCK_D = self._next_pow2(D)
-
-        _rope_fwd_kernel[bt_grid](
+        q_grid_0 = B * T * Hq
+        q_grid_1 = (D + BLOCK_D - 1) // BLOCK_D
+        _rope_kernel[(q_grid_0, q_grid_1)](
             q, q_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hq, D_val=D,
             stride_xb=T*Hq*D, stride_xt=Hq*D, stride_xh=D, stride_xd=1,
             stride_ob=Hq*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
-            s0=s0, s1=s1, half=D//2, BLOCK_D=BLOCK_D,
-            num_warps=1, num_stages=4,
+            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+            num_warps=1,
         )
-        _rope_fwd_kernel[bt_grid](
+
+        k_grid_0 = B * T * Hkv
+        k_grid_1 = (D + BLOCK_D - 1) // BLOCK_D
+        _rope_kernel[(k_grid_0, k_grid_1)](
             k, k_out,
             self.cos_cache, self.sin_cache,
             pos_t, pos_h, pos_w,
             B_val=B, T_val=T, H_val=Hkv, D_val=D,
             stride_xb=T*Hkv*D, stride_xt=Hkv*D, stride_xh=D, stride_xd=1,
             stride_ob=Hkv*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
-            s0=s0, s1=s1, half=D//2, BLOCK_D=BLOCK_D,
-            num_warps=1, num_stages=4,
+            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+            num_warps=1,
         )
+
         return q_out, k_out
 
 
args
{
  "content": "<4874 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Now let me try to get to 60+%. The benchmark is at ~52.6%. The compiled reference gets ~50% on average with peaks of 77%. Let me try a different strategy - use Triton but with the grid being (B, T, H) so we have massive parallelism:
assistant
in 110959out 1871cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python check.py 2>&1",
  "timeout": 120000,
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 112849out 68cache_r 0cache_w 0
assistant
bash
args
{
  "command": "uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
  "description": "Run benchmark"
}
tool
stdout (893 chars)
shape=0 variant=eager tflops=0.176 gbps=252.494 ms=0.358
shape=0 variant=compiled tflops=0.518 gbps=743.182 ms=0.121
shape=0 variant=solution tflops=0.185 gbps=266.118 ms=0.339
shape=0 solution_peak_fraction=0.1478
shape=1 variant=eager tflops=0.179 gbps=261.710 ms=0.562
shape=1 variant=compiled tflops=0.460 gbps=671.588 ms=0.219
shape=1 variant=solution tflops=0.190 gbps=276.743 ms=0.531
shape=1 solution_peak_fraction=0.1537
shape=2 variant=eager tflops=0.073 gbps=114.456 ms=0.193
shape=2 variant=compiled tflops=0.169 gbps=263.867 ms=0.084
shape=2 variant=solution tflops=0.146 gbps=228.874 ms=0.097
shape=2 solution_peak_fraction=0.1272
shape=3 variant=eager tflops=0.171 gbps=245.235 ms=1.472
shape=3 variant=compiled tflops=0.464 gbps=665.123 ms=0.543
shape=3 variant=solution tflops=0.189 gbps=271.372 ms=1.331
shape=3 solution_peak_fraction=0.1508
peak_fraction: 0.1445
RESULT: OK
assistant
in 112932out 104cache_r 0cache_w 0
assistant
todowrite
args
{
  "todos": [
    {
      "content": "Implement initial Triton kernel for mRoPE",
      "status": "completed",
      "priority": "high"
    },
    {
      "content": "Run check.py for correctness",
      "status": "completed",
      "priority": "high"
    },
    {
      "content": "Optimize kernel - fuse Q+K into single kernel, try different grid configs",
      "status": "in_progress",
      "priority": "high"
    },
    {
      "content": "Run benchmark.py to measure achieved bandwidth",
      "status": "pending",
      "priority": "medium"
    }
  ]
}
tool
stdout (385 chars)
content: 'Implement initial Triton kernel for mRoPE'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Optimize kernel - fuse Q+K into single kernel, try different grid configs'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark.py to measure achieved bandwidth'
status: 'pending'
priority: 'medium'
assistant
Too many tiny launches kill it. OK, the best was the B*T grid with num_warps=1. Let me go back to that and look at what I can actually improve. What's happening with Triton is that each B*T program has to loop through H heads sequentially. For shape 2 (B=2, T=2048, Hq=16, Hkv=2), grid=4096 but each program loops through 16 or 2 heads (serial). The total work per program is small. The real optimization opportunity: combine Q and K into a single kernel, and use a 2D grid where the extra dimension handles the QK distinction. Or better yet, process both Q and K in the same kernel.
assistant
in 113504out 300cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (189 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
-"""Multi-axis RoPE pre-attention prep — Triton kernel (SM120 Blackwell).
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
 
-Grid = (B * T * H, D // BLOCK_SIZE). Each program handles BLOCK_SIZE elements
-of D for one (b, h, t). This gives maximum SM utilization.
+Fused Q+K: Grid = (B * T), each kernel processes all Q and K heads for one (b,t).
+Position cache gathered once per (b,t), shared across all Q+K heads.
 """
 
 import torch
@@ -11,36 +11,34 @@
 
 
 @triton.jit
-def _rope_kernel(
-    x_ptr, x_out_ptr,
+def _rope_qk_kernel(
+    q_ptr, k_ptr,
+    q_out_ptr, k_out_ptr,
     cos_cache_ptr, sin_cache_ptr,
     pos_t_ptr, pos_h_ptr, pos_w_ptr,
-    B_val: tl.constexpr, T_val: tl.constexpr, H_val: tl.constexpr, D_val: tl.constexpr,
-    stride_xb, stride_xt, stride_xh, stride_xd,
-    stride_ob, stride_oh, stride_ot, stride_od,
+    B_val: tl.constexpr, T_val: tl.constexpr,
+    Hq_val: tl.constexpr, Hkv_val: tl.constexpr,
+    D_val: tl.constexpr,
+    stride_qb, stride_qt, stride_qh, stride_qd,
+    stride_kb, stride_kt, stride_kh, stride_kd,
+    stride_qob, stride_qoh, stride_qot, stride_qod,
+    stride_kob, stride_koh, stride_kot, stride_kod,
     s0: tl.constexpr, s1: tl.constexpr,
     half: tl.constexpr, BLOCK_D: tl.constexpr,
 ):
-    """Grid = (B * T * H, triton.cdiv(D, BLOCK_D)).
-    Each program handles BLOCK_D consecutive elements of D."""
-    pid_bth = tl.program_id(0)
-    pid_dd = tl.program_id(1)
+    pid_bt = tl.program_id(0)
+    pid_b = pid_bt // T_val
+    pid_t = pid_bt % T_val
 
-    pid_b = pid_bth // (T_val * H_val)
-    pid_th = pid_bth % (T_val * H_val)
-    pid_t = pid_th // H_val
-    pid_h = pid_th % H_val
-
-    d_start = pid_dd * BLOCK_D
-    d_offs = d_start + tl.arange(0, BLOCK_D)
-    mask = d_offs < D_val
-
-    # Each position lookup
     pos_idx = pid_b * T_val + pid_t
     p_t = tl.load(pos_t_ptr + pos_idx)
     p_h = tl.load(pos_h_ptr + pos_idx)
     p_w = tl.load(pos_w_ptr + pos_idx)
 
+    d_offs = tl.arange(0, BLOCK_D)
+    mask = d_offs < D_val
+
+    # Cos/sin gathered once (6 reads), shared by all Q+K heads
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -57,14 +55,25 @@
     perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
-    x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
-    xr = tl.load(x_ptr + x_base + perm  * stride_xd, mask=mask, other=0.0).to(tl.float32)
-    xr = xr * sign
-    out = x * cos_val + xr * sin_val
+    # Process Q heads
+    for h in tl.static_range(Hq_val):
+        x_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qh
+        x  = tl.load(q_ptr  + x_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
+        xr = tl.load(q_ptr + x_base + perm * stride_qd, mask=mask, other=0.0).to(tl.float32)
+        xr = xr * sign
+        out = x * cos_val + xr * sin_val
+        o_base = pid_b * stride_qob + h * stride_qoh + pid_t * stride_qot
+        tl.store(q_out_ptr + o_base + d_offs * stride_qod, out.to(tl.bfloat16), mask=mask)
 
-    o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
-    tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+    # Process K heads
+    for h in tl.static_range(Hkv_val):
+        x_base = pid_b * stride_kb + pid_t * stride_kt + h * stride_kh
+        x  = tl.load(k_ptr  + x_base + d_offs * stride_kd, mask=mask, other=0.0).to(tl.float32)
+        xr = tl.load(k_ptr + x_base + perm * stride_kd, mask=mask, other=0.0).to(tl.float32)
+        xr = xr * sign
+        out = x * cos_val + xr * sin_val
+        o_base = pid_b * stride_kob + h * stride_koh + pid_t * stride_kot
+        tl.store(k_out_ptr + o_base + d_offs * stride_kod, out.to(tl.bfloat16), mask=mask)
 
 
 class Model(nn.Module):
@@ -83,41 +92,64 @@
         self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
         self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
 
+    def _next_pow2(self, n):
+        v = 1
+        while v < n:
+            v *= 2
+        return v
+
+    def _call_qk_kernel(self, q, q_out, k, k_out):
+        B, T, Hq, D = q.shape
+        _, _, Hkv, _ = k.shape
+        s0, s1, _ = self.mrope_section
+        half = D // 2
+        BLOCK_D = self._next_pow2(D)
+        bt_grid = (B * T,)
+
+        _rope_qk_kernel[bt_grid](
+            q, k,
+            q_out, k_out,
+            self.cos_cache, self.sin_cache,
+            q.reshape(-1, T, 1, 1)[:, 0, 0, :].data_ptr(),  # placeholder
+            pos_h, pos_w,
+            B_val=B, T_val=T,
+            Hq_val=Hq, Hkv_val=Hkv,
+            D_val=D,
+            # Will fill in proper strides below
+            stride_qb=T*Hq*D, stride_qt=Hq*D, stride_qh=D, stride_qd=1,
+            stride_kb=T*Hkv*D, stride_kt=Hkv*D, stride_kh=D, stride_kd=1,
+            stride_qob=Hq*T*D, stride_qoh=T*D, stride_qot=D, stride_qod=1,
+            stride_kob=Hkv*T*D, stride_koh=T*D, stride_kot=D, stride_kod=1,
+            s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+            num_warps=1, num_stages=4,
+        )
+
     def forward(self, q, k, pos_t, pos_h, pos_w):
         B, T, Hq, D = q.shape
         _, _, Hkv, _ = k.shape
         s0, s1, _ = self.mrope_section
-        half = D // 2
-
-        BLOCK_D = 32
 
         q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
         k_out = torch.empty(B, Hkv, T, D, dtype=torch.b

... (1697 more chars)
args
{
  "content": "<6231 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
Let me write a fused kernel where each program handles BOTH Q and K heads for one (b,t). I'll use a concat approach but be smarter about it:
assistant
in 113975out 2309cache_r 0cache_w 0
assistant
bash
args
{
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
  "timeout": 120000,
  "workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
  "description": "Run correctness check"
}
tool
stdout (5 chars)
PASS
assistant
in 116303out 133cache_r 0cache_w 0
assistant
bash
args
{
  "command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
  "timeout": 300000,
  "description": "Run benchmark for QK fused kernel"
}
tool
stdout (895 chars)
shape=0 variant=eager tflops=0.180 gbps=257.766 ms=0.350
shape=0 variant=compiled tflops=0.491 gbps=703.960 ms=0.128
shape=0 variant=solution tflops=0.683 gbps=980.236 ms=0.092
shape=0 solution_peak_fraction=0.5446
shape=1 variant=eager tflops=0.181 gbps=265.039 ms=0.555
shape=1 variant=compiled tflops=0.470 gbps=686.185 ms=0.214
shape=1 variant=solution tflops=0.751 gbps=1096.340 ms=0.134
shape=1 solution_peak_fraction=0.6091
shape=2 variant=eager tflops=0.075 gbps=116.658 ms=0.190
shape=2 variant=compiled tflops=0.168 gbps=262.864 ms=0.084
shape=2 variant=solution tflops=0.341 gbps=532.717 ms=0.042
shape=2 solution_peak_fraction=0.2960
shape=3 variant=eager tflops=0.173 gbps=248.365 ms=1.454
shape=3 variant=compiled tflops=0.467 gbps=670.000 ms=0.539
shape=3 variant=solution tflops=0.864 gbps=1240.121 ms=0.291
shape=3 solution_peak_fraction=0.6890
peak_fraction: 0.5100
RESULT: OK
assistant
in 116451out 100cache_r 0cache_w 0
assistant
write /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
diff (294 lines — TRUNCATED)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,10 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-
-Fused Q+K: Grid = (B * T), each kernel processes all Q and K heads for one (b,t).
-Position cache gathered once per (b,t), shared across all Q+K heads.
+"""Multi-axis RoPE pre-attention prep — Triton v3 (SM120 Blackwell).
+
+Two kernel types:
+- Large shapes: Grid=(B*T, 1), each program loops over all H heads (serial).
+- Small H shapes: Grid=(B*T*H, 1), each program processes exactly one (b,h,t).
+
+The second gives more parallel SMs for better occupancy with small head counts.
 """
 
 import torch
@@ -11,21 +14,18 @@
 
 
 @triton.jit
-def _rope_qk_kernel(
-    q_ptr, k_ptr,
-    q_out_ptr, k_out_ptr,
+def _rope_serial_h_kernel(
+    x_ptr, x_out_ptr,
     cos_cache_ptr, sin_cache_ptr,
     pos_t_ptr, pos_h_ptr, pos_w_ptr,
     B_val: tl.constexpr, T_val: tl.constexpr,
-    Hq_val: tl.constexpr, Hkv_val: tl.constexpr,
-    D_val: tl.constexpr,
-    stride_qb, stride_qt, stride_qh, stride_qd,
-    stride_kb, stride_kt, stride_kh, stride_kd,
-    stride_qob, stride_qoh, stride_qot, stride_qod,
-    stride_kob, stride_koh, stride_kot, stride_kod,
-    s0: tl.constexpr, s1: tl.constexpr,
-    half: tl.constexpr, BLOCK_D: tl.constexpr,
-):
+    H_val: tl.constexpr, D_val: tl.constexpr,
+    stride_xb, stride_xt, stride_xh, stride_xd,
+    stride_ob, stride_oh, stride_ot, stride_od,
+    s0: tl.constexpr, s1: tl.constexpr,
+    half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+    """Grid=(B*T). Serial loop over H heads."""
     pid_bt = tl.program_id(0)
     pid_b = pid_bt // T_val
     pid_t = pid_bt % T_val
@@ -38,7 +38,6 @@
     d_offs = tl.arange(0, BLOCK_D)
     mask = d_offs < D_val
 
-    # Cos/sin gathered once (6 reads), shared by all Q+K heads
     cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
     cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -49,31 +48,115 @@
     axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
     is_t = axis_idx < s0
     is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
-    cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
-    sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+    cos_v = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+    sin_v = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
 
     perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
     sign = tl.where(d_offs < half, -1.0, 1.0)
 
-    # Process Q heads
-    for h in tl.static_range(Hq_val):
-        x_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qh
-        x  = tl.load(q_ptr  + x_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
-        xr = tl.load(q_ptr + x_base + perm * stride_qd, mask=mask, other=0.0).to(tl.float32)
+    for h in tl.static_range(H_val):
+        x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+        x  = tl.load(x_ptr  + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+        xr = tl.load(x_ptr + x_base + perm  * stride_xd, mask=mask, other=0.0).to(tl.float32)
         xr = xr * sign
-        out = x * cos_val + xr * sin_val
-        o_base = pid_b * stride_qob + h * stride_qoh + pid_t * stride_qot
-        tl.store(q_out_ptr + o_base + d_offs * stride_qod, out.to(tl.bfloat16), mask=mask)
-
-    # Process K heads
-    for h in tl.static_range(Hkv_val):
-        x_base = pid_b * stride_kb + pid_t * stride_kt + h * stride_kh
-        x  = tl.load(k_ptr  + x_base + d_offs * stride_kd, mask=mask, other=0.0).to(tl.float32)
-        xr = tl.load(k_ptr + x_base + perm * stride_kd, mask=mask, other=0.0).to(tl.float32)
-        xr = xr * sign
-        out = x * cos_val + xr * sin_val
-        o_base = pid_b * stride_kob + h * stride_koh + pid_t * stride_kot
-        tl.store(k_out_ptr + o_base + d_offs * stride_kod, out.to(tl.bfloat16), mask=mask)
+        out = x * cos_v + xr * sin_v
+        o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+        tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+
+
+@triton.jit
+def _rope_parallel_h_kernel(
+    x_ptr, x_out_ptr,
+    cos_cache_ptr, sin_cache_ptr,
+    pos_t_ptr, pos_h_ptr, pos_w_ptr,
+    B_val: tl.constexpr, T_val: tl.constexpr,
+    D_val: tl.constexpr,
+    stride_xb, stride_xt, stride_xh, stride_xd,
+    stride_ob, stride_oh, stride_ot, stride_od,
+    s0: tl.constexpr, s1: tl.constexpr,
+    half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+    """Grid=(B*T*H). Each program = exactly one (b,h,t)."""
+    pid_bth = tl.program_id(0)
+    pid_b = pid_bth // tl.num_programs(1)
+    pid_th = pid_bth % tl.num_programs(1)
+    # Actually this doesn't work cleanly. Let me pass explicit B, T, H.
+    pass
+
+
+@triton.jit
+def _rope_flat_kernel(
+    x_ptr, x_out_ptr,
+    cos_cache_ptr, sin_cache_ptr,
+    pos_t_ptr, pos_h_ptr, pos_w_ptr,
+    H_val: tl.constexpr, D_val: tl.constexpr,
+    total_bt: tl.constexpr,
+    stride_xb, stride_xt, stride_xh, stride_xd,
+    stride_ob, stride_oh, stride_ot, stride_od,
+    s0: tl.constexpr, s1: tl.constexpr,
+    half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+    """Grid=(B*T*H). Each program = one (b,h,t). No serial loop."""
+    pid = tl.program_id(0)
+    pid_h = pid % H_val
+    pid_bt = pid // H_val
+    pid_b = pid_bt // (total_bt // H_val)  # total_bt = B*T, so pid_bt // (BT/H) doesn't work
+    # This approach is getting messy. Let me just pass B and T directly.
+    pass
+
+
+@triton.jit
+def _rope_one_per_head(
+    x_ptr, x_out_ptr,
+    cos_cache_ptr, sin_cache_ptr,
+    pos_t_ptr, pos_h_ptr, pos_w_ptr,
+    B_val: tl.constexpr, T_val: tl.constexpr,
+    H_val

... (6307 more chars)
args
{
  "content": "<9792 chars \u2014 see diff>",
  "filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}
tool
stdout (24 chars)
Wrote file successfully.
assistant
OK, this is about the same as the separate Q+K version (52.6% vs 51.0%). The fusion doesn't help because each program still has to loop through Q+K heads sequentially. The benefit would be avoiding the second kernel launch overhead, but that's negligible compared to the memory bandwidth. Let me try a fundamentally different approach: since Triton 3.6.0 is available, let me use its newer features like `triton.language.core.load/store` with pointer vectors or batched loads. Actually, the key insight I'm missing: I should try increasing the grid size to get more SMs saturated, especially for shape 2. Let me try a kernel where I parallelize both T and H dimensions, effectively having grid = (B, T, H):
assistant
in 117021out 3689cache_r 0cache_w 0