import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
_CUDA_SRC = r"""
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
template<int D_HALF, int BLOCK_SIZE, int BT_PER_BLOCK>
__global__ void mrope_fused_kernel(
const __nv_bfloat16* __restrict__ q_ptr,
const __nv_bfloat16* __restrict__ k_ptr,
__nv_bfloat16* __restrict__ q_rot_ptr,
__nv_bfloat16* __restrict__ k_rot_ptr,
const int64_t* __restrict__ pos_t_ptr,
const int64_t* __restrict__ pos_h_ptr,
const int64_t* __restrict__ pos_w_ptr,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
const int B, const int T, const int Hq, const int Hkv,
const int s0, const int s1,
const int q_sb, const int q_st, const int q_sh,
const int k_sb, const int k_st, const int k_sh,
const int oq_sb, const int oq_sh, const int oq_st,
const int ok_sb, const int ok_sh, const int ok_st,
const int pos_sb, const int pos_st,
const int cache_sp, const int cache_sd
) {
__shared__ __nv_bfloat16 smem_cos[BT_PER_BLOCK][D_HALF];
__shared__ __nv_bfloat16 smem_sin[BT_PER_BLOCK][D_HALF];
__shared__ int pos_t_vals[BT_PER_BLOCK];
__shared__ int pos_h_vals[BT_PER_BLOCK];
__shared__ int pos_w_vals[BT_PER_BLOCK];
const int tid = threadIdx.x;
const int H_total = Hq + Hkv;
const int bt_start = blockIdx.x * BT_PER_BLOCK;
// Phase 1: Load positions for all (b,t) in this block
for (int i = tid; i < BT_PER_BLOCK; i += BLOCK_SIZE) {
int bt = bt_start + i;
if (bt >= B * T) { pos_t_vals[i] = 0; pos_h_vals[i] = 0; pos_w_vals[i] = 0; continue; }
int t = bt % T;
int b = bt / T;
pos_t_vals[i] = static_cast<int>(pos_t_ptr[b * pos_sb + t * pos_st]);
pos_h_vals[i] = static_cast<int>(pos_h_ptr[b * pos_sb + t * pos_st]);
pos_w_vals[i] = static_cast<int>(pos_w_ptr[b * pos_sb + t * pos_st]);
}
__syncthreads();
// Phase 2: Load cos/sin into shared memory for all (b,t)
for (int i = tid; i < BT_PER_BLOCK * D_HALF; i += BLOCK_SIZE) {
int bti = i / D_HALF;
int p = i % D_HALF;
int pos;
if (p < s0) pos = pos_t_vals[bti];
else if (p < s0 + s1) pos = pos_h_vals[bti];
else pos = pos_w_vals[bti];
smem_cos[bti][p] = cos_cache[pos * cache_sp + p * cache_sd];
smem_sin[bti][p] = sin_cache[pos * cache_sp + p * cache_sd];
}
__syncthreads();
// Phase 3: Process all heads for all (b,t) in this block
for (int bti = 0; bti < BT_PER_BLOCK; bti++) {
int bt = bt_start + bti;
if (bt >= B * T) break;
int t = bt % T;
int b = bt / T;
const int total_pairs = H_total * D_HALF;
for (int idx = tid; idx < total_pairs; idx += BLOCK_SIZE) {
const int p = idx % D_HALF;
const int h = idx / D_HALF;
const float cv = __bfloat162float(smem_cos[bti][p]);
const float sv = __bfloat162float(smem_sin[bti][p]);
if (h < Hq) {
const int xb = b * q_sb + t * q_st + h * q_sh;
const float x0 = __bfloat162float(q_ptr[xb + p]);
const float x1 = __bfloat162float(q_ptr[xb + p + D_HALF]);
const float o0 = x0 * cv - x1 * sv;
const float o1 = x1 * cv + x0 * sv;
const int ob = b * oq_sb + h * oq_sh + t * oq_st;
q_rot_ptr[ob + p] = __float2bfloat16(o0);
q_rot_ptr[ob + p + D_HALF] = __float2bfloat16(o1);
} else {
const int hk = h - Hq;
const int xb = b * k_sb + t * k_st + hk * k_sh;
const float x0 = __bfloat162float(k_ptr[xb + p]);
const float x1 = __bfloat162float(k_ptr[xb + p + D_HALF]);
const float o0 = x0 * cv - x1 * sv;
const float o1 = x1 * cv + x0 * sv;
const int ob = b * ok_sb + hk * ok_sh + t * ok_st;
k_rot_ptr[ob + p] = __float2bfloat16(o0);
k_rot_ptr[ob + p + D_HALF] = __float2bfloat16(o1);
}
}
}
}
void launch_mrope_kernel(
int64_t q_ptr_int, int64_t k_ptr_int,
int64_t q_rot_ptr_int, int64_t k_rot_ptr_int,
int64_t pos_t_ptr_int, int64_t pos_h_ptr_int, int64_t pos_w_ptr_int,
int64_t cos_cache_int, int64_t sin_cache_int,
int B, int T, int Hq, int Hkv, int D, int s0, int s1,
int q_sb, int q_st, int q_sh,
int k_sb, int k_st, int k_sh,
int oq_sb, int oq_sh, int oq_st,
int ok_sb, int ok_sh, int ok_st,
int pos_sb, int pos_st,
int cache_sp, int cache_sd,
int64_t stream_int
) {
const __nv_bfloat16* q_ptr = reinterpret_cast<const __nv_bfloat16*>(static_cast<intptr_t>(q_ptr_int));
const __nv_bfloat16* k_ptr = reinterpret_cast<const __nv_bfloat16*>(static_cast<intptr_t>(k_ptr_int));
__nv_bfloat16* q_rot_ptr = reinterpret_cast<__nv_bfloat16*>(static_cast<intptr_t>(q_rot_ptr_int));
__nv_bfloat16* k_rot_ptr = reinterpret_cast<__nv_bfloat16*>(static_cast<intptr_t>(k_rot_ptr_int));
const int64_t* pos_t_ptr = reinterpret_cast<const int64_t*>(static_cast<intptr_t>(pos_t_ptr_int));
const int64_t* pos_h_ptr = reinterpret_cast<const int64_t*>(static_cast<intptr_t>(pos_h_ptr_int));
const int64_t* pos_w_ptr = reinterpret_cast<const int64_t*>(static_cast<intptr_t>(pos_w_ptr_int));
const __nv_bfloat16* cos_cache = reinterpret_cast<const __nv_bfloat16*>(static_cast<intptr_t>(cos_cache_int));
const __nv_bfloat16* sin_cache = reinterpret_cast<const __nv_bfloat16*>(static_cast<intptr_t>(sin_cache_int));
cudaStream_t cs = reinterpret_cast<cudaStream_t>(static_cast<intptr_t>(stream_int));
const int D_half = D / 2;
const int BT = B * T;
// For small BT, batch more (b,t) per block to increase work per block
// For large BT, use 1 (b,t) per block for maximum parallelism
if (D_half == 64) {
if (BT <= 4096) {
// Small: batch 4 (b,t) per block
const int BT_PER_BLOCK = 4;
const int BLOCK = 256;
const int grid_size = (BT + BT_PER_BLOCK - 1) / BT_PER_BLOCK;
mrope_fused_kernel<64, BLOCK, BT_PER_BLOCK><<<grid_size, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr, cos_cache, sin_cache,
B, T, Hq, Hkv, s0, s1,
q_sb, q_st, q_sh, k_sb, k_st, k_sh,
oq_sb, oq_sh, oq_st, ok_sb, ok_sh, ok_st,
pos_sb, pos_st, cache_sp, cache_sd);
} else {
// Large: 1 (b,t) per block
const int BT_PER_BLOCK = 1;
const int BLOCK = 256;
mrope_fused_kernel<64, BLOCK, BT_PER_BLOCK><<<BT, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr, cos_cache, sin_cache,
B, T, Hq, Hkv, s0, s1,
q_sb, q_st, q_sh, k_sb, k_st, k_sh,
oq_sb, oq_sh, oq_st, ok_sb, ok_sh, ok_st,
pos_sb, pos_st, cache_sp, cache_sd);
}
} else {
// D_half == 32
if (BT <= 4096) {
const int BT_PER_BLOCK = 8;
const int BLOCK = 256;
const int grid_size = (BT + BT_PER_BLOCK - 1) / BT_PER_BLOCK;
mrope_fused_kernel<32, BLOCK, BT_PER_BLOCK><<<grid_size, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr, cos_cache, sin_cache,
B, T, Hq, Hkv, s0, s1,
q_sb, q_st, q_sh, k_sb, k_st, k_sh,
oq_sb, oq_sh, oq_st, ok_sb, ok_sh, ok_st,
pos_sb, pos_st, cache_sp, cache_sd);
} else {
const int BT_PER_BLOCK = 1;
const int BLOCK = 256;
mrope_fused_kernel<32, BLOCK, BT_PER_BLOCK><<<BT, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr, cos_cache, sin_cache,
B, T, Hq, Hkv, s0, s1,
q_sb, q_st, q_sh, k_sb, k_st, k_sh,
oq_sb, oq_sh, oq_st, ok_sb, ok_sh, ok_st,
pos_sb, pos_st, cache_sp, cache_sd);
}
}
}
"""
_CPP_SRC = """
void launch_mrope_kernel(
int64_t q_ptr_int, int64_t k_ptr_int,
int64_t q_rot_ptr_int, int64_t k_rot_ptr_int,
int64_t pos_t_ptr_int, int64_t pos_h_ptr_int, int64_t pos_w_ptr_int,
int64_t cos_cache_int, int64_t sin_cache_int,
int B, int T, int Hq, int Hkv, int D, int s0, int s1,
int q_sb, int q_st, int q_sh,
int k_sb, int k_st, int k_sh,
int oq_sb, int oq_sh, int oq_st,
int ok_sb, int ok_sh, int ok_st,
int pos_sb, int pos_st,
int cache_sp, int cache_sd,
int64_t stream_int
);
"""
_mod = None
def _get_mod():
global _mod
if _mod is None:
_mod = load_inline(
name="mrope_cuda_v2",
cpp_sources=_CPP_SRC,
cuda_sources=_CUDA_SRC,
functions=["launch_mrope_kernel"],
verbose=False,
extra_cuda_cflags=["-O3", "-gencode=arch=compute_120,code=sm_120"],
)
return _mod
class Model(nn.Module):
def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
mrope_section: tuple[int, int, int], max_pos: int):
super().__init__()
assert sum(mrope_section) == D // 2, \
f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
self.B, self.T = B, T
self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
pos = torch.arange(max_pos, dtype=torch.float32)
freqs = torch.outer(pos, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
def forward(
self,
q: torch.Tensor, k: torch.Tensor,
pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hq, D = q.shape
Hkv = k.shape[2]
s0, s1, s2 = self.mrope_section
q_rot = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
k_rot = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
mod = _get_mod()
stream = torch.cuda.current_stream().cuda_stream
mod.launch_mrope_kernel(
q.data_ptr(), k.data_ptr(),
q_rot.data_ptr(), k_rot.data_ptr(),
pos_t.data_ptr(), pos_h.data_ptr(), pos_w.data_ptr(),
self.cos_cache.data_ptr(), self.sin_cache.data_ptr(),
B, T, Hq, Hkv, D, s0, s1,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
q_rot.stride(0), q_rot.stride(1), q_rot.stride(2),
k_rot.stride(0), k_rot.stride(1), k_rot.stride(2),
pos_t.stride(0), pos_t.stride(1),
self.cos_cache.stride(0), self.cos_cache.stride(1),
stream,
)
return q_rot, k_rot
B = 1
T = 4096
Hq = 32
Hkv = 8
D = 128
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768
def get_inputs():
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
return [q, k, pos_t, pos_h, pos_w]
def get_init_inputs():
return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
shape=0 variant=eager tflops=0.199 gbps=286.045 ms=0.316
shape=0 variant=compiled tflops=0.526 gbps=754.512 ms=0.120
shape=0 variant=solution tflops=0.762 gbps=1092.822 ms=0.083
shape=0 solution_peak_fraction=0.6071
shape=1 variant=eager tflops=0.194 gbps=282.939 ms=0.520
shape=1 variant=compiled tflops=0.531 gbps=776.088 ms=0.189
shape=1 variant=solution tflops=0.840 gbps=1226.776 ms=0.120
shape=1 solution_peak_fraction=0.6815
shape=2 variant=eager tflops=0.099 gbps=155.099 ms=0.143
shape=2 variant=compiled tflops=0.231 gbps=360.752 ms=0.061
shape=2 variant=solution tflops=0.460 gbps=718.503 ms=0.031
shape=2 solution_peak_fraction=0.3992
shape=3 variant=eager tflops=0.176 gbps=251.998 ms=1.433
shape=3 variant=compiled tflops=0.486 gbps=696.831 ms=0.518
shape=3 variant=solution tflops=0.914 gbps=1311.463 ms=0.275
shape=3 solution_peak_fraction=0.7286
peak_fraction: 0.5890
RESULT: OK
PASS
{
"problem": "09_fmha_preattn_mrope",
"harness": "opencode",
"model": "zai/glm-5.1",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.5890,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 128655, "output_tokens": 20020, "cache_read_tokens": 1988736, "cache_creation_tokens": 0, "reasoning_tokens": 71925, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py to understand the problem",
"status": "in_progress",
"priority": "high"
},
{
"content": "Study mRoPE math and Qwen2-VL implementation details",
"status": "pending",
"priority": "high"
},
{
"content": "Implement fused CUDA/Triton kernel in solution.py",
"status": "pending",
"priority": "high"
},
{
"content": "Verify correctness with check.py",
"status": "pending",
"priority": "high"
},
{
"content": "Profile and o
... (210 more chars)content: 'Read reference.py, check.py, benchmark.py to understand the problem'
status: 'in_progress'
priority: 'high'
content: 'Study mRoPE math and Qwen2-VL implementation details'
status: 'pending'
priority: 'high'
content: 'Implement fused CUDA/Triton kernel in solution.py'
status: 'pending'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize for SM120 Blackwell'
status: 'pending'
priority: 'medium'
content: 'Benchmark with benchmark.py'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
1: """Naive multi-axis RoPE pre-attention reference (correctness only).
2:
3: Mirrors the Qwen2-VL apply_multimodal_rotary_pos_emb semantics: three position
4: axes (temporal, height, width), each contributing a slice of the head_dim via
5: mrope_section. Standard rotate-half on q and k once per-axis cos/sin have been
6: mixed.
7:
8: Layout convention:
9: q: (B, T, Hq, D) bf16
10: k: (B, T, Hkv, D) bf16
11: pos_t, pos_h, pos_w: (B, T) int64
12:
13: Output: q_rot, k_rot in (B, H, T, D) — i.e. the layout an attention kernel
14: expects (head dim second). The transpose is part of the "pre-attention prep".
15: """
16: import torch
17: import torch.nn as nn
18:
19: OP_TYPE = "rope"
20: SUPPORTED_PRECISIONS = ["bf16"]
21: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
22:
23:
24: def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
25: return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
26:
27:
28: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
29: half = x.shape[-1] // 2
30: return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
31:
32:
33: class Model(nn.Module):
34: def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
35: mrope_section: tuple[int, int, int], max_pos: int):
36: super().__init__()
37: assert sum(mrope_section) == D // 2, \
38:
... (3384 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for multi-axis RoPE pre-attention prep.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or
... (3519 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for multi-axis RoPE pre-attention prep.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime
... (3966 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for multi-axis RoPE pre-attention.
2:
3: Sized for Qwen2-VL-style vision-language inference: long T, GQA, mrope_section
4: splitting D/2 across temporal/height/width axes. Mix of base, long-context,
5: small-head-dim, and very-long-context.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL base: T=4k, head_dim=128, GQA 32:8, mrope [16,24,24]
10: {"B": 1, "T": 4096, "Hq": 32, "Hkv": 8, "D": 128,
11: "mrope_section": (16, 24, 24), "max_pos": 32768},
12:
13: # Qwen2-VL 7B-style long context, narrower kv heads
14: {"B": 1, "T": 8192, "Hq": 28, "Hkv": 4, "D": 128,
15: "mrope_section": (16, 24, 24), "max_pos": 32768},
16:
17: # Smaller head_dim and B=2 (tests batch dim and reduced D)
18: {"B": 2, "T": 2048, "Hq": 16, "Hkv": 2, "D": 64,
19: "mrope_section": (8, 12, 12), "max_pos": 16384},
20:
21: # Very long context (16k tokens)
22: {"B": 1, "T": 16384, "Hq": 32, "Hkv": 8, "D": 128,
23: "mrope_section": (16, 24, 24), "max_pos": 65536},
24: ]
(End of file - total 24 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml</path>
<type>file</type>
<content>
1: name: 09_fmha_preattn_mrope
2: display_name: "Multi-axis RoPE Pre-attention Prep"
3: precision: bf16
4: regime: memory # pointwise rotation + cache gather; bandwidth-bound
5:
6: # FLOPs per call: rotate-half on q and k.
7: # Per element: out = a*cos + b*sin -> 3 flops/elem (1 mul + 1 mul + 1 add).
8: # We do this on q ((B*T*Hq*D) elems) and k ((B*T*Hkv*D) elems).
9: flops_formula: "3 * B * T * (Hq + Hkv) * D"
10:
11: # Bytes moved (lower bound):
12: # read q, k: 2 * B*T*(Hq+Hkv)*D
13: # read 3 position id arrays: 3 * B*T*8 (int64)
14: # read cos/sin from cache: 6 * B*T*D*2 (3 axes, 2 caches, bf16)
15: # write q_rot, k_rot: 2 * B*T*(Hq+Hkv)*D
16: bytes_formula: "4 * B * T * (Hq + Hkv) * D + 12 * B * T * D + 24 * B * T"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram
21:
22: tolerance:
23: bfloat16: 0.02
24:
25: # Forbidden ops -- agent must implement the gather + per-axis mix + rotation.
26: # Banning vendor mRoPE entry points and obvious shortcuts.
27: forbidden:
28: - "transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb"
29: - "apply_multimodal_rotary_pos_emb"
30: - "flash_attn.layers.rotary"
31: - "flashinfer.rope"
32: - "flashinfer.apply_rope"
33:
34: sota:
35: name:
... (284 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for multi-axis RoPE pre-attention.
2:
3: Wraps transformers' apply_multimodal_rotary_pos_emb. The full pipeline (gather
4: per-axis cos/sin from cache, then call the vendor rotation) is timed, matching
5: what the reference and the agent solution do.
6:
7: If transformers isn't installed or the impl signature has drifted, returns
8: is_available()=False and the benchmark omits the SOTA line.
9: """
10: from __future__ import annotations
11:
12: import torch
13:
14:
15: def _try_transformers():
16: try:
17: from transformers.models.qwen2_vl.modeling_qwen2_vl import (
18: apply_multimodal_rotary_pos_emb as _impl,
19: )
20: return _impl
21: except Exception:
22: return None
23:
24:
25: def sota_forward(
26: q: torch.Tensor, k: torch.Tensor,
27: pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
28: cos_cache: torch.Tensor, sin_cache: torch.Tensor,
29: mrope_section: tuple[int, int, int],
30: ):
31: """q,k: (B, T, H, D); pos_*: (B, T); cos_cache,sin_cache: (max_pos, D)."""
32: impl = _try_transformers()
33: if impl is None:
34: raise RuntimeError("transformers not available for SOTA")
35:
36: # Gather per-axis cos/sin and stack into (3, B, T, D) as transformers expects.
37: cos = torch.stack([cos_cache[pos_t], cos_cache[pos_h], c
... (495 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/PROMPT.txt"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/PROMPT.txt</path>
<type>file</type>
<content>
1: I need you to optimize the multi-axis RoPE pre-attention prep step for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
2:
3: The op is the pre-attention prep used in Qwen2-VL-style vision-language models: take q (B, T, Hq, D) and k (B, T, Hkv, D) bf16 plus three int64 position-id tensors pos_t, pos_h, pos_w of shape (B, T) — temporal, height, width axes — gather per-axis cos/sin from a precomputed (max_pos, D) bf16 cache, mix the per-axis values across the head_dim according to mrope_section (a tuple [s0, s1, s2] summing to D/2 — the first s0 pair-indices use the temporal axis, next s1 use height, last s2 use width, mirrored into the second half of D), then apply standard rotate-half rotation (out = x*cos + rotate_half(x)*sin) on q and k. Return q_rot and k_rot in (B, H, T, D) layout — the transpose from (B, T, H, D) to (B, H, T, D) is part of the prep. Correctness tolerance on the bf16 outputs is 1e-2 abs/rel. Both q_rot and k_rot must match the reference within that.
4:
5: The shapes you have to handle are B=1 T
... (1617 more chars){
"pattern": "**/solution.py",
"path": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}No files found{
"command": "uv run python -c \"import triton; print(triton.__version__)\"",
"description": "Check Triton version"
}3.6.0
{
"command": "nvcc --version 2>/dev/null | tail -1",
"description": "Check CUDA version"
}Build cuda_13.2.r13.2/compiler.37434383_0
{
"command": "nvidia-smi --query-gpu=name,memory.total,compute_cap --format=csv,noheader 2>/dev/null | head -3",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 97887 MiB, 12.0
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py to understand the problem",
"priority": "high",
"status": "completed"
},
{
"content": "Implement fused Triton kernel in solution.py",
"priority": "high",
"status": "in_progress"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "pending"
},
{
"content": "Profile and optimize for SM120 Blackwell",
"priority": "medium",
"status": "pending"
},
{
"content": "Benchmark with benchmark.p
... (69 more chars)content: 'Read reference.py, check.py, benchmark.py to understand the problem'
status: 'completed'
priority: 'high'
content: 'Implement fused Triton kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile and optimize for SM120 Blackwell'
status: 'pending'
priority: 'medium'
content: 'Benchmark with benchmark.py'
status: 'pending'
priority: 'medium'--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,138 @@
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def mrope_fused_kernel(
+ x_ptr, out_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ cos_cache_ptr, sin_cache_ptr,
+ x_sb, x_st, x_sh, x_sd,
+ o_sb, o_sh, o_st, o_sd,
+ p_sb, p_st,
+ c_sp, c_sd,
+ B, H, T, D_half,
+ s0, s1,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+
+ total = B * H * T * D_half
+ mask = offs < total
+
+ p = offs % D_half
+ rest = offs // D_half
+ t_idx = rest % T
+ rest2 = rest // T
+ h_idx = rest2 % H
+ b_idx = rest2 // H
+
+ use_t = p < s0
+ use_h = (p >= s0) & (p < s0 + s1)
+
+ pos_off = b_idx * p_sb + t_idx * p_st
+ pos_t_val = tl.load(pos_t_ptr + pos_off, mask=mask).to(tl.int32)
+ pos_h_val = tl.load(pos_h_ptr + pos_off, mask=mask).to(tl.int32)
+ pos_w_val = tl.load(pos_w_ptr + pos_off, mask=mask).to(tl.int32)
+
+ pos = tl.where(use_t, pos_t_val, tl.where(use_h, pos_h_val, pos_w_val))
+
+ cache_off = pos * c_sp + p * c_sd
+ cos_val = tl.load(cos_cache_ptr + cache_off, mask=mask).to(tl.float32)
+ sin_val = tl.load(sin_cache_ptr + cache_off, mask=mask).to(tl.float32)
+
+ x_base = b_idx * x_sb + t_idx * x_st + h_idx * x_sh
+ x_0 = tl.load(x_ptr + x_base + p * x_sd, mask=mask).to(tl.float32)
+ x_1 = tl.load(x_ptr + x_base + (p + D_half) * x_sd, mask=mask).to(tl.float32)
+
+ out_0 = x_0 * cos_val - x_1 * sin_val
+ out_1 = x_1 * cos_val + x_0 * sin_val
+
+ out_base = b_idx * o_sb + h_idx * o_sh + t_idx * o_st
+ tl.store(out_ptr + out_base + p * o_sd, out_0.to(tl.bfloat16), mask=mask)
+ tl.store(out_ptr + out_base + (p + D_half) * o_sd, out_1.to(tl.bfloat16), mask=mask)
+
+
+def mrope_fused(x, pos_t, pos_h, pos_w, cos_cache, sin_cache, mrope_section):
+ B, T, H, D = x.shape
+ D_half = D // 2
+ s0, s1, s2 = mrope_section
+
+ out = torch.empty(B, H, T, D, dtype=x.dtype, device=x.device)
+
+ pos_t_i32 = pos_t.to(torch.int32)
+ pos_h_i32 = pos_h.to(torch.int32)
+ pos_w_i32 = pos_w.to(torch.int32)
+
+ total = B * H * T * D_half
+ BLOCK = 512
+ grid = (triton.cdiv(total, BLOCK),)
+
+ mrope_fused_kernel[grid](
+ x, out,
+ pos_t_i32, pos_h_i32, pos_w_i32,
+ cos_cache, sin_cache,
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
+ out.stride(0), out.stride(1), out.stride(2), out.stride(3),
+ pos_t_i32.stride(0), pos_t_i32.stride(1),
+ cos_cache.stride(0), cos_cache.stride(1),
+ B, H, T, D_half,
+ s0, s1,
+ BLOCK=BLOCK,
+ num_warps=8,
+ )
+
+ return out
+
+
+class Model(nn.Module):
+ def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
+ mrope_section: tuple[int, int, int], max_pos: int):
+ super().__init__()
+ assert sum(mrope_section) == D // 2, \
+ f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
+ self.B, self.T = B, T
+ self.Hq, self.Hkv, self.D = Hq, Hkv, D
+ self.mrope_section = tuple(mrope_section)
+ self.max_pos = max_pos
+
+ inv_freq = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+ pos = torch.arange(max_pos, dtype=torch.float32)
+ freqs = torch.outer(pos, inv_freq)
+ emb = torch.cat([freqs, freqs], dim=-1)
+ self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
+ self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+
+ def forward(
+ self,
+ q: torch.Tensor, k: torch.Tensor,
+ pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ q_rot = mrope_fused(q, pos_t, pos_h, pos_w, self.cos_cache, self.sin_cache, self.mrope_section)
+ k_rot = mrope_fused(k, pos_t, pos_h, pos_w, self.cos_cache, self.sin_cache, self.mrope_section)
+ return q_rot, k_rot
+
+
+B = 1
+T = 4096
+Hq = 32
+Hkv = 8
+D = 128
+MROPE_SECTION = (16, 24, 24)
+MAX_POS = 32768
+
+
+def get_inputs():
+ q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
+ k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
+ pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ return [q, k, pos_t, pos_h, pos_w]
+
+
+def get_init_inputs():
+ return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<4435 chars \u2014 see diff>"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Run correctness check",
"timeout": 180000
}PASS
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py to understand the problem",
"priority": "high",
"status": "completed"
},
{
"content": "Implement fused Triton kernel in solution.py",
"priority": "high",
"status": "completed"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "completed"
},
{
"content": "Benchmark and optimize for SM120 Blackwell",
"priority": "high",
"status": "in_progress"
},
{
"content": "Benchmark with benchma
... (73 more chars)content: 'Read reference.py, check.py, benchmark.py to understand the problem'
status: 'completed'
priority: 'high'
content: 'Implement fused Triton kernel in solution.py'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Benchmark and optimize for SM120 Blackwell'
status: 'in_progress'
priority: 'high'
content: 'Benchmark with benchmark.py'
status: 'pending'
priority: 'medium'{
"command": "uv run python benchmark.py 2>&1",
"description": "Run benchmark",
"timeout": 300000
}shape=0 variant=eager tflops=0.200 gbps=286.801 ms=0.315
shape=0 variant=compiled tflops=0.568 gbps=814.882 ms=0.111
shape=0 variant=solution tflops=0.689 gbps=989.172 ms=0.091
shape=0 solution_peak_fraction=0.5495
shape=1 variant=eager tflops=0.190 gbps=276.760 ms=0.531
shape=1 variant=compiled tflops=0.529 gbps=773.214 ms=0.190
shape=1 variant=solution tflops=0.789 gbps=1152.594 ms=0.128
shape=1 solution_peak_fraction=0.6403
shape=2 variant=eager tflops=0.096 gbps=149.886 ms=0.148
shape=2 variant=compiled tflops=0.223 gbps=348.036 ms=0.064
shape=2 variant=solution tflops=0.260 gbps=406.349 ms=0.054
shape=2 solution_peak_fraction=0.2257
shape=3 variant=eager tflops=0.176 gbps=252.806 ms=1.428
shape=3 variant=compiled tflops=0.483 gbps=693.683 ms=0.521
shape=3 variant=solution tflops=0.963 gbps=1382.140 ms=0.261
shape=3 solution_peak_fraction=0.7679
peak_fraction: 0.4970
RESULT: OK
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: import torch
2: import torch.nn as nn
3: import triton
4: import triton.language as tl
5:
6:
7: @triton.jit
8: def mrope_fused_kernel(
9: x_ptr, out_ptr,
10: pos_t_ptr, pos_h_ptr, pos_w_ptr,
11: cos_cache_ptr, sin_cache_ptr,
12: x_sb, x_st, x_sh, x_sd,
13: o_sb, o_sh, o_st, o_sd,
14: p_sb, p_st,
15: c_sp, c_sd,
16: B, H, T, D_half,
17: s0, s1,
18: BLOCK: tl.constexpr,
19: ):
20: pid = tl.program_id(0)
21: offs = pid * BLOCK + tl.arange(0, BLOCK)
22:
23: total = B * H * T * D_half
24: mask = offs < total
25:
26: p = offs % D_half
27: rest = offs // D_half
28: t_idx = rest % T
29: rest2 = rest // T
30: h_idx = rest2 % H
31: b_idx = rest2 // H
32:
33: use_t = p < s0
34: use_h = (p >= s0) & (p < s0 + s1)
35:
36: pos_off = b_idx * p_sb + t_idx * p_st
37: pos_t_val = tl.load(pos_t_ptr + pos_off, mask=mask).to(tl.int32)
38: pos_h_val = tl.load(pos_h_ptr + pos_off, mask=mask).to(tl.int32)
39: pos_w_val = tl.load(pos_w_ptr + pos_off, mask=mask).to(tl.int32)
40:
41: pos = tl.where(use_t, pos_t_val, tl.where(use_h, pos_h_val, pos_w_val))
42:
43: cache_off = pos * c_sp + p * c_sd
44: cos_val = tl.load(cos_cache_ptr + cache_off, mask=mask).to(tl.float32)
45: sin_val = tl.load(sin_cache_ptr + cache_off, mask=mask).to(tl.float32)
46:
47:
... (3682 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -5,30 +5,37 @@
@triton.jit
-def mrope_fused_kernel(
- x_ptr, out_ptr,
+def mrope_fused_qk_kernel(
+ q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
cos_cache_ptr, sin_cache_ptr,
- x_sb, x_st, x_sh, x_sd,
- o_sb, o_sh, o_st, o_sd,
+ q_sb, q_st, q_sh, q_sd,
+ k_sb, k_st, k_sh, k_sd,
+ oq_sb, oq_sh, oq_st, oq_sd,
+ ok_sb, ok_sh, ok_st, ok_sd,
p_sb, p_st,
c_sp, c_sd,
- B, H, T, D_half,
+ B, Hq, Hkv, T, D_half,
s0, s1,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
- total = B * H * T * D_half
+ H_total = Hq + Hkv
+ total = B * H_total * T * D_half
mask = offs < total
p = offs % D_half
rest = offs // D_half
t_idx = rest % T
rest2 = rest // T
- h_idx = rest2 % H
- b_idx = rest2 // H
+ h_idx = rest2 % H_total
+ b_idx = rest2 // H_total
+
+ is_q = h_idx < Hq
+ h_q = h_idx
+ h_k = h_idx - Hq
use_t = p < s0
use_h = (p >= s0) & (p < s0 + s1)
@@ -44,48 +51,30 @@
cos_val = tl.load(cos_cache_ptr + cache_off, mask=mask).to(tl.float32)
sin_val = tl.load(sin_cache_ptr + cache_off, mask=mask).to(tl.float32)
- x_base = b_idx * x_sb + t_idx * x_st + h_idx * x_sh
- x_0 = tl.load(x_ptr + x_base + p * x_sd, mask=mask).to(tl.float32)
- x_1 = tl.load(x_ptr + x_base + (p + D_half) * x_sd, mask=mask).to(tl.float32)
+ x_base_q = b_idx * q_sb + t_idx * q_st + h_q * q_sh
+ x_base_k = b_idx * k_sb + t_idx * k_st + h_k * k_sh
+
+ x_0_q = tl.load(q_ptr + x_base_q + p * q_sd, mask=mask & is_q).to(tl.float32)
+ x_1_q = tl.load(q_ptr + x_base_q + (p + D_half) * q_sd, mask=mask & is_q).to(tl.float32)
+ x_0_k = tl.load(k_ptr + x_base_k + p * k_sd, mask=mask & (~is_q)).to(tl.float32)
+ x_1_k = tl.load(k_ptr + x_base_k + (p + D_half) * k_sd, mask=mask & (~is_q)).to(tl.float32)
+
+ x_0 = tl.where(is_q, x_0_q, x_0_k)
+ x_1 = tl.where(is_q, x_1_q, x_1_k)
out_0 = x_0 * cos_val - x_1 * sin_val
out_1 = x_1 * cos_val + x_0 * sin_val
- out_base = b_idx * o_sb + h_idx * o_sh + t_idx * o_st
- tl.store(out_ptr + out_base + p * o_sd, out_0.to(tl.bfloat16), mask=mask)
- tl.store(out_ptr + out_base + (p + D_half) * o_sd, out_1.to(tl.bfloat16), mask=mask)
+ out_base_q = b_idx * oq_sb + h_q * oq_sh + t_idx * oq_st
+ out_base_k = b_idx * ok_sb + h_k * ok_sh + t_idx * ok_st
+ q_mask = mask & is_q
+ k_mask = mask & (~is_q)
-def mrope_fused(x, pos_t, pos_h, pos_w, cos_cache, sin_cache, mrope_section):
- B, T, H, D = x.shape
- D_half = D // 2
- s0, s1, s2 = mrope_section
-
- out = torch.empty(B, H, T, D, dtype=x.dtype, device=x.device)
-
- pos_t_i32 = pos_t.to(torch.int32)
- pos_h_i32 = pos_h.to(torch.int32)
- pos_w_i32 = pos_w.to(torch.int32)
-
- total = B * H * T * D_half
- BLOCK = 512
- grid = (triton.cdiv(total, BLOCK),)
-
- mrope_fused_kernel[grid](
- x, out,
- pos_t_i32, pos_h_i32, pos_w_i32,
- cos_cache, sin_cache,
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- out.stride(0), out.stride(1), out.stride(2), out.stride(3),
- pos_t_i32.stride(0), pos_t_i32.stride(1),
- cos_cache.stride(0), cos_cache.stride(1),
- B, H, T, D_half,
- s0, s1,
- BLOCK=BLOCK,
- num_warps=8,
- )
-
- return out
+ tl.store(q_rot_ptr + out_base_q + p * oq_sd, out_0.to(tl.bfloat16), mask=q_mask)
+ tl.store(q_rot_ptr + out_base_q + (p + D_half) * oq_sd, out_1.to(tl.bfloat16), mask=q_mask)
+ tl.store(k_rot_ptr + out_base_k + p * ok_sd, out_0.to(tl.bfloat16), mask=k_mask)
+ tl.store(k_rot_ptr + out_base_k + (p + D_half) * ok_sd, out_1.to(tl.bfloat16), mask=k_mask)
class Model(nn.Module):
@@ -111,8 +100,35 @@
q: torch.Tensor, k: torch.Tensor,
pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
- q_rot = mrope_fused(q, pos_t, pos_h, pos_w, self.cos_cache, self.sin_cache, self.mrope_section)
- k_rot = mrope_fused(k, pos_t, pos_h, pos_w, self.cos_cache, self.sin_cache, self.mrope_section)
+ B, T, Hq, D = q.shape
+ Hkv = k.shape[2]
+ D_half = D // 2
+ s0, s1, s2 = self.mrope_section
+
+ q_rot = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
+ k_rot = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
+
+ H_total = Hq + Hkv
+ total = B * H_total * T * D_half
+ BLOCK = 1024
+ grid = (triton.cdiv(total, BLOCK),)
+
+ mrope_fused_qk_kernel[grid](
+ q, k, q_rot, k_rot,
+ pos_t, pos_h, pos_w,
+ self.cos_cache, self.sin_cache,
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
+ q_rot.stride(0), q_rot.stride(1), q_rot.stride(2), q_rot.stride(3),
+ k_rot.stride(0), k_rot.stride(1), k_rot.stride(2), k_rot.stride(3),
+ pos_t.stride(0), pos_t.stride(1),
+ self.cos_cache.stride(0), self.cos_cache.stride(1),
+ B, Hq, Hkv, T, D_half,
+ s0, s1,
+ BLOCK=BLOCK,
+ num_warps=8,
+ )
+
return q_rot, k_rot
{
"content": "<5259 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Check correctness of fused q+k kernel",
"timeout": 180000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark fused q+k kernel",
"timeout": 300000
}shape=0 variant=eager tflops=0.199 gbps=285.929 ms=0.316
shape=0 variant=compiled tflops=0.561 gbps=804.540 ms=0.112
shape=0 variant=solution tflops=0.785 gbps=1126.645 ms=0.080
shape=0 solution_peak_fraction=0.6259
shape=1 variant=eager tflops=0.191 gbps=278.623 ms=0.528
shape=1 variant=compiled tflops=0.532 gbps=776.481 ms=0.189
shape=1 variant=solution tflops=0.853 gbps=1244.895 ms=0.118
shape=1 solution_peak_fraction=0.6916
shape=2 variant=eager tflops=0.096 gbps=150.687 ms=0.147
shape=2 variant=compiled tflops=0.227 gbps=355.373 ms=0.062
shape=2 variant=solution tflops=0.430 gbps=672.374 ms=0.033
shape=2 solution_peak_fraction=0.3735
shape=3 variant=eager tflops=0.177 gbps=254.043 ms=1.421
shape=3 variant=compiled tflops=0.486 gbps=697.563 ms=0.518
shape=3 variant=solution tflops=0.985 gbps=1412.856 ms=0.256
shape=3 solution_peak_fraction=0.7849
peak_fraction: 0.5969
RESULT: OK
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,80 +1,167 @@
import torch
import torch.nn as nn
-import triton
-import triton.language as tl
-
-
-@triton.jit
-def mrope_fused_qk_kernel(
- q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- cos_cache_ptr, sin_cache_ptr,
- q_sb, q_st, q_sh, q_sd,
- k_sb, k_st, k_sh, k_sd,
- oq_sb, oq_sh, oq_st, oq_sd,
- ok_sb, ok_sh, ok_st, ok_sd,
- p_sb, p_st,
- c_sp, c_sd,
- B, Hq, Hkv, T, D_half,
- s0, s1,
- BLOCK: tl.constexpr,
-):
- pid = tl.program_id(0)
- offs = pid * BLOCK + tl.arange(0, BLOCK)
-
- H_total = Hq + Hkv
- total = B * H_total * T * D_half
- mask = offs < total
-
- p = offs % D_half
- rest = offs // D_half
- t_idx = rest % T
- rest2 = rest // T
- h_idx = rest2 % H_total
- b_idx = rest2 // H_total
-
- is_q = h_idx < Hq
- h_q = h_idx
- h_k = h_idx - Hq
-
- use_t = p < s0
- use_h = (p >= s0) & (p < s0 + s1)
-
- pos_off = b_idx * p_sb + t_idx * p_st
- pos_t_val = tl.load(pos_t_ptr + pos_off, mask=mask).to(tl.int32)
- pos_h_val = tl.load(pos_h_ptr + pos_off, mask=mask).to(tl.int32)
- pos_w_val = tl.load(pos_w_ptr + pos_off, mask=mask).to(tl.int32)
-
- pos = tl.where(use_t, pos_t_val, tl.where(use_h, pos_h_val, pos_w_val))
-
- cache_off = pos * c_sp + p * c_sd
- cos_val = tl.load(cos_cache_ptr + cache_off, mask=mask).to(tl.float32)
- sin_val = tl.load(sin_cache_ptr + cache_off, mask=mask).to(tl.float32)
-
- x_base_q = b_idx * q_sb + t_idx * q_st + h_q * q_sh
- x_base_k = b_idx * k_sb + t_idx * k_st + h_k * k_sh
-
- x_0_q = tl.load(q_ptr + x_base_q + p * q_sd, mask=mask & is_q).to(tl.float32)
- x_1_q = tl.load(q_ptr + x_base_q + (p + D_half) * q_sd, mask=mask & is_q).to(tl.float32)
- x_0_k = tl.load(k_ptr + x_base_k + p * k_sd, mask=mask & (~is_q)).to(tl.float32)
- x_1_k = tl.load(k_ptr + x_base_k + (p + D_half) * k_sd, mask=mask & (~is_q)).to(tl.float32)
-
- x_0 = tl.where(is_q, x_0_q, x_0_k)
- x_1 = tl.where(is_q, x_1_q, x_1_k)
-
- out_0 = x_0 * cos_val - x_1 * sin_val
- out_1 = x_1 * cos_val + x_0 * sin_val
-
- out_base_q = b_idx * oq_sb + h_q * oq_sh + t_idx * oq_st
- out_base_k = b_idx * ok_sb + h_k * ok_sh + t_idx * ok_st
-
- q_mask = mask & is_q
- k_mask = mask & (~is_q)
-
- tl.store(q_rot_ptr + out_base_q + p * oq_sd, out_0.to(tl.bfloat16), mask=q_mask)
- tl.store(q_rot_ptr + out_base_q + (p + D_half) * oq_sd, out_1.to(tl.bfloat16), mask=q_mask)
- tl.store(k_rot_ptr + out_base_k + p * ok_sd, out_0.to(tl.bfloat16), mask=k_mask)
- tl.store(k_rot_ptr + out_base_k + (p + D_half) * ok_sd, out_1.to(tl.bfloat16), mask=k_mask)
+from torch.utils.cpp_extension import load_inline
+
+_CUDA_SRC = r"""
+#include <cuda.h>
+#include <cuda_bf16.h>
+#include <cuda_runtime.h>
+
+template<int D_HALF, int BLOCK_SIZE>
+__global__ void mrope_fused_kernel(
+ const __nv_bfloat16* __restrict__ q_ptr,
+ const __nv_bfloat16* __restrict__ k_ptr,
+ __nv_bfloat16* __restrict__ q_rot_ptr,
+ __nv_bfloat16* __restrict__ k_rot_ptr,
+ const int64_t* __restrict__ pos_t_ptr,
+ const int64_t* __restrict__ pos_h_ptr,
+ const int64_t* __restrict__ pos_w_ptr,
+ const __nv_bfloat16* __restrict__ cos_cache,
+ const __nv_bfloat16* __restrict__ sin_cache,
+ const int B, const int T, const int Hq, const int Hkv,
+ const int s0, const int s1,
+ const int q_sb, const int q_st, const int q_sh,
+ const int k_sb, const int k_st, const int k_sh,
+ const int oq_sb, const int oq_sh, const int oq_st,
+ const int ok_sb, const int ok_sh, const int ok_st,
+ const int pos_sb, const int pos_st,
+ const int cache_sp, const int cache_sd
+) {
+ __shared__ __nv_bfloat16 smem_cos[D_HALF];
+ __shared__ __nv_bfloat16 smem_sin[D_HALF];
+ __shared__ int pos_t_val, pos_h_val, pos_w_val;
+
+ const int bt_idx = blockIdx.x;
+ const int t_idx = bt_idx % T;
+ const int b_idx = bt_idx / T;
+ const int tid = threadIdx.x;
+ const int H_total = Hq + Hkv;
+
+ if (tid == 0) pos_t_val = static_cast<int>(pos_t_ptr[b_idx * pos_sb + t_idx * pos_st]);
+ if (tid == 1) pos_h_val = static_cast<int>(pos_h_ptr[b_idx * pos_sb + t_idx * pos_st]);
+ if (tid == 2) pos_w_val = static_cast<int>(pos_w_ptr[b_idx * pos_sb + t_idx * pos_st]);
+ __syncthreads();
+
+ for (int p = tid; p < D_HALF; p += BLOCK_SIZE) {
+ int pos;
+ if (p < s0) pos = pos_t_val;
+ else if (p < s0 + s1) pos = pos_h_val;
+ else pos = pos_w_val;
+ smem_cos[p] = cos_cache[pos * cache_sp + p * cache_sd];
+ smem_sin[p] = sin_cache[pos * cache_sp + p * cache_sd];
+ }
+ __syncthreads();
+
+ const int total_pairs = H_total * D_HALF;
+ for (int idx = tid; idx < total_pairs; idx += BLOCK_SIZE) {
+ const int p = idx % D_HALF;
+ const int h = idx / D_HALF;
+
+ const float cv = __bfloat162float(smem_cos[p]);
+ const float sv = __bfloat162float(smem_sin[p]);
+
+ if (h < Hq) {
+ const int xb = b_idx * q_sb + t_idx * q_st + h * q_sh;
+ const float x0 = __bfloat162float(q_ptr[xb + p]);
+ const float x1 = __bfloat162float(q_ptr[xb + p + D_HALF]);
+ const float o0 = x0 * cv - x1 * sv;
+ const float o1 = x1 * cv + x0 * sv;
+ const int ob = b_idx * oq_sb + h * oq_sh + t_idx * oq_st;
+ q_rot_ptr[ob + p] = __float2bfloat16(o0);
+ q_rot_ptr[ob + p + D_HALF] = __float2bfloat16(o1);
+ } else {
+ const int hk = h - Hq;
+ const int xb = b_idx * k_sb + t_idx * k_st + hk * k_sh;
+ const float x0 = __bfloat162float(k_ptr[xb + p]);
+ const float x1 = __bfloat162float(k_ptr[xb + p + D_HALF]);
+ const float o0 = x0 * cv - x1 * sv;
+ const float o1 =
... (5095 more chars){
"content": "<8904 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Check CUDA C++ kernel correctness",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 197, in forward
mod = _get_mod()
^^^^^^^^^^
File "/home/infatoshi/cuda/Ker
... (4714 more chars){
"command": "rm -rf /home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_cuda/",
"description": "Clean cached extension build"
}--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -9,15 +9,15 @@
template<int D_HALF, int BLOCK_SIZE>
__global__ void mrope_fused_kernel(
- const __nv_bfloat16* __restrict__ q_ptr,
- const __nv_bfloat16* __restrict__ k_ptr,
- __nv_bfloat16* __restrict__ q_rot_ptr,
- __nv_bfloat16* __restrict__ k_rot_ptr,
+ const void* __restrict__ q_ptr_v,
+ const void* __restrict__ k_ptr_v,
+ void* __restrict__ q_rot_ptr_v,
+ void* __restrict__ k_rot_ptr_v,
const int64_t* __restrict__ pos_t_ptr,
const int64_t* __restrict__ pos_h_ptr,
const int64_t* __restrict__ pos_w_ptr,
- const __nv_bfloat16* __restrict__ cos_cache,
- const __nv_bfloat16* __restrict__ sin_cache,
+ const void* __restrict__ cos_cache_v,
+ const void* __restrict__ sin_cache_v,
const int B, const int T, const int Hq, const int Hkv,
const int s0, const int s1,
const int q_sb, const int q_st, const int q_sh,
@@ -27,6 +27,13 @@
const int pos_sb, const int pos_st,
const int cache_sp, const int cache_sd
) {
+ const __nv_bfloat16* q_ptr = static_cast<const __nv_bfloat16*>(q_ptr_v);
+ const __nv_bfloat16* k_ptr = static_cast<const __nv_bfloat16*>(k_ptr_v);
+ __nv_bfloat16* q_rot_ptr = static_cast<__nv_bfloat16*>(q_rot_ptr_v);
+ __nv_bfloat16* k_rot_ptr = static_cast<__nv_bfloat16*>(k_rot_ptr_v);
+ const __nv_bfloat16* cos_cache = static_cast<const __nv_bfloat16*>(cos_cache_v);
+ const __nv_bfloat16* sin_cache = static_cast<const __nv_bfloat16*>(sin_cache_v);
+
__shared__ __nv_bfloat16 smem_cos[D_HALF];
__shared__ __nv_bfloat16 smem_sin[D_HALF];
__shared__ int pos_t_val, pos_h_val, pos_w_val;
@@ -83,12 +90,11 @@
}
}
-// Dispatch for D_HALF=32 and D_HALF=64
void launch_mrope_kernel(
- const __nv_bfloat16* q_ptr, const __nv_bfloat16* k_ptr,
- __nv_bfloat16* q_rot_ptr, __nv_bfloat16* k_rot_ptr,
+ const void* q_ptr, const void* k_ptr,
+ void* q_rot_ptr, void* k_rot_ptr,
const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
- const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
+ const void* cos_cache, const void* sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1,
int q_sb, int q_st, int q_sh,
int k_sb, int k_st, int k_sh,
@@ -96,21 +102,14 @@
int ok_sb, int ok_sh, int ok_st,
int pos_sb, int pos_st,
int cache_sp, int cache_sd,
- cudaStream_t stream
+ int64_t stream
) {
const int D_half = D / 2;
const int BT = B * T;
const int BLOCK = 256;
+ cudaStream_t cs = reinterpret_cast<cudaStream_t>(stream);
if (D_half == 64) {
- mrope_fused_kernel<64, BLOCK><<<BT, BLOCK, 0, stream>>>(
- q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
- pos_t, pos_h, pos_w, cos_cache, sin_cache,
- B, T, Hq, Hkv, s0, s1,
- q_sb, q_st, q_sh, k_sb, k_st, k_sh,
- oq_sb, oq_sh, oq_st, ok_sb, ok_sh, ok_st,
- pos_sb, pos_st, cache_sp, cache_sd);
- } else if (D_half == 32) {
- mrope_fused_kernel<32, BLOCK><<<BT, BLOCK, 0, stream>>>(
+ mrope_fused_kernel<64, BLOCK><<<BT, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
pos_t, pos_h, pos_w, cos_cache, sin_cache,
B, T, Hq, Hkv, s0, s1,
@@ -118,9 +117,7 @@
oq_sb, oq_sh, oq_st, ok_sb, ok_sh, ok_st,
pos_sb, pos_st, cache_sp, cache_sd);
} else {
- // Fallback for other D_half values (shouldn't happen for our shapes)
- // Use D_half=64 as default - will produce wrong results for other values
- mrope_fused_kernel<64, BLOCK><<<BT, BLOCK, 0, stream>>>(
+ mrope_fused_kernel<32, BLOCK><<<BT, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
pos_t, pos_h, pos_w, cos_cache, sin_cache,
B, T, Hq, Hkv, s0, s1,
@@ -131,23 +128,6 @@
}
"""
-_CPP_SRC = """
-void launch_mrope_kernel(
- const __nv_bfloat16* q_ptr, const __nv_bfloat16* k_ptr,
- __nv_bfloat16* q_rot_ptr, __nv_bfloat16* k_rot_ptr,
- const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
- const __nv_bfloat16* cos_cache, const __nv_bfloat16* sin_cache,
- int B, int T, int Hq, int Hkv, int D, int s0, int s1,
- int q_sb, int q_st, int q_sh,
- int k_sb, int k_st, int k_sh,
- int oq_sb, int oq_sh, int oq_st,
- int ok_sb, int ok_sh, int ok_st,
- int pos_sb, int pos_st,
- int cache_sp, int cache_sd,
- cudaStream_t stream
-);
-"""
-
_mod = None
def _get_mod():
@@ -155,11 +135,11 @@
if _mod is None:
_mod = load_inline(
name="mrope_cuda",
- cpp_sources=_CPP_SRC,
+ cpp_sources="",
cuda_sources=_CUDA_SRC,
functions=["launch_mrope_kernel"],
verbose=False,
- extra_cuda_cflags=["-O3", "-arch=sm_120"],
+ extra_cuda_cflags=["-O3", "-gencode=arch=compute_120,code=sm_120"],
)
return _mod
{
"content": "<8175 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Check CUDA C++ kernel correctness",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 177, in forward
mod = _get_mod()
^^^^^^^^^^
File "/home/infatoshi/cuda/Ker
... (3329 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -102,12 +102,12 @@
int ok_sb, int ok_sh, int ok_st,
int pos_sb, int pos_st,
int cache_sp, int cache_sd,
- int64_t stream
+ int64_t stream_int
) {
const int D_half = D / 2;
const int BT = B * T;
const int BLOCK = 256;
- cudaStream_t cs = reinterpret_cast<cudaStream_t>(stream);
+ cudaStream_t cs = reinterpret_cast<cudaStream_t>(static_cast<intptr_t>(stream_int));
if (D_half == 64) {
mrope_fused_kernel<64, BLOCK><<<BT, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
@@ -128,6 +128,23 @@
}
"""
+_CPP_SRC = """
+void launch_mrope_kernel(
+ const void* q_ptr, const void* k_ptr,
+ void* q_rot_ptr, void* k_rot_ptr,
+ const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
+ const void* cos_cache, const void* sin_cache,
+ int B, int T, int Hq, int Hkv, int D, int s0, int s1,
+ int q_sb, int q_st, int q_sh,
+ int k_sb, int k_st, int k_sh,
+ int oq_sb, int oq_sh, int oq_st,
+ int ok_sb, int ok_sh, int ok_st,
+ int pos_sb, int pos_st,
+ int cache_sp, int cache_sd,
+ int64_t stream_int
+);
+"""
+
_mod = None
def _get_mod():
@@ -135,7 +152,7 @@
if _mod is None:
_mod = load_inline(
name="mrope_cuda",
- cpp_sources="",
+ cpp_sources=_CPP_SRC,
cuda_sources=_CUDA_SRC,
functions=["launch_mrope_kernel"],
verbose=False,
{
"content": "<8744 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Check CUDA C++ kernel correctness",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 197, in forward
mod.launch_mrope_kernel(
TypeError: launch_mrope_kernel(): incompatible function arguments. The following argument types are supported:
1. (arg0: typing_extensions.CapsuleType, arg1: typing_extensions.CapsuleType, arg2: typing_extensions.CapsuleType, arg3: typing_extensions.CapsuleType, arg4: typing.SupportsInt, arg5: typing.SupportsInt, arg6: typing.SupportsInt, arg7: typing_extensions.CapsuleType, arg8: typing_extensions.CapsuleType, arg9: typing.SupportsInt, arg10: typing.SupportsInt, arg11: typing.SupportsInt, arg12: typing.SupportsInt, arg13: typing.SupportsInt, arg14: typing.SupportsInt, arg1
... (792 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -9,15 +9,15 @@
template<int D_HALF, int BLOCK_SIZE>
__global__ void mrope_fused_kernel(
- const void* __restrict__ q_ptr_v,
- const void* __restrict__ k_ptr_v,
- void* __restrict__ q_rot_ptr_v,
- void* __restrict__ k_rot_ptr_v,
+ const __nv_bfloat16* __restrict__ q_ptr,
+ const __nv_bfloat16* __restrict__ k_ptr,
+ __nv_bfloat16* __restrict__ q_rot_ptr,
+ __nv_bfloat16* __restrict__ k_rot_ptr,
const int64_t* __restrict__ pos_t_ptr,
const int64_t* __restrict__ pos_h_ptr,
const int64_t* __restrict__ pos_w_ptr,
- const void* __restrict__ cos_cache_v,
- const void* __restrict__ sin_cache_v,
+ const __nv_bfloat16* __restrict__ cos_cache,
+ const __nv_bfloat16* __restrict__ sin_cache,
const int B, const int T, const int Hq, const int Hkv,
const int s0, const int s1,
const int q_sb, const int q_st, const int q_sh,
@@ -27,13 +27,6 @@
const int pos_sb, const int pos_st,
const int cache_sp, const int cache_sd
) {
- const __nv_bfloat16* q_ptr = static_cast<const __nv_bfloat16*>(q_ptr_v);
- const __nv_bfloat16* k_ptr = static_cast<const __nv_bfloat16*>(k_ptr_v);
- __nv_bfloat16* q_rot_ptr = static_cast<__nv_bfloat16*>(q_rot_ptr_v);
- __nv_bfloat16* k_rot_ptr = static_cast<__nv_bfloat16*>(k_rot_ptr_v);
- const __nv_bfloat16* cos_cache = static_cast<const __nv_bfloat16*>(cos_cache_v);
- const __nv_bfloat16* sin_cache = static_cast<const __nv_bfloat16*>(sin_cache_v);
-
__shared__ __nv_bfloat16 smem_cos[D_HALF];
__shared__ __nv_bfloat16 smem_sin[D_HALF];
__shared__ int pos_t_val, pos_h_val, pos_w_val;
@@ -91,10 +84,10 @@
}
void launch_mrope_kernel(
- const void* q_ptr, const void* k_ptr,
- void* q_rot_ptr, void* k_rot_ptr,
- const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
- const void* cos_cache, const void* sin_cache,
+ int64_t q_ptr_int, int64_t k_ptr_int,
+ int64_t q_rot_ptr_int, int64_t k_rot_ptr_int,
+ int64_t pos_t_ptr_int, int64_t pos_h_ptr_int, int64_t pos_w_ptr_int,
+ int64_t cos_cache_int, int64_t sin_cache_int,
int B, int T, int Hq, int Hkv, int D, int s0, int s1,
int q_sb, int q_st, int q_sh,
int k_sb, int k_st, int k_sh,
@@ -104,6 +97,16 @@
int cache_sp, int cache_sd,
int64_t stream_int
) {
+ const __nv_bfloat16* q_ptr = reinterpret_cast<const __nv_bfloat16*>(static_cast<intptr_t>(q_ptr_int));
+ const __nv_bfloat16* k_ptr = reinterpret_cast<const __nv_bfloat16*>(static_cast<intptr_t>(k_ptr_int));
+ __nv_bfloat16* q_rot_ptr = reinterpret_cast<__nv_bfloat16*>(static_cast<intptr_t>(q_rot_ptr_int));
+ __nv_bfloat16* k_rot_ptr = reinterpret_cast<__nv_bfloat16*>(static_cast<intptr_t>(k_rot_ptr_int));
+ const int64_t* pos_t_ptr = reinterpret_cast<const int64_t*>(static_cast<intptr_t>(pos_t_ptr_int));
+ const int64_t* pos_h_ptr = reinterpret_cast<const int64_t*>(static_cast<intptr_t>(pos_h_ptr_int));
+ const int64_t* pos_w_ptr = reinterpret_cast<const int64_t*>(static_cast<intptr_t>(pos_w_ptr_int));
+ const __nv_bfloat16* cos_cache = reinterpret_cast<const __nv_bfloat16*>(static_cast<intptr_t>(cos_cache_int));
+ const __nv_bfloat16* sin_cache = reinterpret_cast<const __nv_bfloat16*>(static_cast<intptr_t>(sin_cache_int));
+
const int D_half = D / 2;
const int BT = B * T;
const int BLOCK = 256;
@@ -111,7 +114,7 @@
if (D_half == 64) {
mrope_fused_kernel<64, BLOCK><<<BT, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
- pos_t, pos_h, pos_w, cos_cache, sin_cache,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr, cos_cache, sin_cache,
B, T, Hq, Hkv, s0, s1,
q_sb, q_st, q_sh, k_sb, k_st, k_sh,
oq_sb, oq_sh, oq_st, ok_sb, ok_sh, ok_st,
@@ -119,7 +122,7 @@
} else {
mrope_fused_kernel<32, BLOCK><<<BT, BLOCK, 0, cs>>>(
q_ptr, k_ptr, q_rot_ptr, k_rot_ptr,
- pos_t, pos_h, pos_w, cos_cache, sin_cache,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr, cos_cache, sin_cache,
B, T, Hq, Hkv, s0, s1,
q_sb, q_st, q_sh, k_sb, k_st, k_sh,
oq_sb, oq_sh, oq_st, ok_sb, ok_sh, ok_st,
@@ -130,10 +133,10 @@
_CPP_SRC = """
void launch_mrope_kernel(
- const void* q_ptr, const void* k_ptr,
- void* q_rot_ptr, void* k_rot_ptr,
- const int64_t* pos_t, const int64_t* pos_h, const int64_t* pos_w,
- const void* cos_cache, const void* sin_cache,
+ int64_t q_ptr_int, int64_t k_ptr_int,
+ int64_t q_rot_ptr_int, int64_t k_rot_ptr_int,
+ int64_t pos_t_ptr_int, int64_t pos_h_ptr_int, int64_t pos_w_ptr_int,
+ int64_t cos_cache_int, int64_t sin_cache_int,
int B, int T, int Hq, int Hkv, int D, int s0, int s1,
int q_sb, int q_st, int q_sh,
int k_sb, int k_st, int k_sh,
{
"content": "<9329 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Check CUDA C++ kernel correctness",
"timeout": 300000
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"description": "Benchmark CUDA C++ kernel",
"timeout": 300000
}shape=0 variant=eager tflops=0.234 gbps=335.448 ms=0.269
shape=0 variant=compiled tflops=0.735 gbps=1055.216 ms=0.086
shape=0 variant=solution tflops=0.994 gbps=1426.970 ms=0.063
shape=0 solution_peak_fraction=0.7928
shape=1 variant=eager tflops=0.206 gbps=300.436 ms=0.489
shape=1 variant=compiled tflops=0.657 gbps=958.711 ms=0.153
shape=1 variant=solution tflops=1.050 gbps=1533.522 ms=0.096
shape=1 solution_peak_fraction=0.8520
shape=2 variant=eager tflops=0.138 gbps=215.327 ms=0.103
shape=2 variant=compiled tflops=0.416 gbps=650.235 ms=0.034
shape=2 variant=solution tflops=0.501 gbps=782.343 ms=0.028
shape=2 solution_peak_fraction=0.4346
shape=3 variant=eager tflops=0.178 gbps=255.265 ms=1.415
shape=3 variant=compiled tflops=0.526 gbps=754.764 ms=0.478
shape=3 variant=solution tflops=1.067 gbps=1530.514 ms=0.236
shape=3 solution_peak_fraction=0.8503
peak_fraction: 0.7068
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -7,7 +7,7 @@
#include <cuda_bf16.h>
#include <cuda_runtime.h>
-template<int D_HALF, int BLOCK_SIZE>
+template<int D_HALF, int BLOCK_SIZE, int BT_PER_BLOCK>
__global__ void mrope_fused_kernel(
const __nv_bfloat16* __restrict__ q_ptr,
const __nv_bfloat16* __restrict__ k_ptr,
@@ -27,58 +27,76 @@
const int pos_sb, const int pos_st,
const int cache_sp, const int cache_sd
) {
- __shared__ __nv_bfloat16 smem_cos[D_HALF];
- __shared__ __nv_bfloat16 smem_sin[D_HALF];
- __shared__ int pos_t_val, pos_h_val, pos_w_val;
-
- const int bt_idx = blockIdx.x;
- const int t_idx = bt_idx % T;
- const int b_idx = bt_idx / T;
+ __shared__ __nv_bfloat16 smem_cos[BT_PER_BLOCK][D_HALF];
+ __shared__ __nv_bfloat16 smem_sin[BT_PER_BLOCK][D_HALF];
+ __shared__ int pos_t_vals[BT_PER_BLOCK];
+ __shared__ int pos_h_vals[BT_PER_BLOCK];
+ __shared__ int pos_w_vals[BT_PER_BLOCK];
+
const int tid = threadIdx.x;
const int H_total = Hq + Hkv;
-
- if (tid == 0) pos_t_val = static_cast<int>(pos_t_ptr[b_idx * pos_sb + t_idx * pos_st]);
- if (tid == 1) pos_h_val = static_cast<int>(pos_h_ptr[b_idx * pos_sb + t_idx * pos_st]);
- if (tid == 2) pos_w_val = static_cast<int>(pos_w_ptr[b_idx * pos_sb + t_idx * pos_st]);
+ const int bt_start = blockIdx.x * BT_PER_BLOCK;
+
+ // Phase 1: Load positions for all (b,t) in this block
+ for (int i = tid; i < BT_PER_BLOCK; i += BLOCK_SIZE) {
+ int bt = bt_start + i;
+ if (bt >= B * T) { pos_t_vals[i] = 0; pos_h_vals[i] = 0; pos_w_vals[i] = 0; continue; }
+ int t = bt % T;
+ int b = bt / T;
+ pos_t_vals[i] = static_cast<int>(pos_t_ptr[b * pos_sb + t * pos_st]);
+ pos_h_vals[i] = static_cast<int>(pos_h_ptr[b * pos_sb + t * pos_st]);
+ pos_w_vals[i] = static_cast<int>(pos_w_ptr[b * pos_sb + t * pos_st]);
+ }
__syncthreads();
- for (int p = tid; p < D_HALF; p += BLOCK_SIZE) {
+ // Phase 2: Load cos/sin into shared memory for all (b,t)
+ for (int i = tid; i < BT_PER_BLOCK * D_HALF; i += BLOCK_SIZE) {
+ int bti = i / D_HALF;
+ int p = i % D_HALF;
int pos;
- if (p < s0) pos = pos_t_val;
- else if (p < s0 + s1) pos = pos_h_val;
- else pos = pos_w_val;
- smem_cos[p] = cos_cache[pos * cache_sp + p * cache_sd];
- smem_sin[p] = sin_cache[pos * cache_sp + p * cache_sd];
+ if (p < s0) pos = pos_t_vals[bti];
+ else if (p < s0 + s1) pos = pos_h_vals[bti];
+ else pos = pos_w_vals[bti];
+ smem_cos[bti][p] = cos_cache[pos * cache_sp + p * cache_sd];
+ smem_sin[bti][p] = sin_cache[pos * cache_sp + p * cache_sd];
}
__syncthreads();
- const int total_pairs = H_total * D_HALF;
- for (int idx = tid; idx < total_pairs; idx += BLOCK_SIZE) {
- const int p = idx % D_HALF;
- const int h = idx / D_HALF;
-
- const float cv = __bfloat162float(smem_cos[p]);
- const float sv = __bfloat162float(smem_sin[p]);
-
- if (h < Hq) {
- const int xb = b_idx * q_sb + t_idx * q_st + h * q_sh;
- const float x0 = __bfloat162float(q_ptr[xb + p]);
- const float x1 = __bfloat162float(q_ptr[xb + p + D_HALF]);
- const float o0 = x0 * cv - x1 * sv;
- const float o1 = x1 * cv + x0 * sv;
- const int ob = b_idx * oq_sb + h * oq_sh + t_idx * oq_st;
- q_rot_ptr[ob + p] = __float2bfloat16(o0);
- q_rot_ptr[ob + p + D_HALF] = __float2bfloat16(o1);
- } else {
- const int hk = h - Hq;
- const int xb = b_idx * k_sb + t_idx * k_st + hk * k_sh;
- const float x0 = __bfloat162float(k_ptr[xb + p]);
- const float x1 = __bfloat162float(k_ptr[xb + p + D_HALF]);
- const float o0 = x0 * cv - x1 * sv;
- const float o1 = x1 * cv + x0 * sv;
- const int ob = b_idx * ok_sb + hk * ok_sh + t_idx * ok_st;
- k_rot_ptr[ob + p] = __float2bfloat16(o0);
- k_rot_ptr[ob + p + D_HALF] = __float2bfloat16(o1);
+ // Phase 3: Process all heads for all (b,t) in this block
+ for (int bti = 0; bti < BT_PER_BLOCK; bti++) {
+ int bt = bt_start + bti;
+ if (bt >= B * T) break;
+ int t = bt % T;
+ int b = bt / T;
+
+ const int total_pairs = H_total * D_HALF;
+ for (int idx = tid; idx < total_pairs; idx += BLOCK_SIZE) {
+ const int p = idx % D_HALF;
+ const int h = idx / D_HALF;
+
+ const float cv = __bfloat162float(smem_cos[bti][p]);
+ const float sv = __bfloat162float(smem_sin[bti][p]);
+
+ if (h < Hq) {
+ const int xb = b * q_sb + t * q_st + h * q_sh;
+ const float x0 = __bfloat162float(q_ptr[xb + p]);
+ const float x1 = __bfloat162float(q_ptr[xb + p + D_HALF]);
+ const float o0 = x0 * cv - x1 * sv;
+ const float o1 = x1 * cv + x0 * sv;
+ const int ob = b * oq_sb + h * oq_sh + t * oq_st;
+ q_rot_ptr[ob + p] = __float2bfloat16(o0);
+ q_rot_ptr[ob + p + D_HALF] = __float2bfloat16(o1);
+ } else {
+ const int hk = h - Hq;
+ const int xb = b * k_sb + t * k_st + hk * k_sh;
+ const float x0 = __bfloat162float(k_ptr[xb + p]);
+ const float x1 = __bfloat162float(k_ptr[xb + p + D_HALF]);
+ const float o0 = x0 * cv - x1 * sv;
+ const float o1 = x1 * cv + x0 * sv;
+ const int ob = b * ok_sb + hk * ok_sh + t * ok_st;
+ k_rot_ptr[ob + p] = __float2bfloat16(o0);
+ k_rot_ptr[ob + p + D_HALF] = __float2bfloat16(o1);
+ }
}
}
}
@@ -106,27 +124,62 @@
const int
... (4144 more chars){
"content": "<11835 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"description": "Check BT_PER_BLOCK kernel correctness",
"timeout": 300000
}PASS