import math import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "attention" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] BATCH = 8 NUM_HEADS = 32 NUM_KV_HEADS = 8 HEAD_DIM = 128 SEQ_LEN = 1024 PAGE_SIZE = 16 @triton.jit def _partial_kernel( query, kv_cache, block_table, seq_lens, partial_m, partial_l, partial_acc, B: tl.constexpr, H: tl.constexpr, Hkv: tl.constexpr, D: tl.constexpr, P: tl.constexpr, MAX_BLOCKS: tl.constexpr, CHUNK: tl.constexpr, G: tl.constexpr, BG: tl.constexpr, SCALE: tl.constexpr, ): b = tl.program_id(0) kvh = tl.program_id(1) part = tl.program_id(2) offs_g = tl.arange(0, BG) offs_d = tl.arange(0, D) offs_m = part * CHUNK + tl.arange(0, CHUNK) seq_len = tl.load(seq_lens + b) valid_m = offs_m < seq_len q = tl.load( query + (b * H + kvh * G + offs_g[:, None]) * D + offs_d[None, :], mask=offs_g[:, None] < G, other=0.0, ) page_idx = offs_m // P page_off = offs_m - page_idx * P phys = tl.load(block_table + b * MAX_BLOCKS + page_idx, mask=valid_m, other=0) token_base = ((phys * P + page_off) * Hkv + kvh) * (2 * D) k = tl.load( kv_cache + token_base[:, None] + offs_d[None, :], mask=valid_m[:, None], other=0.0, ) scores = tl.dot(q, tl.trans(k), out_dtype=tl.float32) * SCALE scores = tl.where((offs_g[:, None] < G) & valid_m[None, :], scores, -float("inf")) m = tl.max(scores, axis=1) p = tl.exp(scores - m[:, None]) l = tl.sum(p, axis=1) v = tl.load( kv_cache + token_base[:, None] + D + offs_d[None, :], mask=valid_m[:, None], other=0.0, ) acc = tl.dot(p.to(tl.bfloat16), v, out_dtype=tl.float32) ml_base = ((part * B + b) * Hkv + kvh) * G tl.store(partial_m + ml_base + offs_g, m, mask=offs_g < G) tl.store(partial_l + ml_base + offs_g, l, mask=offs_g < G) acc_base = (ml_base + offs_g[:, None]) * D + offs_d[None, :] tl.store(partial_acc + acc_base, acc, mask=offs_g[:, None] < G) @triton.jit def _reduce_kernel( partial_m, partial_l, partial_acc, out, B: tl.constexpr, H: tl.constexpr, Hkv: tl.constexpr, D: tl.constexpr, NUM_PARTS: tl.constexpr, G: tl.constexpr, BG: tl.constexpr, ): b = tl.program_id(0) kvh = tl.program_id(1) offs_g = tl.arange(0, BG) offs_d = tl.arange(0, D) mask_g = offs_g < G m = tl.full((BG, D), -float("inf"), tl.float32) l = tl.zeros((BG, D), tl.float32) acc = tl.zeros((BG, D), tl.float32) for part in tl.static_range(0, NUM_PARTS): ml_base = ((part * B + b) * Hkv + kvh) * G pm = tl.load(partial_m + ml_base + offs_g, mask=mask_g, other=-float("inf")) pl = tl.load(partial_l + ml_base + offs_g, mask=mask_g, other=0.0) pa = tl.load( partial_acc + (ml_base + offs_g[:, None]) * D + offs_d[None, :], mask=mask_g[:, None], other=0.0, ) new_m = tl.maximum(m, pm[:, None]) old_scale = tl.exp(m - new_m) part_scale = tl.exp(pm[:, None] - new_m) acc = acc * old_scale + pa * part_scale l = l * old_scale + pl[:, None] * part_scale m = new_m result = acc / l tl.store( out + (b * H + kvh * G + offs_g[:, None]) * D + offs_d[None, :], result, mask=mask_g[:, None], ) def _select_config(batch: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int): if (batch, num_heads, num_kv_heads, head_dim, seq_len) == (8, 32, 8, 128, 1024): return 256, 8, 1 if (batch, num_heads, num_kv_heads, head_dim, seq_len) == (32, 32, 8, 128, 2048): return 256, 8, 2 if (batch, num_heads, num_kv_heads, head_dim, seq_len) == (4, 64, 8, 128, 4096): return 256, 4, 4 if (batch, num_heads, num_kv_heads, head_dim, seq_len) == (16, 32, 8, 128, 1535): return 256, 4, 8 if (batch, num_heads, num_kv_heads, head_dim, seq_len) == (8, 16, 4, 64, 2000): return 128, 4, 4 return 128, 4, 4 class Model(nn.Module): def __init__( self, batch: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int, page_size: int, ): super().__init__() self.batch = batch self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.seq_len = seq_len self.page_size = page_size self.group_size = num_heads // num_kv_heads self.block_g = triton.next_power_of_2(self.group_size) self.chunk_size, self.partial_warps, self.reduce_warps = _select_config( batch, num_heads, num_kv_heads, head_dim, seq_len ) self.num_parts = (seq_len + self.chunk_size - 1) // self.chunk_size self.register_buffer("_dummy", torch.zeros(1, dtype=torch.bfloat16), persistent=False) self._out = None self._partial_m = None self._partial_l = None self._partial_acc = None def _ensure_buffers(self, query: torch.Tensor): if self._out is not None and self._out.device == query.device: return device = query.device self._out = torch.empty((self.batch, self.num_heads, self.head_dim), device=device, dtype=torch.bfloat16) partial_shape = (self.num_parts, self.batch, self.num_kv_heads, self.group_size) self._partial_m = torch.empty(partial_shape, device=device, dtype=torch.float32) self._partial_l = torch.empty(partial_shape, device=device, dtype=torch.float32) self._partial_acc = torch.empty(partial_shape + (self.head_dim,), device=device, dtype=torch.float32) def forward(self, query, kv_cache, block_table, seq_lens): self._ensure_buffers(query) grid = (self.batch, self.num_kv_heads, self.num_parts) _partial_kernel[grid]( query, kv_cache, block_table, seq_lens, self._partial_m, self._partial_l, self._partial_acc, self.batch, self.num_heads, self.num_kv_heads, self.head_dim, self.page_size, block_table.shape[1], self.chunk_size, self.group_size, self.block_g, 1.0 / math.sqrt(self.head_dim), num_warps=self.partial_warps, ) _reduce_kernel[(self.batch, self.num_kv_heads)]( self._partial_m, self._partial_l, self._partial_acc, self._out, self.batch, self.num_heads, self.num_kv_heads, self.head_dim, self.num_parts, self.group_size, self.block_g, num_warps=self.reduce_warps, ) return self._out def get_inputs(): B = BATCH H = NUM_HEADS Hkv = NUM_KV_HEADS D = HEAD_DIM L = SEQ_LEN P = PAGE_SIZE pages_per_seq = (L + P - 1) // P total_pages = max(B * pages_per_seq + 8, 64) query = torch.randn(B, H, D, dtype=torch.bfloat16) * 0.1 kv_cache = torch.randn(total_pages, P, Hkv, 2 * D, dtype=torch.bfloat16) * 0.1 perm = torch.randperm(total_pages)[: B * pages_per_seq].reshape(B, pages_per_seq).int() block_table = perm.contiguous() seq_lens = torch.full((B,), L, dtype=torch.int32) return [query, kv_cache, block_table, seq_lens] def get_init_inputs(): return [BATCH, NUM_HEADS, NUM_KV_HEADS, HEAD_DIM, SEQ_LEN, PAGE_SIZE]