"""Fused multi-axis RoPE pre-attention prep via Triton kernels.
Two-stage with separate q and k rotation kernels:
Stage 1: grid=(B*T,). Gather+mix cos/sin once per position => (B*T, D).
Stage 2: grid=(B*T, Hq) for q, (B*T, Hkv) for k.
Each block loads cos/sin ONCE, applies rotate-half, writes transposed output.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import triton
import triton.language as tl
OP_TYPE = "rope"
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
@triton.jit
def _mix_cos_sin_kernel(
cos_cache_ptr, sin_cache_ptr,
pos_t_ptr, pos_h_ptr, pos_w_ptr,
cos_out_ptr, sin_out_ptr,
B, T, D, max_pos,
s0, s1, s2, half_D,
stride_pos,
BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
b = pid // T
t = pid % T
pos_base = b * stride_pos + t
pt = tl.load(pos_t_ptr + pos_base)
ph = tl.load(pos_h_ptr + pos_base)
pw = tl.load(pos_w_ptr + pos_base)
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D
row_t = pt.to(tl.int64) * D
row_h = ph.to(tl.int64) * D
row_w = pw.to(tl.int64) * D
cos_t = tl.load(cos_cache_ptr + row_t + d_offs, mask=mask)
cos_h = tl.load(cos_cache_ptr + row_h + d_offs, mask=mask)
cos_w = tl.load(cos_cache_ptr + row_w + d_offs, mask=mask)
sin_t = tl.load(sin_cache_ptr + row_t + d_offs, mask=mask)
sin_h = tl.load(sin_cache_ptr + row_h + d_offs, mask=mask)
sin_w = tl.load(sin_cache_ptr + row_w + d_offs, mask=mask)
cut0_s1 = s0 + s1
cos_mixed = tl.where(
d_offs < s0, cos_t,
tl.where(d_offs < cut0_s1, cos_h,
tl.where(d_offs < half_D, cos_w,
tl.where(d_offs < half_D + s0, cos_t,
tl.where(d_offs < half_D + cut0_s1, cos_h, cos_w)))))
sin_mixed = tl.where(
d_offs < s0, sin_t,
tl.where(d_offs < cut0_s1, sin_h,
tl.where(d_offs < half_D, sin_w,
tl.where(d_offs < half_D + s0, sin_t,
tl.where(d_offs < half_D + cut0_s1, sin_h, sin_w)))))
out_base = pid * D
tl.store(cos_out_ptr + out_base + d_offs, cos_mixed, mask=mask)
tl.store(sin_out_ptr + out_base + d_offs, sin_mixed, mask=mask)
@triton.jit
def _rope_transpose_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
T, H, D, half_D,
stride_x_bt, stride_x_h,
stride_cs,
stride_out_b, stride_out_h, stride_out_t,
BLOCK_D: tl.constexpr,
):
"""Rotate + transpose kernel. Writes directly to (B,H,T,D).
Grid: (1, B*T, H).
x: (B*T, H, D) flat view.
cos/sin: (B*T, D) flat.
out: (B, H, T, D)."""
pid_bt = tl.program_id(1)
pid_h = tl.program_id(2)
b = pid_bt // T
t = pid_bt % T
d_offs = tl.arange(0, BLOCK_D)
mask_d = d_offs < D
mask_half = d_offs < half_D
mask_second = (d_offs + half_D) < D
# Load cos/sin ONCE (shared across all heads for this position)
cs_base = pid_bt * stride_cs
cos_vals = tl.load(cos_ptr + cs_base + d_offs, mask=mask_d)
sin_vals = tl.load(sin_ptr + cs_base + d_offs, mask=mask_d)
# Load x for this head
x_base = pid_bt * stride_x_bt + pid_h * stride_x_h
x_d = tl.load(x_ptr + x_base + d_offs, mask=mask_half)
x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=mask_second)
# Rotate
out_first = x_d * cos_vals - x_dh * sin_vals
out_second = x_dh * cos_vals + x_d * sin_vals
# Write to (B,H,T,D) layout directly
out_base = b * stride_out_b + pid_h * stride_out_h + t * stride_out_t
tl.store(out_ptr + out_base + d_offs, out_first, mask=mask_half)
tl.store(out_ptr + out_base + d_offs + half_D, out_second, mask=mask_second)
def _next_power_of_2(n):
return 1 << (n - 1).bit_length()
class Model(nn.Module):
def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
mrope_section: tuple[int, int, int], max_pos: int):
super().__init__()
assert sum(mrope_section) == D // 2
self.B, self.T = B, T
self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
self.max_pos = max_pos
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
pos = torch.arange(max_pos, dtype=torch.float32)
freqs = torch.outer(pos, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos().to(torch.bfloat16))
self.register_buffer("sin_cache", emb.sin().to(torch.bfloat16))
self._BLOCK_D = _next_power_of_2(D)
def forward(
self,
q: torch.Tensor, k: torch.Tensor,
pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hq, D = q.shape
_, _, Hkv, _ = k.shape
half_D = D // 2
s0, s1, s2 = self.mrope_section
BT = B * T
BLOCK_D = self._BLOCK_D
# Stage 1: gather + mix cos/sin
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=q.device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=q.device)
grid_mix = (BT,)
_mix_cos_sin_kernel[grid_mix](
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
cos_mixed, sin_mixed,
B, T, D, self.max_pos,
s0, s1, s2, half_D,
pos_t.stride(0),
BLOCK_D=BLOCK_D,
num_warps=4,
)
# Stage 2q: rotate + transpose q
q_flat = q.view(BT, Hq, D)
q_rot = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
_rope_transpose_kernel[(1, BT, Hq)](
q_flat, cos_mixed, sin_mixed, q_rot,
T, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
q_rot.stride(0), q_rot.stride(1), q_rot.stride(2),
BLOCK_D=BLOCK_D,
num_warps=4,
)
# Stage 2k: rotate + transpose k
k_flat = k.view(BT, Hkv, D)
k_rot = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
_rope_transpose_kernel[(1, BT, Hkv)](
k_flat, cos_mixed, sin_mixed, k_rot,
T, Hkv, D, half_D,
k_flat.stride(0), k_flat.stride(1),
cos_mixed.stride(0),
k_rot.stride(0), k_rot.stride(1), k_rot.stride(2),
BLOCK_D=BLOCK_D,
num_warps=4,
)
return q_rot, k_rot
B = 1
T = 4096
Hq = 32
Hkv = 8
D = 128
MROPE_SECTION = (16, 24, 24)
MAX_POS = 32768
def get_inputs():
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16) * 0.5
pos_t = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_h = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
pos_w = torch.randint(0, MAX_POS, (B, T), dtype=torch.int64)
return [q, k, pos_t, pos_h, pos_w]
def get_init_inputs():
return [B, T, Hq, Hkv, D, MROPE_SECTION, MAX_POS]
shape=0 variant=eager tflops=0.177 gbps=254.109 ms=0.355
shape=0 variant=compiled tflops=0.488 gbps=699.856 ms=0.129
shape=0 variant=solution tflops=0.560 gbps=804.082 ms=0.112
shape=0 solution_peak_fraction=0.4467
shape=1 variant=eager tflops=0.179 gbps=261.501 ms=0.562
shape=1 variant=compiled tflops=0.460 gbps=671.392 ms=0.219
shape=1 variant=solution tflops=0.642 gbps=938.153 ms=0.157
shape=1 solution_peak_fraction=0.5212
shape=2 variant=eager tflops=0.070 gbps=109.863 ms=0.201
shape=2 variant=compiled tflops=0.172 gbps=269.474 ms=0.082
shape=2 variant=solution tflops=0.204 gbps=319.482 ms=0.069
shape=2 solution_peak_fraction=0.1775
shape=3 variant=eager tflops=0.170 gbps=244.081 ms=1.479
shape=3 variant=compiled tflops=0.462 gbps=662.877 ms=0.545
shape=3 variant=solution tflops=0.742 gbps=1064.072 ms=0.339
shape=3 solution_peak_fraction=0.5912
peak_fraction: 0.3953
RESULT: OK
PASS
{
"problem": "09_fmha_preattn_mrope",
"harness": "opencode",
"model": "openrouter-pinned/qwen/qwen3.6-max-preview",
"reasoning_effort": "",
"has_solution": true,
"correct": true,
"peak_fraction": 0.3953,
"elapsed_seconds": 2700,
"harness_exit_code": 124,
"session_complete": false,
"usage": {"input_tokens": 4525648, "output_tokens": 45439, "cache_read_tokens": 0, "cache_creation_tokens": 0, "reasoning_tokens": 17324, "total_cost_usd": null}
}
{
"command": "nvidia-smi --query-gpu=name,driver_version,memory.total,compute_cap --format=csv,noheader",
"description": "Check GPU info"
}NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 595.58.03, 97887 MiB, 12.0
{
"prompt": "Explore /home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope thoroughly. Read ALL files in this directory - reference.py, check.py, benchmark.py, solution.py (if exists), shapes.py, problem.yaml, PROMPT.txt, and any other files. Return the COMPLETE contents of each file so I can implement an optimized kernel. Be very thorough - read every file completely.",
"description": "Explore problem directory",
"subagent_type": "explore"
}task_id: ses_1fc7c0623ffegxCD7L8oNSc7s5 (for resuming to continue this task if needed)
<task_result>
Here is the complete contents of every file in the problem directory, along with the key supporting infrastructure code:
---
## File Listing
The directory `/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/` contains 7 files:
1. `problem.yaml`
2. `PROMPT.txt`
3. `reference.py`
4. `check.py`
5. `benchmark.py`
6. `shapes.py`
7. `sota.py`
There is **no `solution.py`** yet -- you need to create it.
---
## file: `PROMPT.txt`
```
I need you to optimize the multi-axis RoPE pre-attention prep step for the RTX PRO 6000 (SM120 Blackwell, GDDR7, 1.8 TB/s). The reference implementation is in reference.py and your kernel goes in solution.py. You can make whatever mess you want in this directory — scratch files, intermediate kernels, build artifacts, profiling traces — but the final answer has to be in solution.py with the same Model, get_inputs, and get_init_inputs interface as reference.py.
The op is the pre-attention prep used in Qwen2-VL-style vision-language models: take q (B, T, Hq, D) and k (B, T, Hkv, D) bf16 plus three int64 position-id tensors pos_t, pos_h, pos_w of shape (B, T) — temporal, height, width axes — gather per-axis cos/sin from a precomputed (max_pos, D) bf16 cache, mix the per-axis values across the head_dim according to mrope_section (a tuple [s0, s1, s2] summing to D/2 — the first s0 pair-indices use the temporal axis, next s1 use height
... (21653 more chars){
"todos": [
{
"content": "Read reference.py carefully, understand the _mix_axes bug/behavior",
"status": "in_progress",
"priority": "high"
},
{
"content": "Implement Triton kernel: gather + mix + rotate + transpose fused",
"status": "pending",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py for performance",
"status": "pending",
"priority": "medium"
},
{
"content": "Iterate on performanc
... (117 more chars)content: 'Read reference.py carefully, understand the _mix_axes bug/behavior'
status: 'in_progress'
priority: 'high'
content: 'Implement Triton kernel: gather + mix + rotate + transpose fused'
status: 'pending'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py for performance'
status: 'pending'
priority: 'medium'
content: 'Iterate on performance (tune block sizes, load coalescing, shared mem)'
status: 'pending'
priority: 'medium'{
"command": "<924 chars \u2014 see below>",
"description": "Test reference mix_axes behavior"
}uv run python -c "
import torch
import sys
sys.path.insert(0, '.')
import reference
# Initialize the model
model = reference.Model(1, 4096, 32, 8, 128, (16, 24, 24), 32768)
model = model.cuda()
# Test _mix_axes with distinct values per axis
c_t = torch.ones(1, 1, 128, dtype=torch.bfloat16, device='cuda') * 1.0
c_h = torch.ones(1, 1, 128, dtype=torch.bfloat16, device='cuda') * 2.0
c_w = torch.ones(1, 1, 128, dtype=torch.bfloat16, device='cuda') * 3.0
out = model._mix_axes(c_t, c_h, c_w)
print('Shape:', out.shape)
print('First 16 elements (should be from t, val=1):', out[0, 0, :16])
print('Idx 16-40 (should be from h, val=2):', out[0, 0, 16:40])
print('Idx 40-64 (should be from w, val=3):', out[0, 0, 40:64])
print('Idx 64-80 (2nd half t):', out[0, 0, 64:80])
print('Idx 80-104 (2nd half h):', out[0, 0, 80:104])
print('Idx 104-128 (2nd half w):', out[0, 0, 104:128])
print()
print('Full output:', out[0, 0, :])
"
Shape: torch.Size([1, 1, 128])
First 16 elements (should be from t, val=1): tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
device='cuda:0', dtype=torch.bfloat16)
Idx 16-40 (should be from h, val=2): tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2.], device='cuda:0', dtype=torch.bfloat16)
Idx 40-64 (should be from w, val=3): tensor([3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
3., 3., 3., 3., 3., 3.], device='cuda:0', dtype=torch.bfloat16)
Idx 64-80 (2nd half t): tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
device='cuda:0', dtype=torch.bfloat16)
Idx 80-104 (2nd half h): tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2.], device='cuda:0', dtype=torch.bfloat16)
Idx 104-128 (2nd half w): tensor([3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
3., 3., 3., 3., 3., 3.], device='cuda:0', dtype=torch.bfloat16)
Full output: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2.,
... (201 more chars){
"command": "<1447 chars \u2014 see below>",
"description": "Investigate mix_axes indexing bug"
}uv run python -c "
import sys
sys.path.insert(0, '.')
import reference
import torch
model = reference.Model(1, 4096, 32, 8, 128, (16, 24, 24), 32768)
h = 128 // 2
s0, s1, s2 = 16, 24, 24
# The suspicious line in _mix_axes:
idx = h + s0 + h + s0 + s1 + s2
print(f'Suspicious index: h+s0+h+s0+s1+s2 = {h}+{s0}+{h}+{s0}+{s1}+{s2} = {idx}')
print(f'c_w shape: (1, 1, 128)')
print(f'This should be OOB!')
# Let me test with non-uniform values to see what really happens
c_w = torch.arange(128, dtype=torch.bfloat16).unsqueeze(0).unsqueeze(0).cuda()
print(f'c_w values at that index: would be index {idx} which is OOB for size 128')
# Test if the line broadcasts a scalar
try:
out = torch.empty(1, 1, 128, dtype=torch.bfloat16, device='cuda')
out[..., h+s0+s1:h+s0+s1+s2] = c_w[..., h+s0+h+s0+s1+s2]
print(f'Result of assignment: out[104:128] = {out[0, 0, 104:128]}')
except Exception as e:
print(f'Error: {e}')
# What SHOULD it be?
print(f'Expected: c_w[..., h+s0+s1:h+s0+s1+s2] = c_w[..., {h+s0+s1}:{h+s0+s1+s2}] = c_w[..., 104:128]')
print(f'c_w[104:128] = {c_w[0, 0, 104:128]}')
print()
print(f'Actually assigned index: {idx}, but size is 128')
# Hmm... let me check if PyTorch wraps around
# Actually 208 > 128, but maybe negative indexing or something?
# 208 - 128 = 80, maybe it wraps?
print(f'If wrapped: 208 mod 128 = {208 % 128}, c_w[80] = {c_w[0,0,80]}')
print(f'c_w[..., 208] with wrap = element {208 % 128} = {80}')
"
Suspicious index: h+s0+h+s0+s1+s2 = 64+16+64+16+24+24 = 208
c_w shape: (1, 1, 128)
This should be OOB!
c_w values at that index: would be index 208 which is OOB for size 128
Error: index 208 is out of bounds for dimension 1 with size 128
Expected: c_w[..., h+s0+s1:h+s0+s1+s2] = c_w[..., 104:128] = c_w[..., 104:128]
c_w[104:128] = tensor([104., 105., 106., 107., 108., 109., 110., 111., 112., 113., 114., 115.,
116., 117., 118., 119., 120., 121., 122., 123., 124., 125., 126., 127.],
device='cuda:0', dtype=torch.bfloat16)
Actually assigned index: 208, but size is 128
If wrapped: 208 mod 128 = 80, c_w[80] = 80.0
c_w[..., 208] with wrap = element 80 = 80
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py",
"offset": 64,
"limit": 10
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py</path>
<type>file</type>
<content>
64: out[..., h+s0+s1:h+s0+s1+s2] = c_w[..., h+s0+s1:h+s0+s1+s2]
65: return out
66:
67: def forward(
68: self,
69: q: torch.Tensor, k: torch.Tensor,
70: pos_t: torch.Tensor, pos_h: torch.Tensor, pos_w: torch.Tensor,
71: ) -> tuple[torch.Tensor, torch.Tensor]:
72: # Gather per-axis cos/sin: (B, T, D)
73: cos_t = self.cos_cache[pos_t]
(Showing lines 64-73 of 113. Use offset=74 to continue.)
</content>{
"todos": [
{
"content": "Read reference.py carefully, understand the _mix_axes bug/behavior",
"status": "completed",
"priority": "high"
},
{
"content": "Implement Triton kernel: gather + mix + rotate + transpose fused",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "pending",
"priority": "high"
},
{
"content": "Run benchmark.py for performance",
"status": "pending",
"priority": "medium"
},
{
"content": "Iterate on performa
... (119 more chars)content: 'Read reference.py carefully, understand the _mix_axes bug/behavior'
status: 'completed'
priority: 'high'
content: 'Implement Triton kernel: gather + mix + rotate + transpose fused'
status: 'in_progress'
priority: 'high'
content: 'Run check.py for correctness'
status: 'pending'
priority: 'high'
content: 'Run benchmark.py for performance'
status: 'pending'
priority: 'medium'
content: 'Iterate on performance (tune block sizes, load coalescing, shared mem)'
status: 'pending'
priority: 'medium'{
"command": "uv run python -c \"import torch; print('torch:', torch.__version__); import triton; print('triton:', triton.__version__)\"",
"description": "Check torch and triton versions"
}torch: 2.11.0+cu130
triton: 3.6.0
{
"command": "cat /home/infatoshi/cuda/KernelBench-Hard/src/eval/correctness.py",
"description": "Read correctness check impl"
}"""Per-dtype correctness tolerance.
Stanford's KernelBench uses 1e-4 atol/rtol for fp32 and 1e-2 for fp16/bf16.
We extend to fp8 (0.1, permissive) and int (exact).
"""
from __future__ import annotations
import torch
DEFAULT_TOLERANCE = {
torch.float32: {"atol": 1e-4, "rtol": 1e-4},
torch.float16: {"atol": 1e-2, "rtol": 1e-2},
torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
torch.float8_e4m3fn: {"atol": 1e-1, "rtol": 1e-1},
torch.float8_e5m2: {"atol": 1e-1, "rtol": 1e-1},
torch.int8: {"atol": 0, "rtol": 0},
torch.int32: {"atol": 0, "rtol": 0},
torch.int64: {"atol": 0, "rtol": 0},
}
def _coerce_float(x):
"""Accept int/float/str scientific-notation as float; pass through dicts."""
if isinstance(x, (int, float)):
return float(x)
if isinstance(x, str):
try:
return float(x)
except ValueError:
return x
return x
def tolerance_for_dtype(dtype: torch.dtype, override: dict | None = None) -> dict:
"""Lookup atol/rtol for a given dtype, with optional per-problem override.
The override dict is keyed by str(dtype). Values may be:
- a single number (or a string parseable as float, e.g. "5e-2") -> used
as both atol and rtol
- a dict {"atol": ..., "rtol": ...} -> used directly, with values coerced
"""
# Accept several spellings of the same dtype key (PyYAML quirks):
# "torch.bfloat16", "bfloat16", torch.bfloat16
if override is not None:
type_
... (2072 more chars){
"command": "cat /home/infatoshi/cuda/KernelBench-Hard/src/eval/timing.py",
"description": "Read timing utility"
}"""Centralized GPU timing for benchmark.py files.
All problems' benchmark.py should call `time_fn` from here so we have one
implementation of warmup, L2 flush, and CUDA event capture to maintain.
Methodology:
- 10 warmup calls absorb Triton autotune (typical ~7 configs) and
torch.compile reduce-overhead CUDA-graph capture.
- Between each timed call, `_l2_flush()` writes 128 MB to evict L2
(Blackwell L2 is 96 MB; we want any prior L2 contents flushed so
we measure HBM-load bandwidth, not L2-cached bandwidth).
- GPU timing via cuda.Event with synchronize() AFTER record() but
BEFORE elapsed_time() — the canonical NVIDIA pattern.
- Reported value is the median of `iters` trials, robust to outliers.
Notes / known biases not addressed here:
- torch.compile(mode="reduce-overhead") gets CUDA graphs which eliminate
launch overhead. Custom Triton/CUDA kernels do NOT get this treatment.
On small shapes where launch overhead matters, this gives compile an
artificial advantage. We accept this as the cost of using the agreed
"compiled" baseline policy.
- cuBLAS / cuDNN allocate workspaces on first call. 10 warmup absorbs.
- Median over a fairly small number of trials (default 30) is fine for
headline numbers but won't catch bimodal distributions. Use --extra
instrumentation if you ever care.
"""
from __future__ import annotations
import statistics
import torch
# Blackwell consumer (RTX PRO 6000) L2 is 96 MB. Allocate strictly larg
... (1067 more chars)--- /dev/null
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -0,0 +1,165 @@
+"""Fused multi-axis RoPE pre-attention prep via Triton kernels.
+
+Strategy:
+ 1. Precompute mixed cos/sin: gather per-axis values from cache and mix
+ according to mrope_section. Shape (B, T, D). Done once (shared across heads).
+ 2. Apply rotate-half + transpose in a single fused kernel per tensor (q, k).
+ Grid: (B*T, num_heads). Each block processes all D elements, writes
+ directly into (B, H, T, D) output layout.
+
+This avoids redundant gather+mix across heads and minimizes kernel launches.
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import triton
+import triton.language as tl
+
+OP_TYPE = "rope"
+SUPPORTED_PRECISIONS = ["bf16"]
+HARDWARE_REQUIRED = ["RTX_PRO_6000"]
+
+# ---------------------------------------------------------------------------
+# Triton kernels
+# ---------------------------------------------------------------------------
+
+@triton.jit
+def _mix_cos_sin_kernel(
+ cos_cache_ptr, sin_cache_ptr,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ cos_out_ptr, sin_out_ptr,
+ B, T, D, max_pos,
+ s0, s1, s2, half_D,
+ stride_pos_bt,
+ BLOCK_D: tl.constexpr,
+):
+ """Gather per-axis cos/sin, mix by mrope_section, write (B,T,D).
+
+ Grid: (B*T,) one program per (b,t).
+ """
+ pid = tl.program_id(0)
+ b = pid // T
+ t = pid % T
+
+ pos_base = b * stride_pos_bt + t
+ pt = tl.load(pos_t_ptr + pos_base)
+ ph = tl.load(pos_h_ptr + pos_base)
+ pw = tl.load(pos_w_ptr + pos_base)
+
+ d_offsets = tl.arange(0, BLOCK_D)
+ mask = d_offsets < D
+
+ # Load cos values: cos_cache[pos, :] for all D columns
+ cos_t_row = pt * max_pos
+ cos_h_row = ph * max_pos
+ cos_w_row = pw * max_pos
+ max_cos_row = max_pos * D # not needed, we use row * D + d
+
+ cos_t = tl.load(cos_cache_ptr + cos_t_row + d_offsets, mask=mask)
+ cos_h = tl.load(cos_cache_ptr + cos_h_row + d_offsets, mask=mask)
+ cos_w = tl.load(cos_cache_ptr + cos_w_row + d_offsets, mask=mask)
+
+ sin_t_row = pt * max_pos
+ sin_h_row = ph * max_pos
+ sin_w_row = pw * max_pos
+
+ sin_t = tl.load(sin_cache_ptr + sin_t_row + d_offsets, mask=mask)
+ sin_h = tl.load(sin_cache_ptr + sin_h_row + d_offsets, mask=mask)
+ sin_w = tl.load(sin_cache_ptr + sin_w_row + d_offsets, mask=mask)
+
+ # Mix: select per-index based on mrope_section
+ # First half [0:half_D]: [0:s0)->t, [s0:s0+s1)->h, [s0+s1:D/2)->w
+ # Second half [half_D:D]: [half_D:half_D+s0)->t, [half_D+s0:half_D+s0+s1)->h, [half_D+s0+s1:D)->w
+ cut1 = s0
+ cut2 = s0 + s1
+ cut3 = half_D
+ cut4 = half_D + s0
+ cut5 = half_D + s0 + s1
+
+ cos_mixed = tl.where(
+ d_offsets < cut1, cos_t,
+ tl.where(d_offsets < cut2, cos_h,
+ tl.where(d_offsets < cut3, cos_w,
+ tl.where(d_offsets < cut4, cos_t,
+ tl.where(d_offsets < cut5, cos_h, cos_w)))))
+
+ sin_mixed = tl.where(
+ d_offsets < cut1, sin_t,
+ tl.where(d_offsets < cut2, sin_h,
+ tl.where(d_offsets < cut3, sin_w,
+ tl.where(d_offsets < cut4, sin_t,
+ tl.where(d_offsets < cut5, sin_h, sin_w)))))
+
+ out_base = pid * D
+ tl.store(cos_out_ptr + out_base + d_offsets, cos_mixed, mask=mask)
+ tl.store(sin_out_ptr + out_base + d_offsets, sin_mixed, mask=mask)
+
+
+@triton.jit
+def _rope_transpose_kernel(
+ x_ptr, cos_ptr, sin_ptr, out_ptr,
+ B, T, H, D, half_D,
+ # Input strides: x is (B, T, H, D)
+ stride_xb, stride_xt, stride_xh, stride_xd,
+ # cos/sin strides: (B, T, D)
+ stride_cb, stride_ct, stride_cd,
+ # Output strides: (B, H, T, D)
+ stride_ob, stride_oh, stride_ot, stride_od,
+ BLOCK_D: tl.constexpr,
+):
+ """Apply rotate-half with pre-mixed cos/sin and write transposed output.
+
+ Grid: (B*T, H) — one program per (b, t, h).
+ """
+ pid_bt = tl.program_id(0)
+ pid_h = tl.program_id(1)
+
+ b = pid_rt // T
+ t = pid_rt % T
+
+ # Load cos/sin for this (b, t) — already mixed, shape (B*T, D)
+ cos_sin_base = pid_rt * D
+ d_offsets = tl.arange(0, BLOCK_D)
+ mask = d_offsets < D
+
+ cos_vals = tl.load(cos_ptr + cos_sin_base + d_offsets, mask=mask)
+ sin_vals = tl.load(sin_ptr + cos_sin_base + d_offsets, mask=mask)
+
+ # First half of D
+ half_mask = d_offsets < half_D
+ d_first = tl.where(half_mask, d_offsets, d_offsets) # identity, just for loading
+
+ # Load x values
+ x_base = b * stride_xb + t * stride_xt + pid_h * stride_xh
+
+ # First half [0: half_D)
+ first_half_mask = half_mask
+ x_first = tl.load(x_ptr + x_base + d_offsets, mask=first_half_mask)
+
+ # Second half [half_D: D): these are d - half_D
+ d_second = d_offsets - half_D
+ second_half_mask = d_offsets >= half_D
+ x_second = tl.load(x_ptr + x_base + d_second, mask=second_half_mask)
+ # rotate_half: second half goes to first half positions, negated
+
+ # Apply rotation for first half:
+ # out[d] = x[d] * cos[d] + x[d + half_D] * sin[d]
+ # out[d + half_D] = x[d + half_D] * cos[d + half_D] - x[d] * sin[d + half_D]
+ # Note: cos/sin are duplicated (freqs || freqs), so cos[d] == cos[d + half_D] typically
+
+ x_d = x_first
+ x_d_half = x_second # this is x[d + half_D]
+
+ # First half output
+ out_first = x_d * cos_vals + x_d_half * sin_vals
+ # Second half output
+ out_second = x_d_half * cos_vals - x_d * sin_vals
+
+ # Write in transposed layout (B, H, T, D)
+ out_base = b * stride_ob + pid_h * stride_oh + t * stride_ot
+
+ # First half
+ tl.store(out_ptr + out_base + d_offsets, out_first, mask=first_half_mask)
+ # Second half
+ tl.store(out_ptr + out_base + d_offsets, out_second, mask=second_half_mask)
{
"content": "<5600 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -3,11 +3,9 @@
Strategy:
1. Precompute mixed cos/sin: gather per-axis values from cache and mix
according to mrope_section. Shape (B, T, D). Done once (shared across heads).
- 2. Apply rotate-half + transpose in a single fused kernel per tensor (q, k).
+ 2. Fused rotate-half + transpose kernel per tensor (q, k).
Grid: (B*T, num_heads). Each block processes all D elements, writes
directly into (B, H, T, D) output layout.
-
-This avoids redundant gather+mix across heads and minimizes kernel launches.
"""
from __future__ import annotations
@@ -34,9 +32,9 @@
stride_pos_bt,
BLOCK_D: tl.constexpr,
):
- """Gather per-axis cos/sin, mix by mrope_section, write (B,T,D).
-
- Grid: (B*T,) one program per (b,t).
+ """Gather per-axis cos/sin from cache, mix by mrope_section, write (B,T,D).
+
+ Grid: (B*T,) -- one program per (b,t).
"""
pid = tl.program_id(0)
b = pid // T
@@ -47,119 +45,213 @@
ph = tl.load(pos_h_ptr + pos_base)
pw = tl.load(pos_w_ptr + pos_base)
- d_offsets = tl.arange(0, BLOCK_D)
- mask = d_offsets < D
-
- # Load cos values: cos_cache[pos, :] for all D columns
- cos_t_row = pt * max_pos
- cos_h_row = ph * max_pos
- cos_w_row = pw * max_pos
- max_cos_row = max_pos * D # not needed, we use row * D + d
-
- cos_t = tl.load(cos_cache_ptr + cos_t_row + d_offsets, mask=mask)
- cos_h = tl.load(cos_cache_ptr + cos_h_row + d_offsets, mask=mask)
- cos_w = tl.load(cos_cache_ptr + cos_w_row + d_offsets, mask=mask)
-
- sin_t_row = pt * max_pos
- sin_h_row = ph * max_pos
- sin_w_row = pw * max_pos
-
- sin_t = tl.load(sin_cache_ptr + sin_t_row + d_offsets, mask=mask)
- sin_h = tl.load(sin_cache_ptr + sin_h_row + d_offsets, mask=mask)
- sin_w = tl.load(sin_cache_ptr + sin_w_row + d_offsets, mask=mask)
-
- # Mix: select per-index based on mrope_section
- # First half [0:half_D]: [0:s0)->t, [s0:s0+s1)->h, [s0+s1:D/2)->w
- # Second half [half_D:D]: [half_D:half_D+s0)->t, [half_D+s0:half_D+s0+s1)->h, [half_D+s0+s1:D)->w
- cut1 = s0
- cut2 = s0 + s1
- cut3 = half_D
- cut4 = half_D + s0
- cut5 = half_D + s0 + s1
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ # Compute row offsets into (max_pos, D) cache
+ row_t = pt.to(tl.int64) * D
+ row_h = ph.to(tl.int64) * D
+ row_w = pw.to(tl.int64) * D
+
+ cos_t = tl.load(cos_cache_ptr + row_t + d_offs, mask=mask)
+ cos_h = tl.load(cos_cache_ptr + row_h + d_offs, mask=mask)
+ cos_w = tl.load(cos_cache_ptr + row_w + d_offs, mask=mask)
+
+ sin_t = tl.load(sin_cache_ptr + row_t + d_offs, mask=mask)
+ sin_h = tl.load(sin_cache_ptr + row_h + d_offs, mask=mask)
+ sin_w = tl.load(sin_cache_ptr + row_w + d_offs, mask=mask)
+
+ # Mix: choose per-index axis according to mrope_section
+ # [0:s0) -> t, [s0:s0+s1) -> h, [s0+s1:half_D) -> w
+ # [half_D:half_D+s0) -> t, [half_D+s0:half_D+s0+s1) -> h, [half_D+s0+s1:D) -> w
+ cut0_s1 = s0 + s1
cos_mixed = tl.where(
- d_offsets < cut1, cos_t,
- tl.where(d_offsets < cut2, cos_h,
- tl.where(d_offsets < cut3, cos_w,
- tl.where(d_offsets < cut4, cos_t,
- tl.where(d_offsets < cut5, cos_h, cos_w)))))
+ d_offs < s0, cos_t,
+ tl.where(d_offs < cut0_s1, cos_h,
+ tl.where(d_offs < half_D, cos_w,
+ tl.where(d_offs < half_D + s0, cos_t,
+ tl.where(d_offs < half_D + cut0_s1, cos_h, cos_w)))))
sin_mixed = tl.where(
- d_offsets < cut1, sin_t,
- tl.where(d_offsets < cut2, sin_h,
- tl.where(d_offsets < cut3, sin_w,
- tl.where(d_offsets < cut4, sin_t,
- tl.where(d_offsets < cut5, sin_h, sin_w)))))
+ d_offs < s0, sin_t,
+ tl.where(d_offs < cut0_s1, sin_h,
+ tl.where(d_offs < half_D, sin_w,
+ tl.where(d_offs < half_D + s0, sin_t,
+ tl.where(d_offs < half_D + cut0_s1, sin_h, sin_w)))))
out_base = pid * D
- tl.store(cos_out_ptr + out_base + d_offsets, cos_mixed, mask=mask)
- tl.store(sin_out_ptr + out_base + d_offsets, sin_mixed, mask=mask)
+ tl.store(cos_out_ptr + out_base + d_offs, cos_mixed, mask=mask)
+ tl.store(sin_out_ptr + out_base + d_offs, sin_mixed, mask=mask)
@triton.jit
def _rope_transpose_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
B, T, H, D, half_D,
- # Input strides: x is (B, T, H, D)
- stride_xb, stride_xt, stride_xh, stride_xd,
- # cos/sin strides: (B, T, D)
- stride_cb, stride_ct, stride_cd,
- # Output strides: (B, H, T, D)
- stride_ob, stride_oh, stride_ot, stride_od,
+ stride_x_b, stride_x_t, stride_x_h,
+ stride_cos_b, stride_cos_t, stride_cos_d,
+ stride_out_b, stride_out_h, stride_out_t,
BLOCK_D: tl.constexpr,
):
- """Apply rotate-half with pre-mixed cos/sin and write transposed output.
-
- Grid: (B*T, H) — one program per (b, t, h).
+ """Apply rotate-half with pre-mixed cos/sin, write transposed to (B,H,T,D).
+
+ Grid: (B*T, H) -- program 0 gives (b,t), program 1 gives h.
+
+ rotate_half(x): [-x[half_D:D], x[0:half_D]]
+ out = x * cos + rotate_half(x) * sin
+
+ For d in [0, half_D):
+ out[d] = x[d]*cos[d] - x[d+half_D]*sin[d]
+ For d in [half_D, D): (let d'=d-half_D)
+ out[d] = x[d]*cos[d'] + x[d']*sin[d']
"""
pid_bt = tl.program_id(0)
- pid_h = tl.program_id(1)
-
- b = pid_rt // T
- t = pid_rt % T
-
- # Load cos/sin for this (b, t) — already mixed, shape (B*T, D)
- cos_sin_base = pid_rt * D
- d_offsets = tl.arange(0, BLOCK_D)
- mask = d_offsets < D
-
- cos_vals = tl.load(cos_ptr + cos_sin_base + d_offsets, mask=mask)
- sin_vals = tl.load(sin_ptr + cos_sin_base +
... (6935 more chars){
"content": "<8919 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"todos": [
{
"content": "Read reference.py carefully, understand the _mix_axes bug/behavior",
"status": "completed",
"priority": "high"
},
{
"content": "Implement Triton kernel: gather + mix + rotate + transpose fused",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "in_progress",
"priority": "high"
},
{
"content": "Run benchmark.py for performance",
"status": "pending",
"priority": "medium"
},
{
"content": "Iterate on perfor
... (121 more chars)content: 'Read reference.py carefully, understand the _mix_axes bug/behavior'
status: 'completed'
priority: 'high'
content: 'Implement Triton kernel: gather + mix + rotate + transpose fused'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'in_progress'
priority: 'high'
content: 'Run benchmark.py for performance'
status: 'pending'
priority: 'medium'
content: 'Iterate on performance (tune block sizes, load coalescing, shared mem)'
status: 'pending'
priority: 'medium'{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run check.py for correctness"
}FAIL: shape 0 {'B': 1, 'T': 4096, 'Hq': 32, 'Hkv': 8, 'D': 128, 'mrope_section': (16, 24, 24), 'max_pos': 32768} seed 42 q_rot: tolerance exceeded: max_abs_diff=4.42188 (atol=0.02, rtol=0.02)
{
"command": "<1695 chars \u2014 see below>",
"description": "Debug solution correctness"
}uv run python -c "
import torch
import sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
mrope_section = (16, 24, 24)
max_pos = 1024
torch.manual_seed(42)
ref_model = reference.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
sol_model = solution.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device=device) * 0.5
k = torch.randn(B, Hkv, T, D, dtype=torch.bfloat16, device=device) * 0.5
pos_t = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_h = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_w = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
ref_q, ref_k = ref_model(q, k, pos_t, pos_h, pos_w)
sol_q, sol_k = sol_model(q, k, pos_t, pos_h, pos_w)
print('Ref q_rot shape:', ref_q.shape, 'sol q_rot shape:', sol_q.shape)
print('Max diff q:', (ref_q.float() - sol_q.float()).abs().max().item())
print('Ref k_rot shape:', ref_k.shape, 'sol k_rot shape:', sol_k.shape)
print('Max diff k:', (ref_k.float() - sol_k.float()).abs().max().item())
# Check a specific element
print()
b, t, h = 0, 0, 0
print(f'Ref q[0,0,0,:8]:', ref_q[b, h, t, :8])
print(f'Sol q[0,0,0,:8]:', sol_q[b, h, t, :8])
print(f'Diff:', (ref_q[b, h, t, :8].float() - sol_q[b, h, t, :8].float()).abs())
# Also check intermediate cos/sin mixing
ref_cos_t = ref_model.cos_cache[pos_t] # (B, T, D)
ref_sin_t = ref_model.sin_cache[pos_t]
print()
print('Ref cos_cache[pos_t] shape:', ref_cos_t.shape)
print('cos_cache[pos_t][0,0,:8]:', ref_cos_t[0, 0, :8])
" 2>&1
Traceback (most recent call last):
File "<string>", line 21, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/reference.py", line 87, in forward
k_rot = (k * cos) + (_rotate_half(k) * sin)
~~^~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (64) at non-singleton dimension 1
{
"command": "<1522 chars \u2014 see below>",
"description": "Debug solution with correct shapes"
}uv run python -c "
import torch
import sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
mrope_section = (16, 24, 24)
max_pos = 1024
torch.manual_seed(42)
ref_model = reference.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
sol_model = solution.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device=device) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device=device) * 0.5
pos_t = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_h = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_w = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
ref_q, ref_k = ref_model(q, k, pos_t, pos_h, pos_w)
sol_q, sol_k = sol_model(q, k, pos_t, pos_h, pos_w)
print('Ref q_rot shape:', ref_q.shape, 'sol q_rot shape:', sol_q.shape)
print('Max diff q:', (ref_q.float() - sol_q.float()).abs().max().item())
print()
# Specific element
b, t, h = 0, 0, 0
print(f'Ref q[b={b},h={h},t={t},:16]:', ref_q[b, h, t, :16])
print(f'Sol q[b={b},h={h},t={t},:16]:', sol_q[b, h, t, :16])
print(f'Diff:', (ref_q[b, h, t, :16].float() - sol_q[b, h, t, :16].float()).abs())
print()
print(f'Ref q[b,h,t,64:80]:', ref_q[b, h, t, 64:80])
print(f'Sol q[b,h,t,64:80]:', sol_q[b, h, t, 64:80])
print(f'Diff:', (ref_q[b, h, t, 64:80].float() - sol_q[b, h, t, 64:80].float()).abs())
" 2>&1
Ref q_rot shape: torch.Size([1, 4, 64, 128]) sol q_rot shape: torch.Size([1, 4, 64, 128])
Max diff q: 4.0
Ref q[b=0,h=0,t=0,:16]: tensor([-0.4551, -1.0234, 0.1245, -0.6680, -0.4902, -0.2324, -0.1426, -0.2715,
-0.0479, -0.0664, 0.1670, 0.1377, -0.0918, -0.2275, 0.1357, 0.5820],
device='cuda:0', dtype=torch.bfloat16)
Sol q[b=0,h=0,t=0,:16]: tensor([-0.4551, -1.0234, 0.1245, -0.6680, -0.4922, -0.2324, -0.1426, -0.2695,
-0.0474, -0.0684, 0.1670, 0.1377, -0.0918, -0.2275, 0.1367, 0.5820],
device='cuda:0', dtype=torch.bfloat16)
Diff: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0020, 0.0000, 0.0000, 0.0020, 0.0005,
0.0020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0010, 0.0000],
device='cuda:0')
Ref q[b,h,t,64:80]: tensor([-0.3223, -0.6445, -0.2617, -0.3848, -0.8438, -0.3457, 0.2969, 0.3320,
-0.3145, 1.3047, 0.8164, 0.7969, 0.5742, -0.4297, 0.5859, 0.8203],
device='cuda:0', dtype=torch.bfloat16)
Sol q[b,h,t,64:80]: tensor([-0.3223, -0.6445, -0.2617, -0.3848, -0.8477, -0.3457, 0.2949, 0.3320,
-0.3145, 1.3047, 0.8164, 0.7969, 0.5742, -0.4297, 0.5859, 0.8203],
device='cuda:0', dtype=torch.bfloat16)
Diff: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0020, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
device='cuda:0')
{
"command": "<1887 chars \u2014 see below>",
"description": "Find big diff positions"
}uv run python -c "
import torch
import sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
mrope_section = (16, 24, 24)
max_pos = 1024
torch.manual_seed(42)
ref_model = reference.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
sol_model = solution.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
sol_model.load_state_dict(ref_model.state_dict(), strict=True)
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device=device) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device=device) * 0.5
pos_t = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_h = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_w = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
ref_q, ref_k = ref_model(q, k, pos_t, pos_h, pos_w)
sol_q, sol_k = sol_model(q, k, pos_t, pos_h, pos_w)
# Find where the big diffs are
diff_q = (ref_q.float() - sol_q.float()).abs()
diff_k = (ref_k.float() - sol_k.float()).abs()
# Find indices where diff > 1.0
big_q = (diff_q > 1.0).nonzero()
big_k = (diff_k > 1.0).nonzero()
print(f'big diff positions in q: {big_q.shape[0]}')
print(f'big diff positions in k: {big_k.shape[0]}')
if big_q.shape[0] > 0:
for idx in big_q[:5]:
b, h, t, d = idx.tolist()
print(f' q[{b},{h},{t},{d}]: ref={ref_q[b,h,t,d]} sol={sol_q[b,h,t,d]} diff={diff_q[b,h,t,d]}')
if big_k.shape[0] > 0:
for idx in big_k[:5]:
b, h, t, d = idx.tolist()
print(f' k[{b},{h},{t},{d}]: ref={ref_k[b,h,t,d]} sol={sol_k[b,h,t,d]} diff={diff_k[b,h,t,d]}')
# Overall stats
print(f'Max q diff at: b,h,t,d = ', end='')
idx = diff_q.argmax().item()
idx_flat = diff_q.flatten()
flat_idx = torch.unravel_index(idx, diff_q.shape)
print(flat_idx, f'ref={ref_q[flat_idx]} sol={sol_q[flat_idx]} diff={diff_q[flat_idx]}')
" 2>&1
Traceback (most recent call last):
File "<string>", line 48, in <module>
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/functional.py", line 1962, in unravel_index
res_tensor = _unravel_index(indices, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/infatoshi/cuda/KernelBench-Hard/.venv/lib/python3.11/site-packages/torch/functional.py", line 1968, in _unravel_index
not indices.is_complex()
^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'is_complex'
big diff positions in q: 4386
big diff positions in k: 2168
q[0,0,1,15]: ref=0.859375 sol=-0.81640625 diff=1.67578125
q[0,0,1,18]: ref=1.15625 sol=-0.126953125 diff=1.283203125
q[0,0,1,24]: ref=0.404296875 sol=-1.265625 diff=1.669921875
q[0,0,1,37]: ref=0.5078125 sol=-1.140625 diff=1.6484375
q[0,0,1,38]: ref=0.62890625 sol=-0.404296875 diff=1.033203125
k[0,0,1,2]: ref=0.78515625 sol=-0.69140625 diff=1.4765625
k[0,0,1,3]: ref=-0.486328125 sol=0.5625 diff=1.048828125
k[0,0,1,7]: ref=1.15625 sol=-1.046875 diff=2.203125
k[0,0,1,10]: ref=-0.6328125 sol=1.171875 diff=1.8046875
k[0,0,1,13]: ref=-0.890625 sol=0.6796875 diff=1.5703125
Max q diff at: b,h,t,d = {
"command": "<1965 chars \u2014 see below>",
"description": "Trace the indexing bug"
}uv run python -c "
import torch
import sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
BT = B * T
mrope_section = (16, 24, 24)
max_pos = 1024
half_D = D // 2
torch.manual_seed(42)
ref_model = reference.Model(B, T, Hq, Hkv, D, mrope_section, max_pos).to(device).eval()
q = torch.randn(B, T, Hq, D, dtype=torch.bfloat16, device=device) * 0.5
k = torch.randn(B, T, Hkv, D, dtype=torch.bfloat16, device=device) * 0.5
pos_t = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_h = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
pos_w = torch.randint(0, max_pos, (B, T), dtype=torch.int64, device=device)
# Manually run the mix kernel
import triton.language as tl
s0, s1, s2 = mrope_section
BLOCK_D = 128
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref_model.cos_cache, ref_model.sin_cache,
pos_t, pos_h, pos_w,
cos_mixed, sin_mixed,
B, T, D, max_pos,
s0, s1, s2, half_D,
pos_t.stride(0),
BLOCK_D=BLOCK_D,
)
print('cos_mixed shape:', cos_mixed.shape, 'strides:', cos_mixed.stride())
print('cos_mixed.stride(0):', cos_mixed.stride(0))
print('cos_mixed.stride(1):', cos_mixed.stride(1))
# What the kernel sees for the rope kernel:
# cos_base = b * stride_cos_b + t * stride_cos_t
# In the rope kernel, we pass: cos_mixed.stride(0), cos_mixed.stride(1), cos_mixed.stride(-1)
# For (BT, D) tensor: stride(0) = D, stride(1) = 1
# So: b * D + t * 1 = b*D + t. But we want: (b*T + t)*D = b*T*D + t*D
# For b=0, t=1: kernel computes 0*128 + 1*1 = 1
# Should be: (0*64 + 1)*128 = 128
print()
print('For b=0, t=1:')
print(f' Kernel would compute cos_base = 0*{D} + 1*1 = 1')
print(f' Should be: (0*{T} + 1)*{D} = {1*D}')
print(f' cos_mixed[0,:4] = {cos_mixed[0,:4]}')
print(f' cos_mixed[1,:4] = {cos_mixed[1,:4]}')
" 2>&1
cos_mixed shape: torch.Size([64, 128]) strides: (128, 1)
cos_mixed.stride(0): 128
cos_mixed.stride(1): 1
For b=0, t=1:
Kernel would compute cos_base = 0*128 + 1*1 = 1
Should be: (0*64 + 1)*128 = 128
cos_mixed[0,:4] = tensor([-0.7109, -0.5117, -0.9883, -0.0603], device='cuda:0',
dtype=torch.bfloat16)
cos_mixed[1,:4] = tensor([-0.0486, -0.4453, -0.9336, -1.0000], device='cuda:0',
dtype=torch.bfloat16)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: """Fused multi-axis RoPE pre-attention prep via Triton kernels.
2:
3: Strategy:
4: 1. Precompute mixed cos/sin: gather per-axis values from cache and mix
5: according to mrope_section. Shape (B, T, D). Done once (shared across heads).
6: 2. Fused rotate-half + transpose kernel per tensor (q, k).
7: Grid: (B*T, num_heads). Each block processes all D elements, writes
8: directly into (B, H, T, D) output layout.
9: """
10: from __future__ import annotations
11:
12: import torch
13: import torch.nn as nn
14: import triton
15: import triton.language as tl
16:
17: OP_TYPE = "rope"
18: SUPPORTED_PRECISIONS = ["bf16"]
19: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
20:
21: # ---------------------------------------------------------------------------
22: # Triton kernels
23: # ---------------------------------------------------------------------------
24:
25: @triton.jit
26: def _mix_cos_sin_kernel(
27: cos_cache_ptr, sin_cache_ptr,
28: pos_t_ptr, pos_h_ptr, pos_w_ptr,
29: cos_out_ptr, sin_out_ptr,
30: B, T, D, max_pos,
31: s0, s1, s2, half_D,
32: stride_pos_bt,
33: BLOCK_D: tl.constexpr,
34: ):
35: """Gather per-axis cos/sin from cache, mix by mrope_section, write (B,T,D).
36:
37: Grid: (B*T,) -- one program per (b,t).
38: """
39: pid = tl.program_id(0)
40: b = pid // T
41: t = pid % T
42:
... (8761 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -90,13 +90,14 @@
x_ptr, cos_ptr, sin_ptr, out_ptr,
B, T, H, D, half_D,
stride_x_b, stride_x_t, stride_x_h,
- stride_cos_b, stride_cos_t, stride_cos_d,
+ stride_cos_bt,
stride_out_b, stride_out_h, stride_out_t,
BLOCK_D: tl.constexpr,
):
"""Apply rotate-half with pre-mixed cos/sin, write transposed to (B,H,T,D).
Grid: (B*T, H) -- program 0 gives (b,t), program 1 gives h.
+ cos/sin layout: (B*T, D) flat, indexed by pid_bt * D.
rotate_half(x): [-x[half_D:D], x[0:half_D]]
out = x * cos + rotate_half(x) * sin
@@ -112,11 +113,11 @@
b = pid_bt // T
t = pid_bt % T
- # Load cos/sin for this (b,t)
+ # Load cos/sin for this (b,t) -- cos_mixed is (B*T, D) flat
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D
- cos_base = b * stride_cos_b + t * stride_cos_t
+ cos_base = pid_bt * stride_cos_bt
cos_vals = tl.load(cos_ptr + cos_base + d_offs, mask=mask)
sin_vals = tl.load(sin_ptr + cos_base + d_offs, mask=mask)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<1046 chars \u2014 see below>",
"newString": "<1094 chars \u2014 see below>"
}@triton.jit
def _rope_transpose_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
B, T, H, D, half_D,
stride_x_b, stride_x_t, stride_x_h,
stride_cos_b, stride_cos_t, stride_cos_d,
stride_out_b, stride_out_h, stride_out_t,
BLOCK_D: tl.constexpr,
):
"""Apply rotate-half with pre-mixed cos/sin, write transposed to (B,H,T,D).
Grid: (B*T, H) -- program 0 gives (b,t), program 1 gives h.
rotate_half(x): [-x[half_D:D], x[0:half_D]]
out = x * cos + rotate_half(x) * sin
For d in [0, half_D):
out[d] = x[d]*cos[d] - x[d+half_D]*sin[d]
For d in [half_D, D): (let d'=d-half_D)
out[d] = x[d]*cos[d'] + x[d']*sin[d']
"""
pid_bt = tl.program_id(0)
pid_h = tl.program_id(1)
b = pid_bt // T
t = pid_bt % T
# Load cos/sin for this (b,t)
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D
cos_base = b * stride_cos_b + t * stride_cos_t
cos_vals = tl.load(cos_ptr + cos_base + d_offs, mask=mask)
sin_vals = tl.load(sin_ptr + cos_base + d_offs, mask=mask)@triton.jit
def _rope_transpose_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
B, T, H, D, half_D,
stride_x_b, stride_x_t, stride_x_h,
stride_cos_bt,
stride_out_b, stride_out_h, stride_out_t,
BLOCK_D: tl.constexpr,
):
"""Apply rotate-half with pre-mixed cos/sin, write transposed to (B,H,T,D).
Grid: (B*T, H) -- program 0 gives (b,t), program 1 gives h.
cos/sin layout: (B*T, D) flat, indexed by pid_bt * D.
rotate_half(x): [-x[half_D:D], x[0:half_D]]
out = x * cos + rotate_half(x) * sin
For d in [0, half_D):
out[d] = x[d]*cos[d] - x[d+half_D]*sin[d]
For d in [half_D, D): (let d'=d-half_D)
out[d] = x[d]*cos[d'] + x[d']*sin[d']
"""
pid_bt = tl.program_id(0)
pid_h = tl.program_id(1)
b = pid_bt // T
t = pid_bt % T
# Load cos/sin for this (b,t) -- cos_mixed is (B*T, D) flat
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D
cos_base = pid_bt * stride_cos_bt
cos_vals = tl.load(cos_ptr + cos_base + d_offs, mask=mask)
sin_vals = tl.load(sin_ptr + cos_base + d_offs, mask=mask)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -215,7 +215,7 @@
q, cos_mixed, sin_mixed, q_rot,
B, T, Hq, D, half_D,
q.stride(0), q.stride(1), q.stride(2),
- cos_mixed.stride(0), cos_mixed.stride(1), cos_mixed.stride(-1) if cos_mixed.dim() > 1 else 1,
+ cos_mixed.stride(0),
q_rot.stride(0), q_rot.stride(1), q_rot.stride(2),
BLOCK_D=BLOCK_D,
)
@@ -227,7 +227,7 @@
k, cos_mixed, sin_mixed, k_rot,
B, T, Hkv, D, half_D,
k.stride(0), k.stride(1), k.stride(2),
- cos_mixed.stride(0), cos_mixed.stride(1), cos_mixed.stride(-1),
+ cos_mixed.stride(0),
k_rot.stride(0), k_rot.stride(1), k_rot.stride(2),
BLOCK_D=BLOCK_D,
)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<1045 chars \u2014 see below>",
"newString": "<929 chars \u2014 see below>"
} # Step 2: fused rotate-half + transpose for q
q_rot = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
grid_q = (BT, Hq)
_rope_transpose_kernel[grid_q](
q, cos_mixed, sin_mixed, q_rot,
B, T, Hq, D, half_D,
q.stride(0), q.stride(1), q.stride(2),
cos_mixed.stride(0), cos_mixed.stride(1), cos_mixed.stride(-1) if cos_mixed.dim() > 1 else 1,
q_rot.stride(0), q_rot.stride(1), q_rot.stride(2),
BLOCK_D=BLOCK_D,
)
# Step 3: fused rotate-half + transpose for k
k_rot = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
grid_k = (BT, Hkv)
_rope_transpose_kernel[grid_k](
k, cos_mixed, sin_mixed, k_rot,
B, T, Hkv, D, half_D,
k.stride(0), k.stride(1), k.stride(2),
cos_mixed.stride(0), cos_mixed.stride(1), cos_mixed.stride(-1),
k_rot.stride(0), k_rot.stride(1), k_rot.stride(2),
BLOCK_D=BLOCK_D,
) # Step 2: fused rotate-half + transpose for q
q_rot = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
grid_q = (BT, Hq)
_rope_transpose_kernel[grid_q](
q, cos_mixed, sin_mixed, q_rot,
B, T, Hq, D, half_D,
q.stride(0), q.stride(1), q.stride(2),
cos_mixed.stride(0),
q_rot.stride(0), q_rot.stride(1), q_rot.stride(2),
BLOCK_D=BLOCK_D,
)
# Step 3: fused rotate-half + transpose for k
k_rot = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
grid_k = (BT, Hkv)
_rope_transpose_kernel[grid_k](
k, cos_mixed, sin_mixed, k_rot,
B, T, Hkv, D, half_D,
k.stride(0), k.stride(1), k.stride(2),
cos_mixed.stride(0),
k_rot.stride(0), k_rot.stride(1), k_rot.stride(2),
BLOCK_D=BLOCK_D,
)Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run check.py after fix"
}PASS
{
"todos": [
{
"content": "Read reference.py carefully, understand the _mix_axes bug/behavior",
"status": "completed",
"priority": "high"
},
{
"content": "Implement Triton kernel: gather + mix + rotate + transpose fused",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Run benchmark.py for performance",
"status": "in_progress",
"priority": "medium"
},
{
"content": "Iterate on perf
... (123 more chars)content: 'Read reference.py carefully, understand the _mix_axes bug/behavior'
status: 'completed'
priority: 'high'
content: 'Implement Triton kernel: gather + mix + rotate + transpose fused'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py for performance'
status: 'in_progress'
priority: 'medium'
content: 'Iterate on performance (tune block sizes, load coalescing, shared mem)'
status: 'pending'
priority: 'medium'{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run benchmark.py for performance"
}shape=0 variant=eager tflops=0.178 gbps=255.120 ms=0.354
shape=0 variant=compiled tflops=0.493 gbps=707.846 ms=0.128
shape=0 variant=solution tflops=0.584 gbps=838.621 ms=0.108
shape=0 solution_peak_fraction=0.4659
shape=1 variant=eager tflops=0.176 gbps=256.414 ms=0.573
shape=1 variant=compiled tflops=0.493 gbps=719.728 ms=0.204
shape=1 variant=solution tflops=0.657 gbps=959.713 ms=0.153
shape=1 solution_peak_fraction=0.5332
shape=2 variant=eager tflops=0.074 gbps=114.894 ms=0.193
shape=2 variant=compiled tflops=0.168 gbps=263.164 ms=0.084
shape=2 variant=solution tflops=0.214 gbps=334.236 ms=0.066
shape=2 solution_peak_fraction=0.1857
shape=3 variant=eager tflops=0.177 gbps=254.275 ms=1.420
shape=3 variant=compiled tflops=0.460 gbps=659.873 ms=0.547
shape=3 variant=solution tflops=0.792 gbps=1136.117 ms=0.318
shape=3 solution_peak_fraction=0.6312
peak_fraction: 0.4131
RESULT: OK
{
"todos": [
{
"content": "Read reference.py carefully, understand the _mix_axes bug/behavior",
"status": "completed",
"priority": "high"
},
{
"content": "Implement Triton kernel: gather + mix + rotate + transpose fused",
"status": "completed",
"priority": "high"
},
{
"content": "Run check.py for correctness",
"status": "completed",
"priority": "high"
},
{
"content": "Run benchmark.py for performance",
"status": "completed",
"priority": "medium"
},
{
"content": "Iterate: fuse gat
... (407 more chars)content: 'Read reference.py carefully, understand the _mix_axes bug/behavior'
status: 'completed'
priority: 'high'
content: 'Implement Triton kernel: gather + mix + rotate + transpose fused'
status: 'completed'
priority: 'high'
content: 'Run check.py for correctness'
status: 'completed'
priority: 'high'
content: 'Run benchmark.py for performance'
status: 'completed'
priority: 'medium'
content: 'Iterate: fuse gather+mix into rotation kernel (eliminate intermediate)'
status: 'in_progress'
priority: 'high'
content: 'Iterate: share cos/sin gather across heads in single kernel'
status: 'pending'
priority: 'medium'
content: 'Profile with ncu and optimize coalescing/block sizes'
status: 'pending'
priority: 'low'--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,11 +1,14 @@
"""Fused multi-axis RoPE pre-attention prep via Triton kernels.
-Strategy:
- 1. Precompute mixed cos/sin: gather per-axis values from cache and mix
- according to mrope_section. Shape (B, T, D). Done once (shared across heads).
- 2. Fused rotate-half + transpose kernel per tensor (q, k).
- Grid: (B*T, num_heads). Each block processes all D elements, writes
- directly into (B, H, T, D) output layout.
+Fully-fused single-kernel approach:
+ - Grid: (B*T,) -- one program per (b,t) position.
+ - Each program processes ALL q heads and ALL k heads.
+ - Position IDs loaded once, cos/sin gathered+mixed once, shared across heads.
+ - Writes directly into transposed (B,H,T,D) output layout.
+
+This eliminates all intermediate tensors and maximally shares the cos/sin
+bandwidth across heads. Memory traffic per (b,t): 2*D (q) + 2*D (k) once,
+plus ~2*D (x cos/sin) per axis = 6*D gather reads, then 2*D output writes.
"""
from __future__ import annotations
@@ -23,35 +26,22 @@
# ---------------------------------------------------------------------------
@triton.jit
-def _mix_cos_sin_kernel(
+def _mix_cos_sin(
cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- cos_out_ptr, sin_out_ptr,
- B, T, D, max_pos,
- s0, s1, s2, half_D,
- stride_pos_bt,
+ pos_t, pos_h, pos_w,
+ D, max_pos, half_D, s0, s1, s2,
+ cos_vals, sin_vals,
BLOCK_D: tl.constexpr,
):
- """Gather per-axis cos/sin from cache, mix by mrope_section, write (B,T,D).
-
- Grid: (B*T,) -- one program per (b,t).
- """
- pid = tl.program_id(0)
- b = pid // T
- t = pid % T
-
- pos_base = b * stride_pos_bt + t
- pt = tl.load(pos_t_ptr + pos_base)
- ph = tl.load(pos_h_ptr + pos_base)
- pw = tl.load(pos_w_ptr + pos_base)
-
+ """Gather per-axis cos/sin from 3 position IDs, mix by mrope_section.
+ Called inline per (b,t). Produces cos_vals and sin_vals arrays of length D.
+ pos_t, pos_h, pos_w are scalar position IDs."""
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D
- # Compute row offsets into (max_pos, D) cache
- row_t = pt.to(tl.int64) * D
- row_h = ph.to(tl.int64) * D
- row_w = pw.to(tl.int64) * D
+ row_t = pos_t.to(tl.int64) * D
+ row_h = pos_h.to(tl.int64) * D
+ row_w = pos_w.to(tl.int64) * D
cos_t = tl.load(cos_cache_ptr + row_t + d_offs, mask=mask)
cos_h = tl.load(cos_cache_ptr + row_h + d_offs, mask=mask)
@@ -61,102 +51,139 @@
sin_h = tl.load(sin_cache_ptr + row_h + d_offs, mask=mask)
sin_w = tl.load(sin_cache_ptr + row_w + d_offs, mask=mask)
- # Mix: choose per-index axis according to mrope_section
- # [0:s0) -> t, [s0:s0+s1) -> h, [s0+s1:half_D) -> w
- # [half_D:half_D+s0) -> t, [half_D+s0:half_D+s0+s1) -> h, [half_D+s0+s1:D) -> w
cut0_s1 = s0 + s1
-
- cos_mixed = tl.where(
+ cos_vals = tl.where(
d_offs < s0, cos_t,
tl.where(d_offs < cut0_s1, cos_h,
tl.where(d_offs < half_D, cos_w,
tl.where(d_offs < half_D + s0, cos_t,
tl.where(d_offs < half_D + cut0_s1, cos_h, cos_w)))))
- sin_mixed = tl.where(
+ sin_vals = tl.where(
d_offs < s0, sin_t,
tl.where(d_offs < cut0_s1, sin_h,
tl.where(d_offs < half_D, sin_w,
tl.where(d_offs < half_D + s0, sin_t,
tl.where(d_offs < half_D + cut0_s1, sin_h, sin_w)))))
- out_base = pid * D
- tl.store(cos_out_ptr + out_base + d_offs, cos_mixed, mask=mask)
- tl.store(sin_out_ptr + out_base + d_offs, sin_mixed, mask=mask)
+ return cos_vals, sin_vals
@triton.jit
-def _rope_transpose_kernel(
- x_ptr, cos_ptr, sin_ptr, out_ptr,
- B, T, H, D, half_D,
- stride_x_b, stride_x_t, stride_x_h,
- stride_cos_bt,
- stride_out_b, stride_out_h, stride_out_t,
+def _apply_rotate(
+ x_ptr, cos_vals, sin_vals, out_ptr, D, half_D,
BLOCK_D: tl.constexpr,
):
- """Apply rotate-half with pre-mixed cos/sin, write transposed to (B,H,T,D).
-
- Grid: (B*T, H) -- program 0 gives (b,t), program 1 gives h.
- cos/sin layout: (B*T, D) flat, indexed by pid_bt * D.
-
- rotate_half(x): [-x[half_D:D], x[0:half_D]]
- out = x * cos + rotate_half(x) * sin
-
- For d in [0, half_D):
- out[d] = x[d]*cos[d] - x[d+half_D]*sin[d]
- For d in [half_D, D): (let d'=d-half_D)
- out[d] = x[d]*cos[d'] + x[d']*sin[d']
- """
- pid_bt = tl.program_id(0)
- pid_h = tl.program_id(1)
-
- b = pid_bt // T
- t = pid_bt % T
-
- # Load cos/sin for this (b,t) -- cos_mixed is (B*T, D) flat
+ """Apply rotate-half: out[d] = x[d]*cos - x[d+H]*sin for first half,
+ out[d+H] = x[d+H]*cos + x[d]*sin for second half. x_ptr and out_ptr
+ are pointers to arrays of length D."""
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D
-
- cos_base = pid_bt * stride_cos_bt
- cos_vals = tl.load(cos_ptr + cos_base + d_offs, mask=mask)
- sin_vals = tl.load(sin_ptr + cos_base + d_offs, mask=mask)
-
- # Load x for this (b,t,h)
- x_base = b * stride_x_b + t * stride_x_t + pid_h * stride_x_h
-
- # First half [0: half_D)
mask_first = d_offs < half_D
- x_d = tl.load(x_ptr + x_base + d_offs, mask=mask_first)
-
- # Second half [half_D: D), loaded at offset d_offs (same range)
- # We load x[d' + half_D] where d' = d_offs
- mask_second_src = (d_offs + half_D) < D
- x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=mask_second_src)
-
- # Compute rotated values
- # First half: out[d] = x[d]*cos[d] - x[d+half_D]*sin[d]
+ mask_second = (d_offs + half_D) < D
+
+ x_d = tl.load(x_ptr + d_offs, mask=mask_first)
+ x_dh = tl.load(x_ptr + d_offs + half_D, mask=mask_second)
+
out_first = x_d * cos_vals - x_dh * sin_vals
-
- #
... (6674 more chars){
"content": "<8862 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "<902 chars \u2014 see below>",
"description": "Quick correctness test on fused kernel"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
sol = solution.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
rq, rk = ref(q, k, pos_t, pos_h, pos_w)
sq, sk = sol(q, k, pos_t, pos_h, pos_w)
print('q max diff:', (rq.float()-sq.float()).abs().max().item())
print('k max diff:', (rk.float()-sk.float()).abs().max().item())
" 2>&1
q max diff: 0.0078125
k max diff: 0.0078125
{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Run check.py on fused kernel"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Benchmark fused kernel"
}shape=0 variant=eager tflops=0.180 gbps=258.890 ms=0.349
shape=0 variant=compiled tflops=0.486 gbps=696.659 ms=0.130
shape=0 variant=solution tflops=0.633 gbps=908.281 ms=0.099
shape=0 solution_peak_fraction=0.5046
shape=1 variant=eager tflops=0.209 gbps=305.714 ms=0.481
shape=1 variant=compiled tflops=0.488 gbps=712.141 ms=0.206
shape=1 variant=solution tflops=0.720 gbps=1051.302 ms=0.140
shape=1 solution_peak_fraction=0.5841
shape=2 variant=eager tflops=0.067 gbps=104.034 ms=0.213
shape=2 variant=compiled tflops=0.171 gbps=266.924 ms=0.083
shape=2 variant=solution tflops=0.285 gbps=445.074 ms=0.050
shape=2 solution_peak_fraction=0.2473
shape=3 variant=eager tflops=0.179 gbps=256.588 ms=1.407
shape=3 variant=compiled tflops=0.467 gbps=669.682 ms=0.539
shape=3 variant=solution tflops=0.877 gbps=1258.025 ms=0.287
shape=3 solution_peak_fraction=0.6989
peak_fraction: 0.4751
RESULT: OK
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"offset": 145
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
145: q_h_in = q_in_base + h * stride_q_h
146: q_h_out = q_out_base + h * stride_qo_h
147:
148: mask_first = d_offs < half_D
149: mask_second = (d_offs + half_D) < D
150:
151: x_d = tl.load(q_ptr + q_h_in + d_offs, mask=mask_first)
152: x_dh = tl.load(q_ptr + q_h_in + d_offs + half_D, mask=mask_second)
153:
154: out_first = x_d * cos_vals - x_dh * sin_vals
155: out_second = x_dh * cos_vals + x_d * sin_vals
156:
157: tl.store(q_out_ptr + q_h_out + d_offs, out_first, mask=mask_first)
158: tl.store(q_out_ptr + q_h_out + d_offs + half_D, out_second, mask=mask_second)
159:
160: # Process all k heads
161: k_in_base = b * stride_k_b + t * stride_k_t
162: k_out_base = b * stride_ko_b + t * stride_ko_t
163: for h in range(Hkv):
164: k_h_in = k_in_base + h * stride_k_h
165: k_h_out = k_out_base + h * stride_ko_h
166:
167: mask_first = d_offs < half_D
168: mask_second = (d_offs + half_D) < D
169:
170: x_d = tl.load(k_ptr + k_h_in + d_offs, mask=mask_first)
171: x_dh = tl.load(k_ptr + k_h_in + d_offs + half_D, mask=mask_second)
172:
173: out_first = x_d * cos_vals - x_dh * sin_vals
174: out_second = x_dh * cos_vals + x_d * sin_vals
175:
176: tl.store(k_out_ptr + k_h_out + d_offs, out_first, mask
... (3420 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,16 +1,16 @@
"""Fused multi-axis RoPE pre-attention prep via Triton kernels.
-Fully-fused single-kernel approach:
- - Grid: (B*T,) -- one program per (b,t) position.
- - Each program processes ALL q heads and ALL k heads.
- - Position IDs loaded once, cos/sin gathered+mixed once, shared across heads.
- - Writes directly into transposed (B,H,T,D) output layout.
-
-This eliminates all intermediate tensors and maximally shares the cos/sin
-bandwidth across heads. Memory traffic per (b,t): 2*D (q) + 2*D (k) once,
-plus ~2*D (x cos/sin) per axis = 6*D gather reads, then 2*D output writes.
+Two-stage approach for maximal parallelism:
+ Stage 1: grid=(B*T,). Gather+mix cos/sin once per position => (B*T, D).
+ Stage 2: grid=(B*T, Hq+Hkv). Each block processes ONE head:
+ - reads pre-mixed cos/sin from stage 1 output
+ - applies rotate-half
+ - writes transposed output
+ Cos/sin bandwidth shared across ALL heads with minimal latency.
"""
from __future__ import annotations
+
+import functools
import torch
import torch.nn as nn
@@ -22,26 +22,37 @@
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
# ---------------------------------------------------------------------------
-# Triton kernels
+# Stage 1: gather + mix cos/sin
# ---------------------------------------------------------------------------
@triton.jit
-def _mix_cos_sin(
+def _mix_cos_sin_kernel(
cos_cache_ptr, sin_cache_ptr,
- pos_t, pos_h, pos_w,
- D, max_pos, half_D, s0, s1, s2,
- cos_vals, sin_vals,
+ pos_t_ptr, pos_h_ptr, pos_w_ptr,
+ cos_out_ptr, sin_out_ptr,
+ B, T, D, max_pos,
+ s0, s1, s2, half_D,
+ stride_pos,
BLOCK_D: tl.constexpr,
):
- """Gather per-axis cos/sin from 3 position IDs, mix by mrope_section.
- Called inline per (b,t). Produces cos_vals and sin_vals arrays of length D.
- pos_t, pos_h, pos_w are scalar position IDs."""
+ pid = tl.program_id(0)
+
+ pos_base = pid * stride_pos + (pid % T) # wrong -- fix below
+ # Actually: pid = b * T + t, and pos has shape (B, T) with stride[0]=T, stride[1]=1
+ # So pos_idx = b * T + t = pid
+ b = pid // T
+ t = pid % T
+ pos_base = b * stride_pos + t
+ pt = tl.load(pos_t_ptr + pos_base)
+ ph = tl.load(pos_h_ptr + pos_base)
+ pw = tl.load(pos_w_ptr + pos_base)
+
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D
- row_t = pos_t.to(tl.int64) * D
- row_h = pos_h.to(tl.int64) * D
- row_w = pos_w.to(tl.int64) * D
+ row_t = pt.to(tl.int64) * D
+ row_h = ph.to(tl.int64) * D
+ row_w = pw.to(tl.int64) * D
cos_t = tl.load(cos_cache_ptr + row_t + d_offs, mask=mask)
cos_h = tl.load(cos_cache_ptr + row_h + d_offs, mask=mask)
@@ -52,138 +63,64 @@
sin_w = tl.load(sin_cache_ptr + row_w + d_offs, mask=mask)
cut0_s1 = s0 + s1
- cos_vals = tl.where(
+
+ cos_mixed = tl.where(
d_offs < s0, cos_t,
tl.where(d_offs < cut0_s1, cos_h,
tl.where(d_offs < half_D, cos_w,
tl.where(d_offs < half_D + s0, cos_t,
tl.where(d_offs < half_D + cut0_s1, cos_h, cos_w)))))
- sin_vals = tl.where(
+ sin_mixed = tl.where(
d_offs < s0, sin_t,
tl.where(d_offs < cut0_s1, sin_h,
tl.where(d_offs < half_D, sin_w,
tl.where(d_offs < half_D + s0, sin_t,
tl.where(d_offs < half_D + cut0_s1, sin_h, sin_w)))))
- return cos_vals, sin_vals
-
+ out_base = pid * D
+ tl.store(cos_out_ptr + out_base + d_offs, cos_mixed, mask=mask)
+ tl.store(sin_out_ptr + out_base + d_offs, sin_mixed, mask=mask)
+
+
+# ---------------------------------------------------------------------------
+# Stage 2: rotate + transpose, one block per (b, t, head)
+# Each block loads cos/sin ONCE, applies to one head.
+# ---------------------------------------------------------------------------
@triton.jit
-def _apply_rotate(
- x_ptr, cos_vals, sin_vals, out_ptr, D, half_D,
+def _rope_transpose_kernel(
+ x_ptr, cos_ptr, sin_ptr, out_ptr,
+ B, T, H, D, half_D,
+ stride_x_b, stride_x_t, stride_x_h,
+ stride_cs,
+ stride_out_b, stride_out_h, stride_out_t,
BLOCK_D: tl.constexpr,
):
- """Apply rotate-half: out[d] = x[d]*cos - x[d+H]*sin for first half,
- out[d+H] = x[d+H]*cos + x[d]*sin for second half. x_ptr and out_ptr
- are pointers to arrays of length D."""
+ pid_bt = tl.program_id(0)
+ pid_h = tl.program_id(1)
+
d_offs = tl.arange(0, BLOCK_D)
mask = d_offs < D
- mask_first = d_offs < half_D
- mask_second = (d_offs + half_D) < D
-
- x_d = tl.load(x_ptr + d_offs, mask=mask_first)
- x_dh = tl.load(x_ptr + d_offs + half_D, mask=mask_second)
+
+ cos_base = pid_bt * stride_cs
+ cos_vals = tl.load(cos_ptr + cos_base + d_offs, mask=mask)
+ sin_vals = tl.load(sin_ptr + cos_base + d_offs, mask=mask)
+
+ x_base = pid_bt * stride_x_b + pid_h * stride_x_h
+ x_d = tl.load(x_ptr + x_base + d_offs, mask=(d_offs < D))
+ x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=((d_offs + half_D) < D))
out_first = x_d * cos_vals - x_dh * sin_vals
out_second = x_dh * cos_vals + x_d * sin_vals
- tl.store(out_ptr + d_offs, out_first, mask=mask_first)
- tl.store(out_ptr + d_offs + half_D, out_second, mask=mask_second)
-
-
-@triton.jit
-def _fused_mrope_kernel(
- # Inputs
- q_ptr, k_ptr,
- cos_cache_ptr, sin_cache_ptr,
- pos_t_ptr, pos_h_ptr, pos_w_ptr,
- # Outputs
- q_out_ptr, k_out_ptr,
- # Shapes
- B, T, Hq, Hkv, D, max_pos, half_D,
- s0, s1, s2,
- # Input strides (B, T, H, D layout)
- stride_q_b, stride_q_t, stride_q_h,
- stride_k_b, stride_k_t, stride_k_h,
- # Position strides (B, T) int64
- stride_pos,
- # Output strides (B, H, T, D layout)
- stride_qo_b, stride_qo
... (5746 more chars){
"content": "<7537 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Check correctness of 2-stage approach"
}FAIL: shape 0 {'B': 1, 'T': 4096, 'Hq': 32, 'Hkv': 8, 'D': 128, 'mrope_section': (16, 24, 24), 'max_pos': 32768} seed 42 q_rot: tolerance exceeded: max_abs_diff=3.38281 (atol=0.02, rtol=0.02)
{
"command": "<1356 chars \u2014 see below>",
"description": "Debug 2-stage kernel indexing"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
BT = B*T
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
sol = solution.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
# Check q reshape
q_flat = q.view(BT, Hq, D).contiguous()
print('q_flat shape:', q_flat.shape)
print('q_flat strides:', q_flat.stride())
print('q_flat[0,0,:4]:', q_flat[0,0,:4])
print('q[0,0,0,:4]:', q[0,0,0,:4]) # should be same as q_flat[0,0]
rq, rk = ref(q, k, pos_t, pos_h, pos_w)
sq, sk = sol(q, k, pos_t, pos_h, pos_w)
print('rq shape:', rq.shape, 'sq shape:', sq.shape)
print('q max diff:', (rq.float()-sq.float()).abs().max().item())
# Check specific elements
for t in [0, 1, 2]:
for h in range(min(2, Hq)):
d = (rq[0, h, t, :].float() - sq[0, h, t, :].float()).abs().max().item()
print(f' t={t} h={h} max_diff={d:.6f}')
" 2>&1
q_flat shape: torch.Size([64, 4, 128])
q_flat strides: (512, 128, 1)
q_flat[0,0,:4]: tensor([ 0.0972, 1.0781, -0.0859, 0.4238], device='cuda:0',
dtype=torch.bfloat16)
q[0,0,0,:4]: tensor([ 0.0972, 1.0781, -0.0859, 0.4238], device='cuda:0',
dtype=torch.bfloat16)
rq shape: torch.Size([1, 4, 64, 128]) sq shape: torch.Size([1, 4, 64, 128])
q max diff: 3.25
t=0 h=0 max_diff=0.602783
t=0 h=1 max_diff=2.007812
t=1 h=0 max_diff=1.029938
t=1 h=1 max_diff=2.843750
t=2 h=0 max_diff=0.936981
t=2 h=1 max_diff=1.544922
{
"command": "<2256 chars \u2014 see below>",
"description": "Debug stage separation"
}uv run python -c "
import torch, sys, reference
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
BT = B*T
half_D = D // 2
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
# Run stage 1
import solution
s0, s1, s2 = (16, 24, 24)
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache,
pos_t, pos_h, pos_w,
cos_mixed, sin_mixed,
B, T, D, 1024,
s0, s1, s2, half_D,
pos_t.stride(0),
BLOCK_D=128,
)
# Check stage 1 output
cos_t = ref.cos_cache[pos_t] # (B, T, D)
sin_t = ref.sin_cache[pos_t]
cos_h = ref.cos_cache[pos_h]
sin_h = ref.sin_cache[pos_h]
cos_w = ref.cos_cache[pos_w]
sin_w = ref.sin_cache[pos_w]
ref_cos_mixed = ref._mix_axes(cos_t, cos_h, cos_w) # (B, T, D)
ref_sin_mixed = ref._mix_axes(sin_t, sin_h, sin_w)
print('ref_cos_mixed shape:', ref_cos_mixed.shape)
print('cos_mixed shape:', cos_mixed.shape)
print()
# Compare cos_mixed
print('Stage 1 cos_mixed t=0, h=0:')
print(' ref[0,0,:8]:', ref_cos_mixed[0, 0, :8])
print(' sol[0,:8] :', cos_mixed[0, :8])
print(' Max diff:', (ref_cos_mixed[0, 0].float() - cos_mixed[0].float()).abs().max().item())
# Test: run kernel manually on one element
q_flat = q.view(BT, Hq, D).contiguous()
rq_ref, _ = ref(q, k, pos_t, pos_h, pos_w)
# Manually compute what ref gives for (b=0, h=1, t=0)
print()
print('Manual ref computation for (h=1, t=0):')
x = q[0, 0, 1, :] # (D,)
cos = ref_cos_mixed[0, 0, :] # (D,)
sin = ref_sin_mixed[0, 0, :] # (D,)
half = half_D
rh = torch.cat([-x[half:], x[:half]], dim=0)
out_ref = x * cos + rh * sin
print(' Manual out[:8]:', out_ref[:8])
print(' Ref out[:8]: ', rq_ref[0, 1, 0, :8])
print(' Match:', torch.allclose(out_ref.float(), rq_ref[0, 1, 0].float(), atol=0.01))
" 2>&1
ref_cos_mixed shape: torch.Size([1, 64, 128])
cos_mixed shape: torch.Size([64, 128])
Stage 1 cos_mixed t=0, h=0:
ref[0,0,:8]: tensor([-0.7109, -0.5117, -0.9883, -0.0603, 0.6602, -0.9531, 0.5664, 0.8359],
device='cuda:0', dtype=torch.bfloat16)
sol[0,:8] : tensor([-0.7109, -0.5117, -0.9883, -0.0603, 0.6602, -0.9531, 0.5664, 0.8359],
device='cuda:0', dtype=torch.bfloat16)
Max diff: 0.0
Manual ref computation for (h=1, t=0):
Manual out[:8]: tensor([-0.2061, -0.3379, 0.0571, -0.5508, 0.3633, 0.3516, -0.0283, 1.4141],
device='cuda:0', dtype=torch.bfloat16)
Ref out[:8]: tensor([-0.2061, -0.3379, 0.0571, -0.5508, 0.3633, 0.3516, -0.0283, 1.4141],
device='cuda:0', dtype=torch.bfloat16)
Match: True
{
"command": "<2262 chars \u2014 see below>",
"description": "Debug stage 2 kernel output"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
BT = B*T
half_D = D // 2
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
s0, s1, s2 = (16, 24, 24)
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache,
pos_t, pos_h, pos_w,
cos_mixed, sin_mixed,
B, T, D, 1024,
s0, s1, s2, half_D,
pos_t.stride(0),
BLOCK_D=128,
)
# Run kernel on a single head
q_flat = q.view(BT, Hq, D).contiguous()
q_rot_test = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=device)
# Run with grid (1, 1) to test one element
# pid_bt=0, pid_h=0
solution._rope_transpose_kernel[(1, 1)](
q_flat, cos_mixed, sin_mixed, q_rot_test,
B, T, Hq, D, half_D,
q_flat.stride(0), D, 1,
cos_mixed.stride(0),
q_rot_test.stride(0), q_rot_test.stride(1), q_rot_test.stride(2),
BLOCK_D=128,
)
# Check output for b=0, h=0, t=0
rq_ref, _ = ref(q, torch.randn(1,T,2,D,dtype=torch.bfloat16,device=device)*0.5, pos_t, pos_h, pos_w)
print('q_rot_test[0,0,0,:8]:', q_rot_test[0, 0, 0, :8])
print('rq_ref[0,0,0,:8]: ', rq_ref[0, 0, 0, :8])
print('Diff:', (q_rot_test[0, 0, 0].float() - rq_ref[0, 0, 0].float()).abs().max().item())
# Also check a non-zero head
q_rot_test2 = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_kernel[(BT, Hq)](
q_flat, cos_mixed, sin_mixed, q_rot_test2,
B, T, Hq, D, half_D,
q_flat.stride(0), D, 1,
cos_mixed.stride(0),
q_rot_test2.stride(0), q_rot_test2.stride(1), q_rot_test2.stride(2),
BLOCK_D=128,
)
print()
print('Full grid:')
for t in [0, 1]:
for h in range(4):
d = (q_rot_test2[0, h, t].float() - rq_ref[0, h, t].float()).abs().max().item()
print(f' t={t} h={h} max_diff={d:.6f}')
" 2>&1
q_rot_test[0,0,0,:8]: tensor([ 0.2598, -1.0469, -0.2891, -0.3770, 0.8945, 0.4082, -0.0505, -0.3516],
device='cuda:0', dtype=torch.bfloat16)
rq_ref[0,0,0,:8]: tensor([ 0.2617, -1.0469, -0.2891, -0.3770, 0.8945, 0.4082, -0.0505, -0.3516],
device='cuda:0', dtype=torch.bfloat16)
Diff: 0.90234375
Full grid:
t=0 h=0 max_diff=0.711884
t=0 h=1 max_diff=1.769531
t=0 h=2 max_diff=2.066406
t=0 h=3 max_diff=1.597656
t=1 h=0 max_diff=0.905220
t=1 h=1 max_diff=2.460938
t=1 h=2 max_diff=1.941406
t=1 h=3 max_diff=2.058594
{
"command": "<1931 chars \u2014 see below>",
"description": "Find exact element diffs"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
BT = B*T
half_D = D // 2
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k_dummy = torch.randn(1,T,2,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
s0, s1, s2 = (16, 24, 24)
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 1024, s0, s1, s2, half_D,
pos_t.stride(0), BLOCK_D=128,
)
q_flat = q.view(BT, Hq, D).contiguous()
q_rot_test = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_kernel[(BT, Hq)](
q_flat, cos_mixed, sin_mixed, q_rot_test,
B, T, Hq, D, half_D,
q_flat.stride(0), D, 1,
cos_mixed.stride(0),
q_rot_test.stride(0), q_rot_test.stride(1), q_rot_test.stride(2),
BLOCK_D=128,
)
rq_ref, _ = ref(q, k_dummy, pos_t, pos_h, pos_w)
# Check full d range for big diffs
for t in [0]:
for h in range(4):
diff = (q_rot_test[0, h, t].float() - rq_ref[0, h, t].float()).abs()
big = (diff > 0.5).nonzero()
if big.numel() > 0:
for idx in big[:6]:
d = idx.item()
print(f't={t} h={h} d={d}: ref={rq_ref[0,h,t,d]:.4f} sol={q_rot_test[0,h,t,d]:.4f} diff={diff[d]:.4f}')
print(f' Total big diffs for t={t} h={h}: {big.numel()}')
else:
print(f't={t} h={h}: no big diffs, max={diff.max().item():.4f}')
" 2>&1
t=0 h=0: no big diffs, max=0.0039
t=0 h=1 d=1: ref=-0.3379 sol=0.2832 diff=0.6211
t=0 h=1 d=2: ref=0.0571 sol=-0.5117 diff=0.5688
t=0 h=1 d=7: ref=1.4141 sol=0.3203 diff=1.0938
t=0 h=1 d=8: ref=-0.7422 sol=1.2656 diff=2.0078
t=0 h=1 d=9: ref=-0.0811 sol=-0.7422 diff=0.6611
t=0 h=1 d=10: ref=0.2656 sol=0.8086 diff=0.5430
Total big diffs for t=0 h=1: 56
t=0 h=2 d=0: ref=1.5000 sol=-0.1338 diff=1.6338
t=0 h=2 d=1: ref=0.2852 sol=-0.7695 diff=1.0547
t=0 h=2 d=2: ref=-0.0942 sol=0.9219 diff=1.0161
t=0 h=2 d=3: ref=-0.6641 sol=0.2383 diff=0.9023
t=0 h=2 d=5: ref=-0.0869 sol=0.4277 diff=0.5146
t=0 h=2 d=6: ref=-1.3594 sol=0.2969 diff=1.6562
Total big diffs for t=0 h=2: 62
t=0 h=3 d=1: ref=-0.1797 sol=0.3281 diff=0.5078
t=0 h=3 d=4: ref=0.5391 sol=-0.3652 diff=0.9043
t=0 h=3 d=6: ref=0.0713 sol=0.5820 diff=0.5107
t=0 h=3 d=7: ref=0.5234 sol=-0.6875 diff=1.2109
t=0 h=3 d=8: ref=-0.4062 sol=0.6797 diff=1.0859
t=0 h=3 d=12: ref=0.8672 sol=-0.8125 diff=1.6797
Total big diffs for t=0 h=3: 58
{
"command": "<3599 chars \u2014 see below>",
"description": "Trace exact indexing for h=0 vs h=1"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
BT = B*T
half_D = D // 2
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
# Check view mapping
q_flat = q.view(BT, Hq, D).contiguous()
print('q shape:', q.shape, 'q_flat shape:', q_flat.shape)
print('q_flat[0,1,:4] =', q_flat[0,1,:4])
print('q[0,0,1,:4] =', q[0,0,1,:4])
print('q_flat[0,0,:4] =', q_flat[0,0,:4])
print('q[0,0,0,:4] =', q[0,0,0,:4])
print()
# Check what the kernel computes for pid_bt=0, pid_h=1
# x_base = 0*512 + 1*128 = 128
print('Kernel for pid_bt=0, pid_h=1:')
print(' x_base = 128')
print(' loads from q_flat at offset 128: q_flat[0,1,0:4] =', q_flat[0,1,0:4])
print(' BUT q_flat[0,0,:4] =', q_flat[0,0,:4])
# The issue might be that pid_bt=0 maps to q_flat[0,:] = q_flat row 0 = q_flat[0, h, :] for varying h
# But wait, pid_bt=0 should be the first (b,t) position: b=0, t=0
# For this position, all h share the same cos/sin (from cos_mixed[0])
# The x values for head 0: q_flat[0, 0, :], head 1: q_flat[0, 1, :], etc.
# In kernel: x_base = pid_bt * stride_x_b + pid_h * stride_x_h
# stride_x_b = 512 (stride to next bt row), stride_x_h = 128 (stride to next head within same row)
# For pid_bt=0, pid_h=1: x_base = 0 * 512 + 1 * 128 = 128
# This means it reads q_flat[128:(128+128)] = q_flat[0,1,:]
print('That should be correct! q_flat[0,1,:] = q[0,0,1,:]')
print()
# Wait, but the output says h=0 is perfect but h≥1 is wrong.
# If the indexing was wrong, BOTH would be wrong. Let me look at the ref output for this position.
# Run the reference
pos_t = torch.randint(0, 1024, (B, T), dtype=torch.int64, device=device)
pos_h = torch.randint(0, 1024, (B, T), dtype=torch.int64, device=device)
pos_w = torch.randint(0, 1024, (B, T), dtype=torch.int64, device=device)
k_dummy = torch.randn(1, T, 2, D, dtype=torch.bfloat16, device=device) * 0.5
rq_ref, _ = ref(q, k_dummy, pos_t, pos_h, pos_w)
# rq_ref shape: (B=1, H=4, T=64, D=128)
print('Ref rq[0,1,0,:4]:', rq_ref[0,1,0,:4])
print('Ref rq[0,0,0,:4]:', rq_ref[0,0,0,:4])
# Manual computation for (b=0, h=1, t=0)
cos_t = ref.cos_cache[pos_t]
cos_h_ = ref.cos_cache[pos_h]
cos_w = ref.cos_cache[pos_w]
sin_t = ref.sin_cache[pos_t]
sin_h_ = ref.sin_cache[pos_h]
sin_w = ref.sin_cache[pos_w]
cos_m = ref._mix_axes(cos_t, cos_h_, cos_w) # (B,T,D)
sin_m = ref._mix_axes(sin_t, sin_h_, sin_w)
# cos_m shape is (1, 64, 128)
# For (b=0, t=0): cos_m[0,0,:]
print()
print('Manual cos_m[0,0,:4]:', cos_m[0,0,:4])
print('cos_mixed[0,:4]: ', solution._mix_cos_sin_kernel.__module__) # can't access directly
# Just run the kernel and check specific head output
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 1024, 16, 24, 24, half_D,
pos_t.stride(0), BLOCK_D=128,
)
print()
print('cos_mixed[0,:4]:', cos_mixed[0,:4])
print('cos_m[0,0,:4]:', cos_m[0,0,:4])
# Manually apply rotation for (b=0, h=1, t=0)
x = q[0, 0, 1, :] # (D,)
cos = cos_m[0, 0, :] # (D,)
sin = sin_m[0, 0, :] # (D,)
rh = torch.cat([-x[64:], x[:64]], dim=0)
expected = x * cos + rh * sin
print()
print('Manual expected[:4]:', expected[:4])
print('Ref rq[0,1,0,:4]: ', rq_ref[0,1,0,:4])
print('Match:', torch.allclose(expected.float(), rq_ref[0,1,0].float(), atol=0.01))
" 2>&1
q shape: torch.Size([1, 64, 4, 128]) q_flat shape: torch.Size([64, 4, 128])
q_flat[0,1,:4] = tensor([ 0.0986, 0.4883, -0.0143, 0.8398], device='cuda:0',
dtype=torch.bfloat16)
q[0,0,1,:4] = tensor([ 0.0986, 0.4883, -0.0143, 0.8398], device='cuda:0',
dtype=torch.bfloat16)
q_flat[0,0,:4] = tensor([ 0.0972, 1.0781, -0.0859, 0.4238], device='cuda:0',
dtype=torch.bfloat16)
q[0,0,0,:4] = tensor([ 0.0972, 1.0781, -0.0859, 0.4238], device='cuda:0',
dtype=torch.bfloat16)
Kernel for pid_bt=0, pid_h=1:
x_base = 128
loads from q_flat at offset 128: q_flat[0,1,0:4] = tensor([ 0.0986, 0.4883, -0.0143, 0.8398], device='cuda:0',
dtype=torch.bfloat16)
BUT q_flat[0,0,:4] = tensor([ 0.0972, 1.0781, -0.0859, 0.4238], device='cuda:0',
dtype=torch.bfloat16)
That should be correct! q_flat[0,1,:] = q[0,0,1,:]
Ref rq[0,1,0,:4]: tensor([ 0.1533, -0.3535, -0.2949, -0.8008], device='cuda:0',
dtype=torch.bfloat16)
Ref rq[0,0,0,:4]: tensor([ 0.2617, -1.0469, -0.2891, -0.3770], device='cuda:0',
dtype=torch.bfloat16)
Manual cos_m[0,0,:4]: tensor([ 0.9531, -0.5469, 0.2598, -0.9961], device='cuda:0',
dtype=torch.bfloat16)
cos_mixed[0,:4]: solution
cos_mixed[0,:4]: tensor([ 0.9531, -0.5469, 0.2598, -0.9961], device='cuda:0',
dtype=torch.bfloat16)
cos_m[0,0,:4]: tensor([ 0.9531, -0.5469, 0.2598, -0.9961], device='cuda:0',
dtype=torch.bfloat16)
Manual expected[:4]: tensor([ 0.1533, -0.3535, -0.2949, -0
... (177 more chars){
"command": "<2356 chars \u2014 see below>",
"description": "Test individual heads in isolation"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
BT = B*T
half_D = D // 2
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k_dummy = torch.randn(1,T,2,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
rq_ref, _ = ref(q, k_dummy, pos_t, pos_h, pos_w)
# Run stage 1
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 1024, 16, 24, 24, half_D,
pos_t.stride(0), BLOCK_D=128,
)
# Now test individual heads
q_flat = q.view(BT, Hq, D).contiguous()
for h_test in range(4):
q_out = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=device)
# Run with BT blocks, but only check head h_test
solution._rope_transpose_kernel[(BT, 1)](
q_flat, cos_mixed, sin_mixed, q_out,
B, T, 1, D, half_D,
q_flat.stride(0), 128, 1, # stride for (BT,1,D) shape
cos_mixed.stride(0),
q_out.stride(0), q_out.stride(1), q_out.stride(2),
BLOCK_D=128,
)
diff = (q_out[0, 0, :, :].float() - rq_ref[0, h_test, :, :].float()).abs().max().item()
print(f'h={h_test}: max_diff={diff:.6f} (pid_h is always 0 in this call)')
# Test with full grid but Hq=1
print()
for h_test in range(4):
q_h = q[:, :, h_test:h_test+1, :].contiguous() # (1, 64, 1, 128)
q_h_flat = q_h.view(BT, 1, D).contiguous()
q_out = torch.empty(B, 1, T, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_kernel[(BT, 1)](
q_h_flat, cos_mixed, sin_mixed, q_out,
B, T, 1, D, half_D,
q_h_flat.stride(0), 128, 1,
cos_mixed.stride(0),
q_out.stride(0), q_out.stride(1), q_out.stride(2),
BLOCK_D=128,
)
diff = (q_out[0, 0, :, :].float() - rq_ref[0, h_test, :, :].float()).abs().max().item()
print(f'h={h_test} (separate call): max_diff={diff:.6f}')
" 2>&1
h=0: max_diff=1.501953 (pid_h is always 0 in this call)
h=1: max_diff=2.773438 (pid_h is always 0 in this call)
h=2: max_diff=2.851562 (pid_h is always 0 in this call)
h=3: max_diff=2.601562 (pid_h is always 0 in this call)
h=0 (separate call): max_diff=1.396729
h=1 (separate call): max_diff=1.419922
h=2 (separate call): max_diff=1.529785
h=3 (separate call): max_diff=1.900146
{
"command": "<3067 chars \u2014 see below>",
"description": "Trace kernel computation element by element"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 64, 4, 2, 128
BT = B*T
half_D = D // 2
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k_dummy = torch.randn(1,T,2,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
rq_ref, _ = ref(q, k_dummy, pos_t, pos_h, pos_w)
# Run stage 1
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 1024, 16, 24, 24, half_D,
pos_t.stride(0), BLOCK_D=128,
)
# Test single block
q_flat = q.view(BT, Hq, D).contiguous()
# Check: what does q_flat[0,1,:] look like vs q_flat[0,0,:]?
# And what does the kernel compute?
# Manually compute expected for (b=0, t=0, h=0)
x0 = q[0, 0, 0, :]
cos0 = cos_mixed[0, :]
sin0 = sin_mixed[0, :]
rh0 = torch.cat([-x0[64:], x0[:64]], dim=0)
expected0 = x0 * cos0 + rh0 * sin0
# And for (b=0, t=0, h=1)
x1 = q[0, 0, 1, :]
rh1 = torch.cat([-x1[64:], x1[:64]], dim=0)
expected1 = x1 * cos0 + rh1 * sin0 # same cos/sin
print('Expected h=0, t=0, first 8:', expected0[:8])
print('Expected h=1, t=0, first 8:', expected1[:8])
print('Ref h=0, t=0, first 8:', rq_ref[0, 0, 0, :8])
print('Ref h=1, t=0, first 8:', rq_ref[0, 1, 0, :8])
print()
print('Expected h=0 matches ref:', torch.allclose(expected0.float(), rq_ref[0,0,0].float(), atol=0.01))
print('Expected h=1 matches ref:', torch.allclose(expected1.float(), rq_ref[0,1,0].float(), atol=0.01))
# Now run the kernel with grid (1, 2) and check
q_out = torch.empty(B, 2, T, D, dtype=torch.bfloat16, device=device).fill_(float('nan'))
solution._rope_transpose_kernel[(1, 2)](
q_flat, cos_mixed, sin_mixed, q_out,
B, T, 2, D, half_D,
q_flat.stride(0), D, 1,
cos_mixed.stride(0),
q_out.stride(0), q_out.stride(1), q_out.stride(2),
BLOCK_D=128,
)
print()
print('Kernel h=0, t=0, first 8:', q_out[0, 0, 0, :8])
print('Kernel h=1, t=0, first 8:', q_out[0, 1, 0, :8])
print('Diff h=0:', (q_out[0, 0, 0, :8].float() - rq_ref[0, 0, 0, :8].float()).abs().max().item())
print('Diff h=1:', (q_out[0, 1, 0, :8].float() - rq_ref[0, 1, 0, :8].float()).abs().max().item())
print()
# Check ALL elements
for h in range(2):
valid = ~torch.isnan(q_out[0, h, 0]).any()
diff = (q_out[0, h, 0, :128].float() - rq_ref[0, h, 0, :128].float()).abs()
print(f'h={h}: has_nan={not valid}, max_diff={diff.max().item():.6f}')
if diff.max() > 0.05:
print(f' first 16 kernel: {q_out[0, h, 0, :16]}')
print(f' first 16 ref: {rq_ref[0, h, 0, :16]}')
print(f' first 16 man: {expected0 if h==0 else expected1}')
" 2>&1
Expected h=0, t=0, first 8: tensor([-0.4551, -1.0234, 0.1245, -0.6680, -0.4902, -0.2324, -0.1426, -0.2715],
device='cuda:0', dtype=torch.bfloat16)
Expected h=1, t=0, first 8: tensor([-0.2061, -0.3379, 0.0571, -0.5508, 0.3633, 0.3516, -0.0283, 1.4141],
device='cuda:0', dtype=torch.bfloat16)
Ref h=0, t=0, first 8: tensor([-0.4551, -1.0234, 0.1245, -0.6680, -0.4902, -0.2324, -0.1426, -0.2715],
device='cuda:0', dtype=torch.bfloat16)
Ref h=1, t=0, first 8: tensor([-0.2061, -0.3379, 0.0571, -0.5508, 0.3633, 0.3516, -0.0283, 1.4141],
device='cuda:0', dtype=torch.bfloat16)
Expected h=0 matches ref: True
Expected h=1 matches ref: True
Kernel h=0, t=0, first 8: tensor([-0.4551, -1.0234, 0.1245, -0.6680, -0.4922, -0.2324, -0.1426, -0.2695],
device='cuda:0', dtype=torch.bfloat16)
Kernel h=1, t=0, first 8: tensor([-0.3789, 0.2832, -0.5117, -0.1328, 0.0220, 0.3242, -0.1250, 0.3203],
device='cuda:0', dtype=torch.bfloat16)
Diff h=0: 0.001953125
Diff h=1: 1.09375
h=0: has_nan=False, max_diff=0.925781
first 16 kernel: tensor([-0.4551, -1.0234, 0.1245, -0.6680, -0.4922, -0.2324, -0.1426, -0.2695,
-0.0474, -0.0684, 0.1670, 0.1377, -0.0918, -0.2275, 0.1367, 0.5820],
device='cuda:0', dtype=torch.bfloat16)
first 16 ref: tensor([-0.4551, -1.0234, 0.1245, -0.6680, -0.4902, -0.2324, -0.1426, -0.2715,
-0.0479, -0.0664, 0.1670, 0.1377, -0.0918, -0.2275, 0.1357, 0.5820],
device='cuda:0', dtype
... (3791 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,16 +1,14 @@
"""Fused multi-axis RoPE pre-attention prep via Triton kernels.
-Two-stage approach for maximal parallelism:
- Stage 1: grid=(B*T,). Gather+mix cos/sin once per position => (B*T, D).
- Stage 2: grid=(B*T, Hq+Hkv). Each block processes ONE head:
- - reads pre-mixed cos/sin from stage 1 output
- - applies rotate-half
- - writes transposed output
- Cos/sin bandwidth shared across ALL heads with minimal latency.
+Three-kernel approach:
+ 1. grid=(B*T,): gather + mix cos/sin once per position => (B*T, D)
+ 2. grid=(B*T*Hq,): rotate + transpose for q. Each program handles one head.
+ 3. grid=(B*T*Hkv,): rotate + transpose for k. Each program handles one head.
+
+Kernels 2/3 read pre-mixed cos/sin (shared across heads), avoiding redundant
+gather bandwidth. The flat (B*T*H,) grid maximizes GPU occupancy.
"""
from __future__ import annotations
-
-import functools
import torch
import torch.nn as nn
@@ -32,17 +30,14 @@
cos_out_ptr, sin_out_ptr,
B, T, D, max_pos,
s0, s1, s2, half_D,
- stride_pos,
+ stride_pos_b,
BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
-
- pos_base = pid * stride_pos + (pid % T) # wrong -- fix below
- # Actually: pid = b * T + t, and pos has shape (B, T) with stride[0]=T, stride[1]=1
- # So pos_idx = b * T + t = pid
b = pid // T
t = pid % T
- pos_base = b * stride_pos + t
+
+ pos_base = b * stride_pos_b + t
pt = tl.load(pos_t_ptr + pos_base)
ph = tl.load(pos_h_ptr + pos_base)
pw = tl.load(pos_w_ptr + pos_base)
@@ -84,43 +79,49 @@
# ---------------------------------------------------------------------------
-# Stage 2: rotate + transpose, one block per (b, t, head)
-# Each block loads cos/sin ONCE, applies to one head.
+# Stage 2/3: rotate + transpose
+# Grid: (B*T*H,) flat. Each program handles one head.
+# Read cos/sin from precomputed (B*T, D) buffer.
# ---------------------------------------------------------------------------
@triton.jit
-def _rope_transpose_kernel(
+def _rope_transpose_flat_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
- B, T, H, D, half_D,
- stride_x_b, stride_x_t, stride_x_h,
- stride_cs,
- stride_out_b, stride_out_h, stride_out_t,
+ BTH, T, H, D, half_D,
+ stride_x_bt, stride_x_h, # x: (B*T, H, D) flat view
+ stride_cs, # cos/sin: (B*T, D) flat view
+ stride_out_h, stride_out_t, # out: (H, T, D) for a single-b view
BLOCK_D: tl.constexpr,
):
- pid_bt = tl.program_id(0)
- pid_h = tl.program_id(1)
+ pid = tl.program_id(0)
+ bt = pid // H # which (b,t) position
+ h = pid % H # which head
d_offs = tl.arange(0, BLOCK_D)
- mask = d_offs < D
-
- cos_base = pid_bt * stride_cs
- cos_vals = tl.load(cos_ptr + cos_base + d_offs, mask=mask)
- sin_vals = tl.load(sin_ptr + cos_base + d_offs, mask=mask)
-
- x_base = pid_bt * stride_x_b + pid_h * stride_x_h
- x_d = tl.load(x_ptr + x_base + d_offs, mask=(d_offs < D))
+
+ # Load cos/sin
+ cs_base = bt * stride_cs
+ cos_vals = tl.load(cos_ptr + cs_base + d_offs, mask=(d_offs < D))
+ sin_vals = tl.load(sin_ptr + cs_base + d_offs, mask=(d_offs < D))
+
+ # Load x
+ x_base = bt * stride_x_bt + h * stride_x_h
+ x_d = tl.load(x_ptr + x_base + d_offs, mask=(d_offs < D))
x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=((d_offs + half_D) < D))
- out_first = x_d * cos_vals - x_dh * sin_vals
+ # Rotate
+ out_first = x_d * cos_vals - x_dh * sin_vals
out_second = x_dh * cos_vals + x_d * sin_vals
- out_base = pid_h * stride_out_h + pid_bt * stride_out_t
+ # Write transposed: out[h, bt, d]
+ out_base = h * stride_out_h + bt * stride_out_t
tl.store(out_ptr + out_base + d_offs, out_first, mask=(d_offs < D))
tl.store(out_ptr + out_base + d_offs + half_D, out_second, mask=((d_offs + half_D) < D))
def _next_power_of_2(n):
return 1 << (n - 1).bit_length()
+
class Model(nn.Module):
def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
@@ -168,31 +169,35 @@
BLOCK_D=BLOCK_D,
)
- # Stage 2: fused rotate + transpose
- # For q: grid=(B*T, Hq), input (B*T*Hq, D) flat, output (B, Hq, T, D)
- # Reshape q from (B,T,Hq,D) to (B*T, Hq, D) for simpler indexing
- q_flat = q.view(BT, Hq, D).contiguous()
- q_rot = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
- _rope_transpose_kernel[(BT, Hq)](
- q_flat, cos_mixed, sin_mixed, q_rot,
- B, T, Hq, D, half_D,
- q_flat.stride(0), D, 1,
+ # Stage 2: rotate + transpose q
+ # Reshape q from (B, T, Hq, D) to (BT, Hq, D) - contiguous
+ q_flat = q.reshape(BT, Hq, D) # O(1) view if contiguous in BTHD
+ q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=q.device)
+
+ _rope_transpose_flat_kernel[(BT * Hq,)](
+ q_flat, cos_mixed, sin_mixed, q_out,
+ BT, T, Hq, D, half_D,
+ q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
- q_rot.stride(0), q_rot.stride(1), q_rot.stride(2),
+ q_out.stride(0), q_out.stride(1),
BLOCK_D=BLOCK_D,
)
-
- # For k
- k_flat = k.view(BT, Hkv, D).contiguous()
- k_rot = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
- _rope_transpose_kernel[(BT, Hkv)](
- k_flat, cos_mixed, sin_mixed, k_rot,
- B, T, Hkv, D, half_D,
- k_flat.stride(0), D, 1,
+ # Reshape: (Hq, BT, D) -> (B, Hq, T, D)
+ q_rot = q_out.reshape(Hq, B, T, D).transpose(0, 1) # (B, Hq, T, D)
+
+ # Stage 3: rotate + transpose k
+ k_flat = k.reshape(BT, Hkv, D)
+ k_out = torch.empty(Hkv, BT, D, d
... (524 more chars){
"content": "<7749 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Check flat kernel correctness"
}FAIL: shape 0 {'B': 1, 'T': 4096, 'Hq': 32, 'Hkv': 8, 'D': 128, 'mrope_section': (16, 24, 24), 'max_pos': 32768} seed 42 q_rot: tolerance exceeded: max_abs_diff=1.66705 (atol=0.02, rtol=0.02)
{
"command": "<2559 chars \u2014 see below>",
"description": "Debug flat kernel with small shape"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 4, 2, 1, 128
BT = B*T
half_D = D // 2
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 256).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
rq, rk = ref(q, k, pos_t, pos_h, pos_w)
print('ref q shape:', rq.shape) # (1, 2, 4, 128)
sol = solution.Model(B, T, Hq, Hkv, D, (16,24,24), 256).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
sq, sk = sol(q, k, pos_t, pos_h, pos_w)
print('sol q shape:', sq.shape)
print('q max diff:', (rq.float()-sq.float()).abs().max().item())
# Debug: check kernel stages
s0, s1, s2 = (16, 24, 24)
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 256, s0, s1, s2, half_D,
pos_t.stride(0), BLOCK_D=128,
)
q_flat = q.reshape(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=device)
print()
print('q_flat shape:', q_flat.shape, 'strides:', q_flat.stride())
print('q_out shape:', q_out.shape, 'strides:', q_out.stride())
print('cos_mixed shape:', cos_mixed.shape, 'strides:', cos_mixed.stride())
print()
print('BT*Hq =', BT*Hq, f'(grid size)')
solution._rope_transpose_flat_kernel[(BT * Hq,)](
q_flat, cos_mixed, sin_mixed, q_out,
BT, T, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
q_out.stride(0), q_out.stride(1),
BLOCK_D=128,
)
# Reshape
q_rot = q_out.reshape(Hq, B, T, D).transpose(0, 1)
print('q_rot shape:', q_rot.shape)
print('Diff after reshape:', (rq.float() - q_rot.float()).abs().max().item())
# Check specific elements
for bt in range(BT):
for h in range(Hq):
pid = bt * Hq + h
diff = (q_out[h, bt, :].float() - rq[0, h, bt % T, :].float()).abs().max().item()
if diff > 0.05:
print(f' pid={pid} bt={bt} h={h}: diff={diff:.6f}')
# Compare
print(f' kernel[0:4]: {q_out[h, bt, :4]}')
print(f' ref[0:4]: {rq[0, h, bt % T, :4]}')
print()
print('All elements checked.')
" 2>&1
ref q shape: torch.Size([1, 2, 4, 128])
sol q shape: torch.Size([1, 2, 4, 128])
q max diff: 0.99072265625
q_flat shape: torch.Size([4, 2, 128]) strides: (256, 128, 1)
q_out shape: torch.Size([2, 4, 128]) strides: (512, 128, 1)
cos_mixed shape: torch.Size([4, 128]) strides: (128, 1)
BT*Hq = 8 (grid size)
q_rot shape: torch.Size([1, 2, 4, 128])
Diff after reshape: 1.29931640625
pid=0 bt=0 h=0: diff=0.902739
kernel[0:4]: tensor([ 0.4297, 1.1953, -0.0654, -0.3848], device='cuda:0',
dtype=torch.bfloat16)
ref[0:4]: tensor([ 0.4297, 1.1953, -0.0654, -0.3828], device='cuda:0',
dtype=torch.bfloat16)
pid=1 bt=0 h=1: diff=0.932129
kernel[0:4]: tensor([ 0.2002, 0.4570, -0.1387, -0.0698], device='cuda:0',
dtype=torch.bfloat16)
ref[0:4]: tensor([ 0.2002, 0.4570, -0.1387, -0.0703], device='cuda:0',
dtype=torch.bfloat16)
pid=2 bt=1 h=0: diff=0.990723
kernel[0:4]: tensor([ 1.2031, 0.3438, 0.1953, -0.3145], device='cuda:0',
dtype=torch.bfloat16)
ref[0:4]: tensor([ 1.2031, 0.3438, 0.1953, -0.3145], device='cuda:0',
dtype=torch.bfloat16)
pid=3 bt=1 h=1: diff=0.942871
kernel[0:4]: tensor([-0.6641, -0.1338, -0.0574, 0.3984], device='cuda:0',
dtype=torch.bfloat16)
ref[0:4]: tensor([-0.6641, -0.1338, -0.0571, 0.3984], device='cuda:0',
dtype=torch.bfloat16)
pid=6 bt=3 h=0: diff=1.299316
kernel[0:4]: tensor([-0.2832, -0.3848, -0.1699, 0.5781], device='cuda:0',
... (154 more chars){
"command": "<2038 chars \u2014 see below>",
"description": "Check first vs second half of D"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 4, 2, 1, 128
BT = B*T
half_D = D // 2
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 256).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
rq, rk = ref(q, k, pos_t, pos_h, pos_w)
s0, s1, s2 = (16, 24, 24)
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 256, s0, s1, s2, half_D,
pos_t.stride(0), BLOCK_D=128,
)
q_flat = q.reshape(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_flat_kernel[(BT * Hq,)](
q_flat, cos_mixed, sin_mixed, q_out,
BT, T, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
q_out.stride(0), q_out.stride(1),
BLOCK_D=128,
)
# Check first/second half separately
for bt in range(BT):
for h in range(Hq):
diff_first = (q_out[h, bt, :64].float() - rq[0, h, bt % T, :64].float()).abs().max().item()
diff_second = (q_out[h, bt, 64:].float() - rq[0, h, bt % T, 64:].float()).abs().max().item()
print(f'bt={bt} h={h}: first_half_max_diff={diff_first:.6f} second_half_max_diff={diff_second:.6f}')
# Manual check for (bt=0, h=0)
print()
x = q[0, 0, 0, :]
cos = cos_mixed[0, :]
sin = sin_mixed[0, :]
rh = torch.cat([-x[64:], x[:64]], dim=0)
expected = x * cos + rh * sin
print('Manual expected[64:72]:', expected[64:72])
print('Kernel out[64:72]: ', q_out[0, 0, 64:72])
print('Ref[64:72]: ', rq[0, 0, 0, 64:72])
" 2>&1
bt=0 h=0: first_half_max_diff=0.007812 second_half_max_diff=0.902739
bt=0 h=1: first_half_max_diff=0.003906 second_half_max_diff=0.003906
bt=1 h=0: first_half_max_diff=0.003906 second_half_max_diff=0.990723
bt=1 h=1: first_half_max_diff=0.003906 second_half_max_diff=0.003906
bt=2 h=0: first_half_max_diff=0.003906 second_half_max_diff=0.007812
bt=2 h=1: first_half_max_diff=0.003906 second_half_max_diff=0.003906
bt=3 h=0: first_half_max_diff=0.007812 second_half_max_diff=0.007812
bt=3 h=1: first_half_max_diff=0.007812 second_half_max_diff=0.007812
Manual expected[64:72]: tensor([ 0.3594, 0.1855, -0.2832, -0.6641, -0.8203, -0.0684, -0.0840, -0.2539],
device='cuda:0', dtype=torch.bfloat16)
Kernel out[64:72]: tensor([ 0.4219, -0.4473, -0.2393, -0.2871, 0.0718, 0.1729, -0.0498, 0.0757],
device='cuda:0', dtype=torch.bfloat16)
Ref[64:72]: tensor([ 0.3594, 0.1855, -0.2832, -0.6641, -0.8203, -0.0684, -0.0840, -0.2539],
device='cuda:0', dtype=torch.bfloat16)
{
"command": "<1424 chars \u2014 see below>",
"description": "Check write address collision"
}uv run python -c "
# Trace exactly what each program writes
# For bt=0 h=0: out_base = 0*512 + 0*128 = 0
# stores first half at [0:64] and second half at [0+64 : 0+128] = [64:128]
# writes to q_out[0,0,:128] = indices 0..127 ✓
# For bt=0 h=1: out_base = 0*512 + 0*128 = 0
# Wait... bt=0 h=1 has pid=1, bt=1//2=0, h=1%2=1
# out_base = 1*512 + 0*128 = 512
# First half at [512:512+64] = [512:576], second half at [576:640]
# This is q_out[1,0,:] ✓ - correct, no collision
# For bt=1 h=0: pid=2, bt=2//2=1, h=2%2=0
# out_base = 0*512 + 1*128 = 128
# First half at [128:192], second half at [192:256] = q_out[0,1,:] ✓
# Hmm no collision. But why is only the SECOND HALF of q_out[0,0,:] wrong?
# Let me check: is the first half correct?
# Actually the first half of q_out[0,0,:] is correct (diff=0.007812).
# Only the second half is wrong (diff=0.902739).
# This means the FIRST HALF store is correct but the SECOND HALF store for pid=0 is wrong.
# The second half store is at: out_base + d_offs + half_D
# For pid=0: base=0, so [64:128]
# Is another program also writing to [64:128]?
# Let me check ALL programs' write addresses:
for pid in range(8):
Hq = 2
bt = pid // Hq
h = pid % Hq
out_base = h * 512 + bt * 128
first = (out_base, out_base + 64)
second = (out_base + 64, out_base + 128)
print(f'pid={pid}: bt={bt} h={h} out_base={out_base} first={first} second={second}')
" 2>&1
pid=0: bt=0 h=0 out_base=0 first=(0, 64) second=(64, 128)
pid=1: bt=0 h=1 out_base=512 first=(512, 576) second=(576, 640)
pid=2: bt=1 h=0 out_base=128 first=(128, 192) second=(192, 256)
pid=3: bt=1 h=1 out_base=640 first=(640, 704) second=(704, 768)
pid=4: bt=2 h=0 out_base=256 first=(256, 320) second=(320, 384)
pid=5: bt=2 h=1 out_base=768 first=(768, 832) second=(832, 896)
pid=6: bt=3 h=0 out_base=384 first=(384, 448) second=(448, 512)
pid=7: bt=3 h=1 out_base=896 first=(896, 960) second=(960, 1024)
{
"command": "<2514 chars \u2014 see below>",
"description": "Verify cos/sin and manual computation"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 1, 4, 2, 1, 128
BT = B*T
half_D = 64
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 256).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,256,(B,T),dtype=torch.int64,device=device)
# Run stage 1
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 256, 16, 24, 24, half_D,
pos_t.stride(0), BLOCK_D=128,
)
# Manually compute what the output should be
x = q[0, 0, 0, :] # for (bt=0, h=0)
cos = cos_mixed[0, :] # cos for (bt=0)
sin = sin_mixed[0, :] # sin for (bt=0)
# Manual rotation
x_d = x[:64]
x_dh = x[64:]
out_first = x_d * cos[:64] - x_dh * sin[:64]
out_second = x_dh * cos[64:128] + x_d * sin[64:128]
print('Manual out_first[0:4]:', out_first[:4])
print('Manual out_second[0:4]:', out_second[:4])
# Run kernel
q_flat = q.reshape(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_flat_kernel[(BT * Hq,)](
q_flat, cos_mixed, sin_mixed, q_out,
BT, T, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
q_out.stride(0), q_out.stride(1),
BLOCK_D=128,
)
print()
print('Kernel q_out[0,0,0:4]:', q_out[0,0,:4])
print('Kernel q_out[0,0,64:68]:', q_out[0,0,64:68])
print()
print('Expected first[0:4]:', out_first[:4])
print('Expected second[0:4]:', out_second[:4])
print()
print('Match first:', torch.allclose(q_out[0,0,:64].float(), out_first.float(), atol=0.01))
print('Match second:', torch.allclose(q_out[0,0,64:].float(), out_second.float(), atol=0.01))
# What about cos[64:68] and sin[64:68]?
print()
print('cos[64:68]:', cos_mixed[0, 64:68])
print('sin[64:68]:', sin_mixed[0, 64:68])
print('cos[0:4]:', cos_mixed[0, 0:4])
print('sin[0:4]:', sin_mixed[0, 0:4])
# For RoPE, cos[d] == cos[d+D/2] and sin[d] == sin[d+D/2] because freqs is concatenated
# Check reference
cos_ref = ref.cos_cache[pos_t[0,0]] # (D,)
sin_ref = ref.sin_cache[pos_t[0,0]]
print()
print('Ref cos[pos_t[0]][64:68]:', cos_ref[64:68])
print('Ref cos[pos_t[0]][0:4]:', cos_ref[0:4])
" 2>&1
Manual out_first[0:4]: tensor([-0.5039, -0.3730, 0.2773, 0.5391], device='cuda:0',
dtype=torch.bfloat16)
Manual out_second[0:4]: tensor([ 0.2412, 1.1484, -0.0854, -0.5469], device='cuda:0',
dtype=torch.bfloat16)
Kernel q_out[0,0,0:4]: tensor([-0.5039, -0.3730, 0.2773, 0.5391], device='cuda:0',
dtype=torch.bfloat16)
Kernel q_out[0,0,64:68]: tensor([ 0.2412, 1.1484, -0.0854, -0.5469], device='cuda:0',
dtype=torch.bfloat16)
Expected first[0:4]: tensor([-0.5039, -0.3730, 0.2773, 0.5391], device='cuda:0',
dtype=torch.bfloat16)
Expected second[0:4]: tensor([ 0.2412, 1.1484, -0.0854, -0.5469], device='cuda:0',
dtype=torch.bfloat16)
Match first: True
Match second: True
cos[64:68]: tensor([ 0.2676, -0.7070, -0.5625, 0.9805], device='cuda:0',
dtype=torch.bfloat16)
sin[64:68]: tensor([ 0.9648, 0.7070, -0.8242, 0.1924], device='cuda:0',
dtype=torch.bfloat16)
cos[0:4]: tensor([ 0.2676, -0.7070, -0.5625, 0.9805], device='cuda:0',
dtype=torch.bfloat16)
sin[0:4]: tensor([ 0.9648, 0.7070, -0.8242, 0.1924], device='cuda:0',
dtype=torch.bfloat16)
Ref cos[pos_t[0]][64:68]: tensor([ 0.2676, -0.7070, -0.5625, 0.9805], device='cuda:0',
dtype=torch.bfloat16)
Ref cos[pos_t[0]][0:4]: tensor([ 0.2676, -0.7070, -0.5625, 0.9805], device='cuda:0',
dtype=torch.bfloat16)
{
"command": "<1493 chars \u2014 see below>",
"description": "Test with T=64"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
T = 64
B, Hq, Hkv, D = 1, 4, 2, 128
BT = B*T
half_D = 64
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
sol = solution.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
rq, rk = ref(q, k, pos_t, pos_h, pos_w)
sq, sk = sol(q, k, pos_t, pos_h, pos_w)
print(f'T={T}: q max_diff={(rq.float()-sq.float()).abs().max().item():.6f}')
print(f'T={T}: k max_diff={(rk.float()-sk.float()).abs().max().item():.6f}')
# Check first vs second half
q_diff = (rq.float() - sq.float()).abs()
print(f' first_half_max_diff={q_diff[:,:,:,0:64].max().item():.6f}')
print(f' second_half_max_diff={q_diff[:,:,:,64:128].max().item():.6f}')
print()
# Check all (bt, h)
for h in range(Hq):
for t in [0, 1, 10, 50, 63]:
d_first = q_diff[0, h, t, 0:64].max().item()
d_second = q_diff[0, h, t, 64:128].max().item()
if d_first > 0.01 or d_second > 0.01:
print(f' h={h} t={t}: first_half={d_first:.6f} second_half={d_second:.6f}')
" 2>&1
T=64: q max_diff=1.488281
T=64: k max_diff=1.380859
first_half_max_diff=0.007812
second_half_max_diff=1.488281
h=0 t=0: first_half=0.003906 second_half=0.925781
h=0 t=10: first_half=0.003906 second_half=1.369385
h=0 t=63: first_half=0.003906 second_half=0.925293
h=1 t=10: first_half=0.003906 second_half=0.820435
h=2 t=1: first_half=0.003906 second_half=1.184204
h=3 t=1: first_half=0.003906 second_half=1.185303
h=3 t=10: first_half=0.003906 second_half=0.418945
{
"command": "<3041 chars \u2014 see below>",
"description": "Verify cos/sin equality and rotation"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
T = 64
B, Hq, Hkv, D = 1, 4, 2, 128
BT = B*T
half_D = 64
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
# Run stage 1
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 1024, 16, 24, 24, half_D,
pos_t.stride(0), BLOCK_D=128,
)
# Run kernel manually
q_flat = q.reshape(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_flat_kernel[(BT * Hq,)](
q_flat, cos_mixed, sin_mixed, q_out,
BT, T, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
q_out.stride(0), q_out.stride(1),
BLOCK_D=128,
)
# For (bt=0, h=0): manual compute
x = q[0, 0, 0, :]
cos = cos_mixed[0, :]
sin = sin_mixed[0, :]
# Verify cos == cos[half_D:D]
print('cos[0:64] == cos[64:128]:', torch.allclose(cos[:64], cos[64:]))
# Manual rotation
x_d = x[:64]; x_dh = x[64:]
out_first = x_d * cos[:64] - x_dh * sin[:64]
out_second = x_dh * cos[64:128] + x_d * sin[64:128]
print('Manual out_second[0:4]:', out_second[:4])
print('Kernel q_out[0,0,64:68]:', q_out[0,0,64:68])
print('Match:', torch.allclose(out_second.float(), q_out[0,0,64:].float(), atol=0.01))
# Compare with reference
rq, _ = ref(q, torch.randn(1,T,2,D,dtype=torch.bfloat16,device=device)*0.5, pos_t, pos_h, pos_w)
print()
print('Ref rq[0,0,0,64:68]:', rq[0,0,0,64:68])
print('Manual first half[:4]:', out_first[:4])
print('Ref first half[:4]:', rq[0,0,0,0:4])
print('Match first:', torch.allclose(out_first.float(), rq[0,0,0,0:64].float(), atol=0.01))
print('Match second:', torch.allclose(out_second.float(), rq[0,0,0,64:128].float(), atol=0.01))
print()
# What about the reference _rotate_half computation?
from reference import _rotate_half
rh = _rotate_half(q[0:1, 0:1, 0:1, :]).squeeze() # (D,)
print('rh[:4]:', rh[:4])
print('rh[64:68]:', rh[64:68])
print('-x[64:68]:', -x[64:68])
print('x[0:4]:', x[0:4])
# Full reference computation for (b=0, t=0, h=0)
cos_mix = ref._mix_axes(ref.cos_cache[pos_t], ref.cos_cache[pos_h], ref.cos_cache[pos_w])[0, 0, :]
sin_mix = ref._mix_axes(ref.sin_cache[pos_t], ref.sin_cache[pos_h], ref.sin_cache[pos_w])[0, 0, :]
cos_mix = cos_mix.unsqueeze(0) # (1, D)
sin_mix = sin_mix.unsqueeze(0)
q_h = q[0:1, 0:1, 0:1, :] # (1, 1, 1, D)
out = (q_h * cos_mix) + (_rotate_half(q_h) * sin_mix)
print()
print('Full ref out[64:68]:', out[0, 0, 0, 64:68])
print('cos_mix[64:68]:', cos_mix[0, 64:68])
print('sin_mix[64:68]:', sin_mix[0, 64:68])
" 2>&1
cos[0:64] == cos[64:128]: True
Manual out_second[0:4]: tensor([ 0.4941, -0.6016, -0.0107, 0.6680], device='cuda:0',
dtype=torch.bfloat16)
Kernel q_out[0,0,64:68]: tensor([ 0.4941, -0.6016, -0.0106, 0.6680], device='cuda:0',
dtype=torch.bfloat16)
Match: False
Ref rq[0,0,0,64:68]: tensor([ 0.4941, -0.6016, -0.0107, 0.6680], device='cuda:0',
dtype=torch.bfloat16)
Manual first half[:4]: tensor([ 0.2617, -1.0469, -0.2891, -0.3770], device='cuda:0',
dtype=torch.bfloat16)
Ref first half[:4]: tensor([ 0.2617, -1.0469, -0.2891, -0.3770], device='cuda:0',
dtype=torch.bfloat16)
Match first: True
Match second: True
rh[:4]: tensor([-0.5508, 0.5508, -0.2773, 0.6406], device='cuda:0',
dtype=torch.bfloat16)
rh[64:68]: tensor([ 0.0972, 1.0781, -0.0859, 0.4238], device='cuda:0',
dtype=torch.bfloat16)
-x[64:68]: tensor([-0.5508, 0.5508, -0.2773, 0.6406], device='cuda:0',
dtype=torch.bfloat16)
x[0:4]: tensor([ 0.0972, 1.0781, -0.0859, 0.4238], device='cuda:0',
dtype=torch.bfloat16)
Full ref out[64:68]: tensor([ 0.4941, -0.6016, -0.0107, 0.6680], device='cuda:0',
dtype=torch.bfloat16)
cos_mix[64:68]: tensor([ 0.9531, -0.5469, 0.2598, -0.9961], device='cuda:0',
dtype=torch.bfloat16)
sin_mix[64:68]: tensor([-0.3047, -0.8359, 0.9648, 0.0693], device='cuda:0',
dtype=torch.bfloat16)
{
"command": "<2888 chars \u2014 see below>",
"description": "Check element-level diff in second half"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
T = 64
B, Hq, Hkv, D = 1, 4, 2, 128
BT = B*T
half_D = 64
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
# Run stage 1
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
pos_t = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,1024,(B,T),dtype=torch.int64,device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, 1024, 16, 24, 24, half_D,
pos_t.stride(0), BLOCK_D=128,
)
# Run kernel
q_flat = q.reshape(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_flat_kernel[(BT * Hq,)](
q_flat, cos_mixed, sin_mixed, q_out,
BT, T, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
q_out.stride(0), q_out.stride(1),
BLOCK_D=128,
)
# Full manual for ALL elements
for bt in [0]:
for h in [0]:
x = q_flat[bt, h, :]
cos = cos_mixed[bt, :]
sin = sin_mixed[bt, :]
x_d = x[:64]; x_dh = x[64:]
out_first = x_d * cos[:64] - x_dh * sin[:64]
out_second = x_dh * cos[64:] + x_d * sin[64:]
expected = torch.cat([out_first, out_second])
kernel_out = q_out[h, bt, :]
diff = (expected.float() - kernel_out.float()).abs()
max_diff = diff.max().item()
max_idx = diff.argmax().item()
print(f'bt={bt} h={h}: max_diff={max_diff:.8f} at idx={max_idx}')
if max_diff > 0.001:
print(f' expected[{max_idx}]:', expected[max_idx])
print(f' kernel[{max_idx}]:', kernel_out[max_idx])
# Also compare first 8 and last 8
print(f' first 8 diff: {(expected[:8].float() - kernel_out[:8].float()).abs().max():.8f}')
print(f' last 8 diff: {(expected[-8:].float() - kernel_out[-8:].float()).abs().max():.8f}')
# Now compare with the full solution
sol = solution.Model(B, T, Hq, Hkv, D, (16,24,24), 1024).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
rq, rk = ref(q, k, pos_t, pos_h, pos_w)
sq, sk = sol(q, k, pos_t, pos_h, pos_w)
print()
print(f'sol q max_diff with ref: {(rq.float()-sq.float()).abs().max().item():.6f}')
# The sol uses q.reshape which may have different stride
print(f'q strides: {q.stride()}')
print(f'q_flat strides in sol: q.reshape({BT}, {Hq}, {D})')
print(f'q reshape works: {q.reshape(BT, Hq, D).is_contiguous()}')
" 2>&1
bt=0 h=0: max_diff=0.90234375 at idx=65
expected[65]: tensor(-0.6016, device='cuda:0', dtype=torch.bfloat16)
kernel[65]: tensor(0.3008, device='cuda:0', dtype=torch.bfloat16)
first 8 diff: 0.00195312
last 8 diff: 0.01269531
sol q max_diff with ref: 1.563477
q strides: (32768, 512, 128, 1)
q_flat strides in sol: q.reshape(64, 4, 128)
q reshape works: True
--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -1,12 +1,9 @@
"""Fused multi-axis RoPE pre-attention prep via Triton kernels.
-Three-kernel approach:
- 1. grid=(B*T,): gather + mix cos/sin once per position => (B*T, D)
- 2. grid=(B*T*Hq,): rotate + transpose for q. Each program handles one head.
- 3. grid=(B*T*Hkv,): rotate + transpose for k. Each program handles one head.
-
-Kernels 2/3 read pre-mixed cos/sin (shared across heads), avoiding redundant
-gather bandwidth. The flat (B*T*H,) grid maximizes GPU occupancy.
+Two-stage with separate q and k rotation kernels:
+ Stage 1: grid=(B*T,). Gather+mix cos/sin once per position => (B*T, D).
+ Stage 2: grid=(B*T, Hq) for q, (B*T, Hkv) for k.
+ Each block loads cos/sin ONCE, applies rotate-half, writes transposed output.
"""
from __future__ import annotations
@@ -19,9 +16,6 @@
SUPPORTED_PRECISIONS = ["bf16"]
HARDWARE_REQUIRED = ["RTX_PRO_6000"]
-# ---------------------------------------------------------------------------
-# Stage 1: gather + mix cos/sin
-# ---------------------------------------------------------------------------
@triton.jit
def _mix_cos_sin_kernel(
@@ -30,14 +24,14 @@
cos_out_ptr, sin_out_ptr,
B, T, D, max_pos,
s0, s1, s2, half_D,
- stride_pos_b,
+ stride_pos,
BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
b = pid // T
t = pid % T
- pos_base = b * stride_pos_b + t
+ pos_base = b * stride_pos + t
pt = tl.load(pos_t_ptr + pos_base)
ph = tl.load(pos_h_ptr + pos_base)
pw = tl.load(pos_w_ptr + pos_base)
@@ -78,45 +72,47 @@
tl.store(sin_out_ptr + out_base + d_offs, sin_mixed, mask=mask)
-# ---------------------------------------------------------------------------
-# Stage 2/3: rotate + transpose
-# Grid: (B*T*H,) flat. Each program handles one head.
-# Read cos/sin from precomputed (B*T, D) buffer.
-# ---------------------------------------------------------------------------
-
@triton.jit
-def _rope_transpose_flat_kernel(
+def _rope_transpose_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
- BTH, T, H, D, half_D,
- stride_x_bt, stride_x_h, # x: (B*T, H, D) flat view
- stride_cs, # cos/sin: (B*T, D) flat view
- stride_out_h, stride_out_t, # out: (H, T, D) for a single-b view
+ T_strided, H, D, half_D,
+ stride_x_bt, stride_x_h,
+ stride_cs,
BLOCK_D: tl.constexpr,
):
- pid = tl.program_id(0)
- bt = pid // H # which (b,t) position
- h = pid % H # which head
+ """Rotate + transpose kernel.
+
+ Grid: (B*T, H). pid_dim1=pid_bt (0..B*T-1), pid_dim2=pid_h (0..H-1).
+ x: (B*T, H, D) flat view (after reshape).
+ cos/sin: (B*T, D) flat.
+ out: (H, B*T, D) which gets reshaped to (B, H, T, D).
+ """
+ pid_bt = tl.program_id(1)
+ pid_h = tl.program_id(2)
d_offs = tl.arange(0, BLOCK_D)
-
- # Load cos/sin
- cs_base = bt * stride_cs
- cos_vals = tl.load(cos_ptr + cs_base + d_offs, mask=(d_offs < D))
- sin_vals = tl.load(sin_ptr + cs_base + d_offs, mask=(d_offs < D))
-
- # Load x
- x_base = bt * stride_x_bt + h * stride_x_h
- x_d = tl.load(x_ptr + x_base + d_offs, mask=(d_offs < D))
- x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=((d_offs + half_D) < D))
+ mask_d = d_offs < D
+ mask_half = d_offs < half_D
+ mask_second = (d_offs + half_D) < D
+
+ # Load cos/sin ONCE (shared across all heads for this position)
+ cs_base = pid_bt * stride_cs
+ cos_vals = tl.load(cos_ptr + cs_base + d_offs, mask=mask_d)
+ sin_vals = tl.load(sin_ptr + cs_base + d_offs, mask=mask_d)
+
+ # Load x for this head
+ x_base = pid_bt * stride_x_bt + pid_h * stride_x_h
+ x_d = tl.load(x_ptr + x_base + d_offs, mask=mask_half)
+ x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=mask_second)
# Rotate
- out_first = x_d * cos_vals - x_dh * sin_vals
+ out_first = x_d * cos_vals - x_dh * sin_vals
out_second = x_dh * cos_vals + x_d * sin_vals
# Write transposed: out[h, bt, d]
- out_base = h * stride_out_h + bt * stride_out_t
- tl.store(out_ptr + out_base + d_offs, out_first, mask=(d_offs < D))
- tl.store(out_ptr + out_base + d_offs + half_D, out_second, mask=((d_offs + half_D) < D))
+ out_base = pid_h * T_strided + pid_bt * D
+ tl.store(out_ptr + out_base + d_offs, out_first, mask=mask_half)
+ tl.store(out_ptr + out_base + d_offs + half_D, out_second, mask=mask_second)
def _next_power_of_2(n):
@@ -127,9 +123,7 @@
def __init__(self, B: int, T: int, Hq: int, Hkv: int, D: int,
mrope_section: tuple[int, int, int], max_pos: int):
super().__init__()
- assert sum(mrope_section) == D // 2, (
- f"mrope_section must sum to D/2, got {mrope_section} sum={sum(mrope_section)} D/2={D//2}"
- )
+ assert sum(mrope_section) == D // 2
self.B, self.T = B, T
self.Hq, self.Hkv, self.D = Hq, Hkv, D
self.mrope_section = tuple(mrope_section)
@@ -155,11 +149,12 @@
BT = B * T
BLOCK_D = self._BLOCK_D
- # Stage 1: gather + mix cos/sin => (B*T, D)
+ # Stage 1: gather + mix cos/sin
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=q.device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=q.device)
- _mix_cos_sin_kernel[(BT,)](
+ grid_mix = (BT,)
+ _mix_cos_sin_kernel[grid_mix](
self.cos_cache, self.sin_cache,
pos_t, pos_h, pos_w,
cos_mixed, sin_mixed,
@@ -167,42 +162,38 @@
s0, s1, s2, half_D,
pos_t.stride(0),
BLOCK_D=BLOCK_D,
+ num_warps=4,
)
- # Stage 2: rotate + transpose q
- # Reshape q from (B, T, Hq, D) to (BT, Hq, D) - contiguous
- q_flat = q.reshape(BT, Hq, D) # O(1) vi
... (1639 more chars){
"content": "<6931 chars \u2014 see diff>",
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}Wrote file successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Check 3D grid kernel"
}FAIL: shape 2 {'B': 2, 'T': 2048, 'Hq': 16, 'Hkv': 2, 'D': 64, 'mrope_section': (8, 12, 12), 'max_pos': 16384} seed 42 q_rot: tolerance exceeded: max_abs_diff=3.4375 (atol=0.02, rtol=0.02)
{
"command": "<1369 chars \u2014 see below>",
"description": "Quick test multiple shapes"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
for shape in [
{'B': 1, 'T': 64, 'Hq': 4, 'Hkv': 2, 'D': 128, 'mrope_section': (16,24,24), 'max_pos': 1024},
{'B': 2, 'T': 32, 'Hq': 4, 'Hkv': 2, 'D': 64, 'mrope_section': (8,12,12), 'max_pos': 256},
]:
B, T, Hq, Hkv, D = shape['B'], shape['T'], shape['Hq'], shape['Hkv'], shape['D']
mp = shape['max_pos']
ms = shape['mrope_section']
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, ms, mp).to(device).eval()
sol = solution.Model(B, T, Hq, Hkv, D, ms, mp).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
rq, rk = ref(q, k, pos_t, pos_h, pos_w)
sq, sk = sol(q, k, pos_t, pos_h, pos_w)
qd = (rq.float()-sq.float()).abs().max().item()
kd = (rk.float()-sk.float()).abs().max().item()
print(f'B={B} T={T} Hq={Hq} Hkv={Hkv} D={D}: q_diff={qd:.6f} k_diff={kd:.6f}',
'PASS' if qd < 0.02 and kd < 0.02 else 'FAIL')
" 2>&1
B=1 T=64 Hq=4 Hkv=2 D=128: q_diff=0.007812 k_diff=0.007812 PASS
B=2 T=32 Hq=4 Hkv=2 D=64: q_diff=2.773438 k_diff=2.625000 FAIL
{
"command": "<2773 chars \u2014 see below>",
"description": "Debug D=64 case"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 2, 32, 4, 2, 64
BT = B*T
half_D = D // 2
mp, ms = 256, (8, 12, 12)
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, ms, mp).to(device).eval()
sol = solution.Model(B, T, Hq, Hkv, D, ms, mp).to(device).eval()
sol.load_state_dict(ref.state_dict(), strict=True)
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
# Run manual stages
s0, s1, s2 = ms
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, mp, s0, s1, s2, half_D,
pos_t.stride(0), BLOCK_D=64, num_warps=4,
)
# Verify stage 1
cos_ref = ref.cos_cache[pos_t] # (B, T, D)
sin_ref = ref.sin_cache[pos_t]
cos_mix_ref = ref._mix_axes(cos_ref, ref.cos_cache[pos_h], ref.cos_cache[pos_w])
print('cos_mixed shape:', cos_mixed.shape)
print('cos_ref shape:', cos_mix_ref.shape)
print('cos_mixed[0,:4]:', cos_mixed[0,:4])
print('cos_mix_ref[0,0,:4]:', cos_mix_ref[0,0,:4])
print('Stage 1 cos diff:', (cos_mixed.float() - cos_mix_ref.view(BT, D).float()).abs().max().item())
# Try the q rotation
q_flat = q.view(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=device)
BLOCK_D_q = 64
solution._rope_transpose_kernel[(1, BT, Hq)](
q_flat, cos_mixed, sin_mixed, q_out,
BT * D, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
BLOCK_D=BLOCK_D_q, num_warps=4,
)
# Reshape and compare
q_rot = q_out.view(B, Hq, T, D)
rq, _ = ref(q, k, pos_t, pos_h, pos_w)
print()
print('q_rot shape:', q_rot.shape, 'rq shape:', rq.shape)
qd = (rq.float() - q_rot.float()).abs()
print(f'q first_half diff: {qd[:,:,:,:32].max().item():.6f}')
print(f'q second_half diff: {qd[:,:,:,32:64].max().item():.6f}')
print(f'q max_diff: {qd.max().item():.6f}')
# Manual check for one element
bt = 0; h = 0
x = q_flat[bt, h, :] # (D,)
cos = cos_mixed[bt, :] # (D,)
sin = sin_mixed[bt, :] # (D,)
x_d = x[0:32]; x_dh = x[32:64]
manual_first = x_d * cos[:32] - x_dh * sin[:32]
manual_second = x_dh * cos[32:64] + x_d * sin[32:64]
print()
print(f'bt={bt} h={h}:')
print(f' manual_first[:4]: {manual_first[:4]}')
print(f' kernel_out[:4]: {q_out[h, bt, :4]}')
print(f' manual_second[:4]: {manual_second[:4]}')
print(f' kernel_out[32:36]: {q_out[h, bt, 32:36]}')
" 2>&1
cos_mixed shape: torch.Size([64, 64])
cos_ref shape: torch.Size([2, 32, 64])
cos_mixed[0,:4]: tensor([ 0.7656, -0.8633, -0.3770, -0.9961], device='cuda:0',
dtype=torch.bfloat16)
cos_mix_ref[0,0,:4]: tensor([ 0.7656, -0.8633, -0.3770, -0.9961], device='cuda:0',
dtype=torch.bfloat16)
Stage 1 cos diff: 0.0
q_rot shape: torch.Size([2, 4, 32, 64]) rq shape: torch.Size([2, 4, 32, 64])
q first_half diff: 2.507812
q second_half diff: 2.773438
q max_diff: 2.773438
bt=0 h=0:
manual_first[:4]: tensor([ 0.3457, -1.2891, -0.2676, -0.4434], device='cuda:0',
dtype=torch.bfloat16)
kernel_out[:4]: tensor([ 0.3457, -1.2891, -0.2676, -0.4453], device='cuda:0',
dtype=torch.bfloat16)
manual_second[:4]: tensor([ 0.2598, -0.0703, -0.2012, -0.1680], device='cuda:0',
dtype=torch.bfloat16)
kernel_out[32:36]: tensor([ 0.2598, -0.0698, -0.2012, -0.1670], device='cuda:0',
dtype=torch.bfloat16)
{
"command": "<2292 chars \u2014 see below>",
"description": "Find which elements are wrong"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 2, 32, 4, 2, 64
BT = B*T
half_D = D // 2
mp, ms = 256, (8, 12, 12)
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, ms, mp).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, mp, 8, 12, 12, half_D,
pos_t.stride(0), BLOCK_D=64, num_warps=4,
)
q_flat = q.view(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_kernel[(1, BT, Hq)](
q_flat, cos_mixed, sin_mixed, q_out,
BT * D, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
BLOCK_D=64, num_warps=4,
)
q_rot = q_out.view(B, Hq, T, D)
rq, _ = ref(q, k, pos_t, pos_h, pos_w)
# Check per-element diffs
errors = []
for b in range(B):
for h in range(Hq):
for t in range(T):
d = (rq[b, h, t].float() - q_rot[b, h, t].float()).abs()
m = d.max().item()
if m > 0.05:
d_idx = d.argmax().item()
errors.append(f'b={b} h={h} t={t} d={d_idx}: ref={rq[b,h,t,d_idx]:.4f} sol={q_rot[b,h,t,d_idx]:.4f} diff={m:.4f}')
print(f'Total errors (diff>0.05): {len(errors)} out of {B*Hq*T} elements')
for e in errors[:20]:
print(f' {e}')
# Check which (bt, h) combos have errors
# In the q_out view: q_out[h, bt, :] where bt = b*T + t
print()
bad_pairs = set()
for b in range(B):
for h in range(Hq):
for t in range(T):
bt = b * T + t
d = (rq[b, h, t].float() - q_out[h, bt].float()).abs().max().item()
if d > 0.05:
bad_pairs.add((bt, h))
print(f'Bad (bt, h) pairs: {len(bad_pairs)} out of {BT * Hq} total')
" 2>&1
Total errors (diff>0.05): 192 out of 256 elements
b=0 h=1 t=0 d=43: ref=-0.7305 sol=0.5391 diff=1.2695
b=0 h=1 t=1 d=18: ref=-1.0625 sol=0.7656 diff=1.8281
b=0 h=1 t=2 d=33: ref=-1.0938 sol=0.8516 diff=1.9453
b=0 h=1 t=3 d=35: ref=0.7500 sol=-1.6797 diff=2.4297
b=0 h=1 t=4 d=25: ref=0.6445 sol=-0.9531 diff=1.5977
b=0 h=1 t=5 d=19: ref=-0.1465 sol=1.4531 diff=1.5996
b=0 h=1 t=6 d=44: ref=-0.2139 sol=-1.6172 diff=1.4033
b=0 h=1 t=7 d=20: ref=-0.8359 sol=0.6406 diff=1.4766
b=0 h=1 t=8 d=24: ref=0.2383 sol=-1.5859 diff=1.8242
b=0 h=1 t=9 d=32: ref=-0.1846 sol=1.1797 diff=1.3643
b=0 h=1 t=10 d=3: ref=0.1089 sol=-1.4375 diff=1.5464
b=0 h=1 t=11 d=25: ref=-0.3203 sol=1.4062 diff=1.7266
b=0 h=1 t=12 d=48: ref=0.5547 sol=-1.1953 diff=1.7500
b=0 h=1 t=13 d=9: ref=0.8984 sol=-1.0234 diff=1.9219
b=0 h=1 t=14 d=5: ref=-1.1406 sol=1.1406 diff=2.2812
b=0 h=1 t=15 d=32: ref=1.5234 sol=-0.1089 diff=1.6323
b=0 h=1 t=16 d=56: ref=-0.1934 sol=1.1953 diff=1.3887
b=0 h=1 t=17 d=37: ref=0.3555 sol=-1.2891 diff=1.6445
b=0 h=1 t=18 d=14: ref=0.3457 sol=-1.2422 diff=1.5879
b=0 h=1 t=19 d=56: ref=-0.7617 sol=0.9023 diff=1.6641
Bad (bt, h) pairs: 0 out of 256 total
{
"command": "<2214 chars \u2014 see below>",
"description": "Check view correctness"
}uv run python -c "
import torch, sys, reference, solution
device = torch.device('cuda:0')
B, T, Hq, Hkv, D = 2, 32, 4, 2, 64
BT = B*T
half_D = D // 2
mp, ms = 256, (8, 12, 12)
torch.manual_seed(42)
ref = reference.Model(B, T, Hq, Hkv, D, ms, mp).to(device).eval()
q = torch.randn(B,T,Hq,D,dtype=torch.bfloat16,device=device)*0.5
k = torch.randn(B,T,Hkv,D,dtype=torch.bfloat16,device=device)*0.5
pos_t = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
pos_h = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
pos_w = torch.randint(0,mp,(B,T),dtype=torch.int64,device=device)
cos_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
sin_mixed = torch.empty(BT, D, dtype=torch.bfloat16, device=device)
solution._mix_cos_sin_kernel[(BT,)](
ref.cos_cache, ref.sin_cache, pos_t, pos_h, pos_w,
cos_mixed, sin_mixed, B, T, D, mp, 8, 12, 12, half_D,
pos_t.stride(0), BLOCK_D=64, num_warps=4,
)
q_flat = q.view(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=device)
solution._rope_transpose_kernel[(1, BT, Hq)](
q_flat, cos_mixed, sin_mixed, q_out,
BT * D, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
BLOCK_D=64, num_warps=4,
)
# Compare q_out[h, bt, :] vs rq after view
rq, _ = ref(q, k, pos_t, pos_h, pos_w)
# Check: is q_out itself correct?
for bt in [0, 1, 32, 33]: # bt = b*32 + t
for h in [0, 1, 2]:
b = bt // T
t = bt % T
d = (q_out[h, bt].float() - rq[b, h, t].float()).abs().max().item()
print(f'q_out[h={h},bt={bt}] vs ref[b={b},h={h},t={t}]: diff={d:.6f}')
# Now check view
q_rot_view = q_out.view(B, Hq, T, D)
print()
print('q_out shape:', q_out.shape, 'strides:', q_out.stride())
print('q_rot_view shape:', q_rot_view.shape, 'strides:', q_rot_view.stride())
# Check if view is correct
for b in [0, 1]:
for h in [0, 1, 2]:
for t in [0, 1]:
d1 = (q_rot_view[b, h, t].float() - rq[b, h, t].float()).abs().max().item()
bt = b * T + t
d2 = (q_out[h, bt].float() - rq[b, h, t].float()).abs().max().item()
print(f'view[b={b},h={h},t={t}] diff={d1:.6f} vs q_out[h={h},bt={bt}] diff={d2:.6f}')
" 2>&1
q_out[h=0,bt=0] vs ref[b=0,h=0,t=0]: diff=0.007812
q_out[h=1,bt=0] vs ref[b=0,h=1,t=0]: diff=0.003906
q_out[h=2,bt=0] vs ref[b=0,h=2,t=0]: diff=0.003906
q_out[h=0,bt=1] vs ref[b=0,h=0,t=1]: diff=0.003906
q_out[h=1,bt=1] vs ref[b=0,h=1,t=1]: diff=0.007812
q_out[h=2,bt=1] vs ref[b=0,h=2,t=1]: diff=0.003906
q_out[h=0,bt=32] vs ref[b=1,h=0,t=0]: diff=0.007812
q_out[h=1,bt=32] vs ref[b=1,h=1,t=0]: diff=0.003906
q_out[h=2,bt=32] vs ref[b=1,h=2,t=0]: diff=0.003906
q_out[h=0,bt=33] vs ref[b=1,h=0,t=1]: diff=0.007812
q_out[h=1,bt=33] vs ref[b=1,h=1,t=1]: diff=0.003906
q_out[h=2,bt=33] vs ref[b=1,h=2,t=1]: diff=0.003906
q_out shape: torch.Size([4, 64, 64]) strides: (4096, 64, 1)
q_rot_view shape: torch.Size([2, 4, 32, 64]) strides: (8192, 2048, 64, 1)
view[b=0,h=0,t=0] diff=0.007812 vs q_out[h=0,bt=0] diff=0.007812
view[b=0,h=0,t=1] diff=0.003906 vs q_out[h=0,bt=1] diff=0.003906
view[b=0,h=1,t=0] diff=1.269531 vs q_out[h=1,bt=0] diff=0.003906
view[b=0,h=1,t=1] diff=1.828125 vs q_out[h=1,bt=1] diff=0.007812
view[b=0,h=2,t=0] diff=1.679688 vs q_out[h=2,bt=0] diff=0.003906
view[b=0,h=2,t=1] diff=1.785156 vs q_out[h=2,bt=1] diff=0.003906
view[b=1,h=0,t=0] diff=1.738281 vs q_out[h=0,bt=32] diff=0.007812
view[b=1,h=0,t=1] diff=1.496094 vs q_out[h=0,bt=33] diff=0.007812
view[b=1,h=1,t=0] diff=1.531250 vs q_out[h=1,bt=32] diff=0.003906
view[b=1,h=1,t=1] diff=1.601562 vs q_out[h=1,bt=33] diff=0.003906
view[b=1,h=2,t=0] diff=1.482422 vs q_out[h=2,bt=32] diff=0.003906
view[b=1,h=2,t=1] diff=1.73
... (39 more chars)--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -176,7 +176,7 @@
BLOCK_D=BLOCK_D,
num_warps=4,
)
- q_rot = q_out.view(B, Hq, T, D)
+ q_rot = q_out.view(Hq, B, T, D).transpose(0, 1)
# Stage 2k: rotate + transpose k
k_flat = k.view(BT, Hkv, D)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": " q_rot = q_out.view(B, Hq, T, D)",
"newString": " q_rot = q_out.view(Hq, B, T, D).transpose(0, 1)"
}Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -176,7 +176,7 @@
BLOCK_D=BLOCK_D,
num_warps=4,
)
- q_rot = q_out.view(Hq, B, T, D).transpose(0, 1)
+ q_rot = q_out.view(Hq, B, T, D).transpose(0, 1).contiguous()
# Stage 2k: rotate + transpose k
k_flat = k.view(BT, Hkv, D)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"newString": " q_rot = q_out.view(Hq, B, T, D).transpose(0, 1).contiguous()",
"oldString": " q_rot = q_out.view(Hq, B, T, D).transpose(0, 1)"
}Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -189,7 +189,7 @@
BLOCK_D=BLOCK_D,
num_warps=4,
)
- k_rot = k_out.view(B, Hkv, T, D)
+ k_rot = k_out.view(Hkv, B, T, D).transpose(0, 1).contiguous()
return q_rot, k_rot
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"newString": " k_rot = k_out.view(Hkv, B, T, D).transpose(0, 1).contiguous()",
"oldString": " k_rot = k_out.view(B, Hkv, T, D)"
}Edit applied successfully.{
"command": "uv run python check.py 2>&1",
"timeout": 120000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Check view fix"
}PASS
{
"command": "uv run python benchmark.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Benchmark version 3"
}shape=0 variant=eager tflops=0.177 gbps=254.315 ms=0.355
shape=0 variant=compiled tflops=0.503 gbps=721.699 ms=0.125
shape=0 variant=solution tflops=0.562 gbps=806.957 ms=0.112
shape=0 solution_peak_fraction=0.4483
shape=1 variant=eager tflops=0.178 gbps=260.648 ms=0.564
shape=1 variant=compiled tflops=0.468 gbps=683.886 ms=0.215
shape=1 variant=solution tflops=0.608 gbps=888.351 ms=0.165
shape=1 solution_peak_fraction=0.4935
shape=2 variant=eager tflops=0.070 gbps=108.765 ms=0.203
shape=2 variant=compiled tflops=0.161 gbps=250.844 ms=0.088
shape=2 variant=solution tflops=0.174 gbps=271.698 ms=0.081
shape=2 solution_peak_fraction=0.1509
shape=3 variant=eager tflops=0.170 gbps=243.415 ms=1.483
shape=3 variant=compiled tflops=0.462 gbps=662.507 ms=0.545
shape=3 variant=solution tflops=0.742 gbps=1064.523 ms=0.339
shape=3 solution_peak_fraction=0.5914
peak_fraction: 0.3749
RESULT: OK
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"offset": 88,
"limit": 30
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
88: out: (H, B*T, D) which gets reshaped to (B, H, T, D).
89: """
90: pid_bt = tl.program_id(1)
91: pid_h = tl.program_id(2)
92:
93: d_offs = tl.arange(0, BLOCK_D)
94: mask_d = d_offs < D
95: mask_half = d_offs < half_D
96: mask_second = (d_offs + half_D) < D
97:
98: # Load cos/sin ONCE (shared across all heads for this position)
99: cs_base = pid_bt * stride_cs
100: cos_vals = tl.load(cos_ptr + cs_base + d_offs, mask=mask_d)
101: sin_vals = tl.load(sin_ptr + cs_base + d_offs, mask=mask_d)
102:
103: # Load x for this head
104: x_base = pid_bt * stride_x_bt + pid_h * stride_x_h
105: x_d = tl.load(x_ptr + x_base + d_offs, mask=mask_half)
106: x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=mask_second)
107:
108: # Rotate
109: out_first = x_d * cos_vals - x_dh * sin_vals
110: out_second = x_dh * cos_vals + x_d * sin_vals
111:
112: # Write transposed: out[h, bt, d]
113: out_base = pid_h * T_strided + pid_bt * D
114: tl.store(out_ptr + out_base + d_offs, out_first, mask=mask_half)
115: tl.store(out_ptr + out_base + d_offs + half_D, out_second, mask=mask_second)
116:
117:
(Showing lines 88-117 of 214. Use offset=118 to continue.)
</content>--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -75,20 +75,22 @@
@triton.jit
def _rope_transpose_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
- T_strided, H, D, half_D,
+ T, H, D, half_D,
stride_x_bt, stride_x_h,
stride_cs,
+ stride_out_b, stride_out_h, stride_out_t,
BLOCK_D: tl.constexpr,
):
- """Rotate + transpose kernel.
-
- Grid: (B*T, H). pid_dim1=pid_bt (0..B*T-1), pid_dim2=pid_h (0..H-1).
- x: (B*T, H, D) flat view (after reshape).
+ """Rotate + transpose kernel. Writes directly to (B,H,T,D).
+ Grid: (1, B*T, H).
+ x: (B*T, H, D) flat view.
cos/sin: (B*T, D) flat.
- out: (H, B*T, D) which gets reshaped to (B, H, T, D).
- """
+ out: (B, H, T, D)."""
pid_bt = tl.program_id(1)
pid_h = tl.program_id(2)
+
+ b = pid_bt // T
+ t = pid_bt % T
d_offs = tl.arange(0, BLOCK_D)
mask_d = d_offs < D
@@ -109,8 +111,8 @@
out_first = x_d * cos_vals - x_dh * sin_vals
out_second = x_dh * cos_vals + x_d * sin_vals
- # Write transposed: out[h, bt, d]
- out_base = pid_h * T_strided + pid_bt * D
+ # Write to (B,H,T,D) layout directly
+ out_base = b * stride_out_b + pid_h * stride_out_h + t * stride_out_t
tl.store(out_ptr + out_base + d_offs, out_first, mask=mask_half)
tl.store(out_ptr + out_base + d_offs + half_D, out_second, mask=mask_second)
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"oldString": "<1411 chars \u2014 see below>",
"newString": "<1443 chars \u2014 see below>"
}@triton.jit
def _rope_transpose_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
T_strided, H, D, half_D,
stride_x_bt, stride_x_h,
stride_cs,
BLOCK_D: tl.constexpr,
):
"""Rotate + transpose kernel.
Grid: (B*T, H). pid_dim1=pid_bt (0..B*T-1), pid_dim2=pid_h (0..H-1).
x: (B*T, H, D) flat view (after reshape).
cos/sin: (B*T, D) flat.
out: (H, B*T, D) which gets reshaped to (B, H, T, D).
"""
pid_bt = tl.program_id(1)
pid_h = tl.program_id(2)
d_offs = tl.arange(0, BLOCK_D)
mask_d = d_offs < D
mask_half = d_offs < half_D
mask_second = (d_offs + half_D) < D
# Load cos/sin ONCE (shared across all heads for this position)
cs_base = pid_bt * stride_cs
cos_vals = tl.load(cos_ptr + cs_base + d_offs, mask=mask_d)
sin_vals = tl.load(sin_ptr + cs_base + d_offs, mask=mask_d)
# Load x for this head
x_base = pid_bt * stride_x_bt + pid_h * stride_x_h
x_d = tl.load(x_ptr + x_base + d_offs, mask=mask_half)
x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=mask_second)
# Rotate
out_first = x_d * cos_vals - x_dh * sin_vals
out_second = x_dh * cos_vals + x_d * sin_vals
# Write transposed: out[h, bt, d]
out_base = pid_h * T_strided + pid_bt * D
tl.store(out_ptr + out_base + d_offs, out_first, mask=mask_half)
tl.store(out_ptr + out_base + d_offs + half_D, out_second, mask=mask_second)@triton.jit
def _rope_transpose_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
T, H, D, half_D,
stride_x_bt, stride_x_h,
stride_cs,
stride_out_b, stride_out_h, stride_out_t,
BLOCK_D: tl.constexpr,
):
"""Rotate + transpose kernel. Writes directly to (B,H,T,D).
Grid: (1, B*T, H).
x: (B*T, H, D) flat view.
cos/sin: (B*T, D) flat.
out: (B, H, T, D)."""
pid_bt = tl.program_id(1)
pid_h = tl.program_id(2)
b = pid_bt // T
t = pid_bt % T
d_offs = tl.arange(0, BLOCK_D)
mask_d = d_offs < D
mask_half = d_offs < half_D
mask_second = (d_offs + half_D) < D
# Load cos/sin ONCE (shared across all heads for this position)
cs_base = pid_bt * stride_cs
cos_vals = tl.load(cos_ptr + cs_base + d_offs, mask=mask_d)
sin_vals = tl.load(sin_ptr + cs_base + d_offs, mask=mask_d)
# Load x for this head
x_base = pid_bt * stride_x_bt + pid_h * stride_x_h
x_d = tl.load(x_ptr + x_base + d_offs, mask=mask_half)
x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=mask_second)
# Rotate
out_first = x_d * cos_vals - x_dh * sin_vals
out_second = x_dh * cos_vals + x_d * sin_vals
# Write to (B,H,T,D) layout directly
out_base = b * stride_out_b + pid_h * stride_out_h + t * stride_out_t
tl.store(out_ptr + out_base + d_offs, out_first, mask=mask_half)
tl.store(out_ptr + out_base + d_offs + half_D, out_second, mask=mask_second)Edit applied successfully.--- a//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
+++ b//home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py
@@ -169,29 +169,29 @@
# Stage 2q: rotate + transpose q
q_flat = q.view(BT, Hq, D)
- q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=q.device)
+ q_rot = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
_rope_transpose_kernel[(1, BT, Hq)](
- q_flat, cos_mixed, sin_mixed, q_out,
- BT * D, Hq, D, half_D,
+ q_flat, cos_mixed, sin_mixed, q_rot,
+ T, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
+ q_rot.stride(0), q_rot.stride(1), q_rot.stride(2),
BLOCK_D=BLOCK_D,
num_warps=4,
)
- q_rot = q_out.view(Hq, B, T, D).transpose(0, 1).contiguous()
# Stage 2k: rotate + transpose k
k_flat = k.view(BT, Hkv, D)
- k_out = torch.empty(Hkv, BT, D, dtype=torch.bfloat16, device=k.device)
+ k_rot = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
_rope_transpose_kernel[(1, BT, Hkv)](
- k_flat, cos_mixed, sin_mixed, k_out,
- BT * D, Hkv, D, half_D,
+ k_flat, cos_mixed, sin_mixed, k_rot,
+ T, Hkv, D, half_D,
k_flat.stride(0), k_flat.stride(1),
cos_mixed.stride(0),
+ k_rot.stride(0), k_rot.stride(1), k_rot.stride(2),
BLOCK_D=BLOCK_D,
num_warps=4,
)
- k_rot = k_out.view(Hkv, B, T, D).transpose(0, 1).contiguous()
return q_rot, k_rot
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"newString": "<980 chars \u2014 see below>",
"oldString": "<999 chars \u2014 see below>"
} # Stage 2q: rotate + transpose q
q_flat = q.view(BT, Hq, D)
q_rot = torch.empty(B, Hq, T, D, dtype=torch.bfloat16, device=q.device)
_rope_transpose_kernel[(1, BT, Hq)](
q_flat, cos_mixed, sin_mixed, q_rot,
T, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
q_rot.stride(0), q_rot.stride(1), q_rot.stride(2),
BLOCK_D=BLOCK_D,
num_warps=4,
)
# Stage 2k: rotate + transpose k
k_flat = k.view(BT, Hkv, D)
k_rot = torch.empty(B, Hkv, T, D, dtype=torch.bfloat16, device=k.device)
_rope_transpose_kernel[(1, BT, Hkv)](
k_flat, cos_mixed, sin_mixed, k_rot,
T, Hkv, D, half_D,
k_flat.stride(0), k_flat.stride(1),
cos_mixed.stride(0),
k_rot.stride(0), k_rot.stride(1), k_rot.stride(2),
BLOCK_D=BLOCK_D,
num_warps=4,
) # Stage 2q: rotate + transpose q
q_flat = q.view(BT, Hq, D)
q_out = torch.empty(Hq, BT, D, dtype=torch.bfloat16, device=q.device)
_rope_transpose_kernel[(1, BT, Hq)](
q_flat, cos_mixed, sin_mixed, q_out,
BT * D, Hq, D, half_D,
q_flat.stride(0), q_flat.stride(1),
cos_mixed.stride(0),
BLOCK_D=BLOCK_D,
num_warps=4,
)
q_rot = q_out.view(Hq, B, T, D).transpose(0, 1).contiguous()
# Stage 2k: rotate + transpose k
k_flat = k.view(BT, Hkv, D)
k_out = torch.empty(Hkv, BT, D, dtype=torch.bfloat16, device=k.device)
_rope_transpose_kernel[(1, BT, Hkv)](
k_flat, cos_mixed, sin_mixed, k_out,
BT * D, Hkv, D, half_D,
k_flat.stride(0), k_flat.stride(1),
cos_mixed.stride(0),
BLOCK_D=BLOCK_D,
num_warps=4,
)
k_rot = k_out.view(Hkv, B, T, D).transpose(0, 1).contiguous()Edit applied successfully.{
"command": "uv run python check.py 2>&1 && uv run python benchmark.py 2>&1",
"timeout": 300000,
"workdir": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope",
"description": "Check and benchmark direct write"
}PASS
shape=0 variant=eager tflops=0.175 gbps=251.291 ms=0.359
shape=0 variant=compiled tflops=0.512 gbps=734.667 ms=0.123
shape=0 variant=solution tflops=0.585 gbps=839.744 ms=0.108
shape=0 solution_peak_fraction=0.4665
shape=1 variant=eager tflops=0.178 gbps=259.823 ms=0.566
shape=1 variant=compiled tflops=0.465 gbps=679.184 ms=0.216
shape=1 variant=solution tflops=0.644 gbps=940.746 ms=0.156
shape=1 solution_peak_fraction=0.5226
shape=2 variant=eager tflops=0.072 gbps=112.372 ms=0.197
shape=2 variant=compiled tflops=0.198 gbps=309.262 ms=0.072
shape=2 variant=solution tflops=0.200 gbps=312.125 ms=0.071
shape=2 solution_peak_fraction=0.1734
shape=3 variant=eager tflops=0.168 gbps=241.452 ms=1.496
shape=3 variant=compiled tflops=0.458 gbps=657.585 ms=0.549
shape=3 variant=solution tflops=0.741 gbps=1062.669 ms=0.340
shape=3 solution_peak_fraction=0.5904
peak_fraction: 0.3975
RESULT: OK
{
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py",
"offset": 79,
"limit": 40
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
79: stride_x_bt, stride_x_h,
80: stride_cs,
81: stride_out_b, stride_out_h, stride_out_t,
82: BLOCK_D: tl.constexpr,
83: ):
84: """Rotate + transpose kernel. Writes directly to (B,H,T,D).
85: Grid: (1, B*T, H).
86: x: (B*T, H, D) flat view.
87: cos/sin: (B*T, D) flat.
88: out: (B, H, T, D)."""
89: pid_bt = tl.program_id(1)
90: pid_h = tl.program_id(2)
91:
92: b = pid_bt // T
93: t = pid_bt % T
94:
95: d_offs = tl.arange(0, BLOCK_D)
96: mask_d = d_offs < D
97: mask_half = d_offs < half_D
98: mask_second = (d_offs + half_D) < D
99:
100: # Load cos/sin ONCE (shared across all heads for this position)
101: cs_base = pid_bt * stride_cs
102: cos_vals = tl.load(cos_ptr + cs_base + d_offs, mask=mask_d)
103: sin_vals = tl.load(sin_ptr + cs_base + d_offs, mask=mask_d)
104:
105: # Load x for this head
106: x_base = pid_bt * stride_x_bt + pid_h * stride_x_h
107: x_d = tl.load(x_ptr + x_base + d_offs, mask=mask_half)
108: x_dh = tl.load(x_ptr + x_base + d_offs + half_D, mask=mask_second)
109:
110: # Rotate
111: out_first = x_d * cos_vals - x_dh * sin_vals
112: out_second = x_dh * cos_vals + x_d * sin_vals
113:
114: # Write to (B,H,T,D) layout directly
115: out_base = b * stride_out_b + pid_h * stride_out_h + t * stride_out_t
116: tl.store(o
... (217 more chars){
"filePath": "/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py"
}<path>/home/infatoshi/cuda/KernelBench-Hard/problems/09_fmha_preattn_mrope/solution.py</path>
<type>file</type>
<content>
1: """Fused multi-axis RoPE pre-attention prep via Triton kernels.
2:
3: Two-stage with separate q and k rotation kernels:
4: Stage 1: grid=(B*T,). Gather+mix cos/sin once per position => (B*T, D).
5: Stage 2: grid=(B*T, Hq) for q, (B*T, Hkv) for k.
6: Each block loads cos/sin ONCE, applies rotate-half, writes transposed output.
7: """
8: from __future__ import annotations
9:
10: import torch
11: import torch.nn as nn
12: import triton
13: import triton.language as tl
14:
15: OP_TYPE = "rope"
16: SUPPORTED_PRECISIONS = ["bf16"]
17: HARDWARE_REQUIRED = ["RTX_PRO_6000"]
18:
19:
20: @triton.jit
21: def _mix_cos_sin_kernel(
22: cos_cache_ptr, sin_cache_ptr,
23: pos_t_ptr, pos_h_ptr, pos_w_ptr,
24: cos_out_ptr, sin_out_ptr,
25: B, T, D, max_pos,
26: s0, s1, s2, half_D,
27: stride_pos,
28: BLOCK_D: tl.constexpr,
29: ):
30: pid = tl.program_id(0)
31: b = pid // T
32: t = pid % T
33:
34: pos_base = b * stride_pos + t
35: pt = tl.load(pos_t_ptr + pos_base)
36: ph = tl.load(pos_h_ptr + pos_base)
37: pw = tl.load(pos_w_ptr + pos_base)
38:
39: d_offs = tl.arange(0, BLOCK_D)
40: mask = d_offs < D
41:
42: row_t = pt.to(tl.int64) * D
43: row_h = ph.to(tl.int64) * D
44: row_w = pw.to(tl.int64) * D
45:
46: cos_t = tl.load(cos_cache_ptr + row_t + d_offs, mask=mask)
47: cos_h
... (6639 more chars)