import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
_CUDA_SRC = r"""
#include <torch/extension.h>
#include <cuda_bf16.h>
#include <cstdint>
__global__ void __launch_bounds__(256)
mrope_direct(
const __nv_bfloat16* __restrict__ q_in,
const __nv_bfloat16* __restrict__ k_in,
__nv_bfloat16* __restrict__ q_out,
__nv_bfloat16* __restrict__ k_out,
const int64_t* __restrict__ pos_t,
const int64_t* __restrict__ pos_h,
const int64_t* __restrict__ pos_w,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
const int B, const int T, const int Hq, const int Hkv, const int D,
const int s0, const int s1, const int s2
) {
const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
const int rows_per_block = blockDim.x / 32;
const int row = blockIdx.x * rows_per_block + warp_id;
if (row >= B * Hq * T) return;
const int bt = row / Hq;
const int h = row % Hq;
const int b = bt / T;
const int t = bt % T;
const int half_D = D >> 1;
const int ept = D / 32;
const int d_base = lane_id * ept;
const int64_t bt_idx = (int64_t)b * T + t;
const int64_t q_in_off = (int64_t)bt * Hq * D + (int64_t)h * D;
const int64_t q_out_off = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D;
// Precompute axis and pos for each d
float cos_v[4], sin_v[4];
#pragma unroll
for (int i = 0; i < ept; i++) {
const int d = d_base + i;
const int d_mod = d < half_D ? d : d - half_D;
int64_t pos;
if (d_mod < s0) pos = __ldg(&pos_t[bt_idx]);
else if (d_mod < s0 + s1) pos = __ldg(&pos_h[bt_idx]);
else pos = __ldg(&pos_w[bt_idx]);
cos_v[i] = __bfloat162float(__ldg(&cos_cache[pos * D + d]));
sin_v[i] = __bfloat162float(__ldg(&sin_cache[pos * D + d]));
}
// Process q
#pragma unroll
for (int i = 0; i < ept; i++) {
const int d = d_base + i;
const float x_val = __bfloat162float(__ldg(&q_in[q_in_off + d]));
const int pd = d < half_D ? d + half_D : d - half_D;
const float pval = __bfloat162float(__ldg(&q_in[q_in_off + pd]));
const float rh = d < half_D ? -pval : pval;
const float res = x_val * cos_v[i] + rh * sin_v[i];
q_out[q_out_off + d] = __float2bfloat16(res);
}
// Process k
if (h < Hkv) {
const int64_t k_in_off = (int64_t)bt * Hkv * D + (int64_t)h * D;
const int64_t k_out_off = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D;
#pragma unroll
for (int i = 0; i < ept; i++) {
const int d = d_base + i;
const float x_val = __bfloat162float(__ldg(&k_in[k_in_off + d]));
const int pd = d < half_D ? d + half_D : d - half_D;
const float pval = __bfloat162float(__ldg(&k_in[k_in_off + pd]));
const float rh = d < half_D ? -pval : pval;
const float res = x_val * cos_v[i] + rh * sin_v[i];
k_out[k_out_off + d] = __float2bfloat16(res);
}
}
}
void mrope_direct_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
) {
const int rows_per_block = 256 / 32;
const int total_rows = B * Hq * T;
const int grid = (total_rows + rows_per_block - 1) / rows_per_block;
mrope_direct<<<grid, 256>>>(
reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(k_out.data_ptr<at::BFloat16>()),
pos_t.data_ptr<int64_t>(),
pos_h.data_ptr<int64_t>(),
pos_w.data_ptr<int64_t>(),
reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
B, T, Hq, Hkv, D, s0, s1, s2
);
}
"""
_CPP_DECL = r"""
#include <torch/extension.h>
#include <cstdint>
void mrope_direct_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2);
"""
_mod = None
def _get_mod():
global _mod
if _mod is None:
_mod = load_inline(
name="mrope_direct_v2",
cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
functions=["mrope_direct_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
return _mod
OP_TYPE = "rope"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
def _build_inv_freq(D, base=10000.0):
return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
class Model(nn.Module):
def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
super().__init__()
assert sum(mrope_section) == D // 2
self.B, self.T = B, T
self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
inv_freq = _build_inv_freq(D)
pos = torch.arange(max_pos, dtype=torch.float32)
freqs = torch.outer(pos, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
def forward(self, q, k, pos_t, pos_h, pos_w):
B, T, Hq, D = q.shape
Hkv = k.shape[2]
s0, s1, s2 = self.mrope_section
q_out = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
k_out = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
mod = _get_mod()
mod.mrope_direct_forward(q, k, q_out, k_out,
pos_t, pos_h, pos_w,
self.cos_cache, self.sin_cache,
B, T, Hq, Hkv, D, s0, s1, s2)
return q_out, k_out
B = 1
T = 4096
Hq = 32
Hkv = 8
D = 128
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768
def get_inputs():
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
return [q, k, pos_t, pos_h, pos_w]
def get_init_inputs():
return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
shape=0 variant=eager tflops=0.183 gbps=262.113 ms=0.344
shape=0 variant=compiled tflops=0.507 gbps=727.562 ms=0.124
shape=0 variant=solution tflops=0.649 gbps=930.756 ms=0.097
shape=0 solution_peak_fraction=0.5171
shape=1 variant=eager tflops=0.211 gbps=308.527 ms=0.476
shape=1 variant=compiled tflops=0.481 gbps=702.664 ms=0.209
shape=1 variant=solution tflops=0.691 gbps=1008.820 ms=0.146
shape=1 solution_peak_fraction=0.5605
shape=2 variant=eager tflops=0.074 gbps=115.915 ms=0.191
shape=2 variant=compiled tflops=0.166 gbps=259.459 ms=0.085
shape=2 variant=solution tflops=0.421 gbps=657.034 ms=0.034
shape=2 solution_peak_fraction=0.3650
shape=3 variant=eager tflops=0.173 gbps=248.065 ms=1.456
shape=3 variant=compiled tflops=0.470 gbps=675.030 ms=0.535
shape=3 variant=solution tflops=0.844 gbps=1210.392 ms=0.298
shape=3 solution_peak_fraction=0.6724
peak_fraction: 0.5164
RESULT: OK
PASS
{
"problem": "09_fmha_preattn_mrope",
"harness": "opencode",
"model": "openrouter-pinned/xiaomi/mimo-v2.5-pro",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.5164,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 264558, "output_tokens": 36580, "cache_read_tokens": 5345792, "cache_creation_tokens": 0, "reasoning_tokens": 104685, "total_cost_usd": null}
}
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, problem.yaml to understand the op",
"status": "in_progress",
"priority": "high"
},
{
"content": "Read and understand the mRoPE math from transformers Qwen2-VL",
"status": "pending",
"priority": "high"
},
{
"content": "Implement initial fused CUDA kernel in solution.py",
"status": "pending",
"priority": "high"
},
{
"content": "Verify correctness with check.py",
"status": "pending",
"priority": "high"
},
{
"conte
... (229 more chars)content: 'Read reference.py, check.py, benchmark.py, problem.yaml to understand the op'
status: 'in_progress'
priority: 'high'
content: 'Read and understand the mRoPE math from transformers Qwen2-VL'
status: 'pending'
priority: 'high'
content: 'Implement initial fused CUDA kernel in solution.py'
status: 'pending'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile with ncu/nsys and optimize'
status: 'pending'
priority: 'medium'
content: 'Benchmark and squeeze performance'
status: 'pending'
priority: 'medium'{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
1: """Naive multi-axis RoPE pre-attention reference (correctness only).
2:
3: Mirrors the Qwen2-VL apply_multimodal_rotary_pos_emb semantics: three position
4: axes (temporal, height, width), each contributing a slice of the head_dim via
5: mrope_section. Standard rotate-half on q and k once per-axis cos/sin have been
6: mixed.
7:
8: Layout convention:
9: q: (B, T, Hq, D) bf16
10: k: (B, T, Hkv, D) bf16
11: pos_t, pos_h, pos_w: (B, T) int64
12:
13: Output: q_rot, k_rot in (B, H, T, D) — i.e. the layout an attention kernel
14: expects (head dim second). The transpose is part of the "pre-attention prep".
15: """
16: import torch
17: import torch.nn as nn
18:
19: OP_TYPE = "rope"
20: SUPPORTED_PRECISIONS = ["bf16"]
21: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
22:
23:
24: def _build_inv_freq(D: int, base: float = 10000.0) -> torch.Tensor:
25: return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
26:
27:
28: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
29: half = x.shape[-1] // 2
30: return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
31:
32:
33: class Model(nn.Module):
34: def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
35: mrope_section: tuple[int, int, int], max_pos: int):
36: super().__init__()
37: assert sum(mrope_section) == D // 2, \
38:
... (3384 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py</path>
<type>file</type>
<content>
1: """Correctness runner for multi-axis RoPE pre-attention prep.
2:
3: Runs solution.Model vs reference.Model across all shapes in shapes.py, 3 seeds
4: each, with per-dtype atol/rtol. Also rejects forbidden ops by grep.
5: """
6: import re
7: import sys
8: from pathlib import Path
9:
10: import torch
11: import yaml
12:
13: # Make the repo's src/ importable
14: REPO_ROOT = Path(__file__).resolve().parents[2]
15: sys.path.insert(0, str(REPO_ROOT))
16:
17: from src.eval.correctness import check_correctness # noqa: E402
18:
19:
20: def main():
21: try:
22: import reference
23: import shapes
24: import solution
25: except Exception as e:
26: print(f"FAIL: import error: {e}")
27: sys.exit(1)
28:
29: problem_yaml = Path("problem.yaml")
30: meta = yaml.safe_load(problem_yaml.read_text()) if problem_yaml.exists() else {}
31:
32: # --- Forbidden-op check ------------------------------------------------
33: sol_src = Path("solution.py").read_text() if Path("solution.py").exists() else ""
34: for forbidden in meta.get("forbidden", []):
35: pat = re.escape(forbidden)
36: if re.search(pat, sol_src):
37: print(f"FAIL: forbidden op used: {forbidden}")
38: sys.exit(1)
39:
40: device = torch.device("cuda:0")
41: tol_override = meta.get("tolerance") or
... (3519 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/benchmark.py</path>
<type>file</type>
<content>
1: """Roofline benchmark for multi-axis RoPE pre-attention prep.
2:
3: For each shape: times eager reference, compiled reference, SOTA (if available),
4: and the agent's solution. Reports achieved TFLOPS, GB/s, and peak_fraction.
5:
6: Output lines the harness picks up:
7: shape=<idx> variant=<name> tflops=<N> gbps=<N> ms=<N>
8: peak_fraction: <N> (geomean over shapes of solution's peak_fraction)
9: """
10: import sys
11: from math import exp, log
12: from pathlib import Path
13:
14: import torch
15: import yaml
16:
17: REPO_ROOT = Path(__file__).resolve().parents[2]
18: sys.path.insert(0, str(REPO_ROOT))
19:
20: from src.eval.roofline import compute_gbps, compute_tflops, peak_fraction # noqa: E402
21: from src.eval.timing import time_fn # noqa: E402
22: from src.hardware import get as get_hw # noqa: E402
23:
24:
25: def _eval_formula(expr: str, vars: dict) -> float:
26: # Very small eval: only names from `vars` are valid.
27: return float(eval(expr, {"__builtins__": {}}, vars))
28:
29:
30: def main():
31: import reference
32: import shapes
33: import solution
34:
35: meta = yaml.safe_load(Path("problem.yaml").read_text())
36: hw = get_hw(meta["hardware"][0])
37: peak_tflops = hw.peak_tflops_dense.get(meta["peak_tflops_key"], 0.0)
38: peak_gbps = hw.peak_bandwidth_gb_s
39: regime = meta.get("regime
... (3966 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/problem.yaml</path>
<type>file</type>
<content>
1: name: 09_fmha_preattn_mrope
2: display_name: "Multi-axis RoPE Pre-attention Prep"
3: precision: bf16
4: regime: memory # pointwise rotation + cache gather; bandwidth-bound
5:
6: # FLOPs per call: rotate-half on q and k.
7: # Per element: out = a*cos + b*sin -> 3 flops/elem (1 mul + 1 mul + 1 add).
8: # We do this on q ((B*T*Hq*D) elems) and k ((B*T*Hkv*D) elems).
9: flops_formula: "3 * B * T * (Hq + Hkv) * D"
10:
11: # Bytes moved (lower bound):
12: # read q, k: 2 * B*T*(Hq+Hkv)*D
13: # read 3 position id arrays: 3 * B*T*8 (int64)
14: # read cos/sin from cache: 6 * B*T*D*2 (3 axes, 2 caches, bf16)
15: # write q_rot, k_rot: 2 * B*T*(Hq+Hkv)*D
16: bytes_formula: "4 * B * T * (Hq + Hkv) * D + 12 * B * T * D + 24 * B * T"
17:
18: hardware: [RTX_PRO_6000]
19: peak_tflops_key: bf16 # informational; regime=memory uses bandwidth
20: peak_bandwidth_key: dram
21:
22: tolerance:
23: bfloat16: 0.02
24:
25: # Forbidden ops -- agent must implement the gather + per-axis mix + rotation.
26: # Banning vendor mRoPE entry points and obvious shortcuts.
27: forbidden:
28: - "transformers.models.qwen2_vl.modeling_qwen2_vl.apply_multimodal_rotary_pos_emb"
29: - "apply_multimodal_rotary_pos_emb"
30: - "flash_attn.layers.rotary"
31: - "flashinfer.rope"
32: - "flashinfer.apply_rope"
33:
34: sota:
35: name:
... (284 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/shapes.py</path>
<type>file</type>
<content>
1: """Canonical shape sweep for multi-axis RoPE pre-attention.
2:
3: Sized for Qwen2-VL-style vision-language inference: long T, GQA, mrope_section
4: splitting D/2 across temporal/height/width axes. Mix of base, long-context,
5: small-head-dim, and very-long-context.
6: """
7:
8: SHAPES = [
9: # Qwen2-VL base: T=4k, head_dim=128, GQA 32:8, mrope [16,24,24]
10: {"B": 1, "T": 4096, "Hq": 32, "Hkv": 8, "D": 128,
11: "mrope_section": (16, 24, 24), "max_pos": 32768},
12:
13: # Qwen2-VL 7B-style long context, narrower kv heads
14: {"B": 1, "T": 8192, "Hq": 28, "Hkv": 4, "D": 128,
15: "mrope_section": (16, 24, 24), "max_pos": 32768},
16:
17: # Smaller head_dim and B=2 (tests batch dim and reduced D)
18: {"B": 2, "T": 2048, "Hq": 16, "Hkv": 2, "D": 64,
19: "mrope_section": (8, 12, 12), "max_pos": 16384},
20:
21: # Very long context (16k tokens)
22: {"B": 1, "T": 16384, "Hq": 32, "Hkv": 8, "D": 128,
23: "mrope_section": (16, 24, 24), "max_pos": 65536},
24: ]
(End of file - total 24 lines)
</content>{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, problem.yaml to understand the op",
"priority": "high",
"status": "completed"
},
{
"content": "Read and understand the mRoPE math from transformers Qwen2-VL",
"priority": "high",
"status": "completed"
},
{
"content": "Implement initial fused CUDA kernel in solution.py",
"priority": "high",
"status": "in_progress"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "pending"
},
{
"c
... (233 more chars)content: 'Read reference.py, check.py, benchmark.py, problem.yaml to understand the op'
status: 'completed'
priority: 'high'
content: 'Read and understand the mRoPE math from transformers Qwen2-VL'
status: 'completed'
priority: 'high'
content: 'Implement initial fused CUDA kernel in solution.py'
status: 'in_progress'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'pending'
priority: 'high'
content: 'Profile with ncu/nsys and optimize'
status: 'pending'
priority: 'medium'
content: 'Benchmark and squeeze performance'
status: 'pending'
priority: 'medium'{
"command": "nvidia-smi 2>&1 | head -5 && nvcc --version 2>&1 | tail -3",
"description": "Check GPU and CUDA version"
}Thu May 7 13:36:51 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
Built on Mon_Mar_02_09:52:23_PM_PST_2026
Cuda compilation tools, release 13.2, V13.2.51
Build cuda_13.2.r13.2/compiler.37434383_0
{
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/",
"description": "List problem directory"
}benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope</path>
<type>directory</type>
<entries>
benchmark.py
check.py
problem.yaml
PROMPT.txt
reference.py
shapes.py
sota.py
(7 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/sota.py</path>
<type>file</type>
<content>
1: """SOTA reference for multi-axis RoPE pre-attention.
2:
3: Wraps transformers' apply_multimodal_rotary_pos_emb. The full pipeline (gather
4: per-axis cos/sin from cache, then call the vendor rotation) is timed, matching
5: what the reference and the agent solution do.
6:
7: If transformers isn't installed or the impl signature has drifted, returns
8: is_available()=False and the benchmark omits the SOTA line.
9: """
10: from __future__ import annotations
11:
12: import torch
13:
14:
15: def _try_transformers():
16: try:
17: from transformers.models.qwen2_vl.modeling_qwen2_vl import (
18: apply_multimodal_rotary_pos_emb as _impl,
19: )
20: return _impl
21: except Exception:
22: return None
23:
24:
25: def sota_forward(
26: q: torch.Tensor, k: torch.Tensor,
27: pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
28: cos_cache: torch.Tensor, sin_cache: torch.Tensor,
29: mrope_section: tuple[int, int, int],
30: ):
31: """q,k: (B, T, H, D); pos_*: (B, T); cos_cache,sin_cache: (max_pos, D)."""
32: impl = _try_transformers()
33: if impl is None:
34: raise RuntimeError("transformers not available for SOTA")
35:
36: # Gather per-axis cos/sin and stack into (3, B, T, D) as transformers expects.
37: cos = torch.stack([cos_cache[pos_t], cos_cache[pos_h], c
... (495 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py</path>
<type>file</type>
<content>
1: """Per-dtype correctness tolerance.
2:
3: Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
4: We extend to fp8 (0.1, permissive) and int (exact).
5: """
6: from __future__ import annotations
7:
8: import torch
9:
10: DEFAULT_TOLERANCE = {
11: torch.float32: {"atol": 1e-4, "rtol": 1e-4},
12: torch.float16: {"atol": 1e-2, "rtol": 1e-2},
13: torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
14: torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
15: torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
16: torch.int8: {"atol": 0, "rtol": 0},
17: torch.int32: {"atol": 0, "rtol": 0},
18: torch.int64: {"atol": 0, "rtol": 0},
19: }
20:
21:
22: def _coerce_float(x):
23: """Accept int/float/str scientific-notation as float; pass through dicts."""
24: if isinstance(x, (int, float)):
25: return float(x)
26: if isinstance(x, str):
27: try:
28: return float(x)
29: except ValueError:
30: return x
31: return x
32:
33:
34: def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
35: """Lookup atol/rtol for a given dtype, with optional per-problem override.
36:
37: The override dict is keyed by str(dtype). Values may be:
38: - a single number (or a string parseable as float, e.g. "5e-2") -> used
39: as both atol and rtol
4
... (2580 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/roofline.py</path>
<type>file</type>
<content>
1: """Roofline math: achieved TFLOPS / GB/s, peak fraction.
2:
3: FLOPS and bytes formulas come from the problem's `problem.yaml`. They are the
4: dense-equivalent algorithmic work; kernels that skip work (sparsity, early
5: exit) are still measured against the dense cost.
6: """
7: from __future__ import annotations
8:
9:
10: def compute_tflops(flops: float, time_ms: float) -> float:
11: """flops per invocation and wall time in ms -> TFLOPS."""
12: if time_ms <= 0:
13: return 0.0
14: return flops / (time_ms * 1e-3) / 1e12
15:
16:
17: def compute_gbps(bytes_moved: float, time_ms: float) -> float:
18: """bytes per invocation and wall time in ms -> GB/s."""
19: if time_ms <= 0:
20: return 0.0
21: return bytes_moved / (time_ms * 1e-3) / 1e9
22:
23:
24: def peak_fraction(
25: achieved: float,
26: peak: float,
27: ) -> float:
28: """Return achieved / peak, clamped to [0, infinity)."""
29: if peak <= 0:
30: return 0.0
31: return max(0.0, achieved / peak)
(End of file - total 31 lines)
</content>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py</path>
<type>file</type>
<content>
1: """Centralized GPU timing for benchmark.py files.
2:
3: All problems' benchmark.py should call `time_fn` from here so we have one
4: implementation of warmup, L2 flush, and CUDA event capture to maintain.
5:
6: Methodology:
7: - 10 warmup calls absorb Triton autotune (typical ~7 configs) and
8: torch.compile reduce-overhead CUDA-graph capture.
9: - Between each timed call, `_l2_flush()` writes 128 MB to evict L2
10: (Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
11: we measure HBM-load bandwidth, not L2-cached bandwidth).
12: - GPU timing via cuda.Event with synchronize() AFTER record() but
13: BEFORE elapsed_time() — the canonical NVIDIA pattern.
14: - Reported value is the median of `iters` trials, robust to outliers.
15:
16: Notes / known biases not addressed here:
17: - torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
18: launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
19: On small shapes where launch overhead matters, this gives compile an
20: artificial advantage. We accept this as the cost of using the agreed
21: "compiled" baseline policy.
22: - cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
23: - Median over a fairly small number of trials (default 30) is fine for
24: headline numbers but won't catch bimodal distributions. Use --ext
... (1478 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware.py"
}{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware</path>
<type>directory</type>
<entries>
__init__.py
__pycache__/
m4_max.py
rtx_pro_6000.py
(4 entries)
</entries>{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/src/hardware/rtx_pro_6000.py</path>
<type>file</type>
<content>
1: """RTX PRO 6000 Blackwell Workstation — SM120, consumer-lineage Blackwell.
2:
3: Peak tensor-core throughputs are dense-matrix advertised peaks. Actual kernels
4: will see 60-85% of peak on well-tuned code.
5: """
6: from dataclasses import dataclass
7:
8:
9: @dataclass(frozen=True)
10: class HardwareTarget:
11: name: str
12: sm: str
13: vram_gb: int
14: peak_bandwidth_gb_s: float # DRAM
15: peak_tflops_dense: dict[str, float] # dtype -> TFLOPS
16:
17:
18: RTX_PRO_6000 = HardwareTarget(
19: name="RTX PRO 6000 Blackwell Workstation",
20: sm="sm_120a",
21: vram_gb=96,
22: peak_bandwidth_gb_s=1800.0,
23: peak_tflops_dense={
24: "fp4": 800.0,
25: "nvfp4": 800.0,
26: "mxfp4": 800.0,
27: "fp6": 800.0,
28: "fp8": 400.0,
29: "bf16": 200.0,
30: "fp16": 200.0,
31: "tf32": 100.0,
32: "fp32": 12.0, # non-tensor-core SIMT fp32
33: "int8": 400.0,
34: "int4": 800.0,
35: },
36: )
(End of file - total 36 lines)
</content>{
"command": "uv run python -c \"import torch; print(torch.cuda.get_device_capability()); print(torch.version.cuda); print(torch.cuda.get_device_name())\"",
"description": "Check GPU architecture and CUDA version",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}(12, 0)
13.0
NVIDIA RTX PRO 6000 Blackwell Workstation Edition
{
"command": "<277 chars \u2014 see below>",
"description": "Determine correct arch flag",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}uv run python -c "
import torch
# Check compute capability
cap = torch.cuda.get_device_capability()
print(f'Compute capability: {cap[0]}.{cap[1]}')
print(f'SM: sm_{cap[0]}{cap[1]}')
# Check if we can use sm_120a
arch = f'sm_{cap[0]}{cap[1]}'
print(f'Arch flag: -arch={arch}')
"Compute capability: 12.0
SM: sm_120
Arch flag: -arch=sm_120
--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,146 @@
+import torch
+import torch.nn as nn
+from torch.utils.cpp_extension import load_inline
+
+_CUDA_SRC = r"""
+#include <cuda_bf16.h>
+
+__global__ void mrope_rotate(
+ const __nv_bfloat16* __restrict__ x_in,
+ __nv_bfloat16* __restrict__ x_out,
+ const long long* __restrict__ pos_t,
+ const long long* __restrict__ pos_h,
+ const long long* __restrict__ pos_w,
+ const __nv_bfloat16* __restrict__ cos_cache,
+ const __nv_bfloat16* __restrict__ sin_cache,
+ const int B, const int T, const int H, const int D,
+ const int s0, const int s1, const int s2
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int total = B * H * T * D;
+ if (idx >= total) return;
+
+ const int d = idx % D;
+ const int tmp = idx / D;
+ const int t = tmp % T;
+ const int tmp2 = tmp / T;
+ const int h = tmp2 % H;
+ const int b = tmp2 / H;
+
+ const int half_D = D >> 1;
+ const int d_mod = d < half_D ? d : d - half_D;
+ const long long bt = (long long)b * T + t;
+
+ long long pos;
+ if (d_mod < s0) pos = pos_t[bt];
+ else if (d_mod < s0 + s1) pos = pos_h[bt];
+ else pos = pos_w[bt];
+
+ const float cos_v = __bfloat162float(cos_cache[pos * D + d]);
+ const float sin_v = __bfloat162float(sin_cache[pos * D + d]);
+
+ const long long in_base = (long long)b * T * H * D
+ + (long long)t * H * D
+ + (long long)h * D;
+ const float x_val = __bfloat162float(x_in[in_base + d]);
+
+ const int pd = d < half_D ? d + half_D : d - half_D;
+ const float pval = __bfloat162float(x_in[in_base + pd]);
+ const float rh = d < half_D ? -pval : pval;
+
+ const float res = x_val * cos_v + rh * sin_v;
+
+ const long long oi = (long long)b * H * T * D
+ + (long long)h * T * D
+ + (long long)t * D + d;
+ x_out[oi] = __float2bfloat16(res);
+}
+"""
+
+_mod = None
+
+def _get_mod():
+ global _mod
+ if _mod is None:
+ _mod = load_inline(
+ name="mrope_ext",
+ cpp_sources="",
+ cuda_sources=_CUDA_SRC,
+ functions=["mrope_rotate"],
+ extra_cuda_cflags=["-O3", "--use_fast_math"],
+ verbose=False,
+ )
+ return _mod
+
+
+OP_TYPE = "rope"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+
+
+def _build_inv_freq(D, base=10000.0):
+ return 1.0 / (base ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
+
+
+class Model(nn.Module):
+ def __init__(self, B, T, Hq, Hkv, D, mrope_section, max_pos):
+ super().__init__()
+ assert sum(mrope_section) == D // 2
+ self.B, self.T = B, T
+ self.Hq, self.Hkv, self.D = Hq, Hkv, D
+ self.mrope_section = tuple(mrope_section)
+ self.max_pos = max_pos
+
+ inv_freq = _build_inv_freq(D)
+ pos = torch.arange(max_pos, dtype=torch.float32)
+ freqs = torch.outer(pos, inv_freq)
+ emb = torch.cat([freqs, freqs], dim=-1)
+ self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
+ self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
+
+ def forward(self, q, k, pos_t, pos_h, pos_w):
+ B, T, Hq, D = q.shape
+ Hkv = k.shape[2]
+ s0, s1, s2 = self.mrope_section
+
+ q_out = torch.empty(B, Hq, T, D, dtype=q.dtype, device=q.device)
+ k_out = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
+
+ mod = _get_mod()
+ THR = 256
+
+ nq = B * Hq * T * D
+ mod.mrope_rotate(((nq + THR - 1) // THR,), (THR,),
+ q, q_out, pos_t, pos_h, pos_w,
+ self.cos_cache, self.sin_cache,
+ B, T, Hq, D, s0, s1, s2)
+
+ nk = B * Hkv * T * D
+ mod.mrope_rotate(((nk + THR - 1) // THR,), (THR,),
+ k, k_out, pos_t, pos_h, pos_w,
+ self.cos_cache, self.sin_cache,
+ B, T, Hkv, D, s0, s1, s2)
+
+ return q_out, k_out
+
+
+B = 1
+T = 4096
+Hq = 32
+Hkv = 8
+D = 128
+MROPE_SECTION = (16, 24, 24)
+MAX_POS = 32768
+
+
+def get_inputs():
+ q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
+ k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
+ pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
+ return [q, k, pos_t, pos_h, pos_w]
+
+
+def get_init_inputs():
+ return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"content": "<4519 chars \u2014 see diff>"
}Wrote file successfully.{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, problem.yaml to understand the op",
"priority": "high",
"status": "completed"
},
{
"content": "Read and understand the mRoPE math from transformers Qwen2-VL",
"priority": "high",
"status": "completed"
},
{
"content": "Implement initial fused CUDA kernel in solution.py",
"priority": "high",
"status": "completed"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "in_progress"
},
{
... (235 more chars)content: 'Read reference.py, check.py, benchmark.py, problem.yaml to understand the op'
status: 'completed'
priority: 'high'
content: 'Read and understand the mRoPE math from transformers Qwen2-VL'
status: 'completed'
priority: 'high'
content: 'Implement initial fused CUDA kernel in solution.py'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'in_progress'
priority: 'high'
content: 'Profile with ncu/nsys and optimize'
status: 'pending'
priority: 'medium'
content: 'Benchmark and squeeze performance'
status: 'pending'
priority: 'medium'{
"command": "uv run check.py 2>&1",
"description": "Run correctness check",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"timeout": 300000
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 109, in forward
mod = _get_mod()
^^^^^^^^^^
File "/home/infatoshi/cuda/Ker
... (3335 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -57,6 +57,15 @@
}
"""
+_CPP_SRC = r"""
+#include <torch/extension.h>
+void mrope_rotate(
+ torch::Tensor x_in, torch::Tensor x_out,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ int B, int T, int H, int D, int s0, int s1, int s2);
+"""
+
_mod = None
def _get_mod():
@@ -64,7 +73,7 @@
if _mod is None:
_mod = load_inline(
name="mrope_ext",
- cpp_sources="",
+ cpp_sources=_CPP_SRC,
cuda_sources=_CUDA_SRC,
functions=["mrope_rotate"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<338 chars \u2014 see below>",
"newString": "<636 chars \u2014 see below>"
}_mod = None
def _get_mod():
global _mod
if _mod is None:
_mod = load_inline(
name="mrope_ext",
cpp_sources="",
cuda_sources=_CUDA_SRC,
functions=["mrope_rotate"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
return _mod_CPP_SRC = r"""
#include <torch/extension.h>
void mrope_rotate(
torch::Tensor x_in, torch::Tensor x_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int H, int D, int s0, int s1, int s2);
"""
_mod = None
def _get_mod():
global _mod
if _mod is None:
_mod = load_inline(
name="mrope_ext",
cpp_sources=_CPP_SRC,
cuda_sources=_CUDA_SRC,
functions=["mrope_rotate"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
return _modEdit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,6 +3,7 @@
from torch.utils.cpp_extension import load_inline
_CUDA_SRC = r"""
+#include <torch/extension.h>
#include <cuda_bf16.h>
__global__ void mrope_rotate(
@@ -55,15 +56,28 @@
+ (long long)t * D + d;
x_out[oi] = __float2bfloat16(res);
}
-"""
-_CPP_SRC = r"""
-#include <torch/extension.h>
-void mrope_rotate(
+void mrope_forward(
torch::Tensor x_in, torch::Tensor x_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
- int B, int T, int H, int D, int s0, int s1, int s2);
+ int B, int T, int H, int D, int s0, int s1, int s2
+) {
+ const int total = B * H * T * D;
+ const int THR = 256;
+ const int blocks = (total + THR - 1) / THR;
+
+ mrope_rotate<<<blocks, THR>>>(
+ reinterpret_cast<const __nv_bfloat16*>(x_in.data_ptr<at::BFloat16>()),
+ reinterpret_cast<__nv_bfloat16*>(x_out.data_ptr<at::BFloat16>()),
+ pos_t.data_ptr<long long>(),
+ pos_h.data_ptr<long long>(),
+ pos_w.data_ptr<long long>(),
+ reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
+ reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
+ B, T, H, D, s0, s1, s2
+ );
+}
"""
_mod = None
@@ -73,9 +87,9 @@
if _mod is None:
_mod = load_inline(
name="mrope_ext",
- cpp_sources=_CPP_SRC,
+ cpp_sources="",
cuda_sources=_CUDA_SRC,
- functions=["mrope_rotate"],
+ functions=["mrope_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
@@ -116,19 +130,14 @@
k_out = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
mod = _get_mod()
- THR = 256
- nq = B * Hq * T * D
- mod.mrope_rotate(((nq + THR - 1) // THR,), (THR,),
- q, q_out, pos_t, pos_h, pos_w,
- self.cos_cache, self.sin_cache,
- B, T, Hq, D, s0, s1, s2)
+ mod.mrope_forward(q, q_out, pos_t, pos_h, pos_w,
+ self.cos_cache, self.sin_cache,
+ B, T, Hq, D, s0, s1, s2)
- nk = B * Hkv * T * D
- mod.mrope_rotate(((nk + THR - 1) // THR,), (THR,),
- k, k_out, pos_t, pos_h, pos_w,
- self.cos_cache, self.sin_cache,
- B, T, Hkv, D, s0, s1, s2)
+ mod.mrope_forward(k, k_out, pos_t, pos_h, pos_w,
+ self.cos_cache, self.sin_cache,
+ B, T, Hkv, D, s0, s1, s2)
return q_out, k_out
{
"content": "<5226 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2774, in _run_ninja_build
subprocess.run(
File "/home/infatoshi/.local/share/uv/python/cpython-3.11.15-linux-x86_64-gnu/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 132, in forward
mod = _get_mod()
^^^^^^^^^^
File "/home/infatoshi/cuda/Ker
... (3341 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,6 +1,15 @@
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
+
+_CPP_DECL = r"""
+#include <torch/extension.h>
+void mrope_forward(
+ torch::Tensor x_in, torch::Tensor x_out,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ int B, int T, int H, int D, int s0, int s1, int s2);
+"""
_CUDA_SRC = r"""
#include <torch/extension.h>
@@ -87,7 +96,7 @@
if _mod is None:
_mod = load_inline(
name="mrope_ext",
- cpp_sources="",
+ cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
functions=["mrope_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<3057 chars \u2014 see below>",
"newString": "<3358 chars \u2014 see below>"
}_CUDA_SRC = r"""
#include <torch/extension.h>
#include <cuda_bf16.h>
__global__ void mrope_rotate(
const __nv_bfloat16* __restrict__ x_in,
__nv_bfloat16* __restrict__ x_out,
const long long* __restrict__ pos_t,
const long long* __restrict__ pos_h,
const long long* __restrict__ pos_w,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
const int B, const int T, const int H, const int D,
const int s0, const int s1, const int s2
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total = B * H * T * D;
if (idx >= total) return;
const int d = idx % D;
const int tmp = idx / D;
const int t = tmp % T;
const int tmp2 = tmp / T;
const int h = tmp2 % H;
const int b = tmp2 / H;
const int half_D = D >> 1;
const int d_mod = d < half_D ? d : d - half_D;
const long long bt = (long long)b * T + t;
long long pos;
if (d_mod < s0) pos = pos_t[bt];
else if (d_mod < s0 + s1) pos = pos_h[bt];
else pos = pos_w[bt];
const float cos_v = __bfloat162float(cos_cache[pos * D + d]);
const float sin_v = __bfloat162float(sin_cache[pos * D + d]);
const long long in_base = (long long)b * T * H * D
+ (long long)t * H * D
+ (long long)h * D;
const float x_val = __bfloat162float(x_in[in_base + d]);
const int pd = d < half_D ? d + half_D : d - half_D;
const float pval = __bfloat162float(x_in[in_base + pd]);
const float rh = d < half_D ? -pval : pval;
const float res = x_val * cos_v + rh * sin_v;
const long long oi = (long long)b * H * T * D
+ (long long)h * T * D
+ (long long)t * D + d;
x_out[oi] = __float2bfloat16(res);
}
void mrope_forward(
torch::Tensor x_in, torch::Tensor x_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int H, int D, int s0, int s1, int s2
) {
const int total = B * H * T * D;
const int THR = 256;
const int blocks = (total + THR - 1) / THR;
mrope_rotate<<<blocks, THR>>>(
reinterpret_cast<const __nv_bfloat16*>(x_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(x_out.data_ptr<at::BFloat16>()),
pos_t.data_ptr<long long>(),
pos_h.data_ptr<long long>(),
pos_w.data_ptr<long long>(),
reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
B, T, H, D, s0, s1, s2
);
}
"""
_mod = None
def _get_mod():
global _mod
if _mod is None:
_mod = load_inline(
name="mrope_ext",
cpp_sources="",
cuda_sources=_CUDA_SRC,
functions=["mrope_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
return _mod_CPP_DECL = r"""
#include <torch/extension.h>
void mrope_forward(
torch::Tensor x_in, torch::Tensor x_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int H, int D, int s0, int s1, int s2);
"""
_CUDA_SRC = r"""
#include <torch/extension.h>
#include <cuda_bf16.h>
__global__ void mrope_rotate(
const __nv_bfloat16* __restrict__ x_in,
__nv_bfloat16* __restrict__ x_out,
const long long* __restrict__ pos_t,
const long long* __restrict__ pos_h,
const long long* __restrict__ pos_w,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
const int B, const int T, const int H, const int D,
const int s0, const int s1, const int s2
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total = B * H * T * D;
if (idx >= total) return;
const int d = idx % D;
const int tmp = idx / D;
const int t = tmp % T;
const int tmp2 = tmp / T;
const int h = tmp2 % H;
const int b = tmp2 / H;
const int half_D = D >> 1;
const int d_mod = d < half_D ? d : d - half_D;
const long long bt = (long long)b * T + t;
long long pos;
if (d_mod < s0) pos = pos_t[bt];
else if (d_mod < s0 + s1) pos = pos_h[bt];
else pos = pos_w[bt];
const float cos_v = __bfloat162float(cos_cache[pos * D + d]);
const float sin_v = __bfloat162float(sin_cache[pos * D + d]);
const long long in_base = (long long)b * T * H * D
+ (long long)t * H * D
+ (long long)h * D;
const float x_val = __bfloat162float(x_in[in_base + d]);
const int pd = d < half_D ? d + half_D : d - half_D;
const float pval = __bfloat162float(x_in[in_base + pd]);
const float rh = d < half_D ? -pval : pval;
const float res = x_val * cos_v + rh * sin_v;
const long long oi = (long long)b * H * T * D
+ (long long)h * T * D
+ (long long)t * D + d;
x_out[oi] = __float2bfloat16(res);
}
void mrope_forward(
torch::Tensor x_in, torch::Tensor x_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int H, int D, int s0, int s1, int s2
) {
const int total = B * H * T * D;
const int THR = 256;
const int blocks = (total + THR - 1) / THR;
mrope_rotate<<<blocks, THR>>>(
reinterpret_cast<const __nv_bfloat16*>(x_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(x_out.data_ptr<at::BFloat16>()),
pos_t.data_ptr<long long>(),
pos_h.data_ptr<long long>(),
pos_w.data_ptr<long long>(),
reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
B, T, H, D, s0, s1, s2
);
}
"""
_mod = None
def _get_mod():
global _mod
if _mod is None:
_mod = load_inline(
name="mrope_ext",
cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
functions=["mrope_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
return _modEdit applied successfully.{
"command": "rm -rf /home/infatoshi/.cache/torch_extensions/py311_cu130/mrope_ext/ 2>/dev/null; uv run check.py 2>&1",
"description": "Clear cache and run check",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -95,7 +95,7 @@
global _mod
if _mod is None:
_mod = load_inline(
- name="mrope_ext",
+ name="mrope_ext2",
cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
functions=["mrope_forward"],
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " name=\"mrope_ext\",",
"newString": " name=\"mrope_ext2\","
}Edit applied successfully.{
"command": "uv run check.py 2>&1",
"description": "Run check with new extension name",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 141, in forward
mod = _get_mod()
^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 97, in _get_mod
_mod = load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2294, in _jit_compile
return _import_module_from_library(name, build_directory, is_python_module)
... (704 more chars){
"command": "ls /home/infatoshi/cuda/KernelBench-Hard/problems/",
"description": "List all problems"
}01_fp8_gemm
02_kda_cutlass
03_paged_attention
04_kahan_softmax
05_topk_bitonic
06_sonic_moe_swiglu
07_w4a16_gemm
08_metal_lightning_attn
09_fmha_preattn_mrope
10_patch_embed_conv3d_gemm
{
"command": "for d in /home/infatoshi/cuda/KernelBench-Hard/problems/*/; do if [ -f \"$d/solution.py\" ] && grep -q \"load_inline\" \"$d/solution.py\" 2>/dev/null; then echo \"$d\"; fi; done",
"description": "Find solutions using load_inline"
}/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/
{
"command": "for d in /home/infatoshi/cuda/KernelBench-Hard/problems/*/; do if [ -f \"$d/solution.py\" ]; then echo \"=== $(basename $d) ===\" && head -5 \"$d/solution.py\"; fi; done",
"description": "Check solution patterns"
}=== 09_fmha_preattn_mrope ===
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
_CPP_DECL = r"""
{
"command": "uv run python -c \"import triton; print(triton.__version__)\" 2>&1",
"description": "Check if triton is available",
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}3.6.0
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -18,9 +18,9 @@
__global__ void mrope_rotate(
const __nv_bfloat16* __restrict__ x_in,
__nv_bfloat16* __restrict__ x_out,
- const long long* __restrict__ pos_t,
- const long long* __restrict__ pos_h,
- const long long* __restrict__ pos_w,
+ const int64_t* __restrict__ pos_t,
+ const int64_t* __restrict__ pos_h,
+ const int64_t* __restrict__ pos_w,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
const int B, const int T, const int H, const int D,
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<122 chars \u2014 see below>",
"newString": "<116 chars \u2014 see below>"
} const long long* __restrict__ pos_t,
const long long* __restrict__ pos_h,
const long long* __restrict__ pos_w, const int64_t* __restrict__ pos_t,
const int64_t* __restrict__ pos_h,
const int64_t* __restrict__ pos_w,Edit applied successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: import torch
2: import torch.nn as nn
3: from torch.utils.cpp_extension import load_inline
4:
5: _CPP_DECL = r"""
6: #include <torch/extension.h>
7: void mrope_forward(
8: torch::Tensor x_in, torch::Tensor x_out,
9: torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
10: torch::Tensor cos_cache, torch::Tensor sin_cache,
11: int B, int T, int H, int D, int s0, int s1, int s2);
12: """
13:
14: _CUDA_SRC = r"""
15: #include <torch/extension.h>
16: #include <cuda_bf16.h>
17:
18: __global__ void mrope_rotate(
19: const __nv_bfloat16* __restrict__ x_in,
20: __nv_bfloat16* __restrict__ x_out,
21: const int64_t* __restrict__ pos_t,
22: const int64_t* __restrict__ pos_h,
23: const int64_t* __restrict__ pos_w,
24: const __nv_bfloat16* __restrict__ cos_cache,
25: const __nv_bfloat16* __restrict__ sin_cache,
26: const int B, const int T, const int H, const int D,
27: const int s0, const int s1, const int s2
28: ) {
29: const int idx = blockIdx.x * blockDim.x + threadIdx.x;
30: const int total = B * H * T * D;
31: if (idx >= total) return;
32:
33: const int d = idx % D;
34: const int tmp = idx / D;
35: const int t = tmp % T;
36: const int tmp2 = tmp / T;
37: const int h = tmp2 % H;
38: const int b = tmp2 / H;
39:
40: const int half_D = D >> 1;
41: const in
... (4944 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -2,18 +2,10 @@
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
-_CPP_DECL = r"""
-#include <torch/extension.h>
-void mrope_forward(
- torch::Tensor x_in, torch::Tensor x_out,
- torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
- torch::Tensor cos_cache, torch::Tensor sin_cache,
- int B, int T, int H, int D, int s0, int s1, int s2);
-"""
-
_CUDA_SRC = r"""
#include <torch/extension.h>
#include <cuda_bf16.h>
+#include <cstdint>
__global__ void mrope_rotate(
const __nv_bfloat16* __restrict__ x_in,
@@ -39,9 +31,9 @@
const int half_D = D >> 1;
const int d_mod = d < half_D ? d : d - half_D;
- const long long bt = (long long)b * T + t;
+ const int64_t bt = (int64_t)b * T + t;
- long long pos;
+ int64_t pos;
if (d_mod < s0) pos = pos_t[bt];
else if (d_mod < s0 + s1) pos = pos_h[bt];
else pos = pos_w[bt];
@@ -49,9 +41,9 @@
const float cos_v = __bfloat162float(cos_cache[pos * D + d]);
const float sin_v = __bfloat162float(sin_cache[pos * D + d]);
- const long long in_base = (long long)b * T * H * D
- + (long long)t * H * D
- + (long long)h * D;
+ const int64_t in_base = (int64_t)b * T * H * D
+ + (int64_t)t * H * D
+ + (int64_t)h * D;
const float x_val = __bfloat162float(x_in[in_base + d]);
const int pd = d < half_D ? d + half_D : d - half_D;
@@ -60,9 +52,9 @@
const float res = x_val * cos_v + rh * sin_v;
- const long long oi = (long long)b * H * T * D
- + (long long)h * T * D
- + (long long)t * D + d;
+ const int64_t oi = (int64_t)b * H * T * D
+ + (int64_t)h * T * D
+ + (int64_t)t * D + d;
x_out[oi] = __float2bfloat16(res);
}
@@ -79,14 +71,24 @@
mrope_rotate<<<blocks, THR>>>(
reinterpret_cast<const __nv_bfloat16*>(x_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(x_out.data_ptr<at::BFloat16>()),
- pos_t.data_ptr<long long>(),
- pos_h.data_ptr<long long>(),
- pos_w.data_ptr<long long>(),
+ pos_t.data_ptr<int64_t>(),
+ pos_h.data_ptr<int64_t>(),
+ pos_w.data_ptr<int64_t>(),
reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
B, T, H, D, s0, s1, s2
);
}
+"""
+
+_CPP_DECL = r"""
+#include <torch/extension.h>
+#include <cstdint>
+void mrope_forward(
+ torch::Tensor x_in, torch::Tensor x_out,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ int B, int T, int H, int D, int s0, int s1, int s2);
"""
_mod = None
@@ -95,7 +97,7 @@
global _mod
if _mod is None:
_mod = load_inline(
- name="mrope_ext2",
+ name="mrope_ext3",
cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
functions=["mrope_forward"],
{
"content": "<5524 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1",
"description": "Run correctness check",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}PASS
{
"todos": [
{
"content": "Read reference.py, check.py, benchmark.py, problem.yaml to understand the op",
"priority": "high",
"status": "completed"
},
{
"content": "Read and understand the mRoPE math from transformers Qwen2-VL",
"priority": "high",
"status": "completed"
},
{
"content": "Implement initial fused CUDA kernel in solution.py",
"priority": "high",
"status": "completed"
},
{
"content": "Verify correctness with check.py",
"priority": "high",
"status": "completed"
},
{
"c
... (237 more chars)content: 'Read reference.py, check.py, benchmark.py, problem.yaml to understand the op'
status: 'completed'
priority: 'high'
content: 'Read and understand the mRoPE math from transformers Qwen2-VL'
status: 'completed'
priority: 'high'
content: 'Implement initial fused CUDA kernel in solution.py'
status: 'completed'
priority: 'high'
content: 'Verify correctness with check.py'
status: 'completed'
priority: 'high'
content: 'Profile with ncu/nsys and optimize'
status: 'in_progress'
priority: 'medium'
content: 'Benchmark and squeeze performance'
status: 'pending'
priority: 'medium'{
"command": "uv run benchmark.py 2>&1",
"description": "Run performance benchmark",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}shape=0 variant=eager tflops=0.177 gbps=253.823 ms=0.356
shape=0 variant=compiled tflops=0.481 gbps=690.689 ms=0.131
shape=0 variant=solution tflops=0.592 gbps=849.351 ms=0.106
shape=0 solution_peak_fraction=0.4719
shape=1 variant=eager tflops=0.180 gbps=262.510 ms=0.560
shape=1 variant=compiled tflops=0.492 gbps=718.715 ms=0.205
shape=1 variant=solution tflops=0.624 gbps=910.538 ms=0.161
shape=1 solution_peak_fraction=0.5059
shape=2 variant=eager tflops=0.071 gbps=111.215 ms=0.199
shape=2 variant=compiled tflops=0.167 gbps=260.978 ms=0.085
shape=2 variant=solution tflops=0.426 gbps=665.575 ms=0.033
shape=2 solution_peak_fraction=0.3698
shape=3 variant=eager tflops=0.172 gbps=246.647 ms=1.464
shape=3 variant=compiled tflops=0.466 gbps=668.017 ms=0.541
shape=3 variant=solution tflops=0.683 gbps=979.386 ms=0.369
shape=3 solution_peak_fraction=0.5441
peak_fraction: 0.4681
RESULT: OK
{
"command": "<731 chars \u2014 see below>",
"description": "Profile with ncu",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}ncu --set basic --target-processes all -k mrope_rotate -o /tmp/opencode/ncu_report python -c "
import torch, sys
sys.path.insert(0, '.')
import solution
B, T, Hq, Hkv, D = 1, 4096, 32, 8, 128
m = solution.Model(B, T, Hq, Hkv, D, (16,24,24), 32768).cuda().eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device='cuda') * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device='cuda') * 0.5
pt = torch.randint(0, 32768, (B, T), dtype=torch.int64, device='cuda')
ph = torch.randint(0, 32768, (B, T), dtype=torch.int64, device='cuda')
pw = torch.randint(0, 32768, (B, T), dtype=torch.int64, device='cuda')
with torch.no_grad():
for _ in range(3): m(q, k, pt, ph, pw)
torch.cuda.synchronize()
" 2>&1 | tail -40==PROF== Connected to process 1768505 (/usr/bin/python3.12)
==PROF== Profiling "mrope_rotate": 0%....50%....100% - 9 passes
==PROF== Profiling "mrope_rotate": 0%....50%....100% - 9 passes
==PROF== Profiling "mrope_rotate": 0%....50%....100% - 9 passes
==PROF== Profiling "mrope_rotate": 0%....50%....100% - 9 passes
==PROF== Profiling "mrope_rotate": 0%....50%....100% - 9 passes
==PROF== Profiling "mrope_rotate": 0%....50%....100% - 9 passes
==PROF== Disconnected from process 1768505
==PROF== Report: /tmp/opencode/ncu_report.ncu-rep
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -7,88 +7,90 @@
#include <cuda_bf16.h>
#include <cstdint>
-__global__ void mrope_rotate(
- const __nv_bfloat16* __restrict__ x_in,
- __nv_bfloat16* __restrict__ x_out,
+__global__ void __launch_bounds__(256)
+mrope_fused(
+ const __nv_bfloat16* __restrict__ q_in,
+ const __nv_bfloat16* __restrict__ k_in,
+ __nv_bfloat16* __restrict__ q_out,
+ __nv_bfloat16* __restrict__ k_out,
const int64_t* __restrict__ pos_t,
const int64_t* __restrict__ pos_h,
const int64_t* __restrict__ pos_w,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
- const int B, const int T, const int H, const int D,
+ const int B, const int T, const int Hq, const int Hkv, const int D,
const int s0, const int s1, const int s2
) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- const int total = B * H * T * D;
- if (idx >= total) return;
+ extern __shared__ float smem[];
- const int d = idx % D;
- const int tmp = idx / D;
- const int t = tmp % T;
- const int tmp2 = tmp / T;
- const int h = tmp2 % H;
- const int b = tmp2 / H;
+ const int bt = blockIdx.x;
+ const int h = blockIdx.y;
+ const int d = threadIdx.x;
+ if (h >= Hq || d >= D) return;
+
+ const int b = bt / T;
+ const int t = bt % T;
const int half_D = D >> 1;
+
const int d_mod = d < half_D ? d : d - half_D;
- const int64_t bt = (int64_t)b * T + t;
+ const int64_t bt_idx = (int64_t)b * T + t;
int64_t pos;
- if (d_mod < s0) pos = pos_t[bt];
- else if (d_mod < s0 + s1) pos = pos_h[bt];
- else pos = pos_w[bt];
+ if (d_mod < s0) pos = pos_t[bt_idx];
+ else if (d_mod < s0 + s1) pos = pos_h[bt_idx];
+ else pos = pos_w[bt_idx];
const float cos_v = __bfloat162float(cos_cache[pos * D + d]);
const float sin_v = __bfloat162float(sin_cache[pos * D + d]);
- const int64_t in_base = (int64_t)b * T * H * D
- + (int64_t)t * H * D
- + (int64_t)h * D;
- const float x_val = __bfloat162float(x_in[in_base + d]);
+ const int pd = d < half_D ? d + half_D : d - half_D;
+ const float sign = d < half_D ? -1.0f : 1.0f;
- const int pd = d < half_D ? d + half_D : d - half_D;
- const float pval = __bfloat162float(x_in[in_base + pd]);
- const float rh = d < half_D ? -pval : pval;
+ // Process q: load into smem, sync, read self+partner from smem
+ {
+ const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
+ smem[d] = __bfloat162float(q_in[in_base + d]);
+ __syncthreads();
- const float res = x_val * cos_v + rh * sin_v;
+ const float x_val = smem[d];
+ const float pval = smem[pd];
+ const float rh = sign * pval;
+ const float res = x_val * cos_v + rh * sin_v;
- const int64_t oi = (int64_t)b * H * T * D
- + (int64_t)h * T * D
- + (int64_t)t * D + d;
- x_out[oi] = __float2bfloat16(res);
-}
+ const int64_t oi = ((int64_t)b * Hq + h) * T * D + (int64_t)t * D + d;
+ q_out[oi] = __float2bfloat16(res);
+ }
-void mrope_forward(
- torch::Tensor x_in, torch::Tensor x_out,
- torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
- torch::Tensor cos_cache, torch::Tensor sin_cache,
- int B, int T, int H, int D, int s0, int s1, int s2
-) {
- const int total = B * H * T * D;
- const int THR = 256;
- const int blocks = (total + THR - 1) / THR;
+ // Barrier: all threads done reading q from smem
+ __syncthreads();
- mrope_rotate<<<blocks, THR>>>(
- reinterpret_cast<const __nv_bfloat16*>(x_in.data_ptr<at::BFloat16>()),
- reinterpret_cast<__nv_bfloat16*>(x_out.data_ptr<at::BFloat16>()),
- pos_t.data_ptr<int64_t>(),
- pos_h.data_ptr<int64_t>(),
- pos_w.data_ptr<int64_t>(),
- reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
- reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
- B, T, H, D, s0, s1, s2
- );
+ // Process k
+ if (h < Hkv) {
+ const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ smem[d] = __bfloat162float(k_in[in_base + d]);
+ __syncthreads();
+
+ const float x_val = smem[d];
+ const float pval = smem[pd];
+ const float rh = sign * pval;
+ const float res = x_val * cos_v + rh * sin_v;
+
+ const int64_t oi = ((int64_t)b * Hkv + h) * T * D + (int64_t)t * D + d;
+ k_out[oi] = __float2bfloat16(res);
+ }
}
"""
_CPP_DECL = r"""
#include <torch/extension.h>
#include <cstdint>
-void mrope_forward(
- torch::Tensor x_in, torch::Tensor x_out,
+void mrope_fused(
+ torch::Tensor q_in, torch::Tensor k_in,
+ torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
- int B, int T, int H, int D, int s0, int s1, int s2);
+ int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2);
"""
_mod = None
@@ -97,10 +99,10 @@
global _mod
if _mod is None:
_mod = load_inline(
- name="mrope_ext3",
+ name="mrope_fused_v1",
cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
- functions=["mrope_forward"],
+ functions=["mrope_fused"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
@@ -142,13 +144,15 @@
mod = _get_mod()
- mod.mrope_forward(q, q_out, pos_t, pos_h, pos_w,
- self.cos_cache, self.sin_cache,
-
... (564 more chars){
"content": "<5390 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1",
"description": "Verify fused kernel correctness",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}Traceback (most recent call last):
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 119, in <module>
main()
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/check.py", line 76, in main
sol_out = sol_model(*inputs)
^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 145, in forward
mod = _get_mod()
^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py", line 101, in _get_mod
_mod = load_inline(
^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2174, in load_inline
return _jit_compile(
^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2294, in _jit_compile
return _import_module_from_library(name, build_directory, is_python_module)
... (736 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,12 +3,10 @@
from torch.utils.cpp_extension import load_inline
_CUDA_SRC = r"""
-#include <torch/extension.h>
#include <cuda_bf16.h>
#include <cstdint>
-__global__ void __launch_bounds__(256)
-mrope_fused(
+__global__ void mrope_fused(
const __nv_bfloat16* __restrict__ q_in,
const __nv_bfloat16* __restrict__ k_in,
__nv_bfloat16* __restrict__ q_out,
@@ -47,7 +45,7 @@
const int pd = d < half_D ? d + half_D : d - half_D;
const float sign = d < half_D ? -1.0f : 1.0f;
- // Process q: load into smem, sync, read self+partner from smem
+ // Process q
{
const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
smem[d] = __bfloat162float(q_in[in_base + d]);
@@ -55,26 +53,23 @@
const float x_val = smem[d];
const float pval = smem[pd];
- const float rh = sign * pval;
- const float res = x_val * cos_v + rh * sin_v;
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
const int64_t oi = ((int64_t)b * Hq + h) * T * D + (int64_t)t * D + d;
q_out[oi] = __float2bfloat16(res);
}
- // Barrier: all threads done reading q from smem
__syncthreads();
// Process k
if (h < Hkv) {
- const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * d;
smem[d] = __bfloat162float(k_in[in_base + d]);
__syncthreads();
const float x_val = smem[d];
const float pval = smem[pd];
- const float rh = sign * pval;
- const float res = x_val * cos_v + rh * sin_v;
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
const int64_t oi = ((int64_t)b * Hkv + h) * T * D + (int64_t)t * D + d;
k_out[oi] = __float2bfloat16(res);
@@ -82,15 +77,43 @@
}
"""
-_CPP_DECL = r"""
+_CPP_SRC = r"""
#include <torch/extension.h>
+#include <cuda_bf16.h>
#include <cstdint>
+
+// Forward declaration of CUDA kernel
void mrope_fused(
+ const __nv_bfloat16*, const __nv_bfloat16*,
+ __nv_bfloat16*, __nv_bfloat16*,
+ const int64_t*, const int64_t*, const int64_t*,
+ const __nv_bfloat16*, const __nv_bfloat16*,
+ int, int, int, int, int, int, int, int);
+
+void mrope_fused_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
- int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2);
+ int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
+) {
+ const dim3 grid(B * T, Hq);
+ const dim3 block(D);
+ const int smem = D * (int)sizeof(float);
+
+ mrope_fused<<<grid, block, smem>>>(
+ reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
+ reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
+ reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
+ reinterpret_cast<__nv_bfloat16*>(k_out.data_ptr<at::BFloat16>()),
+ pos_t.data_ptr<int64_t>(),
+ pos_h.data_ptr<int64_t>(),
+ pos_w.data_ptr<int64_t>(),
+ reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
+ reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
+ B, T, Hq, Hkv, D, s0, s1, s2
+ );
+}
"""
_mod = None
@@ -99,10 +122,10 @@
global _mod
if _mod is None:
_mod = load_inline(
- name="mrope_fused_v1",
- cpp_sources=_CPP_DECL,
+ name="mrope_fused_v2",
+ cpp_sources=_CPP_SRC,
cuda_sources=_CUDA_SRC,
- functions=["mrope_fused"],
+ functions=["mrope_fused_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
@@ -143,16 +166,10 @@
k_out = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
mod = _get_mod()
-
- grid = (B * T, Hq)
- block = (D,)
- smem = D * 4 # D floats
-
- mod.mrope_fused(grid, block, smem,
- q, k, q_out, k_out,
- pos_t, pos_h, pos_w,
- self.cos_cache, self.sin_cache,
- B, T, Hq, Hkv, D, s0, s1, s2)
+ mod.mrope_fused_forward(q, k, q_out, k_out,
+ pos_t, pos_h, pos_w,
+ self.cos_cache, self.sin_cache,
+ B, T, Hq, Hkv, D, s0, s1, s2)
return q_out, k_out
{
"content": "<6172 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,6 +3,7 @@
from torch.utils.cpp_extension import load_inline
_CUDA_SRC = r"""
+#include <torch/extension.h>
#include <cuda_bf16.h>
#include <cstdint>
@@ -55,7 +56,7 @@
const float pval = smem[pd];
const float res = x_val * cos_v + (sign * pval) * sin_v;
- const int64_t oi = ((int64_t)b * Hq + h) * T * D + (int64_t)t * D + d;
+ const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
q_out[oi] = __float2bfloat16(res);
}
@@ -63,7 +64,7 @@
// Process k
if (h < Hkv) {
- const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * d;
+ const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
smem[d] = __bfloat162float(k_in[in_base + d]);
__syncthreads();
@@ -71,24 +72,10 @@
const float pval = smem[pd];
const float res = x_val * cos_v + (sign * pval) * sin_v;
- const int64_t oi = ((int64_t)b * Hkv + h) * T * D + (int64_t)t * D + d;
+ const int64_t oi = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D + d;
k_out[oi] = __float2bfloat16(res);
}
}
-"""
-
-_CPP_SRC = r"""
-#include <torch/extension.h>
-#include <cuda_bf16.h>
-#include <cstdint>
-
-// Forward declaration of CUDA kernel
-void mrope_fused(
- const __nv_bfloat16*, const __nv_bfloat16*,
- __nv_bfloat16*, __nv_bfloat16*,
- const int64_t*, const int64_t*, const int64_t*,
- const __nv_bfloat16*, const __nv_bfloat16*,
- int, int, int, int, int, int, int, int);
void mrope_fused_forward(
torch::Tensor q_in, torch::Tensor k_in,
@@ -99,7 +86,7 @@
) {
const dim3 grid(B * T, Hq);
const dim3 block(D);
- const int smem = D * (int)sizeof(float);
+ const size_t smem = D * sizeof(float);
mrope_fused<<<grid, block, smem>>>(
reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
@@ -116,14 +103,25 @@
}
"""
+_CPP_DECL = r"""
+#include <torch/extension.h>
+#include <cstdint>
+void mrope_fused_forward(
+ torch::Tensor q_in, torch::Tensor k_in,
+ torch::Tensor q_out, torch::Tensor k_out,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2);
+"""
+
_mod = None
def _get_mod():
global _mod
if _mod is None:
_mod = load_inline(
- name="mrope_fused_v2",
- cpp_sources=_CPP_SRC,
+ name="mrope_fused_v3",
+ cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
functions=["mrope_fused_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
{
"content": "<6214 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1",
"description": "Check correctness of fused kernel",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}PASS
{
"command": "uv run benchmark.py 2>&1",
"description": "Benchmark fused kernel",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}shape=0 variant=eager tflops=0.164 gbps=235.968 ms=0.383
shape=0 variant=compiled tflops=0.669 gbps=960.218 ms=0.094
shape=0 variant=solution tflops=0.759 gbps=1089.025 ms=0.083
shape=0 solution_peak_fraction=0.6050
shape=1 variant=eager tflops=0.209 gbps=305.684 ms=0.481
shape=1 variant=compiled tflops=0.672 gbps=981.238 ms=0.150
shape=1 variant=solution tflops=0.753 gbps=1099.094 ms=0.134
shape=1 solution_peak_fraction=0.6106
shape=2 variant=eager tflops=0.075 gbps=117.063 ms=0.189
shape=2 variant=compiled tflops=0.169 gbps=263.766 ms=0.084
shape=2 variant=solution tflops=0.312 gbps=487.447 ms=0.045
shape=2 solution_peak_fraction=0.2708
shape=3 variant=eager tflops=0.178 gbps=254.901 ms=1.417
shape=3 variant=compiled tflops=0.528 gbps=757.475 ms=0.477
shape=3 variant=solution tflops=0.835 gbps=1198.691 ms=0.301
shape=3 solution_peak_fraction=0.6659
peak_fraction: 0.5080
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -17,96 +17,84 @@
const int64_t* __restrict__ pos_w,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
- const int B, const int T, const int Hq, const int Hkv, const int D,
- const int s0, const int s1, const int s2
+ const int B, const int T, const int Hq, const int Hkv, const int D
) {
extern __shared__ float smem[];
- const int bt = blockIdx.x;
- const int h = blockIdx.y;
- const int d = threadIdx.x;
-
- if (h >= Hq || d >= D) return;
-
- const int b = bt / T;
- const int t = bt % T;
+ const int groups_per_block = blockDim.x / D;
+ const int group_idx = threadIdx.x / D;
+ const int d = threadIdx.x % D;
const int half_D = D >> 1;
+ const int global_group = blockIdx.x * groups_per_block + group_idx;
+ const int active = global_group < B * T * Hq;
+
+ int bt = 0, h = 0, b = 0, t = 0;
+ if (active) {
+ bt = global_group / Hq;
+ h = global_group % Hq;
+ b = bt / T;
+ t = bt % T;
+ }
+
+ // Axis determination
const int d_mod = d < half_D ? d : d - half_D;
- const int64_t bt_idx = (int64_t)b * T + t;
-
- int64_t pos;
- if (d_mod < s0) pos = pos_t[bt_idx];
- else if (d_mod < s0 + s1) pos = pos_h[bt_idx];
- else pos = pos_w[bt_idx];
-
- const float cos_v = __bfloat162float(cos_cache[pos * D + d]);
- const float sin_v = __bfloat162float(sin_cache[pos * D + d]);
+ float cos_v = 0.0f, sin_v = 0.0f;
+ if (active) {
+ const int64_t bt_idx = (int64_t)b * T + t;
+ int64_t pos;
+ if (d_mod < s0) pos = pos_t[bt_idx];
+ else if (d_mod < s0 + s1) pos = pos_h[bt_idx];
+ else pos = pos_w[bt_idx];
+ cos_v = __bfloat162float(cos_cache[pos * D + d]);
+ sin_v = __bfloat162float(sin_cache[pos * D + d]);
+ }
const int pd = d < half_D ? d + half_D : d - half_D;
const float sign = d < half_D ? -1.0f : 1.0f;
- // Process q
- {
+ // Separate smem buffers for q and k per group
+ float* s_q = smem + group_idx * D * 2;
+ float* s_k = s_q + D;
+
+ // Load q
+ if (active) {
const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
- smem[d] = __bfloat162float(q_in[in_base + d]);
- __syncthreads();
+ s_q[d] = __bfloat162float(q_in[in_base + d]);
+ }
- const float x_val = smem[d];
- const float pval = smem[pd];
- const float res = x_val * cos_v + (sign * pval) * sin_v;
+ // Load k
+ if (active && h < Hkv) {
+ const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ s_k[d] = __bfloat162float(k_in[in_base + d]);
+ }
+ __syncthreads();
+
+ // Compute q rotation
+ if (active) {
+ const float x_val = s_q[d];
+ const float pval = s_q[pd];
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
q_out[oi] = __float2bfloat16(res);
}
- __syncthreads();
-
- // Process k
- if (h < Hkv) {
- const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
- smem[d] = __bfloat162float(k_in[in_base + d]);
- __syncthreads();
-
- const float x_val = smem[d];
- const float pval = smem[pd];
- const float res = x_val * cos_v + (sign * pval) * sin_v;
-
+ // Compute k rotation
+ if (active && h < Hkv) {
+ const float x_val = s_k[d];
+ const float pval = s_k[pd];
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
const int64_t oi = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D + d;
k_out[oi] = __float2bfloat16(res);
}
-}
-
-void mrope_fused_forward(
- torch::Tensor q_in, torch::Tensor k_in,
- torch::Tensor q_out, torch::Tensor k_out,
- torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
- torch::Tensor cos_cache, torch::Tensor sin_cache,
- int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
-) {
- const dim3 grid(B * T, Hq);
- const dim3 block(D);
- const size_t smem = D * sizeof(float);
-
- mrope_fused<<<grid, block, smem>>>(
- reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
- reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
- reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
- reinterpret_cast<__nv_bfloat16*>(k_out.data_ptr<at::BFloat16>()),
- pos_t.data_ptr<int64_t>(),
- pos_h.data_ptr<int64_t>(),
- pos_w.data_ptr<int64_t>(),
- reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
- reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
- B, T, Hq, Hkv, D, s0, s1, s2
- );
}
"""
_CPP_DECL = r"""
#include <torch/extension.h>
#include <cstdint>
-void mrope_fused_forward(
+void mrope_fused(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
@@ -120,10 +108,10 @@
global _mod
if _mod is None:
_mod = load_inline(
- name="mrope_fused_v3",
+ name="mrope_fused_v4",
cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
- functions=["mrope_fused_forward"],
+ functions=["mrope_fused"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
@@ -164,10 +152,18 @@
k_out = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
mod = _get_mod()
- mod.mrope_fused_forward(q, k, q_out, k_out,
-
... (718 more chars){
"content": "<5865 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: import torch
2: import torch.nn as nn
3: from torch.utils.cpp_extension import load_inline
4:
5: _CUDA_SRC = r"""
6: #include <torch/extension.h>
7: #include <cuda_bf16.h>
8: #include <cstdint>
9:
10: __global__ void mrope_fused(
11: const __nv_bfloat16* __restrict__ q_in,
12: const __nv_bfloat16* __restrict__ k_in,
13: __nv_bfloat16* __restrict__ q_out,
14: __nv_bfloat16* __restrict__ k_out,
15: const int64_t* __restrict__ pos_t,
16: const int64_t* __restrict__ pos_h,
17: const int64_t* __restrict__ pos_w,
18: const __nv_bfloat16* __restrict__ cos_cache,
19: const __nv_bfloat16* __restrict__ sin_cache,
20: const int B, const int T, const int Hq, const int Hkv, const int D
21: ) {
22: extern __shared__ float smem[];
23:
24: const int groups_per_block = blockDim.x / D;
25: const int group_idx = threadIdx.x / D;
26: const int d = threadIdx.x % D;
27: const int half_D = D >> 1;
28:
29: const int global_group = blockIdx.x * groups_per_block + group_idx;
30: const int active = global_group < B * T * Hq;
31:
32: int bt = 0, h = 0, b = 0, t = 0;
33: if (active) {
34: bt = global_group / Hq;
35: h = global_group % Hq;
36: b = bt / T;
37: t = bt % T;
38: }
39:
40: // Axis determination
41: const int d_mod = d < half_D ? d : d - half_D;
42
... (5372 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -17,7 +17,8 @@
const int64_t* __restrict__ pos_w,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
- const int B, const int T, const int Hq, const int Hkv, const int D
+ const int B, const int T, const int Hq, const int Hkv, const int D,
+ const int s0, const int s1, const int s2
) {
extern __shared__ float smem[];
@@ -37,7 +38,6 @@
t = bt % T;
}
- // Axis determination
const int d_mod = d < half_D ? d : d - half_D;
float cos_v = 0.0f, sin_v = 0.0f;
if (active) {
@@ -53,17 +53,13 @@
const int pd = d < half_D ? d + half_D : d - half_D;
const float sign = d < half_D ? -1.0f : 1.0f;
- // Separate smem buffers for q and k per group
float* s_q = smem + group_idx * D * 2;
float* s_k = s_q + D;
- // Load q
if (active) {
const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
s_q[d] = __bfloat162float(q_in[in_base + d]);
}
-
- // Load k
if (active && h < Hkv) {
const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
s_k[d] = __bfloat162float(k_in[in_base + d]);
@@ -71,7 +67,6 @@
__syncthreads();
- // Compute q rotation
if (active) {
const float x_val = s_q[d];
const float pval = s_q[pd];
@@ -80,7 +75,6 @@
q_out[oi] = __float2bfloat16(res);
}
- // Compute k rotation
if (active && h < Hkv) {
const float x_val = s_k[d];
const float pval = s_k[pd];
@@ -89,12 +83,39 @@
k_out[oi] = __float2bfloat16(res);
}
}
+
+void mrope_fused_forward(
+ torch::Tensor q_in, torch::Tensor k_in,
+ torch::Tensor q_out, torch::Tensor k_out,
+ torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
+ torch::Tensor cos_cache, torch::Tensor sin_cache,
+ int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
+) {
+ const int groups_per_block = max(1, 256 / D);
+ const int total_groups = B * T * Hq;
+ const int grid = (total_groups + groups_per_block - 1) / groups_per_block;
+ const int block = groups_per_block * D;
+ const size_t smem = groups_per_block * D * 2 * sizeof(float);
+
+ mrope_fused<<<grid, block, smem>>>(
+ reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
+ reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
+ reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
+ reinterpret_cast<__nv_bfloat16*>(k_out.data_ptr<at::BFloat16>()),
+ pos_t.data_ptr<int64_t>(),
+ pos_h.data_ptr<int64_t>(),
+ pos_w.data_ptr<int64_t>(),
+ reinterpret_cast<const __nv_bfloat16*>(cos_cache.data_ptr<at::BFloat16>()),
+ reinterpret_cast<const __nv_bfloat16*>(sin_cache.data_ptr<at::BFloat16>()),
+ B, T, Hq, Hkv, D, s0, s1, s2
+ );
+}
"""
_CPP_DECL = r"""
#include <torch/extension.h>
#include <cstdint>
-void mrope_fused(
+void mrope_fused_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
@@ -108,10 +129,10 @@
global _mod
if _mod is None:
_mod = load_inline(
- name="mrope_fused_v4",
+ name="mrope_fused_v5",
cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
- functions=["mrope_fused"],
+ functions=["mrope_fused_forward"],
extra_cuda_cflags=["-O3", "--use_fast_math"],
verbose=False,
)
@@ -152,18 +173,10 @@
k_out = torch.empty(B, Hkv, T, D, dtype=k.dtype, device=k.device)
mod = _get_mod()
-
- groups_per_block = max(1, 256 // D)
- total_groups = B * T * Hq
- grid = ((total_groups + groups_per_block - 1) // groups_per_block,)
- block = (groups_per_block * D,)
- smem = groups_per_block * D * 2 * 4 # 2 float buffers per group
-
- mod.mrope_fused(grid, block, smem,
- q, k, q_out, k_out,
- pos_t, pos_h, pos_w,
- self.cos_cache, self.sin_cache,
- B, T, Hq, Hkv, D, s0, s1, s2)
+ mod.mrope_fused_forward(q, k, q_out, k_out,
+ pos_t, pos_h, pos_w,
+ self.cos_cache, self.sin_cache,
+ B, T, Hq, Hkv, D, s0, s1, s2)
return q_out, k_out
{
"content": "<6742 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -91,7 +91,7 @@
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
) {
- const int groups_per_block = max(1, 256 / D);
+ const int groups_per_block = (256 / D) > 1 ? (256 / D) : 1;
const int total_groups = B * T * Hq;
const int grid = (total_groups + groups_per_block - 1) / groups_per_block;
const int block = groups_per_block * D;
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " const int groups_per_block = max(1, 256 / D);",
"newString": " const int groups_per_block = (256 / D) > 1 ? (256 / D) : 1;"
}Edit applied successfully.{
"command": "uv run check.py 2>&1",
"description": "Run correctness check on multi-group fused kernel",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}PASS
{
"command": "uv run benchmark.py 2>&1",
"description": "Benchmark multi-group fused kernel",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}shape=0 variant=eager tflops=0.180 gbps=258.037 ms=0.350
shape=0 variant=compiled tflops=0.492 gbps=705.456 ms=0.128
shape=0 variant=solution tflops=0.431 gbps=618.463 ms=0.146
shape=0 solution_peak_fraction=0.3436
shape=1 variant=eager tflops=0.181 gbps=264.505 ms=0.556
shape=1 variant=compiled tflops=0.469 gbps=685.468 ms=0.214
shape=1 variant=solution tflops=0.423 gbps=618.051 ms=0.238
shape=1 solution_peak_fraction=0.3434
shape=2 variant=eager tflops=0.063 gbps=98.056 ms=0.226
shape=2 variant=compiled tflops=0.169 gbps=263.565 ms=0.084
shape=2 variant=solution tflops=0.361 gbps=563.325 ms=0.039
shape=2 solution_peak_fraction=0.3130
shape=3 variant=eager tflops=0.170 gbps=244.478 ms=1.477
shape=3 variant=compiled tflops=0.466 gbps=668.571 ms=0.540
shape=3 variant=solution tflops=0.493 gbps=707.602 ms=0.510
shape=3 solution_peak_fraction=0.3931
peak_fraction: 0.3471
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -22,63 +22,56 @@
) {
extern __shared__ float smem[];
- const int groups_per_block = blockDim.x / D;
- const int group_idx = threadIdx.x / D;
- const int d = threadIdx.x % D;
+ const int bt = blockIdx.x;
+ const int h = blockIdx.y;
+ const int d = threadIdx.x;
+
+ if (h >= Hq || d >= D) return;
+
+ const int b = bt / T;
+ const int t = bt % T;
const int half_D = D >> 1;
- const int global_group = blockIdx.x * groups_per_block + group_idx;
- const int active = global_group < B * T * Hq;
+ const int d_mod = d < half_D ? d : d - half_D;
+ const int64_t bt_idx = (int64_t)b * T + t;
- int bt = 0, h = 0, b = 0, t = 0;
- if (active) {
- bt = global_group / Hq;
- h = global_group % Hq;
- b = bt / T;
- t = bt % T;
- }
+ int64_t pos;
+ if (d_mod < s0) pos = pos_t[bt_idx];
+ else if (d_mod < s0 + s1) pos = pos_h[bt_idx];
+ else pos = pos_w[bt_idx];
- const int d_mod = d < half_D ? d : d - half_D;
- float cos_v = 0.0f, sin_v = 0.0f;
- if (active) {
- const int64_t bt_idx = (int64_t)b * T + t;
- int64_t pos;
- if (d_mod < s0) pos = pos_t[bt_idx];
- else if (d_mod < s0 + s1) pos = pos_h[bt_idx];
- else pos = pos_w[bt_idx];
- cos_v = __bfloat162float(cos_cache[pos * D + d]);
- sin_v = __bfloat162float(sin_cache[pos * D + d]);
- }
+ const float cos_v = __bfloat162float(cos_cache[pos * D + d]);
+ const float sin_v = __bfloat162float(sin_cache[pos * D + d]);
const int pd = d < half_D ? d + half_D : d - half_D;
const float sign = d < half_D ? -1.0f : 1.0f;
- float* s_q = smem + group_idx * D * 2;
- float* s_k = s_q + D;
+ // Process q
+ {
+ const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
+ smem[d] = __bfloat162float(q_in[in_base + d]);
+ __syncthreads();
- if (active) {
- const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
- s_q[d] = __bfloat162float(q_in[in_base + d]);
- }
- if (active && h < Hkv) {
- const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
- s_k[d] = __bfloat162float(k_in[in_base + d]);
+ const float x_val = smem[d];
+ const float pval = smem[pd];
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
+
+ const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
+ q_out[oi] = __float2bfloat16(res);
}
__syncthreads();
- if (active) {
- const float x_val = s_q[d];
- const float pval = s_q[pd];
+ // Process k
+ if (h < Hkv) {
+ const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ smem[d] = __bfloat162float(k_in[in_base + d]);
+ __syncthreads();
+
+ const float x_val = smem[d];
+ const float pval = smem[pd];
const float res = x_val * cos_v + (sign * pval) * sin_v;
- const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
- q_out[oi] = __float2bfloat16(res);
- }
- if (active && h < Hkv) {
- const float x_val = s_k[d];
- const float pval = s_k[pd];
- const float res = x_val * cos_v + (sign * pval) * sin_v;
const int64_t oi = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D + d;
k_out[oi] = __float2bfloat16(res);
}
@@ -91,11 +84,9 @@
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
) {
- const int groups_per_block = (256 / D) > 1 ? (256 / D) : 1;
- const int total_groups = B * T * Hq;
- const int grid = (total_groups + groups_per_block - 1) / groups_per_block;
- const int block = groups_per_block * D;
- const size_t smem = groups_per_block * D * 2 * sizeof(float);
+ const dim3 grid(B * T, Hq);
+ const dim3 block(D);
+ const size_t smem = D * sizeof(float);
mrope_fused<<<grid, block, smem>>>(
reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
@@ -129,7 +120,7 @@
global _mod
if _mod is None:
_mod = load_inline(
- name="mrope_fused_v5",
+ name="mrope_fused_v6",
cpp_sources=_CPP_DECL,
cuda_sources=_CUDA_SRC,
functions=["mrope_fused_forward"],
{
"content": "<6208 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1 && echo \"---\" && uv run benchmark.py 2>&1",
"description": "Check and benchmark v6",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}PASS
---
shape=0 variant=eager tflops=0.182 gbps=261.712 ms=0.345
shape=0 variant=compiled tflops=0.479 gbps=687.826 ms=0.131
shape=0 variant=solution tflops=0.664 gbps=952.277 ms=0.095
shape=0 solution_peak_fraction=0.5290
shape=1 variant=eager tflops=0.182 gbps=265.760 ms=0.553
shape=1 variant=compiled tflops=0.477 gbps=696.749 ms=0.211
shape=1 variant=solution tflops=0.724 gbps=1056.865 ms=0.139
shape=1 solution_peak_fraction=0.5871
shape=2 variant=eager tflops=0.071 gbps=111.161 ms=0.199
shape=2 variant=compiled tflops=0.170 gbps=265.897 ms=0.083
shape=2 variant=solution tflops=0.312 gbps=486.932 ms=0.045
shape=2 solution_peak_fraction=0.2705
shape=3 variant=eager tflops=0.172 gbps=247.331 ms=1.460
shape=3 variant=compiled tflops=0.464 gbps=665.260 ms=0.543
shape=3 variant=solution tflops=0.806 gbps=1156.315 ms=0.312
shape=3 solution_peak_fraction=0.6424
peak_fraction: 0.4820
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -7,7 +7,7 @@
#include <cuda_bf16.h>
#include <cstdint>
-__global__ void mrope_fused(
+__global__ void mrope_fused_vec(
const __nv_bfloat16* __restrict__ q_in,
const __nv_bfloat16* __restrict__ k_in,
__nv_bfloat16* __restrict__ q_out,
@@ -20,75 +20,87 @@
const int B, const int T, const int Hq, const int Hkv, const int D,
const int s0, const int s1, const int s2
) {
- extern __shared__ float smem[];
-
- const int bt = blockIdx.x;
- const int h = blockIdx.y;
- const int d = threadIdx.x;
-
- if (h >= Hq || d >= D) return;
-
- const int b = bt / T;
- const int t = bt % T;
+ const int warp_id = threadIdx.x / 32;
+ const int lane_id = threadIdx.x % 32;
+ const int rows_per_block = blockDim.x / 32;
+ const int row = blockIdx.x * rows_per_block + warp_id;
+
+ if (row >= B * Hq * T) return;
+
+ const int bt = row / Hq;
+ const int h = row % Hq;
+ const int b = bt / T;
+ const int t = bt % T;
const int half_D = D >> 1;
-
- const int d_mod = d < half_D ? d : d - half_D;
+ const int ept = D / 32; // elements per thread: 2 for D=64, 4 for D=128
+
+ const int d_base = lane_id * ept;
+ const int partner_lane = (lane_id + 16) % 32;
const int64_t bt_idx = (int64_t)b * T + t;
-
- int64_t pos;
- if (d_mod < s0) pos = pos_t[bt_idx];
- else if (d_mod < s0 + s1) pos = pos_h[bt_idx];
- else pos = pos_w[bt_idx];
-
- const float cos_v = __bfloat162float(cos_cache[pos * D + d]);
- const float sin_v = __bfloat162float(sin_cache[pos * D + d]);
-
- const int pd = d < half_D ? d + half_D : d - half_D;
- const float sign = d < half_D ? -1.0f : 1.0f;
-
- // Process q
- {
- const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
- smem[d] = __bfloat162float(q_in[in_base + d]);
- __syncthreads();
-
- const float x_val = smem[d];
- const float pval = smem[pd];
- const float res = x_val * cos_v + (sign * pval) * sin_v;
-
- const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
- q_out[oi] = __float2bfloat16(res);
+ const float sign = lane_id < 16 ? -1.0f : 1.0f;
+
+ // Load q values
+ const int64_t q_base = (int64_t)bt * Hq * D + (int64_t)h * D;
+ float qx[4];
+ for (int i = 0; i < ept; i++)
+ qx[i] = __bfloat162float(q_in[q_base + d_base + i]);
+
+ // Partner values via warp shuffle
+ float px[4];
+ for (int i = 0; i < ept; i++)
+ px[i] = __shfl_sync(0xFFFFFFFF, qx[i], partner_lane, 32);
+
+ // Cos/sin per element
+ float cs[4], sn[4];
+ for (int i = 0; i < ept; i++) {
+ const int d = d_base + i;
+ const int d_mod = d < half_D ? d : d - half_D;
+ int64_t pos;
+ if (d_mod < s0) pos = pos_t[bt_idx];
+ else if (d_mod < s0 + s1) pos = pos_h[bt_idx];
+ else pos = pos_w[bt_idx];
+ cs[i] = __bfloat162float(cos_cache[pos * D + d]);
+ sn[i] = __bfloat162float(sin_cache[pos * D + d]);
}
- __syncthreads();
+ // Compute and write q rotation
+ const int64_t qo = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D;
+ for (int i = 0; i < ept; i++) {
+ const float res = qx[i] * cs[i] + (sign * px[i]) * sn[i];
+ q_out[qo + d_base + i] = __float2bfloat16(res);
+ }
// Process k
if (h < Hkv) {
- const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
- smem[d] = __bfloat162float(k_in[in_base + d]);
- __syncthreads();
-
- const float x_val = smem[d];
- const float pval = smem[pd];
- const float res = x_val * cos_v + (sign * pval) * sin_v;
-
- const int64_t oi = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D + d;
- k_out[oi] = __float2bfloat16(res);
+ const int64_t k_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ float kx[4];
+ for (int i = 0; i < ept; i++)
+ kx[i] = __bfloat162float(k_in[k_base + d_base + i]);
+
+ float kp[4];
+ for (int i = 0; i < ept; i++)
+ kp[i] = __shfl_sync(0xFFFFFFFF, kx[i], partner_lane, 32);
+
+ const int64_t ko = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D;
+ for (int i = 0; i < ept; i++) {
+ const float res = kx[i] * cs[i] + (sign * kp[i]) * sn[i];
+ k_out[ko + d_base + i] = __float2bfloat16(res);
+ }
}
}
-void mrope_fused_forward(
+void mrope_fused_vec_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
) {
- const dim3 grid(B * T, Hq);
- const dim3 block(D);
- const size_t smem = D * sizeof(float);
-
- mrope_fused<<<grid, block, smem>>>(
+ const int rows_per_block = 8;
+ const int total_rows = B * Hq * T;
+ const int grid = (total_rows + rows_per_block - 1) / rows_per_block;
+
+ mrope_fused_vec<<<grid, 256>>>(
reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
@@ -106,7 +118,7 @@
_CPP_DECL = r"""
#include <torch/extension.h>
#include <cstdint>
-void mrope_fused_forward(
+void mrope_fused_vec_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
@@ -120,10 +132,10 @@
global _mod
if _mod is None:
_mod = load_inline(
- name="m
... (967 more chars){
"content": "<6964 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1 && echo \"---\" && uv run benchmark.py 2>&1",
"description": "Check and benchmark warp-shuffle kernel",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}PASS
---
shape=0 variant=eager tflops=0.179 gbps=257.437 ms=0.351
shape=0 variant=compiled tflops=0.504 gbps=722.901 ms=0.125
shape=0 variant=solution tflops=0.578 gbps=828.766 ms=0.109
shape=0 solution_peak_fraction=0.4604
shape=1 variant=eager tflops=0.181 gbps=264.719 ms=0.555
shape=1 variant=compiled tflops=0.481 gbps=702.127 ms=0.209
shape=1 variant=solution tflops=0.627 gbps=915.802 ms=0.161
shape=1 solution_peak_fraction=0.5088
shape=2 variant=eager tflops=0.069 gbps=107.246 ms=0.206
shape=2 variant=compiled tflops=0.160 gbps=250.616 ms=0.088
shape=2 variant=solution tflops=0.279 gbps=435.951 ms=0.051
shape=2 solution_peak_fraction=0.2422
shape=3 variant=eager tflops=0.172 gbps=247.350 ms=1.460
shape=3 variant=compiled tflops=0.469 gbps=673.258 ms=0.536
shape=3 variant=solution tflops=0.733 gbps=1052.068 ms=0.343
shape=3 solution_peak_fraction=0.5845
peak_fraction: 0.4267
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -7,7 +7,8 @@
#include <cuda_bf16.h>
#include <cstdint>
-__global__ void mrope_fused_vec(
+__global__ void __launch_bounds__(128)
+mrope_fused(
const __nv_bfloat16* __restrict__ q_in,
const __nv_bfloat16* __restrict__ k_in,
__nv_bfloat16* __restrict__ q_out,
@@ -20,87 +21,76 @@
const int B, const int T, const int Hq, const int Hkv, const int D,
const int s0, const int s1, const int s2
) {
- const int warp_id = threadIdx.x / 32;
- const int lane_id = threadIdx.x % 32;
- const int rows_per_block = blockDim.x / 32;
- const int row = blockIdx.x * rows_per_block + warp_id;
+ extern __shared__ float smem[];
- if (row >= B * Hq * T) return;
+ const int bt = blockIdx.x;
+ const int h = blockIdx.y;
+ const int d = threadIdx.x;
- const int bt = row / Hq;
- const int h = row % Hq;
- const int b = bt / T;
- const int t = bt % T;
+ if (h >= Hq || d >= D) return;
+
+ const int b = bt / T;
+ const int t = bt % T;
const int half_D = D >> 1;
- const int ept = D / 32; // elements per thread: 2 for D=64, 4 for D=128
- const int d_base = lane_id * ept;
- const int partner_lane = (lane_id + 16) % 32;
+ const int d_mod = d < half_D ? d : d - half_D;
const int64_t bt_idx = (int64_t)b * T + t;
- const float sign = lane_id < 16 ? -1.0f : 1.0f;
- // Load q values
- const int64_t q_base = (int64_t)bt * Hq * D + (int64_t)h * D;
- float qx[4];
- for (int i = 0; i < ept; i++)
- qx[i] = __bfloat162float(q_in[q_base + d_base + i]);
+ int64_t pos;
+ if (d_mod < s0) pos = __ldg(&pos_t[bt_idx]);
+ else if (d_mod < s0 + s1) pos = __ldg(&pos_h[bt_idx]);
+ else pos = __ldg(&pos_w[bt_idx]);
- // Partner values via warp shuffle
- float px[4];
- for (int i = 0; i < ept; i++)
- px[i] = __shfl_sync(0xFFFFFFFF, qx[i], partner_lane, 32);
+ const int64_t cache_idx = pos * D + d;
+ const float cos_v = __bfloat162float(__ldg(&cos_cache[cache_idx]));
+ const float sin_v = __bfloat162float(__ldg(&sin_cache[cache_idx]));
- // Cos/sin per element
- float cs[4], sn[4];
- for (int i = 0; i < ept; i++) {
- const int d = d_base + i;
- const int d_mod = d < half_D ? d : d - half_D;
- int64_t pos;
- if (d_mod < s0) pos = pos_t[bt_idx];
- else if (d_mod < s0 + s1) pos = pos_h[bt_idx];
- else pos = pos_w[bt_idx];
- cs[i] = __bfloat162float(cos_cache[pos * D + d]);
- sn[i] = __bfloat162float(sin_cache[pos * D + d]);
+ const int pd = d < half_D ? d + half_D : d - half_D;
+ const float sign = d < half_D ? -1.0f : 1.0f;
+
+ // Process q: load into smem, compute, write
+ {
+ const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
+ smem[d] = __bfloat162float(__ldg(&q_in[in_base + d]));
+ __syncthreads();
+
+ const float x_val = smem[d];
+ const float pval = smem[pd];
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
+
+ const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
+ q_out[oi] = __float2bfloat16(res);
}
- // Compute and write q rotation
- const int64_t qo = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D;
- for (int i = 0; i < ept; i++) {
- const float res = qx[i] * cs[i] + (sign * px[i]) * sn[i];
- q_out[qo + d_base + i] = __float2bfloat16(res);
- }
+ __syncthreads();
// Process k
if (h < Hkv) {
- const int64_t k_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
- float kx[4];
- for (int i = 0; i < ept; i++)
- kx[i] = __bfloat162float(k_in[k_base + d_base + i]);
+ const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ smem[d] = __bfloat162float(__ldg(&k_in[in_base + d]));
+ __syncthreads();
- float kp[4];
- for (int i = 0; i < ept; i++)
- kp[i] = __shfl_sync(0xFFFFFFFF, kx[i], partner_lane, 32);
+ const float x_val = smem[d];
+ const float pval = smem[pd];
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
- const int64_t ko = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D;
- for (int i = 0; i < ept; i++) {
- const float res = kx[i] * cs[i] + (sign * kp[i]) * sn[i];
- k_out[ko + d_base + i] = __float2bfloat16(res);
- }
+ const int64_t oi = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D + d;
+ k_out[oi] = __float2bfloat16(res);
}
}
-void mrope_fused_vec_forward(
+void mrope_fused_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
) {
- const int rows_per_block = 8;
- const int total_rows = B * Hq * T;
- const int grid = (total_rows + rows_per_block - 1) / rows_per_block;
+ const dim3 grid(B * T, Hq);
+ const dim3 block(D);
+ const size_t smem = D * sizeof(float);
- mrope_fused_vec<<<grid, 256>>>(
+ mrope_fused<<<grid, block, smem>>>(
reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
@@ -118,7 +108,7 @@
_CPP_DECL = r"""
#include <torch/extension.h>
#include <cstdint>
-void mrope_fused_vec_forward(
+void mrope_fused_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_
... (1098 more chars){
"content": "<6355 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1 && echo \"---\" && uv run benchmark.py 2>&1",
"description": "Final check and benchmark",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}PASS
---
shape=0 variant=eager tflops=0.174 gbps=249.878 ms=0.361
shape=0 variant=compiled tflops=0.477 gbps=683.742 ms=0.132
shape=0 variant=solution tflops=0.671 gbps=963.004 ms=0.094
shape=0 solution_peak_fraction=0.5350
shape=1 variant=eager tflops=0.182 gbps=265.330 ms=0.554
shape=1 variant=compiled tflops=0.478 gbps=698.019 ms=0.211
shape=1 variant=solution tflops=0.722 gbps=1055.045 ms=0.139
shape=1 solution_peak_fraction=0.5861
shape=2 variant=eager tflops=0.072 gbps=112.253 ms=0.197
shape=2 variant=compiled tflops=0.175 gbps=272.986 ms=0.081
shape=2 variant=solution tflops=0.325 gbps=507.117 ms=0.044
shape=2 solution_peak_fraction=0.2817
shape=3 variant=eager tflops=0.172 gbps=246.769 ms=1.463
shape=3 variant=compiled tflops=0.468 gbps=672.115 ms=0.537
shape=3 variant=solution tflops=0.811 gbps=1163.528 ms=0.310
shape=3 solution_peak_fraction=0.6464
peak_fraction: 0.4888
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -7,8 +7,8 @@
#include <cuda_bf16.h>
#include <cstdint>
-__global__ void __launch_bounds__(128)
-mrope_fused(
+__global__ void __launch_bounds__(256)
+mrope_direct(
const __nv_bfloat16* __restrict__ q_in,
const __nv_bfloat16* __restrict__ k_in,
__nv_bfloat16* __restrict__ q_out,
@@ -21,76 +21,85 @@
const int B, const int T, const int Hq, const int Hkv, const int D,
const int s0, const int s1, const int s2
) {
- extern __shared__ float smem[];
-
- const int bt = blockIdx.x;
- const int h = blockIdx.y;
- const int d = threadIdx.x;
-
- if (h >= Hq || d >= D) return;
-
- const int b = bt / T;
- const int t = bt % T;
+ const int warp_id = threadIdx.x / 32;
+ const int lane_id = threadIdx.x % 32;
+ const int rows_per_block = blockDim.x / 32;
+ const int row = blockIdx.x * rows_per_block + warp_id;
+
+ if (row >= B * Hq * T) return;
+
+ const int bt = row / Hq;
+ const int h = row % Hq;
+ const int b = bt / T;
+ const int t = bt % T;
const int half_D = D >> 1;
-
- const int d_mod = d < half_D ? d : d - half_D;
+ const int ept = D / 32; // elements per thread
+
+ const int d_base = lane_id * ept;
const int64_t bt_idx = (int64_t)b * T + t;
-
- int64_t pos;
- if (d_mod < s0) pos = __ldg(&pos_t[bt_idx]);
- else if (d_mod < s0 + s1) pos = __ldg(&pos_h[bt_idx]);
- else pos = __ldg(&pos_w[bt_idx]);
-
- const int64_t cache_idx = pos * D + d;
- const float cos_v = __bfloat162float(__ldg(&cos_cache[cache_idx]));
- const float sin_v = __bfloat162float(__ldg(&sin_cache[cache_idx]));
-
- const int pd = d < half_D ? d + half_D : d - half_D;
- const float sign = d < half_D ? -1.0f : 1.0f;
-
- // Process q: load into smem, compute, write
+ const int64_t q_in_off = (int64_t)bt * Hq * D + (int64_t)h * D;
+ const int64_t q_out_off = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D;
+
+ // Precompute axis and pos for each d in this warp's range
+ int64_t pos_arr[4];
+ float cos_v[4], sin_v[4];
+ #pragma unroll
+ for (int i = 0; i < ept; i++) {
+ const int d = d_base + i;
+ const int d_mod = d < half_D ? d : d - half_D;
+ int64_t pos;
+ if (d_mod < s0) pos = __ldg(&pos_t[bt_idx]);
+ else if (d_mod < s0 + s1) pos = __ldg(&pos_h[bt_idx]);
+ else pos = __ldg(&pos_w[bt_idx]);
+ pos_arr[i] = pos;
+ cos_v[i] = __bfloat162float(__ldg(&cos_cache[pos * D + d]));
+ sin_v[i] = __bfloat162float(__ldg(&sin_cache[pos * D + d]));
+ }
+
+ // Process q
{
- const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
- smem[d] = __bfloat162float(__ldg(&q_in[in_base + d]));
- __syncthreads();
-
- const float x_val = smem[d];
- const float pval = smem[pd];
- const float res = x_val * cos_v + (sign * pval) * sin_v;
-
- const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
- q_out[oi] = __float2bfloat16(res);
+ #pragma unroll
+ for (int i = 0; i < ept; i++) {
+ const int d = d_base + i;
+ const float x_val = __bfloat162float(__ldg(&q_in[q_in_off + d]));
+ const int pd = d < half_D ? d + half_D : d - half_D;
+ const float pval = __bfloat162float(__ldg(&q_in[q_in_off + pd]));
+ const float rh = d < half_D ? -pval : pval;
+ const float res = x_val * cos_v[i] + rh * sin_v[i];
+ q_out[q_out_off + d] = __float2bfloat16(res);
+ }
}
-
- __syncthreads();
// Process k
if (h < Hkv) {
- const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
- smem[d] = __bfloat162float(__ldg(&k_in[in_base + d]));
- __syncthreads();
-
- const float x_val = smem[d];
- const float pval = smem[pd];
- const float res = x_val * cos_v + (sign * pval) * sin_v;
-
- const int64_t oi = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D + d;
- k_out[oi] = __float2bfloat16(res);
+ const int64_t k_in_off = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ const int64_t k_out_off = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D;
+
+ #pragma unroll
+ for (int i = 0; i < ept; i++) {
+ const int d = d_base + i;
+ const float x_val = __bfloat162float(__ldg(&k_in[k_in_off + d]));
+ const int pd = d < half_D ? d + half_D : d - half_D;
+ const float pval = __bfloat162float(__ldg(&k_in[k_in_off + pd]));
+ const float rh = d < half_D ? -pval : pval;
+ const float res = x_val * cos_v[i] + rh * sin_v[i];
+ k_out[k_out_off + d] = __float2bfloat16(res);
+ }
}
}
-void mrope_fused_forward(
+void mrope_direct_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
) {
- const dim3 grid(B * T, Hq);
- const dim3 block(D);
- const size_t smem = D * sizeof(float);
-
- mrope_fused<<<grid, block, smem>>>(
+ const int rows_per_block = 256 / 32; // 8 warps
+ const int total_rows = B * Hq * T;
+ const int grid = (total_rows + rows_per_block - 1) / rows_per_block;
+
+ mrope_direct<<<grid, 256>>>(
reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
@@ -108,7 +117,7 @@
_CPP_DECL = r"""
#include <torch/extensio
... (1300 more chars){
"content": "<7154 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1 && echo \"---\" && uv run benchmark.py 2>&1",
"description": "Check and benchmark direct global read approach",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}PASS
---
shape=0 variant=eager tflops=0.181 gbps=259.748 ms=0.348
shape=0 variant=compiled tflops=0.452 gbps=649.205 ms=0.139
shape=0 variant=solution tflops=0.745 gbps=1069.214 ms=0.084
shape=0 solution_peak_fraction=0.5940
shape=1 variant=eager tflops=0.172 gbps=251.212 ms=0.585
shape=1 variant=compiled tflops=0.415 gbps=605.864 ms=0.243
shape=1 variant=solution tflops=0.637 gbps=929.891 ms=0.158
shape=1 solution_peak_fraction=0.5166
shape=2 variant=eager tflops=0.061 gbps=94.704 ms=0.234
shape=2 variant=compiled tflops=0.149 gbps=232.102 ms=0.095
shape=2 variant=solution tflops=0.411 gbps=641.783 ms=0.034
shape=2 solution_peak_fraction=0.3565
shape=3 variant=eager tflops=0.172 gbps=247.098 ms=1.461
shape=3 variant=compiled tflops=0.471 gbps=675.293 ms=0.535
shape=3 variant=solution tflops=0.844 gbps=1211.366 ms=0.298
shape=3 solution_peak_fraction=0.6730
peak_fraction: 0.5209
RESULT: OK
{
"command": "uv run benchmark.py 2>&1",
"description": "Re-run benchmark for stability",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}shape=0 variant=eager tflops=0.178 gbps=255.005 ms=0.354
shape=0 variant=compiled tflops=0.494 gbps=709.180 ms=0.127
shape=0 variant=solution tflops=0.640 gbps=918.782 ms=0.098
shape=0 solution_peak_fraction=0.5104
shape=1 variant=eager tflops=0.179 gbps=261.093 ms=0.563
shape=1 variant=compiled tflops=0.448 gbps=654.601 ms=0.225
shape=1 variant=solution tflops=0.688 gbps=1005.068 ms=0.146
shape=1 solution_peak_fraction=0.5584
shape=2 variant=eager tflops=0.070 gbps=109.793 ms=0.201
shape=2 variant=compiled tflops=0.167 gbps=261.570 ms=0.085
shape=2 variant=solution tflops=0.427 gbps=667.826 ms=0.033
shape=2 solution_peak_fraction=0.3710
shape=3 variant=eager tflops=0.172 gbps=246.941 ms=1.462
shape=3 variant=compiled tflops=0.467 gbps=670.717 ms=0.538
shape=3 variant=solution tflops=0.841 gbps=1207.219 ms=0.299
shape=3 solution_peak_fraction=0.6707
peak_fraction: 0.5161
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -7,8 +7,8 @@
#include <cuda_bf16.h>
#include <cstdint>
-__global__ void __launch_bounds__(256)
-mrope_direct(
+__global__ void __launch_bounds__(128)
+mrope_2d(
const __nv_bfloat16* __restrict__ q_in,
const __nv_bfloat16* __restrict__ k_in,
__nv_bfloat16* __restrict__ q_out,
@@ -21,85 +21,63 @@
const int B, const int T, const int Hq, const int Hkv, const int D,
const int s0, const int s1, const int s2
) {
- const int warp_id = threadIdx.x / 32;
- const int lane_id = threadIdx.x % 32;
- const int rows_per_block = blockDim.x / 32;
- const int row = blockIdx.x * rows_per_block + warp_id;
+ const int bt = blockIdx.x;
+ const int h = blockIdx.y;
+ const int d = threadIdx.x;
- if (row >= B * Hq * T) return;
+ if (h >= Hq || d >= D) return;
- const int bt = row / Hq;
- const int h = row % Hq;
- const int b = bt / T;
- const int t = bt % T;
+ const int b = bt / T;
+ const int t = bt % T;
const int half_D = D >> 1;
- const int ept = D / 32; // elements per thread
- const int d_base = lane_id * ept;
+ const int d_mod = d < half_D ? d : d - half_D;
const int64_t bt_idx = (int64_t)b * T + t;
- const int64_t q_in_off = (int64_t)bt * Hq * D + (int64_t)h * D;
- const int64_t q_out_off = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D;
- // Precompute axis and pos for each d in this warp's range
- int64_t pos_arr[4];
- float cos_v[4], sin_v[4];
- #pragma unroll
- for (int i = 0; i < ept; i++) {
- const int d = d_base + i;
- const int d_mod = d < half_D ? d : d - half_D;
- int64_t pos;
- if (d_mod < s0) pos = __ldg(&pos_t[bt_idx]);
- else if (d_mod < s0 + s1) pos = __ldg(&pos_h[bt_idx]);
- else pos = __ldg(&pos_w[bt_idx]);
- pos_arr[i] = pos;
- cos_v[i] = __bfloat162float(__ldg(&cos_cache[pos * D + d]));
- sin_v[i] = __bfloat162float(__ldg(&sin_cache[pos * D + d]));
- }
+ int64_t pos;
+ if (d_mod < s0) pos = __ldg(&pos_t[bt_idx]);
+ else if (d_mod < s0 + s1) pos = __ldg(&pos_h[bt_idx]);
+ else pos = __ldg(&pos_w[bt_idx]);
- // Process q
+ const int64_t cache_idx = pos * D + d;
+ const float cos_v = __bfloat162float(__ldg(&cos_cache[cache_idx]));
+ const float sin_v = __bfloat162float(__ldg(&sin_cache[cache_idx]));
+
+ const int pd = d < half_D ? d + half_D : d - half_D;
+ const float sign = d < half_D ? -1.0f : 1.0f;
+
+ // Process q: direct global reads (partner likely L1 hit)
{
- #pragma unroll
- for (int i = 0; i < ept; i++) {
- const int d = d_base + i;
- const float x_val = __bfloat162float(__ldg(&q_in[q_in_off + d]));
- const int pd = d < half_D ? d + half_D : d - half_D;
- const float pval = __bfloat162float(__ldg(&q_in[q_in_off + pd]));
- const float rh = d < half_D ? -pval : pval;
- const float res = x_val * cos_v[i] + rh * sin_v[i];
- q_out[q_out_off + d] = __float2bfloat16(res);
- }
+ const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
+ const float x_val = __bfloat162float(__ldg(&q_in[in_base + d]));
+ const float pval = __bfloat162float(__ldg(&q_in[in_base + pd]));
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
+ const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
+ q_out[oi] = __float2bfloat16(res);
}
// Process k
if (h < Hkv) {
- const int64_t k_in_off = (int64_t)bt * Hkv * D + (int64_t)h * D;
- const int64_t k_out_off = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D;
-
- #pragma unroll
- for (int i = 0; i < ept; i++) {
- const int d = d_base + i;
- const float x_val = __bfloat162float(__ldg(&k_in[k_in_off + d]));
- const int pd = d < half_D ? d + half_D : d - half_D;
- const float pval = __bfloat162float(__ldg(&k_in[k_in_off + pd]));
- const float rh = d < half_D ? -pval : pval;
- const float res = x_val * cos_v[i] + rh * sin_v[i];
- k_out[k_out_off + d] = __float2bfloat16(res);
- }
+ const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ const float x_val = __bfloat162float(__ldg(&k_in[in_base + d]));
+ const float pval = __bfloat162float(__ldg(&k_in[in_base + pd]));
+ const float res = x_val * cos_v + (sign * pval) * sin_v;
+ const int64_t oi = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D + d;
+ k_out[oi] = __float2bfloat16(res);
}
}
-void mrope_direct_forward(
+void mrope_2d_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
) {
- const int rows_per_block = 256 / 32; // 8 warps
- const int total_rows = B * Hq * T;
- const int grid = (total_rows + rows_per_block - 1) / rows_per_block;
+ const dim3 grid(B * T, Hq);
+ const dim3 block(D);
- mrope_direct<<<grid, 256>>>(
+ mrope_2d<<<grid, block>>>(
reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
@@ -117,7 +95,7 @@
_CPP_DECL = r"""
#include <torch/extension.h>
#include <cstdint>
-void mrope_direct_forward(
+void mrope_2d_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_ou
... (1116 more chars){
"content": "<6197 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run check.py 2>&1 && echo \"---\" && uv run benchmark.py 2>&1",
"description": "Check and benchmark 2D grid direct read approach",
"timeout": 600000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope"
}PASS
---
shape=0 variant=eager tflops=0.182 gbps=260.624 ms=0.346
shape=0 variant=compiled tflops=0.504 gbps=723.271 ms=0.125
shape=0 variant=solution tflops=0.585 gbps=839.619 ms=0.108
shape=0 solution_peak_fraction=0.4665
shape=1 variant=eager tflops=0.181 gbps=264.292 ms=0.556
shape=1 variant=compiled tflops=0.482 gbps=703.363 ms=0.209
shape=1 variant=solution tflops=0.719 gbps=1049.860 ms=0.140
shape=1 solution_peak_fraction=0.5833
shape=2 variant=eager tflops=0.072 gbps=112.071 ms=0.197
shape=2 variant=compiled tflops=0.161 gbps=251.437 ms=0.088
shape=2 variant=solution tflops=0.307 gbps=480.167 ms=0.046
shape=2 solution_peak_fraction=0.2668
shape=3 variant=eager tflops=0.172 gbps=247.513 ms=1.459
shape=3 variant=compiled tflops=0.468 gbps=672.215 ms=0.537
shape=3 variant=solution tflops=0.812 gbps=1164.789 ms=0.310
shape=3 solution_peak_fraction=0.6471
peak_fraction: 0.4655
RESULT: OK
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -7,8 +7,8 @@
#include <cuda_bf16.h>
#include <cstdint>
-__global__ void __launch_bounds__(128)
-mrope_2d(
+__global__ void __launch_bounds__(256)
+mrope_direct(
const __nv_bfloat16* __restrict__ q_in,
const __nv_bfloat16* __restrict__ k_in,
__nv_bfloat16* __restrict__ q_out,
@@ -21,63 +21,81 @@
const int B, const int T, const int Hq, const int Hkv, const int D,
const int s0, const int s1, const int s2
) {
- const int bt = blockIdx.x;
- const int h = blockIdx.y;
- const int d = threadIdx.x;
-
- if (h >= Hq || d >= D) return;
-
- const int b = bt / T;
- const int t = bt % T;
+ const int warp_id = threadIdx.x / 32;
+ const int lane_id = threadIdx.x % 32;
+ const int rows_per_block = blockDim.x / 32;
+ const int row = blockIdx.x * rows_per_block + warp_id;
+
+ if (row >= B * Hq * T) return;
+
+ const int bt = row / Hq;
+ const int h = row % Hq;
+ const int b = bt / T;
+ const int t = bt % T;
const int half_D = D >> 1;
-
- const int d_mod = d < half_D ? d : d - half_D;
+ const int ept = D / 32;
+
+ const int d_base = lane_id * ept;
const int64_t bt_idx = (int64_t)b * T + t;
-
- int64_t pos;
- if (d_mod < s0) pos = __ldg(&pos_t[bt_idx]);
- else if (d_mod < s0 + s1) pos = __ldg(&pos_h[bt_idx]);
- else pos = __ldg(&pos_w[bt_idx]);
-
- const int64_t cache_idx = pos * D + d;
- const float cos_v = __bfloat162float(__ldg(&cos_cache[cache_idx]));
- const float sin_v = __bfloat162float(__ldg(&sin_cache[cache_idx]));
-
- const int pd = d < half_D ? d + half_D : d - half_D;
- const float sign = d < half_D ? -1.0f : 1.0f;
-
- // Process q: direct global reads (partner likely L1 hit)
- {
- const int64_t in_base = (int64_t)bt * Hq * D + (int64_t)h * D;
- const float x_val = __bfloat162float(__ldg(&q_in[in_base + d]));
- const float pval = __bfloat162float(__ldg(&q_in[in_base + pd]));
- const float res = x_val * cos_v + (sign * pval) * sin_v;
- const int64_t oi = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D + d;
- q_out[oi] = __float2bfloat16(res);
+ const int64_t q_in_off = (int64_t)bt * Hq * D + (int64_t)h * D;
+ const int64_t q_out_off = ((int64_t)b * Hq + h) * (int64_t)T * D + (int64_t)t * D;
+
+ // Precompute axis and pos for each d
+ float cos_v[4], sin_v[4];
+ #pragma unroll
+ for (int i = 0; i < ept; i++) {
+ const int d = d_base + i;
+ const int d_mod = d < half_D ? d : d - half_D;
+ int64_t pos;
+ if (d_mod < s0) pos = __ldg(&pos_t[bt_idx]);
+ else if (d_mod < s0 + s1) pos = __ldg(&pos_h[bt_idx]);
+ else pos = __ldg(&pos_w[bt_idx]);
+ cos_v[i] = __bfloat162float(__ldg(&cos_cache[pos * D + d]));
+ sin_v[i] = __bfloat162float(__ldg(&sin_cache[pos * D + d]));
+ }
+
+ // Process q
+ #pragma unroll
+ for (int i = 0; i < ept; i++) {
+ const int d = d_base + i;
+ const float x_val = __bfloat162float(__ldg(&q_in[q_in_off + d]));
+ const int pd = d < half_D ? d + half_D : d - half_D;
+ const float pval = __bfloat162float(__ldg(&q_in[q_in_off + pd]));
+ const float rh = d < half_D ? -pval : pval;
+ const float res = x_val * cos_v[i] + rh * sin_v[i];
+ q_out[q_out_off + d] = __float2bfloat16(res);
}
// Process k
if (h < Hkv) {
- const int64_t in_base = (int64_t)bt * Hkv * D + (int64_t)h * D;
- const float x_val = __bfloat162float(__ldg(&k_in[in_base + d]));
- const float pval = __bfloat162float(__ldg(&k_in[in_base + pd]));
- const float res = x_val * cos_v + (sign * pval) * sin_v;
- const int64_t oi = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D + d;
- k_out[oi] = __float2bfloat16(res);
+ const int64_t k_in_off = (int64_t)bt * Hkv * D + (int64_t)h * D;
+ const int64_t k_out_off = ((int64_t)b * Hkv + h) * (int64_t)T * D + (int64_t)t * D;
+
+ #pragma unroll
+ for (int i = 0; i < ept; i++) {
+ const int d = d_base + i;
+ const float x_val = __bfloat162float(__ldg(&k_in[k_in_off + d]));
+ const int pd = d < half_D ? d + half_D : d - half_D;
+ const float pval = __bfloat162float(__ldg(&k_in[k_in_off + pd]));
+ const float rh = d < half_D ? -pval : pval;
+ const float res = x_val * cos_v[i] + rh * sin_v[i];
+ k_out[k_out_off + d] = __float2bfloat16(res);
+ }
}
}
-void mrope_2d_forward(
+void mrope_direct_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
torch::Tensor cos_cache, torch::Tensor sin_cache,
int B, int T, int Hq, int Hkv, int D, int s0, int s1, int s2
) {
- const dim3 grid(B * T, Hq);
- const dim3 block(D);
-
- mrope_2d<<<grid, block>>>(
+ const int rows_per_block = 256 / 32;
+ const int total_rows = B * Hq * T;
+ const int grid = (total_rows + rows_per_block - 1) / rows_per_block;
+
+ mrope_direct<<<grid, 256>>>(
reinterpret_cast<const __nv_bfloat16*>(q_in.data_ptr<at::BFloat16>()),
reinterpret_cast<const __nv_bfloat16*>(k_in.data_ptr<at::BFloat16>()),
reinterpret_cast<__nv_bfloat16*>(q_out.data_ptr<at::BFloat16>()),
@@ -95,7 +113,7 @@
_CPP_DECL = r"""
#include <torch/extension.h>
#include <cstdint>
-void mrope_2d_forward(
+void mrope_direct_forward(
torch::Tensor q_in, torch::Tensor k_in,
torch::Tensor q_out, torch::Tensor k_out,
torch::Tensor pos_t, torch::Tensor pos_h, torch::Tensor pos_w,
@@ -109,10 +127,10 @@
global _mod
if _mod is None:
_mo
... (972 more chars){
"content": "<6995 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.