"""Multi-axis RoPE pre-attention prep fused kernel (Triton).
Fuses gather + per-axis mixing + rotate-half + transpose into a single
kernel with direct (B, H, T, D) output.
Key optimization:
- Group by (batch, time): one program loads all 6 cos/sin rows once and
processes ALL heads at that (b,t). Eliminates H_total-way redundant
cos/sin cache reads — the main bandwidth bottleneck.
- Process head_dim in two halves to avoid double-loading q/k.
- All rotation ops are element-local, no cross-thread communication.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "rope"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
@triton.jit
def _mrope_kernel(
q_ptr, k_ptr, out_q_ptr, out_k_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
stride_q_b, stride_q_t, stride_q_h,
stride_k_b, stride_k_t, stride_k_h,
stride_out_q_b, stride_out_q_h, stride_out_q_t,
stride_out_k_b, stride_out_k_h, stride_out_k_t,
B, T, D,
s0, s1, s2,
half_D: tl.constexpr,
Hq_: tl.constexpr,
Hkv_: tl.constexpr,
):
b = tl.program_id(0)
t = tl.program_id(1)
# --- Load position IDs once per (b, t) ---
pos_offs = b * T + t
pt = tl.load(pos_t_ptr + pos_offs)
ph = tl.load(pos_h_ptr + pos_offs)
pw = tl.load(pos_w_ptr + pos_offs)
d_offs = tl.arange(0, half_D)
# --- Load all 6 cos/sin rows once (shared across all heads) ---
cos_t = tl.load(cos_cache_ptr + pt * D + d_offs)
sin_t = tl.load(sin_cache_ptr + pt * D + d_offs)
cos_h = tl.load(cos_cache_ptr + ph * D + d_offs)
sin_h = tl.load(sin_cache_ptr + ph * D + d_offs)
cos_w = tl.load(cos_cache_ptr + pw * D + d_offs)
sin_w = tl.load(sin_cache_ptr + pw * D + d_offs)
cos_t_2 = tl.load(cos_cache_ptr + pt * D + d_offs + half_D)
sin_t_2 = tl.load(sin_cache_ptr + pt * D + d_offs + half_D)
cos_h_2 = tl.load(cos_cache_ptr + ph * D + d_offs + half_D)
sin_h_2 = tl.load(sin_cache_ptr + ph * D + d_offs + half_D)
cos_w_2 = tl.load(cos_cache_ptr + pw * D + d_offs + half_D)
sin_w_2 = tl.load(sin_cache_ptr + pw * D + d_offs + half_D)
# --- Per-element axis mixing ---
is_t = d_offs < s0
is_h = (d_offs >= s0) & (d_offs < s0 + s1)
cos_first = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_first = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
cos_second = tl.where(is_t, cos_t_2, tl.where(is_h, cos_h_2, cos_w_2))
sin_second = tl.where(is_t, sin_t_2, tl.where(is_h, sin_h_2, sin_w_2))
# --- Process all q heads ---
for h in range(Hq_):
base_x = b * stride_q_b + t * stride_q_t + h * stride_q_h
x_first = tl.load(q_ptr + base_x + d_offs)
x_second = tl.load(q_ptr + base_x + d_offs + half_D)
out_first = x_first * cos_first + (-x_second) * sin_first
out_second = x_second * cos_second + x_first * sin_second
base_out = b * stride_out_q_b + h * stride_out_q_h + t * stride_out_q_t
tl.store(out_q_ptr + base_out + d_offs, out_first)
tl.store(out_q_ptr + base_out + d_offs + half_D, out_second)
# --- Process all k heads ---
for h in range(Hkv_):
base_x = b * stride_k_b + t * stride_k_t + h * stride_k_h
x_first = tl.load(k_ptr + base_x + d_offs)
x_second = tl.load(k_ptr + base_x + d_offs + half_D)
out_first = x_first * cos_first + (-x_second) * sin_first
out_second = x_second * cos_second + x_first * sin_second
base_out = b * stride_out_k_b + h * stride_out_k_h + t * stride_out_k_t
tl.store(out_k_ptr + base_out + d_offs, out_first)
tl.store(out_k_ptr + base_out + d_offs + half_D, out_second)
class Model(nn.Module):
def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
mrope_section: tuple[int, int, int], max_pos: int):
super().__init__()
self.B, self.T = B, T
self.Hq, self.Hkv = Hq, Hkv
self.D = D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
pos = torch.arange(max_pos, dtype=torch.float32)
freqs = torch.outer(pos, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
def forward(
self,
q: torch.Tensor, k: torch.Tensor,
pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
out_q = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
out_k = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
s0, s1, s2 = self.mrope_section
half_D = D // 2
grid = (B, T)
_mrope_kernel[grid](
q, k, out_q, out_k,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
out_q.stride(0), out_q.stride(1), out_q.stride(2),
out_k.stride(0), out_k.stride(1), out_k.stride(2),
B, T, D,
s0, s1, s2,
half_D=half_D,
Hq_=Hq, Hkv_=Hkv,
num_warps=2,
)
return out_q, out_k
# Module-level shape shims (overwritten per-shape by check.py / benchmark.py).
B = 1
T = 4096
Hq = 32
Hkv = 8
D = 128
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768
def get_inputs():
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
return [q, k, pos_t, pos_h, pos_w]
def get_init_inputs():
return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
shape=0 variant=eager tflops=0.200 gbps=286.991 ms=0.315
shape=0 variant=compiled tflops=0.569 gbps=817.006 ms=0.110
shape=0 variant=solution tflops=0.755 gbps=1083.796 ms=0.083
shape=0 solution_peak_fraction=0.6021
shape=1 variant=eager tflops=0.192 gbps=280.178 ms=0.525
shape=1 variant=compiled tflops=0.531 gbps=776.022 ms=0.189
shape=1 variant=solution tflops=0.798 gbps=1165.461 ms=0.126
shape=1 solution_peak_fraction=0.6475
shape=2 variant=eager tflops=0.097 gbps=151.529 ms=0.146
shape=2 variant=compiled tflops=0.229 gbps=357.302 ms=0.062
shape=2 variant=solution tflops=0.467 gbps=729.884 ms=0.030
shape=2 solution_peak_fraction=0.4055
shape=3 variant=eager tflops=0.175 gbps=251.372 ms=1.437
shape=3 variant=compiled tflops=0.486 gbps=697.434 ms=0.518
shape=3 variant=solution tflops=0.879 gbps=1261.470 ms=0.286
shape=3 solution_peak_fraction=0.7008
peak_fraction: 0.5769
RESULT: OK
PASS
{
"problem": "09_fmha_preattn_mrope",
"harness": "opencode",
"model": "deepseek/deepseek-v4-flash",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.5769,
"elapsed_seconds": 2160,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 52148, "output_tokens": 32634, "cache_read_tokens": 7308160, "cache_creation_tokens": 0, "reasoning_tokens": 83989, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement baseline Triton kernel for mRoPE prep",
"status": "pending",
"priority": "high"
},
{
"content": "Check correctness with python check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize iteratively",
"status": "pending",
"priority": "high"
},
{
"content": "Benchmark and finalize",
... (64 more chars)content: 'Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml'
status: 'in_progress'
priority: 'high'
content: 'Implement baseline Triton kernel for mRoPE prep'
status: 'pending'
priority: 'high'
content: 'Check correctness with python check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize iteratively'
status: 'pending'
priority: 'high'
content: 'Benchmark and finalize'
status: 'pending'
priority: 'high'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
1: """Naive multi-axis RoPE pre-attention reference (correctness only).
2:
3: Mirrors the Qwen2-VL apply_multimodal_rotary_pos_emb semantics: three position
4: axes (temporal, height, width), each contributing a slice of the head_dim via
5: mrope_section. Standard rotate-half on q and k once per-axis cos/sin have been
6: mixed.
7:
8: Layout convention:
9: q: (B, T, Hq, D) bf16
10: k: (B, T, Hkv, D) bf16
11: pos_t, pos_h, pos_w: (B, T) int64
12:
13: Output: q_rot, k_rot in (B, H, T, D) — i.e. the layout an attention kernel
14: expects (head dim second). The transpose is part of the "pre-attention prep".
15: """
16: import torch
17: import torch.nn as nn
18:
19: OP_TYPE = "rope"
20: SUPPORTED_PRECISIONS = ["bf16"]
21: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
22:
23:
24: def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
25: return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
26:
27:
28: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
29: half = x.shape[-1] // 2
30: return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
31:
32:
33: class Model(nn.Module):
34: def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
35: mrope_section: tuple[int, int, int], max_pos: int):
36: super().__init__()
37: assert sum(mrope_section) == D // 2, \
38:
... (3384 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for multi-axis RoPE pre-attention prep.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or
... (3519 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for multi-axis RoPE pre-attention.
2:
3: Sized for Qwen2-VL-style vision-language inference: long T, GQA, mrope_section
4: splitting D/2 across temporal/height/width axes. Mix of base, long-context,
5: small-head-dim, and very-long-context.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL base: T=4k, head_dim=128, GQA 32:8, mrope [16,24,24]
10: {"B": 1, "T": 4096, "Hq": 32, "Hkv": 8, "D": 128,
11: "mrope_section": (16, 24, 24), "max_pos": 32768},
12:
13: # Qwen2-VL 7B-style long context, narrower kv heads
14: {"B": 1, "T": 8192, "Hq": 28, "Hkv": 4, "D": 128,
15: "mrope_section": (16, 24, 24), "max_pos": 32768},
16:
17: # Smaller head_dim and B=2 (tests batch dim and reduced D)
18: {"B": 2, "T": 2048, "Hq": 16, "Hkv": 2, "D": 64,
19: "mrope_section": (8, 12, 12), "max_pos": 16384},
20:
21: # Very long context (16k tokens)
22: {"B": 1, "T": 16384, "Hq": 32, "Hkv": 8, "D": 128,
23: "mrope_section": (16, 24, 24), "max_pos": 65536},
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for multi-axis RoPE pre-attention prep.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime
... (3966 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml</path>
<type>file</type>
<content>
1: name: 09_fmha_preattn_mrope
2: display_name: "Multi-axis RoPE Pre-attention Prep"
3: precision: bf16
4: regime: memory # pointwise rotation + cache gather; bandwidth-bound
5:
6: # FLOPs per call: rotate-half on q and k.
7: # Per element: out = a*cos + b*sin -> 3 flops/elem (1 mul + 1 mul + 1 add).
8: # We do this on q ((B*T*Hq*D) elems) and k ((B*T*Hkv*D) elems).
9: flops_formula: "3 * B * T * (Hq + Hkv) * D"
10:
11: # Bytes moved (lower bound):
12: # read q, k: 2 * B*T*(Hq+Hkv)*D
13: # read 3 position id arrays: 3 * B*T*8 (int64)
14: # read cos/sin from cache: 6 * B*T*D*2 (3 axes, 2 caches, bf16)
15: # write q_rot, k_rot: 2 * B*T*(Hq+Hkv)*D
16: bytes_formula: "4 * B * T * (Hq + Hkv) * D + 12 * B * T * D + 24 * B * T"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram
21:
22: tolerance:
23: bfloat16: 0.02
24:
25: # Forbidden ops -- agent must implement the gather + per-axis mix + rotation.
26: # Banning vendor mRoPE entry points and obvious shortcuts.
27: forbidden:
28: - "transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb"
29: - "apply_multimodal_rotary_pos_emb"
30: - "flash_attn.layers.rotary"
31: - "flashinfer.rope"
32: - "flashinfer.apply_rope"
33:
34: sota:
35: name:
... (284 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for multi-axis RoPE pre-attention.
2:
3: Wraps transformers' apply_multimodal_rotary_pos_emb. The full pipeline (gather
4: per-axis cos/sin from cache, then call the vendor rotation) is timed, matching
5: what the reference and the agent solution do.
6:
7: If transformers isn't installed or the impl signature has drifted, returns
8: is_available()=False and the benchmark omits the SOTA line.
9: """
10: from __future__ import annotations
11:
12: import torch
13:
14:
15: def _try_transformers():
16: try:
17: from transformers.models.qwen2_vl.modeling_qwen2_vl import (
18: apply_multimodal_rotary_pos_emb as _impl,
19: )
20: return _impl
21: except Exception:
22: return None
23:
24:
25: def sota_forward(
26: q: torch.Tensor, k: torch.Tensor,
27: pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
28: cos_cache: torch.Tensor, sin_cache: torch.Tensor,
29: mrope_section: tuple[int, int, int],
30: ):
31: """q,k: (B, T, H, D); pos_*: (B, T); cos_cache,sin_cache: (max_pos, D)."""
32: impl = _try_transformers()
33: if impl is None:
34: raise RuntimeError("transformers not available for SOTA")
35:
36: # Gather per-axis cos/sin and stack into (3, B, T, D) as transformers expects.
37: cos = torch.stack([cos_cache[pos_t], cos_cache[pos_h], c
... (495 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware.py"
}{
"pattern": "**/hardware.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard/src"
}No files found{
"pattern": "**/hardware/**",
"path": "/home/infatoshi/cuda/KernelBench-Hard/src"
}/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__pycache__/rtx_pro_6000.cpython-312.pyc
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__pycache__/m4_max.cpython-312.pyc
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__pycache__/__init__.cpython-312.pyc
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__pycache__/rtx_pro_6000.cpython-311.pyc
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__pycache__/__init__.cpython-311.pyc
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__pycache__/m4_max.cpython-311.pyc
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__init__.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/m4_max.py
/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"command": "nvidia-smi && uv run python --version && uv run python -c \"import triton; print(f'Triton {triton.__version__}')\" 2>&1 | head -20",
"description": "Check GPU, Python, Triton availability"
}Thu May 7 03:13:21 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 Off | Off |
| 30% 26C P8 6W / 600W | 271MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|======================================================================================
... (309 more chars){
"command": "uv run python -c \"import torch; print(f'CUDA {torch.version.cuda}, Torch {torch.__version__}, SM: {torch.cuda.get_device_capability()}')\" 2>&1",
"description": "Check torch and CUDA version"
}CUDA 13.0, Torch 2.11.0+cu130, SM: (12, 0)
{
"command": "<316 chars \u2014 see below>",
"description": "Check Triton API signatures"
}uv run python -c "
import triton.language as tl
import inspect
# Check cat signature
try:
sig = inspect.signature(tl.cat)
print(f'tl.cat: {sig}')
except:
print('no signature')
# Check reshape
try:
sig = inspect.signature(tl.reshape)
print(f'tl.reshape: {sig}')
except:
print('no signature')
"tl.cat: (input, other, can_reorder=False, _semantic=None)
tl.reshape: (input, *shape, can_reorder=False, _semantic=None, _generator=None)
{
"todos": [
{
"priority": "high",
"content": "Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml",
"status": "completed"
},
{
"priority": "high",
"content": "Implement baseline Triton kernel for mRoPE prep",
"status": "in_progress"
},
{
"priority": "high",
"content": "Check correctness with python check.py",
"status": "pending"
},
{
"priority": "high",
"content": "Profile and optimize iteratively",
"status": "pending"
},
{
"priority": "high",
"content
... (66 more chars)content: 'Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml'
status: 'completed'
priority: 'high'
content: 'Implement baseline Triton kernel for mRoPE prep'
status: 'in_progress'
priority: 'high'
content: 'Check correctness with python check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize iteratively'
status: 'pending'
priority: 'high'
content: 'Benchmark and finalize'
status: 'pending'
priority: 'high'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,150 @@
+"""Multi-axis RoPE pre-attention prep fused kernel (Triton).
+
+Fuses gather + per-axis mixing + rotate-half + transpose into a single
+read of q/k with direct (B, H, T, D) output.
+"""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "rope"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+
+
+@triton.jit
+def _mrope_kernel(
+ x_ptr, out_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ stride_x_b, stride_x_t, stride_x_h,
+ stride_out_b, stride_out_h, stride_out_t,
+ B, T, H, D,
+ s0, s1, s2,
+ half_D: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ b = tl.program_id(0)
+ h = tl.program_id(1)
+ t = tl.program_id(2)
+
+ # Load position IDs for this (b, t) — (B, T) int64 layout
+ pos_offs = b * T + t
+ pt = tl.load(pos_t_ptr + pos_offs)
+ ph = tl.load(pos_h_ptr + pos_offs)
+ pw = tl.load(pos_w_ptr + pos_offs)
+
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ # Load x[b, t, h, :]
+ x_offs = b * stride_x_b + t * stride_x_t + h * stride_x_h + d_offs
+ x = tl.load(x_ptr + x_offs, mask=mask, other=0.0)
+
+ # Determine which position axis each d uses
+ d_mod = tl.where(d_offs < half_D, d_offs, d_offs - half_D)
+ pos = tl.where(d_mod < s0, pt, tl.where(d_mod < s0 + s1, ph, pw))
+
+ # Load cos and sin from cache (max_pos, D) bf16, indexed by pos and d
+ cos_offs = pos * D + d_offs
+ cos = tl.load(cos_cache_ptr + cos_offs, mask=mask, other=0.0)
+ sin = tl.load(sin_cache_ptr + cos_offs, mask=mask, other=0.0)
+
+ # rotate_half(x) = cat(-x[half:], x[:half])
+ x_2d = tl.reshape(x, 2, half_D)
+ first = x_2d[0, :]
+ second = x_2d[1, :]
+ rotated = tl.cat(-second, first)
+
+ out = x * cos + rotated * sin
+
+ # Write output in (B, H, T, D) layout
+ out_offs = b * stride_out_b + h * stride_out_h + t * stride_out_t + d_offs
+ tl.store(out_ptr + out_offs, out, mask=mask)
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
+ mrope_section: tuple[int, int, int], max_pos: int):
+ super().__init__()
+ self.B, self.T = B, T
+ self.Hq, self.Hkv = Hq, Hkv
+ self.D = D
+ self.mrope_section = tuple(mrope_section)
+ self.max_pos = max_pos
+
+ inv_freq = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+ pos = torch.arange(max_pos, dtype=torch.float32)
+ freqs = torch.outer(pos, inv_freq)
+ emb = torch.cat([freqs, freqs], dim=-1)
+ self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
+ self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+
+ def forward(
+ self,
+ q: torch.Tensor, k: torch.Tensor,
+ pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ B, T, Hq, D = q.shape
+ _, _, Hkv, _ = k.shape
+
+ out_q = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
+ out_k = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
+
+ s0, s1, s2 = self.mrope_section
+ half_D = D // 2
+
+ BLOCK_D = D
+ num_warps = 4 if D == 128 else 2
+
+ grid_q = (B, Hq, T)
+ _mrope_kernel[grid_q](
+ q, out_q,
+ self.cos_cache, self.sin_cache,
+ pos_t, pos_h, pos_w,
+ q.stride(0), q.stride(1), q.stride(2),
+ out_q.stride(0), out_q.stride(1), out_q.stride(2),
+ B, T, Hq, D,
+ s0, s1, s2,
+ half_D=half_D, BLOCK_D=BLOCK_D,
+ num_warps=num_warps,
+ )
+
+ grid_k = (B, Hkv, T)
+ _mrope_kernel[grid_k](
+ k, out_k,
+ self.cos_cache, self.sin_cache,
+ pos_t, pos_h, pos_w,
+ k.stride(0), k.stride(1), k.stride(2),
+ out_k.stride(0), out_k.stride(1), out_k.stride(2),
+ B, T, Hkv, D,
+ s0, s1, s2,
+ half_D=half_D, BLOCK_D=BLOCK_D,
+ num_warps=num_warps,
+ )
+
+ return out_q, out_k
+
+
+# Module-level shape shims (overwritten per-shape by check.py / benchmark.py).
+B = 1
+T = 4096
+Hq = 32
+Hkv = 8
+D = 128
+MROPE_SECTION = (16, 24, 24)
+MAX_POS = 32768
+
+
+def get_inputs():
+ q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
+ k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
+ pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ return [q, k, pos_t, pos_h, pos_w]
+
+
+def get_init_inputs():
+ return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<4659 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1 | tail -30",
"description": "Run correctness check on initial Triton kernel",
"timeout": 120000
}File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 720, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 849, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 304, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 80, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.er
... (614 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -52,10 +52,10 @@
sin = tl.load(sin_cache_ptr + cos_offs, mask=mask, other=0.0)
# rotate_half(x) = cat(-x[half:], x[:half])
- x_2d = tl.reshape(x, 2, half_D)
- first = x_2d[0, :]
- second = x_2d[1, :]
- rotated = tl.cat(-second, first)
+ # Load paired elements; x is already in L1 from the previous load
+ pair_offs = tl.where(d_offs < half_D, d_offs + half_D, d_offs - half_D)
+ x_pair = tl.load(x_ptr + b * stride_x_b + t * stride_x_t + h * stride_x_h + pair_offs, mask=mask, other=0.0)
+ rotated = tl.where(d_offs < half_D, -x_pair, x_pair)
out = x * cos + rotated * sin
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<202 chars \u2014 see below>",
"newString": "<398 chars \u2014 see below>"
} # rotate_half(x) = cat(-x[half:], x[:half])
x_2d = tl.reshape(x, 2, half_D)
first = x_2d[0, :]
second = x_2d[1, :]
rotated = tl.cat(-second, first)
out = x * cos + rotated * sin # rotate_half(x) = cat(-x[half:], x[:half])
# Load paired elements; x is already in L1 from the previous load
pair_offs = tl.where(d_offs < half_D, d_offs + half_D, d_offs - half_D)
x_pair = tl.load(x_ptr + b * stride_x_b + t * stride_x_t + h * stride_x_h + pair_offs, mask=mask, other=0.0)
rotated = tl.where(d_offs < half_D, -x_pair, x_pair)
out = x * cos + rotated * sinEdit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1 | tail -30",
"description": "Check correctness after fixing rotate_half",
"timeout": 120000
}PASS
{
"todos": [
{
"priority": "high",
"content": "Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml",
"status": "completed"
},
{
"priority": "high",
"content": "Implement baseline Triton kernel for mRoPE prep",
"status": "completed"
},
{
"priority": "high",
"content": "Check correctness with python check.py",
"status": "completed"
},
{
"priority": "high",
"content": "Profile and optimize iteratively",
"status": "in_progress"
},
{
"priority": "high",
"con
... (70 more chars)content: 'Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml'
status: 'completed'
priority: 'high'
content: 'Implement baseline Triton kernel for mRoPE prep'
status: 'completed'
priority: 'high'
content: 'Check correctness with python check.py'
status: 'completed'
priority: 'high'
content: 'Profile and optimize iteratively'
status: 'in_progress'
priority: 'high'
content: 'Benchmark and finalize'
status: 'pending'
priority: 'high'{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run benchmark.py 2>&1",
"description": "Run benchmark to see performance baseline",
"timeout": 300000
}shape=0 variant=eager tflops=0.201 gbps=288.208 ms=0.313
shape=0 variant=compiled tflops=0.559 gbps=802.595 ms=0.112
shape=0 variant=solution tflops=0.409 gbps=586.207 ms=0.154
shape=0 solution_peak_fraction=0.3257
shape=1 variant=eager tflops=0.193 gbps=281.535 ms=0.522
shape=1 variant=compiled tflops=0.530 gbps=773.930 ms=0.190
shape=1 variant=solution tflops=0.415 gbps=606.264 ms=0.242
shape=1 solution_peak_fraction=0.3368
shape=2 variant=eager tflops=0.097 gbps=151.363 ms=0.146
shape=2 variant=compiled tflops=0.222 gbps=346.466 ms=0.064
shape=2 variant=solution tflops=0.262 gbps=409.600 ms=0.054
shape=2 solution_peak_fraction=0.2276
shape=3 variant=eager tflops=0.175 gbps=251.308 ms=1.437
shape=3 variant=compiled tflops=0.485 gbps=695.671 ms=0.519
shape=3 variant=solution tflops=0.440 gbps=631.759 ms=0.572
shape=3 solution_peak_fraction=0.3510
peak_fraction: 0.3059
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -2,6 +2,11 @@
Fuses gather + per-axis mixing + rotate-half + transpose into a single
read of q/k with direct (B, H, T, D) output.
+
+Optimization strategy:
+- Process halves separately to avoid double-loading q/k for rotate_half
+- Each thread handles one element from each half, all ops are element-local
+- Two kernel launches: one for q heads, one for k heads
"""
import torch
import torch.nn as nn
@@ -23,45 +28,35 @@
B, T, H, D,
s0, s1, s2,
half_D: tl.constexpr,
- BLOCK_D: tl.constexpr,
):
b = tl.program_id(0)
h = tl.program_id(1)
t = tl.program_id(2)
- # Load position IDs for this (b, t) — (B, T) int64 layout
pos_offs = b * T + t
pt = tl.load(pos_t_ptr + pos_offs)
ph = tl.load(pos_h_ptr + pos_offs)
pw = tl.load(pos_w_ptr + pos_offs)
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D
+ d_offs = tl.arange(0, half_D)
+ base_x = b * stride_x_b + t * stride_x_t + h * stride_x_h
+ base_out = b * stride_out_b + h * stride_out_h + t * stride_out_t
- # Load x[b, t, h, :]
- x_offs = b * stride_x_b + t * stride_x_t + h * stride_x_h + d_offs
- x = tl.load(x_ptr + x_offs, mask=mask, other=0.0)
+ x_first = tl.load(x_ptr + base_x + d_offs)
+ x_second = tl.load(x_ptr + base_x + d_offs + half_D)
- # Determine which position axis each d uses
- d_mod = tl.where(d_offs < half_D, d_offs, d_offs - half_D)
- pos = tl.where(d_mod < s0, pt, tl.where(d_mod < s0 + s1, ph, pw))
+ pos_axis = tl.where(d_offs < s0, pt, tl.where(d_offs < s0 + s1, ph, pw))
- # Load cos and sin from cache (max_pos, D) bf16, indexed by pos and d
- cos_offs = pos * D + d_offs
- cos = tl.load(cos_cache_ptr + cos_offs, mask=mask, other=0.0)
- sin = tl.load(sin_cache_ptr + cos_offs, mask=mask, other=0.0)
+ cos_first = tl.load(cos_cache_ptr + pos_axis * D + d_offs)
+ sin_first = tl.load(sin_cache_ptr + pos_axis * D + d_offs)
+ cos_second = tl.load(cos_cache_ptr + pos_axis * D + d_offs + half_D)
+ sin_second = tl.load(sin_cache_ptr + pos_axis * D + d_offs + half_D)
- # rotate_half(x) = cat(-x[half:], x[:half])
- # Load paired elements; x is already in L1 from the previous load
- pair_offs = tl.where(d_offs < half_D, d_offs + half_D, d_offs - half_D)
- x_pair = tl.load(x_ptr + b * stride_x_b + t * stride_x_t + h * stride_x_h + pair_offs, mask=mask, other=0.0)
- rotated = tl.where(d_offs < half_D, -x_pair, x_pair)
+ out_first = x_first * cos_first + (-x_second) * sin_first
+ out_second = x_second * cos_second + x_first * sin_second
- out = x * cos + rotated * sin
-
- # Write output in (B, H, T, D) layout
- out_offs = b * stride_out_b + h * stride_out_h + t * stride_out_t + d_offs
- tl.store(out_ptr + out_offs, out, mask=mask)
+ tl.store(out_ptr + base_out + d_offs, out_first)
+ tl.store(out_ptr + base_out + d_offs + half_D, out_second)
class Model(nn.Module):
@@ -95,11 +90,13 @@
s0, s1, s2 = self.mrope_section
half_D = D // 2
- BLOCK_D = D
- num_warps = 4 if D == 128 else 2
+ num_warps = half_D // 32
+ if D == 128:
+ num_warps = 4
+ elif D == 64:
+ num_warps = 2
- grid_q = (B, Hq, T)
- _mrope_kernel[grid_q](
+ _mrope_kernel[(B, Hq, T)](
q, out_q,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
@@ -107,12 +104,11 @@
out_q.stride(0), out_q.stride(1), out_q.stride(2),
B, T, Hq, D,
s0, s1, s2,
- half_D=half_D, BLOCK_D=BLOCK_D,
+ half_D=half_D,
num_warps=num_warps,
)
- grid_k = (B, Hkv, T)
- _mrope_kernel[grid_k](
+ _mrope_kernel[(B, Hkv, T)](
k, out_k,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
@@ -120,14 +116,13 @@
out_k.stride(0), out_k.stride(1), out_k.stride(2),
B, T, Hkv, D,
s0, s1, s2,
- half_D=half_D, BLOCK_D=BLOCK_D,
+ half_D=half_D,
num_warps=num_warps,
)
return out_q, out_k
-# Module-level shape shims (overwritten per-shape by check.py / benchmark.py).
B = 1
T = 4096
Hq = 32
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<4564 chars \u2014 see diff>"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -91,10 +91,6 @@
half_D = D // 2
num_warps = half_D // 32
- if D == 128:
- num_warps = 4
- elif D == 64:
- num_warps = 2
_mrope_kernel[(B, Hq, T)](
q, out_q,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<163 chars \u2014 see below>",
"newString": "<68 chars \u2014 see below>"
} num_warps = half_D // 32
if D == 128:
num_warps = 4
elif D == 64:
num_warps = 2
_mrope_kernel[(B, Hq, T)]( num_warps = half_D // 32
_mrope_kernel[(B, Hq, T)](Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run benchmark.py 2>&1",
"description": "Benchmark after half_D optimization",
"timeout": 300000
}shape=0 variant=eager tflops=0.200 gbps=287.459 ms=0.314
shape=0 variant=compiled tflops=0.510 gbps=731.808 ms=0.123
shape=0 variant=solution tflops=0.484 gbps=694.772 ms=0.130
shape=0 solution_peak_fraction=0.3860
shape=1 variant=eager tflops=0.193 gbps=282.253 ms=0.521
shape=1 variant=compiled tflops=0.528 gbps=771.008 ms=0.191
shape=1 variant=solution tflops=0.528 gbps=771.590 ms=0.191
shape=1 solution_peak_fraction=0.4287
shape=2 variant=eager tflops=0.096 gbps=150.473 ms=0.147
shape=2 variant=compiled tflops=0.209 gbps=326.346 ms=0.068
shape=2 variant=solution tflops=0.284 gbps=443.361 ms=0.050
shape=2 solution_peak_fraction=0.2463
shape=3 variant=eager tflops=0.177 gbps=253.743 ms=1.423
shape=3 variant=compiled tflops=0.479 gbps=687.073 ms=0.526
shape=3 variant=solution tflops=0.593 gbps=851.370 ms=0.424
shape=3 solution_peak_fraction=0.4730
peak_fraction: 0.3726
RESULT: OK
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1",
"description": "Check correctness after half_D optimization",
"timeout": 120000
}PASS
{
"description": "Profile kernel with simple timing breakdown",
"command": "<2124 chars \u2014 see below>",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch
import triton
import triton.language as tl
from solution import Model, MROPE_SECTION, MAX_POS
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
m = Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
# Warmup
for _ in range(10):
m(q, k, pos_t, pos_h, pos_w)
# Time each kernel separately
import time
n_iters = 100
# Time q kernel
m.q_only = lambda: m._mrope_kernel[(1, 32, 4096)](
q, torch.empty(1, 32, 4096, 128, dtype=torch.bfloat16, device='cuda'),
m.cos_cache, m.sin_cache, pos_t, pos_h, pos_w,
q.stride(0), q.stride(1), q.stride(2),
4096*128, 128, 1, # out strides
1, 4096, 32, 128,
MROPE_SECTION[0], MROPE_SECTION[1], MROPE_SECTION[2],
half_D=64,
num_warps=2,
)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(n_iters):
m.q_only()
torch.cuda.synchronize()
t1 = time.time()
print(f'Q kernel: {(t1-t0)/n_iters*1000:.3f} ms')
# Time k kernel
m.k_only = lambda: m._mrope_kernel[(1, 8, 4096)](
k, torch.empty(1, 8, 4096, 128, dtype=torch.bfloat16, device='cuda'),
m.cos_cache, m.sin_cache, pos_t, pos_h, pos_w,
k.stride(0), k.stride(1), k.stride(2),
4096*128, 128, 1,
1, 4096, 8, 128,
MROPE_SECTION[0], MROPE_SECTION[1], MROPE_SECTION[2],
half_D=64,
num_warps=2,
)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(n_iters):
m.k_only()
torch.cuda.synchronize()
t1 = time.time()
print(f'K kernel: {(t1-t0)/n_iters*1000:.3f} ms')
# Time combined
torch.cuda.synchronize()
t0 = time.time()
for _ in range(n_iters):
m(q, k, pos_t, pos_h, pos_w)
torch.cuda.synchronize()
t1 = time.time()
print(f'Combined: {(t1-t0)/n_iters*1000:.3f} ms')
" 2>&1Traceback (most recent call last):
File "<string>", line 38, in <module>
File "<string>", line 25, in <lambda>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1968, in __getattr__
raise AttributeError(
AttributeError: 'Model' object has no attribute '_mrope_kernel'
{
"command": "<1097 chars \u2014 see below>",
"description": "Time combined forward pass",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch
import time
from solution import Model, MROPE_SECTION, MAX_POS
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
m = Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
# Warmup
for _ in range(10):
m(q, k, pos_t, pos_h, pos_w)
n_iters = 100
torch.cuda.synchronize()
t0 = time.time()
for _ in range(n_iters):
m(q, k, pos_t, pos_h, pos_w)
torch.cuda.synchronize()
t1 = time.time()
print(f'Combined forward: {(t1-t0)/n_iters*1000:.3f} ms')
# Now let's also check with nsight compute
import subprocess
ret = subprocess.run(['ncu', '--list-sets'], capture_output=True, text=True)
print(ret.stdout[:500])
" 2>&1Combined forward: 0.074 ms
---------- --------------------------------------------------------------------------- ------- -----------------
Identifier Sections Enabled Estimated Metrics
---------- --------------------------------------------------------------------------- ------- -----------------
basic LaunchStats, Occupancy, SpeedOfLight, WorkloadDistribution yes 213
detailed ComputeWorkloadAnalysis, LaunchStats,
{
"command": "<2618 chars \u2014 see below>",
"description": "Sweep num_warps for best performance",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch
from solution import Model, MROPE_SECTION, MAX_POS
import time
# Test different num_warps settings
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
m = Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
for nw in [1, 2, 4, 8]:
import solution
# Hot-patch num_warps by recompiling
try:
from triton.runtime import jit
# Just directly modify Model.forward
original_forward = Model.forward
except:
pass
out_q = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device='cuda')
out_k = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device='cuda')
s0, s1, s2 = MROPE_SECTION
half_D = 64
# Warmup with specific num_warps
solution._mrope_kernel[(B, Hq, T)](
q, out_q, m.cos_cache, m.sin_cache, pos_t, pos_h, pos_w,
q.stride(0), q.stride(1), q.stride(2),
out_q.stride(0), out_q.stride(1), out_q.stride(2),
B, T, Hq, D, s0, s1, s2,
half_D=half_D, num_warps=nw,
)
solution._mrope_kernel[(B, Hkv, T)](
k, out_k, m.cos_cache, m.sin_cache, pos_t, pos_h, pos_w,
k.stride(0), k.stride(1), k.stride(2),
out_k.stride(0), out_k.stride(1), out_k.stride(2),
B, T, Hkv, D, s0, s1, s2,
half_D=half_D, num_warps=nw,
)
torch.cuda.synchronize()
n_iters = 100
t0 = time.time()
for _ in range(n_iters):
solution._mrope_kernel[(B, Hq, T)](
q, out_q, m.cos_cache, m.sin_cache, pos_t, pos_h, pos_w,
q.stride(0), q.stride(1), q.stride(2),
out_q.stride(0), out_q.stride(1), out_q.stride(2),
B, T, Hq, D, s0, s1, s2,
half_D=half_D, num_warps=nw,
)
solution._mrope_kernel[(B, Hkv, T)](
k, out_k, m.cos_cache, m.sin_cache, pos_t, pos_h, pos_w,
k.stride(0), k.stride(1), k.stride(2),
out_k.stride(0), out_k.stride(1), out_k.stride(2),
B, T, Hkv, D, s0, s1, s2,
half_D=half_D, num_warps=nw,
)
torch.cuda.synchronize()
t1 = time.time()
ms = (t1-t0)/n_iters*1000
print(f'num_warps={nw}: {ms:.4f} ms total')
" 2>&1num_warps=1: 0.0765 ms total
num_warps=2: 0.0743 ms total
num_warps=4: 0.0755 ms total
num_warps=8: 0.1039 ms total
{
"command": "cat /home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py 2>/dev/null || find /home/infatoshi/cuda/KernelBench-Hard -name \"timing.py\" -exec cat {} \\;",
"description": "Read timing utility"
}"""Centralized GPU timing for benchmark.py files.
All problems' benchmark.py should call `time_fn` from here so we have one
implementation of warmup, L2 flush, and CUDA event capture to maintain.
Methodology:
- 10 warmup calls absorb Triton autotune (typical ~7 configs) and
torch.compile reduce-overhead CUDA-graph capture.
- Between each timed call, `_l2_flush()` writes 128 MB to evict L2
(Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
we measure HBM-load bandwidth, not L2-cached bandwidth).
- GPU timing via cuda.Event with synchronize() AFTER record() but
BEFORE elapsed_time() — the canonical NVIDIA pattern.
- Reported value is the median of `iters` trials, robust to outliers.
Notes / known biases not addressed here:
- torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
On small shapes where launch overhead matters, this gives compile an
artificial advantage. We accept this as the cost of using the agreed
"compiled" baseline policy.
- cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
- Median over a fairly small number of trials (default 30) is fine for
headline numbers but won't catch bimodal distributions. Use --extra
instrumentation if you ever care.
"""
from __future__ import annotations
import statistics
import torch
# Blackwell consumer (RTX PRO 6000) L2 is 96 MB. Allocate strictly larg
... (1067 more chars){
"command": "<1473 chars \u2014 see below>",
"description": "Measure with L2 flush to match benchmark",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch
import time
from solution import Model, MROPE_SECTION, MAX_POS
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
m = Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
# L2 flush helper
l2_scratch = torch.empty(128 * 1024 * 1024 // 4, dtype=torch.float32, device='cuda')
def time_with_flush(fn, n_iters=30):
for _ in range(10):
fn()
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
l2_scratch.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
fn()
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
med = sorted(times)[len(times)//2]
return med
# Test combined
fn = lambda: m(q, k, pos_t, pos_h, pos_w)
ms = time_with_flush(fn)
print(f'Combined (with L2 flush): {ms:.3f} ms')
# For reference: benchmark.py output
print(f'Expected benchmark: ~{ms:.3f} ms (full forward)')
" 2>&1Combined (with L2 flush): 0.129 ms
Expected benchmark: ~0.129 ms (full forward)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -6,7 +6,7 @@
Optimization strategy:
- Process halves separately to avoid double-loading q/k for rotate_half
- Each thread handles one element from each half, all ops are element-local
-- Two kernel launches: one for q heads, one for k heads
+- Combined q/k kernel: single launch instead of two
"""
import torch
import torch.nn as nn
@@ -20,18 +20,22 @@
@triton.jit
def _mrope_kernel(
- x_ptr, out_ptr,
+ q_ptr, k_ptr, out_q_ptr, out_k_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
- stride_x_b, stride_x_t, stride_x_h,
- stride_out_b, stride_out_h, stride_out_t,
- B, T, H, D,
+ stride_q_b, stride_q_t, stride_q_h,
+ stride_k_b, stride_k_t, stride_k_h,
+ stride_out_q_b, stride_out_q_h, stride_out_q_t,
+ stride_out_k_b, stride_out_k_h, stride_out_k_t,
+ B, T, Hq, Hkv, D,
s0, s1, s2,
half_D: tl.constexpr,
):
b = tl.program_id(0)
- h = tl.program_id(1)
+ h_global = tl.program_id(1)
t = tl.program_id(2)
+
+ is_q = h_global < Hq
pos_offs = b * T + t
pt = tl.load(pos_t_ptr + pos_offs)
@@ -39,11 +43,18 @@
pw = tl.load(pos_w_ptr + pos_offs)
d_offs = tl.arange(0, half_D)
- base_x = b * stride_x_b + t * stride_x_t + h * stride_x_h
- base_out = b * stride_out_b + h * stride_out_h + t * stride_out_t
- x_first = tl.load(x_ptr + base_x + d_offs)
- x_second = tl.load(x_ptr + base_x + d_offs + half_D)
+ if is_q:
+ base_x = b * stride_q_b + t * stride_q_t + h_global * stride_q_h
+ base_out = b * stride_out_q_b + h_global * stride_out_q_h + t * stride_out_q_t
+ x_first = tl.load(q_ptr + base_x + d_offs)
+ x_second = tl.load(q_ptr + base_x + d_offs + half_D)
+ else:
+ h_kv = h_global - Hq
+ base_x = b * stride_k_b + t * stride_k_t + h_kv * stride_k_h
+ base_out = b * stride_out_k_b + h_kv * stride_out_k_h + t * stride_out_k_t
+ x_first = tl.load(k_ptr + base_x + d_offs)
+ x_second = tl.load(k_ptr + base_x + d_offs + half_D)
pos_axis = tl.where(d_offs < s0, pt, tl.where(d_offs < s0 + s1, ph, pw))
@@ -55,8 +66,12 @@
out_first = x_first * cos_first + (-x_second) * sin_first
out_second = x_second * cos_second + x_first * sin_second
- tl.store(out_ptr + base_out + d_offs, out_first)
- tl.store(out_ptr + base_out + d_offs + half_D, out_second)
+ if is_q:
+ tl.store(out_q_ptr + base_out + d_offs, out_first)
+ tl.store(out_q_ptr + base_out + d_offs + half_D, out_second)
+ else:
+ tl.store(out_k_ptr + base_out + d_offs, out_first)
+ tl.store(out_k_ptr + base_out + d_offs + half_D, out_second)
class Model(nn.Module):
@@ -89,28 +104,18 @@
s0, s1, s2 = self.mrope_section
half_D = D // 2
-
num_warps = half_D // 32
- _mrope_kernel[(B, Hq, T)](
- q, out_q,
+ grid = (B, Hq + Hkv, T)
+ _mrope_kernel[grid](
+ q, k, out_q, out_k,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
q.stride(0), q.stride(1), q.stride(2),
+ k.stride(0), k.stride(1), k.stride(2),
out_q.stride(0), out_q.stride(1), out_q.stride(2),
- B, T, Hq, D,
- s0, s1, s2,
- half_D=half_D,
- num_warps=num_warps,
- )
-
- _mrope_kernel[(B, Hkv, T)](
- k, out_k,
- self.cos_cache, self.sin_cache,
- pos_t, pos_h, pos_w,
- k.stride(0), k.stride(1), k.stride(2),
out_k.stride(0), out_k.stride(1), out_k.stride(2),
- B, T, Hkv, D,
+ B, T, Hq, Hkv, D,
s0, s1, s2,
half_D=half_D,
num_warps=num_warps,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<4921 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1",
"description": "Check correctness after combined kernel",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run benchmark.py 2>&1",
"description": "Benchmark combined kernel",
"timeout": 300000
}shape=0 variant=eager tflops=0.232 gbps=332.405 ms=0.272
shape=0 variant=compiled tflops=0.738 gbps=1059.178 ms=0.085
shape=0 variant=solution tflops=0.653 gbps=936.316 ms=0.096
shape=0 solution_peak_fraction=0.5202
shape=1 variant=eager tflops=0.161 gbps=234.442 ms=0.627
shape=1 variant=compiled tflops=0.684 gbps=998.406 ms=0.147
shape=1 variant=solution tflops=0.683 gbps=996.997 ms=0.147
shape=1 solution_peak_fraction=0.5539
shape=2 variant=eager tflops=0.138 gbps=215.160 ms=0.103
shape=2 variant=compiled tflops=0.172 gbps=268.688 ms=0.082
shape=2 variant=solution tflops=0.182 gbps=284.796 ms=0.078
shape=2 solution_peak_fraction=0.1582
shape=3 variant=eager tflops=0.178 gbps=255.340 ms=1.414
shape=3 variant=compiled tflops=0.526 gbps=754.487 ms=0.479
shape=3 variant=solution tflops=0.706 gbps=1013.515 ms=0.356
shape=3 solution_peak_fraction=0.5631
peak_fraction: 0.4003
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,12 +1,10 @@
"""Multi-axis RoPE pre-attention prep fused kernel (Triton).
Fuses gather + per-axis mixing + rotate-half + transpose into a single
-read of q/k with direct (B, H, T, D) output.
+pass with direct (B, H, T, D) output.
-Optimization strategy:
-- Process halves separately to avoid double-loading q/k for rotate_half
-- Each thread handles one element from each half, all ops are element-local
-- Combined q/k kernel: single launch instead of two
+Strategy: process full head_dim D per program with BLOCK_D = D.
+Double-load of q/k for rotate_half is absorbed by L1 cache.
"""
import torch
import torch.nn as nn
@@ -30,6 +28,7 @@
B, T, Hq, Hkv, D,
s0, s1, s2,
half_D: tl.constexpr,
+ BLOCK_D: tl.constexpr,
):
b = tl.program_id(0)
h_global = tl.program_id(1)
@@ -42,36 +41,35 @@
ph = tl.load(pos_h_ptr + pos_offs)
pw = tl.load(pos_w_ptr + pos_offs)
- d_offs = tl.arange(0, half_D)
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
if is_q:
base_x = b * stride_q_b + t * stride_q_t + h_global * stride_q_h
base_out = b * stride_out_q_b + h_global * stride_out_q_h + t * stride_out_q_t
- x_first = tl.load(q_ptr + base_x + d_offs)
- x_second = tl.load(q_ptr + base_x + d_offs + half_D)
+ x = tl.load(q_ptr + base_x + d_offs, mask=mask, other=0.0)
+ x_pair = tl.load(q_ptr + base_x + tl.where(d_offs < half_D, d_offs + half_D, d_offs - half_D), mask=mask, other=0.0)
else:
h_kv = h_global - Hq
base_x = b * stride_k_b + t * stride_k_t + h_kv * stride_k_h
base_out = b * stride_out_k_b + h_kv * stride_out_k_h + t * stride_out_k_t
- x_first = tl.load(k_ptr + base_x + d_offs)
- x_second = tl.load(k_ptr + base_x + d_offs + half_D)
+ x = tl.load(k_ptr + base_x + d_offs, mask=mask, other=0.0)
+ x_pair = tl.load(k_ptr + base_x + tl.where(d_offs < half_D, d_offs + half_D, d_offs - half_D), mask=mask, other=0.0)
- pos_axis = tl.where(d_offs < s0, pt, tl.where(d_offs < s0 + s1, ph, pw))
+ # Axis selection per element
+ d_mod = tl.where(d_offs < half_D, d_offs, d_offs - half_D)
+ pos = tl.where(d_mod < s0, pt, tl.where(d_mod < s0 + s1, ph, pw))
- cos_first = tl.load(cos_cache_ptr + pos_axis * D + d_offs)
- sin_first = tl.load(sin_cache_ptr + pos_axis * D + d_offs)
- cos_second = tl.load(cos_cache_ptr + pos_axis * D + d_offs + half_D)
- sin_second = tl.load(sin_cache_ptr + pos_axis * D + d_offs + half_D)
+ cos = tl.load(cos_cache_ptr + pos * D + d_offs, mask=mask, other=0.0)
+ sin = tl.load(sin_cache_ptr + pos * D + d_offs, mask=mask, other=0.0)
- out_first = x_first * cos_first + (-x_second) * sin_first
- out_second = x_second * cos_second + x_first * sin_second
+ rotated = tl.where(d_offs < half_D, -x_pair, x_pair)
+ out = x * cos + rotated * sin
if is_q:
- tl.store(out_q_ptr + base_out + d_offs, out_first)
- tl.store(out_q_ptr + base_out + d_offs + half_D, out_second)
+ tl.store(out_q_ptr + base_out + d_offs, out, mask=mask)
else:
- tl.store(out_k_ptr + base_out + d_offs, out_first)
- tl.store(out_k_ptr + base_out + d_offs + half_D, out_second)
+ tl.store(out_k_ptr + base_out + d_offs, out, mask=mask)
class Model(nn.Module):
@@ -104,7 +102,8 @@
s0, s1, s2 = self.mrope_section
half_D = D // 2
- num_warps = half_D // 32
+ BLOCK_D = D
+ num_warps = D // 32
grid = (B, Hq + Hkv, T)
_mrope_kernel[grid](
@@ -117,7 +116,7 @@
out_k.stride(0), out_k.stride(1), out_k.stride(2),
B, T, Hq, Hkv, D,
s0, s1, s2,
- half_D=half_D,
+ half_D=half_D, BLOCK_D=BLOCK_D,
num_warps=num_warps,
)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<4861 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1",
"description": "Check correctness after BLOCK_D=D design",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run benchmark.py 2>&1",
"description": "Benchmark BLOCK_D=D design",
"timeout": 300000
}shape=0 variant=eager tflops=0.195 gbps=279.596 ms=0.323
shape=0 variant=compiled tflops=0.503 gbps=721.606 ms=0.125
shape=0 variant=solution tflops=0.365 gbps=524.372 ms=0.172
shape=0 solution_peak_fraction=0.2913
shape=1 variant=eager tflops=0.206 gbps=300.593 ms=0.489
shape=1 variant=compiled tflops=0.493 gbps=719.390 ms=0.204
shape=1 variant=solution tflops=0.412 gbps=601.935 ms=0.244
shape=1 solution_peak_fraction=0.3344
shape=2 variant=eager tflops=0.062 gbps=97.277 ms=0.227
shape=2 variant=compiled tflops=0.466 gbps=728.346 ms=0.030
shape=2 variant=solution tflops=0.343 gbps=535.814 ms=0.041
shape=2 solution_peak_fraction=0.2977
shape=3 variant=eager tflops=0.178 gbps=255.308 ms=1.414
shape=3 variant=compiled tflops=0.523 gbps=750.847 ms=0.481
shape=3 variant=solution tflops=0.452 gbps=648.869 ms=0.557
shape=3 solution_peak_fraction=0.3605
peak_fraction: 0.3198
RESULT: OK
{
"command": "<3463 chars \u2014 see below>",
"description": "Sweep num_warps for D=64 shape",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch, time
import triton, triton.language as tl
B, T, Hq, Hkv, D = 2, 2048, 16, 2, 64
mrope_section = (8, 12, 12)
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, 16384, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, 16384, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, 16384, (B, T), dtype=torch.int64, device='cuda')
cos_cache = torch.empty(16384, D, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
sin_cache = torch.empty(16384, D, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
out_q = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device='cuda')
out_k = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device='cuda')
s0, s1, s2 = mrope_section
half_D = D // 2
# Test various num_warps for half_D approach
@triton.jit
def test_half_kernel(
q_ptr, k_ptr, out_q_ptr, out_k_ptr, cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
sq_b, sq_t, sq_h, sk_b, sk_t, sk_h,
soq_b, soq_h, soq_t, sok_b, sok_h, sok_t,
B, T, Hq, Hkv, D, s0, s1, s2, half_D: tl.constexpr):
b= tl.program_id(0); hg= tl.program_id(1); t= tl.program_id(2)
is_q= hg < Hq
po= b*T+t
pt=tl.load(pos_t_ptr+po); ph=tl.load(pos_h_ptr+po); pw=tl.load(pos_w_ptr+po)
do= tl.arange(0, half_D)
if is_q:
bx= b*sq_b + t*sq_t + hg*sq_h; bo= b*soq_b + hg*soq_h + t*soq_t
x1= tl.load(q_ptr+bx+do); x2= tl.load(q_ptr+bx+do+half_D)
else:
hk= hg-Hq; bx= b*sk_b + t*sk_t + hk*sk_h; bo= b*sok_b + hk*sok_h + t*sok_t
x1= tl.load(k_ptr+bx+do); x2= tl.load(k_ptr+bx+do+half_D)
pa= tl.where(do < s0, pt, tl.where(do < s0+s1, ph, pw))
c1= tl.load(cos_cache_ptr+pa*D+do); s1_= tl.load(sin_cache_ptr+pa*D+do)
c2= tl.load(cos_cache_ptr+pa*D+do+half_D); s2_= tl.load(sin_cache_ptr+pa*D+do+half_D)
o1= x1*c1 + (-x2)*s1_; o2= x2*c2 + x1*s2_
if is_q: tl.store(out_q_ptr+bo+do, o1); tl.store(out_q_ptr+bo+do+half_D, o2)
else: tl.store(out_k_ptr+bo+do, o1); tl.store(out_k_ptr+bo+do+half_D, o2)
l2 = torch.empty(128*1024*1024//4, dtype=torch.float32, device='cuda')
for nw in [1, 2, 4]:
for _ in range(10):
test_half_kernel[(B, Hq+Hkv, T)](q,k,out_q,out_k,cos_cache,sin_cache,pos_t,pos_h,pos_w,
q.stride(0),q.stride(1),q.stride(2),k.stride(0),k.stride(1),k.stride(2),
out_q.stride(0),out_q.stride(1),out_q.stride(2),out_k.stride(0),out_k.stride(1),out_k.stride(2),
B,T,Hq,Hkv,D,s0,s1,s2,half_D=half_D, num_warps=nw)
torch.cuda.synchronize()
times=[]
for _ in range(30):
l2.zero_(); torch.cuda.synchronize()
s= torch.cuda.Event(enable_timing=True); e= torch.cuda.Event(enable_timing=True)
s.record()
test_half_kernel[(B, Hq+Hkv, T)](q,k,out_q,out_k,cos_cache,sin_cache,pos_t,pos_h,pos_w,
q.stride(0),q.stride(1),q.stride(2),k.stride(0),k.stride(1),k.stride(2),
out_q.stride(0),out_q.stride(1),out_q.stride(2),out_k.stride(0),out_k.stride(1),out_k.stride(2),
B,T,Hq,Hkv,D,s0,s1,s2,half_D=half_D, num_warps=nw)
e.record(); torch.cuda.synchronize()
times.append(s.elapsed_time(e))
med= sorted(times)[len(times)//2]
print(f'D=64 half_D approach num_warps={nw}: {med:.3f} ms')
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 461, in __init__
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1240, in getsourcelines
lines, lnum = findsource(object)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/inspect.py", line 1077, in findsource
raise OSError('could not get source code')
OSError: could not get source code
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 24, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 935, in jit
return decorator(fn)
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 923, in decorator
return JITFunction(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 756, in __init__
super().__init__(fn)
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 463, in __
... (150 more chars){
"description": "Sweep num_warps for D=64 using test file",
"command": "<3440 chars \u2014 see below>",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && cat > /tmp/test_nw.py << 'PYEOF'
import torch
import triton
import triton.language as tl
import time
@triton.jit
def test_half_kernel(
q_ptr, k_ptr, out_q_ptr, out_k_ptr, cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
sq_b, sq_t, sq_h, sk_b, sk_t, sk_h,
soq_b, soq_h, soq_t, sok_b, sok_h, sok_t,
B, T, Hq, Hkv, D, s0, s1, s2, half_D: tl.constexpr):
b= tl.program_id(0); hg= tl.program_id(1); t= tl.program_id(2)
is_q= hg < Hq
po= b*T+t
pt=tl.load(pos_t_ptr+po); ph=tl.load(pos_h_ptr+po); pw=tl.load(pos_w_ptr+po)
do= tl.arange(0, half_D)
if is_q:
bx= b*sq_b + t*sq_t + hg*sq_h; bo= b*soq_b + hg*soq_h + t*soq_t
x1= tl.load(q_ptr+bx+do); x2= tl.load(q_ptr+bx+do+half_D)
else:
hk= hg-Hq; bx= b*sk_b + t*sk_t + hk*sk_h; bo= b*sok_b + hk*sok_h + t*sok_t
x1= tl.load(k_ptr+bx+do); x2= tl.load(k_ptr+bx+do+half_D)
pa= tl.where(do < s0, pt, tl.where(do < s0+s1, ph, pw))
c1= tl.load(cos_cache_ptr+pa*D+do); s1_= tl.load(sin_cache_ptr+pa*D+do)
c2= tl.load(cos_cache_ptr+pa*D+do+half_D); s2_= tl.load(sin_cache_ptr+pa*D+do+half_D)
o1= x1*c1 + (-x2)*s1_; o2= x2*c2 + x1*s2_
if is_q: tl.store(out_q_ptr+bo+do, o1); tl.store(out_q_ptr+bo+do+half_D, o2)
else: tl.store(out_k_ptr+bo+do, o1); tl.store(out_k_ptr+bo+do+half_D, o2)
B, T, Hq, Hkv, D = 2, 2048, 16, 2, 64
s0, s1, s2 = 8, 12, 12
half_D = D // 2
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, 16384, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, 16384, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, 16384, (B, T), dtype=torch.int64, device='cuda')
cos_cache = torch.empty(16384, D, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
sin_cache = torch.empty(16384, D, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
out_q = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device='cuda')
out_k = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device='cuda')
l2 = torch.empty(128*1024*1024//4, dtype=torch.float32, device='cuda')
for nw in [1, 2, 4]:
for _ in range(10):
test_half_kernel[(B, Hq+Hkv, T)](q,k,out_q,out_k,cos_cache,sin_cache,pos_t,pos_h,pos_w,
q.stride(0),q.stride(1),q.stride(2),k.stride(0),k.stride(1),k.stride(2),
out_q.stride(0),out_q.stride(1),out_q.stride(2),out_k.stride(0),out_k.stride(1),out_k.stride(2),
B,T,Hq,Hkv,D,s0,s1,s2,half_D=half_D, num_warps=nw)
torch.cuda.synchronize()
times=[]
for _ in range(30):
l2.zero_(); torch.cuda.synchronize()
s= torch.cuda.Event(enable_timing=True); e= torch.cuda.Event(enable_timing=True)
s.record()
test_half_kernel[(B, Hq+Hkv, T)](q,k,out_q,out_k,cos_cache,sin_cache,pos_t,pos_h,pos_w,
q.stride(0),q.stride(1),q.stride(2),k.stride(0),k.stride(1),k.stride(2),
out_q.stride(0),out_q.stride(1),out_q.stride(2),out_k.stride(0),out_k.stride(1),out_k.stride(2),
B,T,Hq,Hkv,D,s0,s1,s2,half_D=half_D, num_warps=nw)
e.record(); torch.cuda.synchronize()
times.append(s.elapsed_time(e))
med= sorted(times)[len(times)//2]
print(f'D=64 half_D approach num_warps={nw}: {med:.3f} ms')
PYEOF
uv run python /tmp/test_nw.py 2>&1D=64 half_D approach num_warps=1: 0.045 ms
D=64 half_D approach num_warps=2: 0.043 ms
D=64 half_D approach num_warps=4: 0.066 ms
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,10 +1,15 @@
"""Multi-axis RoPE pre-attention prep fused kernel (Triton).
Fuses gather + per-axis mixing + rotate-half + transpose into a single
-pass with direct (B, H, T, D) output.
+kernel with direct (B, H, T, D) output.
-Strategy: process full head_dim D per program with BLOCK_D = D.
-Double-load of q/k for rotate_half is absorbed by L1 cache.
+Key optimization: process head_dim in two halves to avoid double-loading
+q/k for rotate_half. Each thread handles one element from each half,
+all rotation ops are element-local. Combined q/k grid eliminates second
+kernel launch.
+
+Memory-bound: fused approach reduces DRAM traffic vs reference by
+eliminating intermediate cos/sin tensors and the transpose pass.
"""
import torch
import torch.nn as nn
@@ -28,7 +33,6 @@
B, T, Hq, Hkv, D,
s0, s1, s2,
half_D: tl.constexpr,
- BLOCK_D: tl.constexpr,
):
b = tl.program_id(0)
h_global = tl.program_id(1)
@@ -41,35 +45,36 @@
ph = tl.load(pos_h_ptr + pos_offs)
pw = tl.load(pos_w_ptr + pos_offs)
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D
+ d_offs = tl.arange(0, half_D)
if is_q:
base_x = b * stride_q_b + t * stride_q_t + h_global * stride_q_h
base_out = b * stride_out_q_b + h_global * stride_out_q_h + t * stride_out_q_t
- x = tl.load(q_ptr + base_x + d_offs, mask=mask, other=0.0)
- x_pair = tl.load(q_ptr + base_x + tl.where(d_offs < half_D, d_offs + half_D, d_offs - half_D), mask=mask, other=0.0)
+ x_first = tl.load(q_ptr + base_x + d_offs)
+ x_second = tl.load(q_ptr + base_x + d_offs + half_D)
else:
h_kv = h_global - Hq
base_x = b * stride_k_b + t * stride_k_t + h_kv * stride_k_h
base_out = b * stride_out_k_b + h_kv * stride_out_k_h + t * stride_out_k_t
- x = tl.load(k_ptr + base_x + d_offs, mask=mask, other=0.0)
- x_pair = tl.load(k_ptr + base_x + tl.where(d_offs < half_D, d_offs + half_D, d_offs - half_D), mask=mask, other=0.0)
+ x_first = tl.load(k_ptr + base_x + d_offs)
+ x_second = tl.load(k_ptr + base_x + d_offs + half_D)
- # Axis selection per element
- d_mod = tl.where(d_offs < half_D, d_offs, d_offs - half_D)
- pos = tl.where(d_mod < s0, pt, tl.where(d_mod < s0 + s1, ph, pw))
+ pos_axis = tl.where(d_offs < s0, pt, tl.where(d_offs < s0 + s1, ph, pw))
- cos = tl.load(cos_cache_ptr + pos * D + d_offs, mask=mask, other=0.0)
- sin = tl.load(sin_cache_ptr + pos * D + d_offs, mask=mask, other=0.0)
+ cos_first = tl.load(cos_cache_ptr + pos_axis * D + d_offs)
+ sin_first = tl.load(sin_cache_ptr + pos_axis * D + d_offs)
+ cos_second = tl.load(cos_cache_ptr + pos_axis * D + d_offs + half_D)
+ sin_second = tl.load(sin_cache_ptr + pos_axis * D + d_offs + half_D)
- rotated = tl.where(d_offs < half_D, -x_pair, x_pair)
- out = x * cos + rotated * sin
+ out_first = x_first * cos_first + (-x_second) * sin_first
+ out_second = x_second * cos_second + x_first * sin_second
if is_q:
- tl.store(out_q_ptr + base_out + d_offs, out, mask=mask)
+ tl.store(out_q_ptr + base_out + d_offs, out_first)
+ tl.store(out_q_ptr + base_out + d_offs + half_D, out_second)
else:
- tl.store(out_k_ptr + base_out + d_offs, out, mask=mask)
+ tl.store(out_k_ptr + base_out + d_offs, out_first)
+ tl.store(out_k_ptr + base_out + d_offs + half_D, out_second)
class Model(nn.Module):
@@ -102,8 +107,9 @@
s0, s1, s2 = self.mrope_section
half_D = D // 2
- BLOCK_D = D
- num_warps = D // 32
+
+ # At least 2 warps for latency hiding (critical for D=64 where half_D=32)
+ num_warps = max(half_D // 32, 2)
grid = (B, Hq + Hkv, T)
_mrope_kernel[grid](
@@ -116,13 +122,14 @@
out_k.stride(0), out_k.stride(1), out_k.stride(2),
B, T, Hq, Hkv, D,
s0, s1, s2,
- half_D=half_D, BLOCK_D=BLOCK_D,
+ half_D=half_D,
num_warps=num_warps,
)
return out_q, out_k
+# Module-level shape shims (overwritten per-shape by check.py / benchmark.py).
B = 1
T = 4096
Hq = 32
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<5224 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1 && uv run benchmark.py 2>&1",
"description": "Check + benchmark after num_warps fix",
"timeout": 300000
}PASS
shape=0 variant=eager tflops=0.231 gbps=332.053 ms=0.272
shape=0 variant=compiled tflops=0.724 gbps=1038.704 ms=0.087
shape=0 variant=solution tflops=0.650 gbps=933.219 ms=0.097
shape=0 solution_peak_fraction=0.5185
shape=1 variant=eager tflops=0.206 gbps=300.879 ms=0.489
shape=1 variant=compiled tflops=0.694 gbps=1013.271 ms=0.145
shape=1 variant=solution tflops=0.674 gbps=984.181 ms=0.149
shape=1 solution_peak_fraction=0.5468
shape=2 variant=eager tflops=0.137 gbps=214.359 ms=0.103
shape=2 variant=compiled tflops=0.459 gbps=716.641 ms=0.031
shape=2 variant=solution tflops=0.331 gbps=516.592 ms=0.043
shape=2 solution_peak_fraction=0.2870
shape=3 variant=eager tflops=0.178 gbps=255.698 ms=1.412
shape=3 variant=compiled tflops=0.525 gbps=753.831 ms=0.479
shape=3 variant=solution tflops=0.654 gbps=937.794 ms=0.385
shape=3 solution_peak_fraction=0.5210
peak_fraction: 0.4537
RESULT: OK
{
"command": "<767 chars \u2014 see below>",
"description": "Profile kernel with ncu",
"timeout": 180000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run ncu --set basic -c 1 --kernel-name _mrope_kernel --page details python -c "
from solution import Model, MROPE_SECTION, MAX_POS
import torch
m = Model(1, 4096, 32, 8, 128, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(1, 4096, 32, 128, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(1, 4096, 8, 128, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
for _ in range(5): m(q,k,pos_t,pos_h,pos_w)
torch.cuda.synchronize()
" 2>&1 | head -80==PROF== Connected to process 365708 (/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/bin/python3.11)
==PROF== Profiling "_mrope_kernel": 0%....50%....100% - 9 passes
==PROF== Disconnected from process 365708
[365708] python3.11@127.0.0.1
_mrope_kernel (1, 40, 4096)x(64, 1, 1), Context 1, Stream 7, Device 0, CC 12.0
Section: GPU Speed Of Light Throughput
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 13.35
SM Frequency Ghz 2.45
Elapsed Cycles cycle 240,315
Memory Throughput % 50.14
DRAM Throughput % 50.14
Duration us 97.34
L1/TEX Cache Throughput % 19.19
L2 Cache Throughput % 24.19
SM Active Cycles cycle 217,940.88
Compute (SM) Throughput % 17.55
----------------------- ----------- ------------
OPT This workload exhibits low compute throughput and memory bandwidth utilization relative to the peak
performance of this device. Achieved compute throughput and/or memory bandwidth below 60.0% of peak
typically indicate latency issues. Look at Scheduler Statistics and Warp State Statistics for potential
... (3671 more chars){
"todos": [
{
"priority": "high",
"content": "Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml",
"status": "completed"
},
{
"priority": "high",
"content": "Implement baseline Triton kernel for mRoPE prep",
"status": "completed"
},
{
"priority": "high",
"content": "Check correctness with python check.py",
"status": "completed"
},
{
"priority": "high",
"content": "Profile and optimize iteratively",
"status": "in_progress"
},
{
"priority": "high",
"con
... (70 more chars)content: 'Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml'
status: 'completed'
priority: 'high'
content: 'Implement baseline Triton kernel for mRoPE prep'
status: 'completed'
priority: 'high'
content: 'Check correctness with python check.py'
status: 'completed'
priority: 'high'
content: 'Profile and optimize iteratively'
status: 'in_progress'
priority: 'high'
content: 'Benchmark and finalize'
status: 'pending'
priority: 'high'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,13 +3,12 @@
Fuses gather + per-axis mixing + rotate-half + transpose into a single
kernel with direct (B, H, T, D) output.
-Key optimization: process head_dim in two halves to avoid double-loading
-q/k for rotate_half. Each thread handles one element from each half,
-all rotation ops are element-local. Combined q/k grid eliminates second
-kernel launch.
-
-Memory-bound: fused approach reduces DRAM traffic vs reference by
-eliminating intermediate cos/sin tensors and the transpose pass.
+Key optimization:
+- Group by (batch, time): one program loads all 6 cos/sin rows once and
+ processes ALL heads at that (b,t). Eliminates H_total-way redundant
+ cos/sin cache reads — the main bandwidth bottleneck.
+- Process head_dim in two halves to avoid double-loading q/k.
+- All rotation ops are element-local, no cross-thread communication.
"""
import torch
import torch.nn as nn
@@ -30,16 +29,16 @@
stride_k_b, stride_k_t, stride_k_h,
stride_out_q_b, stride_out_q_h, stride_out_q_t,
stride_out_k_b, stride_out_k_h, stride_out_k_t,
- B, T, Hq, Hkv, D,
+ B, T, D,
s0, s1, s2,
half_D: tl.constexpr,
+ Hq_: tl.constexpr,
+ Hkv_: tl.constexpr,
):
b = tl.program_id(0)
- h_global = tl.program_id(1)
- t = tl.program_id(2)
+ t = tl.program_id(1)
- is_q = h_global < Hq
-
+ # --- Load position IDs once per (b, t) ---
pos_offs = b * T + t
pt = tl.load(pos_t_ptr + pos_offs)
ph = tl.load(pos_h_ptr + pos_offs)
@@ -47,32 +46,51 @@
d_offs = tl.arange(0, half_D)
- if is_q:
- base_x = b * stride_q_b + t * stride_q_t + h_global * stride_q_h
- base_out = b * stride_out_q_b + h_global * stride_out_q_h + t * stride_out_q_t
+ # --- Load all 6 cos/sin rows once (shared across all heads) ---
+ cos_t = tl.load(cos_cache_ptr + pt * D + d_offs)
+ sin_t = tl.load(sin_cache_ptr + pt * D + d_offs)
+ cos_h = tl.load(cos_cache_ptr + ph * D + d_offs)
+ sin_h = tl.load(sin_cache_ptr + ph * D + d_offs)
+ cos_w = tl.load(cos_cache_ptr + pw * D + d_offs)
+ sin_w = tl.load(sin_cache_ptr + pw * D + d_offs)
+
+ cos_t_2 = tl.load(cos_cache_ptr + pt * D + d_offs + half_D)
+ sin_t_2 = tl.load(sin_cache_ptr + pt * D + d_offs + half_D)
+ cos_h_2 = tl.load(cos_cache_ptr + ph * D + d_offs + half_D)
+ sin_h_2 = tl.load(sin_cache_ptr + ph * D + d_offs + half_D)
+ cos_w_2 = tl.load(cos_cache_ptr + pw * D + d_offs + half_D)
+ sin_w_2 = tl.load(sin_cache_ptr + pw * D + d_offs + half_D)
+
+ # --- Per-element axis mixing ---
+ axis = tl.where(d_offs < s0, 0, tl.where(d_offs < s0 + s1, 1, 2))
+ cos_first = tl.where(axis == 0, cos_t, tl.where(axis == 1, cos_h, cos_w))
+ sin_first = tl.where(axis == 0, sin_t, tl.where(axis == 1, sin_h, sin_w))
+ cos_second = tl.where(axis == 0, cos_t_2, tl.where(axis == 1, cos_h_2, cos_w_2))
+ sin_second = tl.where(axis == 0, sin_t_2, tl.where(axis == 1, sin_h_2, sin_w_2))
+
+ # --- Process all q heads ---
+ for h in range(Hq_):
+ base_x = b * stride_q_b + t * stride_q_t + h * stride_q_h
x_first = tl.load(q_ptr + base_x + d_offs)
x_second = tl.load(q_ptr + base_x + d_offs + half_D)
- else:
- h_kv = h_global - Hq
- base_x = b * stride_k_b + t * stride_k_t + h_kv * stride_k_h
- base_out = b * stride_out_k_b + h_kv * stride_out_k_h + t * stride_out_k_t
+
+ out_first = x_first * cos_first + (-x_second) * sin_first
+ out_second = x_second * cos_second + x_first * sin_second
+
+ base_out = b * stride_out_q_b + h * stride_out_q_h + t * stride_out_q_t
+ tl.store(out_q_ptr + base_out + d_offs, out_first)
+ tl.store(out_q_ptr + base_out + d_offs + half_D, out_second)
+
+ # --- Process all k heads ---
+ for h in range(Hkv_):
+ base_x = b * stride_k_b + t * stride_k_t + h * stride_k_h
x_first = tl.load(k_ptr + base_x + d_offs)
x_second = tl.load(k_ptr + base_x + d_offs + half_D)
- pos_axis = tl.where(d_offs < s0, pt, tl.where(d_offs < s0 + s1, ph, pw))
+ out_first = x_first * cos_first + (-x_second) * sin_first
+ out_second = x_second * cos_second + x_first * sin_second
- cos_first = tl.load(cos_cache_ptr + pos_axis * D + d_offs)
- sin_first = tl.load(sin_cache_ptr + pos_axis * D + d_offs)
- cos_second = tl.load(cos_cache_ptr + pos_axis * D + d_offs + half_D)
- sin_second = tl.load(sin_cache_ptr + pos_axis * D + d_offs + half_D)
-
- out_first = x_first * cos_first + (-x_second) * sin_first
- out_second = x_second * cos_second + x_first * sin_second
-
- if is_q:
- tl.store(out_q_ptr + base_out + d_offs, out_first)
- tl.store(out_q_ptr + base_out + d_offs + half_D, out_second)
- else:
+ base_out = b * stride_out_k_b + h * stride_out_k_h + t * stride_out_k_t
tl.store(out_k_ptr + base_out + d_offs, out_first)
tl.store(out_k_ptr + base_out + d_offs + half_D, out_second)
@@ -108,10 +126,7 @@
s0, s1, s2 = self.mrope_section
half_D = D // 2
- # At least 2 warps for latency hiding (critical for D=64 where half_D=32)
- num_warps = max(half_D // 32, 2)
-
- grid = (B, Hq + Hkv, T)
+ grid = (B, T)
_mrope_kernel[grid](
q, k, out_q, out_k,
self.cos_cache, self.sin_cache,
@@ -120,10 +135,11 @@
k.stride(0), k.stride(1), k.stride(2),
out_q.stride(0), out_q.stride(1), out_q.stride(2),
out_k.stride(0), out_k.stride(1), out_k.stride(2),
- B, T, Hq, Hkv, D,
+ B, T, D,
s0, s1, s2,
half_D=half_D,
- num_warps=num_warps,
+ Hq_=Hq, Hkv_=Hkv,
+ num_warps=4,
)
return out_q, out_k
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<6140 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1",
"description": "Check correctness of head-grouping kernel",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run benchmark.py 2>&1",
"description": "Benchmark head-grouping kernel",
"timeout": 300000
}shape=0 variant=eager tflops=0.234 gbps=335.908 ms=0.269
shape=0 variant=compiled tflops=0.511 gbps=732.663 ms=0.123
shape=0 variant=solution tflops=0.904 gbps=1297.365 ms=0.070
shape=0 solution_peak_fraction=0.7208
shape=1 variant=eager tflops=0.207 gbps=301.728 ms=0.487
shape=1 variant=compiled tflops=0.677 gbps=988.204 ms=0.149
shape=1 variant=solution tflops=0.942 gbps=1375.760 ms=0.107
shape=1 solution_peak_fraction=0.7643
shape=2 variant=eager tflops=0.138 gbps=214.959 ms=0.103
shape=2 variant=compiled tflops=0.451 gbps=705.306 ms=0.031
shape=2 variant=solution tflops=0.497 gbps=777.066 ms=0.028
shape=2 solution_peak_fraction=0.4317
shape=3 variant=eager tflops=0.177 gbps=254.398 ms=1.419
shape=3 variant=compiled tflops=0.528 gbps=757.322 ms=0.477
shape=3 variant=solution tflops=0.990 gbps=1419.878 ms=0.254
shape=3 solution_peak_fraction=0.7888
peak_fraction: 0.6581
RESULT: OK
{
"command": "<767 chars \u2014 see below>",
"description": "ncu profile of new kernel",
"timeout": 180000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run ncu --set basic -c 1 --kernel-name _mrope_kernel --page details python -c "
from solution import Model, MROPE_SECTION, MAX_POS
import torch
m = Model(1, 4096, 32, 8, 128, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(1, 4096, 32, 128, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(1, 4096, 8, 128, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
for _ in range(5): m(q,k,pos_t,pos_h,pos_w)
torch.cuda.synchronize()
" 2>&1 | head -80==PROF== Connected to process 368006 (/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/bin/python3.11)
==PROF== Profiling "_mrope_kernel": 0%....50%....100% - 9 passes
==PROF== Disconnected from process 368006
[368006] python3.11@127.0.0.1
_mrope_kernel (1, 4096, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 12.0
Section: GPU Speed Of Light Throughput
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 13.34
SM Frequency Ghz 2.51
Elapsed Cycles cycle 161,244
Memory Throughput % 79.19
DRAM Throughput % 79.19
Duration us 64.10
L1/TEX Cache Throughput % 35.05
L2 Cache Throughput % 19.63
SM Active Cycles cycle 87,273.41
Compute (SM) Throughput % 19.04
----------------------- ----------- ------------
OPT Memory is more heavily utilized than Compute: Look at the Memory Workload Analysis section to identify the
DRAM bottleneck. Check memory replay (coalescing) metrics to make sure you're efficiently utilizing the
bytes transferred. Also consider whether it is possible to do more work per memory access (kernel fusion) or
... (3937 more chars){
"command": "<2150 chars \u2014 see below>",
"description": "Sweep num_warps for head-grouping kernel",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch, time
# Test different num_warps with head-grouping
from solution import Model, MROPE_SECTION, MAX_POS
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
m = Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
s0, s1, s2 = MROPE_SECTION
half_D = D // 2
out_q = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device='cuda')
out_k = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device='cuda')
l2 = torch.empty(128*1024*1024//4, dtype=torch.float32, device='cuda')
import solution
for nw in [2, 4, 8]:
for _ in range(10):
solution._mrope_kernel[(B, T)](
q,k,out_q,out_k,m.cos_cache,m.sin_cache,pos_t,pos_h,pos_w,
q.stride(0),q.stride(1),q.stride(2),k.stride(0),k.stride(1),k.stride(2),
out_q.stride(0),out_q.stride(1),out_q.stride(2),out_k.stride(0),out_k.stride(1),out_k.stride(2),
B,T,D,s0,s1,s2,half_D=half_D,Hq_=Hq,Hkv_=Hkv,num_warps=nw)
torch.cuda.synchronize()
times=[]
for _ in range(30):
l2.zero_(); torch.cuda.synchronize()
s=torch.cuda.Event(enable_timing=True); e=torch.cuda.Event(enable_timing=True)
s.record()
solution._mrope_kernel[(B, T)](
q,k,out_q,out_k,m.cos_cache,m.sin_cache,pos_t,pos_h,pos_w,
q.stride(0),q.stride(1),q.stride(2),k.stride(0),k.stride(1),k.stride(2),
out_q.stride(0),out_q.stride(1),out_q.stride(2),out_k.stride(0),out_k.stride(1),out_k.stride(2),
B,T,D,s0,s1,s2,half_D=half_D,Hq_=Hq,Hkv_=Hkv,num_warps=nw)
e.record(); torch.cuda.synchronize()
times.append(s.elapsed_time(e))
med=sorted(times)[len(times)//2]
print(f'num_warps={nw}: {med:.3f} ms')
" 2>&1num_warps=2: 0.060 ms
num_warps=4: 0.072 ms
num_warps=8: 0.093 ms
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -139,7 +139,7 @@
s0, s1, s2,
half_D=half_D,
Hq_=Hq, Hkv_=Hkv,
- num_warps=4,
+ num_warps=2,
)
return out_q, out_k
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " num_warps=4,",
"newString": " num_warps=2,"
}Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1 && uv run benchmark.py 2>&1",
"description": "Check + benchmark with num_warps=2",
"timeout": 300000
}PASS
shape=0 variant=eager tflops=0.199 gbps=286.190 ms=0.315
shape=0 variant=compiled tflops=0.563 gbps=808.228 ms=0.112
shape=0 variant=solution tflops=0.753 gbps=1080.268 ms=0.084
shape=0 solution_peak_fraction=0.6001
shape=1 variant=eager tflops=0.191 gbps=279.182 ms=0.527
shape=1 variant=compiled tflops=0.532 gbps=776.285 ms=0.189
shape=1 variant=solution tflops=0.794 gbps=1159.577 ms=0.127
shape=1 solution_peak_fraction=0.6442
shape=2 variant=eager tflops=0.094 gbps=146.658 ms=0.151
shape=2 variant=compiled tflops=0.222 gbps=347.337 ms=0.064
shape=2 variant=solution tflops=0.467 gbps=729.884 ms=0.030
shape=2 solution_peak_fraction=0.4055
shape=3 variant=eager tflops=0.175 gbps=251.689 ms=1.435
shape=3 variant=compiled tflops=0.483 gbps=693.022 ms=0.521
shape=3 variant=solution tflops=0.876 gbps=1257.534 ms=0.287
shape=3 solution_peak_fraction=0.6986
peak_fraction: 0.5753
RESULT: OK
{
"command": "<879 chars \u2014 see below>",
"description": "Quick ncu check for num_warps=2",
"timeout": 180000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run ncu --set basic -c 1 --kernel-name _mrope_kernel --page details python -c "
from solution import Model, MROPE_SECTION, MAX_POS
import torch
m = Model(1, 4096, 32, 8, 128, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(1, 4096, 32, 128, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(1, 4096, 8, 128, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (1, 4096), dtype=torch.int64, device='cuda')
for _ in range(5): m(q,k,pos_t,pos_h,pos_w)
torch.cuda.synchronize()
" 2>&1 | grep -E 'Waves|Partial|Tail|Duration|Memory Throughput|DRAM Throughput|Occupancy|Registers|Block Limit|Warp|num_warps=2'Memory Throughput % 87.43
DRAM Throughput % 87.43
Duration us 31.01
Registers Per Thread register/thread 30
Waves Per SM 0.91
Section: Occupancy
Overall GPU Occupancy % 0
Cluster Occupancy % 0
Block Limit Barriers block 24
Block Limit SM block 24
Block Limit Registers block 32
Block Limit Shared Mem block 32
Block Limit Warps block 24
Theoretical Active Warps per SM warp 48
Theoretical Occupancy % 100
Achieved Occupancy % 87.14
Achieved Active Warps Per SM warp 41.83
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -139,7 +139,7 @@
s0, s1, s2,
half_D=half_D,
Hq_=Hq, Hkv_=Hkv,
- num_warps=2,
+ num_warps=4,
)
return out_q, out_k
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " num_warps=2,",
"newString": " num_warps=4,"
}Edit applied successfully.{
"command": "<5766 chars \u2014 see below>",
"description": "Compare in-loop vs hoisted axis selection",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && cat > /tmp/test_inline.py << 'PYEOF'
import torch, triton, triton.language as tl, time
# V1: axis mixing inside head loop (saves regs)
@triton.jit
def _mrope_inline(
q_ptr, k_ptr, out_q_ptr, out_k_ptr, cos_ptr, sin_ptr,
pt_ptr, ph_ptr, pw_ptr,
sq_b,sq_t,sq_h,sk_b,sk_t,sk_h,soq_b,soq_h,soq_t,sok_b,sok_h,sok_t,
B,T,D,s0,s1,s2,half_D:tl.constexpr,Hq_:tl.constexpr,Hkv_:tl.constexpr):
b=tl.program_id(0); t=tl.program_id(1)
po=b*T+t
pt=tl.load(pt_ptr+po); ph=tl.load(ph_ptr+po); pw=tl.load(pw_ptr+po)
do=tl.arange(0,half_D)
cos_t=tl.load(cos_ptr+pt*D+do); sin_t=tl.load(sin_ptr+pt*D+do)
cos_h=tl.load(cos_ptr+ph*D+do); sin_h=tl.load(sin_ptr+ph*D+do)
cos_w=tl.load(cos_ptr+pw*D+do); sin_w=tl.load(sin_ptr+pw*D+do)
cos_t2=tl.load(cos_ptr+pt*D+do+half_D); sin_t2=tl.load(sin_ptr+pt*D+do+half_D)
cos_h2=tl.load(cos_ptr+ph*D+do+half_D); sin_h2=tl.load(sin_ptr+ph*D+do+half_D)
cos_w2=tl.load(cos_ptr+pw*D+do+half_D); sin_w2=tl.load(sin_ptr+pw*D+do+half_D)
is_t=do<s0; is_h=(do>=s0)&(do<(s0+s1))
for h in range(Hq_):
bx=b*sq_b+t*sq_t+h*sq_h; bo=b*soq_b+h*soq_h+t*soq_t
x1=tl.load(q_ptr+bx+do); x2=tl.load(q_ptr+bx+do+half_D)
c1=tl.where(is_t,cos_t,tl.where(is_h,cos_h,cos_w))
s1v=tl.where(is_t,sin_t,tl.where(is_h,sin_h,sin_w))
c2=tl.where(is_t,cos_t2,tl.where(is_h,cos_h2,cos_w2))
s2v=tl.where(is_t,sin_t2,tl.where(is_h,sin_h2,sin_w2))
o1=x1*c1+(-x2)*s1v; o2=x2*c2+x1*s2v
tl.store(out_q_ptr+bo+do,o1); tl.store(out_q_ptr+bo+do+half_D,o2)
for h in range(Hkv_):
bx=b*sk_b+t*sk_t+h*sk_h; bo=b*sok_b+h*sok_h+t*sok_t
x1=tl.load(k_ptr+bx+do); x2=tl.load(k_ptr+bx+do+half_D)
c1=tl.where(is_t,cos_t,tl.where(is_h,cos_h,cos_w))
s1v=tl.where(is_t,sin_t,tl.where(is_h,sin_h,sin_w))
c2=tl.where(is_t,cos_t2,tl.where(is_h,cos_h2,cos_w2))
s2v=tl.where(is_t,sin_t2,tl.where(is_h,sin_h2,sin_w2))
o1=x1*c1+(-x2)*s1v; o2=x2*c2+x1*s2v
tl.store(out_k_ptr+bo+do,o1); tl.store(out_k_ptr+bo+do+half_D,o2)
# V2: axis mixing outside head loop (current approach)
@triton.jit
def _mrope_outside(
q_ptr, k_ptr, out_q_ptr, out_k_ptr, cos_ptr, sin_ptr,
pt_ptr, ph_ptr, pw_ptr,
sq_b,sq_t,sq_h,sk_b,sk_t,sk_h,soq_b,soq_h,soq_t,sok_b,sok_h,sok_t,
B,T,D,s0,s1,s2,half_D:tl.constexpr,Hq_:tl.constexpr,Hkv_:tl.constexpr):
b=tl.program_id(0); t=tl.program_id(1)
po=b*T+t
pt=tl.load(pt_ptr+po); ph=tl.load(ph_ptr+po); pw=tl.load(pw_ptr+po)
do=tl.arange(0,half_D)
cos_t=tl.load(cos_ptr+pt*D+do); sin_t=tl.load(sin_ptr+pt*D+do)
cos_h=tl.load(cos_ptr+ph*D+do); sin_h=tl.load(sin_ptr+ph*D+do)
cos_w=tl.load(cos_ptr+pw*D+do); sin_w=tl.load(sin_ptr+pw*D+do)
cos_t2=tl.load(cos_ptr+pt*D+do+half_D); sin_t2=tl.load(sin_ptr+pt*D+do+half_D)
cos_h2=tl.load(cos_ptr+ph*D+do+half_D); sin_h2=tl.load(sin_ptr+ph*D+do+half_D)
cos_w2=tl.load(cos_ptr+pw*D+do+half_D); sin_w2=tl.load(sin_ptr+pw*D+do+half_D)
is_t=do<s0; is_h=(do>=s0)&(do<(s0+s1))
c1=tl.where(is_t,cos_t,tl.where(is_h,cos_h,cos_w))
s1v=tl.where(is_t,sin_t,tl.where(is_h,sin_h,sin_w))
c2=tl.where(is_t,cos_t2,tl.where(is_h,cos_h2,cos_w2))
s2v=tl.where(is_t,sin_t2,tl.where(is_h,sin_h2,sin_w2))
for h in range(Hq_):
bx=b*sq_b+t*sq_t+h*sq_h; bo=b*soq_b+h*soq_h+t*soq_t
x1=tl.load(q_ptr+bx+do); x2=tl.load(q_ptr+bx+do+half_D)
o1=x1*c1+(-x2)*s1v; o2=x2*c2+x1*s2v
tl.store(out_q_ptr+bo+do,o1); tl.store(out_q_ptr+bo+do+half_D,o2)
for h in range(Hkv_):
bx=b*sk_b+t*sk_t+h*sk_h; bo=b*sok_b+h*sok_h+t*sok_t
x1=tl.load(k_ptr+bx+do); x2=tl.load(k_ptr+bx+do+half_D)
o1=x1*c1+(-x2)*s1v; o2=x2*c2+x1*s2v
tl.store(out_k_ptr+bo+do,o1); tl.store(out_k_ptr+bo+do+half_D,o2)
B,T,Hq,Hkv,D = 1,4096,32,8,128
s0,s1,s2 = 16,24,24
hD = D//2
q = torch.randn(B,T,Hq,D, dtype=torch.bfloat16, device='cuda
... (1766 more chars)inline (in-loop axis): 0.071 ms
outside (hoisted axis): 0.072 ms
{
"command": "<3697 chars \u2014 see below>",
"description": "Test with memory hints",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && cp solution.py solution.py.bak && cat > /tmp/test_hints.py << 'PYEOF'
import torch, triton, triton.language as tl, time
# V3: With tl.max_contiguous / tl.multiple_of hints
@triton.jit
def _mrope_hints(
q_ptr, k_ptr, out_q_ptr, out_k_ptr, cos_ptr, sin_ptr,
pt_ptr, ph_ptr, pw_ptr,
sq_b,sq_t,sq_h,sk_b,sk_t,sk_h,soq_b,soq_h,soq_t,sok_b,sok_h,sok_t,
B,T,D,s0,s1,s2,half_D:tl.constexpr,Hq_:tl.constexpr,Hkv_:tl.constexpr):
b=tl.program_id(0); t=tl.program_id(1)
po=b*T+t
pt=tl.load(pt_ptr+po); ph=tl.load(ph_ptr+po); pw=tl.load(pw_ptr+po)
do=tl.arange(0,half_D)
cos_t=tl.load(cos_ptr+pt*D+do); sin_t=tl.load(sin_ptr+pt*D+do)
cos_h=tl.load(cos_ptr+ph*D+do); sin_h=tl.load(sin_ptr+ph*D+do)
cos_w=tl.load(cos_ptr+pw*D+do); sin_w=tl.load(sin_ptr+pw*D+do)
cos_t2=tl.load(cos_ptr+pt*D+do+half_D); sin_t2=tl.load(sin_ptr+pt*D+do+half_D)
cos_h2=tl.load(cos_ptr+ph*D+do+half_D); sin_h2=tl.load(sin_ptr+ph*D+do+half_D)
cos_w2=tl.load(cos_ptr+pw*D+do+half_D); sin_w2=tl.load(sin_ptr+pw*D+do+half_D)
is_t=do<s0; is_h=(do>=s0)&(do<(s0+s1))
c1=tl.where(is_t,cos_t,tl.where(is_h,cos_h,cos_w))
s1v=tl.where(is_t,sin_t,tl.where(is_h,sin_h,sin_w))
c2=tl.where(is_t,cos_t2,tl.where(is_h,cos_h2,cos_w2))
s2v=tl.where(is_t,sin_t2,tl.where(is_h,sin_h2,sin_w2))
for h in range(Hq_):
bx=b*sq_b+t*sq_t+h*sq_h; bo=b*soq_b+h*soq_h+t*soq_t
x1=tl.load(q_ptr+bx+do); x2=tl.load(q_ptr+bx+do+half_D)
o1=x1*c1+(-x2)*s1v; o2=x2*c2+x1*s2v
tl.store(out_q_ptr+bo+do,o1); tl.store(out_q_ptr+bo+do+half_D,o2)
for h in range(Hkv_):
bx=b*sk_b+t*sk_t+h*sk_h; bo=b*sok_b+h*sok_h+t*sok_t
x1=tl.load(k_ptr+bx+do); x2=tl.load(k_ptr+bx+do+half_D)
o1=x1*c1+(-x2)*s1v; o2=x2*c2+x1*s2v
tl.store(out_k_ptr+bo+do,o1); tl.store(out_k_ptr+bo+do+half_D,o2)
B,T,Hq,Hkv,D = 1,4096,32,8,128
s0,s1,s2 = 16,24,24; hD=D//2
q=torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device='cuda')*0.5
k=torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device='cuda')*0.5
pt=torch.randint(0,32768,(B,T),dtype=torch.int64,device='cuda')
ph=torch.randint(0,32768,(B,T),dtype=torch.int64,device='cuda')
pw=torch.randint(0,32768,(B,T),dtype=torch.int64,device='cuda')
cc=torch.empty(32768,D,dtype=torch.bfloat16,device='cuda').uniform_(-1,1)
sc=torch.empty(32768,D,dtype=torch.bfloat16,device='cuda').uniform_(-1,1)
oq=torch.empty(B,Hq,T,D,dtype=torch.bfloat16,device='cuda')
ok=torch.empty(B,Hkv,T,D,dtype=torch.bfloat16,device='cuda')
l2=torch.empty(128*1024*1024//4,dtype=torch.float32,device='cuda')
for nw in [2,4]:
for _ in range(5):
_mrope_hints[(B,T)](q,k,oq,ok,cc,sc,pt,ph,pw,
q.stride(0),q.stride(1),q.stride(2),k.stride(0),k.stride(1),k.stride(2),
oq.stride(0),oq.stride(1),oq.stride(2),ok.stride(0),ok.stride(1),ok.stride(2),
B,T,D,s0,s1,s2,hD,Hq_=Hq,Hkv_=Hkv,num_warps=nw)
torch.cuda.synchronize()
times=[]
for _ in range(30):
l2.zero_(); torch.cuda.synchronize()
s=torch.cuda.Event(enable_timing=True); e=torch.cuda.Event(enable_timing=True)
s.record()
_mrope_hints[(B,T)](q,k,oq,ok,cc,sc,pt,ph,pw,
q.stride(0),q.stride(1),q.stride(2),k.stride(0),k.stride(1),k.stride(2),
oq.stride(0),oq.stride(1),oq.stride(2),ok.stride(0),ok.stride(1),ok.stride(2),
B,T,D,s0,s1,s2,hD,Hq_=Hq,Hkv_=Hkv,num_warps=nw)
e.record(); torch.cuda.synchronize()
times.append(s.elapsed_time(e))
med=sorted(times)[len(times)//2]
print(f'hints kernel num_warps={nw}: {med:.3f} ms')
PYEOF
uv run python /tmp/test_hints.py 2>&1hints kernel num_warps=2: 0.059 ms
hints kernel num_warps=4: 0.072 ms
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: """Multi-axis RoPE pre-attention prep fused kernel (Triton).
2:
3: Fuses gather + per-axis mixing + rotate-half + transpose into a single
4: kernel with direct (B, H, T, D) output.
5:
6: Key optimization:
7: - Group by (batch, time): one program loads all 6 cos/sin rows once and
8: processes ALL heads at that (b,t). Eliminates H_total-way redundant
9: cos/sin cache reads — the main bandwidth bottleneck.
10: - Process head_dim in two halves to avoid double-loading q/k.
11: - All rotation ops are element-local, no cross-thread communication.
12: """
13: import torch
14: import torch.nn as nn
15: import triton
16: import triton.language as tl
17:
18: OP_TYPE = "rope"
19: SUPPORTED_PRECISIONS = ["bf16"]
20: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
21:
22:
23: @triton.jit
24: def _mrope_kernel(
25: q_ptr, k_ptr, out_q_ptr, out_k_ptr,
26: cos_cache_ptr, sin_cache_ptr,
27: pos_t_ptr, pos_h_ptr, pos_w_ptr,
28: stride_q_b, stride_q_t, stride_q_h,
29: stride_k_b, stride_k_t, stride_k_h,
30: stride_out_q_b, stride_out_q_h, stride_out_q_t,
31: stride_out_k_b, stride_out_k_h, stride_out_k_t,
32: B, T, D,
33: s0, s1, s2,
34: half_D: tl.constexpr,
35: Hq_: tl.constexpr,
36: Hkv_: tl.constexpr,
37: ):
38: b = tl.program_id(0)
39: t = tl.program_id(1)
40:
41: # --- Load position IDs once per (b, t) ---
42
... (5537 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -62,11 +62,12 @@
sin_w_2 = tl.load(sin_cache_ptr + pw * D + d_offs + half_D)
# --- Per-element axis mixing ---
- axis = tl.where(d_offs < s0, 0, tl.where(d_offs < s0 + s1, 1, 2))
- cos_first = tl.where(axis == 0, cos_t, tl.where(axis == 1, cos_h, cos_w))
- sin_first = tl.where(axis == 0, sin_t, tl.where(axis == 1, sin_h, sin_w))
- cos_second = tl.where(axis == 0, cos_t_2, tl.where(axis == 1, cos_h_2, cos_w_2))
- sin_second = tl.where(axis == 0, sin_t_2, tl.where(axis == 1, sin_h_2, sin_w_2))
+ is_t = d_offs < s0
+ is_h = (d_offs < s0 + s1) & (d_offs >= s0)
+ cos_first = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+ sin_first = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+ cos_second = tl.where(is_t, cos_t_2, tl.where(is_h, cos_h_2, cos_w_2))
+ sin_second = tl.where(is_t, sin_t_2, tl.where(is_h, sin_h_2, sin_w_2))
# --- Process all q heads ---
for h in range(Hq_):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<433 chars \u2014 see below>",
"newString": "<393 chars \u2014 see below>"
} # --- Per-element axis mixing ---
axis = tl.where(d_offs < s0, 0, tl.where(d_offs < s0 + s1, 1, 2))
cos_first = tl.where(axis == 0, cos_t, tl.where(axis == 1, cos_h, cos_w))
sin_first = tl.where(axis == 0, sin_t, tl.where(axis == 1, sin_h, sin_w))
cos_second = tl.where(axis == 0, cos_t_2, tl.where(axis == 1, cos_h_2, cos_w_2))
sin_second = tl.where(axis == 0, sin_t_2, tl.where(axis == 1, sin_h_2, sin_w_2)) # --- Per-element axis mixing ---
is_t = d_offs < s0
is_h = (d_offs < s0 + s1) & (d_offs >= s0)
cos_first = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_first = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
cos_second = tl.where(is_t, cos_t_2, tl.where(is_h, cos_h_2, cos_w_2))
sin_second = tl.where(is_t, sin_t_2, tl.where(is_h, sin_h_2, sin_w_2))Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1 && for nw in 2 4; do sed -i \"s/num_warps=[0-9]\\+/num_warps=$nw/\" solution.py && uv run benchmark.py 2>&1 | grep -E \"shape=|solution_peak|peak_fraction\"; done",
"description": "Benchmark both num_warps=2 and num_warps=4",
"timeout": 600000
}PASS
shape=0 variant=eager tflops=0.200 gbps=286.641 ms=0.315
shape=0 variant=compiled tflops=0.569 gbps=817.006 ms=0.110
shape=0 variant=solution tflops=0.756 gbps=1084.629 ms=0.083
shape=0 solution_peak_fraction=0.6026
shape=1 variant=eager tflops=0.191 gbps=279.394 ms=0.526
shape=1 variant=compiled tflops=0.530 gbps=773.865 ms=0.190
shape=1 variant=solution tflops=0.797 gbps=1164.427 ms=0.126
shape=1 solution_peak_fraction=0.6469
shape=2 variant=eager tflops=0.097 gbps=151.513 ms=0.146
shape=2 variant=compiled tflops=0.226 gbps=352.743 ms=0.063
shape=2 variant=solution tflops=0.450 gbps=703.511 ms=0.031
shape=2 solution_peak_fraction=0.3908
shape=3 variant=eager tflops=0.178 gbps=254.717 ms=1.418
shape=3 variant=compiled tflops=0.485 gbps=696.637 ms=0.518
shape=3 variant=solution tflops=0.878 gbps=1259.148 ms=0.287
shape=3 solution_peak_fraction=0.6995
peak_fraction: 0.5714
shape=0 variant=eager tflops=0.200 gbps=287.503 ms=0.314
shape=0 variant=compiled tflops=0.568 gbps=814.529 ms=0.111
shape=0 variant=solution tflops=0.683 gbps=980.407 ms=0.092
shape=0 solution_peak_fraction=0.5447
shape=1 variant=eager tflops=0.191 gbps=279.004 ms=0.527
shape=1 variant=compiled tflops=0.530 gbps=773.539 ms=0.190
shape=1 variant=solution tflops=0.766 gbps=1118.224 ms=0.131
shape=1 solution_peak_fraction=0.6212
shape=2 variant=eager tflops=0.097 gbps=151.729 ms=0.146
shape=2 variant=compiled tflops=0.227 gbps=354.098 ms=0.062
shape=2 variant=solution tflops=0.477 gbps=745.229 ms=0.030
sh
... (274 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -62,12 +62,11 @@
sin_w_2 = tl.load(sin_cache_ptr + pw * D + d_offs + half_D)
# --- Per-element axis mixing ---
- is_t = d_offs < s0
- is_h = (d_offs < s0 + s1) & (d_offs >= s0)
- cos_first = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_first = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
- cos_second = tl.where(is_t, cos_t_2, tl.where(is_h, cos_h_2, cos_w_2))
- sin_second = tl.where(is_t, sin_t_2, tl.where(is_h, sin_h_2, sin_w_2))
+ axis = tl.where(d_offs < s0, 0, tl.where(d_offs < s0 + s1, 1, 2))
+ cos_first = tl.where(axis == 0, cos_t, tl.where(axis == 1, cos_h, cos_w))
+ sin_first = tl.where(axis == 0, sin_t, tl.where(axis == 1, sin_h, sin_w))
+ cos_second = tl.where(axis == 0, cos_t_2, tl.where(axis == 1, cos_h_2, cos_w_2))
+ sin_second = tl.where(axis == 0, sin_t_2, tl.where(axis == 1, sin_h_2, sin_w_2))
# --- Process all q heads ---
for h in range(Hq_):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<393 chars \u2014 see below>",
"newString": "<433 chars \u2014 see below>"
} # --- Per-element axis mixing ---
is_t = d_offs < s0
is_h = (d_offs < s0 + s1) & (d_offs >= s0)
cos_first = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_first = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
cos_second = tl.where(is_t, cos_t_2, tl.where(is_h, cos_h_2, cos_w_2))
sin_second = tl.where(is_t, sin_t_2, tl.where(is_h, sin_h_2, sin_w_2)) # --- Per-element axis mixing ---
axis = tl.where(d_offs < s0, 0, tl.where(d_offs < s0 + s1, 1, 2))
cos_first = tl.where(axis == 0, cos_t, tl.where(axis == 1, cos_h, cos_w))
sin_first = tl.where(axis == 0, sin_t, tl.where(axis == 1, sin_h, sin_w))
cos_second = tl.where(axis == 0, cos_t_2, tl.where(axis == 1, cos_h_2, cos_w_2))
sin_second = tl.where(axis == 0, sin_t_2, tl.where(axis == 1, sin_h_2, sin_w_2))Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1 +1 @@
- num_warps=2,+ num_warps=4,{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " num_warps=2,",
"newString": " num_warps=4,"
}{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && grep -n \"num_warps\" solution.py",
"description": "Check current num_warps value"
}142: num_warps=4,
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1 && uv run benchmark.py 2>&1",
"description": "Final benchmark with restored axis code + num_warps=4",
"timeout": 300000
}PASS
shape=0 variant=eager tflops=0.200 gbps=286.947 ms=0.315
shape=0 variant=compiled tflops=0.566 gbps=812.652 ms=0.111
shape=0 variant=solution tflops=0.681 gbps=977.011 ms=0.092
shape=0 solution_peak_fraction=0.5428
shape=1 variant=eager tflops=0.191 gbps=279.641 ms=0.526
shape=1 variant=compiled tflops=0.534 gbps=779.380 ms=0.189
shape=1 variant=solution tflops=0.766 gbps=1117.952 ms=0.131
shape=1 solution_peak_fraction=0.6211
shape=2 variant=eager tflops=0.097 gbps=151.181 ms=0.146
shape=2 variant=compiled tflops=0.231 gbps=361.034 ms=0.061
shape=2 variant=solution tflops=0.484 gbps=756.650 ms=0.029
shape=2 solution_peak_fraction=0.4204
shape=3 variant=eager tflops=0.175 gbps=251.754 ms=1.434
shape=3 variant=compiled tflops=0.486 gbps=696.766 ms=0.518
shape=3 variant=solution tflops=0.861 gbps=1235.031 ms=0.292
shape=3 solution_peak_fraction=0.6861
peak_fraction: 0.5584
RESULT: OK
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && for i in 1 2 3; do echo \"=== Run $i ===\" && uv run benchmark.py 2>&1 | grep -E \"shape=.*solution|peak_fraction\"; done",
"description": "Run benchmark 3 times for consistency",
"timeout": 600000
}=== Run 1 ===
shape=0 variant=solution tflops=0.683 gbps=979.726 ms=0.092
shape=0 solution_peak_fraction=0.5443
shape=1 variant=solution tflops=0.763 gbps=1114.561 ms=0.132
shape=1 solution_peak_fraction=0.6192
shape=2 variant=solution tflops=0.479 gbps=748.052 ms=0.030
shape=2 solution_peak_fraction=0.4156
shape=3 variant=solution tflops=0.863 gbps=1238.828 ms=0.291
shape=3 solution_peak_fraction=0.6882
peak_fraction: 0.5572
=== Run 2 ===
shape=0 variant=solution tflops=0.678 gbps=972.465 ms=0.093
shape=0 solution_peak_fraction=0.5403
shape=1 variant=solution tflops=0.765 gbps=1116.593 ms=0.132
shape=1 solution_peak_fraction=0.6203
shape=2 variant=solution tflops=0.473 gbps=739.647 ms=0.030
shape=2 solution_peak_fraction=0.4109
shape=3 variant=solution tflops=0.864 gbps=1239.848 ms=0.291
shape=3 solution_peak_fraction=0.6888
peak_fraction: 0.5550
=== Run 3 ===
shape=0 variant=solution tflops=0.680 gbps=975.997 ms=0.092
shape=0 solution_peak_fraction=0.5422
shape=1 variant=solution tflops=0.765 gbps=1116.729 ms=0.132
shape=1 solution_peak_fraction=0.6204
shape=2 variant=solution tflops=0.472 gbps=736.887 ms=0.030
shape=2 solution_peak_fraction=0.4094
shape=3 variant=solution tflops=0.863 gbps=1238.556 ms=0.292
shape=3 solution_peak_fraction=0.6881
peak_fraction: 0.5548
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && git log --oneline -5 -- solution.py 2>/dev/null; echo \"---\"; git diff HEAD -- solution.py 2>/dev/null | head -80",
"description": "Check git history for solution.py"
}---
{
"command": "diff /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py.bak /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py 2>/dev/null || echo \"no backup\"",
"description": "Compare with backup"
}(no output){
"command": "nvidia-smi --query-gpu=temperature.gpu,clocks.current.sm,clocks.current.memory,power.draw --format=csv,noheader 2>&1",
"description": "Check GPU state"
}29, 180 MHz, 405 MHz, 6.85 W
{
"command": "<1364 chars \u2014 see below>",
"description": "Manual benchmark to verify",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
from solution import Model, MROPE_SECTION, MAX_POS
import torch
m = Model(1, 4096, 32, 8, 128, MROPE_SECTION, MAX_POS).cuda().eval()
q = torch.randn(1, 4096, 32, 128, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(1, 4096, 8, 128, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, 32768, (1, 4096), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, 32768, (1, 4096), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, 32768, (1, 4096), dtype=torch.int64, device='cuda')
l2 = torch.empty(128*1024*1024//4, dtype=torch.float32, device='cuda')
# Warmup
for _ in range(10):
m(q,k,pos_t,pos_h,pos_w)
torch.cuda.synchronize()
# Measured with L2 flush
import time
times = []
for _ in range(30):
l2.zero_()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
m(q,k,pos_t,pos_h,pos_w)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
med = sorted(times)[len(times)//2]
print(f'Manual benchmark (with L2 flush): {med:.3f} ms')
B_bytes = 4 * 1 * 4096 * 40 * 128 + 12 * 1 * 4096 * 128 + 24 * 1 * 4096
gbps = B_bytes / med / 1e6
print(f'GB/s: {gbps:.1f}')
print(f'Peak fraction: {gbps/1800*100:.1f}%')
" 2>&1Manual benchmark (with L2 flush): 0.088 ms
GB/s: 1030.4
Peak fraction: 57.2%
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"offset": 55,
"limit": 45
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
55: sin_w = tl.load(sin_cache_ptr + pw * D + d_offs)
56:
57: cos_t_2 = tl.load(cos_cache_ptr + pt * D + d_offs + half_D)
58: sin_t_2 = tl.load(sin_cache_ptr + pt * D + d_offs + half_D)
59: cos_h_2 = tl.load(cos_cache_ptr + ph * D + d_offs + half_D)
60: sin_h_2 = tl.load(sin_cache_ptr + ph * D + d_offs + half_D)
61: cos_w_2 = tl.load(cos_cache_ptr + pw * D + d_offs + half_D)
62: sin_w_2 = tl.load(sin_cache_ptr + pw * D + d_offs + half_D)
63:
64: # --- Per-element axis mixing ---
65: axis = tl.where(d_offs < s0, 0, tl.where(d_offs < s0 + s1, 1, 2))
66: cos_first = tl.where(axis == 0, cos_t, tl.where(axis == 1, cos_h, cos_w))
67: sin_first = tl.where(axis == 0, sin_t, tl.where(axis == 1, sin_h, sin_w))
68: cos_second = tl.where(axis == 0, cos_t_2, tl.where(axis == 1, cos_h_2, cos_w_2))
69: sin_second = tl.where(axis == 0, sin_t_2, tl.where(axis == 1, sin_h_2, sin_w_2))
70:
71: # --- Process all q heads ---
72: for h in range(Hq_):
73: base_x = b * stride_q_b + t * stride_q_t + h * stride_q_h
74: x_first = tl.load(q_ptr + base_x + d_offs)
75: x_second = tl.load(q_ptr + base_x + d_offs + half_D)
76:
77: out_first = x_first * cos_first + (-x_second) * sin_first
78: out_second = x_second * cos_second + x_first * sin_second
79:
80: base_out = b * stride
... (997 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -62,11 +62,12 @@
sin_w_2 = tl.load(sin_cache_ptr + pw * D + d_offs + half_D)
# --- Per-element axis mixing ---
- axis = tl.where(d_offs < s0, 0, tl.where(d_offs < s0 + s1, 1, 2))
- cos_first = tl.where(axis == 0, cos_t, tl.where(axis == 1, cos_h, cos_w))
- sin_first = tl.where(axis == 0, sin_t, tl.where(axis == 1, sin_h, sin_w))
- cos_second = tl.where(axis == 0, cos_t_2, tl.where(axis == 1, cos_h_2, cos_w_2))
- sin_second = tl.where(axis == 0, sin_t_2, tl.where(axis == 1, sin_h_2, sin_w_2))
+ is_t = d_offs < s0
+ is_h = (d_offs >= s0) & (d_offs < s0 + s1)
+ cos_first = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+ sin_first = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+ cos_second = tl.where(is_t, cos_t_2, tl.where(is_h, cos_h_2, cos_w_2))
+ sin_second = tl.where(is_t, sin_t_2, tl.where(is_h, sin_h_2, sin_w_2))
# --- Process all q heads ---
for h in range(Hq_):
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<433 chars \u2014 see below>",
"newString": "<393 chars \u2014 see below>"
} # --- Per-element axis mixing ---
axis = tl.where(d_offs < s0, 0, tl.where(d_offs < s0 + s1, 1, 2))
cos_first = tl.where(axis == 0, cos_t, tl.where(axis == 1, cos_h, cos_w))
sin_first = tl.where(axis == 0, sin_t, tl.where(axis == 1, sin_h, sin_w))
cos_second = tl.where(axis == 0, cos_t_2, tl.where(axis == 1, cos_h_2, cos_w_2))
sin_second = tl.where(axis == 0, sin_t_2, tl.where(axis == 1, sin_h_2, sin_w_2)) # --- Per-element axis mixing ---
is_t = d_offs < s0
is_h = (d_offs >= s0) & (d_offs < s0 + s1)
cos_first = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_first = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
cos_second = tl.where(is_t, cos_t_2, tl.where(is_h, cos_h_2, cos_w_2))
sin_second = tl.where(is_t, sin_t_2, tl.where(is_h, sin_h_2, sin_w_2))Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -140,7 +140,7 @@
s0, s1, s2,
half_D=half_D,
Hq_=Hq, Hkv_=Hkv,
- num_warps=4,
+ num_warps=2,
)
return out_q, out_k
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " num_warps=4,",
"newString": " num_warps=2,"
}Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run check.py 2>&1 && uv run benchmark.py 2>&1",
"description": "Final benchmark with boolean+num_warps=2",
"timeout": 300000
}PASS
shape=0 variant=eager tflops=0.200 gbps=287.605 ms=0.314
shape=0 variant=compiled tflops=0.565 gbps=811.249 ms=0.111
shape=0 variant=solution tflops=0.754 gbps=1081.510 ms=0.083
shape=0 solution_peak_fraction=0.6008
shape=1 variant=eager tflops=0.191 gbps=279.437 ms=0.526
shape=1 variant=compiled tflops=0.526 gbps=767.979 ms=0.191
shape=1 variant=solution tflops=0.798 gbps=1165.165 ms=0.126
shape=1 solution_peak_fraction=0.6473
shape=2 variant=eager tflops=0.096 gbps=150.769 ms=0.147
shape=2 variant=compiled tflops=0.226 gbps=353.194 ms=0.063
shape=2 variant=solution tflops=0.454 gbps=708.923 ms=0.031
shape=2 solution_peak_fraction=0.3938
shape=3 variant=eager tflops=0.177 gbps=254.536 ms=1.419
shape=3 variant=compiled tflops=0.486 gbps=696.939 ms=0.518
shape=3 variant=solution tflops=0.877 gbps=1258.867 ms=0.287
shape=3 solution_peak_fraction=0.6994
peak_fraction: 0.5721
RESULT: OK
{
"todos": [
{
"priority": "high",
"content": "Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml",
"status": "completed"
},
{
"priority": "high",
"content": "Implement baseline Triton kernel for mRoPE prep",
"status": "completed"
},
{
"priority": "high",
"content": "Check correctness with python check.py",
"status": "completed"
},
{
"priority": "high",
"content": "Profile and optimize iteratively",
"status": "completed"
},
{
"priority": "high",
"conte
... (72 more chars)content: 'Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml'
status: 'completed'
priority: 'high'
content: 'Implement baseline Triton kernel for mRoPE prep'
status: 'completed'
priority: 'high'
content: 'Check correctness with python check.py'
status: 'completed'
priority: 'high'
content: 'Profile and optimize iteratively'
status: 'completed'
priority: 'high'
content: 'Benchmark and finalize'
status: 'in_progress'
priority: 'high'{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run ruff check problems/09_fmha_preattn_mrope/solution.py --fix 2>&1",
"description": "Lint solution.py"
}All checks passed!
{
"command": "rm -f /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py.bak /tmp/test_nw.py /tmp/test_inline.py /tmp/test_hints.py",
"description": "Clean up temp files"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: """Multi-axis RoPE pre-attention prep fused kernel (Triton).
2:
3: Fuses gather + per-axis mixing + rotate-half + transpose into a single
4: kernel with direct (B, H, T, D) output.
5:
6: Key optimization:
7: - Group by (batch, time): one program loads all 6 cos/sin rows once and
8: processes ALL heads at that (b,t). Eliminates H_total-way redundant
9: cos/sin cache reads — the main bandwidth bottleneck.
10: - Process head_dim in two halves to avoid double-loading q/k.
11: - All rotation ops are element-local, no cross-thread communication.
12: """
13: import torch
14: import torch.nn as nn
15: import triton
16: import triton.language as tl
17:
18: OP_TYPE = "rope"
19: SUPPORTED_PRECISIONS = ["bf16"]
20: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
21:
22:
23: @triton.jit
24: def _mrope_kernel(
25: q_ptr, k_ptr, out_q_ptr, out_k_ptr,
26: cos_cache_ptr, sin_cache_ptr,
27: pos_t_ptr, pos_h_ptr, pos_w_ptr,
28: stride_q_b, stride_q_t, stride_q_h,
29: stride_k_b, stride_k_t, stride_k_h,
30: stride_out_q_b, stride_out_q_h, stride_out_q_t,
31: stride_out_k_b, stride_out_k_h, stride_out_k_t,
32: B, T, D,
33: s0, s1, s2,
34: half_D: tl.constexpr,
35: Hq_: tl.constexpr,
36: Hkv_: tl.constexpr,
37: ):
38: b = tl.program_id(0)
39: t = tl.program_id(1)
40:
41: # --- Load position IDs once per (b, t) ---
42
... (5502 more chars){
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && cat framework.txt 2>/dev/null && echo \"---\" && grep -c \"triton\" solution.py && echo \"triton references found\"",
"description": "Verify framework label"
}triton
---
3
triton references found
{
"command": "<592 chars \u2014 see below>",
"description": "Verify no forbidden ops"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
# Verify forbidden ops are not present
import re
src = open('solution.py').read()
forbidden = [
'transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb',
'apply_multimodal_rotary_pos_emb',
'flash_attn.layers.rotary',
'flashinfer.rope',
'flashinfer.apply_rope',
]
for f in forbidden:
if re.search(re.escape(f), src):
print(f'FAIL: forbidden: {f}')
exit(1)
print('No forbidden ops found')
print(f'Lines: {len(src.splitlines())}')
" 2>&1No forbidden ops found
Lines: 169
{
"todos": [
{
"priority": "high",
"content": "Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml",
"status": "completed"
},
{
"priority": "high",
"content": "Implement baseline Triton kernel for mRoPE prep",
"status": "completed"
},
{
"priority": "high",
"content": "Check correctness with python check.py",
"status": "completed"
},
{
"priority": "high",
"content": "Profile and optimize iteratively",
"status": "completed"
},
{
"priority": "high",
"conte
... (70 more chars)content: 'Read reference.py, solution.py, check.py, benchmark.py, and problem.yaml'
status: 'completed'
priority: 'high'
content: 'Implement baseline Triton kernel for mRoPE prep'
status: 'completed'
priority: 'high'
content: 'Check correctness with python check.py'
status: 'completed'
priority: 'high'
content: 'Profile and optimize iteratively'
status: 'completed'
priority: 'high'
content: 'Benchmark and finalize'
status: 'completed'
priority: 'high'