"""Fused multi-axis RoPE pre-attention kernel for RTX PRO 6000 (Blackwell, sm_120).
Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
Each thread handles one RoPE pair (d, d+D/2) and reads cos/sin ONLY from
the axis assigned to its slice by mrope_section — cutting cache reads by 3x.
One block per (b,t) pair. No shared memory, no syncthreads.
"""
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
OP_TYPE = "rope"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
B = 1
T = 4096
Hq = 32
Hkv = 8
D = 128
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768
_cuda_src = r"""
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <stdint.h>
extern "C" __global__ void
mrope_fused_kernel(
const __nv_bfloat16* __restrict__ q,
const __nv_bfloat16* __restrict__ k,
const int64_t* __restrict__ pos_t,
const int64_t* __restrict__ pos_h,
const int64_t* __restrict__ pos_w,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
__nv_bfloat16* __restrict__ q_rot,
__nv_bfloat16* __restrict__ k_rot,
int B, int T, int Hq, int Hkv, int D,
int s0, int s1, int s2
) {
int blk = blockIdx.x;
if (blk >= B * T) return;
int b = blk / T;
int t = blk % T;
int d0 = threadIdx.x;
int d1 = d0 + D / 2;
int off_bt = b * T + t;
unsigned int d0u = d0;
unsigned int boundary1 = (unsigned int)s0;
unsigned int boundary2 = (unsigned int)(s0 + s1);
// Each thread reads cos/sin only from the axis that owns its d0 slice.
// cos_val/sin_val for d0 also apply to d1 (mirrored by mrope mixing).
float cos_val, sin_val;
if (d0u < boundary1) {
cos_val = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
sin_val = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
} else if (d0u < boundary2) {
cos_val = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
sin_val = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
} else {
cos_val = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
sin_val = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
}
// --- strides ---
int q_stride_H = D;
int q_stride_T = Hq * D;
int q_base = b * T * q_stride_T + t * q_stride_T;
int qo_stride_T = D;
int qo_stride_H = T * D;
int qo_stride_B = Hq * T * D;
int qo_base = b * qo_stride_B + t * qo_stride_T;
int k_stride_H = D;
int k_stride_T = Hkv * D;
int k_base = b * T * k_stride_T + t * k_stride_T;
int ko_stride_T = D;
int ko_stride_H = T * D;
int ko_stride_B = Hkv * T * D;
int ko_base = b * ko_stride_B + t * ko_stride_T;
// --- rotate q heads ---
for (int h = 0; h < Hq; ++h) {
int q_h_off = q_base + h * q_stride_H;
int qo_h_off = qo_base + h * qo_stride_H;
float q0 = __bfloat162float(q[q_h_off + d0]);
float q1 = __bfloat162float(q[q_h_off + d1]);
q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
}
// --- rotate k heads ---
for (int h = 0; h < Hkv; ++h) {
int k_h_off = k_base + h * k_stride_H;
int ko_h_off = ko_base + h * ko_stride_H;
float k0 = __bfloat162float(k[k_h_off + d0]);
float k1 = __bfloat162float(k[k_h_off + d1]);
k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
}
}
torch::Tensor launch_mrope(torch::Tensor q, torch::Tensor k,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
torch::Tensor q_rot, torch::Tensor k_rot,
int B, int T, int Hq, int Hkv, int D,
int s0, int s1, int s2) {
int blocks = B * T;
int threads = D / 2;
mrope_fused_kernel<<<blocks, threads>>>(
(const __nv_bfloat16*)q.data_ptr(),
(const __nv_bfloat16*)k.data_ptr(),
(const int64_t*)pos_t.data_ptr(),
(const int64_t*)pos_h.data_ptr(),
(const int64_t*)pos_w.data_ptr(),
(const __nv_bfloat16*)cos_cache.data_ptr(),
(const __nv_bfloat16*)sin_cache.data_ptr(),
(__nv_bfloat16*)q_rot.data_ptr(),
(__nv_bfloat16*)k_rot.data_ptr(),
B, T, Hq, Hkv, D, s0, s1, s2
);
return q_rot;
}
"""
_cpp_src = """
torch::Tensor launch_mrope(torch::Tensor q, torch::Tensor k,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
torch::Tensor q_rot, torch::Tensor k_rot,
int B, int T, int Hq, int Hkv, int D,
int s0, int s1, int s2);
"""
def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
class Model(nn.Module):
def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
mrope_section: tuple[int, int, int], max_pos: int):
super().__init__()
assert sum(mrope_section) == D // 2, \
f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
self.B, self.T = B, T
self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
inv_freq = _build_inv_freq(D)
pos = torch.arange(max_pos, dtype=torch.float32)
freqs = torch.outer(pos, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
self._mrope_module = load_inline(
name="mrope_fused",
cpp_sources=_cpp_src,
cuda_sources=_cuda_src,
functions=["launch_mrope"],
extra_cuda_cflags=["-arch=sm_120", "-O3", "--use_fast_math"],
verbose=False,
)
def forward(
self,
q: torch.Tensor, k: torch.Tensor,
pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hq, D = q.shape
Hkv = k.shape[2]
s0, s1, s2 = self.mrope_section
q_rot = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_rot = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
self._mrope_module.launch_mrope(
q, k, pos_t, pos_h, pos_w,
self.cos_cache, self.sin_cache,
q_rot, k_rot,
B, T, Hq, Hkv, D, s0, s1, s2,
)
return q_rot, k_rot
def get_inputs():
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
return [q, k, pos_t, pos_h, pos_w]
def get_init_inputs():
return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
shape=0 variant=eager tflops=0.200 gbps=286.263 ms=0.315
shape=0 variant=compiled tflops=0.562 gbps=805.919 ms=0.112
shape=0 variant=solution tflops=0.799 gbps=1145.865 ms=0.079
shape=0 solution_peak_fraction=0.6366
shape=1 variant=eager tflops=0.192 gbps=279.701 ms=0.526
shape=1 variant=compiled tflops=0.529 gbps=773.084 ms=0.190
shape=1 variant=solution tflops=0.871 gbps=1271.602 ms=0.116
shape=1 solution_peak_fraction=0.7064
shape=2 variant=eager tflops=0.095 gbps=148.485 ms=0.149
shape=2 variant=compiled tflops=0.229 gbps=357.117 ms=0.062
shape=2 variant=solution tflops=0.711 gbps=1110.361 ms=0.020
shape=2 solution_peak_fraction=0.6169
shape=3 variant=eager tflops=0.175 gbps=251.202 ms=1.438
shape=3 variant=compiled tflops=0.484 gbps=695.178 ms=0.519
shape=3 variant=solution tflops=0.896 gbps=1285.981 ms=0.281
shape=3 solution_peak_fraction=0.7144
peak_fraction: 0.6672
RESULT: OK
PASS
{
"problem": "09_fmha_preattn_mrope",
"harness": "opencode",
"model": "deepseek/deepseek-v4-pro",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.6672,
"elapsed_seconds": 2349,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 35161, "output_tokens": 23507, "cache_read_tokens": 2164864, "cache_creation_tokens": 0, "reasoning_tokens": 46529, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
1: """Naive multi-axis RoPE pre-attention reference (correctness only).
2:
3: Mirrors the Qwen2-VL apply_multimodal_rotary_pos_emb semantics: three position
4: axes (temporal, height, width), each contributing a slice of the head_dim via
5: mrope_section. Standard rotate-half on q and k once per-axis cos/sin have been
6: mixed.
7:
8: Layout convention:
9: q: (B, T, Hq, D) bf16
10: k: (B, T, Hkv, D) bf16
11: pos_t, pos_h, pos_w: (B, T) int64
12:
13: Output: q_rot, k_rot in (B, H, T, D) — i.e. the layout an attention kernel
14: expects (head dim second). The transpose is part of the "pre-attention prep".
15: """
16: import torch
17: import torch.nn as nn
18:
19: OP_TYPE = "rope"
20: SUPPORTED_PRECISIONS = ["bf16"]
21: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
22:
23:
24: def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
25: return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
26:
27:
28: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
29: half = x.shape[-1] // 2
30: return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
31:
32:
33: class Model(nn.Module):
34: def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
35: mrope_section: tuple[int, int, int], max_pos: int):
36: super().__init__()
37: assert sum(mrope_section) == D // 2, \
38:
... (3384 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for multi-axis RoPE pre-attention prep.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or
... (3519 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for multi-axis RoPE pre-attention prep.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime
... (3966 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for multi-axis RoPE pre-attention.
2:
3: Sized for Qwen2-VL-style vision-language inference: long T, GQA, mrope_section
4: splitting D/2 across temporal/height/width axes. Mix of base, long-context,
5: small-head-dim, and very-long-context.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL base: T=4k, head_dim=128, GQA 32:8, mrope [16,24,24]
10: {"B": 1, "T": 4096, "Hq": 32, "Hkv": 8, "D": 128,
11: "mrope_section": (16, 24, 24), "max_pos": 32768},
12:
13: # Qwen2-VL 7B-style long context, narrower kv heads
14: {"B": 1, "T": 8192, "Hq": 28, "Hkv": 4, "D": 128,
15: "mrope_section": (16, 24, 24), "max_pos": 32768},
16:
17: # Smaller head_dim and B=2 (tests batch dim and reduced D)
18: {"B": 2, "T": 2048, "Hq": 16, "Hkv": 2, "D": 64,
19: "mrope_section": (8, 12, 12), "max_pos": 16384},
20:
21: # Very long context (16k tokens)
22: {"B": 1, "T": 16384, "Hq": 32, "Hkv": 8, "D": 128,
23: "mrope_section": (16, 24, 24), "max_pos": 65536},
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml</path>
<type>file</type>
<content>
1: name: 09_fmha_preattn_mrope
2: display_name: "Multi-axis RoPE Pre-attention Prep"
3: precision: bf16
4: regime: memory # pointwise rotation + cache gather; bandwidth-bound
5:
6: # FLOPs per call: rotate-half on q and k.
7: # Per element: out = a*cos + b*sin -> 3 flops/elem (1 mul + 1 mul + 1 add).
8: # We do this on q ((B*T*Hq*D) elems) and k ((B*T*Hkv*D) elems).
9: flops_formula: "3 * B * T * (Hq + Hkv) * D"
10:
11: # Bytes moved (lower bound):
12: # read q, k: 2 * B*T*(Hq+Hkv)*D
13: # read 3 position id arrays: 3 * B*T*8 (int64)
14: # read cos/sin from cache: 6 * B*T*D*2 (3 axes, 2 caches, bf16)
15: # write q_rot, k_rot: 2 * B*T*(Hq+Hkv)*D
16: bytes_formula: "4 * B * T * (Hq + Hkv) * D + 12 * B * T * D + 24 * B * T"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram
21:
22: tolerance:
23: bfloat16: 0.02
24:
25: # Forbidden ops -- agent must implement the gather + per-axis mix + rotation.
26: # Banning vendor mRoPE entry points and obvious shortcuts.
27: forbidden:
28: - "transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb"
29: - "apply_multimodal_rotary_pos_emb"
30: - "flash_attn.layers.rotary"
31: - "flashinfer.rope"
32: - "flashinfer.apply_rope"
33:
34: sota:
35: name:
... (284 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for multi-axis RoPE pre-attention.
2:
3: Wraps transformers' apply_multimodal_rotary_pos_emb. The full pipeline (gather
4: per-axis cos/sin from cache, then call the vendor rotation) is timed, matching
5: what the reference and the agent solution do.
6:
7: If transformers isn't installed or the impl signature has drifted, returns
8: is_available()=False and the benchmark omits the SOTA line.
9: """
10: from __future__ import annotations
11:
12: import torch
13:
14:
15: def _try_transformers():
16: try:
17: from transformers.models.qwen2_vl.modeling_qwen2_vl import (
18: apply_multimodal_rotary_pos_emb as _impl,
19: )
20: return _impl
21: except Exception:
22: return None
23:
24:
25: def sota_forward(
26: q: torch.Tensor, k: torch.Tensor,
27: pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
28: cos_cache: torch.Tensor, sin_cache: torch.Tensor,
29: mrope_section: tuple[int, int, int],
30: ):
31: """q,k: (B, T, H, D); pos_*: (B, T); cos_cache,sin_cache: (max_pos, D)."""
32: impl = _try_transformers()
33: if impl is None:
34: raise RuntimeError("transformers not available for SOTA")
35:
36: # Gather per-axis cos/sin and stack into (3, B, T, D) as transformers expects.
37: cos = torch.stack([cos_cache[pos_t], cos_cache[pos_h], c
... (495 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to optimize the multi-axis RoPE pre-attention prep step for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The op is the pre-attention prep used in Qwen2-VL-style vision-language models: take q (B, T, Hq, D) and k (B, T, Hkv, D) bf16 plus three int64 position-id tensors pos_t, pos_h, pos_w of shape (B, T) — temporal, height, width axes — gather per-axis cos/sin from a precomputed (max_pos, D) bf16 cache, mix the per-axis values across the head_dim according to mrope_section (a tuple [s0, s1, s2] summing to D/2 — the first s0 pair-indices use the temporal axis, next s1 use height, last s2 use width, mirrored into the second half of D), then apply standard rotate-half rotation (out = x*cos + rotate_half(x)*sin) on q and k. Return q_rot and k_rot in (B, H, T, D) layout — the transpose from (B, T, H, D) to (B, H, T, D) is part of the prep. Correctness tolerance on the bf16 outputs is 1e-2 abs/rel. Both q_rot and k_rot must match the reference within that.
4:
5: The shapes you have to handle are B=1 T
... (1617 more chars){
"command": "nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader 2>/dev/null || echo \"no GPU info\"",
"description": "Check GPU status"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 96956 MiB
{
"pattern": "solution.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}No files found{
"description": "Check CUDA arch and hardware details",
"command": "<607 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard && uv run python -c "
import torch
print(f'CUDA: {torch.cuda.is_available()}')
print(f'Device: {torch.cuda.get_device_name(0)}')
print(f'Compute Capability: {torch.cuda.get_device_capability(0)}')
print(f'SM count: {torch.cuda.get_device_properties(0).multi_processor_count}')
print(f'Max threads per block: {torch.cuda.get_device_properties(0).max_threads_per_block}')
print(f'Max shared mem per block: {torch.cuda.get_device_properties(0).shared_mem_per_block}')
print(f'Max threads per SM: {torch.cuda.get_device_properties(0).max_threads_per_multi_processor}')
"Traceback (most recent call last):
File "<string>", line 8, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'shared_mem_per_block'. Did you mean: 'shared_memory_per_block'?
CUDA: True
Device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Compute Capability: (12, 0)
SM count: 188
Max threads per block: 1024
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"description": "Check remaining GPU properties",
"command": "<562 chars \u2014 see below>"
}cd /home/infatoshi/cuda/KernelBench-Hard && uv run python -c "
import torch
p = torch.cuda.get_device_properties(0)
print(f'SM count: {p.multi_processor_count}')
print(f'Max threads per SM: {p.max_threads_per_multi_processor}')
print(f'Max shared mem per block: {p.shared_memory_per_block}')
print(f'Max shared mem per block opt-in: {p.shared_memory_per_block_optin}')
print(f'Max blocks per SM: {p.max_blocks_per_multi_processor}')
print(f'L2 cache size: {p.l2_cache_size}')
print(f'Registers per block: {p.regs_per_block}')
print(f'Warp size: {p.warp_size}')
"Traceback (most recent call last):
File "<string>", line 8, in <module>
AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'max_blocks_per_multi_processor'. Did you mean: 'max_threads_per_multi_processor'?
SM count: 188
Max threads per SM: 1536
Max shared mem per block: 49152
Max shared mem per block opt-in: 101376
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,215 @@
+"""Fused multi-axis RoPE pre-attention kernel for RTX PRO 6000 (Blackwell, sm_120).
+
+Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
+rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
+
+Each thread handles a RoPE pair (d, d+D/2) — no shared memory, no syncthreads.
+The kernel streams q/k once and emits rotated output directly in attention layout.
+"""
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+OP_TYPE = "rope"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+
+B = 1
+T = 4096
+Hq = 32
+Hkv = 8
+D = 128
+MROPE_SECTION = (16, 24, 24)
+MAX_POS = 32768
+
+
+_cuda_src = r"""
+#include <cuda_runtime.h>
+#include <cuda_bf16.h>
+#include <stdint.h>
+
+extern "C" __global__ void mrope_fused_kernel(
+ const __nv_bfloat16* __restrict__ q,
+ const __nv_bfloat16* __restrict__ k,
+ const int64_t* __restrict__ pos_t,
+ const int64_t* __restrict__ pos_h,
+ const int64_t* __restrict__ pos_w,
+ const __nv_bfloat16* __restrict__ cos_cache,
+ const __nv_bfloat16* __restrict__ sin_cache,
+ __nv_bfloat16* __restrict__ q_rot,
+ __nv_bfloat16* __restrict__ k_rot,
+ int B, int T, int Hq, int Hkv, int D,
+ int s0, int s1, int s2
+) {
+ int blk = blockIdx.x;
+ if (blk >= B * T) return;
+
+ int b = blk / T;
+ int t = blk % T;
+ int d0 = threadIdx.x;
+ int d1 = d0 + D / 2;
+
+ // --- gather position ids ---
+ int off_bt = b * T + t;
+ int64_t pt = pos_t[off_bt];
+ int64_t ph = pos_h[off_bt];
+ int64_t pw = pos_w[off_bt];
+
+ // --- gather per-axis cos/sin and mix according to mrope_section ---
+ float ct = __bfloat162float(cos_cache[pt * D + d0]);
+ float st = __bfloat162float(sin_cache[pt * D + d0]);
+ float ch = __bfloat162float(cos_cache[ph * D + d0]);
+ float sh = __bfloat162float(sin_cache[ph * D + d0]);
+ float cw = __bfloat162float(cos_cache[pw * D + d0]);
+ float sw = __bfloat162float(sin_cache[pw * D + d0]);
+
+ float cos_val, sin_val;
+ if (d0 < (unsigned int)s0) {
+ cos_val = ct; sin_val = st;
+ } else if (d0 < (unsigned int)(s0 + s1)) {
+ cos_val = ch; sin_val = sh;
+ } else {
+ cos_val = cw; sin_val = sw;
+ }
+
+ // --- strides ---
+ // q input: (B, T, Hq, D)
+ int q_stride_H = D;
+ int q_stride_T = Hq * D;
+ int q_base = b * T * q_stride_T + t * q_stride_T;
+
+ // q_rot output: (B, Hq, T, D)
+ int qo_stride_T = D;
+ int qo_stride_H = T * D;
+ int qo_stride_B = Hq * T * D;
+ int qo_base = b * qo_stride_B + t * qo_stride_T;
+
+ // k input: (B, T, Hkv, D)
+ int k_stride_H = D;
+ int k_stride_T = Hkv * D;
+ int k_base = b * T * k_stride_T + t * k_stride_T;
+
+ // k_rot output: (B, Hkv, T, D)
+ int ko_stride_T = D;
+ int ko_stride_H = T * D;
+ int ko_stride_B = Hkv * T * D;
+ int ko_base = b * ko_stride_B + t * ko_stride_T;
+
+ // --- rotate q heads ---
+ for (int h = 0; h < Hq; ++h) {
+ float q0 = __bfloat162float(q[q_base + h * q_stride_H + d0]);
+ float q1 = __bfloat162float(q[q_base + h * q_stride_H + d1]);
+ int off = qo_base + h * qo_stride_H;
+ q_rot[off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
+ q_rot[off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
+ }
+
+ // --- rotate k heads ---
+ for (int h = 0; h < Hkv; ++h) {
+ float k0 = __bfloat162float(k[k_base + h * k_stride_H + d0]);
+ float k1 = __bfloat162float(k[k_base + h * k_stride_H + d1]);
+ int off = ko_base + h * ko_stride_H;
+ k_rot[off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
+ k_rot[off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
+ }
+}
+
+torch::Tensor launch_mrope(torch::Tensor q, torch::Tensor k,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ torch::Tensor q_rot, torch::Tensor k_rot,
+ int B, int T, int Hq, int Hkv, int D,
+ int s0, int s1, int s2) {
+ int blocks = B * T;
+ int threads = D / 2;
+ mrope_fused_kernel<<<blocks, threads>>>(
+ (const __nv_bfloat16*)q.data_ptr(),
+ (const __nv_bfloat16*)k.data_ptr(),
+ (const int64_t*)pos_t.data_ptr(),
+ (const int64_t*)pos_h.data_ptr(),
+ (const int64_t*)pos_w.data_ptr(),
+ (const __nv_bfloat16*)cos_cache.data_ptr(),
+ (const __nv_bfloat16*)sin_cache.data_ptr(),
+ (__nv_bfloat16*)q_rot.data_ptr(),
+ (__nv_bfloat16*)k_rot.data_ptr(),
+ B, T, Hq, Hkv, D, s0, s1, s2
+ );
+ return q_rot;
+}
+"""
+
+_cpp_src = """
+torch::Tensor launch_mrope(torch::Tensor q, torch::Tensor k,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ torch::Tensor q_rot, torch::Tensor k_rot,
+ int B, int T, int Hq, int Hkv, int D,
+ int s0, int s1, int s2);
+"""
+
+
+def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
+ return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
+ mrope_section: tuple[int, int, int], max_pos: int):
+ super().__init__()
+ assert sum(mrope_section) == D // 2, \
+ f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
+ self.B, self.T = B, T
+ self.Hq, self.Hkv, self.D = Hq, Hkv, D
+ self.mrope_section = tuple(mrope_section)
+
... (1919 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<7585 chars \u2014 see diff>"
}Wrote file successfully.{
"description": "Run correctness check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"timeout": 120000
}PASS
{
"description": "Run performance benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"timeout": 300000
}shape=0 variant=eager tflops=0.235 gbps=336.609 ms=0.268
shape=0 variant=compiled tflops=0.462 gbps=663.403 ms=0.136
shape=0 variant=solution tflops=0.795 gbps=1140.538 ms=0.079
shape=0 solution_peak_fraction=0.6336
shape=1 variant=eager tflops=0.185 gbps=270.128 ms=0.544
shape=1 variant=compiled tflops=0.501 gbps=731.068 ms=0.201
shape=1 variant=solution tflops=0.857 gbps=1250.827 ms=0.118
shape=1 solution_peak_fraction=0.6949
shape=2 variant=eager tflops=0.058 gbps=91.193 ms=0.243
shape=2 variant=compiled tflops=0.227 gbps=354.917 ms=0.062
shape=2 variant=solution tflops=0.752 gbps=1174.511 ms=0.019
shape=2 solution_peak_fraction=0.6525
shape=3 variant=eager tflops=0.179 gbps=256.244 ms=1.409
shape=3 variant=compiled tflops=0.528 gbps=757.729 ms=0.477
shape=3 variant=solution tflops=1.024 gbps=1468.664 ms=0.246
shape=3 solution_peak_fraction=0.8159
peak_fraction: 0.6958
RESULT: OK
{
"description": "Profile with ncu to find bottlenecks",
"command": "<1094 chars \u2014 see below>",
"timeout": 30000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch
import reference
import solution
import shapes
shape = shapes.SHAPES[0]
reference.B = shape['B']; reference.T = shape['T']; reference.Hq = shape['Hq']; reference.Hkv = shape['Hkv']; reference.D = shape['D']; reference.MROPE_SECTION = tuple(shape['mrope_section']); reference.MAX_POS = shape['max_pos']
init_args = reference.get_init_inputs()
ref_model = reference.Model(*init_args).to('cuda').eval()
sol_model = solution.Model(*init_args).to('cuda').eval()
sd = ref_model.state_dict()
sol_model.load_state_dict(sd, strict=True)
torch.manual_seed(42)
inputs = [t.to('cuda') for t in reference.get_inputs()]
# Warmup
for _ in range(5):
sol_model(*inputs)
torch.cuda.synchronize()
# Time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(100):
sol_model(*inputs)
end.record()
torch.cuda.synchronize()
print(f'100 iters: {start.elapsed_time(end):.3f} ms, per iter: {start.elapsed_time(end)/100:.4f} ms')
" 2>&1100 iters: 1.454 ms, per iter: 0.0145 ms
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
-Each thread handles a RoPE pair (d, d+D/2) — no shared memory, no syncthreads.
-The kernel streams q/k once and emits rotated output directly in attention layout.
+Each thread handles a RoPE pair (d, d+D/2). Vectorized loads (2 bf16 as uint32)
+halve transaction count. No shared memory, no syncthreads.
"""
import torch
import torch.nn as nn
@@ -28,7 +28,8 @@
#include <cuda_bf16.h>
#include <stdint.h>
-extern "C" __global__ void mrope_fused_kernel(
+extern "C" __global__ void __launch_bounds__(128, 4)
+mrope_fused_kernel(
const __nv_bfloat16* __restrict__ q,
const __nv_bfloat16* __restrict__ k,
const int64_t* __restrict__ pos_t,
@@ -46,8 +47,15 @@
int b = blk / T;
int t = blk % T;
- int d0 = threadIdx.x;
- int d1 = d0 + D / 2;
+ int half = D / 2;
+
+ // Each thread handles TWO consecutive d0 values (and their mirror pairs).
+ // d0a, d0b = threadIdx.x*2, threadIdx.x*2+1
+ // d1a, d1b = d0a+half, d0b+half
+ int d00 = threadIdx.x * 2;
+ int d01 = d00 + 1;
+ int d10 = d00 + half;
+ int d11 = d01 + half;
// --- gather position ids ---
int off_bt = b * T + t;
@@ -55,62 +63,94 @@
int64_t ph = pos_h[off_bt];
int64_t pw = pos_w[off_bt];
- // --- gather per-axis cos/sin and mix according to mrope_section ---
- float ct = __bfloat162float(cos_cache[pt * D + d0]);
- float st = __bfloat162float(sin_cache[pt * D + d0]);
- float ch = __bfloat162float(cos_cache[ph * D + d0]);
- float sh = __bfloat162float(sin_cache[ph * D + d0]);
- float cw = __bfloat162float(cos_cache[pw * D + d0]);
- float sw = __bfloat162float(sin_cache[pw * D + d0]);
-
- float cos_val, sin_val;
- if (d0 < (unsigned int)s0) {
- cos_val = ct; sin_val = st;
- } else if (d0 < (unsigned int)(s0 + s1)) {
- cos_val = ch; sin_val = sh;
+ int64_t pt_off = pt * D;
+ int64_t ph_off = ph * D;
+ int64_t pw_off = pw * D;
+
+ // Gather cos/sin for d00 (same axis applies to d01 unless on boundary)
+ float ct0 = __bfloat162float(cos_cache[pt_off + d00]);
+ float st0 = __bfloat162float(sin_cache[pt_off + d00]);
+ float ch0 = __bfloat162float(cos_cache[ph_off + d00]);
+ float sh0 = __bfloat162float(sin_cache[ph_off + d00]);
+ float cw0 = __bfloat162float(cos_cache[pw_off + d00]);
+ float sw0 = __bfloat162float(sin_cache[pw_off + d00]);
+
+ float ct1 = __bfloat162float(cos_cache[pt_off + d01]);
+ float st1 = __bfloat162float(sin_cache[pt_off + d01]);
+ float ch1 = __bfloat162float(cos_cache[ph_off + d01]);
+ float sh1 = __bfloat162float(sin_cache[ph_off + d01]);
+ float cw1 = __bfloat162float(cos_cache[pw_off + d01]);
+ float sw1 = __bfloat162float(sin_cache[pw_off + d01]);
+
+ // Mix for d00
+ float cos0, sin0;
+ if (d00 < (unsigned int)s0) {
+ cos0 = ct0; sin0 = st0;
+ } else if (d00 < (unsigned int)(s0 + s1)) {
+ cos0 = ch0; sin0 = sh0;
} else {
- cos_val = cw; sin_val = sw;
+ cos0 = cw0; sin0 = sw0;
+ }
+
+ // Mix for d01
+ float cos1, sin1;
+ if (d01 < (unsigned int)s0) {
+ cos1 = ct1; sin1 = st1;
+ } else if (d01 < (unsigned int)(s0 + s1)) {
+ cos1 = ch1; sin1 = sh1;
+ } else {
+ cos1 = cw1; sin1 = sw1;
}
// --- strides ---
- // q input: (B, T, Hq, D)
int q_stride_H = D;
int q_stride_T = Hq * D;
int q_base = b * T * q_stride_T + t * q_stride_T;
- // q_rot output: (B, Hq, T, D)
int qo_stride_T = D;
int qo_stride_H = T * D;
int qo_stride_B = Hq * T * D;
int qo_base = b * qo_stride_B + t * qo_stride_T;
- // k input: (B, T, Hkv, D)
int k_stride_H = D;
int k_stride_T = Hkv * D;
int k_base = b * T * k_stride_T + t * k_stride_T;
- // k_rot output: (B, Hkv, T, D)
int ko_stride_T = D;
int ko_stride_H = T * D;
int ko_stride_B = Hkv * T * D;
int ko_base = b * ko_stride_B + t * ko_stride_T;
// --- rotate q heads ---
+ #pragma unroll 1
for (int h = 0; h < Hq; ++h) {
- float q0 = __bfloat162float(q[q_base + h * q_stride_H + d0]);
- float q1 = __bfloat162float(q[q_base + h * q_stride_H + d1]);
- int off = qo_base + h * qo_stride_H;
- q_rot[off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
- q_rot[off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
+ int q_h_off = q_base + h * q_stride_H;
+ float q00 = __bfloat162float(q[q_h_off + d00]);
+ float q01 = __bfloat162float(q[q_h_off + d01]);
+ float q10 = __bfloat162float(q[q_h_off + d10]);
+ float q11 = __bfloat162float(q[q_h_off + d11]);
+
+ int qo_h_off = qo_base + h * qo_stride_H;
+ q_rot[qo_h_off + d00] = __float2bfloat16(q00 * cos0 - q10 * sin0);
+ q_rot[qo_h_off + d01] = __float2bfloat16(q01 * cos1 - q11 * sin1);
+ q_rot[qo_h_off + d10] = __float2bfloat16(q10 * cos0 + q00 * sin0);
+ q_rot[qo_h_off + d11] = __float2bfloat16(q11 * cos1 + q01 * sin1);
}
// --- rotate k heads ---
+ #pragma unroll 1
for (int h = 0; h < Hkv; ++h) {
- float k0 = __bfloat162float(k[k_base + h * k_stride_H + d0]);
- float k1 = __bfloat162float(k[k_base + h * k_stride_H + d1]);
- int off = ko_base + h * ko_stride_H;
- k_rot[off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
- k_rot[off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
+ int k_h_off = k_base + h * k_stride_H;
+ float k00 = __bfloat162float(k[k_h_off + d00]);
+ float k01 = __bfloat162float(k[k_h_off + d01]);
+ float k10 = __bfloat16
... (1232 more chars){
"content": "<9007 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
-Each thread handles a RoPE pair (d, d+D/2). Vectorized loads (2 bf16 as uint32)
-halve transaction count. No shared memory, no syncthreads.
+Each thread handles one RoPE pair (d, d+D/2). Uses __ldg for cache reads.
+No shared memory, no syncthreads.
"""
import torch
import torch.nn as nn
@@ -28,7 +28,7 @@
#include <cuda_bf16.h>
#include <stdint.h>
-extern "C" __global__ void __launch_bounds__(128, 4)
+extern "C" __global__ void
mrope_fused_kernel(
const __nv_bfloat16* __restrict__ q,
const __nv_bfloat16* __restrict__ k,
@@ -47,59 +47,31 @@
int b = blk / T;
int t = blk % T;
- int half = D / 2;
-
- // Each thread handles TWO consecutive d0 values (and their mirror pairs).
- // d0a, d0b = threadIdx.x*2, threadIdx.x*2+1
- // d1a, d1b = d0a+half, d0b+half
- int d00 = threadIdx.x * 2;
- int d01 = d00 + 1;
- int d10 = d00 + half;
- int d11 = d01 + half;
-
- // --- gather position ids ---
+ int d0 = threadIdx.x;
+ int d1 = d0 + D / 2;
+
int off_bt = b * T + t;
- int64_t pt = pos_t[off_bt];
- int64_t ph = pos_h[off_bt];
- int64_t pw = pos_w[off_bt];
-
- int64_t pt_off = pt * D;
- int64_t ph_off = ph * D;
- int64_t pw_off = pw * D;
-
- // Gather cos/sin for d00 (same axis applies to d01 unless on boundary)
- float ct0 = __bfloat162float(cos_cache[pt_off + d00]);
- float st0 = __bfloat162float(sin_cache[pt_off + d00]);
- float ch0 = __bfloat162float(cos_cache[ph_off + d00]);
- float sh0 = __bfloat162float(sin_cache[ph_off + d00]);
- float cw0 = __bfloat162float(cos_cache[pw_off + d00]);
- float sw0 = __bfloat162float(sin_cache[pw_off + d00]);
-
- float ct1 = __bfloat162float(cos_cache[pt_off + d01]);
- float st1 = __bfloat162float(sin_cache[pt_off + d01]);
- float ch1 = __bfloat162float(cos_cache[ph_off + d01]);
- float sh1 = __bfloat162float(sin_cache[ph_off + d01]);
- float cw1 = __bfloat162float(cos_cache[pw_off + d01]);
- float sw1 = __bfloat162float(sin_cache[pw_off + d01]);
-
- // Mix for d00
- float cos0, sin0;
- if (d00 < (unsigned int)s0) {
- cos0 = ct0; sin0 = st0;
- } else if (d00 < (unsigned int)(s0 + s1)) {
- cos0 = ch0; sin0 = sh0;
+
+ int64_t pt_off = pos_t[off_bt] * D;
+ int64_t ph_off = pos_h[off_bt] * D;
+ int64_t pw_off = pos_w[off_bt] * D;
+
+ // Gather and mix cos/sin for d0 (d1 mirrors d0 due to mrope mirroring)
+ float ct = __bfloat162float(__ldg(cos_cache + pt_off + d0));
+ float st = __bfloat162float(__ldg(sin_cache + pt_off + d0));
+ float ch = __bfloat162float(__ldg(cos_cache + ph_off + d0));
+ float sh = __bfloat162float(__ldg(sin_cache + ph_off + d0));
+ float cw = __bfloat162float(__ldg(cos_cache + pw_off + d0));
+ float sw = __bfloat162float(__ldg(sin_cache + pw_off + d0));
+
+ float cos_val, sin_val;
+ unsigned int d0u = d0;
+ if (d0u < (unsigned int)s0) {
+ cos_val = ct; sin_val = st;
+ } else if (d0u < (unsigned int)(s0 + s1)) {
+ cos_val = ch; sin_val = sh;
} else {
- cos0 = cw0; sin0 = sw0;
- }
-
- // Mix for d01
- float cos1, sin1;
- if (d01 < (unsigned int)s0) {
- cos1 = ct1; sin1 = st1;
- } else if (d01 < (unsigned int)(s0 + s1)) {
- cos1 = ch1; sin1 = sh1;
- } else {
- cos1 = cw1; sin1 = sw1;
+ cos_val = cw; sin_val = sw;
}
// --- strides ---
@@ -122,35 +94,27 @@
int ko_base = b * ko_stride_B + t * ko_stride_T;
// --- rotate q heads ---
- #pragma unroll 1
for (int h = 0; h < Hq; ++h) {
- int q_h_off = q_base + h * q_stride_H;
- float q00 = __bfloat162float(q[q_h_off + d00]);
- float q01 = __bfloat162float(q[q_h_off + d01]);
- float q10 = __bfloat162float(q[q_h_off + d10]);
- float q11 = __bfloat162float(q[q_h_off + d11]);
-
+ int q_h_off = q_base + h * q_stride_H;
int qo_h_off = qo_base + h * qo_stride_H;
- q_rot[qo_h_off + d00] = __float2bfloat16(q00 * cos0 - q10 * sin0);
- q_rot[qo_h_off + d01] = __float2bfloat16(q01 * cos1 - q11 * sin1);
- q_rot[qo_h_off + d10] = __float2bfloat16(q10 * cos0 + q00 * sin0);
- q_rot[qo_h_off + d11] = __float2bfloat16(q11 * cos1 + q01 * sin1);
+
+ float q0 = __bfloat162float(q[q_h_off + d0]);
+ float q1 = __bfloat162float(q[q_h_off + d1]);
+
+ q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
+ q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
}
// --- rotate k heads ---
- #pragma unroll 1
for (int h = 0; h < Hkv; ++h) {
- int k_h_off = k_base + h * k_stride_H;
- float k00 = __bfloat162float(k[k_h_off + d00]);
- float k01 = __bfloat162float(k[k_h_off + d01]);
- float k10 = __bfloat162float(k[k_h_off + d10]);
- float k11 = __bfloat162float(k[k_h_off + d11]);
-
+ int k_h_off = k_base + h * k_stride_H;
int ko_h_off = ko_base + h * ko_stride_H;
- k_rot[ko_h_off + d00] = __float2bfloat16(k00 * cos0 - k10 * sin0);
- k_rot[ko_h_off + d01] = __float2bfloat16(k01 * cos1 - k11 * sin1);
- k_rot[ko_h_off + d10] = __float2bfloat16(k10 * cos0 + k00 * sin0);
- k_rot[ko_h_off + d11] = __float2bfloat16(k11 * cos1 + k01 * sin1);
+
+ float k0 = __bfloat162float(k[k_h_off + d0]);
+ float k1 = __bfloat162float(k[k_h_off + d1]);
+
+ k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
+ k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
}
}
@@ -161,7 +125,7 @@
int B, int T, int
... (354 more chars){
"content": "<7454 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"description": "Run correctness check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"timeout": 120000
}PASS
{
"description": "Run benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"timeout": 300000
}shape=0 variant=eager tflops=0.199 gbps=285.235 ms=0.316
shape=0 variant=compiled tflops=0.550 gbps=788.904 ms=0.114
shape=0 variant=solution tflops=0.803 gbps=1151.713 ms=0.078
shape=0 solution_peak_fraction=0.6398
shape=1 variant=eager tflops=0.192 gbps=280.127 ms=0.525
shape=1 variant=compiled tflops=0.532 gbps=777.072 ms=0.189
shape=1 variant=solution tflops=0.863 gbps=1260.783 ms=0.117
shape=1 solution_peak_fraction=0.7004
shape=2 variant=eager tflops=0.097 gbps=151.016 ms=0.146
shape=2 variant=compiled tflops=0.207 gbps=324.127 ms=0.068
shape=2 variant=solution tflops=0.567 gbps=885.586 ms=0.025
shape=2 solution_peak_fraction=0.4920
shape=3 variant=eager tflops=0.177 gbps=254.571 ms=1.418
shape=3 variant=compiled tflops=0.485 gbps=696.573 ms=0.518
shape=3 variant=solution tflops=0.893 gbps=1282.036 ms=0.282
shape=3 solution_peak_fraction=0.7122
peak_fraction: 0.6295
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
-Each thread handles one RoPE pair (d, d+D/2). Uses __ldg for cache reads.
-No shared memory, no syncthreads.
+Each thread handles one RoPE pair (d, d+D/2) and reads cos/sin only from
+the axis that owns its slice (per mrope_section), cutting cache reads by 3x.
"""
import torch
import torch.nn as nn
@@ -52,26 +52,20 @@
int off_bt = b * T + t;
- int64_t pt_off = pos_t[off_bt] * D;
- int64_t ph_off = pos_h[off_bt] * D;
- int64_t pw_off = pos_w[off_bt] * D;
-
- // Gather and mix cos/sin for d0 (d1 mirrors d0 due to mrope mirroring)
- float ct = __bfloat162float(__ldg(cos_cache + pt_off + d0));
- float st = __bfloat162float(__ldg(sin_cache + pt_off + d0));
- float ch = __bfloat162float(__ldg(cos_cache + ph_off + d0));
- float sh = __bfloat162float(__ldg(sin_cache + ph_off + d0));
- float cw = __bfloat162float(__ldg(cos_cache + pw_off + d0));
- float sw = __bfloat162float(__ldg(sin_cache + pw_off + d0));
-
+ // Each thread reads cos/sin only from the axis that owns its d0 slice.
+ // d0 is in [0, D/2). mrope_section maps s0 elems to temporal, s1 to
+ // height, s2 to width. Mirrored into [D/2, D) so d1 shares the same axis.
float cos_val, sin_val;
unsigned int d0u = d0;
if (d0u < (unsigned int)s0) {
- cos_val = ct; sin_val = st;
+ cos_val = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
+ sin_val = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
} else if (d0u < (unsigned int)(s0 + s1)) {
- cos_val = ch; sin_val = sh;
+ cos_val = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
+ sin_val = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
} else {
- cos_val = cw; sin_val = sw;
+ cos_val = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
+ sin_val = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
}
// --- strides ---
{
"content": "<7454 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"description": "Run correctness check",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"timeout": 120000
}PASS
{
"description": "Run benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"timeout": 300000
}shape=0 variant=eager tflops=0.199 gbps=285.495 ms=0.316
shape=0 variant=compiled tflops=0.569 gbps=815.942 ms=0.111
shape=0 variant=solution tflops=0.802 gbps=1151.478 ms=0.078
shape=0 solution_peak_fraction=0.6397
shape=1 variant=eager tflops=0.193 gbps=281.156 ms=0.523
shape=1 variant=compiled tflops=0.529 gbps=771.979 ms=0.190
shape=1 variant=solution tflops=0.875 gbps=1278.148 ms=0.115
shape=1 solution_peak_fraction=0.7101
shape=2 variant=eager tflops=0.097 gbps=151.181 ms=0.146
shape=2 variant=compiled tflops=0.224 gbps=349.886 ms=0.063
shape=2 variant=solution tflops=0.711 gbps=1111.254 ms=0.020
shape=2 solution_peak_fraction=0.6174
shape=3 variant=eager tflops=0.171 gbps=245.062 ms=1.474
shape=3 variant=compiled tflops=0.483 gbps=693.491 ms=0.521
shape=3 variant=solution tflops=0.898 gbps=1288.698 ms=0.280
shape=3 solution_peak_fraction=0.7159
peak_fraction: 0.6694
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
-Each thread handles one RoPE pair (d, d+D/2) and reads cos/sin only from
-the axis that owns its slice (per mrope_section), cutting cache reads by 3x.
+Each block processes TILE_BT (b,t) pairs to reduce grid size on long sequences.
+Each thread handles one RoPE pair (d, d+D/2). No shared memory, no syncthreads.
"""
import torch
import torch.nn as nn
@@ -22,6 +22,10 @@
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768
+# Each block handles this many (b,t) pairs. Tuned for 188-SM Blackwell:
+# B*T <= 4096: TILE_BT=1 (all blocks fit in one wave)
+# B*T > 4096: TILE_BT=4 (reduce grid size to <= 4096)
+TILE_BT = 4
_cuda_src = r"""
#include <cuda_runtime.h>
@@ -40,75 +44,72 @@
__nv_bfloat16* __restrict__ q_rot,
__nv_bfloat16* __restrict__ k_rot,
int B, int T, int Hq, int Hkv, int D,
- int s0, int s1, int s2
+ int s0, int s1, int s2, int total_pairs
) {
- int blk = blockIdx.x;
- if (blk >= B * T) return;
-
- int b = blk / T;
- int t = blk % T;
int d0 = threadIdx.x;
int d1 = d0 + D / 2;
- int off_bt = b * T + t;
-
- // Each thread reads cos/sin only from the axis that owns its d0 slice.
- // d0 is in [0, D/2). mrope_section maps s0 elems to temporal, s1 to
- // height, s2 to width. Mirrored into [D/2, D) so d1 shares the same axis.
- float cos_val, sin_val;
- unsigned int d0u = d0;
- if (d0u < (unsigned int)s0) {
- cos_val = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
- sin_val = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
- } else if (d0u < (unsigned int)(s0 + s1)) {
- cos_val = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
- sin_val = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
- } else {
- cos_val = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
- sin_val = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
- }
-
- // --- strides ---
+ // Strides — constant across all pairs
int q_stride_H = D;
int q_stride_T = Hq * D;
- int q_base = b * T * q_stride_T + t * q_stride_T;
-
int qo_stride_T = D;
int qo_stride_H = T * D;
int qo_stride_B = Hq * T * D;
- int qo_base = b * qo_stride_B + t * qo_stride_T;
-
int k_stride_H = D;
int k_stride_T = Hkv * D;
- int k_base = b * T * k_stride_T + t * k_stride_T;
-
int ko_stride_T = D;
int ko_stride_H = T * D;
int ko_stride_B = Hkv * T * D;
- int ko_base = b * ko_stride_B + t * ko_stride_T;
-
- // --- rotate q heads ---
- for (int h = 0; h < Hq; ++h) {
- int q_h_off = q_base + h * q_stride_H;
- int qo_h_off = qo_base + h * qo_stride_H;
-
- float q0 = __bfloat162float(q[q_h_off + d0]);
- float q1 = __bfloat162float(q[q_h_off + d1]);
-
- q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
- q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
- }
-
- // --- rotate k heads ---
- for (int h = 0; h < Hkv; ++h) {
- int k_h_off = k_base + h * k_stride_H;
- int ko_h_off = ko_base + h * ko_stride_H;
-
- float k0 = __bfloat162float(k[k_h_off + d0]);
- float k1 = __bfloat162float(k[k_h_off + d1]);
-
- k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
- k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
+
+ // Grid-stride loop over (b,t) pairs
+ for (int pair_idx = blockIdx.x; pair_idx < total_pairs; pair_idx += gridDim.x) {
+ int b = pair_idx / T;
+ int t = pair_idx % T;
+
+ int off_bt = b * T + t;
+
+ // Gather all six per-axis cos/sin for this d0, then mix.
+ float ct = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
+ float st = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
+ float ch = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
+ float sh = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
+ float cw = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
+ float sw = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
+
+ float cos_val, sin_val;
+ unsigned int d0u = d0;
+ if (d0u < (unsigned int)s0) {
+ cos_val = ct; sin_val = st;
+ } else if (d0u < (unsigned int)(s0 + s1)) {
+ cos_val = ch; sin_val = sh;
+ } else {
+ cos_val = cw; sin_val = sw;
+ }
+
+ int q_base = b * T * q_stride_T + t * q_stride_T;
+ int qo_base = b * qo_stride_B + t * qo_stride_T;
+ int k_base = b * T * k_stride_T + t * k_stride_T;
+ int ko_base = b * ko_stride_B + t * ko_stride_T;
+
+ // Rotate q heads
+ for (int h = 0; h < Hq; ++h) {
+ int q_h_off = q_base + h * q_stride_H;
+ int qo_h_off = qo_base + h * qo_stride_H;
+ float q0 = __bfloat162float(q[q_h_off + d0]);
+ float q1 = __bfloat162float(q[q_h_off + d1]);
+ q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
+ q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
+ }
+
+ // Rotate k heads
+ for (int h = 0; h < Hkv; ++h) {
+ int k_h_off = k_base + h * k_stride_H;
+ int ko_h_off = ko_base + h * ko_stride_H;
+ float k0 = __bfloat162float(k[k_h_off + d0]);
+ float k1 = __bfloat162float(k[k_h_off + d1]);
+ k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
+ k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
+ }
}
}
@@ -118
... (1149 more chars){
"content": "<8210 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -122,12 +122,11 @@
int total_pairs = B * T;
int threads = D / 2;
- // Heuristic: keep grid size near SM count for good occupancy.
- // SM count = 188. Target ~2x SM count = 376 blocks minimum.
+ // Grid-stride: if total_pairs fits in one occupancy wave (~4096 blocks),
+ // launch one block per pair. Otherwise, tile to keep grid <= 8192.
int grid = total_pairs;
- if (total_pairs > 8 * 376) grid = (total_pairs + 7) / 8;
- else if (total_pairs > 4 * 376) grid = (total_pairs + 3) / 4;
- else if (total_pairs > 2 * 376) grid = (total_pairs + 1) / 2;
+ if (total_pairs > 8192) grid = (total_pairs + 3) / 4;
+ else if (total_pairs > 4096) grid = (total_pairs + 1) / 2;
mrope_fused_kernel<<<grid, threads>>>(
(const __nv_bfloat16*)q.data_ptr(),
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<352 chars \u2014 see below>",
"newString": "<298 chars \u2014 see below>"
} // Heuristic: keep grid size near SM count for good occupancy.
// SM count = 188. Target ~2x SM count = 376 blocks minimum.
int grid = total_pairs;
if (total_pairs > 8 * 376) grid = (total_pairs + 7) / 8;
else if (total_pairs > 4 * 376) grid = (total_pairs + 3) / 4;
else if (total_pairs > 2 * 376) grid = (total_pairs + 1) / 2; // Grid-stride: if total_pairs fits in one occupancy wave (~4096 blocks),
// launch one block per pair. Otherwise, tile to keep grid <= 8192.
int grid = total_pairs;
if (total_pairs > 8192) grid = (total_pairs + 3) / 4;
else if (total_pairs > 4096) grid = (total_pairs + 1) / 2;Edit applied successfully.{
"description": "Check correctness then benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000
}PASS
shape=0 variant=eager tflops=0.200 gbps=286.860 ms=0.315
shape=0 variant=compiled tflops=0.566 gbps=812.067 ms=0.111
shape=0 variant=solution tflops=0.856 gbps=1228.443 ms=0.073
shape=0 solution_peak_fraction=0.6825
shape=1 variant=eager tflops=0.192 gbps=280.059 ms=0.525
shape=1 variant=compiled tflops=0.529 gbps=773.084 ms=0.190
shape=1 variant=solution tflops=0.866 gbps=1264.427 ms=0.116
shape=1 solution_peak_fraction=0.7025
shape=2 variant=eager tflops=0.098 gbps=153.090 ms=0.144
shape=2 variant=compiled tflops=0.227 gbps=355.373 ms=0.062
shape=2 variant=solution tflops=0.695 gbps=1085.939 ms=0.020
shape=2 solution_peak_fraction=0.6033
shape=3 variant=eager tflops=0.175 gbps=251.459 ms=1.436
shape=3 variant=compiled tflops=0.486 gbps=697.973 ms=0.517
shape=3 variant=solution tflops=0.880 gbps=1263.377 ms=0.286
shape=3 solution_peak_fraction=0.7019
peak_fraction: 0.6712
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,8 @@
Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
-Each block processes TILE_BT (b,t) pairs to reduce grid size on long sequences.
Each thread handles one RoPE pair (d, d+D/2). No shared memory, no syncthreads.
+One block per (b,t) pair — simple, no grid-stride overhead.
"""
import torch
import torch.nn as nn
@@ -22,10 +22,6 @@
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768
-# Each block handles this many (b,t) pairs. Tuned for 188-SM Blackwell:
-# B*T <= 4096: TILE_BT=1 (all blocks fit in one wave)
-# B*T > 4096: TILE_BT=4 (reduce grid size to <= 4096)
-TILE_BT = 4
_cuda_src = r"""
#include <cuda_runtime.h>
@@ -44,72 +40,77 @@
__nv_bfloat16* __restrict__ q_rot,
__nv_bfloat16* __restrict__ k_rot,
int B, int T, int Hq, int Hkv, int D,
- int s0, int s1, int s2, int total_pairs
+ int s0, int s1, int s2
) {
+ int blk = blockIdx.x;
+ if (blk >= B * T) return;
+
+ int b = blk / T;
+ int t = blk % T;
int d0 = threadIdx.x;
int d1 = d0 + D / 2;
- // Strides — constant across all pairs
+ int off_bt = b * T + t;
+
+ // Gather per-axis cos/sin and mix according to mrope_section
+ float ct = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
+ float st = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
+ float ch = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
+ float sh = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
+ float cw = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
+ float sw = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
+
+ float cos_val, sin_val;
+ unsigned int d0u = d0;
+ if (d0u < (unsigned int)s0) {
+ cos_val = ct; sin_val = st;
+ } else if (d0u < (unsigned int)(s0 + s1)) {
+ cos_val = ch; sin_val = sh;
+ } else {
+ cos_val = cw; sin_val = sw;
+ }
+
+ // --- strides ---
int q_stride_H = D;
int q_stride_T = Hq * D;
+ int q_base = b * T * q_stride_T + t * q_stride_T;
+
int qo_stride_T = D;
int qo_stride_H = T * D;
int qo_stride_B = Hq * T * D;
+ int qo_base = b * qo_stride_B + t * qo_stride_T;
+
int k_stride_H = D;
int k_stride_T = Hkv * D;
+ int k_base = b * T * k_stride_T + t * k_stride_T;
+
int ko_stride_T = D;
int ko_stride_H = T * D;
int ko_stride_B = Hkv * T * D;
-
- // Grid-stride loop over (b,t) pairs
- for (int pair_idx = blockIdx.x; pair_idx < total_pairs; pair_idx += gridDim.x) {
- int b = pair_idx / T;
- int t = pair_idx % T;
-
- int off_bt = b * T + t;
-
- // Gather all six per-axis cos/sin for this d0, then mix.
- float ct = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
- float st = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
- float ch = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
- float sh = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
- float cw = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
- float sw = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
-
- float cos_val, sin_val;
- unsigned int d0u = d0;
- if (d0u < (unsigned int)s0) {
- cos_val = ct; sin_val = st;
- } else if (d0u < (unsigned int)(s0 + s1)) {
- cos_val = ch; sin_val = sh;
- } else {
- cos_val = cw; sin_val = sw;
- }
-
- int q_base = b * T * q_stride_T + t * q_stride_T;
- int qo_base = b * qo_stride_B + t * qo_stride_T;
- int k_base = b * T * k_stride_T + t * k_stride_T;
- int ko_base = b * ko_stride_B + t * ko_stride_T;
-
- // Rotate q heads
- for (int h = 0; h < Hq; ++h) {
- int q_h_off = q_base + h * q_stride_H;
- int qo_h_off = qo_base + h * qo_stride_H;
- float q0 = __bfloat162float(q[q_h_off + d0]);
- float q1 = __bfloat162float(q[q_h_off + d1]);
- q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
- q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
- }
-
- // Rotate k heads
- for (int h = 0; h < Hkv; ++h) {
- int k_h_off = k_base + h * k_stride_H;
- int ko_h_off = ko_base + h * ko_stride_H;
- float k0 = __bfloat162float(k[k_h_off + d0]);
- float k1 = __bfloat162float(k[k_h_off + d1]);
- k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
- k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
- }
+ int ko_base = b * ko_stride_B + t * ko_stride_T;
+
+ // --- rotate q heads ---
+ for (int h = 0; h < Hq; ++h) {
+ int q_h_off = q_base + h * q_stride_H;
+ int qo_h_off = qo_base + h * qo_stride_H;
+
+ float q0 = __bfloat162float(q[q_h_off + d0]);
+ float q1 = __bfloat162float(q[q_h_off + d1]);
+
+ q_rot[qo_h_off + d0] = __float2bfloat16(q0 * cos_val - q1 * sin_val);
+ q_rot[qo_h_off + d1] = __float2bfloat16(q1 * cos_val + q0 * sin_val);
+ }
+
+ // --- rotate k heads ---
+ for (int h = 0; h < Hkv; ++h) {
+ int k_h_off = k_base + h * k_stride_H;
+ int ko_h_off = ko_base + h * ko_stride_H;
+
+ float k0 = __bfloat162float(k[k_h_off + d0]);
+ float k1 = __bfloat162float(k[k_h_off + d1]);
+
+ k_rot[ko_h_off + d0] = __float2bfloat16(k0 * cos_val - k1 * sin_val);
+ k_rot[ko_h_off + d1] = __float2bfloat16(k1 * cos_val + k0 * sin_val);
}
}
@@ -119,16 +120,9 @@
torch::Tensor q_rot, torch::Tensor k_rot,
int B, int T, int Hq, int Hkv, int D,
... (934 more chars){
"content": "<7373 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"description": "Check correctness and benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000
}PASS
shape=0 variant=eager tflops=0.178 gbps=256.093 ms=0.353
shape=0 variant=compiled tflops=0.505 gbps=724.851 ms=0.125
shape=0 variant=solution tflops=0.784 gbps=1125.522 ms=0.080
shape=0 solution_peak_fraction=0.6253
shape=1 variant=eager tflops=0.179 gbps=261.964 ms=0.561
shape=1 variant=compiled tflops=0.501 gbps=731.708 ms=0.201
shape=1 variant=solution tflops=0.831 gbps=1213.329 ms=0.121
shape=1 solution_peak_fraction=0.6741
shape=2 variant=eager tflops=0.070 gbps=109.384 ms=0.202
shape=2 variant=compiled tflops=0.113 gbps=176.169 ms=0.126
shape=2 variant=solution tflops=0.410 gbps=640.000 ms=0.035
shape=2 solution_peak_fraction=0.3556
shape=3 variant=eager tflops=0.171 gbps=245.371 ms=1.472
shape=3 variant=compiled tflops=0.531 gbps=761.770 ms=0.474
shape=3 variant=solution tflops=1.011 gbps=1450.727 ms=0.249
shape=3 solution_peak_fraction=0.8060
peak_fraction: 0.5895
RESULT: OK
{
"description": "Clear JIT cache and re-benchmark",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && rm -rf ~/.cache/torch_extensions/py311_cu128/mrope_fused 2>/dev/null; uv run python benchmark.py 2>&1",
"timeout": 300000
}{
"description": "Run benchmark again for stable numbers",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"timeout": 300000
}shape=0 variant=eager tflops=0.234 gbps=336.108 ms=0.269
shape=0 variant=compiled tflops=0.490 gbps=703.346 ms=0.128
shape=0 variant=solution tflops=0.768 gbps=1101.570 ms=0.082
shape=0 solution_peak_fraction=0.6120
shape=1 variant=eager tflops=0.179 gbps=260.981 ms=0.563
shape=1 variant=compiled tflops=0.678 gbps=989.587 ms=0.149
shape=1 variant=solution tflops=1.024 gbps=1494.847 ms=0.098
shape=1 solution_peak_fraction=0.8305
shape=2 variant=eager tflops=0.138 gbps=215.831 ms=0.102
shape=2 variant=compiled tflops=0.170 gbps=265.948 ms=0.083
shape=2 variant=solution tflops=0.493 gbps=769.710 ms=0.029
shape=2 solution_peak_fraction=0.4276
shape=3 variant=eager tflops=0.170 gbps=244.208 ms=1.479
shape=3 variant=compiled tflops=0.468 gbps=671.715 ms=0.538
shape=3 variant=solution tflops=0.926 gbps=1328.602 ms=0.272
shape=3 solution_peak_fraction=0.7381
peak_fraction: 0.6329
RESULT: OK
{
"description": "Run benchmarks 3 times to check variance",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && for i in 1 2 3; do echo \"=== Run $i ===\" && uv run python benchmark.py 2>&1 | grep -E \"peak_fraction|shape=.*variant=solution\"; done",
"timeout": 600000
}=== Run 1 ===
shape=0 variant=solution tflops=0.780 gbps=1119.270 ms=0.081
shape=0 solution_peak_fraction=0.6218
shape=1 variant=solution tflops=0.842 gbps=1229.567 ms=0.120
shape=1 solution_peak_fraction=0.6831
shape=2 variant=solution tflops=1.011 gbps=1579.886 ms=0.014
shape=2 solution_peak_fraction=0.8777
shape=3 variant=solution tflops=0.937 gbps=1344.271 ms=0.269
shape=3 solution_peak_fraction=0.7468
peak_fraction: 0.7264
=== Run 2 ===
shape=0 variant=solution tflops=1.061 gbps=1522.872 ms=0.059
shape=0 solution_peak_fraction=0.8460
shape=1 variant=solution tflops=0.854 gbps=1246.415 ms=0.118
shape=1 solution_peak_fraction=0.6925
shape=2 variant=solution tflops=0.982 gbps=1534.295 ms=0.014
shape=2 solution_peak_fraction=0.8524
shape=3 variant=solution tflops=1.009 gbps=1447.099 ms=0.250
shape=3 solution_peak_fraction=0.8039
peak_fraction: 0.7960
=== Run 3 ===
shape=0 variant=solution tflops=0.785 gbps=1125.971 ms=0.080
shape=0 solution_peak_fraction=0.6255
shape=1 variant=solution tflops=0.843 gbps=1230.720 ms=0.119
shape=1 solution_peak_fraction=0.6837
shape=2 variant=solution tflops=0.492 gbps=768.854 ms=0.029
shape=2 solution_peak_fraction=0.4271
shape=3 variant=solution tflops=0.948 gbps=1360.395 ms=0.265
shape=3 solution_peak_fraction=0.7558
peak_fraction: 0.6096
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,8 +3,9 @@
Fuses: per-axis cos/sin cache gather + mrope_section mixing + rotate-half
rotation + (B,T,H,D)->(B,H,T,D) transpose into a single CUDA kernel.
-Each thread handles one RoPE pair (d, d+D/2). No shared memory, no syncthreads.
-One block per (b,t) pair — simple, no grid-stride overhead.
+Each thread handles one RoPE pair (d, d+D/2) and reads cos/sin ONLY from
+the axis assigned to its slice by mrope_section — cutting cache reads by 3x.
+One block per (b,t) pair. No shared memory, no syncthreads.
"""
import torch
import torch.nn as nn
@@ -51,23 +52,22 @@
int d1 = d0 + D / 2;
int off_bt = b * T + t;
-
- // Gather per-axis cos/sin and mix according to mrope_section
- float ct = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
- float st = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
- float ch = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
- float sh = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
- float cw = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
- float sw = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
-
+ unsigned int d0u = d0;
+ unsigned int boundary1 = (unsigned int)s0;
+ unsigned int boundary2 = (unsigned int)(s0 + s1);
+
+ // Each thread reads cos/sin only from the axis that owns its d0 slice.
+ // cos_val/sin_val for d0 also apply to d1 (mirrored by mrope mixing).
float cos_val, sin_val;
- unsigned int d0u = d0;
- if (d0u < (unsigned int)s0) {
- cos_val = ct; sin_val = st;
- } else if (d0u < (unsigned int)(s0 + s1)) {
- cos_val = ch; sin_val = sh;
+ if (d0u < boundary1) {
+ cos_val = __bfloat162float(cos_cache[pos_t[off_bt] * D + d0]);
+ sin_val = __bfloat162float(sin_cache[pos_t[off_bt] * D + d0]);
+ } else if (d0u < boundary2) {
+ cos_val = __bfloat162float(cos_cache[pos_h[off_bt] * D + d0]);
+ sin_val = __bfloat162float(sin_cache[pos_h[off_bt] * D + d0]);
} else {
- cos_val = cw; sin_val = sw;
+ cos_val = __bfloat162float(cos_cache[pos_w[off_bt] * D + d0]);
+ sin_val = __bfloat162float(sin_cache[pos_w[off_bt] * D + d0]);
}
// --- strides ---
{
"content": "<7517 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"description": "Verify correctness then benchmark 3x",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && echo \"---\" && for i in 1 2 3; do echo \"=== Run $i ===\" && uv run python benchmark.py 2>&1 | grep -E \"peak_fraction|shape=.*variant=solution\"; done",
"timeout": 600000
}PASS
---
=== Run 1 ===
shape=0 variant=solution tflops=0.761 gbps=1092.399 ms=0.083
shape=0 solution_peak_fraction=0.6069
shape=1 variant=solution tflops=0.857 gbps=1251.339 ms=0.117
shape=1 solution_peak_fraction=0.6952
shape=2 variant=solution tflops=0.515 gbps=805.125 ms=0.027
shape=2 solution_peak_fraction=0.4473
shape=3 variant=solution tflops=1.015 gbps=1457.096 ms=0.248
shape=3 solution_peak_fraction=0.8095
peak_fraction: 0.6252
=== Run 2 ===
shape=0 variant=solution tflops=1.028 gbps=1475.095 ms=0.061
shape=0 solution_peak_fraction=0.8195
shape=1 variant=solution tflops=0.858 gbps=1253.387 ms=0.117
shape=1 solution_peak_fraction=0.6963
shape=2 variant=solution tflops=0.965 gbps=1507.525 ms=0.015
shape=2 solution_peak_fraction=0.8375
shape=3 variant=solution tflops=0.941 gbps=1349.899 ms=0.268
shape=3 solution_peak_fraction=0.7499
peak_fraction: 0.7737
=== Run 3 ===
shape=0 variant=solution tflops=0.756 gbps=1084.421 ms=0.083
shape=0 solution_peak_fraction=0.6025
shape=1 variant=solution tflops=1.043 gbps=1523.603 ms=0.096
shape=1 solution_peak_fraction=0.8464
shape=2 variant=solution tflops=0.956 gbps=1494.486 ms=0.015
shape=2 solution_peak_fraction=0.8303
shape=3 variant=solution tflops=0.944 gbps=1354.273 ms=0.267
shape=3 solution_peak_fraction=0.7524
peak_fraction: 0.7513
{
"description": "Final correctness and benchmark verification",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && echo \"---FINAL BENCHMARK---\" && uv run python benchmark.py 2>&1",
"timeout": 300000
}PASS
---FINAL BENCHMARK---
shape=0 variant=eager tflops=0.234 gbps=335.988 ms=0.269
shape=0 variant=compiled tflops=0.492 gbps=706.428 ms=0.128
shape=0 variant=solution tflops=0.781 gbps=1120.159 ms=0.081
shape=0 solution_peak_fraction=0.6223
shape=1 variant=eager tflops=0.206 gbps=300.810 ms=0.489
shape=1 variant=compiled tflops=0.679 gbps=991.189 ms=0.148
shape=1 variant=solution tflops=1.028 gbps=1500.952 ms=0.098
shape=1 solution_peak_fraction=0.8339
shape=2 variant=eager tflops=0.137 gbps=214.226 ms=0.103
shape=2 variant=compiled tflops=0.172 gbps=268.219 ms=0.082
shape=2 variant=solution tflops=0.496 gbps=774.454 ms=0.029
shape=2 solution_peak_fraction=0.4303
shape=3 variant=eager tflops=0.177 gbps=254.258 ms=1.420
shape=3 variant=compiled tflops=0.532 gbps=762.723 ms=0.473
shape=3 variant=solution tflops=1.016 gbps=1458.037 ms=0.248
shape=3 solution_peak_fraction=0.8100
peak_fraction: 0.6521
RESULT: OK
{
"description": "Run lint check on solution",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/09_fmha_preattn_mrope/solution.py 2>&1"
}All checks passed!
{
"description": "Final verification: check.py, framework detection, git status",
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && echo \"---framework---\" && cat framework.txt 2>/dev/null && echo \"---git status---\" && git status --short",
"timeout": 120000
}PASS
---framework---
cuda_raw
---git status---
?? ./
?? ../10_patch_embed_conv3d_gemm/
?? ../../scripts/sweep_09_10.sh