"""Multi-axis RoPE pre-attention prep — Triton v3 (SM120 Blackwell).
Two kernel types:
- Large shapes: Grid=(B*T, 1), each program loops over all H heads (serial).
- Small H shapes: Grid=(B*T*H, 1), each program processes exactly one (b,h,t).
The second gives more parallel SMs for better occupancy with small head counts.
"""
import torch
import triton
import triton.language as tl
import torch.nn as nn
@triton.jit
def _rope_serial_h_kernel(
x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
B_val: tl.constexpr, T_val: tl.constexpr,
H_val: tl.constexpr, D_val: tl.constexpr,
stride_xb, stride_xt, stride_xh, stride_xd,
stride_ob, stride_oh, stride_ot, stride_od,
s0: tl.constexpr, s1: tl.constexpr,
half: tl.constexpr, BLOCK_D: tl.constexpr,
):
"""Grid=(B*T). Serial loop over H heads."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
is_t = axis_idx < s0
is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
cos_v = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_v = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
for h in tl.static_range(H_val):
x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
xr = xr * sign
out = x * cos_v + xr * sin_v
o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
@triton.jit
def _rope_parallel_h_kernel(
x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
B_val: tl.constexpr, T_val: tl.constexpr,
D_val: tl.constexpr,
stride_xb, stride_xt, stride_xh, stride_xd,
stride_ob, stride_oh, stride_ot, stride_od,
s0: tl.constexpr, s1: tl.constexpr,
half: tl.constexpr, BLOCK_D: tl.constexpr,
):
"""Grid=(B*T*H). Each program = exactly one (b,h,t)."""
pid_bth = tl.program_id(0)
pid_b = pid_bth // tl.num_programs(1)
pid_th = pid_bth % tl.num_programs(1)
# Actually this doesn't work cleanly. Let me pass explicit B, T, H.
pass
@triton.jit
def _rope_flat_kernel(
x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
H_val: tl.constexpr, D_val: tl.constexpr,
total_bt: tl.constexpr,
stride_xb, stride_xt, stride_xh, stride_xd,
stride_ob, stride_oh, stride_ot, stride_od,
s0: tl.constexpr, s1: tl.constexpr,
half: tl.constexpr, BLOCK_D: tl.constexpr,
):
"""Grid=(B*T*H). Each program = one (b,h,t). No serial loop."""
pid = tl.program_id(0)
pid_h = pid % H_val
pid_bt = pid // H_val
pid_b = pid_bt // (total_bt // H_val) # total_bt = B*T, so pid_bt // (BT/H) doesn't work
# This approach is getting messy. Let me just pass B and T directly.
pass
@triton.jit
def _rope_one_per_head(
x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
B_val: tl.constexpr, T_val: tl.constexpr,
H_val: tl.constexpr, D_val: tl.constexpr,
stride_xb, stride_xt, stride_xh, stride_xd,
stride_ob, stride_oh, stride_ot, stride_od,
s0: tl.constexpr, s1: tl.constexpr,
half: tl.constexpr, BLOCK_D: tl.constexpr,
):
"""Grid=(B*T*H). Each program = exactly one (b, h, t).
Cache cos/sin per (b,t), but no head loop - maximum parallelism."""
pid = tl.program_id(0)
total_bt = B_val * T_val
pid_h = pid % H_val
pid_bth = pid // H_val
pid_b = pid_bth // T_val
pid_t = pid_bth % T_val
pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
is_t = axis_idx < s0
is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
cos_v = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_v = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
xr = xr * sign
out = x * cos_v + xr * sin_v
o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__init__()
assert sum(mrope_section) == D // 2
self.B, self.T = B, T
self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
inv_freq = _build_inv_freq(D)
pos = torch.arange(max_pos, dtype=torch.float32)
freqs = torch.outer(pos, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
def _next_pow2(self, n):
v = 1
while v < n:
v *= 2
return v
def _launch(self, x, x_out, H, method):
B, T, _, D = x.shape
s0, s1, _ = self.mrope_section
half = D // 2
BLOCK_D = self._next_pow2(D)
if method == "parallel":
grid = (B * T * H,)
_rope_one_per_head[grid](
x, x_out,
self.cos_cache, self.sin_cache,
x.reshape(-1, 1, 1)[:, 0, 0].data_ptr(), # wrong
pos_h, pos_w,
B_val=B, T_val=T, H_val=H, D_val=D,
stride_xb=T*H*D, stride_xt=H*D, stride_xh=D, stride_xd=1,
stride_ob=H*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
num_warps=1,
)
else:
grid = (B * T,)
_rope_serial_h_kernel[grid](
x, x_out,
self.cos_cache, self.sin_cache,
x.reshape(-1, 1, 1)[:, 0, 0].data_ptr(), # wrong
pos_h, pos_w,
B_val=B, T_val=T, H_val=H, D_val=D,
stride_xb=T*H*D, stride_xt=H*D, stride_xh=D, stride_xd=1,
stride_ob=H*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
num_warps=1, num_stages=4,
)
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
s0, s1, _ = self.mrope_section
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
BLOCK_D = self._next_pow2(D)
half = D // 2
_rope_serial_h_kernel[(B * T,)](
q, q_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hq, D_val=D,
stride_xb=T*Hq*D, stride_xt=Hq*D, stride_xh=D, stride_xd=1,
stride_ob=Hq*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
num_warps=1, num_stages=4,
)
_rope_serial_h_kernel[(B * T,)](
k, k_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hkv, D_val=D,
stride_xb=T*Hkv*D, stride_xt=Hkv*D, stride_xh=D, stride_xd=1,
stride_ob=Hkv*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
num_warps=1, num_stages=4,
)
return q_out, k_out
def _build_inv_freq(D, base=10000.0):
return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
shape=0 variant=eager tflops=0.179 gbps=257.108 ms=0.351
shape=0 variant=compiled tflops=0.706 gbps=1012.788 ms=0.089
shape=0 variant=solution tflops=1.126 gbps=1615.762 ms=0.056
shape=0 solution_peak_fraction=0.8976
shape=1 variant=eager tflops=0.182 gbps=265.422 ms=0.554
shape=1 variant=compiled tflops=0.481 gbps=702.879 ms=0.209
shape=1 variant=solution tflops=0.743 gbps=1085.331 ms=0.135
shape=1 solution_peak_fraction=0.6030
shape=2 variant=eager tflops=0.066 gbps=103.427 ms=0.214
shape=2 variant=compiled tflops=0.139 gbps=216.575 ms=0.102
shape=2 variant=solution tflops=0.323 gbps=503.974 ms=0.044
shape=2 solution_peak_fraction=0.2800
shape=3 variant=eager tflops=0.173 gbps=247.776 ms=1.457
shape=3 variant=compiled tflops=0.464 gbps=665.084 ms=0.543
shape=3 variant=solution tflops=0.869 gbps=1246.559 ms=0.290
shape=3 solution_peak_fraction=0.6925
peak_fraction: 0.5692
RESULT: OK
PASS
{
"problem": "09_fmha_preattn_mrope",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-plus",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.5692,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 5089607, "output_tokens": 76415, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 39072, "total_cost_usd": null}
}
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
1: """Naive multi-axis RoPE pre-attention reference (correctness only).
2:
3: Mirrors the Qwen2-VL apply_multimodal_rotary_pos_emb semantics: three position
4: axes (temporal, height, width), each contributing a slice of the head_dim via
5: mrope_section. Standard rotate-half on q and k once per-axis cos/sin have been
6: mixed.
7:
8: Layout convention:
9: q: (B, T, Hq, D) bf16
10: k: (B, T, Hkv, D) bf16
11: pos_t, pos_h, pos_w: (B, T) int64
12:
13: Output: q_rot, k_rot in (B, H, T, D) — i.e. the layout an attention kernel
14: expects (head dim second). The transpose is part of the "pre-attention prep".
15: """
16: import torch
17: import torch.nn as nn
18:
19: OP_TYPE = "rope"
20: SUPPORTED_PRECISIONS = ["bf16"]
21: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
22:
23:
24: def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
25: return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
26:
27:
28: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
29: half = x.shape[-1] // 2
30: return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
31:
32:
33: class Model(nn.Module):
34: def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
35: mrope_section: tuple[int, int, int], max_pos: int):
36: super().__init__()
37: assert sum(mrope_section) == D // 2, \
38:
... (3384 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for multi-axis RoPE pre-attention prep.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or
... (3519 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for multi-axis RoPE pre-attention prep.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime
... (3966 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for multi-axis RoPE pre-attention.
2:
3: Sized for Qwen2-VL-style vision-language inference: long T, GQA, mrope_section
4: splitting D/2 across temporal/height/width axes. Mix of base, long-context,
5: small-head-dim, and very-long-context.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL base: T=4k, head_dim=128, GQA 32:8, mrope [16,24,24]
10: {"B": 1, "T": 4096, "Hq": 32, "Hkv": 8, "D": 128,
11: "mrope_section": (16, 24, 24), "max_pos": 32768},
12:
13: # Qwen2-VL 7B-style long context, narrower kv heads
14: {"B": 1, "T": 8192, "Hq": 28, "Hkv": 4, "D": 128,
15: "mrope_section": (16, 24, 24), "max_pos": 32768},
16:
17: # Smaller head_dim and B=2 (tests batch dim and reduced D)
18: {"B": 2, "T": 2048, "Hq": 16, "Hkv": 2, "D": 64,
19: "mrope_section": (8, 12, 12), "max_pos": 16384},
20:
21: # Very long context (16k tokens)
22: {"B": 1, "T": 16384, "Hq": 32, "Hkv": 8, "D": 128,
23: "mrope_section": (16, 24, 24), "max_pos": 65536},
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml</path>
<type>file</type>
<content>
1: name: 09_fmha_preattn_mrope
2: display_name: "Multi-axis RoPE Pre-attention Prep"
3: precision: bf16
4: regime: memory # pointwise rotation + cache gather; bandwidth-bound
5:
6: # FLOPs per call: rotate-half on q and k.
7: # Per element: out = a*cos + b*sin -> 3 flops/elem (1 mul + 1 mul + 1 add).
8: # We do this on q ((B*T*Hq*D) elems) and k ((B*T*Hkv*D) elems).
9: flops_formula: "3 * B * T * (Hq + Hkv) * D"
10:
11: # Bytes moved (lower bound):
12: # read q, k: 2 * B*T*(Hq+Hkv)*D
13: # read 3 position id arrays: 3 * B*T*8 (int64)
14: # read cos/sin from cache: 6 * B*T*D*2 (3 axes, 2 caches, bf16)
15: # write q_rot, k_rot: 2 * B*T*(Hq+Hkv)*D
16: bytes_formula: "4 * B * T * (Hq + Hkv) * D + 12 * B * T * D + 24 * B * T"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram
21:
22: tolerance:
23: bfloat16: 0.02
24:
25: # Forbidden ops -- agent must implement the gather + per-axis mix + rotation.
26: # Banning vendor mRoPE entry points and obvious shortcuts.
27: forbidden:
28: - "transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb"
29: - "apply_multimodal_rotary_pos_emb"
30: - "flash_attn.layers.rotary"
31: - "flashinfer.rope"
32: - "flashinfer.apply_rope"
33:
34: sota:
35: name:
... (284 more chars){
"pattern": "problems/*/solution.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard"
}No files found{
"command": "uv run python -c \"import triton; print(triton.__version__)\" 2>&1",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard",
"description": "Check if Triton is available"
}3.6.0
{
"command": "nvidia-smi --query-gpu=name,memory.total,compute_cap,pstate --format=csv 2>&1 | head -10",
"description": "Check GPU info"
}name, memory.total [MiB], compute_cap, pstate
NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 12.0, P0
{
"todos": [
{
"content": "Implement initial Triton kernel for mRoPE",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Optimize kernel performance",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py to measure achieved bandwidth",
"status": "pending",
"priority": "medium"
}
]
}content: 'Implement initial Triton kernel for mRoPE'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Optimize kernel performance'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py to measure achieved bandwidth'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,235 @@
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120).
+
+Fused kernel: gather per-axis cos/sin from cache, mix by mrope_section,
+apply rotate-half rotation, transpose (B,T,H,D)->(B,H,T,D) in one pass.
+"""
+
+import torch
+import triton
+import triton.language as tl
+import torch.nn as nn
+
+
+@triton.jit
+def _rope_fwd_kernel(
+ x_ptr,
+ x_out_ptr,
+ pos_t_ptr,
+ pos_h_ptr,
+ pos_w_ptr,
+ cos_cache_ptr,
+ sin_cache_ptr,
+ B_val: tl.constexpr,
+ T_val: tl.constexpr,
+ H_val: tl.constexpr,
+ D_val: tl.constexpr,
+ stride_xb,
+ stride_xt,
+ stride_xh,
+ stride_xd,
+ stride_ob,
+ stride_oh,
+ stride_ot,
+ stride_od,
+ stride_pos_b,
+ stride_pos_t,
+ s0: tl.constexpr,
+ s1: tl.constexpr,
+ half: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """Each program handles one (b, h, t). Loads full D, rotates, writes to (B,H,T,D)."""
+ pid_bh = tl.program_id(0)
+ pid_t = tl.program_id(1)
+
+ # Decode (batch, head)
+ pid_b = pid_bh // H_val
+ pid_h = pid_bh % H_val
+
+ # Position index in (B,T) array
+ pos_idx = pid_b * stride_pos_b + pid_t * stride_pos_t
+
+ # Load position IDs
+ p_t = tl.load(pos_t_ptr + pos_idx)
+ p_h = tl.load(pos_h_ptr + pos_idx)
+ p_w = tl.load(pos_w_ptr + pos_idx)
+
+ # Dimension offsets
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D_val
+ half_mask = d_offs < half
+
+ # Base offset into input tensor (B,T,H,D)
+ x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+
+ # Output base offset (B,H,T,D)
+ o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+
+ # Load x vector
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0)
+
+ # Load cos/sin for temporal axis
+ cos_t_ = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0)
+ sin_t_ = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0)
+
+ # Load cos/sin for height axis
+ cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0)
+ sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0)
+
+ # Load cos/sin for width axis
+ cos_w_ = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0)
+ sin_w_ = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0)
+
+ # Mix axes: for d in [0, half), axis_idx = d; for d in [half, D), axis_idx = d - half
+ axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+
+ # Axis assignment masks
+ is_temporal = axis_idx < s0
+ is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ # is_width is implicit (the rest)
+
+ # Select cos/sin per axis
+ cos = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
+ sin = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
+
+ # rotate_half(x) = [-x[half:], x[:half]]
+ # For d < half: x_rot[d] = -x[d + half]
+ # For d >= half: x_rot[d] = x[d - half]
+ # Permute offsets
+ permuted_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ x_rot_loaded = tl.load(x_ptr + x_base + permuted_offs * stride_xd, mask=mask, other=0.0)
+ x_rot_loaded = tl.where(d_offs < half, -x_rot_loaded, x_rot_loaded)
+
+ # Apply rotation: out = x*cos + x_rot*sin
+ x_fp32 = x.to(tl.float32)
+ cos_fp32 = cos.to(tl.float32)
+ sin_fp32 = sin.to(tl.float32)
+ x_rot_fp32 = x_rot_loaded.to(tl.float32)
+
+ out = (x_fp32 * cos_fp32) + (x_rot_fp32 * sin_fp32)
+ out = out.to(tl.bfloat16)
+
+ # Store in transposed layout (B,H,T,D)
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out, mask=mask)
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
+ mrope_section: tuple[int, int, int], max_pos: int):
+ super().__init__()
+ assert sum(mrope_section) == D // 2, \
+ f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
+ self.B, self.T = B, T
+ self.Hq, self.Hkv, self.D = D
+ self.mrope_section = tuple(mrope_section)
+ self.max_pos = max_pos
+
+ inv_freq = _build_inv_freq(D)
+ pos = torch.arange(max_pos, dtype=torch.float32)
+ freqs = torch.outer(pos, inv_freq)
+ emb = torch.cat([freqs, freqs], dim=-1)
+ self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
+ self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+
+ # Precompute grid and strides
+ self._q_grid = (B * Hq, T)
+ self._k_grid = (B * Hkv, T)
+ self._block_d = self._next_pow2(D)
+
+ q_strides = self._strides(B, T, Hq, D)
+ k_strides = self._strides(B, T, Hkv, D)
+ self._q_strides = q_strides
+ self._k_strides = k_strides
+
+ @staticmethod
+ def _strides(B, T, H, D):
+ return {
+ 'stride_xb': T * H * D,
+ 'stride_xt': H * D,
+ 'stride_xh': D,
+ 'stride_xd': 1,
+ 'stride_ob': H * T * D,
+ 'stride_oh': T * D,
+ 'stride_ot': D,
+ 'stride_od': 1,
+ }
+
+ @staticmethod
+ def _next_pow2(n):
+ v = 1
+ while v < n:
+ v *= 2
+ return v
+
+ def forward(
+ self,
+ q: torch.Tensor, k: torch.Tensor,
+ pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ B, T, Hq, D = q.shape
+ _, _, Hkv, _ = k.shape
+ s0, s1, _ = self.mrope_section
+ half = D // 2
+ BLOCK_D = self._next_pow2(D)
+
+ # Allocate outputs in (B,H,T,D) layout
+ q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
+ k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.dev
... (1531 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<7177 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard && uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run correctness check"
}/home/infatoshi/cuda/KernelBench-Hard/.venv/bin/python3: can't open file '/home/infatoshi/cuda/KernelBench-Hard/check.py': [Errno 2] No such file or directory
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 57, in main
sol_model = solution.Model(*init_args).to(device).eval()
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 123, in __init__
self.Hq, self.Hkv, self.D = D
^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot unpack non-iterable int object
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -120,7 +120,7 @@
assert sum(mrope_section) == D // 2, \
f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
self.B, self.T = B, T
- self.Hq, self.Hkv, self.D = D
+ self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " self.Hq, self.Hkv, self.D = D",
"newString": " self.Hq, self.Hkv, self.D = Hq, Hkv, D"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}shape=0 variant=eager tflops=0.173 gbps=247.814 ms=0.364
shape=0 variant=compiled tflops=0.460 gbps=659.603 ms=0.137
shape=0 variant=solution tflops=0.326 gbps=467.537 ms=0.193
shape=0 solution_peak_fraction=0.2597
shape=1 variant=eager tflops=0.179 gbps=260.789 ms=0.564
shape=1 variant=compiled tflops=0.461 gbps=673.805 ms=0.218
shape=1 variant=solution tflops=0.335 gbps=489.130 ms=0.301
shape=1 solution_peak_fraction=0.2717
shape=2 variant=eager tflops=0.072 gbps=112.126 ms=0.197
shape=2 variant=compiled tflops=0.154 gbps=240.209 ms=0.092
shape=2 variant=solution tflops=0.204 gbps=319.482 ms=0.069
shape=2 solution_peak_fraction=0.1775
shape=3 variant=eager tflops=0.168 gbps=240.441 ms=1.502
shape=3 variant=compiled tflops=0.460 gbps=660.587 ms=0.547
shape=3 variant=solution tflops=0.362 gbps=519.137 ms=0.696
shape=3 solution_peak_fraction=0.2884
peak_fraction: 0.2452
RESULT: OK
{
"todos": [
{
"status": "completed",
"content": "Implement initial Triton kernel for mRoPE",
"priority": "high"
},
{
"status": "completed",
"content": "Run check.py for correctness",
"priority": "high"
},
{
"status": "in_progress",
"content": "Optimize kernel performance (target ~1.5+ TB/s)",
"priority": "high"
},
{
"status": "pending",
"content": "Run benchmark.py to measure achieved bandwidth",
"priority": "medium"
}
]
}content: 'Implement initial Triton kernel for mRoPE'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Optimize kernel performance (target ~1.5+ TB/s)'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark.py to measure achieved bandwidth'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120).
-
-Fused kernel: gather per-axis cos/sin from cache, mix by mrope_section,
-apply rotate-half rotation, transpose (B,T,H,D)->(B,H,T,D) in one pass.
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
+
+Single fused kernel for Q+K: gather per-axis cos/sin from cache, mix by
+mrope_section, apply rotate-half rotation, transpose (B,T,H,D)->(B,H,T,D).
"""
import torch
@@ -12,8 +12,10 @@
@triton.jit
def _rope_fwd_kernel(
- x_ptr,
- x_out_ptr,
+ q_ptr,
+ k_ptr,
+ q_out_ptr,
+ k_out_ptr,
pos_t_ptr,
pos_h_ptr,
pos_w_ptr,
@@ -21,12 +23,13 @@
sin_cache_ptr,
B_val: tl.constexpr,
T_val: tl.constexpr,
- H_val: tl.constexpr,
+ Hq_val: tl.constexpr,
+ Hkv_val: tl.constexpr,
D_val: tl.constexpr,
- stride_xb,
- stride_xt,
- stride_xh,
- stride_xd,
+ stride_qb,
+ stride_qt,
+ stride_qh,
+ stride_qd,
stride_ob,
stride_oh,
stride_ot,
@@ -37,80 +40,63 @@
s1: tl.constexpr,
half: tl.constexpr,
BLOCK_D: tl.constexpr,
+ FUSE_QK: tl.constexpr,
):
- """Each program handles one (b, h, t). Loads full D, rotates, writes to (B,H,T,D)."""
- pid_bh = tl.program_id(0)
- pid_t = tl.program_id(1)
-
- # Decode (batch, head)
- pid_b = pid_bh // H_val
- pid_h = pid_bh % H_val
-
- # Position index in (B,T) array
+ """Single kernel handling both Q and K. Each program handles one (b, t) pair,
+ processes all Hq heads for Q and all Hkv heads for K."""
+ pid_bt = tl.program_id(0)
+ pid_b = pid_bt // T_val
+ pid_t = pid_bt % T_val
+
pos_idx = pid_b * stride_pos_b + pid_t * stride_pos_t
-
- # Load position IDs
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- # Dimension offsets
+ # Load per-position cache values (shared across all heads)
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
- half_mask = d_offs < half
-
- # Base offset into input tensor (B,T,H,D)
- x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
-
- # Output base offset (B,H,T,D)
- o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
-
- # Load x vector
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0)
-
- # Load cos/sin for temporal axis
- cos_t_ = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0)
- sin_t_ = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0)
-
- # Load cos/sin for height axis
- cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0)
- sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0)
-
- # Load cos/sin for width axis
- cos_w_ = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0)
- sin_w_ = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0)
-
- # Mix axes: for d in [0, half), axis_idx = d; for d in [half, D), axis_idx = d - half
+
+ cos_t_val = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_t_val = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h_val = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h_val = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_w_val = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_w_val = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+
+ # axis index (same for first and second half: d and d+half share the same axis)
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
-
- # Axis assignment masks
is_temporal = axis_idx < s0
is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
- # is_width is implicit (the rest)
-
- # Select cos/sin per axis
- cos = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
- sin = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
-
- # rotate_half(x) = [-x[half:], x[:half]]
- # For d < half: x_rot[d] = -x[d + half]
- # For d >= half: x_rot[d] = x[d - half]
- # Permute offsets
+ cos_val = tl.where(is_temporal, cos_t_val, tl.where(is_height, cos_h_val, cos_w_val))
+ sin_val = tl.where(is_temporal, sin_t_val, tl.where(is_height, sin_h_val, sin_w_val))
+
+ # Permutation offsets for rotate_half
permuted_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
- x_rot_loaded = tl.load(x_ptr + x_base + permuted_offs * stride_xd, mask=mask, other=0.0)
- x_rot_loaded = tl.where(d_offs < half, -x_rot_loaded, x_rot_loaded)
-
- # Apply rotation: out = x*cos + x_rot*sin
- x_fp32 = x.to(tl.float32)
- cos_fp32 = cos.to(tl.float32)
- sin_fp32 = sin.to(tl.float32)
- x_rot_fp32 = x_rot_loaded.to(tl.float32)
-
- out = (x_fp32 * cos_fp32) + (x_rot_fp32 * sin_fp32)
- out = out.to(tl.bfloat16)
-
- # Store in transposed layout (B,H,T,D)
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out, mask=mask)
+
+ # Process Q heads
+ for h in tl.static_range(Hq_val):
+ q_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qh
+ x = tl.load(q_ptr + q_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(q_ptr + q_base + permuted_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
+ # Apply negation for first half
+ x_rot = tl.where(d_offs < half, -x_rot, x_rot)
+ out = x * cos_val + x_rot * sin_val
+
+ o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+ tl.store(q_ou
... (6582 more chars){
"content": "<8820 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}FAIL: shape 0 {'B': 1, 'T': 4096, 'Hq': 32, 'Hkv': 8, 'D': 128, 'mrope_section': (16, 24, 24), 'max_pos': 32768} seed 42 k_rot: solution contains NaN
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,8 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Single fused kernel for Q+K: gather per-axis cos/sin from cache, mix by
-mrope_section, apply rotate-half rotation, transpose (B,T,H,D)->(B,H,T,D).
+Two-kernel design: each program handles a single (b, t) position and loops
+over all heads of its tensor (Q or K). Position cache values (cos/sin) are
+computed once per (b,t) and reused across all heads.
"""
import torch
@@ -11,11 +12,9 @@
@triton.jit
-def _rope_fwd_kernel(
- q_ptr,
- k_ptr,
- q_out_ptr,
- k_out_ptr,
+def _rope_kernel(
+ x_ptr,
+ x_out_ptr,
pos_t_ptr,
pos_h_ptr,
pos_w_ptr,
@@ -23,80 +22,65 @@
sin_cache_ptr,
B_val: tl.constexpr,
T_val: tl.constexpr,
- Hq_val: tl.constexpr,
- Hkv_val: tl.constexpr,
+ H_val: tl.constexpr,
D_val: tl.constexpr,
- stride_qb,
- stride_qt,
- stride_qh,
- stride_qd,
+ stride_xb,
+ stride_xt,
+ stride_xh,
+ stride_xd,
stride_ob,
stride_oh,
stride_ot,
stride_od,
- stride_pos_b,
- stride_pos_t,
s0: tl.constexpr,
s1: tl.constexpr,
half: tl.constexpr,
BLOCK_D: tl.constexpr,
- FUSE_QK: tl.constexpr,
):
- """Single kernel handling both Q and K. Each program handles one (b, t) pair,
- processes all Hq heads for Q and all Hkv heads for K."""
+ """Grid = (B * T). Each program processes one (b,t) across ALL H heads."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
- pos_idx = pid_b * stride_pos_b + pid_t * stride_pos_t
+ pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- # Load per-position cache values (shared across all heads)
+ # Load cache values once (shared across all heads)
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
- cos_t_val = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_t_val = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h_val = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h_val = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_w_val = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_w_val = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_t_ = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_t_ = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_w_ = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_w_ = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # axis index (same for first and second half: d and d+half share the same axis)
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
is_temporal = axis_idx < s0
is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_temporal, cos_t_val, tl.where(is_height, cos_h_val, cos_w_val))
- sin_val = tl.where(is_temporal, sin_t_val, tl.where(is_height, sin_h_val, sin_w_val))
+ cos_val = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
+ sin_val = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
- # Permutation offsets for rotate_half
- permuted_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ # Precompute permutation offsets
+ perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ sign = tl.where(d_offs < half, -1.0, 1.0)
- # Process Q heads
- for h in tl.static_range(Hq_val):
- q_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qh
- x = tl.load(q_ptr + q_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(q_ptr + q_base + permuted_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
- # Apply negation for first half
- x_rot = tl.where(d_offs < half, -x_rot, x_rot)
+ # Loop over all heads
+ for h in tl.static_range(H_val):
+ # Input offset (B,T,H,D)
+ x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
+
out = x * cos_val + x_rot * sin_val
+ # Output offset (B,H,T,D)
o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(q_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
-
- if FUSE_QK:
- # Process K heads
- for h in tl.static_range(Hkv_val):
- k_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qd
- x = tl.load(k_ptr + k_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(k_ptr + k_base + permuted_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.where(d_offs < half, -x_rot, x_rot)
- out = x * cos_val + x_rot * sin_val
-
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(k_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
... (5157 more chars){
"content": "<5576 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}shape=0 variant=eager tflops=0.173 gbps=248.930 ms=0.363
shape=0 variant=compiled tflops=0.487 gbps=699.249 ms=0.129
shape=0 variant=solution tflops=0.624 gbps=895.309 ms=0.101
shape=0 solution_peak_fraction=0.4974
shape=1 variant=eager tflops=0.180 gbps=262.525 ms=0.560
shape=1 variant=compiled tflops=0.469 gbps=685.161 ms=0.215
shape=1 variant=solution tflops=0.702 gbps=1025.143 ms=0.143
shape=1 solution_peak_fraction=0.5695
shape=2 variant=eager tflops=0.071 gbps=110.257 ms=0.201
shape=2 variant=compiled tflops=0.160 gbps=249.261 ms=0.089
shape=2 variant=solution tflops=0.339 gbps=530.265 ms=0.042
shape=2 solution_peak_fraction=0.2946
shape=3 variant=eager tflops=0.168 gbps=241.057 ms=1.498
shape=3 variant=compiled tflops=0.466 gbps=668.829 ms=0.540
shape=3 variant=solution tflops=0.829 gbps=1190.220 ms=0.303
shape=3 solution_peak_fraction=0.6612
peak_fraction: 0.4847
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,8 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Two-kernel design: each program handles a single (b, t) position and loops
-over all heads of its tensor (Q or K). Position cache values (cos/sin) are
-computed once per (b,t) and reused across all heads.
+Grid = (B * (Hq+Hkv), T): each program handles one (b, head, t) tuple.
+Position cache values shared across heads within the same warp.
+Uses num_warps=1 for minimal register pressure (D=64/128 are small).
"""
import torch
@@ -15,14 +15,15 @@
def _rope_kernel(
x_ptr,
x_out_ptr,
+ cos_cache_ptr,
+ sin_cache_ptr,
pos_t_ptr,
pos_h_ptr,
pos_w_ptr,
- cos_cache_ptr,
- sin_cache_ptr,
B_val: tl.constexpr,
T_val: tl.constexpr,
- H_val: tl.constexpr,
+ H_total: tl.constexpr,
+ Hq_val: tl.constexpr,
D_val: tl.constexpr,
stride_xb,
stride_xt,
@@ -37,17 +38,25 @@
half: tl.constexpr,
BLOCK_D: tl.constexpr,
):
- """Grid = (B * T). Each program processes one (b,t) across ALL H heads."""
- pid_bt = tl.program_id(0)
- pid_b = pid_bt // T_val
- pid_t = pid_bt % T_val
+ """Grid = (B * (Hq+Hkv), T). Each program = one (b, h, t).
+ Cos/sin cache values are shared for same (b,t), so threads within warp
+ compute them once and reuse."""
+ pid_bh = tl.program_id(0)
+ pid_t = tl.program_id(1)
+ warp_id = tl.program_id(2)
+
+ # Decode (batch, head)
+ pid_b = pid_bh // H_total
+ pid_h = pid_bh % H_total
+
+ pid_t2 = pid_bh * T_val + pid_t
pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- # Load cache values once (shared across all heads)
+ # Load cache values
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
@@ -64,23 +73,23 @@
cos_val = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
sin_val = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
- # Precompute permutation offsets
+ # Permutation
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
- # Loop over all heads
- for h in tl.static_range(H_val):
- # Input offset (B,T,H,D)
- x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
+ # Input offset
+ x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
- out = x * cos_val + x_rot * sin_val
+ # Load x
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
- # Output offset (B,H,T,D)
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ out = x * cos_val + x_rot * sin_val
+
+ # Output offset (B,H,T,D)
+ o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
@@ -101,6 +110,8 @@
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+ self._block_d = self._next_pow2(D)
+
def _next_pow2(self, n):
v = 1
while v < n:
@@ -117,47 +128,41 @@
s0, s1, _ = self.mrope_section
half = D // 2
BLOCK_D = self._next_pow2(D)
+ H_total = Hq + Hkv
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
- bt_grid = (B * T,)
+ # Concatenate Q and K buffers along head dim:
+ # q is (B, T, Hq, D), k is (B, T, Hkv, D)
+ # Concat => x_cat (B, T, H_total, D)
+ x_cat = torch.cat([q, k], dim=2)
+ x_out = torch.empty(B, H_total, T, D, dtype=torch.bfloat16, device=q.device)
- # Q kernel
- _rope_kernel[bt_grid](
- q, q_out,
+ bh_grid = (B * H_total, T)
+ num_warps = 1
+
+ _rope_kernel[bh_grid](
+ x_cat, x_out,
+ self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
- self.cos_cache, self.sin_cache,
- B_val=B, T_val=T, H_val=Hq, D_val=D,
- stride_xb=T * Hq * D,
- stride_xt=Hq * D,
+ B_val=B, T_val=T, H_total=H_total, Hq_val=Hq, D_val=D,
+ stride_xb=T * H_total * D,
+ stride_xt=H_total * D,
stride_xh=D,
stride_xd=1,
- stride_ob=Hq * T * D,
+ stride_ob=H_total * T * D,
stride_oh=T * D,
stride_ot=D,
stride_od=1,
s0=s0, s1=s1, half=half,
BLOCK_D=BLOCK_D,
+ num_warps=num_warps,
)
- # K kernel
- _rope_kernel[bt_grid](
- k, k_out,
- pos_t, pos_h, pos_w,
- self.cos_cache, self.sin_cache,
- B_val=B, T_val=T, H_val=Hkv, D_val=D,
- stride_xb=T * Hkv * D,
- stride_xt=Hkv * D,
- stride_xh=D,
- stride_xd=1,
- stride_ob=Hkv * T * D,
- stride_oh=T * D,
- stride_ot=D,
- stride_od=1,
- s0=s0, s1=s1, half=half,
- BLOCK_D=BLOCK_D,
- )
+ # Split back into Q and K outputs
+
... (131 more chars){
"content": "<5755 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}shape=0 variant=eager tflops=0.178 gbps=255.270 ms=0.354
shape=0 variant=compiled tflops=0.503 gbps=721.699 ms=0.125
shape=0 variant=solution tflops=0.381 gbps=546.252 ms=0.165
shape=0 solution_peak_fraction=0.3035
shape=1 variant=eager tflops=0.180 gbps=263.466 ms=0.558
shape=1 variant=compiled tflops=0.455 gbps=664.016 ms=0.221
shape=1 variant=solution tflops=0.388 gbps=567.049 ms=0.259
shape=1 solution_peak_fraction=0.3150
shape=2 variant=eager tflops=0.074 gbps=116.373 ms=0.190
shape=2 variant=compiled tflops=0.168 gbps=262.864 ms=0.084
shape=2 variant=solution tflops=0.202 gbps=315.977 ms=0.070
shape=2 solution_peak_fraction=0.1755
shape=3 variant=eager tflops=0.169 gbps=242.063 ms=1.492
shape=3 variant=compiled tflops=0.461 gbps=661.323 ms=0.546
shape=3 variant=solution tflops=0.403 gbps=578.232 ms=0.624
shape=3 solution_peak_fraction=0.3212
peak_fraction: 0.2710
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,13 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Grid = (B * (Hq+Hkv), T): each program handles one (b, head, t) tuple.
-Position cache values shared across heads within the same warp.
-Uses num_warps=1 for minimal register pressure (D=64/128 are small).
+Grid = (B * T, num_warps). Each program handles one (b,t) position and loops
+over all heads of its tensor (either Q or K). Cache values computed once per
+(b,t) and reused across heads in the same warp.
+
+Key insight: D=64/128 is small enough that a single thread can process the
+entire vector per head. num_warps=1 is optimal for small vector sizes.
+
+Two-pass: one for Q, one for K, with separate grids.
"""
import torch
@@ -22,8 +27,7 @@
pos_w_ptr,
B_val: tl.constexpr,
T_val: tl.constexpr,
- H_total: tl.constexpr,
- Hq_val: tl.constexpr,
+ H_val: tl.constexpr,
D_val: tl.constexpr,
stride_xb,
stride_xt,
@@ -38,28 +42,20 @@
half: tl.constexpr,
BLOCK_D: tl.constexpr,
):
- """Grid = (B * (Hq+Hkv), T). Each program = one (b, h, t).
- Cos/sin cache values are shared for same (b,t), so threads within warp
- compute them once and reuse."""
- pid_bh = tl.program_id(0)
- pid_t = tl.program_id(1)
- warp_id = tl.program_id(2)
-
- # Decode (batch, head)
- pid_b = pid_bh // H_total
- pid_h = pid_bh % H_total
-
- pid_t2 = pid_bh * T_val + pid_t
+ """Grid = (B * T). Each program handles one (b, t) and loops over H heads."""
+ pid_bt = tl.program_id(0)
+ pid_b = pid_bt // T_val
+ pid_t = pid_bt % T_val
pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- # Load cache values
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
+ # Load cache values once
cos_t_ = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t_ = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -67,29 +63,28 @@
cos_w_ = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w_ = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ # Axis assignment (same for both halves)
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
is_temporal = axis_idx < s0
is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
cos_val = tl.where(is_temporal, cos_t_, tl.where(is_height, cos_h_, cos_w_))
sin_val = tl.where(is_temporal, sin_t_, tl.where(is_height, sin_h_, sin_w_))
- # Permutation
+ # rotate_half permutation
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
- # Input offset
- x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+ # Loop over heads
+ for h in tl.static_range(H_val):
+ x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
- # Load x
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
+ out = x * cos_val + x_rot * sin_val
- out = x * cos_val + x_rot * sin_val
-
- # Output offset (B,H,T,D)
- o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
@@ -128,41 +123,51 @@
s0, s1, _ = self.mrope_section
half = D // 2
BLOCK_D = self._next_pow2(D)
- H_total = Hq + Hkv
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
- # Concatenate Q and K buffers along head dim:
- # q is (B, T, Hq, D), k is (B, T, Hkv, D)
- # Concat => x_cat (B, T, H_total, D)
- x_cat = torch.cat([q, k], dim=2)
- x_out = torch.empty(B, H_total, T, D, dtype=torch.bfloat16, device=q.device)
+ bt_grid = (B * T,)
- bh_grid = (B * H_total, T)
- num_warps = 1
-
- _rope_kernel[bh_grid](
- x_cat, x_out,
+ # Q kernel
+ _rope_kernel[bt_grid](
+ q, q_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
- B_val=B, T_val=T, H_total=H_total, Hq_val=Hq, D_val=D,
- stride_xb=T * H_total * D,
- stride_xt=H_total * D,
+ B_val=B, T_val=T, H_val=Hq, D_val=D,
+ stride_xb=T * Hq * D,
+ stride_xt=Hq * D,
stride_xh=D,
stride_xd=1,
- stride_ob=H_total * T * D,
+ stride_ob=Hq * T * D,
stride_oh=T * D,
stride_ot=D,
stride_od=1,
s0=s0, s1=s1, half=half,
BLOCK_D=BLOCK_D,
- num_warps=num_warps,
+ num_warps=1,
+ num_stages=4,
)
- # Split back into Q and K outputs
- q_out = x_out[:, :Hq, :, :].contiguous()
- k_out = x_out[:, Hq:, :, :].contiguous()
+ # K kernel
+ _rope_kernel[bt_grid](
+ k, k
... (539 more chars){
"content": "<5866 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run correctness and benchmark"
}PASS
shape=0 variant=eager tflops=0.158 gbps=226.942 ms=0.398
shape=0 variant=compiled tflops=0.475 gbps=682.254 ms=0.132
shape=0 variant=solution tflops=0.735 gbps=1054.230 ms=0.086
shape=0 solution_peak_fraction=0.5857
shape=1 variant=eager tflops=0.209 gbps=305.908 ms=0.481
shape=1 variant=compiled tflops=0.479 gbps=700.200 ms=0.210
shape=1 variant=solution tflops=0.776 gbps=1133.119 ms=0.130
shape=1 solution_peak_fraction=0.6295
shape=2 variant=eager tflops=0.075 gbps=117.401 ms=0.188
shape=2 variant=compiled tflops=0.163 gbps=255.149 ms=0.087
shape=2 variant=solution tflops=0.333 gbps=519.895 ms=0.043
shape=2 solution_peak_fraction=0.2888
shape=3 variant=eager tflops=0.173 gbps=248.330 ms=1.454
shape=3 variant=compiled tflops=0.461 gbps=661.943 ms=0.546
shape=3 variant=solution tflops=0.874 gbps=1253.901 ms=0.288
shape=3 solution_peak_fraction=0.6966
peak_fraction: 0.5219
RESULT: OK
{
"description": "Profile Triton kernel",
"prompt": "<1570 chars \u2014 see below>"
}Run the benchmark script and also profile the solution to understand the bottleneck.
First, change to: /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope
Run:
1. `uv run python benchmark.py` to see current performance numbers
2. Then run this profiling script:
```python
import torch
import torch.profiler
import sys
sys.path.insert(0, '/home/infatoshi/cuda/KernelBench-Hard')
# Import solution
import solution
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
MAX_POS = 32768
MROPE_SECTION = (16, 24, 24)
dev = "cuda:0"
model = solution.Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).to(dev).eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device=dev) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device=dev) * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device=dev)
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device=dev)
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device=dev)
# Warmup
with torch.no_grad():
model(q, k, pos_t, pos_h, pos_w)
torch.cuda.synchronize()
# Time with profiler
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
)
for _ in range(50):
with torch.no_grad():
model(q, k, pos_t, pos_h, pos_w)
torch.cuda.synchronize()
prof.stop()
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20))
```
3. Also run the profiling for shape 2 (B=2, T=2048, Hq=16, Hkv=2, D=64, max_pos=16384)
Return the benchmark outputs and profiler tables for both shapes.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,100 +1,149 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
+"""Optimized multi-axis RoPE pre-attention prep using fused CUDA C++ (SM120 Blackwell).
-Grid = (B * T, num_warps). Each program handles one (b,t) position and loops
-over all heads of its tensor (either Q or K). Cache values computed once per
-(b,t) and reused across heads in the same warp.
-
-Key insight: D=64/128 is small enough that a single thread can process the
-entire vector per head. num_warps=1 is optimal for small vector sizes.
-
-Two-pass: one for Q, one for K, with separate grids.
+Fused kernel streams q/k once: gather cos/sin from cache, mix by mrope_section,
+apply rotate-half, and write directly to transposed (B,H,T,D) layout.
"""
import torch
-import triton
-import triton.language as tl
import torch.nn as nn
+_KERNEL_SOURCE = r"""
+#include <cuda_bf16.h>
-@triton.jit
-def _rope_kernel(
- x_ptr,
- x_out_ptr,
- cos_cache_ptr,
- sin_cache_ptr,
- pos_t_ptr,
- pos_h_ptr,
- pos_w_ptr,
- B_val: tl.constexpr,
- T_val: tl.constexpr,
- H_val: tl.constexpr,
- D_val: tl.constexpr,
- stride_xb,
- stride_xt,
- stride_xh,
- stride_xd,
- stride_ob,
- stride_oh,
- stride_ot,
- stride_od,
- s0: tl.constexpr,
- s1: tl.constexpr,
- half: tl.constexpr,
- BLOCK_D: tl.constexpr,
-):
- """Grid = (B * T). Each program handles one (b, t) and loops over H heads."""
- pid_bt = tl.program_id(0)
- pid_b = pid_bt // T_val
- pid_t = pid_bt % T_val
+struct bf16x2 { __nv_bfloat16 x, y; };
- pos_idx = pid_b * T_val + pid_t
- p_t = tl.load(pos_t_ptr + pos_idx)
- p_h = tl.load(pos_h_ptr + pos_idx)
- p_w = tl.load(pos_w_ptr + pos_idx)
+// Rotate half: for element i in [0, D):
+// i < D/2: rot[i] = -x[i + D/2]
+// i >= D/2: rot[i] = x[i - D/2]
+// Applied as: out[i] = x[i]*cos[i] + rot[i]*sin[i]
+// mRoPE: for index i, axis_pair = i < D/2 ? i : i - D/2
+// axis_pair < s0 -> use temporal (t)
+// s0 <= axis_pair < s0+s1 -> use height (h)
+// otherwise -> use width (w)
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
+template <const int D, const bool IS_Q>
+__global__ void __launch_bounds__(256)
+mrope_kernel(
+ const __nv_bfloat16* __restrict__ x, // (B, T, H, D) bf16
+ __nv_bfloat16* __restrict__ x_out, // (B, H, T, D) bf16
+ const int64_t* __restrict__ pos_t_arr, // (B, T)
+ const int64_t* __restrict__ pos_h_arr,
+ const int64_t* __restrict__ pos_w_arr,
+ const __nv_bfloat16* __restrict__ cos_cache, // (max_pos, D)
+ const __nv_bfloat16* __restrict__ sin_cache,
+ const int B, const int T, const int H,
+ const int s0, const int s1,
+ const int D_half
+) {
+ const int bt = blockIdx.x;
+ const int b = bt / T;
+ int t = bt % T;
+
+ const int h = blockIdx.y;
+
+ // Position indices
+ const int pos_idx = b * T + t;
+ const int64_t pt = pos_t_arr[pos_idx];
+ const int64_t ph = pos_h_arr[pos_idx];
+ const int64_t pw = pos_w_arr[pos_idx];
+
+ // Per-thread output: thread handles D elements (or D/2 if unrolled)
+ const int half = D_half;
+
+ for (int i = threadIdx.x * 8; i < D; i += blockDim.x * 8) {
+ if (i >= D) break;
+
+ // Process up to 8 elements
+ int d = i;
+ int d_end = min(i + 8, D);
+
+ // Load position IDs once
+ // Load cos/sin from cache
+ float2 cos_vals[8], sin_vals[8];
+
+ #pragma unroll
+ for (int idx = 0; idx < 8 && (d + idx) < D; ++idx) {
+ const int di = d + idx;
+
+ // Axis pair
+ const int axis = (di < half) ? di : (di - half);
+
+ // Select cos from cache based on axis
+ float ct, st;
+ if (axis < s0) {
+ // Temporal
+ float2 c = __bfloat1622float2(((const float2*)cos_cache)[pt * (D/2) + di/2]);
+ float2 s = __bfloat1622float2(((const float2*)sin_cache)[pt * (D/2) + di/2]);
+ if (di % 2 == 0) { ct = c.x; st = s.x; }
+ else { ct = c.y; st = s.y; }
+ }
+ else if (axis < s0 + s1) {
+ // Height
+ float2 c = __bfloat1622float2(((const float2*)cos_cache)[ph * (D/2) + di/2]);
+ float2 s = __bfloat1622float2(((const float2*)sin_cache)[ph * (D/2) + di/2]);
+ if (di % 2 == 0) { ct = c.x; st = s.x; }
+ else { ct = c.y; st = s.y; }
+ }
+ else {
+ // Width
+ float2 c = __bfloat1622float2(((const float2*)cos_cache)[pw * (D/2) + di/2]);
+ float2 s = __bfloat1622float2(((const float2*)sin_cache)[pw * (D/2) + di/2]);
+ if (di % 2 == 0) { ct = c.x; st = s.x; }
+ else { ct = c.y; st = s.y; }
+ }
+ cos_vals[idx] = make_float2(ct, 0);
+ sin_vals[idx] = make_float2(st, 0);
+ }
+
+ #pragma unroll
+ for (int idx = 0; idx < 8 && (d + idx) < D; ++idx) {
+ const int di = d + idx;
+
+ // Load x
+ float2 x_packed = __bfloat1622float2(((const float2*)x)[(b * T * H * D + t * H * D + h * D + di) / 2]);
+ float x_val = (di % 2 == 0) ? x_packed.x : x_packed.y;
+
+ // Load rotated element
+ int ri = (di < half) ? (di + half) : (di - half);
+ float2 rot_packed = __bfloat1622float2(((const float2*)x)[(b * T * H * D + t * H * D + h * D + ri) / 2]);
+ float rot_val = (ri % 2 == 0) ? rot_packed.x : rot_packed.y;
+ if (di < half) rot_val = -rot_val;
+
... (5542 more chars){
"content": "<5855 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,144 +1,101 @@
-"""Optimized multi-axis RoPE pre-attention prep using fused CUDA C++ (SM120 Blackwell).
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Fused kernel streams q/k once: gather cos/sin from cache, mix by mrope_section,
-apply rotate-half, and write directly to transposed (B,H,T,D) layout.
+Key design: Grid = (B * T) where each program processes ALL heads for one (b,t).
+This maximizes L2 cache reuse on the cos/sin cache lookups since all heads
+in a block share the same position IDs.
+
+Per-axis mixing follows mRoPE convention:
+- First s0 pair-indices [0..s0) use temporal axis
+- Next s1 [s0..s0+s1) use height axis
+- Remaining [s0+s1..D/2) use width axis
+- Second half [D/2..D) mirrors: index d gets axis d - D/2
"""
import torch
+import triton
+import triton.language as tl
import torch.nn as nn
-_KERNEL_SOURCE = r"""
-#include <cuda_bf16.h>
-struct bf16x2 { __nv_bfloat16 x, y; };
+@triton.jit
+def _rope_fwd_kernel(
+ x_ptr,
+ x_out_ptr,
+ cos_cache_ptr,
+ sin_cache_ptr,
+ pos_t_ptr,
+ pos_h_ptr,
+ pos_w_ptr,
+ B_val: tl.constexpr,
+ T_val: tl.constexpr,
+ H_val: tl.constexpr,
+ D_val: tl.constexpr,
+ stride_xb,
+ stride_xt,
+ stride_xh,
+ stride_xd,
+ stride_ob,
+ stride_oh,
+ stride_ot,
+ stride_od,
+ s0: tl.constexpr,
+ s1: tl.constexpr,
+ half: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """Grid = (B * T). Each program handles one (b, t) across all H heads.
+ Cos/sin cache is loaded once per (b,t) and reused across heads."""
+ pid_bt = tl.program_id(0)
+ pid_b = pid_bt // T_val
+ pid_t = pid_bt % T_val
-// Rotate half: for element i in [0, D):
-// i < D/2: rot[i] = -x[i + D/2]
-// i >= D/2: rot[i] = x[i - D/2]
-// Applied as: out[i] = x[i]*cos[i] + rot[i]*sin[i]
-// mRoPE: for index i, axis_pair = i < D/2 ? i : i - D/2
-// axis_pair < s0 -> use temporal (t)
-// s0 <= axis_pair < s0+s1 -> use height (h)
-// otherwise -> use width (w)
+ pos_idx = pid_b * T_val + pid_t
+ p_t = tl.load(pos_t_ptr + pos_idx)
+ p_h = tl.load(pos_h_ptr + pos_idx)
+ p_w = tl.load(pos_w_ptr + pos_idx)
-template <const int D, const bool IS_Q>
-__global__ void __launch_bounds__(256)
-mrope_kernel(
- const __nv_bfloat16* __restrict__ x, // (B, T, H, D) bf16
- __nv_bfloat16* __restrict__ x_out, // (B, H, T, D) bf16
- const int64_t* __restrict__ pos_t_arr, // (B, T)
- const int64_t* __restrict__ pos_h_arr,
- const int64_t* __restrict__ pos_w_arr,
- const __nv_bfloat16* __restrict__ cos_cache, // (max_pos, D)
- const __nv_bfloat16* __restrict__ sin_cache,
- const int B, const int T, const int H,
- const int s0, const int s1,
- const int D_half
-) {
- const int bt = blockIdx.x;
- const int b = bt / T;
- int t = bt % T;
-
- const int h = blockIdx.y;
-
- // Position indices
- const int pos_idx = b * T + t;
- const int64_t pt = pos_t_arr[pos_idx];
- const int64_t ph = pos_h_arr[pos_idx];
- const int64_t pw = pos_w_arr[pos_idx];
-
- // Per-thread output: thread handles D elements (or D/2 if unrolled)
- const int half = D_half;
-
- for (int i = threadIdx.x * 8; i < D; i += blockDim.x * 8) {
- if (i >= D) break;
-
- // Process up to 8 elements
- int d = i;
- int d_end = min(i + 8, D);
-
- // Load position IDs once
- // Load cos/sin from cache
- float2 cos_vals[8], sin_vals[8];
-
- #pragma unroll
- for (int idx = 0; idx < 8 && (d + idx) < D; ++idx) {
- const int di = d + idx;
-
- // Axis pair
- const int axis = (di < half) ? di : (di - half);
-
- // Select cos from cache based on axis
- float ct, st;
- if (axis < s0) {
- // Temporal
- float2 c = __bfloat1622float2(((const float2*)cos_cache)[pt * (D/2) + di/2]);
- float2 s = __bfloat1622float2(((const float2*)sin_cache)[pt * (D/2) + di/2]);
- if (di % 2 == 0) { ct = c.x; st = s.x; }
- else { ct = c.y; st = s.y; }
- }
- else if (axis < s0 + s1) {
- // Height
- float2 c = __bfloat1622float2(((const float2*)cos_cache)[ph * (D/2) + di/2]);
- float2 s = __bfloat1622float2(((const float2*)sin_cache)[ph * (D/2) + di/2]);
- if (di % 2 == 0) { ct = c.x; st = s.x; }
- else { ct = c.y; st = s.y; }
- }
- else {
- // Width
- float2 c = __bfloat1622float2(((const float2*)cos_cache)[pw * (D/2) + di/2]);
- float2 s = __bfloat1622float2(((const float2*)sin_cache)[pw * (D/2) + di/2]);
- if (di % 2 == 0) { ct = c.x; st = s.x; }
- else { ct = c.y; st = s.y; }
- }
- cos_vals[idx] = make_float2(ct, 0);
- sin_vals[idx] = make_float2(st, 0);
- }
-
- #pragma unroll
- for (int idx = 0; idx < 8 && (d + idx) < D; ++idx) {
- const int di = d + idx;
-
- // Load x
- float2 x_packed = __bfloat1622float2(((const float2*)x)[(b * T * H * D + t * H * D + h * D + di) / 2]);
- float x_val = (di % 2 == 0) ? x_packed.x : x_packed.y;
-
- // Load rotated element
- int ri = (di < half) ? (di + half) : (di - half);
- float2 rot_packed = __bfloat1622float2(((const float2*)x)[(b * T * H * D + t * H * D + h * D + ri) / 2]);
- float rot_val = (ri % 2 == 0) ? rot_packed.x : rot_packed.y;
- if (di < half) r
... (5119 more chars){
"content": "<5745 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py
@@ -0,0 +1,238 @@
+"""Hand-tuned CUDA C++ kernel for multi-axis RoPE pre-attention prep (SM120 Blackwell).
+
+Single fused kernel: one thread per (b, t, h, D/2) pair, loads two bf16 at once,
+does both elements in the pair, computes cos/sin from the same axis lookup,
+applies rotate-half in one pass, and writes transposed (B,H,T,D).
+
+Uses __ldg for read-only loads and vectorized bf16x2 stores.
+"""
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+import os
+import tempfile
+
+CUDA_SRC = r"""
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+#include <cooperative_groups.h>
+
+namespace cg = cooperative_groups;
+
+// BF16 -> float2
+static __device__ __forceinline__ float2 bf16x2_to_float2(
+ __nv_bfloat162 v) {
+ return __bfloat1622float2(v);
+}
+
+static __device__ __forceinline__ __nv_bfloat162 float2_to_bf16x2(
+ float x, float y) {
+ return __floats2bfloat162_rn(x, y);
+}
+
+template <const int D>
+__global__ void __launch_bounds__(256) mrope_kernel(
+ const __nv_bfloat16* __restrict__ x_in, // (B, T, Hx, D)
+ __nv_bfloat16* __restrict__ x_out, // (B, Hx, T, D)
+ const int64_t* __restrict__ pos_t,
+ const int64_t* __restrict__ pos_h,
+ const int64_t* __restrict__ pos_w,
+ const __nv_bfloat16* __restrict__ cos_cache, // (max_pos, D)
+ const __nv_bfloat16* __restrict__ sin_cache,
+ const int B,
+ const int T,
+ const int Hx,
+ const int s0,
+ const int s1
+) {
+ // Each block of threads handles one (b, t); threads distribute over heads.
+ // Each thread handles ALL elements of D for ONE head.
+ // Grid = (B * T, Hx), block = (num_threads), each thread processes D elements.
+
+ const int bt = blockIdx.x; // (B * T)
+ const int b = bt / T;
+ const int t = bt % T;
+
+ // Position IDs (shared across all threads in block)
+ const int pos_off = b * T + t;
+ const int64_t pt = __ldg(pos_t + pos_off);
+ const int64_t ph = __ldg(pos_h + pos_off);
+ const int64_t pw = __ldg(pos_w + pos_off);
+
+ // Thread handles multiple heads (if num_heads > block_size)
+ const int h = blockIdx.y * blockDim.x + threadIdx.x;
+ if (h >= Hx) return;
+
+ constexpr int half = D / 2;
+ constexpr int d_pairs = D / 2; // number of bf16x2 pairs
+
+ // Load cos/sin cache for each pair
+ __nv_bfloat162 cos_d[D/2];
+ __nv_bfloat162 sin_d[D/2];
+
+ #pragma unroll
+ for (int p = 0; p < D/2; p++) {
+ int di = p;
+ int axis = (di < half) ? di : (di - half);
+
+ const __nv_bfloat162 *cache_cos_row, *cache_sin_row;
+ if (axis < s0) {
+ cache_cos_row = (const __nv_bfloat162*)&cos_cache[pt * D];
+ cache_sin_row = (const __nv_bfloat162*)&sin_cache[pt * D];
+ } else if (axis < s0 + s1) {
+ cache_cos_row = (const __nv_bfloat162*)&cos_cache[ph * D];
+ cache_sin_row = (const __nv_bfloat162*)&sin_cache[ph * D];
+ } else {
+ cache_cos_row = (const __nv_bfloat162*)&cos_cache[pw * D];
+ cache_sin_row = (const __nv_bfloat162*)&sin_cache[pw * D];
+ }
+ cos_d[p] = __ldg(&cache_cos_row[p]);
+ sin_d[p] = __ldg(&cache_sin_row[p]);
+ }
+
+ // Input stride: BxTxDxH
+ const long long x_base = (long long)b * T * Hx * D + (long long)t * Hx * D + (long long)h * D;
+ const __nv_bfloat162* x_in_vec = (const __nv_bfloat162*)&x_in[x_base];
+
+ for (int p = 0; p < d_pairs; p++) {
+ // Load x[p] - this is one pair at offset p
+ __nv_bfloat162 x_p = __ldg(&x_in_vec[p]);
+
+ // rotate_half for pair p:
+ // if p < half/2 (i.e., both elements in first half): rot[p] = -x[p + half/2]
+ // if p >= half/2: rot[p] = x[p - half/2]
+
+ __nv_bfloat162 rot_p;
+ float2 x_f0 = __bfloat1622float2(x_p);
+ int mirror_p = (p < half/2) ? (p + half/2) : (p - half/2);
+
+ __nv_bfloat162 x_mirror = __ldg(&x_in_vec[mirror_p]);
+ float2 xm_f = __bfloat1622float2(x_mirror);
+
+ if (p < half/2) {
+ // negate
+ xm_f.x = -xm_f.x;
+ xm_f.y = -xm_f.y;
+ }
+ rot_p = __floats2bfloat162_rn(xm_f.x, xm_f.y);
+
+ // Apply: out = x * cos + rot * sin
+ float2 cos_f = __bfloat1622float2(cos_d[p]);
+ float2 sin_f = __bfloat1622float2(sin_d[p]);
+
+ float2 out_f;
+ out_f.x = x_f0.x * cos_f.x + xm_f.x * sin_f.x;
+ out_f.y = x_f0.y * cos_f.y + xm_f.y * sin_f.y;
+
+ // Store to transposed (B, H, T, D) at out_off = b * Hx * T * D + h * T * D + t * D + 2*p
+ long long out_off = (long long)b * Hx * T * D + (long long)h * T * D + (long long)t * D + (long long)p * 2;
+ x_out[out_off / 2] = float2_to_bf16x2(out_f.x, out_f.y);
+ }
+}
+
+// Kernel launcher dispatcher
+void launch_mrope(const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
+ const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+ const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+ int B, int T, int Hx, int D, int s0, int s1,
+ cudaStream_t stream) {
+
+ dim3 block(256);
+ dim3 grid(B * T, (Hx + block.x - 1) / block.x);
+
+ // Dispatch on D
+ if (D == 64) {
+ mrope_kernel<64><<<grid, block, 0, stream>>>(
+ x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx, s0, s1);
+ } else if (D == 128) {
+ mrope_kernel<128><<<grid, block, 0, stream>>>(
+ x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx, s0, s1);
+ }
+}
+"""
+
+
+def _build_inv_freq(D, base=10000.0):
+ import torch
+ return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+
+
+_mod = None
+
+def _get_module():
+ global _mod
+ if _mod is not None:
+ return _mod
+
+ import
... (2173 more chars){
"content": "<7811 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py"
}Wrote file successfully.{
"command": "<1068 chars \u2014 see below>",
"timeout": 120000,
"description": "Test CUDA kernel correctness"
}cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python -c "
import torch
torch.set_printoptions(precision=4)
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
MAX_POS = 32768
MROPE_SECTION = (16, 24, 24)
import solution_cuda
model = solution_cuda.Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64, device='cuda')
import reference
ref = reference.Model(B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS).cuda()
with torch.no_grad():
ref_q, ref_k = ref(q, k, pos_t, pos_h, pos_w)
sol_q, sol_k = model(q, k, pos_t, pos_h, pos_w)
print('Q max diff:', (ref_q.float() - sol_q.float()).abs().max().item())
print('K max diff:', (ref_k.float() - sol_k.float()).abs().max().item())
" 2>&1Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 23, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py", line 236, in forward
q_out = self._run(q, Hq)
^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py", line 218, in _run
_get_module().launch_mrope(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda.py", lin
... (9992 more chars){
"tool": "delete",
"error": "Model tried to call unavailable tool 'delete'. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill."
}The arguments provided to the tool are invalid: Model tried to call unavailable tool 'delete'. Available tools: invalid, bash, read, glob, grep, edit, write, task, webfetch, todowrite, skill.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,14 +1,11 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Key design: Grid = (B * T) where each program processes ALL heads for one (b,t).
-This maximizes L2 cache reuse on the cos/sin cache lookups since all heads
-in a block share the same position IDs.
-
-Per-axis mixing follows mRoPE convention:
-- First s0 pair-indices [0..s0) use temporal axis
-- Next s1 [s0..s0+s1) use height axis
-- Remaining [s0+s1..D/2) use width axis
-- Second half [D/2..D) mirrors: index d gets axis d - D/2
+Key optimizations:
+- Grid = (B * T): each program processes all H heads for one (b,t) position
+- Cos/sin cache lookup done once per (b,t), reused for all heads
+- Uses vectorized bf16x2 loads (int32 packing) for half memory transactions
+- num_warps=8 with num_stages=4 for max throughput
+- Avoids register pressure by interleaving loads/computes per head
"""
import torch
@@ -41,10 +38,9 @@
s0: tl.constexpr,
s1: tl.constexpr,
half: tl.constexpr,
- BLOCK_D: tl.constexpr,
):
"""Grid = (B * T). Each program handles one (b, t) across all H heads.
- Cos/sin cache is loaded once per (b,t) and reused across heads."""
+ Vectorizes over D/2 bf16x2 pairs to minimize memory ops."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
@@ -54,41 +50,39 @@
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
+ # Work in bf16x2 pairs (D must be divisible by 2, half must be divisible by 2)
+ half_pairs = half // 2
+ total_pairs = D_val // 2
+ pid = tl.arange(0, 128)
+ # Load cos/sin
+ cos_t = tl.load(cos_cache_ptr + p_t * D_val + pid, mask=pid < D_val, other=0.0).to(tl.float32)
+ sin_t = tl.load(sin_cache_ptr + p_t * D_val + pid, mask=pid < D_val, other=0.0).to(tl.float32)
- # Load cos/sin for temporal axis
- cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Load cos/sin for height axis
- cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Load cos/sin for width axis
- cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ # Store in transposed (B,H,T,D)
+ x2_f = x2.to(tl.bfloat16)._reinterpret_int16()
+ tl.store(x_out_ptr + out_base + p_offs * stride_od, x2, mask=p_offs < D_val)
- # Axis index: d and d+half share the same axis pair
- axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_t = axis_idx < s0
- is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
-
- # rotate_half permutation + sign
- perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
- sign = tl.where(d_offs < half, -1.0, 1.0)
-
- for h in tl.static_range(H_val):
- x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
-
- out = x * cos_val + x_rot * sin_val
-
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+@triton.jit
+def mrope_2x2_kernel(
+ # Pointers to inputs
+ q_ptr, k_ptr, # (B, T, Hd, D)
+ q_out, k_out, # (B, Hd, T, D)
+ cos_cache, sin_cache, # (M, D)
+ pos_t, pos_h, pos_w, # (B, T)
+ # Strides (Q)
+ stride_qb, stride_qt, stride_qh, stride_qd,
+ stride_ob, stride_oh, stride_ot, stride_od,
+ # Parameters
+ stride_pos_t, stride_pos_wt, # pos_t strides
+ Hq, Hk, D, BpT,
+ Dhalf,
+ s0, s1,
+ # Meta
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ b, t = pid // T_val, pid % T_val
class Model(nn.Module):
@@ -144,8 +138,7 @@
stride_ot=D,
stride_od=1,
s0=s0, s1=s1, half=half,
- BLOCK_D=BLOCK_D,
- num_warps=1,
+ num_warps=8,
num_stages=4,
)
@@ -164,8 +157,7 @@
stride_ot=D,
stride_od=1,
s0=s0, s1=s1, half=half,
- BLOCK_D=BLOCK_D,
- num_warps=1,
+ num_warps=8,
num_stages=4,
)
{
"content": "<4912 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,11 +1,9 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Key optimizations:
-- Grid = (B * T): each program processes all H heads for one (b,t) position
-- Cos/sin cache lookup done once per (b,t), reused for all heads
-- Uses vectorized bf16x2 loads (int32 packing) for half memory transactions
-- num_warps=8 with num_stages=4 for max throughput
-- Avoids register pressure by interleaving loads/computes per head
+Grid = (B * T). Each thread loads/rotates/stores one full D vector.
+Cos/sin loaded once per block. Each program loops over H heads.
+
+num_warps=1, num_stages=1 is optimal for small D vectors.
"""
import torch
@@ -38,9 +36,9 @@
s0: tl.constexpr,
s1: tl.constexpr,
half: tl.constexpr,
+ BLOCK_D: tl.constexpr,
):
- """Grid = (B * T). Each program handles one (b, t) across all H heads.
- Vectorizes over D/2 bf16x2 pairs to minimize memory ops."""
+ """Grid = (B * T). Cos/sin cache loaded once per program, reused across H heads."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
@@ -50,46 +48,52 @@
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- # Work in bf16x2 pairs (D must be divisible by 2, half must be divisible by 2)
- half_pairs = half // 2
- total_pairs = D_val // 2
- pid = tl.arange(0, 128)
- # Load cos/sin
- cos_t = tl.load(cos_cache_ptr + p_t * D_val + pid, mask=pid < D_val, other=0.0).to(tl.float32)
- sin_t = tl.load(sin_cache_ptr + p_t * D_val + pid, mask=pid < D_val, other=0.0).to(tl.float32)
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D_val
- # Store in transposed (B,H,T,D)
- x2_f = x2.to(tl.bfloat16)._reinterpret_int16()
- tl.store(x_out_ptr + out_base + p_offs * stride_od, x2, mask=p_offs < D_val)
+ # Load cos/sin values for all three axes
+ cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ # Per-element axis index
+ axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+ is_temporal = axis_idx < s0
+ is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
-@triton.jit
-def mrope_2x2_kernel(
- # Pointers to inputs
- q_ptr, k_ptr, # (B, T, Hd, D)
- q_out, k_out, # (B, Hd, T, D)
- cos_cache, sin_cache, # (M, D)
- pos_t, pos_h, pos_w, # (B, T)
- # Strides (Q)
- stride_qb, stride_qt, stride_qh, stride_qd,
- stride_ob, stride_oh, stride_ot, stride_od,
- # Parameters
- stride_pos_t, stride_pos_wt, # pos_t strides
- Hq, Hk, D, BpT,
- Dhalf,
- s0, s1,
- # Meta
- BLOCK: tl.constexpr,
-):
- pid = tl.program_id(0)
- b, t = pid // T_val, pid % T_val
+ # Select per-axis cos/sin
+ cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
+ sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
+
+ # Permutation and sign for rotate_half
+ perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ sign = tl.where(d_offs < half, -1.0, 1.0)
+
+ for h in tl.static_range(H_val):
+ # Input offset (B,T,H,D)
+ x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+
+ # Load x and rotated x
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
+
+ # Apply rotation
+ out = x * cos_val + x_rot * sin_val
+
+ # Output offset (B,H,T,D)
+ o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__init__()
assert sum(mrope_section) == D // 2, \
- f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
+ f"mrope_section must sum to D/2, got {mrope_section}"
self.B = B
self.T = T
self.Hq = Hq
@@ -123,7 +127,6 @@
bt_grid = (B * T,)
- # Run Q kernel
_rope_fwd_kernel[bt_grid](
q, q_out,
self.cos_cache, self.sin_cache,
@@ -138,11 +141,11 @@
stride_ot=D,
stride_od=1,
s0=s0, s1=s1, half=half,
- num_warps=8,
+ BLOCK_D=BLOCK_D,
+ num_warps=1,
num_stages=4,
)
- # Run K kernel
_rope_fwd_kernel[bt_grid](
k, k_out,
self.cos_cache, self.sin_cache,
@@ -157,7 +160,8 @@
stride_ot=D,
stride_od=1,
s0=s0, s1=s1, half=half,
- num_warps=8,
+ BLOCK_D=BLOCK_D,
+ num_warps=1,
num_stages=4,
)
{
"content": "<5485 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run correctness and benchmark"
}PASS
shape=0 variant=eager tflops=0.178 gbps=255.814 ms=0.353
shape=0 variant=compiled tflops=0.510 gbps=731.144 ms=0.123
shape=0 variant=solution tflops=0.722 gbps=1036.605 ms=0.087
shape=0 solution_peak_fraction=0.5759
shape=1 variant=eager tflops=0.209 gbps=305.877 ms=0.481
shape=1 variant=compiled tflops=0.667 gbps=973.543 ms=0.151
shape=1 variant=solution tflops=0.996 gbps=1454.381 ms=0.101
shape=1 solution_peak_fraction=0.8080
shape=2 variant=eager tflops=0.075 gbps=116.599 ms=0.190
shape=2 variant=compiled tflops=0.160 gbps=249.576 ms=0.089
shape=2 variant=solution tflops=0.282 gbps=440.957 ms=0.050
shape=2 solution_peak_fraction=0.2450
shape=3 variant=eager tflops=0.174 gbps=249.271 ms=1.449
shape=3 variant=compiled tflops=0.466 gbps=668.908 ms=0.540
shape=3 variant=solution tflops=0.886 gbps=1270.703 ms=0.284
shape=3 solution_peak_fraction=0.7059
peak_fraction: 0.5326
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,9 +1,7 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Grid = (B * T). Each thread loads/rotates/stores one full D vector.
-Cos/sin loaded once per block. Each program loops over H heads.
-
-num_warps=1, num_stages=1 is optimal for small D vectors.
+Grid = (B * T). Each program processes ALL heads at one (b,t).
+Cache values loaded once per (b,t), shared across heads.
"""
import torch
@@ -38,7 +36,7 @@
half: tl.constexpr,
BLOCK_D: tl.constexpr,
):
- """Grid = (B * T). Cos/sin cache loaded once per program, reused across H heads."""
+ """Grid = (B * T). Each program handles one (b, t) across all H heads."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
@@ -51,7 +49,7 @@
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
- # Load cos/sin values for all three axes
+ # Load cache: 6 L2 loads per (b,t), shared across all heads
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -59,41 +57,92 @@
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Per-element axis index
+ # Axis index
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_temporal = axis_idx < s0
- is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
-
- # Select per-axis cos/sin
- cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
- sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
-
- # Permutation and sign for rotate_half
- perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ is_t = axis_idx < s0
+ is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+ sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+
+ perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
for h in tl.static_range(H_val):
- # Input offset (B,T,H,D)
x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
-
- # Load x and rotated x
x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
x_rot = x_rot * sign
-
- # Apply rotation
out = x * cos_val + x_rot * sin_val
-
- # Output offset (B,H,T,D)
o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+
+
+@triton.jit
+def _rope_batch_kernel(
+ x_ptr, x_out_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ B_val: tl.constexpr,
+ T_val: tl.constexpr,
+ H_total: tl.constexpr,
+ Hq_val: tl.constexpr,
+ D_val: tl.constexpr,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_ob, stride_oh, stride_ot, stride_od,
+ s0: tl.constexpr, s1: tl.constexpr,
+ half: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """Grid = (B * T). Each program handles all Hq+Hkv heads for (b,t).
+ Position cache loaded once, applied to both Q and K heads."""
+ pid_bt = tl.program_id(0)
+ pid_b = pid_bt // T_val
+ pid_t = pid_bt % T_val
+
+ pos_idx = pid_b * T_val + pid_t
+ p_t = tl.load(pos_t_ptr + pos_idx)
+ p_h = tl.load(pos_h_ptr + pos_idx)
+ p_w = tl.load(pos_w_ptr + pos_idx)
+
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D_val
+
+ cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+
+ axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+ is_t = axis_idx < s0
+ is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+ sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+
+ perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ sign = tl.where(d_offs < half, -1.0, 1.0)
+
+ # Q heads
+ for h_q in tl.static_range(Hq_val):
+ x_base = pid_b * stride_xb + pid_t * stride_xt + h_q * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
+ out = x * cos_val + x_rot * sin_val
+ o_base = pid_b * stride_ob + h_q * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+
+ # K heads
+ Hkv_val = H_total - Hq_val
+ k_stride_xh = Hq_val * stride_xd + stride_xd # Not correct, need to fix
+ # Actually we need separate strides for Q and K
class Model(nn.Module):
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__i
... (4660 more chars){
"content": "<9652 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}shape=0 variant=eager tflops=0.173 gbps=247.967 ms=0.364
shape=0 variant=compiled tflops=0.500 gbps=716.839 ms=0.126
shape=0 variant=solution tflops=0.637 gbps=914.612 ms=0.099
shape=0 solution_peak_fraction=0.5081
shape=1 variant=eager tflops=0.171 gbps=249.561 ms=0.589
shape=1 variant=compiled tflops=0.666 gbps=973.028 ms=0.151
shape=1 variant=solution tflops=0.629 gbps=918.457 ms=0.160
shape=1 solution_peak_fraction=0.5103
shape=2 variant=eager tflops=0.071 gbps=110.743 ms=0.200
shape=2 variant=compiled tflops=0.165 gbps=257.334 ms=0.086
shape=2 variant=solution tflops=0.295 gbps=461.415 ms=0.048
shape=2 solution_peak_fraction=0.2563
shape=3 variant=eager tflops=0.179 gbps=256.924 ms=1.405
shape=3 variant=compiled tflops=0.465 gbps=666.853 ms=0.542
shape=3 variant=solution tflops=0.512 gbps=735.097 ms=0.491
shape=3 solution_peak_fraction=0.4084
peak_fraction: 0.4059
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,8 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Grid = (B * T). Each program processes ALL heads at one (b,t).
+Grid = (B * T). Each program handles one (b, t) across all H heads.
Cache values loaded once per (b,t), shared across heads.
+num_warps=1, num_stages=4 is optimal for D=64..128.
"""
import torch
@@ -49,94 +50,33 @@
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
- # Load cache: 6 L2 loads per (b,t), shared across all heads
+ # Load cos/sin values for all three axes
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Axis index
+ # Per-element axis index
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_t = axis_idx < s0
- is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+ is_temporal = axis_idx < s0
+ is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
- perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
+ sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
+
+ perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
for h in tl.static_range(H_val):
x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
x_rot = x_rot * sign
out = x * cos_val + x_rot * sin_val
o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
-
-
-@triton.jit
-def _rope_batch_kernel(
- x_ptr, x_out_ptr,
- cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- B_val: tl.constexpr,
- T_val: tl.constexpr,
- H_total: tl.constexpr,
- Hq_val: tl.constexpr,
- D_val: tl.constexpr,
- stride_xb, stride_xt, stride_xh, stride_xd,
- stride_ob, stride_oh, stride_ot, stride_od,
- s0: tl.constexpr, s1: tl.constexpr,
- half: tl.constexpr,
- BLOCK_D: tl.constexpr,
-):
- """Grid = (B * T). Each program handles all Hq+Hkv heads for (b,t).
- Position cache loaded once, applied to both Q and K heads."""
- pid_bt = tl.program_id(0)
- pid_b = pid_bt // T_val
- pid_t = pid_bt % T_val
-
- pos_idx = pid_b * T_val + pid_t
- p_t = tl.load(pos_t_ptr + pos_idx)
- p_h = tl.load(pos_h_ptr + pos_idx)
- p_w = tl.load(pos_w_ptr + pos_idx)
-
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
-
- cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-
- axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_t = axis_idx < s0
- is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
-
- perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
- sign = tl.where(d_offs < half, -1.0, 1.0)
-
- # Q heads
- for h_q in tl.static_range(Hq_val):
- x_base = pid_b * stride_xb + pid_t * stride_xt + h_q * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
- out = x * cos_val + x_rot * sin_val
- o_base = pid_b * stride_ob + h_q * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
-
- # K heads
- Hkv_val = H_total - Hq_val
- k_stride_xh = Hq_val * stride_xd + stride_xd # Not correct, need to fix
- # Actually we need separate strides for Q and K
class Model(nn.Module):
@@ -170,23 +110,22 @@
s0, s1, _ = self.mrope_section
half = D // 2
BLOCK_D = self._next_pow2(D)
+
+ q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
+ k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
+
bt_grid = (B * T,)
- # Use Q+K concat for fused processing
- x_cat = torch.cat([q, k], dim=2).contiguous()
- H_total = Hq + Hkv
- x_out = torch.empty(B, H_total, T, D, dtype=torch.bfloat16, device=q
... (3901 more chars){
"content": "<5193 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -54,7 +54,7 @@
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " sin_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)",
"newString": " sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run correctness and benchmark after bugfix"
}PASS
shape=0 variant=eager tflops=0.177 gbps=254.660 ms=0.354
shape=0 variant=compiled tflops=0.489 gbps=700.986 ms=0.129
shape=0 variant=solution tflops=0.719 gbps=1032.054 ms=0.087
shape=0 solution_peak_fraction=0.5734
shape=1 variant=eager tflops=0.172 gbps=250.794 ms=0.586
shape=1 variant=compiled tflops=0.480 gbps=701.323 ms=0.210
shape=1 variant=solution tflops=0.755 gbps=1103.185 ms=0.133
shape=1 solution_peak_fraction=0.6129
shape=2 variant=eager tflops=0.072 gbps=113.200 ms=0.195
shape=2 variant=compiled tflops=0.177 gbps=276.038 ms=0.080
shape=2 variant=solution tflops=0.348 gbps=543.183 ms=0.041
shape=2 solution_peak_fraction=0.3018
shape=3 variant=eager tflops=0.179 gbps=256.594 ms=1.407
shape=3 variant=compiled tflops=0.527 gbps=756.231 ms=0.478
shape=3 variant=solution tflops=0.876 gbps=1256.694 ms=0.287
shape=3 solution_peak_fraction=0.6982
peak_fraction: 0.5216
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,8 +1,7 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Grid = (B * T). Each program handles one (b, t) across all H heads.
-Cache values loaded once per (b,t), shared across heads.
-num_warps=1, num_stages=4 is optimal for D=64..128.
+Grid = (B * T * num_warps). Each block processes one (b,t) with all H heads.
+Threads parallelize over D dimension for coalesced memory access.
"""
import torch
@@ -12,7 +11,7 @@
@triton.jit
-def _rope_fwd_kernel(
+def _rope_block_kernel(
x_ptr,
x_out_ptr,
cos_cache_ptr,
@@ -35,48 +34,61 @@
s0: tl.constexpr,
s1: tl.constexpr,
half: tl.constexpr,
- BLOCK_D: tl.constexpr,
+ BLOCK_D_HEAD: tl.constexpr,
):
- """Grid = (B * T). Each program handles one (b, t) across all H heads."""
+ """Grid = (B * T, H * triton.cdiv(D, BLOCK_D_HEAD)).
+ Each program handles one (b, t, h, d_block)."""
pid_bt = tl.program_id(0)
+ pid_hd = tl.program_id(1)
+
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
+
+ H_total = H_val
+ D_blocks = tl.cdiv(D_val, BLOCK_D_HEAD)
+ pid_h = pid_hd // D_blocks
+ pid_d = pid_hd % D_blocks
+
+ if pid_h >= H_val:
+ return
pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- d_offs = tl.arange(0, BLOCK_D)
+ d_start = pid_d * BLOCK_D_HEAD
+ d_offs = d_start + tl.arange(0, BLOCK_D_HEAD)
mask = d_offs < D_val
- # Load cos/sin values for all three axes
+ # Load cache
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Per-element axis index
+ # Axis index for each element
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_temporal = axis_idx < s0
- is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ is_t = axis_idx < s0
+ is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h_, cos_w))
+ sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h_, sin_w))
- cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
- sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
-
+ # Rotate-half permutation
perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
- for h in tl.static_range(H_val):
- x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
- out = x * cos_val + x_rot * sin_val
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ # Single head processing
+ x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
+ out = x * cos_val + x_rot * sin_val
+
+ o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
@@ -98,25 +110,23 @@
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
- def _next_pow2(self, n):
- v = 1
- while v < n:
- v *= 2
- return v
-
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
s0, s1, _ = self.mrope_section
half = D // 2
- BLOCK_D = self._next_pow2(D)
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
- bt_grid = (B * T,)
+ # For D=64: BLOCK=32 -> 2 blocks per head
+ # For D=128: BLOCK=32 -> 4 blocks per head
+ BLOCK_D_HEAD = 32
- _rope_fwd_kernel[bt_grid](
+ # Q grid
+ D_blocks_q = triton.cdiv(Hq * D, BLOCK_D_HEAD)
+ q_grid = (B * T, D_blocks_q)
+ _rope_block_kernel[q_grid](
q, q_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
@@ -130,12 +140,15 @@
stride_ot=D,
stride_od=1,
s0=s0, s1=s1, half=half,
- BLOCK_D=BLOCK_D,
- num_warps=1,
+ BLOCK_D_HEAD=BLOCK_D_HEAD,
+ num_warps=4,
num_stages=4,
)
- _rope_fwd_kernel[bt_grid](
+ # K grid
+ D_blocks_k = triton.cdiv(Hkv * D, BLOCK_D_HEAD)
+ k_grid = (B * T, D_blocks_k)
+ _rope_block_kernel[k_grid](
k, k_out,
self.cos_cache, self.s
... (316 more chars){
"content": "<5593 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Grid = (B * T * num_warps). Each block processes one (b,t) with all H heads.
-Threads parallelize over D dimension for coalesced memory access.
+Grid = (B * T). Each program handles one (b, t) across all H heads.
+Cache values loaded once per (b,t), shared across heads.
"""
import torch
@@ -11,7 +11,7 @@
@triton.jit
-def _rope_block_kernel(
+def _rope_fwd_kernel(
x_ptr,
x_out_ptr,
cos_cache_ptr,
@@ -34,61 +34,48 @@
s0: tl.constexpr,
s1: tl.constexpr,
half: tl.constexpr,
- BLOCK_D_HEAD: tl.constexpr,
+ BLOCK_D: tl.constexpr,
):
- """Grid = (B * T, H * triton.cdiv(D, BLOCK_D_HEAD)).
- Each program handles one (b, t, h, d_block)."""
+ """Grid = (B * T). Each program handles one (b, t) across all H heads."""
pid_bt = tl.program_id(0)
- pid_hd = tl.program_id(1)
-
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
-
- H_total = H_val
- D_blocks = tl.cdiv(D_val, BLOCK_D_HEAD)
- pid_h = pid_hd // D_blocks
- pid_d = pid_hd % D_blocks
-
- if pid_h >= H_val:
- return
pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- d_start = pid_d * BLOCK_D_HEAD
- d_offs = d_start + tl.arange(0, BLOCK_D_HEAD)
+ d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
- # Load cache
+ # Load cos/sin values for all three axes
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h_ = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h_ = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Axis index for each element
+ # Per-element axis index
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_t = axis_idx < s0
- is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h_, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h_, sin_w))
+ is_temporal = axis_idx < s0
+ is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
- # Rotate-half permutation
+ cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
+ sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
+
perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
- # Single head processing
- x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
- out = x * cos_val + x_rot * sin_val
-
- o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ for h in tl.static_range(H_val):
+ x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
+ out = x * cos_val + x_rot * sin_val
+ o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
@@ -110,23 +97,25 @@
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+ def _next_pow2(self, n):
+ v = 1
+ while v < n:
+ v *= 2
+ return v
+
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
s0, s1, _ = self.mrope_section
half = D // 2
+ BLOCK_D = self._next_pow2(D)
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
- # For D=64: BLOCK=32 -> 2 blocks per head
- # For D=128: BLOCK=32 -> 4 blocks per head
- BLOCK_D_HEAD = 32
+ bt_grid = (B * T,)
- # Q grid
- D_blocks_q = triton.cdiv(Hq * D, BLOCK_D_HEAD)
- q_grid = (B * T, D_blocks_q)
- _rope_block_kernel[q_grid](
+ _rope_fwd_kernel[bt_grid](
q, q_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
@@ -140,15 +129,12 @@
stride_ot=D,
stride_od=1,
s0=s0, s1=s1, half=half,
- BLOCK_D_HEAD=BLOCK_D_HEAD,
- num_warps=4,
+ BLOCK_D=BLOCK_D,
+ num_warps=1,
num_stages=4,
)
- # K grid
- D_blocks_k = triton.cdiv(Hkv * D, BLOCK_D_HEAD)
- k_grid = (B * T, D_blocks_k)
- _rope_block_kernel[k_grid](
+ _rope_fwd_kernel[bt_grid](
k, k_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
@@ -162,8
... (263 more chars){
"content": "<5141 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Grid = (B * T). Each program handles one (b, t) across all H heads.
-Cache values loaded once per (b,t), shared across heads.
+Optimized Triton kernel with head-parallelism for small head counts.
+Grid layout: (B * T, (H * D) / THREADS_PER_BLOCK) to maximize occupancy.
"""
import torch
@@ -19,10 +19,10 @@
pos_t_ptr,
pos_h_ptr,
pos_w_ptr,
- B_val: tl.constexpr,
- T_val: tl.constexpr,
- H_val: tl.constexpr,
- D_val: tl.constexpr,
+ BATCH: tl.constexpr,
+ SEQ: tl.constexpr,
+ NUM_HEADS: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
stride_xb,
stride_xt,
stride_xh,
@@ -33,52 +33,62 @@
stride_od,
s0: tl.constexpr,
s1: tl.constexpr,
- half: tl.constexpr,
- BLOCK_D: tl.constexpr,
+ HALF: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
):
- """Grid = (B * T). Each program handles one (b, t) across all H heads."""
- pid_bt = tl.program_id(0)
- pid_b = pid_bt // T_val
- pid_t = pid_bt % T_val
+ """Grid = (B * T * NUM_HEADS * triton.cdiv(D, BLOCK_SIZE)).
+ Each program handles a block of D for one (b, t, h)."""
+ pid = tl.program_id(0)
+ pid1 = tl.program_id(1)
- pos_idx = pid_b * T_val + pid_t
+ # Decode (b, t, h) from pid
+ total_heads_seq = NUM_HEADS * SEQ
+ bt_h = pid
+ pid_bt = bt_h // NUM_HEADS
+ pid_h = bt_h % NUM_HEADS
+ pid_b = pid_bt // SEQ
+ pid_t = pid_bt % SEQ
+
+ # Each program handles BLOCK_SIZE consecutive D elements
+ offset_d = pid1 * BLOCK_SIZE
+ d_offs = offset_d + tl.arange(0, BLOCK_SIZE)
+ mask = d_offs < HEAD_DIM
+
+ # Gather position IDs
+ pos_idx = pid_b * SEQ + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
- p_h = tl.load(pos_h_ptr + pos_idx)
+ p_h_val = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
+ # Load cos/sin for all three axes
+ cos_t = tl.load(cos_cache_ptr + p_t * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_t = tl.load(sin_cache_ptr + p_t * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h = tl.load(cos_cache_ptr + p_h_val * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h = tl.load(sin_cache_ptr + p_h_val * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_w = tl.load(cos_cache_ptr + p_w * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_w = tl.load(sin_cache_ptr + p_w * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Load cos/sin values for all three axes
- cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ axis_idx = tl.where(d_offs < HALF, d_offs, d_offs - HALF)
+ is_t = axis_idx < s0
+ is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+ sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
- # Per-element axis index
- axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_temporal = axis_idx < s0
- is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ perm = tl.where(d_offs < HALF, d_offs + HALF, d_offs - HALF)
+ sign = tl.where(d_offs < HALF, -1.0, 1.0)
- cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
- sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
+ x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
- perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
- sign = tl.where(d_offs < half, -1.0, 1.0)
-
- for h in tl.static_range(H_val):
- x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
- out = x * cos_val + x_rot * sin_val
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ out = x * cos_val + x_rot * sin_val
+ o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
+ """Triton-based mRoPE kernel with per-(b,t,h,d_block) parallelism."""
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__init__()
assert sum(mrope_section) == D // 2
@@ -97,59 +107,48 @@
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
- def _next_pow2(self, n):
- v = 1
- while v < n:
- v *= 2
- return v
-
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
s0, s1, _ = self.mrope_section
- half = D // 2
-
... (2295 more chars){
"content": "<5406 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark"
}shape=0 variant=eager tflops=0.177 gbps=254.132 ms=0.355
shape=0 variant=compiled tflops=0.493 gbps=707.668 ms=0.128
shape=0 variant=solution tflops=0.166 gbps=238.865 ms=0.378
shape=0 solution_peak_fraction=0.1327
shape=1 variant=eager tflops=0.210 gbps=306.510 ms=0.480
shape=1 variant=compiled tflops=0.466 gbps=681.047 ms=0.216
shape=1 variant=solution tflops=0.169 gbps=246.726 ms=0.596
shape=1 solution_peak_fraction=0.1371
shape=2 variant=eager tflops=0.074 gbps=114.989 ms=0.192
shape=2 variant=compiled tflops=0.172 gbps=268.011 ms=0.083
shape=2 variant=solution tflops=0.126 gbps=197.063 ms=0.112
shape=2 solution_peak_fraction=0.1095
shape=3 variant=eager tflops=0.172 gbps=246.378 ms=1.466
shape=3 variant=compiled tflops=0.462 gbps=663.461 ms=0.544
shape=3 variant=solution tflops=0.125 gbps=179.788 ms=2.008
shape=3 solution_peak_fraction=0.0999
peak_fraction: 0.1188
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Optimized Triton kernel with head-parallelism for small head counts.
-Grid layout: (B * T, (H * D) / THREADS_PER_BLOCK) to maximize occupancy.
+Grid = (B * T). Each program handles one (b, t) across all H heads.
+Cache values loaded once per (b,t), shared across heads.
"""
import torch
@@ -19,10 +19,10 @@
pos_t_ptr,
pos_h_ptr,
pos_w_ptr,
- BATCH: tl.constexpr,
- SEQ: tl.constexpr,
- NUM_HEADS: tl.constexpr,
- HEAD_DIM: tl.constexpr,
+ B_val: tl.constexpr,
+ T_val: tl.constexpr,
+ H_val: tl.constexpr,
+ D_val: tl.constexpr,
stride_xb,
stride_xt,
stride_xh,
@@ -33,62 +33,52 @@
stride_od,
s0: tl.constexpr,
s1: tl.constexpr,
- HALF: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
+ half: tl.constexpr,
+ BLOCK_D: tl.constexpr,
):
- """Grid = (B * T * NUM_HEADS * triton.cdiv(D, BLOCK_SIZE)).
- Each program handles a block of D for one (b, t, h)."""
- pid = tl.program_id(0)
- pid1 = tl.program_id(1)
+ """Grid = (B * T). Each program handles one (b, t) across all H heads."""
+ pid_bt = tl.program_id(0)
+ pid_b = pid_bt // T_val
+ pid_t = pid_bt % T_val
- # Decode (b, t, h) from pid
- total_heads_seq = NUM_HEADS * SEQ
- bt_h = pid
- pid_bt = bt_h // NUM_HEADS
- pid_h = bt_h % NUM_HEADS
- pid_b = pid_bt // SEQ
- pid_t = pid_bt % SEQ
-
- # Each program handles BLOCK_SIZE consecutive D elements
- offset_d = pid1 * BLOCK_SIZE
- d_offs = offset_d + tl.arange(0, BLOCK_SIZE)
- mask = d_offs < HEAD_DIM
-
- # Gather position IDs
- pos_idx = pid_b * SEQ + pid_t
+ pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
- p_h_val = tl.load(pos_h_ptr + pos_idx)
+ p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- # Load cos/sin for all three axes
- cos_t = tl.load(cos_cache_ptr + p_t * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_t = tl.load(sin_cache_ptr + p_t * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h = tl.load(cos_cache_ptr + p_h_val * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h_val * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_w = tl.load(cos_cache_ptr + p_w * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_w = tl.load(sin_cache_ptr + p_w * HEAD_DIM + d_offs, mask=mask, other=0.0).to(tl.float32)
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D_val
- axis_idx = tl.where(d_offs < HALF, d_offs, d_offs - HALF)
- is_t = axis_idx < s0
- is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+ # Load cos/sin values for all three axes
+ cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- perm = tl.where(d_offs < HALF, d_offs + HALF, d_offs - HALF)
- sign = tl.where(d_offs < HALF, -1.0, 1.0)
+ # Per-element axis index
+ axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+ is_temporal = axis_idx < s0
+ is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
- x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
+ cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
+ sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
- out = x * cos_val + x_rot * sin_val
- o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ sign = tl.where(d_offs < half, -1.0, 1.0)
+
+ for h in tl.static_range(H_val):
+ x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = x_rot * sign
+ out = x * cos_val + x_rot * sin_val
+ o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
- """Triton-based mRoPE kernel with per-(b,t,h,d_block) parallelism."""
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__init__()
assert sum(mrope_section) == D // 2
@@ -107,48 +97,59 @@
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+ def _next_pow2(self, n):
+ v = 1
+ while v < n:
+ v *= 2
+ return v
+
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
s0, s1, _ = self.mrope_section
+ half = D // 2
+
... (2295 more chars){
"content": "<5141 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,15 @@
"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Grid = (B * T). Each program handles one (b, t) across all H heads.
-Cache values loaded once per (b,t), shared across heads.
+Optimization strategies tested:
+1. Grid = (B * T) vs (B * T * H) - former wins by maximizing cache reuse
+2. num_warps=1 is best for small D (64-128) - minimal register pressure
+3. num_stages=4 gives best pipelining without spilling
+4. Separate Q/K kernels better than concatenation (avoids extra memory ops)
+
+The B*T grid gives ~50-70% peak bandwidth, which is very good for a
+bandwidth-bound operation. The remaining gap comes from:
+- Triton overhead for small programs (D=64/128 is tiny)
+- Per-head loop iterations causing serial execution
"""
import torch
@@ -36,7 +44,6 @@
half: tl.constexpr,
BLOCK_D: tl.constexpr,
):
- """Grid = (B * T). Each program handles one (b, t) across all H heads."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
@@ -49,7 +56,6 @@
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
- # Load cos/sin values for all three axes
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -57,21 +63,19 @@
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Per-element axis index
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_temporal = axis_idx < s0
- is_height = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ is_t = axis_idx < s0
+ is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+ sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
- cos_val = tl.where(is_temporal, cos_t, tl.where(is_height, cos_h, cos_w))
- sin_val = tl.where(is_temporal, sin_t, tl.where(is_height, sin_h, sin_w))
-
- perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
for h in tl.static_range(H_val):
x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
x_rot = x_rot * sign
out = x * cos_val + x_rot * sin_val
o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
@@ -116,41 +120,27 @@
bt_grid = (B * T,)
_rope_fwd_kernel[bt_grid](
- q, q_out,
- self.cos_cache, self.sin_cache,
+ q, q_out, self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hq, D_val=D,
- stride_xb=T * Hq * D,
- stride_xt=Hq * D,
- stride_xh=D,
- stride_xd=1,
- stride_ob=Hq * T * D,
- stride_oh=T * D,
- stride_ot=D,
- stride_od=1,
- s0=s0, s1=s1, half=half,
- BLOCK_D=BLOCK_D,
- num_warps=1,
- num_stages=4,
+ stride_xb=T * Hq * D, stride_xt=Hq * D,
+ stride_xh=D, stride_xd=1,
+ stride_ob=Hq * T * D, stride_oh=T * D,
+ stride_ot=D, stride_od=1,
+ s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+ num_warps=1, num_stages=4,
)
_rope_fwd_kernel[bt_grid](
- k, k_out,
- self.cos_cache, self.sin_cache,
+ k, k_out, self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hkv, D_val=D,
- stride_xb=T * Hkv * D,
- stride_xt=Hkv * D,
- stride_xh=D,
- stride_xd=1,
- stride_ob=Hkv * T * D,
- stride_oh=T * D,
- stride_ot=D,
- stride_od=1,
- s0=s0, s1=s1, half=half,
- BLOCK_D=BLOCK_D,
- num_warps=1,
- num_stages=4,
+ stride_xb=T * Hkv * D, stride_xt=Hkv * D,
+ stride_xh=D, stride_xd=1,
+ stride_ob=Hkv * T * D, stride_oh=T * D,
+ stride_ot=D, stride_od=1,
+ s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+ num_warps=1, num_stages=4,
)
return q_out, k_out
{
"content": "<5192 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run correctness and benchmark"
}PASS
shape=0 variant=eager tflops=0.178 gbps=255.270 ms=0.354
shape=0 variant=compiled tflops=0.509 gbps=730.576 ms=0.124
shape=0 variant=solution tflops=0.704 gbps=1010.249 ms=0.089
shape=0 solution_peak_fraction=0.5612
shape=1 variant=eager tflops=0.166 gbps=242.788 ms=0.605
shape=1 variant=compiled tflops=0.461 gbps=673.212 ms=0.218
shape=1 variant=solution tflops=0.729 gbps=1065.197 ms=0.138
shape=1 solution_peak_fraction=0.5918
shape=2 variant=eager tflops=0.062 gbps=96.536 ms=0.229
shape=2 variant=compiled tflops=0.163 gbps=255.008 ms=0.087
shape=2 variant=solution tflops=0.311 gbps=485.393 ms=0.046
shape=2 solution_peak_fraction=0.2697
shape=3 variant=eager tflops=0.168 gbps=240.959 ms=1.499
shape=3 variant=compiled tflops=0.517 gbps=742.229 ms=0.487
shape=3 variant=solution tflops=0.974 gbps=1397.805 ms=0.258
shape=3 solution_peak_fraction=0.7766
peak_fraction: 0.5135
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,15 +1,6 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — Triton kernel (SM120 Blackwell).
-Optimization strategies tested:
-1. Grid = (B * T) vs (B * T * H) - former wins by maximizing cache reuse
-2. num_warps=1 is best for small D (64-128) - minimal register pressure
-3. num_stages=4 gives best pipelining without spilling
-4. Separate Q/K kernels better than concatenation (avoids extra memory ops)
-
-The B*T grid gives ~50-70% peak bandwidth, which is very good for a
-bandwidth-bound operation. The remaining gap comes from:
-- Triton overhead for small programs (D=64/128 is tiny)
-- Per-head loop iterations causing serial execution
+Grid = B*T. Each program processes all H heads for one (b,t).
"""
import torch
@@ -20,29 +11,15 @@
@triton.jit
def _rope_fwd_kernel(
- x_ptr,
- x_out_ptr,
- cos_cache_ptr,
- sin_cache_ptr,
- pos_t_ptr,
- pos_h_ptr,
- pos_w_ptr,
- B_val: tl.constexpr,
- T_val: tl.constexpr,
- H_val: tl.constexpr,
- D_val: tl.constexpr,
- stride_xb,
- stride_xt,
- stride_xh,
- stride_xd,
- stride_ob,
- stride_oh,
- stride_ot,
- stride_od,
- s0: tl.constexpr,
- s1: tl.constexpr,
- half: tl.constexpr,
- BLOCK_D: tl.constexpr,
+ x_ptr, x_out_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ B_val: tl.constexpr, T_val: tl.constexpr,
+ H_val: tl.constexpr, D_val: tl.constexpr,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_ob, stride_oh, stride_ot, stride_od,
+ s0: tl.constexpr, s1: tl.constexpr,
+ half: tl.constexpr, BLOCK_D: tl.constexpr,
):
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
@@ -69,15 +46,15 @@
cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
- perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
for h in tl.static_range(H_val):
x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- x_rot = x_rot * sign
- out = x * cos_val + x_rot * sin_val
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ xr = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ xr = xr * sign
+ out = x * cos_val + xr * sin_val
o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
@@ -86,11 +63,8 @@
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__init__()
assert sum(mrope_section) == D // 2
- self.B = B
- self.T = T
- self.Hq = Hq
- self.Hkv = Hkv
- self.D = D
+ self.B, self.T = B, T
+ self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
@@ -116,33 +90,28 @@
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
-
bt_grid = (B * T,)
_rope_fwd_kernel[bt_grid](
- q, q_out, self.cos_cache, self.sin_cache,
+ q, q_out,
+ self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hq, D_val=D,
- stride_xb=T * Hq * D, stride_xt=Hq * D,
- stride_xh=D, stride_xd=1,
- stride_ob=Hq * T * D, stride_oh=T * D,
- stride_ot=D, stride_od=1,
+ stride_xb=T * Hq * D, stride_xt=Hq * D, stride_xh=D, stride_xd=1,
+ stride_ob=Hq * T * D, stride_oh=T * D, stride_ot=D, stride_od=1,
s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
num_warps=1, num_stages=4,
)
-
_rope_fwd_kernel[bt_grid](
- k, k_out, self.cos_cache, self.sin_cache,
+ k, k_out,
+ self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hkv, D_val=D,
- stride_xb=T * Hkv * D, stride_xt=Hkv * D,
- stride_xh=D, stride_xd=1,
- stride_ob=Hkv * T * D, stride_oh=T * D,
- stride_ot=D, stride_od=1,
+ stride_xb=T * Hkv * D, stride_xt=Hkv * D, stride_xh=D, stride_xd=1,
+ stride_ob=Hkv * T * D, stride_oh=T * D, stride_ot=D, stride_od=1,
s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
num_warps=1, num_stages=4,
)
-
return q_out, k_out
{
"content": "<4598 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}PASS
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,65 +1,135 @@
-"""Multi-axis RoPE pre-attention prep — Triton kernel (SM120 Blackwell).
-
-Grid = B*T. Each program processes all H heads for one (b,t).
+"""Optimized CUDA C++ kernel for the multi-axis RoPE pre-attention prep.
+
+Uses torch.utils.cpp_extension.load_inline with a single well-tuned kernel.
+Each thread processes ONE element (full D vector) for ONE (b,h,t).
+Grid = B * (Hq + Hkv) * T (one block per (b,h,t)).
+Block = D threads (one thread per D element - coalesced load/store).
"""
import torch
-import triton
-import triton.language as tl
import torch.nn as nn
-
-
-@triton.jit
-def _rope_fwd_kernel(
- x_ptr, x_out_ptr,
- cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- B_val: tl.constexpr, T_val: tl.constexpr,
- H_val: tl.constexpr, D_val: tl.constexpr,
- stride_xb, stride_xt, stride_xh, stride_xd,
- stride_ob, stride_oh, stride_ot, stride_od,
- s0: tl.constexpr, s1: tl.constexpr,
- half: tl.constexpr, BLOCK_D: tl.constexpr,
-):
- pid_bt = tl.program_id(0)
- pid_b = pid_bt // T_val
- pid_t = pid_bt % T_val
-
- pos_idx = pid_b * T_val + pid_t
- p_t = tl.load(pos_t_ptr + pos_idx)
- p_h = tl.load(pos_h_ptr + pos_idx)
- p_w = tl.load(pos_w_ptr + pos_idx)
-
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
-
- cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-
- axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_t = axis_idx < s0
- is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
-
- perm_offs = tl.where(d_offs < half, d_offs + half, d_offs - half)
- sign = tl.where(d_offs < half, -1.0, 1.0)
-
- for h in tl.static_range(H_val):
- x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = tl.load(x_ptr + x_base + perm_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = xr * sign
- out = x * cos_val + xr * sin_val
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+from torch.utils.cpp_extension import load_inline
+import os
+
+CUDA_SRC = r"""
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+
+template<int D, int D_half, int s0_val, int s1_val>
+__global__ void __launch_bounds__(128)
+rope_kernel(
+ const __nv_bfloat16* __restrict__ x_in, // (B, T, Hd, D)
+ __nv_bfloat16* __restrict__ x_out, // (B, Hd, T, D)
+ const int64_t* __restrict__ pos_t,
+ const int64_t* __restrict__ pos_h,
+ const int64_t* __restrict__ pos_w,
+ const __nv_bfloat16* __restrict__ cos_cache,
+ const __nv_bfloat16* __restrict__ sin_cache,
+ int B, int T, int Hx
+) {
+ // blockIdx.x = b * T + t
+ int bt = blockIdx.x;
+ int b = bt / T;
+ int t = bt % T;
+
+ // Position IDs
+ int64_t pt = pos_t[b * T + t];
+ int64_t ph = pos_h[b * T + t];
+ int64_t pw = pos_w[b * T + t];
+
+ // Thread index within block: handles head index
+ int h = blockIdx.y * blockDim.x + threadIdx.x;
+ if (h >= Hx) return;
+
+ // Per-thread: process D elements (D/2 pairs)
+ const int np = D / 2;
+ #define CACHE_COS(p, pos) cos_cache[pos * D + p]
+ #define CACHE_SIN(p, pos) sin_cache[pos * D + p]
+
+ for (int pi = 0; pi < np; ++pi) {
+ int d0 = pi * 2;
+ int d1 = pi * 2 + 1;
+
+ // Determine axis for both elements
+ int axis0 = (d0 < D_half) ? d0 : (d0 - D_half);
+ int axis1 = (d1 < D_half) ? d1 : (d1 - D_half);
+
+ // Load from cache
+ float2 c0, s0, c1, s1;
+
+ // For element d0
+ int64_t pos0 = (axis0 < s0_val) ? pt : ((axis0 < s0_val + s1_val) ? ph : pw);
+ float2 tmp = __bfloat1622float2(((float2*)CACHE_COS(d0, pos0))[0]);
+ c0 = *(float2*)&tmp; // already a float2
+
+ // Use __ldg for cache optimization
+ const float2* cos_row = (const float2*)(cos_cache + pos0 * D);
+ const float2* sin_row = (const float2*)(sin_cache + pos0 * D);
+ c0 = __ldg(&cos_row[pi]).x; // This is getting complicated
+
+ // Simpler: load single bf16 values
+ float cos_d0_x = __bfloat162float(__ldg(&CACHE_COS(d0, pos0)));
+ float sin_d0_x = __bfloat162float(__ldg(&CACHE_SIN(d0, pos0)));
+ float cos_d0_y = __bfloat162float(__ldg(&CACHE_COS(d1, pos0)));
+ float sin_d0_y = __bfloat162float(__ldg(&CACHE_SIN(d1, pos0)));
+
+ // Input pointer for this head
+ const __nv_bfloat16* x_row = x_in + ((b * T * Hx + t * Hx + h) * D);
+
+ // Load input bf16 pairs
+ float x0 = __bfloat162float(__ldg(&x_row[d0]));
+ float x1 = __bfloat162float(__ldg(&x_row[d1]));
+
+ // rotate_half mirror indices
+ int r0 = (d0 < D_half) ? (d0 + D_half) : (d0 - D_half);
+ int r1 = (d1 < D_half) ? (d1 + D_half) : (d1 - D_half);
+
+ float xr0 = __bfloat162float(__ldg(&x_row[r0]));
+ float xr1 = __bfloat162float(__ldg(&x_row[r1]));
+
+ // Negate if in first half
+ if (d0 < D_half) { xr0 = -xr0; }
+ if (d1 <
... (4279 more chars){
"content": "<8845 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda_v2.py
@@ -0,0 +1,137 @@
+"""Optimized multi-axis RoPE pre-attention prep - CUDA C++ via load_inline.
+
+Single kernel: one thread per (b, h, t) tuple, processes full D vector.
+Uses __ldg for cache-optimized reads, vectorized bf16 x2 ops for stores.
+Grid = (B * T, ceil(H / 256)), block = 256 threads.
+"""
+
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+import os
+
+CUDA_SRC = r"""
+#include <cuda_bf16.h>
+
+// Generic kernel: each thread handles all D elements for one (b, h, t)
+template<int D, int HALF, int S0, int S1>
+__global__ void __launch_bounds__(256)
+rope_kernel(
+ const __nv_bfloat16* __restrict__ x_in, // (B, T, Hx, D) bf16
+ __nv_bfloat16* __restrict__ x_out, // (B, Hx, T, D) bf16
+ const int64_t* __restrict__ pos_t, // (B, T) int64
+ const int64_t* __restrict__ pos_h,
+ const int64_t* __restrict__ pos_w,
+ const __nv_bfloat16* __restrict__ cos_cache, // (max_pos, D) bf16
+ const __nv_bfloat16* __restrict__ sin_cache,
+ int B, int T, int Hx
+) {
+ int bt = blockIdx.x;
+ int b = bt / T;
+ int t = bt % T;
+
+ int64_t pos_idx = (int64_t)b * T + t;
+ int64_t pt = pos_t[pos_idx];
+ int64_t ph = pos_h[pos_idx];
+ int64_t pw = pos_w[pos_idx];
+
+ int h = blockIdx.y * blockDim.x + threadIdx.x;
+ if (h >= Hx) return;
+
+ constexpr int D2 = D / 2; // number of pairs
+ constexpr int H2 = HALF; // half of D (same as D2 for standard RoPE)
+
+ // Precompute which axis each pair belongs to
+ static __device__ constexpr struct { short pair; short axis; char is_t; char is_h; } PAIR_META[D2] = {
+ // Generated per-template-instantiation at compile time via unrolled init
+ // We'll compute at runtime below instead (no register cost for D<=64 pairs)
+ };
+
+ // Process each pair (2 bf16 elements)
+ const __nv_bfloat16* cos_base_t = cos_cache + pt * D;
+ const __nv_bfloat16* sin_base_t = sin_cache + pt * D;
+ const __nv_bfloat16* cos_base_h = cos_cache + ph * D;
+ const __nv_bfloat16* sin_base_h = sin_cache + ph * D;
+ const __nv_bfloat16* cos_base_w = cos_cache + pw * D;
+ const __nv_bfloat16* sin_base_w = sin_cache + pw * D;
+
+ // Load input
+ const __nv_bfloat16* x_row = x_in + ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
+
+ // Output pointer
+ __nv_bfloat16* out_row = x_out + ((int64_t)b * Hx * T + (int64_t)h * T + (int64_t)t) * D;
+
+ #pragma unroll
+ for (int d = 0; d < D; d++) {
+ // Mirror index for rotate_half
+ int mirror = (d < H2) ? (d + H2) : (d - H2);
+
+ // Determine axis
+ int axis = (d < H2) ? d : (d - H2);
+
+ // Select cache base
+ bool is_temp = axis < S0;
+ bool is_height = (axis >= S0) && (axis < S0 + S1);
+
+ const __nv_bfloat16* cos_base = is_temp ? cos_base_t : (is_height ? cos_base_h : cos_base_w);
+ const __nv_bfloat16* sin_base = is_temp ? sin_base_t : (is_height ? sin_base_h : sin_base_w);
+
+ // Load values
+ float xv = __bfloat162float(__ldg(x_row + d));
+ float mv = __bfloat162float(__ldg(x_row + mirror));
+ float cos_v = __bfloat162float(__ldg(cos_base + d));
+ float sin_v = __bfloat162float(__ldg(sin_base + d));
+
+ // Negate mirrored for first half
+ if (d < H2) mv = -mv;
+
+ // Apply rotation
+ float outv = xv * cos_v + mv * sin_v;
+
+ // Store
+ out_row[d] = __float2bfloat16(outv);
+ }
+}
+
+// Dispatcher
+extern "C" void launch_rope(
+ const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
+ const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+ const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+ int B, int T, int Hx, int D, int s0, int s1, cudaStream_t stream)
+{
+ dim3 block(256);
+ dim3 grid(B * T, (Hx + block.x - 1) / block.x);
+
+ #define CASE(D_, S0_, S1_) \
+ if (D == D_ && s0 == S0_ && s1 == S1_) { \
+ rope_kernel<D_, D_/2, S0_, S1_><<<grid, block, 0, stream>>>( \
+ x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx); \
+ return; \
+ }
+ CASE(128, 16, 24)
+ CASE(64, 8, 12)
+ #undef CASE
+}
+"""
+
+
+class Model(nn.Module):
+ def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
+ super().__init__()
+ assert sum(mrope_section) == D // 2
+ self.B, self.T = B, T
+ self.Hq, self.Hkv, self.D = Hq, Hkv, D
+ self.mrope_section = tuple(mrope_section)
+ self.max_pos = max_pos
+
+ inv_freq = _build_inv_freq(D)
+ pos = torch.arange(max_pos, dtype=torch.float32)
+ freqs = torch.outer(pos, inv_freq)
+ emb = torch.cat([freqs, freqs], dim=-1)
+ self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
+ self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+
+
+def _build_inv_freq(D, base=10000.0):
+ return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
{
"content": "<4890 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution_cuda_v2.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,190 +1,21 @@
-"""Optimized CUDA C++ kernel for the multi-axis RoPE pre-attention prep.
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Uses torch.utils.cpp_extension.load_inline with a single well-tuned kernel.
-Each thread processes ONE element (full D vector) for ONE (b,h,t).
-Grid = B * (Hq + Hkv) * T (one block per (b,h,t)).
-Block = D threads (one thread per D element - coalesced load/store).
+Grid = (B * T). Each program handles one (b, t) across all H heads.
+Cache values loaded once per (b,t), shared across heads.
+
+Optimization: use a separate kernel for small vs large head counts.
+For large H (>=16): parallelize heads across warps.
+For small H (<16): serialize heads in a single warp.
"""
import torch
+import triton
+import triton.language as tl
import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-import os
-
-CUDA_SRC = r"""
-#include <cuda_fp16.h>
-#include <cuda_bf16.h>
-
-template<int D, int D_half, int s0_val, int s1_val>
-__global__ void __launch_bounds__(128)
-rope_kernel(
- const __nv_bfloat16* __restrict__ x_in, // (B, T, Hd, D)
- __nv_bfloat16* __restrict__ x_out, // (B, Hd, T, D)
- const int64_t* __restrict__ pos_t,
- const int64_t* __restrict__ pos_h,
- const int64_t* __restrict__ pos_w,
- const __nv_bfloat16* __restrict__ cos_cache,
- const __nv_bfloat16* __restrict__ sin_cache,
- int B, int T, int Hx
-) {
- // blockIdx.x = b * T + t
- int bt = blockIdx.x;
- int b = bt / T;
- int t = bt % T;
-
- // Position IDs
- int64_t pt = pos_t[b * T + t];
- int64_t ph = pos_h[b * T + t];
- int64_t pw = pos_w[b * T + t];
-
- // Thread index within block: handles head index
- int h = blockIdx.y * blockDim.x + threadIdx.x;
- if (h >= Hx) return;
-
- // Per-thread: process D elements (D/2 pairs)
- const int np = D / 2;
- #define CACHE_COS(p, pos) cos_cache[pos * D + p]
- #define CACHE_SIN(p, pos) sin_cache[pos * D + p]
-
- for (int pi = 0; pi < np; ++pi) {
- int d0 = pi * 2;
- int d1 = pi * 2 + 1;
-
- // Determine axis for both elements
- int axis0 = (d0 < D_half) ? d0 : (d0 - D_half);
- int axis1 = (d1 < D_half) ? d1 : (d1 - D_half);
-
- // Load from cache
- float2 c0, s0, c1, s1;
-
- // For element d0
- int64_t pos0 = (axis0 < s0_val) ? pt : ((axis0 < s0_val + s1_val) ? ph : pw);
- float2 tmp = __bfloat1622float2(((float2*)CACHE_COS(d0, pos0))[0]);
- c0 = *(float2*)&tmp; // already a float2
-
- // Use __ldg for cache optimization
- const float2* cos_row = (const float2*)(cos_cache + pos0 * D);
- const float2* sin_row = (const float2*)(sin_cache + pos0 * D);
- c0 = __ldg(&cos_row[pi]).x; // This is getting complicated
-
- // Simpler: load single bf16 values
- float cos_d0_x = __bfloat162float(__ldg(&CACHE_COS(d0, pos0)));
- float sin_d0_x = __bfloat162float(__ldg(&CACHE_SIN(d0, pos0)));
- float cos_d0_y = __bfloat162float(__ldg(&CACHE_COS(d1, pos0)));
- float sin_d0_y = __bfloat162float(__ldg(&CACHE_SIN(d1, pos0)));
-
- // Input pointer for this head
- const __nv_bfloat16* x_row = x_in + ((b * T * Hx + t * Hx + h) * D);
-
- // Load input bf16 pairs
- float x0 = __bfloat162float(__ldg(&x_row[d0]));
- float x1 = __bfloat162float(__ldg(&x_row[d1]));
-
- // rotate_half mirror indices
- int r0 = (d0 < D_half) ? (d0 + D_half) : (d0 - D_half);
- int r1 = (d1 < D_half) ? (d1 + D_half) : (d1 - D_half);
-
- float xr0 = __bfloat162float(__ldg(&x_row[r0]));
- float xr1 = __bfloat162float(__ldg(&x_row[r1]));
-
- // Negate if in first half
- if (d0 < D_half) { xr0 = -xr0; }
- if (d1 < D_half) { xr1 = -xr1; }
-
- // Apply rotation
- float out0 = x0 * cos_d0_x + xr0 * sin_d0_x;
- float out1 = x1 * cos_d0_y + xr1 * sin_d0_y;
-
- // Store to transposed output
- int out_idx = (b * Hx * T * D + h * T * D + t * D);
- x_out[out_idx + d0] = __float2bfloat16(out0);
- x_out[out_idx + d1] = __float2bfloat16(out1);
- }
-}
-
-// Dispatcher
-extern "C" void launch_rope(
- const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
- const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
- const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
- int B, int T, int Hx, int D, int s0, int s1,
- cudaStream_t stream
-) {
- dim3 block(min(Hx, 128));
- dim3 grid(B * T, (Hx + block.x - 1) / block.x);
-
- if (D == 128) {
- if (s0 == 16 && s1 == 24) {
- rope_kernel<128, 64, 16, 24><<<grid, block, 0, stream>>>(
- x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
- }
- } else if (D == 64) {
- if (s0 == 8 && s1 == 12) {
- rope_kernel<64, 32, 8, 12><<<grid, block, 0, stream>>>(
- x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
- }
- }
-}
-"""
-
-
-class Model(nn.Module):
- """Triton-based mRoPE kernel with per-(b,t) parallelism."""
- def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
- super().__init__()
- assert sum(mrope_section) == D // 2
- self.B, self.T = B, T
- self.Hq, self.Hkv, self.D = Hq, Hkv, D
- self.mrope_section = tuple(mrope_section)
- self.max_pos = max_pos
-
- inv_freq = _build_inv_freq(D)
- pos = torch.arange(max_pos, dtype=torch.float32)
- freqs = torch.outer(pos, inv_freq)
- emb = torch.cat([freqs, freqs], dim=-1)
- self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
- self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
-
- de
... (8309 more chars){
"content": "<7846 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}PASS
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,11 +1,9 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
-Grid = (B * T). Each program handles one (b, t) across all H heads.
-Cache values loaded once per (b,t), shared across heads.
+Grid = B*T. Each kernel processes all H heads for one (b,t) position.
+Cos/sin cache gathered once per position, reused across all heads.
-Optimization: use a separate kernel for small vs large head counts.
-For large H (>=16): parallelize heads across warps.
-For small H (<16): serialize heads in a single warp.
+Uses num_warps=1, num_stages=4 for tiny D vectors (64/128 elements).
"""
import torch
@@ -15,7 +13,7 @@
@triton.jit
-def _rope_small_h_kernel(
+def _rope_fwd_kernel(
x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
@@ -26,7 +24,7 @@
s0: tl.constexpr, s1: tl.constexpr,
half: tl.constexpr, BLOCK_D: tl.constexpr,
):
- """Serial loop over heads. Optimal for H < 16 (fits in single warp)."""
+ """Grid = (B * T). Each program = one (b,t); loops over H heads."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
@@ -39,6 +37,7 @@
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
+ # Gather cache values (6 reads, shared across all heads)
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -46,80 +45,22 @@
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ # Per-element axis assignment (mirrored across D/2)
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
is_t = axis_idx < s0
is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+ # rotate_half: swap halves with sign flip
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
+ # Loop over heads
for h in tl.static_range(H_val):
x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = xr * sign
- out = x * cos_val + xr * sin_val
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
-
-
-@triton.jit
-def _rope_large_h_kernel(
- x_ptr, x_out_ptr,
- cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- B_val: tl.constexpr, T_val: tl.constexpr,
- H_val: tl.constexpr, D_val: tl.constexpr,
- stride_xb, stride_xt, stride_xh, stride_xd,
- stride_ob, stride_oh, stride_ot, stride_od,
- s0: tl.constexpr, s1: tl.constexpr,
- half: tl.constexpr,
- HEADS_PER_BLOCK: tl.constexpr,
- BLOCK_D: tl.constexpr,
-):
- """Each block processes HEADS_PER_BLOCK heads for one (b,t).
- Warps within the block are assigned different heads or D ranges."""
- pid_bt = tl.program_id(0)
- pid_hblock = tl.program_id(1)
- pid_b = pid_bt // T_val
- pid_t = pid_bt % T_val
-
- pos_idx = pid_b * T_val + pid_t
- p_t = tl.load(pos_t_ptr + pos_idx)
- p_h = tl.load(pos_h_ptr + pos_idx)
- p_w = tl.load(pos_w_ptr + pos_idx)
-
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
-
- cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
-
- axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_t = axis_idx < s0
- is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
-
- perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
- sign = tl.where(d_offs < half, -1.0, 1.0)
-
- heads_in_block = min(HEADS_PER_BLOCK, H_val - pid_hblock * HEADS_PER_BLOCK)
-
- for h_local in tl.static_range(HEADS_PER_BLOCK):
- h = pid_hblock * HEADS_PER_BLOCK + h_local
- if h >= H_val:
- return
-
- x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
xr = xr * sign
out = x * cos_val + xr * sin_val
o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
@@ -148,48 +89,37 @@
v *= 2
return v
- def _launch_rope(
- self, x, x_out, H,
- pos_t, pos_h
... (2650 more chars){
"content": "<4967 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: """Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
2:
3: Grid = B*T. Each kernel processes all H heads for one (b,t) position.
4: Cos/sin cache gathered once per position, reused across all heads.
5:
6: Uses num_warps=1, num_stages=4 for tiny D vectors (64/128 elements).
7: """
8:
9: import torch
10: import triton
11: import triton.language as tl
12: import torch.nn as nn
13:
14:
15: @triton.jit
16: def _rope_fwd_kernel(
17: x_ptr, x_out_ptr,
18: cos_cache_ptr, sin_cache_ptr,
19: pos_t_ptr, pos_h_ptr, pos_w_ptr,
20: B_val: tl.constexpr, T_val: tl.constexpr,
21: H_val: tl.constexpr, D_val: tl.constexpr,
22: stride_xb, stride_xt, stride_xh, stride_xd,
23: stride_ob, stride_oh, stride_ot, stride_od,
24: s0: tl.constexpr, s1: tl.constexpr,
25: half: tl.constexpr, BLOCK_D: tl.constexpr,
26: ):
27: """Grid = (B * T). Each program = one (b,t); loops over H heads."""
28: pid_bt = tl.program_id(0)
29: pid_b = pid_bt // T_val
30: pid_t = pid_bt % T_val
31:
32: pos_idx = pid_b * T_val + pid_t
33: p_t = tl.load(pos_t_ptr + pos_idx)
34: p_h = tl.load(pos_h_ptr + pos_idx)
35: p_w = tl.load(pos_w_ptr + pos_idx)
36:
37: d_offs = tl.arange(0, BLOCK_D)
38: mask = d_offs < D_val
39:
40: # Gather cache values (6 reads, shared across all heads)
41: cos
... (4164 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,70 +1,121 @@
-"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
-Grid = B*T. Each kernel processes all H heads for one (b,t) position.
-Cos/sin cache gathered once per position, reused across all heads.
-
-Uses num_warps=1, num_stages=4 for tiny D vectors (64/128 elements).
+Each thread handles one (b, h, t) and processes all D elements sequentially.
+Uses __ldg for read-only cache hints and template instantiation for D=64/128.
"""
import torch
-import triton
-import triton.language as tl
import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+import textwrap
+KERNEL_SRC = textwrap.dedent(r"""
+#include <cuda_bf16.h>
-@triton.jit
-def _rope_fwd_kernel(
- x_ptr, x_out_ptr,
- cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- B_val: tl.constexpr, T_val: tl.constexpr,
- H_val: tl.constexpr, D_val: tl.constexpr,
- stride_xb, stride_xt, stride_xh, stride_xd,
- stride_ob, stride_oh, stride_ot, stride_od,
- s0: tl.constexpr, s1: tl.constexpr,
- half: tl.constexpr, BLOCK_D: tl.constexpr,
-):
- """Grid = (B * T). Each program = one (b,t); loops over H heads."""
- pid_bt = tl.program_id(0)
- pid_b = pid_bt // T_val
- pid_t = pid_bt % T_val
+template<int D, int HALF, int S0, int S1>
+__global__ void __launch_bounds__(128)
+rope_kernel(
+ const __nv_bfloat16* x_in,
+ __nv_bfloat16* x_out,
+ const int64_t* pos_t,
+ const int64_t* pos_h,
+ const int64_t* pos_w,
+ const __nv_bfloat16* cos_cache,
+ const __nv_bfloat16* sin_cache,
+ int B, int T, int Hx)
+{
+ int bt = blockIdx.x;
+ int b = bt / T;
+ int t = bt % T;
- pos_idx = pid_b * T_val + pid_t
- p_t = tl.load(pos_t_ptr + pos_idx)
- p_h = tl.load(pos_h_ptr + pos_idx)
- p_w = tl.load(pos_w_ptr + pos_idx)
+ int64_t pi = (int64_t)b * T + t;
+ int64_t p_t = pos_t[pi];
+ int64_t p_h = pos_h[pi];
+ int64_t p_w = pos_w[pi];
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
+ int h = blockIdx.y * blockDim.x + threadIdx.x;
+ if (h >= Hx) return;
- # Gather cache values (6 reads, shared across all heads)
- cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ int64_t in_off = ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
+ int64_t out_off = ((int64_t)b * Hx * T + (int64_t)h * T + (int64_t)t) * D;
- # Per-element axis assignment (mirrored across D/2)
- axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
- is_t = axis_idx < s0
- is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+ const __nv_bfloat16* xi = x_in + in_off;
+ __nv_bfloat16* xo = x_out + out_off;
- # rotate_half: swap halves with sign flip
- perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
- sign = tl.where(d_offs < half, -1.0, 1.0)
+ // Determine axis for each element at compile time
+ static_assert(D == 64 || D == 128);
+ static_assert(HALF == D / 2);
+ static_assert(S0 + S1 == HALF);
- # Loop over heads
- for h in tl.static_range(H_val):
- x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = xr * sign
- out = x * cos_val + xr * sin_val
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ #pragma unroll
+ for (int d = 0; d < D; d++) {
+ int mirror = (d < HALF) ? (d + HALF) : (d - HALF);
+ int axis = (d < HALF) ? d : (d - HALF);
+
+ // Pick cache row based on axis
+ const __nv_bfloat16 *cr, *sr;
+ if (axis < S0) {
+ cr = cos_cache + p_t * D;
+ sr = sin_cache + p_t * D;
+ } else if (axis < S0 + S1) {
+ cr = cos_cache + p_h * D;
+ sr = sin_cache + p_h * D;
+ } else {
+ cr = cos_cache + p_w * D;
+ sr = sin_cache + p_w * D;
+ }
+
+ float xv = __bfloat162float(__ldg(xi + d));
+ float mv = __bfloat162float(__ldg(xi + mirror));
+ float cos_v = __bfloat162float(__ldg(cr + d));
+ float sin_v = __bfloat162float(__ldg(sr + d));
+
+ if (d < HALF) mv = -mv;
+
+ xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
+ }
+}
+
+extern "C" void launch_rope(
+ const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
+ const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+ const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+ int B, int T, int Hx, int D, int s0, int s1, cudaStream_t stream)
+{
+ dim3 block(min(Hx, 128));
+ dim3 grid(B * T, (Hx + block.x - 1) / block.x);
+
+ if (D == 128 && s0 == 16 && s1 == 24) {
+ rope_kernel<128, 64, 16, 24><<<grid, block, 0, stream>>>(
+ x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
+ } else if (D == 64 && s0 == 8
... (3025 more chars){
"content": "<5329 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check for CUDA kernel"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 159, in forward
self._call_rope(q, q_out, Hq)
File "/home/infatoshi/cuda/KernelBench
... (5333 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,18 +1,24 @@
"""Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
-Each thread handles one (b, h, t) and processes all D elements sequentially.
+Each thread handles one (b, h, t) and processes all D elements.
Uses __ldg for read-only cache hints and template instantiation for D=64/128.
"""
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
-import textwrap
-KERNEL_SRC = textwrap.dedent(r"""
+# Clean old cache
+import shutil, os
+for cached in ["/home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda_128",
+ "/home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda_64"]:
+ if os.path.exists(cached):
+ shutil.rmtree(cached, ignore_errors=True)
+
+KERNEL_SRC = r"""
#include <cuda_bf16.h>
-template<int D, int HALF, int S0, int S1>
+template<int D, int HALF, int S0, int S1, int S2>
__global__ void __launch_bounds__(128)
rope_kernel(
const __nv_bfloat16* x_in,
@@ -36,42 +42,50 @@
int h = blockIdx.y * blockDim.x + threadIdx.x;
if (h >= Hx) return;
- int64_t in_off = ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
- int64_t out_off = ((int64_t)b * Hx * T + (int64_t)h * T + (int64_t)t) * D;
+ int64_t in_off = ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
+ int64_t out_off = ((int64_t)b * Hx * T + (int64_t)h * T + (int64_t)t) * D;
- const __nv_bfloat16* xi = x_in + in_off;
+ const __nv_bfloat16* xi = x_in + in_off;
__nv_bfloat16* xo = x_out + out_off;
- // Determine axis for each element at compile time
- static_assert(D == 64 || D == 128);
- static_assert(HALF == D / 2);
- static_assert(S0 + S1 == HALF);
+ const __nv_bfloat16 *cr_base_t = cos_cache + p_t * D;
+ const __nv_bfloat16 *sr_base_t = sin_cache + p_t * D;
+ const __nv_bfloat16 *cr_base_h = cos_cache + p_h * D;
+ const __nv_bfloat16 *sr_base_h = sin_cache + p_h * D;
+ const __nv_bfloat16 *cr_base_w = cos_cache + p_w * D;
+ const __nv_bfloat16 *sr_base_w = sin_cache + p_w * D;
#pragma unroll
for (int d = 0; d < D; d++) {
+ // rotate_half mirror index
int mirror = (d < HALF) ? (d + HALF) : (d - HALF);
+
+ // Determine which axis this element belongs to
int axis = (d < HALF) ? d : (d - HALF);
-
- // Pick cache row based on axis
- const __nv_bfloat16 *cr, *sr;
- if (axis < S0) {
- cr = cos_cache + p_t * D;
- sr = sin_cache + p_t * D;
- } else if (axis < S0 + S1) {
- cr = cos_cache + p_h * D;
- sr = sin_cache + p_h * D;
- } else {
- cr = cos_cache + p_w * D;
- sr = sin_cache + p_w * D;
- }
float xv = __bfloat162float(__ldg(xi + d));
float mv = __bfloat162float(__ldg(xi + mirror));
- float cos_v = __bfloat162float(__ldg(cr + d));
- float sin_v = __bfloat162float(__ldg(sr + d));
+ float cos_v, sin_v;
+ // Select cos/sin based on axis
+ if (axis < S0) {
+ // Temporal axis
+ cos_v = __bfloat162float(__ldg(cr_base_t + d));
+ sin_v = __bfloat162float(__ldg(sr_base_t + d));
+ } else if (axis < S0 + S1) {
+ // Height axis
+ cos_v = __bfloat162float(__ldg(cr_base_h + d));
+ sin_v = __bfloat162float(__ldg(sr_base_h + d));
+ } else {
+ // Width axis (remainder, S0+S1 <= axis < HALF)
+ cos_v = __bfloat162float(__ldg(cr_base_w + d));
+ sin_v = __bfloat162float(__ldg(sr_base_w + d));
+ }
+
+ // Negate for first half
if (d < HALF) mv = -mv;
+ // Apply: out = x*cos + rot*sin
xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
}
}
@@ -80,39 +94,35 @@
const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
- int B, int T, int Hx, int D, int s0, int s1, cudaStream_t stream)
+ int B, int T, int Hx, int D, int s0, int s1, int s2, cudaStream_t stream)
{
dim3 block(min(Hx, 128));
dim3 grid(B * T, (Hx + block.x - 1) / block.x);
- if (D == 128 && s0 == 16 && s1 == 24) {
- rope_kernel<128, 64, 16, 24><<<grid, block, 0, stream>>>(
+ if (D == 128 && s0 == 16 && s1 == 24 && s2 == 24) {
+ rope_kernel<128, 64, 16, 24, 24><<<grid, block, 0, stream>>>(
x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
- } else if (D == 64 && s0 == 8 && s1 == 12) {
- rope_kernel<64, 32, 8, 12><<<grid, block, 0, stream>>>(
+ } else if (D == 64 && s0 == 8 && s1 == 12 && s2 == 12) {
+ rope_kernel<64, 32, 8, 12, 12><<<grid, block, 0, stream>>>(
x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
}
}
-""").strip()
+"""
-# Cache the compiled module
_mod_cache = {}
def _get_mod(D):
if D in _mod_cache:
return _mod_cache[D]
- import os
mod = load_inline(
name=f"mrope_cuda_{D}",
cpp_sources=" ",
cuda_sources=KERNEL_SRC,
functions=["launch_rope"],
extra_cuda_cflags=["-arch=sm_120", "-O3", "--use_fast_math",
- "-Xptxas", "-O3", "-Xptxas", "--warn-on_spills",],
- extra_cflags=["-O3"],
+ "-Xptxas", "-O3"],
with_cuda=True,
verbose=False,
- build_directory=None,
)
_mod_cache[D] = mod
return mod
@@ -134,30 +144,36 @@
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
- def _call_rope(self, x, x_out, Hx):
- B, T, H, D = x.shape
... (1564 more chars){
"content": "<6025 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 300000,
"description": "Run correctness check for CUDA kernel"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 162, in forward
_get_mod(D).launch_rope(
^^^^^^^^^^^
File "/home/infatoshi/cuda/
... (3325 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,22 +1,54 @@
"""Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
-Each thread handles one (b, h, t) and processes all D elements.
-Uses __ldg for read-only cache hints and template instantiation for D=64/128.
+Each thread handles one (b, h, t) and processes all D elements in an unrolled loop.
+Uses __ldg for read-only cache hints. Separate Q/K kernel launches.
"""
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
-# Clean old cache
+# Clean old builds
import shutil, os
-for cached in ["/home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda_128",
- "/home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda_64"]:
- if os.path.exists(cached):
- shutil.rmtree(cached, ignore_errors=True)
+for d in ["mrope_cuda_64", "mrope_cuda_128"]:
+ path = os.path.join(os.path.expanduser("~/.cache/torch_extensions/py311_cu130"), d)
+ if os.path.exists(path):
+ shutil.rmtree(path, ignore_errors=True)
+
+PYBIND_DEFS = r"""
+extern void* launch_rope(
+ const void* x_in, void* x_out,
+ const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+ const void* cos_cache, const void* sin_cache,
+ int B, int T, int Hx, int D, int s0, int s1, int s2,
+ cudaStream_t stream);
+
+#include <pybind11/pybind11.h>
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+namespace py = pybind11;
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("launch_rope", [](torch::Tensor x_in, torch::Tensor x_out,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ int B, int T, int Hx, int D, int s0, int s1, int s2) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_in.scalar_type(), "launch_rope", [&] {
+ using scalar_t = char; // we'll use void* internally
+ launch_rope(
+ x_in.data_ptr(), x_out.data_ptr(),
+ pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
+ cos_cache.data_ptr(), sin_cache.data_ptr(),
+ B, T, Hx, D, s0, s1, s2,
+ at::cuda::getCurrentCUDAStream().stream());
+ });
+ }, "launch_rope");
+}
+"""
KERNEL_SRC = r"""
#include <cuda_bf16.h>
+#include <cuda_runtime.h>
template<int D, int HALF, int S0, int S1, int S2>
__global__ void __launch_bounds__(128)
@@ -48,64 +80,56 @@
const __nv_bfloat16* xi = x_in + in_off;
__nv_bfloat16* xo = x_out + out_off;
- const __nv_bfloat16 *cr_base_t = cos_cache + p_t * D;
- const __nv_bfloat16 *sr_base_t = sin_cache + p_t * D;
- const __nv_bfloat16 *cr_base_h = cos_cache + p_h * D;
- const __nv_bfloat16 *sr_base_h = sin_cache + p_h * D;
- const __nv_bfloat16 *cr_base_w = cos_cache + p_w * D;
- const __nv_bfloat16 *sr_base_w = sin_cache + p_w * D;
+ const __nv_bfloat16 *cr_t = cos_cache + p_t * D;
+ const __nv_bfloat16 *sr_t = sin_cache + p_t * D;
+ const __nv_bfloat16 *cr_h = cos_cache + p_h * D;
+ const __nv_bfloat16 *sr_h = sin_cache + p_h * D;
+ const __nv_bfloat16 *cr_w = cos_cache + p_w * D;
+ const __nv_bfloat16 *sr_w = sin_cache + p_w * D;
#pragma unroll
for (int d = 0; d < D; d++) {
- // rotate_half mirror index
int mirror = (d < HALF) ? (d + HALF) : (d - HALF);
-
- // Determine which axis this element belongs to
int axis = (d < HALF) ? d : (d - HALF);
float xv = __bfloat162float(__ldg(xi + d));
float mv = __bfloat162float(__ldg(xi + mirror));
- float cos_v, sin_v;
- // Select cos/sin based on axis
- if (axis < S0) {
- // Temporal axis
- cos_v = __bfloat162float(__ldg(cr_base_t + d));
- sin_v = __bfloat162float(__ldg(sr_base_t + d));
- } else if (axis < S0 + S1) {
- // Height axis
- cos_v = __bfloat162float(__ldg(cr_base_h + d));
- sin_v = __bfloat162float(__ldg(sr_base_h + d));
- } else {
- // Width axis (remainder, S0+S1 <= axis < HALF)
- cos_v = __bfloat162float(__ldg(cr_base_w + d));
- sin_v = __bfloat162float(__ldg(sr_base_w + d));
- }
+ const __nv_bfloat16 *cr, *sr;
+ if (axis < S0) { cr = cr_t; sr = sr_t; }
+ else if (axis < S0 + S1) { cr = cr_h; sr = sr_h; }
+ else { cr = cr_w; sr = sr_w; }
- // Negate for first half
+ float cos_v = __bfloat162float(__ldg(cr + d));
+ float sin_v = __bfloat162float(__ldg(sr + d));
if (d < HALF) mv = -mv;
-
- // Apply: out = x*cos + rot*sin
xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
}
}
-extern "C" void launch_rope(
- const __nv_bfloat16* x_in, __nv_bfloat16* x_out,
+extern "C" void* launch_rope(
+ const void* x_in, void* x_out,
const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
- const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+ const void* cos_cache, const void* sin_cache,
int B, int T, int Hx, int D, int s0, int s1, int s2, cudaStream_t stream)
{
dim3 block(min(Hx, 128));
dim3 grid(B * T, (Hx + block.x - 1) / block.x);
- if (D == 128 && s0 == 16 && s1 == 24 && s2 == 24) {
+ if (D == 128 && s0 == 16 && s1 == 24) {
rope_kernel<128, 64, 16, 24, 24><<<grid, block, 0, stream>>>(
- x_in, x_out, pos_t, pos_h, pos_w, cos_cache, sin_cache, B, T, Hx);
- } else if (D == 64 && s0 == 8 && s1 == 12 && s2 == 12) {
+ (const __nv_bfloat16*)x_in, (__nv_bfloat16*)x_out,
+ pos_t, pos_h, pos_w,
+ (const __nv_bfloat16*)cos_cache, (const __nv_bfloat16*)sin_cache,
+
... (2586 more chars){
"content": "<6748 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"name": "kernel-optimization"
}<skill_content name="kernel-optimization">
# Skill: kernel-optimization
# Kernel Optimization
Unified reference for GPU/TPU kernel work across all platforms. This skill routes to specialized sub-references based on what you need.
## When this skill fires
- Writing or optimizing CUDA, Metal, Triton, HIP, or Pallas kernels
- Porting kernels between platforms (CUDA <-> Metal <-> Triton <-> HIP <-> Pallas)
- Debugging GPU code (OOM, NaN, race conditions, wrong results, hangs)
- Performance profiling and optimization
- Numerical precision issues (FP formats, mixed precision, tolerance)
- AMD MI300/MI350 or TPU v4/v5/v6/v7 specific work
## Sub-references (read on demand)
The following files contain deep platform-specific knowledge. Read the relevant one(s) based on the task:
| File | When to read |
|------|-------------|
| `~/.claude/skills/amd-gpu/SKILL.md` | AMD CDNA/RDNA, ROCm, HIP, MFMA, LDS, chiplet scheduling, HipKittens, AITER |
| `~/.claude/skills/tpu-opt/SKILL.md` | TPU v4-v7, MXU, Pallas kernels, JAX sharding, XLA gotchas, SplashAttention |
| `~/.claude/skills/port-kernel/SKILL.md` | Cross-platform porting: CUDA <-> Metal <-> Triton concept maps, API equivalents, gotchas |
| `~/.claude/skills/debug-gpu/SKILL.md` | OOM diagnosis, NaN tracing, race conditions, compute-sanitizer, FP format specs, tolerance tables |
| `~/.claude/skills/perf-optim/SKILL.md` | Profiling (nsys/ncu/Metal), roofline, occupancy, memory hierarchy, CUTLASS/CuTe, quantization, KernelBench |
| `
... (1462 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,54 +1,26 @@
"""Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
Each thread handles one (b, h, t) and processes all D elements in an unrolled loop.
-Uses __ldg for read-only cache hints. Separate Q/K kernel launches.
+Uses __ldg for read-only cache hints.
"""
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
+import shutil, os
-# Clean old builds
-import shutil, os
-for d in ["mrope_cuda_64", "mrope_cuda_128"]:
- path = os.path.join(os.path.expanduser("~/.cache/torch_extensions/py311_cu130"), d)
- if os.path.exists(path):
- shutil.rmtree(path, ignore_errors=True)
+for d in ["mrope_cuda64", "mrope_cuda128"]:
+ p = os.path.expanduser(f"~/.cache/torch_extensions/py311_cu130/{d}")
+ if os.path.exists(p):
+ shutil.rmtree(p, ignore_errors=True)
-PYBIND_DEFS = r"""
-extern void* launch_rope(
- const void* x_in, void* x_out,
- const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
- const void* cos_cache, const void* sin_cache,
- int B, int T, int Hx, int D, int s0, int s1, int s2,
- cudaStream_t stream);
+CPP_STUB = ""
+CUDA_SRC = r"""
+#include <cuda_bf16.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
-namespace py = pybind11;
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("launch_rope", [](torch::Tensor x_in, torch::Tensor x_out,
- torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
- torch::Tensor cos_cache, torch::Tensor sin_cache,
- int B, int T, int Hx, int D, int s0, int s1, int s2) {
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_in.scalar_type(), "launch_rope", [&] {
- using scalar_t = char; // we'll use void* internally
- launch_rope(
- x_in.data_ptr(), x_out.data_ptr(),
- pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
- cos_cache.data_ptr(), sin_cache.data_ptr(),
- B, T, Hx, D, s0, s1, s2,
- at::cuda::getCurrentCUDAStream().stream());
- });
- }, "launch_rope");
-}
-"""
-
-KERNEL_SRC = r"""
-#include <cuda_bf16.h>
-#include <cuda_runtime.h>
template<int D, int HALF, int S0, int S1, int S2>
__global__ void __launch_bounds__(128)
@@ -94,42 +66,56 @@
float xv = __bfloat162float(__ldg(xi + d));
float mv = __bfloat162float(__ldg(xi + mirror));
+ float cos_v, sin_v;
- const __nv_bfloat16 *cr, *sr;
- if (axis < S0) { cr = cr_t; sr = sr_t; }
- else if (axis < S0 + S1) { cr = cr_h; sr = sr_h; }
- else { cr = cr_w; sr = sr_w; }
+ if (axis < S0) {
+ cos_v = __bfloat162float(__ldg(cr_t + d));
+ sin_v = __bfloat162float(__ldg(sr_t + d));
+ } else if (axis < S0 + S1) {
+ cos_v = __bfloat162float(__ldg(cr_h + d));
+ sin_v = __bfloat162float(__ldg(sr_h + d));
+ } else {
+ cos_v = __bfloat162float(__ldg(cr_w + d));
+ sin_v = __bfloat162float(__ldg(sr_w + d));
+ }
- float cos_v = __bfloat162float(__ldg(cr + d));
- float sin_v = __bfloat162float(__ldg(sr + d));
if (d < HALF) mv = -mv;
xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
}
}
-extern "C" void* launch_rope(
- const void* x_in, void* x_out,
- const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
- const void* cos_cache, const void* sin_cache,
- int B, int T, int Hx, int D, int s0, int s1, int s2, cudaStream_t stream)
+void launch_rope(
+ torch::Tensor x_in, torch::Tensor x_out,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ int B, int T, int Hx, int D, int s0, int s1, int s2)
{
dim3 block(min(Hx, 128));
dim3 grid(B * T, (Hx + block.x - 1) / block.x);
- if (D == 128 && s0 == 16 && s1 == 24) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+
+ if (D == 128 && s0 == 16 && s1 == 24 && s2 == 24) {
rope_kernel<128, 64, 16, 24, 24><<<grid, block, 0, stream>>>(
- (const __nv_bfloat16*)x_in, (__nv_bfloat16*)x_out,
- pos_t, pos_h, pos_w,
- (const __nv_bfloat16*)cos_cache, (const __nv_bfloat16*)sin_cache,
+ static_cast<const __nv_bfloat16*>(x_in.data_ptr()),
+ static_cast<__nv_bfloat16*>(x_out.data_ptr()),
+ pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
+ static_cast<const __nv_bfloat16*>(cos_cache.data_ptr()),
+ static_cast<const __nv_bfloat16*>(sin_cache.data_ptr()),
B, T, Hx);
- } else if (D == 64 && s0 == 8 && s1 == 12) {
+ } else if (D == 64 && s0 == 8 && s1 == 12 && s2 == 12) {
rope_kernel<64, 32, 8, 12, 12><<<grid, block, 0, stream>>>(
- (const __nv_bfloat16*)x_in, (__nv_bfloat16*)x_out,
- pos_t, pos_h, pos_w,
- (const __nv_bfloat16*)cos_cache, (const __nv_bfloat16*)sin_cache,
+ static_cast<const __nv_bfloat16*>(x_in.data_ptr()),
+ static_cast<__nv_bfloat16*>(x_out.data_ptr()),
+ pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
+ static_cast<const __nv_bfloat16*>(cos_cache.data_ptr()),
+ static_cast<const __nv_bfloat16*>(sin_cache.data_ptr()),
B, T, Hx);
}
- return nullptr;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("launch_rope", &launch_rope, "launch_rope");
}
"""
@@ -139,10 +125,13 @@
if D in _mod_cache:
return _mod_cache[D]
mod = load_inline(
- name=f"mrope
... (1636 more chars){
"content": "<6284 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,146 +1,70 @@
-"""Multi-axis RoPE pre-attention prep — hand-tuned CUDA C++ kernel (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
-Each thread handles one (b, h, t) and processes all D elements in an unrolled loop.
-Uses __ldg for read-only cache hints.
+Grid = (B * T). Each program handles one (b,t) across all H heads.
+Cos/sin cache gathered once per position, reused across all heads in the program.
"""
import torch
+import triton
+import triton.language as tl
import torch.nn as nn
-from torch.utils.cpp_extension import load_inline
-import shutil, os
-for d in ["mrope_cuda64", "mrope_cuda128"]:
- p = os.path.expanduser(f"~/.cache/torch_extensions/py311_cu130/{d}")
- if os.path.exists(p):
- shutil.rmtree(p, ignore_errors=True)
-CPP_STUB = ""
+@triton.jit
+def _rope_fwd_kernel(
+ x_ptr, x_out_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ B_val: tl.constexpr, T_val: tl.constexpr,
+ H_val: tl.constexpr, D_val: tl.constexpr,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_ob, stride_oh, stride_ot, stride_od,
+ s0: tl.constexpr, s1: tl.constexpr,
+ half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+ """Grid = (B * T). Each program = one (b,t); serially loops over H heads."""
+ pid_bt = tl.program_id(0)
+ pid_b = pid_bt // T_val
+ pid_t = pid_bt % T_val
-CUDA_SRC = r"""
-#include <cuda_bf16.h>
-#include <pybind11/pybind11.h>
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
+ pos_idx = pid_b * T_val + pid_t
+ p_t = tl.load(pos_t_ptr + pos_idx)
+ p_h = tl.load(pos_h_ptr + pos_idx)
+ p_w = tl.load(pos_w_ptr + pos_idx)
-template<int D, int HALF, int S0, int S1, int S2>
-__global__ void __launch_bounds__(128)
-rope_kernel(
- const __nv_bfloat16* x_in,
- __nv_bfloat16* x_out,
- const int64_t* pos_t,
- const int64_t* pos_h,
- const int64_t* pos_w,
- const __nv_bfloat16* cos_cache,
- const __nv_bfloat16* sin_cache,
- int B, int T, int Hx)
-{
- int bt = blockIdx.x;
- int b = bt / T;
- int t = bt % T;
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D_val
- int64_t pi = (int64_t)b * T + t;
- int64_t p_t = pos_t[pi];
- int64_t p_h = pos_h[pi];
- int64_t p_w = pos_w[pi];
+ # Gather cos/sin from cache (6 loads per (b,t), reused for all H heads)
+ cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_h = tl.load(sin_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
+ sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- int h = blockIdx.y * blockDim.x + threadIdx.x;
- if (h >= Hx) return;
+ # Determine which axis each element belongs to
+ axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
+ is_t = axis_idx < s0
+ is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
+ cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+ sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
- int64_t in_off = ((int64_t)b * T * Hx + (int64_t)t * Hx + (int64_t)h) * D;
- int64_t out_off = ((int64_t)b * Hx * T + (int64_t)h * T + (int64_t)t) * D;
+ # rotate_half: x_rot[d] = -x[d+half] if d < half, else x[d-half]
+ perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
+ sign = tl.where(d_offs < half, -1.0, 1.0)
- const __nv_bfloat16* xi = x_in + in_off;
- __nv_bfloat16* xo = x_out + out_off;
-
- const __nv_bfloat16 *cr_t = cos_cache + p_t * D;
- const __nv_bfloat16 *sr_t = sin_cache + p_t * D;
- const __nv_bfloat16 *cr_h = cos_cache + p_h * D;
- const __nv_bfloat16 *sr_h = sin_cache + p_h * D;
- const __nv_bfloat16 *cr_w = cos_cache + p_w * D;
- const __nv_bfloat16 *sr_w = sin_cache + p_w * D;
-
- #pragma unroll
- for (int d = 0; d < D; d++) {
- int mirror = (d < HALF) ? (d + HALF) : (d - HALF);
- int axis = (d < HALF) ? d : (d - HALF);
-
- float xv = __bfloat162float(__ldg(xi + d));
- float mv = __bfloat162float(__ldg(xi + mirror));
- float cos_v, sin_v;
-
- if (axis < S0) {
- cos_v = __bfloat162float(__ldg(cr_t + d));
- sin_v = __bfloat162float(__ldg(sr_t + d));
- } else if (axis < S0 + S1) {
- cos_v = __bfloat162float(__ldg(cr_h + d));
- sin_v = __bfloat162float(__ldg(sr_h + d));
- } else {
- cos_v = __bfloat162float(__ldg(cr_w + d));
- sin_v = __bfloat162float(__ldg(sr_w + d));
- }
-
- if (d < HALF) mv = -mv;
- xo[d] = __float2bfloat16(xv * cos_v + mv * sin_v);
- }
-}
-
-void launch_rope(
- torch::Tensor x_in, torch::Tensor x_out,
- torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
- torch::Tensor cos_cache, torch::Tensor sin_cache,
- int B, int T, int Hx, int D, int s0, int s1, int s2)
-{
- dim3 block(min(Hx, 128));
- dim3 grid(B * T, (Hx + block.x - 1) / block.x);
-
- cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
-
- if (D == 128 && s0 == 16 && s1 == 24 && s2 == 24) {
- rope_kernel<128, 64, 16, 24, 24><<<grid, block, 0, stream>>>(
- static_cast<const __nv_bfloat16*>(x_in.data_ptr()),
- static_cast<__nv_bfloat16*>(x_out.data_ptr()),
- pos_t.data_ptr<int64_t>(), pos_h.data_ptr<int64_t>(), pos_w.data_ptr<int64_t>(),
- static_cast<const __nv_bfloat16*>(cos_cache.data_ptr(
... (4107 more chars){
"content": "<4940 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run benchmark"
}shape=0 variant=eager tflops=0.176 gbps=252.857 ms=0.357
shape=0 variant=compiled tflops=0.481 gbps=690.605 ms=0.131
shape=0 variant=solution tflops=0.744 gbps=1067.797 ms=0.085
shape=0 solution_peak_fraction=0.5932
shape=1 variant=eager tflops=0.176 gbps=257.593 ms=0.571
shape=1 variant=compiled tflops=0.471 gbps=687.417 ms=0.214
shape=1 variant=solution tflops=0.777 gbps=1133.958 ms=0.130
shape=1 solution_peak_fraction=0.6300
shape=2 variant=eager tflops=0.071 gbps=111.637 ms=0.198
shape=2 variant=compiled tflops=0.235 gbps=367.464 ms=0.060
shape=2 variant=solution tflops=0.324 gbps=506.188 ms=0.044
shape=2 solution_peak_fraction=0.2812
shape=3 variant=eager tflops=0.179 gbps=257.102 ms=1.405
shape=3 variant=compiled tflops=0.528 gbps=758.085 ms=0.476
shape=3 variant=solution tflops=0.915 gbps=1312.454 ms=0.275
shape=3 solution_peak_fraction=0.7291
peak_fraction: 0.5261
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
-"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep (SM120 Blackwell).
-Grid = (B * T). Each program handles one (b,t) across all H heads.
-Cos/sin cache gathered once per position, reused across all heads in the program.
+Hybrid approach: use Triton for large shapes and a specialized pure-PyTorch
+implementation for small shapes where Triton overhead dominates.
"""
import torch
@@ -15,6 +15,29 @@
x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ H_val: tl.constexpr, D_val: tl.constexpr,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_ob, stride_oh, stride_ot, stride_od,
+ s0: tl.constexpr, s1: tl.constexpr,
+ half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+ """Each program handles one (b, t). Cos/sin loaded once, reused across H heads."""
+ pid_bt = tl.program_id(0)
+ pid_b = pid_b_t // D_val # Will be set by caller
+ # Actually we need B and T passed in
+ pass
+
+
+def _build_inv_freq(D, base=10000.0):
+ return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+
+
+# Triton kernel
+@triton.jit
+def _rope_triton_kernel(
+ x_ptr, x_out_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
B_val: tl.constexpr, T_val: tl.constexpr,
H_val: tl.constexpr, D_val: tl.constexpr,
stride_xb, stride_xt, stride_xh, stride_xd,
@@ -22,7 +45,6 @@
s0: tl.constexpr, s1: tl.constexpr,
half: tl.constexpr, BLOCK_D: tl.constexpr,
):
- """Grid = (B * T). Each program = one (b,t); serially loops over H heads."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
@@ -32,10 +54,9 @@
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- d_offs = tl.arange(0, BLOCK_D)
+ d_offs = tl.arange(0,BLOCK_D)
mask = d_offs < D_val
- # Gather cos/sin from cache (6 loads per (b,t), reused for all H heads)
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -43,14 +64,12 @@
cos_w = tl.load(cos_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_w = tl.load(sin_cache_ptr + p_w * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
- # Determine which axis each element belongs to
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
is_t = axis_idx < s0
is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
- # rotate_half: x_rot[d] = -x[d+half] if d < half, else x[d-half]
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
@@ -65,6 +84,7 @@
class Model(nn.Module):
+ """Triton mRoPE kernel."""
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__init__()
assert sum(mrope_section) == D // 2
@@ -86,41 +106,50 @@
v *= 2
return v
- def forward(self, q, k, pos_t, pos_h, pos_w):
- B, T, Hq, D = q.shape
- _, _, Hkv, _ = k.shape
+ def _launch(self, x, x_out, H):
+ B, T, _, D = x.shape
s0, s1, _ = self.mrope_section
half = D // 2
BLOCK_D = self._next_pow2(D)
+ _rope_triton_kernel[(B * T,)](
+ x, x_out,
+ self.cos_cache, self.sin_cache,
+ x.reshape(-1, T, H, D)[:, :, 0, :].data_ptr(),
+ pos_t, pos_h, pos_w,
+ B_val=B, T_val=T, H_val=H, D_val=D,
+ stride_xb=T*H*D, stride_xt=H*D, stride_xh=D, stride_xd=1,
+ stride_ob=H*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
+ s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+ num_warps=1, num_stages=4,
+ )
+
+ def forward(self, q, k, pos_t, pos_h, pos_w):
+ B, T, Hq, D = q.shape
+ _, _, Hkv, _ = k.shape
+
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
- bt_grid = (B * T,)
-
- _rope_fwd_kernel[bt_grid](
+ _rope_triton_kernel[(B * T,)](
q, q_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hq, D_val=D,
stride_xb=T*Hq*D, stride_xt=Hq*D, stride_xh=D, stride_xd=1,
stride_ob=Hq*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
- s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+ s0=s0, s1=s1, half=D//2, BLOCK_D=self._next_pow2(D),
num_warps=1, num_stages=4,
)
- _rope_fwd_kernel[bt_grid](
+ _rope_triton_kernel[(B * T,)](
k, k_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hkv, D_val=D,
stride_xb=T*Hkv*D, stride_xt=Hkv*D, stride_xh=D, stride_xd=1,
stride_ob=Hkv*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
- s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+ s0=s0, s1=s1, half=D//2, BLOCK_D=self._next_pow2(D),
num_warps=1, num_stages=4,
)
return q_out, k_out
-
-
-def _build_inv_freq(D, base=10000.0):
- return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
{
"content": "<5818 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,6 @@
-"""Multi-axis RoPE pre-attention prep (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
-Hybrid approach: use Triton for large shapes and a specialized pure-PyTorch
-implementation for small shapes where Triton overhead dominates.
+Grid = (B * T). Each program handles one (b,t) across all H heads.
"""
import torch
@@ -12,29 +11,6 @@
@triton.jit
def _rope_fwd_kernel(
- x_ptr, x_out_ptr,
- cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- H_val: tl.constexpr, D_val: tl.constexpr,
- stride_xb, stride_xt, stride_xh, stride_xd,
- stride_ob, stride_oh, stride_ot, stride_od,
- s0: tl.constexpr, s1: tl.constexpr,
- half: tl.constexpr, BLOCK_D: tl.constexpr,
-):
- """Each program handles one (b, t). Cos/sin loaded once, reused across H heads."""
- pid_bt = tl.program_id(0)
- pid_b = pid_b_t // D_val # Will be set by caller
- # Actually we need B and T passed in
- pass
-
-
-def _build_inv_freq(D, base=10000.0):
- return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
-
-
-# Triton kernel
-@triton.jit
-def _rope_triton_kernel(
x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
@@ -54,7 +30,7 @@
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
- d_offs = tl.arange(0,BLOCK_D)
+ d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -84,7 +60,6 @@
class Model(nn.Module):
- """Triton mRoPE kernel."""
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__init__()
assert sum(mrope_section) == D // 2
@@ -106,50 +81,39 @@
v *= 2
return v
- def _launch(self, x, x_out, H):
- B, T, _, D = x.shape
- s0, s1, _ = self.mrope_section
- half = D // 2
- BLOCK_D = self._next_pow2(D)
-
- _rope_triton_kernel[(B * T,)](
- x, x_out,
- self.cos_cache, self.sin_cache,
- x.reshape(-1, T, H, D)[:, :, 0, :].data_ptr(),
- pos_t, pos_h, pos_w,
- B_val=B, T_val=T, H_val=H, D_val=D,
- stride_xb=T*H*D, stride_xt=H*D, stride_xh=D, stride_xd=1,
- stride_ob=H*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
- s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
- num_warps=1, num_stages=4,
- )
-
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
+ s0, s1, _ = self.mrope_section
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
- _rope_triton_kernel[(B * T,)](
+ bt_grid = (B * T,)
+ BLOCK_D = self._next_pow2(D)
+
+ _rope_fwd_kernel[bt_grid](
q, q_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hq, D_val=D,
stride_xb=T*Hq*D, stride_xt=Hq*D, stride_xh=D, stride_xd=1,
stride_ob=Hq*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
- s0=s0, s1=s1, half=D//2, BLOCK_D=self._next_pow2(D),
+ s0=s0, s1=s1, half=D//2, BLOCK_D=BLOCK_D,
num_warps=1, num_stages=4,
)
- _rope_triton_kernel[(B * T,)](
+ _rope_fwd_kernel[bt_grid](
k, k_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hkv, D_val=D,
stride_xb=T*Hkv*D, stride_xt=Hkv*D, stride_xh=D, stride_xd=1,
stride_ob=Hkv*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
- s0=s0, s1=s1, half=D//2, BLOCK_D=self._next_pow2(D),
+ s0=s0, s1=s1, half=D//2, BLOCK_D=BLOCK_D,
num_warps=1, num_stages=4,
)
+ return q_out, k_out
- return q_out, k_out
+
+def _build_inv_freq(D, base=10000.0):
+ return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
{
"content": "<4558 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run correctness check"
}PASS
{
"command": "nvidia-smi 2>&1 | head -5",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Check GPU status"
}Thu May 7 13:03:21 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,6 +1,7 @@
-"""Multi-axis RoPE pre-attention prep — optimized Triton kernel (SM120 Blackwell).
+"""Multi-axis RoPE pre-attention prep — Triton kernel (SM120 Blackwell).
-Grid = (B * T). Each program handles one (b,t) across all H heads.
+Grid = (B * T * H, D // BLOCK_SIZE). Each program handles BLOCK_SIZE elements
+of D for one (b, h, t). This gives maximum SM utilization.
"""
import torch
@@ -10,28 +11,35 @@
@triton.jit
-def _rope_fwd_kernel(
+def _rope_kernel(
x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
- B_val: tl.constexpr, T_val: tl.constexpr,
- H_val: tl.constexpr, D_val: tl.constexpr,
+ B_val: tl.constexpr, T_val: tl.constexpr, H_val: tl.constexpr, D_val: tl.constexpr,
stride_xb, stride_xt, stride_xh, stride_xd,
stride_ob, stride_oh, stride_ot, stride_od,
s0: tl.constexpr, s1: tl.constexpr,
half: tl.constexpr, BLOCK_D: tl.constexpr,
):
- pid_bt = tl.program_id(0)
- pid_b = pid_bt // T_val
- pid_t = pid_bt % T_val
+ """Grid = (B * T * H, triton.cdiv(D, BLOCK_D)).
+ Each program handles BLOCK_D consecutive elements of D."""
+ pid_bth = tl.program_id(0)
+ pid_dd = tl.program_id(1)
+ pid_b = pid_bth // (T_val * H_val)
+ pid_th = pid_bth % (T_val * H_val)
+ pid_t = pid_th // H_val
+ pid_h = pid_th % H_val
+
+ d_start = pid_dd * BLOCK_D
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D_val
+
+ # Each position lookup
pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
-
- d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -49,14 +57,14 @@
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
- for h in tl.static_range(H_val):
- x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = xr * sign
- out = x * cos_val + xr * sin_val
- o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ xr = xr * sign
+ out = x * cos_val + xr * sin_val
+
+ o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
@@ -75,43 +83,43 @@
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
- def _next_pow2(self, n):
- v = 1
- while v < n:
- v *= 2
- return v
-
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
s0, s1, _ = self.mrope_section
+ half = D // 2
+
+ BLOCK_D = 32
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
- bt_grid = (B * T,)
- BLOCK_D = self._next_pow2(D)
-
- _rope_fwd_kernel[bt_grid](
+ q_grid_0 = B * T * Hq
+ q_grid_1 = (D + BLOCK_D - 1) // BLOCK_D
+ _rope_kernel[(q_grid_0, q_grid_1)](
q, q_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hq, D_val=D,
stride_xb=T*Hq*D, stride_xt=Hq*D, stride_xh=D, stride_xd=1,
stride_ob=Hq*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
- s0=s0, s1=s1, half=D//2, BLOCK_D=BLOCK_D,
- num_warps=1, num_stages=4,
+ s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+ num_warps=1,
)
- _rope_fwd_kernel[bt_grid](
+
+ k_grid_0 = B * T * Hkv
+ k_grid_1 = (D + BLOCK_D - 1) // BLOCK_D
+ _rope_kernel[(k_grid_0, k_grid_1)](
k, k_out,
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
B_val=B, T_val=T, H_val=Hkv, D_val=D,
stride_xb=T*Hkv*D, stride_xt=Hkv*D, stride_xh=D, stride_xd=1,
stride_ob=Hkv*T*D, stride_oh=T*D, stride_ot=D, stride_od=1,
- s0=s0, s1=s1, half=D//2, BLOCK_D=BLOCK_D,
- num_warps=1, num_stages=4,
+ s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+ num_warps=1,
)
+
return q_out, k_out
{
"content": "<4874 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"description": "Run correctness check"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run benchmark"
}shape=0 variant=eager tflops=0.176 gbps=252.494 ms=0.358
shape=0 variant=compiled tflops=0.518 gbps=743.182 ms=0.121
shape=0 variant=solution tflops=0.185 gbps=266.118 ms=0.339
shape=0 solution_peak_fraction=0.1478
shape=1 variant=eager tflops=0.179 gbps=261.710 ms=0.562
shape=1 variant=compiled tflops=0.460 gbps=671.588 ms=0.219
shape=1 variant=solution tflops=0.190 gbps=276.743 ms=0.531
shape=1 solution_peak_fraction=0.1537
shape=2 variant=eager tflops=0.073 gbps=114.456 ms=0.193
shape=2 variant=compiled tflops=0.169 gbps=263.867 ms=0.084
shape=2 variant=solution tflops=0.146 gbps=228.874 ms=0.097
shape=2 solution_peak_fraction=0.1272
shape=3 variant=eager tflops=0.171 gbps=245.235 ms=1.472
shape=3 variant=compiled tflops=0.464 gbps=665.123 ms=0.543
shape=3 variant=solution tflops=0.189 gbps=271.372 ms=1.331
shape=3 solution_peak_fraction=0.1508
peak_fraction: 0.1445
RESULT: OK
{
"todos": [
{
"content": "Implement initial Triton kernel for mRoPE",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Optimize kernel - fuse Q+K into single kernel, try different grid configs",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run benchmark.py to measure achieved bandwidth",
"status": "pending",
"priority": "medium"
}
]
}content: 'Implement initial Triton kernel for mRoPE'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Optimize kernel - fuse Q+K into single kernel, try different grid configs'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark.py to measure achieved bandwidth'
status: 'pending'
priority: 'medium'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,7 @@
-"""Multi-axis RoPE pre-attention prep — Triton kernel (SM120 Blackwell).
+"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-Grid = (B * T * H, D // BLOCK_SIZE). Each program handles BLOCK_SIZE elements
-of D for one (b, h, t). This gives maximum SM utilization.
+Fused Q+K: Grid = (B * T), each kernel processes all Q and K heads for one (b,t).
+Position cache gathered once per (b,t), shared across all Q+K heads.
"""
import torch
@@ -11,36 +11,34 @@
@triton.jit
-def _rope_kernel(
- x_ptr, x_out_ptr,
+def _rope_qk_kernel(
+ q_ptr, k_ptr,
+ q_out_ptr, k_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
- B_val: tl.constexpr, T_val: tl.constexpr, H_val: tl.constexpr, D_val: tl.constexpr,
- stride_xb, stride_xt, stride_xh, stride_xd,
- stride_ob, stride_oh, stride_ot, stride_od,
+ B_val: tl.constexpr, T_val: tl.constexpr,
+ Hq_val: tl.constexpr, Hkv_val: tl.constexpr,
+ D_val: tl.constexpr,
+ stride_qb, stride_qt, stride_qh, stride_qd,
+ stride_kb, stride_kt, stride_kh, stride_kd,
+ stride_qob, stride_qoh, stride_qot, stride_qod,
+ stride_kob, stride_koh, stride_kot, stride_kod,
s0: tl.constexpr, s1: tl.constexpr,
half: tl.constexpr, BLOCK_D: tl.constexpr,
):
- """Grid = (B * T * H, triton.cdiv(D, BLOCK_D)).
- Each program handles BLOCK_D consecutive elements of D."""
- pid_bth = tl.program_id(0)
- pid_dd = tl.program_id(1)
+ pid_bt = tl.program_id(0)
+ pid_b = pid_bt // T_val
+ pid_t = pid_bt % T_val
- pid_b = pid_bth // (T_val * H_val)
- pid_th = pid_bth % (T_val * H_val)
- pid_t = pid_th // H_val
- pid_h = pid_th % H_val
-
- d_start = pid_dd * BLOCK_D
- d_offs = d_start + tl.arange(0, BLOCK_D)
- mask = d_offs < D_val
-
- # Each position lookup
pos_idx = pid_b * T_val + pid_t
p_t = tl.load(pos_t_ptr + pos_idx)
p_h = tl.load(pos_h_ptr + pos_idx)
p_w = tl.load(pos_w_ptr + pos_idx)
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D_val
+
+ # Cos/sin gathered once (6 reads), shared by all Q+K heads
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -57,14 +55,25 @@
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
- x_base = pid_b * stride_xb + pid_t * stride_xt + pid_h * stride_xh
- x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
- xr = xr * sign
- out = x * cos_val + xr * sin_val
+ # Process Q heads
+ for h in tl.static_range(Hq_val):
+ x_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qh
+ x = tl.load(q_ptr + x_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
+ xr = tl.load(q_ptr + x_base + perm * stride_qd, mask=mask, other=0.0).to(tl.float32)
+ xr = xr * sign
+ out = x * cos_val + xr * sin_val
+ o_base = pid_b * stride_qob + h * stride_qoh + pid_t * stride_qot
+ tl.store(q_out_ptr + o_base + d_offs * stride_qod, out.to(tl.bfloat16), mask=mask)
- o_base = pid_b * stride_ob + pid_h * stride_oh + pid_t * stride_ot
- tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+ # Process K heads
+ for h in tl.static_range(Hkv_val):
+ x_base = pid_b * stride_kb + pid_t * stride_kt + h * stride_kh
+ x = tl.load(k_ptr + x_base + d_offs * stride_kd, mask=mask, other=0.0).to(tl.float32)
+ xr = tl.load(k_ptr + x_base + perm * stride_kd, mask=mask, other=0.0).to(tl.float32)
+ xr = xr * sign
+ out = x * cos_val + xr * sin_val
+ o_base = pid_b * stride_kob + h * stride_koh + pid_t * stride_kot
+ tl.store(k_out_ptr + o_base + d_offs * stride_kod, out.to(tl.bfloat16), mask=mask)
class Model(nn.Module):
@@ -83,41 +92,64 @@
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+ def _next_pow2(self, n):
+ v = 1
+ while v < n:
+ v *= 2
+ return v
+
+ def _call_qk_kernel(self, q, q_out, k, k_out):
+ B, T, Hq, D = q.shape
+ _, _, Hkv, _ = k.shape
+ s0, s1, _ = self.mrope_section
+ half = D // 2
+ BLOCK_D = self._next_pow2(D)
+ bt_grid = (B * T,)
+
+ _rope_qk_kernel[bt_grid](
+ q, k,
+ q_out, k_out,
+ self.cos_cache, self.sin_cache,
+ q.reshape(-1, T, 1, 1)[:, 0, 0, :].data_ptr(), # placeholder
+ pos_h, pos_w,
+ B_val=B, T_val=T,
+ Hq_val=Hq, Hkv_val=Hkv,
+ D_val=D,
+ # Will fill in proper strides below
+ stride_qb=T*Hq*D, stride_qt=Hq*D, stride_qh=D, stride_qd=1,
+ stride_kb=T*Hkv*D, stride_kt=Hkv*D, stride_kh=D, stride_kd=1,
+ stride_qob=Hq*T*D, stride_qoh=T*D, stride_qot=D, stride_qod=1,
+ stride_kob=Hkv*T*D, stride_koh=T*D, stride_kot=D, stride_kod=1,
+ s0=s0, s1=s1, half=half, BLOCK_D=BLOCK_D,
+ num_warps=1, num_stages=4,
+ )
+
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
s0, s1, _ = self.mrope_section
- half = D // 2
-
- BLOCK_D = 32
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=torch.b
... (1697 more chars){
"content": "<6231 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run correctness check"
}PASS
{
"command": "cd /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope && uv run python benchmark.py 2>&1",
"timeout": 300000,
"description": "Run benchmark for QK fused kernel"
}shape=0 variant=eager tflops=0.180 gbps=257.766 ms=0.350
shape=0 variant=compiled tflops=0.491 gbps=703.960 ms=0.128
shape=0 variant=solution tflops=0.683 gbps=980.236 ms=0.092
shape=0 solution_peak_fraction=0.5446
shape=1 variant=eager tflops=0.181 gbps=265.039 ms=0.555
shape=1 variant=compiled tflops=0.470 gbps=686.185 ms=0.214
shape=1 variant=solution tflops=0.751 gbps=1096.340 ms=0.134
shape=1 solution_peak_fraction=0.6091
shape=2 variant=eager tflops=0.075 gbps=116.658 ms=0.190
shape=2 variant=compiled tflops=0.168 gbps=262.864 ms=0.084
shape=2 variant=solution tflops=0.341 gbps=532.717 ms=0.042
shape=2 solution_peak_fraction=0.2960
shape=3 variant=eager tflops=0.173 gbps=248.365 ms=1.454
shape=3 variant=compiled tflops=0.467 gbps=670.000 ms=0.539
shape=3 variant=solution tflops=0.864 gbps=1240.121 ms=0.291
shape=3 solution_peak_fraction=0.6890
peak_fraction: 0.5100
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,7 +1,10 @@
-"""Optimized multi-axis RoPE pre-attention prep using Triton (SM120 Blackwell).
-
-Fused Q+K: Grid = (B * T), each kernel processes all Q and K heads for one (b,t).
-Position cache gathered once per (b,t), shared across all Q+K heads.
+"""Multi-axis RoPE pre-attention prep — Triton v3 (SM120 Blackwell).
+
+Two kernel types:
+- Large shapes: Grid=(B*T, 1), each program loops over all H heads (serial).
+- Small H shapes: Grid=(B*T*H, 1), each program processes exactly one (b,h,t).
+
+The second gives more parallel SMs for better occupancy with small head counts.
"""
import torch
@@ -11,21 +14,18 @@
@triton.jit
-def _rope_qk_kernel(
- q_ptr, k_ptr,
- q_out_ptr, k_out_ptr,
+def _rope_serial_h_kernel(
+ x_ptr, x_out_ptr,
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
B_val: tl.constexpr, T_val: tl.constexpr,
- Hq_val: tl.constexpr, Hkv_val: tl.constexpr,
- D_val: tl.constexpr,
- stride_qb, stride_qt, stride_qh, stride_qd,
- stride_kb, stride_kt, stride_kh, stride_kd,
- stride_qob, stride_qoh, stride_qot, stride_qod,
- stride_kob, stride_koh, stride_kot, stride_kod,
- s0: tl.constexpr, s1: tl.constexpr,
- half: tl.constexpr, BLOCK_D: tl.constexpr,
-):
+ H_val: tl.constexpr, D_val: tl.constexpr,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_ob, stride_oh, stride_ot, stride_od,
+ s0: tl.constexpr, s1: tl.constexpr,
+ half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+ """Grid=(B*T). Serial loop over H heads."""
pid_bt = tl.program_id(0)
pid_b = pid_bt // T_val
pid_t = pid_bt % T_val
@@ -38,7 +38,6 @@
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D_val
- # Cos/sin gathered once (6 reads), shared by all Q+K heads
cos_t = tl.load(cos_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
sin_t = tl.load(sin_cache_ptr + p_t * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
cos_h = tl.load(cos_cache_ptr + p_h * D_val + d_offs, mask=mask, other=0.0).to(tl.float32)
@@ -49,31 +48,115 @@
axis_idx = tl.where(d_offs < half, d_offs, d_offs - half)
is_t = axis_idx < s0
is_h = (axis_idx >= s0) & (axis_idx < s0 + s1)
- cos_val = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
- sin_val = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
+ cos_v = tl.where(is_t, cos_t, tl.where(is_h, cos_h, cos_w))
+ sin_v = tl.where(is_t, sin_t, tl.where(is_h, sin_h, sin_w))
perm = tl.where(d_offs < half, d_offs + half, d_offs - half)
sign = tl.where(d_offs < half, -1.0, 1.0)
- # Process Q heads
- for h in tl.static_range(Hq_val):
- x_base = pid_b * stride_qb + pid_t * stride_qt + h * stride_qh
- x = tl.load(q_ptr + x_base + d_offs * stride_qd, mask=mask, other=0.0).to(tl.float32)
- xr = tl.load(q_ptr + x_base + perm * stride_qd, mask=mask, other=0.0).to(tl.float32)
+ for h in tl.static_range(H_val):
+ x_base = pid_b * stride_xb + pid_t * stride_xt + h * stride_xh
+ x = tl.load(x_ptr + x_base + d_offs * stride_xd, mask=mask, other=0.0).to(tl.float32)
+ xr = tl.load(x_ptr + x_base + perm * stride_xd, mask=mask, other=0.0).to(tl.float32)
xr = xr * sign
- out = x * cos_val + xr * sin_val
- o_base = pid_b * stride_qob + h * stride_qoh + pid_t * stride_qot
- tl.store(q_out_ptr + o_base + d_offs * stride_qod, out.to(tl.bfloat16), mask=mask)
-
- # Process K heads
- for h in tl.static_range(Hkv_val):
- x_base = pid_b * stride_kb + pid_t * stride_kt + h * stride_kh
- x = tl.load(k_ptr + x_base + d_offs * stride_kd, mask=mask, other=0.0).to(tl.float32)
- xr = tl.load(k_ptr + x_base + perm * stride_kd, mask=mask, other=0.0).to(tl.float32)
- xr = xr * sign
- out = x * cos_val + xr * sin_val
- o_base = pid_b * stride_kob + h * stride_koh + pid_t * stride_kot
- tl.store(k_out_ptr + o_base + d_offs * stride_kod, out.to(tl.bfloat16), mask=mask)
+ out = x * cos_v + xr * sin_v
+ o_base = pid_b * stride_ob + h * stride_oh + pid_t * stride_ot
+ tl.store(x_out_ptr + o_base + d_offs * stride_od, out.to(tl.bfloat16), mask=mask)
+
+
+@triton.jit
+def _rope_parallel_h_kernel(
+ x_ptr, x_out_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ B_val: tl.constexpr, T_val: tl.constexpr,
+ D_val: tl.constexpr,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_ob, stride_oh, stride_ot, stride_od,
+ s0: tl.constexpr, s1: tl.constexpr,
+ half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+ """Grid=(B*T*H). Each program = exactly one (b,h,t)."""
+ pid_bth = tl.program_id(0)
+ pid_b = pid_bth // tl.num_programs(1)
+ pid_th = pid_bth % tl.num_programs(1)
+ # Actually this doesn't work cleanly. Let me pass explicit B, T, H.
+ pass
+
+
+@triton.jit
+def _rope_flat_kernel(
+ x_ptr, x_out_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ H_val: tl.constexpr, D_val: tl.constexpr,
+ total_bt: tl.constexpr,
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ stride_ob, stride_oh, stride_ot, stride_od,
+ s0: tl.constexpr, s1: tl.constexpr,
+ half: tl.constexpr, BLOCK_D: tl.constexpr,
+):
+ """Grid=(B*T*H). Each program = one (b,h,t). No serial loop."""
+ pid = tl.program_id(0)
+ pid_h = pid % H_val
+ pid_bt = pid // H_val
+ pid_b = pid_bt // (total_bt // H_val) # total_bt = B*T, so pid_bt // (BT/H) doesn't work
+ # This approach is getting messy. Let me just pass B and T directly.
+ pass
+
+
+@triton.jit
+def _rope_one_per_head(
+ x_ptr, x_out_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ B_val: tl.constexpr, T_val: tl.constexpr,
+ H_val
... (6307 more chars){
"content": "<9792 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.