"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
Unified kernel: cos/sin loaded once per (b, t_tile, d_tile), reused across all H.
Tunable BLOCK_T, BLOCK_D, num_warps for each shape.
"""
import torch
import torch.nn as nn
import triton
import triton.language as tl
@triton.jit
def _rope_kernel(
q_ptr, k_ptr,
cos_cache, sin_cache,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
q_out_ptr, k_out_ptr,
B: tl.int64,
T: tl.int64,
Hq: tl.int64,
Hkv: tl.int64,
D: tl.int64,
s0: tl.int64,
s1: tl.int64,
stride_qb, stride_qt, stride_qh, stride_qd,
stride_kb, stride_kt, stride_kh, stride_kd,
stride_pb, stride_pt,
stride_ccb, stride_ccd,
stride_qob, stride_qoh, stride_qot, stride_qod,
stride_kob, stride_koh, stride_kot, stride_kod,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
pid_d = tl.program_id(1)
R = tl.cdiv(T, BLOCK_T)
D_tiles = tl.cdiv(D, BLOCK_D)
r = pid % R
b = pid // R
d_off = pid_d * BLOCK_D
if b >= B or pid_d >= D_tiles:
return
t_off = r * BLOCK_T
half_d = D // 2
_rt = tl.arange(0, BLOCK_T)
_rd = tl.arange(0, BLOCK_D)
t_rng = t_off + _rt
d_rng = d_off + _rd
t_mask = t_rng < T
d_mask = d_rng < D
td_mask = t_mask[:, None] & d_mask[None, :]
# Positions
po = b * stride_pb + t_rng * stride_pt
p_t = tl.load(pos_t_ptr + po, mask=t_mask, other=0)
p_h = tl.load(pos_h_ptr + po, mask=t_mask, other=0)
p_w = tl.load(pos_w_ptr + po, mask=t_mask, other=0)
# Axis select per d
d_f = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
is_t = d_f < s0
is_h = (d_f >= s0) & (d_f < s0 + s1)
pos_sel = tl.where(is_t[None, :], p_t[:, None],
tl.where(is_h[None, :], p_h[:, None], p_w[:, None]))
# Cos/sin cache — loaded ONCE, reused for all heads
cache_off = pos_sel * stride_ccb + d_rng[None, :] * stride_ccd
cos = tl.load(cos_cache + cache_off, mask=td_mask, other=1.0)
sin = tl.load(sin_cache + cache_off, mask=td_mask, other=0.0)
# Rotate-half mirror indices (same for all heads)
d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
is_first_half = d_rng[None, :] < half_d
# Process Q heads
h = 0
while h < Hq:
q_base = b * stride_qb + h * stride_qh
qx = tl.load(q_ptr + q_base + t_rng[:, None] * stride_qt + d_rng[None, :] * stride_qd,
mask=td_mask, other=0.0)
qm = tl.load(q_ptr + q_base + t_rng[:, None] * stride_qt + d_mir[None, :] * stride_qd,
mask=td_mask, other=0.0)
out = qx * cos + tl.where(is_first_half, -qm, qm) * sin
tl.store(q_out_ptr + b * stride_qob + h * stride_qoh + t_rng[:, None] * stride_qot + d_rng[None, :] * stride_qod,
out, mask=td_mask)
h += 1
# Process K heads
h = 0
while h < Hkv:
k_base = b * stride_kb + h * stride_kh
kx = tl.load(k_ptr + k_base + t_rng[:, None] * stride_kt + d_rng[None, :] * stride_kd,
mask=td_mask, other=0.0)
km = tl.load(k_ptr + k_base + t_rng[:, None] * stride_kt + d_mir[None, :] * stride_kd,
mask=td_mask, other=0.0)
out = kx * cos + tl.where(is_first_half, -km, km) * sin
tl.store(k_out_ptr + b * stride_kob + h * stride_koh + t_rng[:, None] * stride_kot + d_rng[None, :] * stride_kod,
out, mask=td_mask)
h += 1
def apply_mrope_torch(q, k, cos_cache, sin_cache, pos_t, pos_h, pos_w, mrope_section):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
s0, s1, s2 = mrope_section
q_rot = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
k_rot = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
# Tunable configs — selected based on shape characteristics
# Larger H benefits from smaller grid => larger BLOCK_T
# Bandwidth-bound: want maximal memory throughput
if D == 64:
BLOCK_T, BLOCK_D, num_warps = 32, 64, 4
elif T <= 4096:
BLOCK_T, BLOCK_D, num_warps = 16, 64, 4
else:
BLOCK_T, BLOCK_D, num_warps = 16, 64, 8
n_t = (T + BLOCK_T - 1) // BLOCK_T
n_d = (D + BLOCK_D - 1) // BLOCK_D
_rope_kernel[(B * n_t, n_d)](
q, k, cos_cache, sin_cache,
pos_t, pos_h, pos_w,
q_rot, k_rot,
B=B, T=T, Hq=Hq, Hkv=Hkv, D=D,
s0=s0, s1=s1,
stride_qb=q.stride(0), stride_qt=q.stride(1),
stride_qh=q.stride(2), stride_qd=q.stride(3),
stride_kb=k.stride(0), stride_kt=k.stride(1),
stride_kh=k.stride(2), stride_kd=k.stride(3),
stride_pb=pos_t.stride(0), stride_pt=pos_t.stride(1),
stride_ccb=cos_cache.stride(0), stride_ccd=cos_cache.stride(1),
stride_qob=q_rot.stride(0), stride_qoh=q_rot.stride(1),
stride_qot=q_rot.stride(2), stride_qod=q_rot.stride(3),
stride_kob=k_rot.stride(0), stride_koh=k_rot.stride(1),
stride_kot=k_rot.stride(2), stride_kod=k_rot.stride(3),
BLOCK_T=BLOCK_T, BLOCK_D=BLOCK_D,
num_warps=num_warps,
)
return q_rot, k_rot
class Model(nn.Module):
def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
mrope_section: tuple[int, int, int], max_pos: int):
super().__init__()
assert sum(mrope_section) == D // 2
self.B, self.T = B, T
self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
pos = torch.arange(max_pos, dtype=torch.float32)
freqs = torch.outer(pos, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
def forward(
self,
q: torch.Tensor, k: torch.Tensor,
pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return apply_mrope_torch(
q, k, self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w, self.mrope_section,
)
B = 1
T = 4096
Hq = 32
Hkv = 8
D = 128
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768
def get_inputs():
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
return [q, k, pos_t, pos_h, pos_w]
def get_init_inputs():
return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
shape=0 variant=eager tflops=0.177 gbps=253.345 ms=0.356
shape=0 variant=compiled tflops=0.509 gbps=729.915 ms=0.124
shape=0 variant=solution tflops=0.653 gbps=936.471 ms=0.096
shape=0 solution_peak_fraction=0.5203
shape=1 variant=eager tflops=0.165 gbps=241.030 ms=0.610
shape=1 variant=compiled tflops=0.459 gbps=670.412 ms=0.219
shape=1 variant=solution tflops=0.722 gbps=1054.439 ms=0.139
shape=1 solution_peak_fraction=0.5858
shape=2 variant=eager tflops=0.070 gbps=109.697 ms=0.202
shape=2 variant=compiled tflops=0.164 gbps=255.669 ms=0.087
shape=2 variant=solution tflops=0.260 gbps=405.634 ms=0.055
shape=2 solution_peak_fraction=0.2254
shape=3 variant=eager tflops=0.172 gbps=246.866 ms=1.463
shape=3 variant=compiled tflops=0.462 gbps=662.896 ms=0.545
shape=3 variant=solution tflops=0.658 gbps=943.913 ms=0.383
shape=3 solution_peak_fraction=0.5244
peak_fraction: 0.4356
RESULT: OK
PASS
{
"problem": "09_fmha_preattn_mrope",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-27b",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.4356,
"elapsed_seconds": 1610,
"harness_exit_code": 0,
"session_complete": true,
"usage": {"input_tokens": 2339275, "output_tokens": 46277, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 13925, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
1: """Naive multi-axis RoPE pre-attention reference (correctness only).
2:
3: Mirrors the Qwen2-VL apply_multimodal_rotary_pos_emb semantics: three position
4: axes (temporal, height, width), each contributing a slice of the head_dim via
5: mrope_section. Standard rotate-half on q and k once per-axis cos/sin have been
6: mixed.
7:
8: Layout convention:
9: q: (B, T, Hq, D) bf16
10: k: (B, T, Hkv, D) bf16
11: pos_t, pos_h, pos_w: (B, T) int64
12:
13: Output: q_rot, k_rot in (B, H, T, D) — i.e. the layout an attention kernel
14: expects (head dim second). The transpose is part of the "pre-attention prep".
15: """
16: import torch
17: import torch.nn as nn
18:
19: OP_TYPE = "rope"
20: SUPPORTED_PRECISIONS = ["bf16"]
21: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
22:
23:
24: def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
25: return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
26:
27:
28: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
29: half = x.shape[-1] // 2
30: return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
31:
32:
33: class Model(nn.Module):
34: def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
35: mrope_section: tuple[int, int, int], max_pos: int):
36: super().__init__()
37: assert sum(mrope_section) == D // 2, \
38:
... (3384 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for multi-axis RoPE pre-attention prep.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or
... (3519 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for multi-axis RoPE pre-attention prep.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime
... (3966 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for multi-axis RoPE pre-attention.
2:
3: Sized for Qwen2-VL-style vision-language inference: long T, GQA, mrope_section
4: splitting D/2 across temporal/height/width axes. Mix of base, long-context,
5: small-head-dim, and very-long-context.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL base: T=4k, head_dim=128, GQA 32:8, mrope [16,24,24]
10: {"B": 1, "T": 4096, "Hq": 32, "Hkv": 8, "D": 128,
11: "mrope_section": (16, 24, 24), "max_pos": 32768},
12:
13: # Qwen2-VL 7B-style long context, narrower kv heads
14: {"B": 1, "T": 8192, "Hq": 28, "Hkv": 4, "D": 128,
15: "mrope_section": (16, 24, 24), "max_pos": 32768},
16:
17: # Smaller head_dim and B=2 (tests batch dim and reduced D)
18: {"B": 2, "T": 2048, "Hq": 16, "Hkv": 2, "D": 64,
19: "mrope_section": (8, 12, 12), "max_pos": 16384},
20:
21: # Very long context (16k tokens)
22: {"B": 1, "T": 16384, "Hq": 32, "Hkv": 8, "D": 128,
23: "mrope_section": (16, 24, 24), "max_pos": 65536},
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml</path>
<type>file</type>
<content>
1: name: 09_fmha_preattn_mrope
2: display_name: "Multi-axis RoPE Pre-attention Prep"
3: precision: bf16
4: regime: memory # pointwise rotation + cache gather; bandwidth-bound
5:
6: # FLOPs per call: rotate-half on q and k.
7: # Per element: out = a*cos + b*sin -> 3 flops/elem (1 mul + 1 mul + 1 add).
8: # We do this on q ((B*T*Hq*D) elems) and k ((B*T*Hkv*D) elems).
9: flops_formula: "3 * B * T * (Hq + Hkv) * D"
10:
11: # Bytes moved (lower bound):
12: # read q, k: 2 * B*T*(Hq+Hkv)*D
13: # read 3 position id arrays: 3 * B*T*8 (int64)
14: # read cos/sin from cache: 6 * B*T*D*2 (3 axes, 2 caches, bf16)
15: # write q_rot, k_rot: 2 * B*T*(Hq+Hkv)*D
16: bytes_formula: "4 * B * T * (Hq + Hkv) * D + 12 * B * T * D + 24 * B * T"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram
21:
22: tolerance:
23: bfloat16: 0.02
24:
25: # Forbidden ops -- agent must implement the gather + per-axis mix + rotation.
26: # Banning vendor mRoPE entry points and obvious shortcuts.
27: forbidden:
28: - "transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb"
29: - "apply_multimodal_rotary_pos_emb"
30: - "flash_attn.layers.rotary"
31: - "flashinfer.rope"
32: - "flashinfer.apply_rope"
33:
34: sota:
35: name:
... (284 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for multi-axis RoPE pre-attention.
2:
3: Wraps transformers' apply_multimodal_rotary_pos_emb. The full pipeline (gather
4: per-axis cos/sin from cache, then call the vendor rotation) is timed, matching
5: what the reference and the agent solution do.
6:
7: If transformers isn't installed or the impl signature has drifted, returns
8: is_available()=False and the benchmark omits the SOTA line.
9: """
10: from __future__ import annotations
11:
12: import torch
13:
14:
15: def _try_transformers():
16: try:
17: from transformers.models.qwen2_vl.modeling_qwen2_vl import (
18: apply_multimodal_rotary_pos_emb as _impl,
19: )
20: return _impl
21: except Exception:
22: return None
23:
24:
25: def sota_forward(
26: q: torch.Tensor, k: torch.Tensor,
27: pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
28: cos_cache: torch.Tensor, sin_cache: torch.Tensor,
29: mrope_section: tuple[int, int, int],
30: ):
31: """q,k: (B, T, H, D); pos_*: (B, T); cos_cache,sin_cache: (max_pos, D)."""
32: impl = _try_transformers()
33: if impl is None:
34: raise RuntimeError("transformers not available for SOTA")
35:
36: # Gather per-axis cos/sin and stack into (3, B, T, D) as transformers expects.
37: cos = torch.stack([cos_cache[pos_t], cos_cache[pos_h], c
... (495 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/CLAUDE.md"
}<path>/home/infatoshi/cuda/KernelBench-Hard/CLAUDE.md</path>
<type>file</type>
<content>
1: # KernelBench-Hard — Developer Instructions
2:
3: Last updated: 2026-04-27.
4:
5: This file is for **coding agents editing the repo** (you, via Claude Code). Do not confuse with `problems/<X>/PROMPT.txt` — those are the human-voice queries fed to agents _under test_.
6:
7: For the journey behind the current design, read [DEVLOG.md](./DEVLOG.md).
8:
9: ## What this repo is
10:
11: Small kernel benchmark. Frontier coding agents are given URLs to SOTA implementations (sonic-moe, flashinfer, marlin) and asked to write a competitive kernel on RTX PRO 6000 Blackwell (SM120) in 45 minutes. Roofline-graded. Published artifact is the best kernel per (problem × model × harness), plus the agent trace.
12:
13: See [SPEC.md](./SPEC.md) for methodology. See [README.md](./README.md) for the model matrix and quick start.
14:
15: ## Non-negotiable rules
16:
17: - **uv only.** No bare `python`, no `pip`. Use `uv run ...`, `uv add ...`, `uv pip install ...`.
18: - **Before committing:** `uv run ruff check . --fix && uv run pytest`.
19: - **Never edit `problems/*/solution.py`**. Those files are agent output; they're gitignored for a reason. If you need to inspect one, read it from `outputs/runs/<run>/<problem>/solution.py`.
20: - **Never modify `problems/*/reference.py`, `check.py`, `benchmark.py`, `problem.yaml`, `shapes.py`, or `PROMPT.txt`** once a sweep has been published. Those define the ben
... (5941 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py</path>
<type>file</type>
<content>
1: """Roofline math: achieved TFLOPS / GB/s, peak fraction.
2:
3: FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
4: dense-equivalent algorithmic work; kernels that skip work (sparsity, early
5: exit) are still measured against the dense cost.
6: """
7: from __future__ import annotations
8:
9:
10: def compute_tflops(flops: float, time_ms: float) -> float:
11: """flops per invocation and wall time in ms -> TFLOPS."""
12: if time_ms <= 0:
13: return 0.0
14: return flops / (time_ms * 1e-3) / 1e12
15:
16:
17: def compute_gbps(bytes_moved: float, time_ms: float) -> float:
18: """bytes per invocation and wall time in ms -> GB/s."""
19: if time_ms <= 0:
20: return 0.0
21: return bytes_moved / (time_ms * 1e-3) / 1e9
22:
23:
24: def peak_fraction(
25: achieved: float,
26: peak: float,
27: ) -> float:
28: """Return achieved / peak, clamped to [0, infinity)."""
29: if peak <= 0:
30: return 0.0
31: return max(0.0, achieved / peak)
(End of file - total 31 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware</path>
<type>directory</type>
<entries>
__init__.py
__pycache__/
m4_max.py
rtx_pro_6000.py
(4 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__init__.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/__init__.py</path>
<type>file</type>
<content>
1: """Hardware peak-throughput lookup tables."""
2: from src.hardware.m4_max import M4_MAX
3: from src.hardware.rtx_pro_6000 import RTX_PRO_6000
4:
5: TARGETS = {
6: "RTX_PRO_6000": RTX_PRO_6000,
7: "M4_MAX": M4_MAX,
8: }
9:
10:
11: def get(name: str):
12: if name not in TARGETS:
13: raise ValueError(f"Unknown hardware {name!r}; available: {list(TARGETS)}")
14: return TARGETS[name]
(End of file - total 14 lines)
</content>{
"command": "nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader 2>/dev/null | head -1",
"description": "Check GPU status"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 595.58.03, 97887 MiB
{
"command": "uv run python -c \"import triton; print(triton.__version__)\"",
"description": "Check Triton version"
}3.6.0
{
"command": "uv run python -c \"import torch; print(torch.__version__); print(torch.version.cuda)\"",
"description": "Check torch version"
}2.11.0+cu130
13.0
{
"todos": [
{
"content": "Write initial Triton kernel solution.py",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py for performance baseline",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and optimize kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Verify final correctness + benchmark",
"status": "pending",
"prior
... (24 more chars)content: 'Write initial Triton kernel solution.py'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py for performance baseline'
status: 'pending'
priority: 'high'
content: 'Profile and optimize kernel'
status: 'pending'
priority: 'high'
content: 'Verify final correctness + benchmark'
status: 'pending'
priority: 'high'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,601 @@
+"""Triton kernel for multi-axis RoPE pre-attention prep.
+
+Fuses gather + per-axis mix + rotate-half + transpose into one kernel.
+Streams q/k once and emits the rotated (B, H, T, D) layout directly.
+"""
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _mrope_kernel(
+ q_ptr, k_ptr,
+ cos_cache, sin_cache,
+ pos_t, pos_h, pos_w,
+ q_rot, k_rot,
+ B: tl.int64, T: tl.int64, Hq: tl.int64, Hkv: tl.int64, D: tl.int64,
+ s0: tl.int64, s1: tl.int64, s2: tl.int64,
+ stride_qb, stride_qt, stride_qh, stride_qd,
+ stride_kb, stride_kt, stride_kh, stride_kd,
+ stride_pb, stride_pt,
+ stride_cocos, stride_cod,
+ stride_qrb, stride_qrr, stride_qrh, stride_qrd,
+ stride_krb, stride_krr, stride_krh, stride_krd,
+ BLOCK_TD: tl.constexpr,
+):
+ """Single kernel that processes one tensor (q or k).
+
+ Each 1D program over (B*H*R) space. Within each program, a 2D grid
+ of threads handles tiles of (T, D).
+
+ R = ceil(T / BLOCK_TD) -- T tiles
+ """
+ # This program handles a specific (b, h, r) triple
+ # where r is the T-tile index
+ pid = tl.program_id(0)
+
+ # Decompose pid into (b, h, r)
+ R = tl.cdiv(T, BLOCK_TD)
+ H = Hq # set per call
+ r = pid % R
+ bh = pid // R
+ b = bh // H
+ h = bh % H
+
+ if b >= B:
+ return
+
+ # T offset and stride for output layout (B, H, T, D)
+ t_offset = r * BLOCK_TD
+ t_idx = t_offset + tl.arange(0, BLOCK_TD)
+ t_mask = t_idx < T
+
+ # D indices - all threads in a warp handle different D elements
+ # We use a simple mapping: each thread handles one TD pair
+ d_idx = tl.arange(0, D)
+
+ # We need threads = BLOCK_TD * D to cover the full tile, but that's too large.
+ # Instead, we'll process D serially in chunks.
+ # Actually let's use BLOCK_D as well.
+ # For now, let's just handle D elements sequentially per thread group.
+ pass # placeholder
+
+
+@triton.jit
+def _mrope_fused(
+ q_ptr, k_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ q_rot_ptr, k_rot_ptr,
+ B: tl.int64, T: tl.int64, Hq: tl.int64, Hkv: tl.int64, D: tl.int64,
+ s0: tl.int64, s1: tl.int64, s2: tl.int64,
+ stride_qb, stride_qt, stride_qh, stride_qd,
+ stride_kb, stride_kt, stride_kh, stride_kd,
+ stride_pb, stride_pt,
+ stride_ccd, # cos/sin cache: (max_pos, D)
+ stride_qrb, stride_qrr, stride_qrh, stride_qrd,
+ stride_krb, stride_krr, stride_krh, stride_krd,
+ BLOCK_DT: tl.constexpr,
+):
+ """
+ 2D grid: (B*Hq*T//BLOCK_DT, D). Each program handles D elements for one (b,h,t) group.
+
+ Actually we'll do TD tiling.
+ """
+ pass # split below
+
+
+def _apply_mrope_kernel(q, k, cos_cache, sin_cache, pos_t, pos_h, pos_w,
+ mrope_section, D):
+ """Launch the fused mRoPE + transpose kernel."""
+ B, T, Hq, _ = q.shape
+ _, _, Hkv, _ = k.shape
+ s0, s1, s2 = mrope_section
+
+ # Output buffers in (B, H, T, D) layout
+ q_rot = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
+ k_rot = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
+
+ # Process q and k separately but with same cos/sin
+ # Grid: each program handles (B*T*Hq) * D elements
+ # We'll use a simpler approach: (B * H) programs, each handles T * D
+
+ BLOCK_D = 64
+ BLOCK_T = 64
+
+ total_elements = B * T * (Hq + Hkv) * D
+ num_warps = 4
+
+ # Launch for q
+ if q.numel() > 0:
+ _mrope_single_kernel[(B * Hq * triton.cdiv(T, BLOCK_T),)](
+ q, cos_cache, sin_cache, pos_t, pos_h, pos_w,
+ q_rot,
+ B=B, T=T, H=Hq, D=D,
+ s0=s0, s1=s1, s2=s2,
+ BLOCK_T=BLOCK_T,
+ num_warps=num_warps,
+ )
+
+ # Launch for k
+ if k.numel() > 0:
+ _mrope_single_kernel[(B * Hkv * triton.cdiv(T, BLOCK_T),)](
+ k, cos_cache, sin_cache, pos_t, pos_h, pos_w,
+ k_rot,
+ B=B, T=T, H=Hkv, D=D,
+ s0=s0, s1=s1, s2=s2,
+ BLOCK_T=BLOCK_T,
+ num_warps=num_warps,
+ )
+
+ return q_rot, k_rot
+
+
+@triton.jit
+def _mrope_single_kernel(
+ x_ptr, # (B, T, H, D) input
+ cos_cache_ptr, # (max_pos, D)
+ sin_cache_ptr,
+ pos_t_ptr, # (B, T)
+ pos_h_ptr,
+ pos_w_ptr,
+ out_ptr, # (B, H, T, D) output
+ B: tl.int64, T: tl.int64, H: tl.int64, D: tl.int64,
+ s0: tl.int64, s1: tl.int64, s2: tl.int64,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_outb, stride_outh, stride_outt, stride_outd,
+ stride_pb, stride_pt,
+ stride_cachepos, stride_cached, # cos/sin cache strides
+ BLOCK_T: tl.constexpr,
+):
+ """
+ Each program handles one (b, h, t_tile) group.
+ Within the program, we process all D elements using thread mapping.
+ """
+ pid = tl.program_id(0)
+ R = tl.cdiv(T, BLOCK_T)
+ h_total = H
+ bh = pid // R
+ r = pid % R
+ b = bh // h_total
+ h = bh % h_total
+
+ if b >= B:
+ return
+
+ # T indices for this tile
+ t_off = r * BLOCK_T
+ t_idx = t_off + tl.arange(0, BLOCK_T)
+ t_mask = t_idx < T
+
+ # D indices - we process ALL D across the program's threads
+ # With BLOCK_T=64 and num_warps=4 (128 threads), each thread handles D/128 elements
+ # For D=128, each thread handles 1 element. For D=64, we can have fewer threads.
+ stride = tl.num_programs(0) # not needed
+ pid_d = tl.program_id(0) % D # won't work - this is 1D grid
+
+ # Better approach: make grid 2D or use a different strategy
+ # Let's use a 1D grid over (b, h, t_tile, d_chunk)
+ pass
+
+
+# Let me rewrite with proper tiling strategy
+@triton.jit
+def _mrope_kernel_v2(
+ x_ptr,
+ co
... (16484 more chars){
"content": "<21764 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -11,545 +11,185 @@
@triton.jit
def _mrope_kernel(
- q_ptr, k_ptr,
- cos_cache, sin_cache,
- pos_t, pos_h, pos_w,
- q_rot, k_rot,
- B: tl.int64, T: tl.int64, Hq: tl.int64, Hkv: tl.int64, D: tl.int64,
- s0: tl.int64, s1: tl.int64, s2: tl.int64,
- stride_qb, stride_qt, stride_qh, stride_qd,
- stride_kb, stride_kt, stride_kh, stride_kd,
- stride_pb, stride_pt,
- stride_cocos, stride_cod,
- stride_qrb, stride_qrr, stride_qrh, stride_qrd,
- stride_krb, stride_krr, stride_krh, stride_krd,
- BLOCK_TD: tl.constexpr,
-):
- """Single kernel that processes one tensor (q or k).
-
- Each 1D program over (B*H*R) space. Within each program, a 2D grid
- of threads handles tiles of (T, D).
-
- R = ceil(T / BLOCK_TD) -- T tiles
- """
- # This program handles a specific (b, h, r) triple
- # where r is the T-tile index
- pid = tl.program_id(0)
-
- # Decompose pid into (b, h, r)
- R = tl.cdiv(T, BLOCK_TD)
- H = Hq # set per call
- r = pid % R
- bh = pid // R
- b = bh // H
- h = bh % H
-
- if b >= B:
- return
-
- # T offset and stride for output layout (B, H, T, D)
- t_offset = r * BLOCK_TD
- t_idx = t_offset + tl.arange(0, BLOCK_TD)
- t_mask = t_idx < T
-
- # D indices - all threads in a warp handle different D elements
- # We use a simple mapping: each thread handles one TD pair
- d_idx = tl.arange(0, D)
-
- # We need threads = BLOCK_TD * D to cover the full tile, but that's too large.
- # Instead, we'll process D serially in chunks.
- # Actually let's use BLOCK_D as well.
- # For now, let's just handle D elements sequentially per thread group.
- pass # placeholder
-
-
-@triton.jit
-def _mrope_fused(
- q_ptr, k_ptr,
- cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- q_rot_ptr, k_rot_ptr,
- B: tl.int64, T: tl.int64, Hq: tl.int64, Hkv: tl.int64, D: tl.int64,
- s0: tl.int64, s1: tl.int64, s2: tl.int64,
- stride_qb, stride_qt, stride_qh, stride_qd,
- stride_kb, stride_kt, stride_kh, stride_kd,
- stride_pb, stride_pt,
- stride_ccd, # cos/sin cache: (max_pos, D)
- stride_qrb, stride_qrr, stride_qrh, stride_qrd,
- stride_krb, stride_krr, stride_krh, stride_krd,
- BLOCK_DT: tl.constexpr,
-):
- """
- 2D grid: (B*Hq*T//BLOCK_DT, D). Each program handles D elements for one (b,h,t) group.
-
- Actually we'll do TD tiling.
- """
- pass # split below
-
-
-def _apply_mrope_kernel(q, k, cos_cache, sin_cache, pos_t, pos_h, pos_w,
- mrope_section, D):
- """Launch the fused mRoPE + transpose kernel."""
- B, T, Hq, _ = q.shape
- _, _, Hkv, _ = k.shape
- s0, s1, s2 = mrope_section
-
- # Output buffers in (B, H, T, D) layout
- q_rot = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
- k_rot = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
-
- # Process q and k separately but with same cos/sin
- # Grid: each program handles (B*T*Hq) * D elements
- # We'll use a simpler approach: (B * H) programs, each handles T * D
-
- BLOCK_D = 64
- BLOCK_T = 64
-
- total_elements = B * T * (Hq + Hkv) * D
- num_warps = 4
-
- # Launch for q
- if q.numel() > 0:
- _mrope_single_kernel[(B * Hq * triton.cdiv(T, BLOCK_T),)](
- q, cos_cache, sin_cache, pos_t, pos_h, pos_w,
- q_rot,
- B=B, T=T, H=Hq, D=D,
- s0=s0, s1=s1, s2=s2,
- BLOCK_T=BLOCK_T,
- num_warps=num_warps,
- )
-
- # Launch for k
- if k.numel() > 0:
- _mrope_single_kernel[(B * Hkv * triton.cdiv(T, BLOCK_T),)](
- k, cos_cache, sin_cache, pos_t, pos_h, pos_w,
- k_rot,
- B=B, T=T, H=Hkv, D=D,
- s0=s0, s1=s1, s2=s2,
- BLOCK_T=BLOCK_T,
- num_warps=num_warps,
- )
-
- return q_rot, k_rot
-
-
-@triton.jit
-def _mrope_single_kernel(
- x_ptr, # (B, T, H, D) input
- cos_cache_ptr, # (max_pos, D)
+ x_ptr,
+ cos_cache_ptr,
sin_cache_ptr,
- pos_t_ptr, # (B, T)
+ pos_t_ptr,
pos_h_ptr,
pos_w_ptr,
- out_ptr, # (B, H, T, D) output
- B: tl.int64, T: tl.int64, H: tl.int64, D: tl.int64,
- s0: tl.int64, s1: tl.int64, s2: tl.int64,
- stride_xb, stride_xt, stride_xh, stride_xd,
- stride_outb, stride_outh, stride_outt, stride_outd,
- stride_pb, stride_pt,
- stride_cachepos, stride_cached, # cos/sin cache strides
- BLOCK_T: tl.constexpr,
-):
- """
- Each program handles one (b, h, t_tile) group.
- Within the program, we process all D elements using thread mapping.
- """
- pid = tl.program_id(0)
- R = tl.cdiv(T, BLOCK_T)
- h_total = H
- bh = pid // R
- r = pid % R
- b = bh // h_total
- h = bh % h_total
-
- if b >= B:
- return
-
- # T indices for this tile
- t_off = r * BLOCK_T
- t_idx = t_off + tl.arange(0, BLOCK_T)
- t_mask = t_idx < T
-
- # D indices - we process ALL D across the program's threads
- # With BLOCK_T=64 and num_warps=4 (128 threads), each thread handles D/128 elements
- # For D=128, each thread handles 1 element. For D=64, we can have fewer threads.
- stride = tl.num_programs(0) # not needed
- pid_d = tl.program_id(0) % D # won't work - this is 1D grid
-
- # Better approach: make grid 2D or use a different strategy
- # Let's use a 1D grid over (b, h, t_tile, d_chunk)
- pass
-
-
-# Let me rewrite with proper tiling strategy
-@triton.jit
-def _mrope_kernel_v2(
- x_ptr,
- cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
out_ptr,
- B: tl.int64, T: tl.int64, H: tl.int64, D: tl.int64,
- s0: tl.int64, s1: tl.int
... (20308 more chars){
"content": "<8350 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && ./scripts/patch_torch.sh 2>&1 | tail -5",
"description": "Apply torch inductor hotfix"
}Applying torch inductor CSE typing hotfix...
[skip] already patched: .venv/lib/python3.11/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py
[skip] already patched: /home/infatoshi/.local/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py
Done.
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}PASS
{
"todos": [
{
"content": "Write initial Triton kernel solution.py",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Run benchmark.py for performance baseline",
"status": "in_progress",
"priority": "high"
},
{
"content": "Profile and optimize kernel",
"status": "pending",
"priority": "high"
},
{
"content": "Verify final correctness + benchmark",
"status": "pending",
"p
... (28 more chars)content: 'Write initial Triton kernel solution.py'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py for performance baseline'
status: 'in_progress'
priority: 'high'
content: 'Profile and optimize kernel'
status: 'pending'
priority: 'high'
content: 'Verify final correctness + benchmark'
status: 'pending'
priority: 'high'{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"description": "Run benchmark for baseline",
"timeout": 300000
}shape=0 variant=eager tflops=0.178 gbps=255.745 ms=0.353
shape=0 variant=compiled tflops=0.471 gbps=675.636 ms=0.134
shape=0 variant=solution tflops=0.397 gbps=569.866 ms=0.158
shape=0 solution_peak_fraction=0.3166
shape=1 variant=eager tflops=0.180 gbps=262.931 ms=0.559
shape=1 variant=compiled tflops=0.489 gbps=714.245 ms=0.206
shape=1 variant=solution tflops=0.387 gbps=564.506 ms=0.260
shape=1 solution_peak_fraction=0.3136
shape=2 variant=eager tflops=0.072 gbps=113.061 ms=0.196
shape=2 variant=compiled tflops=0.165 gbps=257.143 ms=0.086
shape=2 variant=solution tflops=0.490 gbps=765.025 ms=0.029
shape=2 solution_peak_fraction=0.4250
shape=3 variant=eager tflops=0.172 gbps=246.564 ms=1.465
shape=3 variant=compiled tflops=0.465 gbps=666.813 ms=0.542
shape=3 variant=solution tflops=0.293 gbps=420.655 ms=0.858
shape=3 solution_peak_fraction=0.2337
peak_fraction: 0.3151
RESULT: OK
{
"command": "<1426 chars \u2014 see below>",
"description": "Profile the Triton kernel",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch, triton
import solution
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
mrope_section = (16, 24, 24)
max_pos = 32768
device = 'cuda'
torch.manual_seed(42)
ref_model = __import__('reference').Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
sol_model = solution.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
with torch.no_grad():
sol_model.load_state_dict(ref_model.state_dict())
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device=device) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device=device) * 0.5
pos_t = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_h = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_w = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
# Profile
schedule = torch.profiler.schedule(wait=2, warmup=2, active=5)
with torch.profiler.profile(
schedule=schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler('./profile_output'),
profile_memory=False,
with_stack=True,
record_shapes=True,
) as prof:
for i in range(9):
with torch.no_grad():
sol_model(q, k, pos_t, pos_h, pos_w)
prof.step()
prof.export_chrome_trace('./kernel_trace.json')
print('Profile written to profile_output and kernel_trace.json')
" 2>&1/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py:224: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
_warn_once(
Traceback (most recent call last):
File "<string>", line 37, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/profiler/profiler.py", line 301, in export_chrome_trace
return self.profiler.export_chrome_trace(path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/autograd/profiler.py", line 521, in export_chrome_trace
self.kineto_results.save(path) # type: ignore[union-attr]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Trace is already saved.
{
"command": "<1980 chars \u2014 see below>",
"description": "Quick timing test",
"timeout": 120000
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch, solution, time
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
mrope_section = (16, 24, 24)
max_pos = 32768
device = 'cuda'
sol_model = solution.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device=device) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device=device) * 0.5
pos_t = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_h = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_w = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
# Warmup
for _ in range(10):
with torch.no_grad():
sol_model(q, k, pos_t, pos_h, pos_w)
torch.cuda.synchronize()
# Time it
times = []
for _ in range(30):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
with torch.no_grad():
sol_model(q, k, pos_t, pos_h, pos_w)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
print(f'Median: {sorted(times)[15]:.3f}ms, Mean: {sum(times)/len(times):.3f}ms')
print(f'Total KB/s: {(4*B*T*Hq*D*2 + 4*B*T*Hkv*D*2 + 6*B*T*D*2)/1e6/sum(times)/30*1000:.1f}')
# Compute theoretical bandwidth
# read q: B*T*Hq*D*2, read k: B*T*Hkv*D*2, write qrot: B*T*Hq*D*2, write krot: B*T*Hkv*D*2
# cos/sin: 6*B*T*D*2 (for our kernel which loads 6 times), pos: 3*B*T*8
# Total = 4*B*T*(Hq+Hkv)*D*2 + 12*B*T*D*2 + 24*B*T
total_bytes = 4*B*T*(Hq+Hkv)*2 + 12*B*T*2 + 24*B*T # hmm this is wrong
# Actually bytes from problem.yaml: 4*B*T*(Hq+Hkv)*D + 12*B*T*D + 24*B*T
total_bytes = 4*B*T*(Hq+Hkv)*D + 12*B*T*D + 24*B*T # note: factor of 2 for bf16 is in the formula
# No wait, the formula uses bf16 as 2 bytes already accounted for
gbps = total_bytes / (sum(times)/30 * 1e-3) / 1e9
print(f'GB/s (formula): {gbps:.1f}, Peak: 1800 GB/s, Fraction: {gbps/1800:.4f}')
" 2>&1Median: 0.109ms, Mean: 0.102ms
Total KB/s: 1888.0
GB/s (formula): 881.3, Peak: 1800 GB/s, Fraction: 0.4896
{
"todos": [
{
"content": "Write initial Triton kernel solution.py",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Run benchmark.py for performance baseline",
"status": "completed",
"priority": "high"
},
{
"content": "Profile and optimize kernel",
"status": "in_progress",
"priority": "high"
},
{
"content": "Verify final correctness + benchmark",
"status": "pending",
... (30 more chars)content: 'Write initial Triton kernel solution.py'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py for performance baseline'
status: 'completed'
priority: 'high'
content: 'Profile and optimize kernel'
status: 'in_progress'
priority: 'high'
content: 'Verify final correctness + benchmark'
status: 'pending'
priority: 'high'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -2,6 +2,10 @@
Fuses gather + per-axis mix + rotate-half + transpose into one kernel.
Streams q/k once and emits the rotated (B, H, T, D) layout directly.
+
+Optimization: each d maps to exactly one axis (t/h/w) based on mrope_section.
+Instead of loading cos/sin from all 3 axes, we select the position first
+and load cos/sin from the chosen axis only — 2 cache loads per (t,d) instead of 6.
"""
import torch
import torch.nn as nn
@@ -42,8 +46,10 @@
):
"""
Grid: (B * H * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
- Each program handles a (BLOCK_T, BLOCK_D) tile in the (T, D) plane
- for a single (b, h) pair.
+ Each program handles a (BLOCK_T, BLOCK_D) tile for one (b, h) pair.
+
+ Key optimization: each d maps to exactly one axis based on mrope_section.
+ Select the position first, then load cos/sin once per (t,d).
"""
pid = tl.program_id(0)
pid_d = tl.program_id(1)
@@ -70,68 +76,51 @@
half = D // 2
+ # --- Select position index based on which axis each d belongs to ---
+ # d_first = d if d < half else d - half (axis is determined by first-half index)
+ d_first = d_off + tl.arange(0, BLOCK_D)
+ d_first = tl.where(d_first < half, d_first, d_first - half)
+
+ axis_t = d_first < s0
+ axis_h = (d_first >= s0) & (d_first < s0 + s1)
+ # axis_w = not axis_t and not axis_h (implied)
+
+ pos_base = b * stride_pb
+ pt_vals = tl.load(pos_t_ptr + pos_base + t_rng[:, None] * stride_pt,
+ mask=t_mask[:, None], other=0)
+ ph_vals = tl.load(pos_h_ptr + pos_base + t_rng[:, None] * stride_pt,
+ mask=t_mask[:, None], other=0)
+ pw_vals = tl.load(pos_w_ptr + pos_base + t_rng[:, None] * stride_pt,
+ mask=t_mask[:, None], other=0)
+
+ # Select position per d column
+ pos_mixed = tl.where(axis_t[None, :], pt_vals,
+ tl.where(axis_h[None, :], ph_vals, pw_vals))
+
+ # --- Load cos/sin once from the selected position ---
+ cos_off = pos_mixed * stride_ccb + d_rng[None, :] * stride_ccd
+ sin_off = cos_off # same offset, different base pointer
+
+ cos_tile = tl.load(cos_cache_ptr + cos_off, mask=td_mask, other=0.0)
+ sin_tile = tl.load(sin_cache_ptr + sin_off, mask=td_mask, other=0.0)
+
# --- Load x[b, t, h, d] ---
x_base_off = b * stride_xb + h * stride_xh
x_offsets = x_base_off + t_rng[:, None] * stride_xt + d_rng[None, :] * stride_xd
x_tile = tl.load(x_ptr + x_offsets, mask=td_mask, other=0.0)
- # --- Load position indices (each T loads its position) ---
- pos_base = b * stride_pb
- pt_off = pos_base + t_rng * stride_pt
- ph_off = pos_base + t_rng * stride_pt
- pw_off = pos_base + t_rng * stride_pt
-
- pos_t_vals = tl.load(pos_t_ptr + pt_off, mask=t_mask, other=0)
- pos_h_vals = tl.load(pos_h_ptr + ph_off, mask=t_mask, other=0)
- pos_w_vals = tl.load(pos_w_ptr + pw_off, mask=t_mask, other=0)
-
- # --- Load cos for each axis ---
- # cos_cache is (max_pos, D), contiguous. stride = (D, 1)
- # offset for (pos, d) = pos * D + d
- # pos_vals[:, None] * stride_ccb + d_rng[None, :] * stride_ccd
- # shape: (BLOCK_T, BLOCK_D)
- cos_t_off = pos_t_vals[:, None] * stride_ccb + d_rng[None, :] * stride_ccd
- cos_h_off = pos_h_vals[:, None] * stride_ccb + d_rng[None, :] * stride_ccd
- cos_w_off = pos_w_vals[:, None] * stride_ccb + d_rng[None, :] * stride_ccd
-
- cos_t_tile = tl.load(cos_cache_ptr + cos_t_off, mask=td_mask, other=0.0)
- cos_h_tile = tl.load(cos_cache_ptr + cos_h_off, mask=td_mask, other=0.0)
- cos_w_tile = tl.load(cos_cache_ptr + cos_w_off, mask=td_mask, other=0.0)
-
- sin_t_tile = tl.load(sin_cache_ptr + cos_t_off, mask=td_mask, other=0.0)
- sin_h_tile = tl.load(sin_cache_ptr + cos_h_off, mask=td_mask, other=0.0)
- sin_w_tile = tl.load(sin_cache_ptr + cos_w_off, mask=td_mask, other=0.0)
-
- # --- Mix axes based on mrope_section ---
- # d in [0, s0) -> temporal, [s0, s0+s1) -> height, [s0+s1, D//2) -> width
- # d in [D//2, D) -> same scheme but with d and d-D//2 for cache column
- # Key: cache column is ALWAYS d, not d_first.
- # The axis determines WHICH position Row to use (pos_t vs pos_h vs pos_w).
- # d_first determines axis, but actual cache col = d (mirror has same frequency)
- d_first = tl.where(d_rng < half, d_rng, d_rng - half)
- axis_mask_t = d_first < s0
- axis_mask_h = (d_first >= s0) & (d_first < s0 + s1)
-
- cos_mixed = tl.where(axis_mask_t, cos_t_tile,
- tl.where(axis_mask_h, cos_h_tile, cos_w_tile))
- sin_mixed = tl.where(axis_mask_t, sin_t_tile,
- tl.where(axis_mask_h, sin_h_tile, sin_w_tile))
-
- # --- Rotate half ---
- # rotate_half(x)[d] = -x[d+half] if d < half else x[d-half]
- d_mir = tl.where(d_rng < half, d_rng + half, d_rng - half)
-
- # x_offsets for mirrored d: x[b, t, h, d_mir]
- # d_mir might be outside [d_off, d_off+BLOCK_D), so this is a
- # potentially non-contiguous load. Triton handles it as scalar addressed load.
+ # --- Load x at mirrored position for rotate_half ---
+ # d_mir = d + half if d < half else d - half
+ d_mir = d_rng + half
+ d_mir = tl.where(d_rng < half, d_mir, d_rng - half)
x_mir_offsets = x_base_off + t_rng[:, None] * stride_xt + d_mir[None, :] * stride_xd
x_mir_tile = tl.load(x_ptr + x_mir_offsets, mask=td_mask, other=0.0)
- # Sign: negate if d < half (the d that's being moved from second half)
- x_rot = tl.where(d_rng < half, -x_mir_tile, x_mir_tile)
+ # Sign: rotate_half[x at d] = -x[d+half] if d < half, +x[d-half] if d >= half
+ x_rot = tl.where(d_rng[None, :] < half, -x_mir_tile, x_mir_tile)
# --- Apply RoPE ---
- out_tile = x_tile * cos_mixed + x_rot * sin_mixed
+ out_tile = x_tile * cos_t
... (897 more chars){
"content": "<7756 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}shape=0 variant=eager tflops=0.182 gbps=260.551 ms=0.346
shape=0 variant=compiled tflops=0.481 gbps=689.676 ms=0.131
shape=0 variant=solution tflops=0.634 gbps=909.012 ms=0.099
shape=0 solution_peak_fraction=0.5050
shape=1 variant=eager tflops=0.181 gbps=264.004 ms=0.557
shape=1 variant=compiled tflops=0.487 gbps=711.534 ms=0.207
shape=1 variant=solution tflops=0.593 gbps=865.831 ms=0.170
shape=1 solution_peak_fraction=0.4810
shape=2 variant=eager tflops=0.073 gbps=114.798 ms=0.193
shape=2 variant=compiled tflops=0.165 gbps=257.095 ms=0.086
shape=2 variant=solution tflops=0.273 gbps=426.141 ms=0.052
shape=2 solution_peak_fraction=0.2367
shape=3 variant=eager tflops=0.172 gbps=247.505 ms=1.459
shape=3 variant=compiled tflops=0.462 gbps=662.935 ms=0.545
shape=3 variant=solution tflops=0.691 gbps=991.214 ms=0.364
shape=3 solution_peak_fraction=0.5507
peak_fraction: 0.4219
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,11 +1,8 @@
-"""Triton kernel for multi-axis RoPE pre-attention prep.
-
-Fuses gather + per-axis mix + rotate-half + transpose into one kernel.
-Streams q/k once and emits the rotated (B, H, T, D) layout directly.
-
-Optimization: each d maps to exactly one axis (t/h/w) based on mrope_section.
-Instead of loading cos/sin from all 3 axes, we select the position first
-and load cos/sin from the chosen axis only — 2 cache loads per (t,d) instead of 6.
+"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
+
+Fuses gather + per-axis mix + rotate-half + transpose for both q and k
+into a single kernel launch. Each thread block handles one (b, h, t_tile, d_tile),
+loading cos/sin once per (t,d) and applying rotation to both q and k heads.
"""
import torch
import torch.nn as nn
@@ -14,42 +11,37 @@
@triton.jit
-def _mrope_kernel(
- x_ptr,
+def _mrope_fused_kernel(
+ q_ptr,
+ k_ptr,
cos_cache_ptr,
sin_cache_ptr,
pos_t_ptr,
pos_h_ptr,
pos_w_ptr,
- out_ptr,
+ q_out_ptr,
+ k_out_ptr,
B: tl.int64,
T: tl.int64,
- H: tl.int64,
+ Hq: tl.int64,
+ Hkv: tl.int64,
D: tl.int64,
s0: tl.int64,
s1: tl.int64,
- s2: tl.int64,
- stride_xb,
- stride_xt,
- stride_xh,
- stride_xd,
- stride_pb,
- stride_pt,
- stride_ccb,
- stride_ccd,
- stride_outb,
- stride_outh,
- stride_outt,
- stride_outd,
+ stride_qb, stride_qt, stride_qh, stride_qd,
+ stride_kb, stride_kt, stride_kh, stride_kd,
+ stride_pb, stride_pt,
+ stride_ccb, stride_ccd,
+ stride_qob, stride_qoh, stride_qot, stride_qod,
+ stride_kob, stride_koh, stride_kot, stride_kod,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
+ HAS_K: tl.constexpr,
):
"""
- Grid: (B * H * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
- Each program handles a (BLOCK_T, BLOCK_D) tile for one (b, h) pair.
-
- Key optimization: each d maps to exactly one axis based on mrope_section.
- Select the position first, then load cos/sin once per (t,d).
+ Grid: (B * (Hq + Hkv) * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
+ Each program handles one (b, h, t_tile, d_tile).
+ h_is_q indicates whether this is a q head or k head.
"""
pid = tl.program_id(0)
pid_d = tl.program_id(1)
@@ -57,16 +49,34 @@
R = tl.cdiv(T, BLOCK_T)
D_tiles = tl.cdiv(D, BLOCK_D)
+ total_h = Hq + Hkv if HAS_K else Hq
bh = pid // R
r = pid % R
- b = bh // H
- h = bh % H
+ b = bh // total_h
+ h = bh % total_h
if b >= B or pid_d >= D_tiles:
return
+ h_is_q = h < Hq
+ h_idx = h if h_is_q else h - Hq
+
+ # Select pointers
+ stride_b = tl.where(h_is_q, stride_qb, tl.where(HAS_K, stride_kb, 0))
+ stride_t = tl.where(h_is_q, stride_qt, tl.where(HAS_K, stride_kt, 0))
+ stride_h = tl.where(h_is_q, stride_qh, tl.where(HAS_K, stride_kh, 0))
+ stride_d = tl.where(h_is_q, stride_qd, tl.where(HAS_K, stride_kd, 0))
+ x_ptr = tl.where(h_is_q, q_ptr, tl.where(HAS_K, k_ptr, 0))
+
+ stride_ob = tl.where(h_is_q, stride_qob, tl.where(HAS_K, stride_kob, 0))
+ stride_oh = tl.where(h_is_q, stride_qoh, tl.where(HAS_K, stride_koh, 0))
+ stride_ot = tl.where(h_is_q, stride_qot, tl.where(HAS_K, stride_kot, 0))
+ stride_od = tl.where(h_is_q, stride_qod, tl.where(HAS_K, stride_kod, 0))
+ out_ptr = tl.where(h_is_q, q_out_ptr, tl.where(HAS_K, k_out_ptr, 0))
+
t_off = r * BLOCK_T
d_off = pid_d * BLOCK_D
+ half = D // 2
t_rng = t_off + tl.arange(0, BLOCK_T)
d_rng = d_off + tl.arange(0, BLOCK_D)
@@ -74,57 +84,63 @@
d_mask = d_rng < D
td_mask = t_mask[:, None] & d_mask[None, :]
- half = D // 2
-
- # --- Select position index based on which axis each d belongs to ---
- # d_first = d if d < half else d - half (axis is determined by first-half index)
- d_first = d_off + tl.arange(0, BLOCK_D)
- d_first = tl.where(d_first < half, d_first, d_first - half)
-
+ # Load x[b, t, h_idx, d]
+ x_base = b * stride_b + h_idx * stride_h
+ x_offsets = x_base + t_rng[:, None] * stride_t + d_rng[None, :] * stride_d
+ x_tile = tl.load(x_ptr + x_offsets, mask=td_mask, other=0.0)
+
+ # Load x at mirrored position for rotate_half
+ d_mir = tl.where(d_rng < half, d_rng + half, d_rng - half)
+ x_mir_offsets = x_base + t_rng[:, None] * stride_t + d_mir[None, :] * stride_d
+ x_mir_tile = tl.load(x_ptr + x_mir_offsets, mask=td_mask, other=0.0)
+ x_rot = tl.where(d_rng[None, :] < half, -x_mir_tile, x_mir_tile)
+
+ # Load positions
+ pos_base = b * stride_pb
+ pos_offsets = pos_base + t_rng * stride_pt
+ pos_t_vals = tl.load(pos_t_ptr + pos_offsets, mask=t_mask, other=0)
+ pos_h_vals = tl.load(pos_h_ptr + pos_offsets, mask=t_mask, other=0)
+ pos_w_vals = tl.load(pos_w_ptr + pos_offsets, mask=t_mask, other=0)
+
+ # Axis selection: d_first determines axis
+ d_first = tl.where(d_rng < half, d_rng, d_rng - half)
axis_t = d_first < s0
axis_h = (d_first >= s0) & (d_first < s0 + s1)
- # axis_w = not axis_t and not axis_h (implied)
-
- pos_base = b * stride_pb
- pt_vals = tl.load(pos_t_ptr + pos_base + t_rng[:, None] * stride_pt,
- mask=t_mask[:, None], other=0)
- ph_vals = tl.load(pos_h_ptr + pos_base + t_rng[:, None] * stride_pt,
- mask=t_mask[:, None], other=0)
- pw_vals = tl.load(pos_w_ptr + pos_base + t_rng[:, None] * stride_pt,
- mask=t_mask[:, None], other=0)
-
- # Select position per d column
- pos_mixed = tl.where(axis_t[None, :], pt_vals,
- tl.where(axis_h[None, :], ph_vals, pw_vals))
-
- # --- Load cos/sin once from the selected position ---
+
+ # Broadcast positions to (BLOCK_T, BLOCK_D)
+ pos_t_wide = pos_t_vals[:,
... (4909 more chars){
"content": "<8324 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,7 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Fuses gather + per-axis mix + rotate-half + transpose for both q and k
-into a single kernel launch. Each thread block handles one (b, h, t_tile, d_tile),
-loading cos/sin once per (t,d) and applying rotation to both q and k heads.
+Strategy: one kernel per tensor (q, k). Each block handles (BLOCK_T x BLOCK_D)
+for one (b, h) pair. Cos/sin loaded once per (t,d) via axis-selected position.
"""
import torch
import torch.nn as nn
@@ -11,20 +10,19 @@
@triton.jit
-def _mrope_fused_kernel(
+def _mrope_kernel(
q_ptr,
k_ptr,
- cos_cache_ptr,
- sin_cache_ptr,
- pos_t_ptr,
- pos_h_ptr,
- pos_w_ptr,
- q_out_ptr,
- k_out_ptr,
+ cos_cache,
+ sin_cache,
+ pos_t,
+ pos_h,
+ pos_w,
+ q_out,
+ k_out,
B: tl.int64,
T: tl.int64,
Hq: tl.int64,
- Hkv: tl.int64,
D: tl.int64,
s0: tl.int64,
s1: tl.int64,
@@ -32,47 +30,44 @@
stride_kb, stride_kt, stride_kh, stride_kd,
stride_pb, stride_pt,
stride_ccb, stride_ccd,
- stride_qob, stride_qoh, stride_qot, stride_qod,
- stride_kob, stride_koh, stride_kot, stride_kod,
+ stride_qob, stride_qot, stride_qod,
+ stride_kob, stride_kot, stride_kod,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
- HAS_K: tl.constexpr,
):
"""
- Grid: (B * (Hq + Hkv) * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
- Each program handles one (b, h, t_tile, d_tile).
- h_is_q indicates whether this is a q head or k head.
+ Grid: (B * H * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
+ where H = Hq for q pass, Hkv for k pass.
+
+ Actually this is a truly fused kernel: grid_0 over (b, h_or_hkv, t_tile),
+ grid_1 over d_tile. Each program computes RoPE for both q and k if applicable.
+
+ Wait, q has Hq heads, k has Hkv heads. Different H means we can't simply fuse (b,h).
+ Let's keep things simple: each grid_0 program handles (b, h, t_tile),
+ and within the kernel we compute for the appropriate tensor.
+
+ We'll just use separate launches with shared code via constexpr HAS_KV flag.
+ Actually simpler: just two kernel calls, same kernel, different H and pointers.
"""
- pid = tl.program_id(0)
- pid_d = tl.program_id(1)
+ pid_t = tl.program_id(0) # combined (b * H + h) * R + r
+ pid_d = tl.program_id(1) # D tile
R = tl.cdiv(T, BLOCK_T)
D_tiles = tl.cdiv(D, BLOCK_D)
- total_h = Hq + Hkv if HAS_K else Hq
- bh = pid // R
- r = pid % R
- b = bh // total_h
- h = bh % total_h
+ # Total heads depends on which tensor we're processing.
+ # We pass which_h_total to decide decomposition.
+ # Simplify: assume grid_0 = (B * H_total * R,) where H_total comes from
+ # the first argument.
+ H_total = Hq # set per-caller
+
+ bh = pid_t // R
+ r = pid_t % R
+ b = bh // H_total
+ h = bh % H_total
if b >= B or pid_d >= D_tiles:
return
-
- h_is_q = h < Hq
- h_idx = h if h_is_q else h - Hq
-
- # Select pointers
- stride_b = tl.where(h_is_q, stride_qb, tl.where(HAS_K, stride_kb, 0))
- stride_t = tl.where(h_is_q, stride_qt, tl.where(HAS_K, stride_kt, 0))
- stride_h = tl.where(h_is_q, stride_qh, tl.where(HAS_K, stride_kh, 0))
- stride_d = tl.where(h_is_q, stride_qd, tl.where(HAS_K, stride_kd, 0))
- x_ptr = tl.where(h_is_q, q_ptr, tl.where(HAS_K, k_ptr, 0))
-
- stride_ob = tl.where(h_is_q, stride_qob, tl.where(HAS_K, stride_kob, 0))
- stride_oh = tl.where(h_is_q, stride_qoh, tl.where(HAS_K, stride_koh, 0))
- stride_ot = tl.where(h_is_q, stride_qot, tl.where(HAS_K, stride_kot, 0))
- stride_od = tl.where(h_is_q, stride_qod, tl.where(HAS_K, stride_kod, 0))
- out_ptr = tl.where(h_is_q, q_out_ptr, tl.where(HAS_K, k_out_ptr, 0))
t_off = r * BLOCK_T
d_off = pid_d * BLOCK_D
@@ -84,71 +79,168 @@
d_mask = d_rng < D
td_mask = t_mask[:, None] & d_mask[None, :]
- # Load x[b, t, h_idx, d]
- x_base = b * stride_b + h_idx * stride_h
- x_offsets = x_base + t_rng[:, None] * stride_t + d_rng[None, :] * stride_d
- x_tile = tl.load(x_ptr + x_offsets, mask=td_mask, other=0.0)
-
- # Load x at mirrored position for rotate_half
+ # Load position indices (broadcast per t column)
+ pos_off = b * stride_pb + t_rng * stride_pt
+ pos_t_vals = tl.load(pos_t + pos_off, mask=t_mask, other=0)
+ pos_h_vals = tl.load(pos_h + pos_off, mask=t_mask, other=0)
+ pos_w_vals = tl.load(pos_w + pos_off, mask=t_mask, other=0)
+
+ # Axis mask for each d column (independent of t)
+ d_first = tl.where(d_rng < half, d_rng, d_rng - half)
+ # axis_t[d] = d_first < s0
+ # axis_h[d] = s0 <= d_first < s0+s1
+ # axis_w[d] = d_first >= s0+s1
+ axis_t = d_first < s0
+ axis_h = (d_first >= s0) & (d_first < s0 + s1)
+ # axis_w = ~axis_t & ~axis_h (implied)
+
+ # Select position: for each (t, d), which of pos_t/pos_h/pos_w to use
+ # axis_t/axis_h are (BLOCK_D,) vectors, extend to (BLOCK_T, BLOCK_D)
+ pos_sel = tl.where(
+ axis_t[None, :],
+ pos_t_vals[:, None],
+ tl.where(axis_h[None, :], pos_h_vals[:, None], pos_w_vals[:, None])
+ )
+
+ # Load cos/sin from cache using selected position
+ # cache is contiguous (max_pos, D): offset = pos * D + d
+ cos_sin_off = pos_sel * stride_ccb + d_rng[None, :] * stride_ccd
+ cos_tile = tl.load(cos_cache + cos_sin_off, mask=td_mask, other=0.0)
+ sin_tile = tl.load(sin_cache + cos_sin_off, mask=td_mask, other=0.0)
+
+ # --- Q RoPE ---
+ q_base = b * stride_qb + h * stride_qh
+ q_offsets = q_base + t_rng[:, None] * stride_qt + d_rng[None, :] * stride_qd
+ q_tile = tl.load(q_ptr + q_offsets, mask=td_mask, other=0.0)
+
+ # Mirror d for rotate_half
d_mir = tl.w
... (8826 more chars){
"content": "<11263 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"description": "Correctness check",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}shape=0 variant=eager tflops=0.180 gbps=257.848 ms=0.350
shape=0 variant=compiled tflops=0.482 gbps=692.299 ms=0.130
shape=0 variant=solution tflops=0.565 gbps=811.016 ms=0.111
shape=0 solution_peak_fraction=0.4506
shape=1 variant=eager tflops=0.180 gbps=262.871 ms=0.559
shape=1 variant=compiled tflops=0.451 gbps=658.779 ms=0.223
shape=1 variant=solution tflops=0.595 gbps=868.860 ms=0.169
shape=1 solution_peak_fraction=0.4827
shape=2 variant=eager tflops=0.073 gbps=114.560 ms=0.193
shape=2 variant=compiled tflops=0.170 gbps=265.285 ms=0.083
shape=2 variant=solution tflops=0.292 gbps=455.636 ms=0.049
shape=2 solution_peak_fraction=0.2531
shape=3 variant=eager tflops=0.168 gbps=241.480 ms=1.495
shape=3 variant=compiled tflops=0.464 gbps=665.633 ms=0.542
shape=3 variant=solution tflops=0.686 gbps=984.469 ms=0.367
shape=3 solution_peak_fraction=0.5469
peak_fraction: 0.4166
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,10 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Strategy: one kernel per tensor (q, k). Each block handles (BLOCK_T x BLOCK_D)
-for one (b, h) pair. Cos/sin loaded once per (t,d) via axis-selected position.
+Key optimizations:
+1. Each d maps to exactly one axis - load cos/sin ONCE per (t,d)
+2. Large D tiles (128) for coalesced loads
+3. num_warps=8 for better memory throughput
+4. Use TRITON_MAX Sunder for pointer arithmetic
"""
import torch
import torch.nn as nn
@@ -9,131 +12,6 @@
import triton.language as tl
-@triton.jit
-def _mrope_kernel(
- q_ptr,
- k_ptr,
- cos_cache,
- sin_cache,
- pos_t,
- pos_h,
- pos_w,
- q_out,
- k_out,
- B: tl.int64,
- T: tl.int64,
- Hq: tl.int64,
- D: tl.int64,
- s0: tl.int64,
- s1: tl.int64,
- stride_qb, stride_qt, stride_qh, stride_qd,
- stride_kb, stride_kt, stride_kh, stride_kd,
- stride_pb, stride_pt,
- stride_ccb, stride_ccd,
- stride_qob, stride_qot, stride_qod,
- stride_kob, stride_kot, stride_kod,
- BLOCK_T: tl.constexpr,
- BLOCK_D: tl.constexpr,
-):
- """
- Grid: (B * H * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
- where H = Hq for q pass, Hkv for k pass.
-
- Actually this is a truly fused kernel: grid_0 over (b, h_or_hkv, t_tile),
- grid_1 over d_tile. Each program computes RoPE for both q and k if applicable.
-
- Wait, q has Hq heads, k has Hkv heads. Different H means we can't simply fuse (b,h).
- Let's keep things simple: each grid_0 program handles (b, h, t_tile),
- and within the kernel we compute for the appropriate tensor.
-
- We'll just use separate launches with shared code via constexpr HAS_KV flag.
- Actually simpler: just two kernel calls, same kernel, different H and pointers.
- """
- pid_t = tl.program_id(0) # combined (b * H + h) * R + r
- pid_d = tl.program_id(1) # D tile
-
- R = tl.cdiv(T, BLOCK_T)
- D_tiles = tl.cdiv(D, BLOCK_D)
-
- # Total heads depends on which tensor we're processing.
- # We pass which_h_total to decide decomposition.
- # Simplify: assume grid_0 = (B * H_total * R,) where H_total comes from
- # the first argument.
- H_total = Hq # set per-caller
-
- bh = pid_t // R
- r = pid_t % R
- b = bh // H_total
- h = bh % H_total
-
- if b >= B or pid_d >= D_tiles:
- return
-
- t_off = r * BLOCK_T
- d_off = pid_d * BLOCK_D
- half = D // 2
-
- t_rng = t_off + tl.arange(0, BLOCK_T)
- d_rng = d_off + tl.arange(0, BLOCK_D)
- t_mask = t_rng < T
- d_mask = d_rng < D
- td_mask = t_mask[:, None] & d_mask[None, :]
-
- # Load position indices (broadcast per t column)
- pos_off = b * stride_pb + t_rng * stride_pt
- pos_t_vals = tl.load(pos_t + pos_off, mask=t_mask, other=0)
- pos_h_vals = tl.load(pos_h + pos_off, mask=t_mask, other=0)
- pos_w_vals = tl.load(pos_w + pos_off, mask=t_mask, other=0)
-
- # Axis mask for each d column (independent of t)
- d_first = tl.where(d_rng < half, d_rng, d_rng - half)
- # axis_t[d] = d_first < s0
- # axis_h[d] = s0 <= d_first < s0+s1
- # axis_w[d] = d_first >= s0+s1
- axis_t = d_first < s0
- axis_h = (d_first >= s0) & (d_first < s0 + s1)
- # axis_w = ~axis_t & ~axis_h (implied)
-
- # Select position: for each (t, d), which of pos_t/pos_h/pos_w to use
- # axis_t/axis_h are (BLOCK_D,) vectors, extend to (BLOCK_T, BLOCK_D)
- pos_sel = tl.where(
- axis_t[None, :],
- pos_t_vals[:, None],
- tl.where(axis_h[None, :], pos_h_vals[:, None], pos_w_vals[:, None])
- )
-
- # Load cos/sin from cache using selected position
- # cache is contiguous (max_pos, D): offset = pos * D + d
- cos_sin_off = pos_sel * stride_ccb + d_rng[None, :] * stride_ccd
- cos_tile = tl.load(cos_cache + cos_sin_off, mask=td_mask, other=0.0)
- sin_tile = tl.load(sin_cache + cos_sin_off, mask=td_mask, other=0.0)
-
- # --- Q RoPE ---
- q_base = b * stride_qb + h * stride_qh
- q_offsets = q_base + t_rng[:, None] * stride_qt + d_rng[None, :] * stride_qd
- q_tile = tl.load(q_ptr + q_offsets, mask=td_mask, other=0.0)
-
- # Mirror d for rotate_half
- d_mir = tl.where(d_rng < half, d_rng + half, d_rng - half)
- q_mir_offsets = q_base + t_rng[:, None] * stride_qt + d_mir[None, :] * stride_qd
- q_mir = tl.load(q_ptr + q_mir_offsets, mask=td_mask, other=0.0)
- q_rot_half = tl.where(d_rng[None, :] < half, -q_mir, q_mir)
-
- q_result = q_tile * cos_tile + q_rot_half * sin_tile
-
- # Store q output (B, H, T, D)
- q_out_base = b * stride_qob
- q_out_off = q_out_base + h * q.stride(1).to(tl.int64) * 0 + t_rng[:, None] * stride_qot + d_rng[None, :] * stride_qod
- # Wait, q.stride(1) is T stride in OUTPUT layout (B, H, T, D). Output strides are different.
- tl.store(q_out + q_out_base + h * q_out.stride(1) + t_rng[:, None] * stride_qot + d_rng[None, :] * stride_qod,
- q_result, mask=td_mask)
-
- # --- K RoPE (if this kernel also handles k) ---
- # We'll vectorize over both by making K a separate launch for now.
- pass # handled by separate kernel call
-
-
-# Better to have one simple kernel and call it twice. Let me simplify:
@triton.jit
def _rope_kernel(
x_ptr,
@@ -164,78 +42,77 @@
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
- """
- Grid: (B * H * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
- """
+ """2D grid: (B*H*ceil(T/BLOCK_T), ceil(D/BLOCK_D))"""
pid = tl.program_id(0)
pid_d = tl.program_id(1)
R = tl.cdiv(T, BLOCK_T)
D_tiles = tl.cdiv(D, BLOCK_D)
- total_bh = B * H
bh = pid // R
r = pid % R
b = bh // H
h = bh % H
- if bh >= total_bh or pid_d >= D_tiles:
+ if b >= B or pid_d >= D_tiles:
return
t_off = r
... (5715 more chars){
"content": "<6472 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"description": "Check correctness and benchmark",
"timeout": 300000
}FAIL: shape 0 {'B': 1, 'T': 4096, 'Hq': 32, 'Hkv': 8, 'D': 128, 'mrope_section': (16, 24, 24), 'max_pos': 32768} seed 42 q_rot: tolerance exceeded: max_abs_diff=4.34375 (atol=0.02, rtol=0.02)
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -80,8 +80,8 @@
# Axis determination per d (BLOCK_D,)
d_in_first = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
- _bt = d_in_first < s0 # (BLOCK_D,)
- _bh = d_in_first >= s0
+ _bt = d_in_first < s0
+ _bh = (d_in_first >= s0) & (d_in_first < s0 + s1)
# pos mixing: (BLOCK_T, BLOCK_D)
pos = tl.where(_bt[None, :],
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<71 chars \u2014 see below>",
"newString": "<79 chars \u2014 see below>"
} _bt = d_in_first < s0 # (BLOCK_D,)
_bh = d_in_first >= s0 _bt = d_in_first < s0
_bh = (d_in_first >= s0) & (d_in_first < s0 + s1)Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"description": "Correctness check after fix",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"description": "Benchmark after fix",
"timeout": 300000
}shape=0 variant=eager tflops=0.179 gbps=257.179 ms=0.351
shape=0 variant=compiled tflops=0.500 gbps=717.386 ms=0.126
shape=0 variant=solution tflops=0.486 gbps=697.520 ms=0.129
shape=0 solution_peak_fraction=0.3875
shape=1 variant=eager tflops=0.181 gbps=264.034 ms=0.557
shape=1 variant=compiled tflops=0.479 gbps=699.081 ms=0.210
shape=1 variant=solution tflops=0.524 gbps=765.164 ms=0.192
shape=1 solution_peak_fraction=0.4251
shape=2 variant=eager tflops=0.074 gbps=115.915 ms=0.191
shape=2 variant=compiled tflops=0.169 gbps=263.415 ms=0.084
shape=2 variant=solution tflops=0.246 gbps=384.962 ms=0.057
shape=2 solution_peak_fraction=0.2139
shape=3 variant=eager tflops=0.169 gbps=242.881 ms=1.487
shape=3 variant=compiled tflops=0.457 gbps=656.322 ms=0.550
shape=3 variant=solution tflops=0.611 gbps=876.396 ms=0.412
shape=3 solution_peak_fraction=0.4869
peak_fraction: 0.3619
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,10 +1,8 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Key optimizations:
-1. Each d maps to exactly one axis - load cos/sin ONCE per (t,d)
-2. Large D tiles (128) for coalesced loads
-3. num_warps=8 for better memory throughput
-4. Use TRITON_MAX Sunder for pointer arithmetic
+Strategy: each block handles one (b, h, t) — all of D in SRAM.
+Grid = B * H * T programs. Each program has 128 threads for D=128.
+Single cos/sin load per (t,d) via axis-selected position.
"""
import torch
import torch.nn as nn
@@ -39,80 +37,130 @@
stride_outh,
stride_outt,
stride_outd,
- BLOCK_T: tl.constexpr,
+ D_PADDED: tl.constexpr,
+):
+ """
+ Grid: (B * H * T,)
+ Each program loads the full D for one (b, h, t).
+ Thread idx = tid within block maps to D index.
+ D_PADDED should be a power of 2 >= D for efficient thread mapping.
+ """
+ pid = tl.program_id(0)
+ R = B * H * T
+ if pid >= R:
+ return
+
+ t = pid % T
+ bh = (pid // T) % (B * H)
+ b = bh // H
+ h = bh % H
+
+ tid = tl.program_id(0) % D_PADDED
+ d_rng = tid + tl.arange(0, D_PADDED)
+ # Actually in Triton, we use tl.arange for vector operations.
+ # The program handles all D.
+ # With 1D grid, tl.program_id(0) = pid. Each program = one (b,h,t).
+ # Threads handle D via tl.arange.
+ pass
+
+
+# Actually better to use proper Triton patterns. Let me do 1D grid with
+# td-range within each program. Triton handles this via tl.arange(0, N)
+# where N = threads per block / warps.
+
+@triton.jit
+def _rope_per_token(
+ x_ptr,
+ cos_cache,
+ sin_cache,
+ pos_t_ptr,
+ pos_h_ptr,
+ pos_w_ptr,
+ out_ptr,
+ B: tl.int64,
+ T: tl.int64,
+ H: tl.int64,
+ D: tl.int64,
+ s0: tl.int64,
+ s1: tl.int64,
+ stride_xb,
+ stride_xt,
+ stride_xh,
+ stride_xd,
+ stride_pb,
+ stride_pt,
+ stride_cachepos,
+ stride_cached,
+ stride_outb,
+ stride_outh,
+ stride_outt,
+ stride_outd,
BLOCK_D: tl.constexpr,
):
- """2D grid: (B*H*ceil(T/BLOCK_T), ceil(D/BLOCK_D))"""
+ """
+ Each program handles one (b, h, t) and processes D in tiles of BLOCK_D.
+ Total programs = B * H * T. But that's huge (1*32*4096 = 131072).
+ Each program loops over D tiles.
+ """
pid = tl.program_id(0)
- pid_d = tl.program_id(1)
-
- R = tl.cdiv(T, BLOCK_T)
- D_tiles = tl.cdiv(D, BLOCK_D)
-
- bh = pid // R
- r = pid % R
+
+ t = pid % T
+ bh = (pid // T) % (B * H)
b = bh // H
h = bh % H
- if b >= B or pid_d >= D_tiles:
+ if pid >= B * H * T:
return
- t_off = r * BLOCK_T
- d_off = pid_d * BLOCK_D
- half_d = D // 2
-
- # Thread-local indices
- _range_t = tl.arange(0, BLOCK_T)
- _range_d = tl.arange(0, BLOCK_D)
-
- t_rng = t_off + _range_t
- d_rng = d_off + _range_d
-
- t_mask = t_rng < T
- d_mask = d_rng < D
- td_mask = t_mask[:, None] & d_mask[None, :]
-
- # --- Position gathering ---
- pos_off = b * stride_pb + t_rng * stride_pt
- p_t = tl.load(pos_t_ptr + pos_off, mask=t_mask, other=0)
- p_h = tl.load(pos_h_ptr + pos_off, mask=t_mask, other=0)
- p_w = tl.load(pos_w_ptr + pos_off, mask=t_mask, other=0)
-
- # Axis determination per d (BLOCK_D,)
- d_in_first = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
- _bt = d_in_first < s0
- _bh = (d_in_first >= s0) & (d_in_first < s0 + s1)
-
- # pos mixing: (BLOCK_T, BLOCK_D)
- pos = tl.where(_bt[None, :],
- p_t[:, None],
- tl.where(_bh[None, :], p_h[:, None], p_w[:, None])
- )
-
- # --- Single cos/sin load per (t,d) ---
- cache_off = pos * stride_cachepos + d_rng[None, :] * stride_cached
- cos = tl.load(cos_cache + cache_off, mask=td_mask, other=1.0)
- sin = tl.load(sin_cache + cache_off, mask=td_mask, other=0.0)
-
- # --- Load x and x_mirror ---
- x_base = b * stride_xb + h * stride_xh
-
- x_off = x_base + t_rng[:, None] * stride_xt + d_rng[None, :] * stride_xd
- x = tl.load(x_ptr + x_off, mask=td_mask, other=0.0)
-
- # Mirror index for rotate_half
- d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
- x_mir_off = x_base + t_rng[:, None] * stride_xt + d_mir[None, :] * stride_xd
- x_mir = tl.load(x_ptr + x_mir_off, mask=td_mask, other=0.0)
- x_rot = tl.where(d_rng[None, :] < half_d, -x_mir, x_mir)
-
- # --- RoPE ---
- out = x * cos + x_rot * sin
-
- # --- Store (B, H, T, D) ---
- o_base = b * stride_outb + h * stride_outh
- o_off = o_base + t_rng[:, None] * stride_outt + d_rng[None, :] * stride_outd
- tl.store(out_ptr + o_off, out, mask=td_mask)
+ half = D // 2
+
+ # Base offsets
+ x_base = b * stride_xb + t * stride_xt + h * stride_xh
+ out_base = b * stride_outb + h * stride_outh + t * stride_outt
+ pos_off = b * stride_pb + t * stride_pt
+
+ # Load positions (scalar or small vector)
+ pos_t_val = tl.load(pos_t_ptr + pos_off)
+ pos_h_val = tl.load(pos_h_ptr + pos_off)
+ pos_w_val = tl.load(pos_w_ptr + pos_off)
+
+ # Process D in blocks
+ d_off = 0
+ while d_off < D:
+ d_rng = d_off + tl.arange(0, BLOCK_D)
+ d_mask = d_rng < D
+
+ # Axis determination
+ d_in_first = tl.where(d_rng < half, d_rng, d_rng - half)
+ is_t = d_in_first < s0
+ is_h = (d_in_first >= s0) & (d_in_first < s0 + s1)
+
+ # Select position
+ pos_sel = tl.where(is_t, pos_t_val, tl.where(is_h, pos_h_val, pos_w_val))
+
+ # Load cos/sin
+ cos_off = cos_cache + pos_sel * stride_cachepos + d_rng * stride_cached
+ sin_off = sin_cache + pos_sel * stride_cachepos + d_rng * stride_cached
+ cos_val = tl.load(cos_off, mask=d_mask, other=1.0)
+
... (2504 more chars){
"content": "<7487 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"description": "Check correctness",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"description": "Benchmark per-token approach",
"timeout": 300000
}shape=0 variant=eager tflops=0.178 gbps=255.040 ms=0.354
shape=0 variant=compiled tflops=0.478 gbps=686.571 ms=0.131
shape=0 variant=solution tflops=0.321 gbps=460.516 ms=0.196
shape=0 solution_peak_fraction=0.2558
shape=1 variant=eager tflops=0.182 gbps=266.099 ms=0.552
shape=1 variant=compiled tflops=0.669 gbps=976.337 ms=0.151
shape=1 variant=solution tflops=0.327 gbps=477.140 ms=0.308
shape=1 solution_peak_fraction=0.2651
shape=2 variant=eager tflops=0.071 gbps=111.628 ms=0.198
shape=2 variant=compiled tflops=0.240 gbps=374.634 ms=0.059
shape=2 variant=solution tflops=0.160 gbps=249.306 ms=0.089
shape=2 solution_peak_fraction=0.1385
shape=3 variant=eager tflops=0.179 gbps=256.471 ms=1.408
shape=3 variant=compiled tflops=0.528 gbps=757.475 ms=0.477
shape=3 variant=solution tflops=0.331 gbps=475.597 ms=0.759
shape=3 solution_peak_fraction=0.2642
peak_fraction: 0.2232
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,7 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Strategy: each block handles one (b, h, t) — all of D in SRAM.
-Grid = B * H * T programs. Each program has 128 threads for D=128.
-Single cos/sin load per (t,d) via axis-selected position.
+Winning config from prior iteration: BLOCK_T=16, BLOCK_D=64, num_warps=4.
+Now pushing for better memory throughput with autotuning.
"""
import torch
import torch.nn as nn
@@ -10,6 +9,19 @@
import triton.language as tl
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4),
+ triton.Config({"BLOCK_T": 16, "BLOCK_D": 64}, num_warps=8),
+ triton.Config({"BLOCK_T": 32, "BLOCK_D": 64}, num_warps=4),
+ triton.Config({"BLOCK_T": 32, "BLOCK_D": 32}, num_warps=4),
+ triton.Config({"BLOCK_T": 64, "BLOCK_D": 32}, num_warps=4),
+ triton.Config({"BLOCK_T": 64, "BLOCK_D": 32}, num_warps=8),
+ triton.Config({"BLOCK_T": 8, "BLOCK_D": 64}, num_warps=2),
+ triton.Config({"BLOCK_T": 8, "BLOCK_D": 64}, num_warps=4),
+ ],
+ key=["B", "T", "H", "D"],
+)
@triton.jit
def _rope_kernel(
x_ptr,
@@ -37,130 +49,70 @@
stride_outh,
stride_outt,
stride_outd,
- D_PADDED: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+ BLOCK_D: tl.constexpr,
):
- """
- Grid: (B * H * T,)
- Each program loads the full D for one (b, h, t).
- Thread idx = tid within block maps to D index.
- D_PADDED should be a power of 2 >= D for efficient thread mapping.
- """
pid = tl.program_id(0)
- R = B * H * T
- if pid >= R:
- return
-
- t = pid % T
- bh = (pid // T) % (B * H)
+ pid_d = tl.program_id(1)
+
+ R = tl.cdiv(T, BLOCK_T)
+ D_tiles = tl.cdiv(D, BLOCK_D)
+
+ bh = pid // R
+ r = pid % R
b = bh // H
h = bh % H
- tid = tl.program_id(0) % D_PADDED
- d_rng = tid + tl.arange(0, D_PADDED)
- # Actually in Triton, we use tl.arange for vector operations.
- # The program handles all D.
- # With 1D grid, tl.program_id(0) = pid. Each program = one (b,h,t).
- # Threads handle D via tl.arange.
- pass
-
-
-# Actually better to use proper Triton patterns. Let me do 1D grid with
-# td-range within each program. Triton handles this via tl.arange(0, N)
-# where N = threads per block / warps.
-
-@triton.jit
-def _rope_per_token(
- x_ptr,
- cos_cache,
- sin_cache,
- pos_t_ptr,
- pos_h_ptr,
- pos_w_ptr,
- out_ptr,
- B: tl.int64,
- T: tl.int64,
- H: tl.int64,
- D: tl.int64,
- s0: tl.int64,
- s1: tl.int64,
- stride_xb,
- stride_xt,
- stride_xh,
- stride_xd,
- stride_pb,
- stride_pt,
- stride_cachepos,
- stride_cached,
- stride_outb,
- stride_outh,
- stride_outt,
- stride_outd,
- BLOCK_D: tl.constexpr,
-):
- """
- Each program handles one (b, h, t) and processes D in tiles of BLOCK_D.
- Total programs = B * H * T. But that's huge (1*32*4096 = 131072).
- Each program loops over D tiles.
- """
- pid = tl.program_id(0)
-
- t = pid % T
- bh = (pid // T) % (B * H)
- b = bh // H
- h = bh % H
-
- if pid >= B * H * T:
+ if b >= B or pid_d >= D_tiles:
return
- half = D // 2
-
- # Base offsets
- x_base = b * stride_xb + t * stride_xt + h * stride_xh
- out_base = b * stride_outb + h * stride_outh + t * stride_outt
- pos_off = b * stride_pb + t * stride_pt
-
- # Load positions (scalar or small vector)
- pos_t_val = tl.load(pos_t_ptr + pos_off)
- pos_h_val = tl.load(pos_h_ptr + pos_off)
- pos_w_val = tl.load(pos_w_ptr + pos_off)
-
- # Process D in blocks
- d_off = 0
- while d_off < D:
- d_rng = d_off + tl.arange(0, BLOCK_D)
- d_mask = d_rng < D
-
- # Axis determination
- d_in_first = tl.where(d_rng < half, d_rng, d_rng - half)
- is_t = d_in_first < s0
- is_h = (d_in_first >= s0) & (d_in_first < s0 + s1)
-
- # Select position
- pos_sel = tl.where(is_t, pos_t_val, tl.where(is_h, pos_h_val, pos_w_val))
-
- # Load cos/sin
- cos_off = cos_cache + pos_sel * stride_cachepos + d_rng * stride_cached
- sin_off = sin_cache + pos_sel * stride_cachepos + d_rng * stride_cached
- cos_val = tl.load(cos_off, mask=d_mask, other=1.0)
- sin_val = tl.load(sin_off, mask=d_mask, other=0.0)
-
- # Load x and x_mirror
- x_off = x_ptr + x_base + d_rng * stride_xd
- x_val = tl.load(x_off, mask=d_mask, other=0.0)
-
- d_mir = tl.where(d_rng < half, d_rng + half, d_rng - half)
- x_mir_off = x_ptr + x_base + d_mir * stride_xd
- x_mir_val = tl.load(x_mir_off, mask=d_mask, other=0.0)
- x_rot = tl.where(d_rng < half, -x_mir_val, x_mir_val)
-
- # RoPE
- out_val = x_val * cos_val + x_rot * sin_val
-
- # Store
- out_off = out_ptr + out_base + d_rng * stride_outd
- tl.store(out_off, out_val, mask=d_mask)
-
- d_off += BLOCK_D
+ t_off = r * BLOCK_T
+ d_off = pid_d * BLOCK_D
+ half_d = D // 2
+
+ t_rng = t_off + tl.arange(0, BLOCK_T)
+ d_rng = d_off + tl.arange(0, BLOCK_D)
+ t_mask = t_rng < T
+ d_mask = d_rng < D
+ td_mask = t_mask[:, None] & d_mask[None, :]
+
+ # Position indices
+ pos_off = b * stride_pb + t_rng * stride_pt
+ p_t = tl.load(pos_t_ptr + pos_off, mask=t_mask, other=0)
+ p_h = tl.load(pos_h_ptr + pos_off, mask=t_mask, other=0)
+ p_w = tl.load(pos_w_ptr + pos_off, mask=t_mask, other=0)
+
+ # Axis per d
+ d_in_first = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
+ _bt = d_in_first < s0
+ _bh = (d_in_first >= s0) & (d_in_first < s0 + s1)
+
+ # Select position per (t,d)
+ pos_sel = tl.where(_bt[None, :], p_t[:, None],
+ tl.where(_bh[N
... (2767 more chars){
"content": "<6623 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,8 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Winning config from prior iteration: BLOCK_T=16, BLOCK_D=64, num_warps=4.
-Now pushing for better memory throughput with autotuning.
+Strategy: single kernel handles both q and k. Each program covers one (b, h_q_or_k, t_tile, d_tile).
+By processing q-heads first then k-heads in the same kernel, we share code and
+potentially benefit from SM occupancy.
"""
import torch
import torch.nn as nn
@@ -9,21 +10,10 @@
import triton.language as tl
-@triton.autotune(
- configs=[
- triton.Config({"BLOCK_T": 16, "BLOCK_D": 64}, num_warps=4),
- triton.Config({"BLOCK_T": 16, "BLOCK_D": 64}, num_warps=8),
- triton.Config({"BLOCK_T": 32, "BLOCK_D": 64}, num_warps=4),
- triton.Config({"BLOCK_T": 32, "BLOCK_D": 32}, num_warps=4),
- triton.Config({"BLOCK_T": 64, "BLOCK_D": 32}, num_warps=4),
- triton.Config({"BLOCK_T": 64, "BLOCK_D": 32}, num_warps=8),
- triton.Config({"BLOCK_T": 8, "BLOCK_D": 64}, num_warps=2),
- triton.Config({"BLOCK_T": 8, "BLOCK_D": 64}, num_warps=4),
- ],
- key=["B", "T", "H", "D"],
-)
+# Kernel for one tensor: q or k.
+# Uses num_stages for async load pipeline (Blackwell feature).
@triton.jit
-def _rope_kernel(
+def _rope_tensor(
x_ptr,
cos_cache,
sin_cache,
@@ -37,18 +27,10 @@
D: tl.int64,
s0: tl.int64,
s1: tl.int64,
- stride_xb,
- stride_xt,
- stride_xh,
- stride_xd,
- stride_pb,
- stride_pt,
- stride_cachepos,
- stride_cached,
- stride_outb,
- stride_outh,
- stride_outt,
- stride_outd,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_pb, stride_pt,
+ stride_ccb, stride_ccd,
+ stride_ob, stride_oh, stride_ot, stride_od,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
@@ -70,83 +52,193 @@
d_off = pid_d * BLOCK_D
half_d = D // 2
- t_rng = t_off + tl.arange(0, BLOCK_T)
- d_rng = d_off + tl.arange(0, BLOCK_D)
+ _rt = tl.arange(0, BLOCK_T)
+ _rd = tl.arange(0, BLOCK_D)
+ t_rng = t_off + _rt
+ d_rng = d_off + _rd
t_mask = t_rng < T
d_mask = d_rng < D
td_mask = t_mask[:, None] & d_mask[None, :]
- # Position indices
- pos_off = b * stride_pb + t_rng * stride_pt
- p_t = tl.load(pos_t_ptr + pos_off, mask=t_mask, other=0)
- p_h = tl.load(pos_h_ptr + pos_off, mask=t_mask, other=0)
- p_w = tl.load(pos_w_ptr + pos_off, mask=t_mask, other=0)
-
- # Axis per d
- d_in_first = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
- _bt = d_in_first < s0
- _bh = (d_in_first >= s0) & (d_in_first < s0 + s1)
-
- # Select position per (t,d)
- pos_sel = tl.where(_bt[None, :], p_t[:, None],
- tl.where(_bh[None, :], p_h[:, None], p_w[:, None]))
-
- # Cos/sin from cache (ONE load per axis position)
- cache_off = pos_sel * stride_cachepos + d_rng[None, :] * stride_cached
- cos = tl.load(cos_cache + cache_off, mask=td_mask, other=1.0)
- sin = tl.load(sin_cache + cache_off, mask=td_mask, other=0.0)
-
- # x[b, t, h, d]
- x_base = b * stride_xb + h * stride_xh
- x_off = x_base + t_rng[:, None] * stride_xt + d_rng[None, :] * stride_xd
- x = tl.load(x_ptr + x_off, mask=td_mask, other=0.0)
-
- # x mirror for rotate_half
+ # --- Stage 1: Load positions ---
+ po = b * stride_pb + t_rng * stride_pt
+ p_t = tl.load(pos_t_ptr + po, mask=t_mask, other=0)
+ p_h = tl.load(pos_h_ptr + po, mask=t_mask, other=0)
+ p_w = tl.load(pos_w_ptr + po, mask=t_mask, other=0)
+
+ # --- Stage 2: Axis select + cos/sin ---
+ d_f = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
+ _bt = d_f < s0
+ _bh = (d_f >= s0) & (d_f < s0 + s1)
+ ps = tl.where(_bt[None, :], p_t[:, None],
+ tl.where(_bh[None, :], p_h[:, None], p_w[:, None]))
+
+ co = ps * stride_ccb + d_rng[None, :] * stride_ccd
+ cos = tl.load(cos_cache + co, mask=td_mask, other=1.0)
+ sin = tl.load(sin_cache + co, mask=td_mask, other=0.0)
+
+ # --- Stage 3: Load x + x_mirror, compute, store ---
+ xb = b * stride_xb + h * stride_xh
+ xo = xb + t_rng[:, None] * stride_xt + d_rng[None, :] * stride_xd
+ x = tl.load(x_ptr + xo, mask=td_mask, other=0.0)
+
d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
- x_mir_off = x_base + t_rng[:, None] * stride_xt + d_mir[None, :] * stride_xd
- x_mir = tl.load(x_ptr + x_mir_off, mask=td_mask, other=0.0)
- x_rot = tl.where(d_rng[None, :] < half_d, -x_mir, x_mir)
-
- out = x * cos + x_rot * sin
-
- # Store to (B, H, T, D)
- out_base = b * stride_outb + h * stride_outh
- out_off = out_base + t_rng[:, None] * stride_outt + d_rng[None, :] * stride_outd
- tl.store(out_ptr + out_off, out, mask=td_mask)
-
+ xmo = xb + t_rng[:, None] * stride_xt + d_mir[None, :] * stride_xd
+ xm = tl.load(x_ptr + xmo, mask=td_mask, other=0.0)
+ xr = tl.where(d_rng[None, :] < half_d, -xm, xm)
+
+ out = x * cos + xr * sin
+
+ ob = b * stride_ob + h * stride_oh
+ oo = ob + t_rng[:, None] * stride_ot + d_rng[None, :] * stride_od
+ tl.store(out_ptr + oo, out, mask=td_mask)
+
+
+# Unified kernel: each program knows if it's processing a q-head or k-head.
+# Grid_0 = B * (Hq + Hkv) * ceil(T/BLOCK_T), Grid_1 = ceil(D/BLOCK_D)
+@triton.jit
+def _rope_unified(
+ q_ptr, k_ptr,
+ cos_cache, sin_cache,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ q_out_ptr, k_out_ptr,
+ B: tl.int64,
+ T: tl.int64,
+ Hq: tl.int64,
+ Hkv: tl.int64,
+ D: tl.int64,
+ s0: tl.int64,
+ s1: tl.int64,
+ stride_qb, stride_qt, stride_qh, stride_qd,
+ stride_kb, stride_kt, stride_kh, stride_kd,
+ stride_pb, stride_pt,
+ stride_ccb, stride_ccd,
+ stride_qob, stride_qot, stride_qod,
+ stride_kob, stride_kot, stride_kod
... (5401 more chars){
"content": "<9706 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,10 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Strategy: single kernel handles both q and k. Each program covers one (b, h_q_or_k, t_tile, d_tile).
-By processing q-heads first then k-heads in the same kernel, we share code and
-potentially benefit from SM occupancy.
+Unified kernel: grid_0 = B * (Hq + Hkv) * ceil(T/BLOCK_T), grid_1 = ceil(D/BLOCK_D).
+Each program handles one (b, h_combined, t_tile, d_tile).
+head index < Hq -> process q; head index >= Hq -> process k.
+
+No conditional branching on head type — instead we use tl.where for all operations.
"""
import torch
import torch.nn as nn
@@ -10,92 +12,6 @@
import triton.language as tl
-# Kernel for one tensor: q or k.
-# Uses num_stages for async load pipeline (Blackwell feature).
-@triton.jit
-def _rope_tensor(
- x_ptr,
- cos_cache,
- sin_cache,
- pos_t_ptr,
- pos_h_ptr,
- pos_w_ptr,
- out_ptr,
- B: tl.int64,
- T: tl.int64,
- H: tl.int64,
- D: tl.int64,
- s0: tl.int64,
- s1: tl.int64,
- stride_xb, stride_xt, stride_xh, stride_xd,
- stride_pb, stride_pt,
- stride_ccb, stride_ccd,
- stride_ob, stride_oh, stride_ot, stride_od,
- BLOCK_T: tl.constexpr,
- BLOCK_D: tl.constexpr,
-):
- pid = tl.program_id(0)
- pid_d = tl.program_id(1)
-
- R = tl.cdiv(T, BLOCK_T)
- D_tiles = tl.cdiv(D, BLOCK_D)
-
- bh = pid // R
- r = pid % R
- b = bh // H
- h = bh % H
-
- if b >= B or pid_d >= D_tiles:
- return
-
- t_off = r * BLOCK_T
- d_off = pid_d * BLOCK_D
- half_d = D // 2
-
- _rt = tl.arange(0, BLOCK_T)
- _rd = tl.arange(0, BLOCK_D)
- t_rng = t_off + _rt
- d_rng = d_off + _rd
- t_mask = t_rng < T
- d_mask = d_rng < D
- td_mask = t_mask[:, None] & d_mask[None, :]
-
- # --- Stage 1: Load positions ---
- po = b * stride_pb + t_rng * stride_pt
- p_t = tl.load(pos_t_ptr + po, mask=t_mask, other=0)
- p_h = tl.load(pos_h_ptr + po, mask=t_mask, other=0)
- p_w = tl.load(pos_w_ptr + po, mask=t_mask, other=0)
-
- # --- Stage 2: Axis select + cos/sin ---
- d_f = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
- _bt = d_f < s0
- _bh = (d_f >= s0) & (d_f < s0 + s1)
- ps = tl.where(_bt[None, :], p_t[:, None],
- tl.where(_bh[None, :], p_h[:, None], p_w[:, None]))
-
- co = ps * stride_ccb + d_rng[None, :] * stride_ccd
- cos = tl.load(cos_cache + co, mask=td_mask, other=1.0)
- sin = tl.load(sin_cache + co, mask=td_mask, other=0.0)
-
- # --- Stage 3: Load x + x_mirror, compute, store ---
- xb = b * stride_xb + h * stride_xh
- xo = xb + t_rng[:, None] * stride_xt + d_rng[None, :] * stride_xd
- x = tl.load(x_ptr + xo, mask=td_mask, other=0.0)
-
- d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
- xmo = xb + t_rng[:, None] * stride_xt + d_mir[None, :] * stride_xd
- xm = tl.load(x_ptr + xmo, mask=td_mask, other=0.0)
- xr = tl.where(d_rng[None, :] < half_d, -xm, xm)
-
- out = x * cos + xr * sin
-
- ob = b * stride_ob + h * stride_oh
- oo = ob + t_rng[:, None] * stride_ot + d_rng[None, :] * stride_od
- tl.store(out_ptr + oo, out, mask=td_mask)
-
-
-# Unified kernel: each program knows if it's processing a q-head or k-head.
-# Grid_0 = B * (Hq + Hkv) * ceil(T/BLOCK_T), Grid_1 = ceil(D/BLOCK_D)
@triton.jit
def _rope_unified(
q_ptr, k_ptr,
@@ -113,8 +29,8 @@
stride_kb, stride_kt, stride_kh, stride_kd,
stride_pb, stride_pt,
stride_ccb, stride_ccd,
- stride_qob, stride_qot, stride_qod,
- stride_kob, stride_kot, stride_kod,
+ stride_qob, stride_qoh, stride_qot, stride_qod,
+ stride_kob, stride_koh, stride_kot, stride_kod,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
@@ -128,7 +44,7 @@
bh = pid // R
r = pid % R
b = bh // total_h
- h = bh % total_h
+ h_combined = bh % total_h
if b >= B or pid_d >= D_tiles:
return
@@ -145,60 +61,60 @@
d_mask = d_rng < D
td_mask = t_mask[:, None] & d_mask[None, :]
- # Positions (broadcast across q and k since same pos)
+ # --- Load positions (shared for q and k since pos is per-token) ---
po = b * stride_pb + t_rng * stride_pt
p_t = tl.load(pos_t_ptr + po, mask=t_mask, other=0)
p_h = tl.load(pos_h_ptr + po, mask=t_mask, other=0)
p_w = tl.load(pos_w_ptr + po, mask=t_mask, other=0)
- # Axis select
+ # --- Axis select per d ---
d_f = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
- _bt = d_f < s0
- _bh = (d_f >= s0) & (d_f < s0 + s1)
- ps = tl.where(_bt[None, :], p_t[:, None],
- tl.where(_bh[None, :], p_h[:, None], p_w[:, None]))
-
- co = ps * stride_ccb + d_rng[None, :] * stride_ccd
- cos = tl.load(cos_cache + co, mask=td_mask, other=1.0)
- sin = tl.load(sin_cache + co, mask=td_mask, other=0.0)
-
- # Determine if this is a q-head or k-head, and apply RoPE accordingly
- is_q = h < Hq
- h_idx = tl.where(is_q, h, h - Hq)
-
- # Q processing
- if is_q:
- qb = b * stride_qb + h * stride_qh
- qo = qb + t_rng[:, None] * stride_qt + d_rng[None, :] * stride_qd
- qx = tl.load(q_ptr + qo, mask=td_mask, other=0.0)
-
- d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
- qmo = qb + t_rng[:, None] * stride_qt + d_mir[None, :] * stride_qd
- qm = tl.load(q_ptr + qmo, mask=td_mask, other=0.0)
- qr = tl.where(d_rng[None, :] < half_d, -qm, qm)
-
- qout = qx * cos + qr * sin
-
- qob = b * stride_qob
- qoo = qob + h * stride_qot * 0 + t_rng[:, None] * stride_qot + d_rng[None, :] * stride_qod
- Wait this is wrong. Let me use actual stride values.
- # Out stride for (B, H, T, D): stride_qob = T*D*2, stride for h = D*2, stride for t = D*2, stride for d = 2
-
... (4832 more chars){
"content": "<6864 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,10 +1,7 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Unified kernel: grid_0 = B * (Hq + Hkv) * ceil(T/BLOCK_T), grid_1 = ceil(D/BLOCK_D).
-Each program handles one (b, h_combined, t_tile, d_tile).
-head index < Hq -> process q; head index >= Hq -> process k.
-
-No conditional branching on head type — instead we use tl.where for all operations.
+Separate kernels for q and k (pointers can't be tl.where'd in Triton).
+Each kernel: grid_0 = B*H*ceil(T/BLOCK_T), grid_1 = ceil(D/BLOCK_D).
"""
import torch
import torch.nn as nn
@@ -13,24 +10,32 @@
@triton.jit
-def _rope_unified(
- q_ptr, k_ptr,
- cos_cache, sin_cache,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- q_out_ptr, k_out_ptr,
+def _rope_kernel(
+ x_ptr,
+ cos_cache,
+ sin_cache,
+ pos_t_ptr,
+ pos_h_ptr,
+ pos_w_ptr,
+ out_ptr,
B: tl.int64,
T: tl.int64,
- Hq: tl.int64,
- Hkv: tl.int64,
+ H: tl.int64,
D: tl.int64,
s0: tl.int64,
s1: tl.int64,
- stride_qb, stride_qt, stride_qh, stride_qd,
- stride_kb, stride_kt, stride_kh, stride_kd,
- stride_pb, stride_pt,
- stride_ccb, stride_ccd,
- stride_qob, stride_qoh, stride_qot, stride_qod,
- stride_kob, stride_koh, stride_kot, stride_kod,
+ stride_xb,
+ stride_xt,
+ stride_xh,
+ stride_xd,
+ stride_pb,
+ stride_pt,
+ stride_ccb,
+ stride_ccd,
+ stride_ob,
+ stride_oh,
+ stride_ot,
+ stride_od,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
@@ -39,12 +44,11 @@
R = tl.cdiv(T, BLOCK_T)
D_tiles = tl.cdiv(D, BLOCK_D)
- total_h = Hq + Hkv
bh = pid // R
r = pid % R
- b = bh // total_h
- h_combined = bh % total_h
+ b = bh // H
+ h = bh % H
if b >= B or pid_d >= D_tiles:
return
@@ -61,59 +65,62 @@
d_mask = d_rng < D
td_mask = t_mask[:, None] & d_mask[None, :]
- # --- Load positions (shared for q and k since pos is per-token) ---
+ # Position indices
po = b * stride_pb + t_rng * stride_pt
p_t = tl.load(pos_t_ptr + po, mask=t_mask, other=0)
p_h = tl.load(pos_h_ptr + po, mask=t_mask, other=0)
p_w = tl.load(pos_w_ptr + po, mask=t_mask, other=0)
- # --- Axis select per d ---
+ # Axis select per d
d_f = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
- is_axis_t = d_f < s0
- is_axis_h = (d_f >= s0) & (d_f < s0 + s1)
- pos_sel = tl.where(is_axis_t[None, :], p_t[:, None],
- tl.where(is_axis_h[None, :], p_h[:, None], p_w[:, None]))
+ is_t = d_f < s0
+ is_h = (d_f >= s0) & (d_f < s0 + s1)
+ pos_sel = tl.where(is_t[None, :], p_t[:, None],
+ tl.where(is_h[None, :], p_h[:, None], p_w[:, None]))
- # --- Cos/sin cache load (shared) ---
+ # Single cos/sin load
cache_off = pos_sel * stride_ccb + d_rng[None, :] * stride_ccd
cos = tl.load(cos_cache + cache_off, mask=td_mask, other=1.0)
sin = tl.load(sin_cache + cache_off, mask=td_mask, other=0.0)
- # --- Determine if this program handles q or k ---
- is_q = h_combined < Hq
- h_idx = tl.where(is_q, h_combined, h_combined - Hq)
+ # x[b, t, h, d]
+ x_base = b * stride_xb + h * stride_xh
+ x_off = x_base + t_rng[:, None] * stride_xt + d_rng[None, :] * stride_xd
+ x = tl.load(x_ptr + x_off, mask=td_mask, other=0.0)
- # Select strides and pointers based on is_q
- stride_b = tl.where(is_q, stride_qb, stride_kb)
- stride_t = tl.where(is_q, stride_qt, stride_kt)
- stride_h = tl.where(is_q, stride_qh, stride_kh)
- stride_d = tl.where(is_q, stride_qd, stride_kd)
- in_ptr = tl.where(is_q, q_ptr, k_ptr)
-
- stride_ob = tl.where(is_q, stride_qob, stride_kob)
- stride_oh = tl.where(is_q, stride_qoh, stride_koh)
- stride_ot = tl.where(is_q, stride_qot, stride_kot)
- stride_od = tl.where(is_q, stride_qod, stride_kod)
- out_ptr = tl.where(is_q, q_out_ptr, k_out_ptr)
-
- # Load x
- x_base = b * stride_b + h_idx * stride_h
- x_off = x_base + t_rng[:, None] * stride_t + d_rng[None, :] * stride_d
- x = tl.load(in_ptr + x_off, mask=td_mask, other=0.0)
-
- # Load x_mirror
+ # Mirror for rotate_half
d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
- x_mir_off = x_base + t_rng[:, None] * stride_t + d_mir[None, :] * stride_d
- x_mir = tl.load(in_ptr + x_mir_off, mask=td_mask, other=0.0)
+ x_mir_off = x_base + t_rng[:, None] * stride_xt + d_mir[None, :] * stride_xd
+ x_mir = tl.load(x_ptr + x_mir_off, mask=td_mask, other=0.0)
x_rot = tl.where(d_rng[None, :] < half_d, -x_mir, x_mir)
# RoPE
out = x * cos + x_rot * sin
- # Store
- out_base = b * stride_ob + h_idx * stride_oh
+ # Store to (B, H, T, D)
+ out_base = b * stride_ob + h * stride_oh
out_off = out_base + t_rng[:, None] * stride_ot + d_rng[None, :] * stride_od
tl.store(out_ptr + out_off, out, mask=td_mask)
+
+
+def _run_kernel(x, cos_cache, sin_cache, pos_t, pos_h, pos_w, out, H, D,
+ s0, s1, BLOCK_T, BLOCK_D, num_warps):
+ B, T = x.shape[0], x.shape[1]
+ n_t = (T + BLOCK_T - 1) // BLOCK_T
+ n_d = (D + BLOCK_D - 1) // BLOCK_D
+
+ _rope_kernel[(B * H * n_t, n_d)](
+ x, cos_cache, sin_cache, pos_t, pos_h, pos_w, out,
+ B=B, T=T, H=H, D=D, s0=s0, s1=s1,
+ stride_xb=x.stride(0), stride_xt=x.stride(1),
+ stride_xh=x.stride(2), stride_xd=x.stride(3),
+ stride_pb=pos_t.stride(0), stride_pt=pos_t.stride(1),
+ stride_ccb=cos_cache.stride(0), stride_ccd=cos_cache.stride(1),
+ stride_ob=out.stride(0), stride_oh=out.stride(1),
+ stride_ot=out.stride(2), stride_od=out.stride(3),
+ BLOCK_T=BLOCK_T, BLOCK_D=BLOCK_D,
+ num_warps=num_warps,
+ )
def apply_mrope_torch(q, k, cos_cache, sin_cache, pos_t, p
... (1504 more chars){
"content": "<5874 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"description": "Check correctness",
"timeout": 120000
}PASS
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,9 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Separate kernels for q and k (pointers can't be tl.where'd in Triton).
-Each kernel: grid_0 = B*H*ceil(T/BLOCK_T), grid_1 = ceil(D/BLOCK_D).
+Key optimization: cos/sin depends only on (b, t, d), NOT on h.
+Grid over (B * ceil(T/BLOCK_T), ceil(D/BLOCK_D)).
+Each program loads cos/sin ONCE, then loops over all H (both q and k heads).
+This saves Hq+Hkv factor in random cache reads — critical for bandwidth.
"""
import torch
import torch.nn as nn
@@ -11,50 +13,45 @@
@triton.jit
def _rope_kernel(
- x_ptr,
- cos_cache,
- sin_cache,
- pos_t_ptr,
- pos_h_ptr,
- pos_w_ptr,
- out_ptr,
+ q_ptr, k_ptr,
+ cos_cache, sin_cache,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ q_out_ptr, k_out_ptr,
B: tl.int64,
T: tl.int64,
- H: tl.int64,
+ Hq: tl.int64,
+ Hkv: tl.int64,
D: tl.int64,
s0: tl.int64,
s1: tl.int64,
- stride_xb,
- stride_xt,
- stride_xh,
- stride_xd,
- stride_pb,
- stride_pt,
- stride_ccb,
- stride_ccd,
- stride_ob,
- stride_oh,
- stride_ot,
- stride_od,
+ stride_qb, stride_qt, stride_qh, stride_qd,
+ stride_kb, stride_kt, stride_kh, stride_kd,
+ stride_pb, stride_pt,
+ stride_ccb, stride_ccd,
+ stride_qob, stride_qoh, stride_qot, stride_qod,
+ stride_kob, stride_koh, stride_kot, stride_kod,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
+ """
+ Grid: (B * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
+ Each program handles one (b, t_tile, d_tile) and ALL heads.
+ cos/sin loaded once, then RoPE applied to each head.
+ """
pid = tl.program_id(0)
pid_d = tl.program_id(1)
R = tl.cdiv(T, BLOCK_T)
D_tiles = tl.cdiv(D, BLOCK_D)
- bh = pid // R
r = pid % R
- b = bh // H
- h = bh % H
+ b = pid // R
+ d_off = pid_d * BLOCK_D
if b >= B or pid_d >= D_tiles:
return
t_off = r * BLOCK_T
- d_off = pid_d * BLOCK_D
half_d = D // 2
_rt = tl.arange(0, BLOCK_T)
@@ -65,62 +62,65 @@
d_mask = d_rng < D
td_mask = t_mask[:, None] & d_mask[None, :]
- # Position indices
+ # --- Load positions ---
po = b * stride_pb + t_rng * stride_pt
p_t = tl.load(pos_t_ptr + po, mask=t_mask, other=0)
p_h = tl.load(pos_h_ptr + po, mask=t_mask, other=0)
p_w = tl.load(pos_w_ptr + po, mask=t_mask, other=0)
- # Axis select per d
+ # --- Axis select per d ---
d_f = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
is_t = d_f < s0
is_h = (d_f >= s0) & (d_f < s0 + s1)
pos_sel = tl.where(is_t[None, :], p_t[:, None],
tl.where(is_h[None, :], p_h[:, None], p_w[:, None]))
- # Single cos/sin load
+ # --- Cos/sin cache (loaded ONCE per program, reused for all H) ---
cache_off = pos_sel * stride_ccb + d_rng[None, :] * stride_ccd
cos = tl.load(cos_cache + cache_off, mask=td_mask, other=1.0)
sin = tl.load(sin_cache + cache_off, mask=td_mask, other=0.0)
- # x[b, t, h, d]
- x_base = b * stride_xb + h * stride_xh
- x_off = x_base + t_rng[:, None] * stride_xt + d_rng[None, :] * stride_xd
- x = tl.load(x_ptr + x_off, mask=td_mask, other=0.0)
-
- # Mirror for rotate_half
- d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
- x_mir_off = x_base + t_rng[:, None] * stride_xt + d_mir[None, :] * stride_xd
- x_mir = tl.load(x_ptr + x_mir_off, mask=td_mask, other=0.0)
- x_rot = tl.where(d_rng[None, :] < half_d, -x_mir, x_mir)
-
- # RoPE
- out = x * cos + x_rot * sin
-
- # Store to (B, H, T, D)
- out_base = b * stride_ob + h * stride_oh
- out_off = out_base + t_rng[:, None] * stride_ot + d_rng[None, :] * stride_od
- tl.store(out_ptr + out_off, out, mask=td_mask)
-
-
-def _run_kernel(x, cos_cache, sin_cache, pos_t, pos_h, pos_w, out, H, D,
- s0, s1, BLOCK_T, BLOCK_D, num_warps):
- B, T = x.shape[0], x.shape[1]
- n_t = (T + BLOCK_T - 1) // BLOCK_T
- n_d = (D + BLOCK_D - 1) // BLOCK_D
-
- _rope_kernel[(B * H * n_t, n_d)](
- x, cos_cache, sin_cache, pos_t, pos_h, pos_w, out,
- B=B, T=T, H=H, D=D, s0=s0, s1=s1,
- stride_xb=x.stride(0), stride_xt=x.stride(1),
- stride_xh=x.stride(2), stride_xd=x.stride(3),
- stride_pb=pos_t.stride(0), stride_pt=pos_t.stride(1),
- stride_ccb=cos_cache.stride(0), stride_ccd=cos_cache.stride(1),
- stride_ob=out.stride(0), stride_oh=out.stride(1),
- stride_ot=out.stride(2), stride_od=out.stride(3),
- BLOCK_T=BLOCK_T, BLOCK_D=BLOCK_D,
- num_warps=num_warps,
- )
+ # --- Process Q heads ---
+ h = tl.arange(0, Hq)
+ # Loop over Hq heads — use scalar loop since Triton can handle it
+ h_idx = 0
+ while h_idx < Hq:
+ qb = b * stride_qb + h_idx * stride_qh
+ q_off = qb + t_rng[:, None] * stride_qt + d_rng[None, :] * stride_qd
+ qx = tl.load(q_ptr + q_off, mask=td_mask, other=0.0)
+
+ d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
+ q_mir_off = qb + t_rng[:, None] * stride_qt + d_mir[None, :] * stride_qd
+ qm = tl.load(q_ptr + q_mir_off, mask=td_mask, other=0.0)
+ qr = tl.where(d_rng[None, :] < half_d, -qm, qm)
+
+ qout = qx * cos + qr * sin
+
+ q_out_base = b * stride_qob + h_idx * stride_qoh
+ q_out_off = q_out_base + t_rng[:, None] * stride_qot + d_rng[None, :] * stride_qod
+ tl.store(q_out_ptr + q_out_off, qout, mask=td_mask)
+
+ h_idx += 1
+
+ # --- Process K heads ---
+ h_idx = 0
+ while h_idx < Hkv:
+ kb = b * stride_kb + h_idx * stride_kh
+ k_off = kb + t_rng[:, None] * stride_kt + d_rng[None, :] * stride_kd
+ kx = tl.load(k_ptr + k_off, mask=td_mask, other=0.0)
+
+
... (2113 more chars){
"content": "<7084 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"description": "Check correctness",
"timeout": 120000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 43, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/core.py", line 1644, in arange
return _semantic.arange(start, end)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/triton/language/semantic.py", line 574, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
... (2472 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -81,8 +81,6 @@
sin = tl.load(sin_cache + cache_off, mask=td_mask, other=0.0)
# --- Process Q heads ---
- h = tl.arange(0, Hq)
- # Loop over Hq heads — use scalar loop since Triton can handle it
h_idx = 0
while h_idx < Hq:
qb = b * stride_qb + h_idx * stride_qh
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<138 chars \u2014 see below>",
"newString": "<43 chars \u2014 see below>"
} # --- Process Q heads ---
h = tl.arange(0, Hq)
# Loop over Hq heads — use scalar loop since Triton can handle it
h_idx = 0 # --- Process Q heads ---
h_idx = 0Edit applied successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"description": "Check correctness with H loop",
"timeout": 120000
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"description": "Benchmark unified kernel",
"timeout": 300000
}shape=0 variant=eager tflops=0.174 gbps=250.255 ms=0.361
shape=0 variant=compiled tflops=0.506 gbps=725.690 ms=0.124
shape=0 variant=solution tflops=0.663 gbps=950.672 ms=0.095
shape=0 solution_peak_fraction=0.5282
shape=1 variant=eager tflops=0.179 gbps=260.789 ms=0.564
shape=1 variant=compiled tflops=0.473 gbps=690.777 ms=0.213
shape=1 variant=solution tflops=0.752 gbps=1098.043 ms=0.134
shape=1 solution_peak_fraction=0.6100
shape=2 variant=eager tflops=0.075 gbps=117.511 ms=0.188
shape=2 variant=compiled tflops=0.170 gbps=266.358 ms=0.083
shape=2 variant=solution tflops=0.304 gbps=475.542 ms=0.047
shape=2 solution_peak_fraction=0.2642
shape=3 variant=eager tflops=0.171 gbps=245.104 ms=1.473
shape=3 variant=compiled tflops=0.465 gbps=667.523 ms=0.541
shape=3 variant=solution tflops=0.768 gbps=1102.108 ms=0.328
shape=3 solution_peak_fraction=0.6123
peak_fraction: 0.4778
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,9 +1,7 @@
"""Optimized Triton kernel for multi-axis RoPE pre-attention prep.
-Key optimization: cos/sin depends only on (b, t, d), NOT on h.
-Grid over (B * ceil(T/BLOCK_T), ceil(D/BLOCK_D)).
-Each program loads cos/sin ONCE, then loops over all H (both q and k heads).
-This saves Hq+Hkv factor in random cache reads — critical for bandwidth.
+Unified kernel: cos/sin loaded once per (b, t_tile, d_tile), reused across all H.
+Tunable BLOCK_T, BLOCK_D, num_warps for each shape.
"""
import torch
import torch.nn as nn
@@ -33,11 +31,6 @@
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
- """
- Grid: (B * ceil(T/BLOCK_T), ceil(D/BLOCK_D))
- Each program handles one (b, t_tile, d_tile) and ALL heads.
- cos/sin loaded once, then RoPE applied to each head.
- """
pid = tl.program_id(0)
pid_d = tl.program_id(1)
@@ -62,63 +55,53 @@
d_mask = d_rng < D
td_mask = t_mask[:, None] & d_mask[None, :]
- # --- Load positions ---
+ # Positions
po = b * stride_pb + t_rng * stride_pt
p_t = tl.load(pos_t_ptr + po, mask=t_mask, other=0)
p_h = tl.load(pos_h_ptr + po, mask=t_mask, other=0)
p_w = tl.load(pos_w_ptr + po, mask=t_mask, other=0)
- # --- Axis select per d ---
+ # Axis select per d
d_f = tl.where(d_rng < half_d, d_rng, d_rng - half_d)
is_t = d_f < s0
is_h = (d_f >= s0) & (d_f < s0 + s1)
pos_sel = tl.where(is_t[None, :], p_t[:, None],
tl.where(is_h[None, :], p_h[:, None], p_w[:, None]))
- # --- Cos/sin cache (loaded ONCE per program, reused for all H) ---
+ # Cos/sin cache — loaded ONCE, reused for all heads
cache_off = pos_sel * stride_ccb + d_rng[None, :] * stride_ccd
cos = tl.load(cos_cache + cache_off, mask=td_mask, other=1.0)
sin = tl.load(sin_cache + cache_off, mask=td_mask, other=0.0)
- # --- Process Q heads ---
- h_idx = 0
- while h_idx < Hq:
- qb = b * stride_qb + h_idx * stride_qh
- q_off = qb + t_rng[:, None] * stride_qt + d_rng[None, :] * stride_qd
- qx = tl.load(q_ptr + q_off, mask=td_mask, other=0.0)
+ # Rotate-half mirror indices (same for all heads)
+ d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
+ is_first_half = d_rng[None, :] < half_d
- d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
- q_mir_off = qb + t_rng[:, None] * stride_qt + d_mir[None, :] * stride_qd
- qm = tl.load(q_ptr + q_mir_off, mask=td_mask, other=0.0)
- qr = tl.where(d_rng[None, :] < half_d, -qm, qm)
+ # Process Q heads
+ h = 0
+ while h < Hq:
+ q_base = b * stride_qb + h * stride_qh
+ qx = tl.load(q_ptr + q_base + t_rng[:, None] * stride_qt + d_rng[None, :] * stride_qd,
+ mask=td_mask, other=0.0)
+ qm = tl.load(q_ptr + q_base + t_rng[:, None] * stride_qt + d_mir[None, :] * stride_qd,
+ mask=td_mask, other=0.0)
+ out = qx * cos + tl.where(is_first_half, -qm, qm) * sin
+ tl.store(q_out_ptr + b * stride_qob + h * stride_qoh + t_rng[:, None] * stride_qot + d_rng[None, :] * stride_qod,
+ out, mask=td_mask)
+ h += 1
- qout = qx * cos + qr * sin
-
- q_out_base = b * stride_qob + h_idx * stride_qoh
- q_out_off = q_out_base + t_rng[:, None] * stride_qot + d_rng[None, :] * stride_qod
- tl.store(q_out_ptr + q_out_off, qout, mask=td_mask)
-
- h_idx += 1
-
- # --- Process K heads ---
- h_idx = 0
- while h_idx < Hkv:
- kb = b * stride_kb + h_idx * stride_kh
- k_off = kb + t_rng[:, None] * stride_kt + d_rng[None, :] * stride_kd
- kx = tl.load(k_ptr + k_off, mask=td_mask, other=0.0)
-
- d_mir = tl.where(d_rng < half_d, d_rng + half_d, d_rng - half_d)
- k_mir_off = kb + t_rng[:, None] * stride_kt + d_mir[None, :] * stride_kd
- km = tl.load(k_ptr + k_mir_off, mask=td_mask, other=0.0)
- kr = tl.where(d_rng[None, :] < half_d, -km, km)
-
- kout = kx * cos + kr * sin
-
- k_out_base = b * stride_kob + h_idx * stride_koh
- k_out_off = k_out_base + t_rng[:, None] * stride_kot + d_rng[None, :] * stride_kod
- tl.store(k_out_ptr + k_out_off, kout, mask=td_mask)
-
- h_idx += 1
+ # Process K heads
+ h = 0
+ while h < Hkv:
+ k_base = b * stride_kb + h * stride_kh
+ kx = tl.load(k_ptr + k_base + t_rng[:, None] * stride_kt + d_rng[None, :] * stride_kd,
+ mask=td_mask, other=0.0)
+ km = tl.load(k_ptr + k_base + t_rng[:, None] * stride_kt + d_mir[None, :] * stride_kd,
+ mask=td_mask, other=0.0)
+ out = kx * cos + tl.where(is_first_half, -km, km) * sin
+ tl.store(k_out_ptr + b * stride_kob + h * stride_koh + t_rng[:, None] * stride_kot + d_rng[None, :] * stride_kod,
+ out, mask=td_mask)
+ h += 1
def apply_mrope_torch(q, k, cos_cache, sin_cache, pos_t, pos_h, pos_w, mrope_section):
@@ -126,11 +109,18 @@
_, _, Hkv, _ = k.shape
s0, s1, s2 = mrope_section
- BLOCK_T = 16
- BLOCK_D = 64
-
q_rot = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
k_rot = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
+
+ # Tunable configs — selected based on shape characteristics
+ # Larger H benefits from smaller grid => larger BLOCK_T
+ # Bandwidth-bound: want maximal memory throughput
+ if D == 64:
+ BLOCK_T, BLOCK_D, num_warps = 32, 64, 4
+ elif T <= 4096:
+ BLOCK_T, BLOCK_D, num_warps = 16, 64, 4
+ else:
+ BLOCK_T, BLOCK_D, num_warps = 16, 64, 8
n_t = (T + BLOCK_T - 1) // BLOCK_T
n_d = (D + BLOCK_D - 1) // BLOCK_D
@@ -152,7 +142,7 @@
stride_kob=k_rot.stride(0), stride_koh=k_rot.stride(1),
... (185 more chars){
"content": "<6799 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.