"""Grouped GEMM + fused SwiGLU up-projection for top-K MoE FFN. Each expert e computes: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e]) Uses a Triton kernel with dual MMA accumulators and periodic accumulator reset to work around an SM120 code-generation issue. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "grouped_gemm_swiglu" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] T_total = 32768 H = 4096 I = 1536 # noqa: E741 E = 128 K = 8 @triton.jit def _kernel( hidden_states_ptr, W_gate_ptr, W_up_ptr, out_ptr, tile_to_expert_ptr, tile_row_start_ptr, tile_num_rows_ptr, H_dim, I_dim, stride_hs_m, stride_hs_k, stride_wg_e, stride_wg_k, stride_wg_n, stride_wu_e, stride_wu_k, stride_wu_n, stride_out_m, stride_out_n, num_n_tiles, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, RESET_EVERY: tl.constexpr, ): pid = tl.program_id(0) m_tile = pid // num_n_tiles n_tile = pid % num_n_tiles expert = tl.load(tile_to_expert_ptr + m_tile).to(tl.int32) row_start = tl.load(tile_row_start_ptr + m_tile).to(tl.int32) num_rows = tl.load(tile_num_rows_ptr + m_tile).to(tl.int32) rm = tl.arange(0, BLOCK_M) rn = n_tile * BLOCK_N + tl.arange(0, BLOCK_N) wg_off_e = expert * stride_wg_e wu_off_e = expert * stride_wu_e final_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) final_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k_outer in range(0, H_dim, BLOCK_K * RESET_EVERY): mma_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) mma_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) k_end = tl.minimum(k_outer + BLOCK_K * RESET_EVERY, H_dim) for k in range(k_outer, k_end, BLOCK_K): k_rem = H_dim - k rk = tl.arange(0, BLOCK_K) hs_off = ((row_start + rm[:, None]) * stride_hs_m + (k + rk[None, :]) * stride_hs_k) hs_mask = (rm[:, None] < num_rows) & (rk[None, :] < k_rem) x = tl.load(hidden_states_ptr + hs_off, mask=hs_mask, other=0.0) wg_off = (wg_off_e + (k + rk[:, None]) * stride_wg_k + rn[None, :] * stride_wg_n) wg_mask = (rk[:, None] < k_rem) & (rn[None, :] < I_dim) w_gate = tl.load(W_gate_ptr + wg_off, mask=wg_mask, other=0.0) wu_off = (wu_off_e + (k + rk[:, None]) * stride_wu_k + rn[None, :] * stride_wu_n) w_up = tl.load(W_up_ptr + wu_off, mask=wg_mask, other=0.0) mma_gate += tl.dot(x, w_gate) mma_up += tl.dot(x, w_up) final_gate += mma_gate final_up += mma_up silu_gate = final_gate * tl.sigmoid(final_gate) result = (silu_gate * final_up).to(tl.bfloat16) out_off = ((row_start + rm[:, None]) * stride_out_m + rn[None, :] * stride_out_n) out_mask = (rm[:, None] < num_rows) & (rn[None, :] < I_dim) tl.store(out_ptr + out_off, result, mask=out_mask) def _build_tile_map(expert_offsets, E, BLOCK_M, device): counts = expert_offsets[1:] - expert_offsets[:-1] num_tiles_per_expert = (counts + BLOCK_M - 1) // BLOCK_M total = int(num_tiles_per_expert.sum().item()) if total == 0: z = torch.zeros(1, dtype=torch.int32, device=device) return z, z, z, 0 tile_to_expert = torch.repeat_interleave( torch.arange(E, device=device), num_tiles_per_expert, ).to(torch.int32) parts = [torch.arange(int(nt.item()), device=device, dtype=torch.int32) for nt in num_tiles_per_expert if nt > 0] local_tile_idx = torch.cat(parts) first = expert_offsets[:-1].to(torch.int32) tile_row_start = (first[tile_to_expert.long()] + local_tile_idx * BLOCK_M).to(torch.int32) n_e_expanded = counts[tile_to_expert.long()] tile_num_rows = torch.clamp( n_e_expanded - local_tile_idx * BLOCK_M, 0, BLOCK_M, ).to(torch.int32) return tile_to_expert, tile_row_start, tile_num_rows, total class Model(nn.Module): def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741 super().__init__() self.T_total = T_total self.H = H self.I = I self.E = E self.K = K self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) nn.init.normal_(self.W_gate, std=0.02) nn.init.normal_(self.W_up, std=0.02) def forward(self, hidden_states, expert_offsets): T_perm_dim, H_dim = hidden_states.shape I_dim = self.I device = hidden_states.device out = torch.empty(T_perm_dim, I_dim, dtype=torch.bfloat16, device=device) # Per-shape optimal config (determined empirically). if I_dim >= 4096: BM, BN, BK, NW, RE = 64, 256, 32, 8, 8 elif H_dim >= 4000 and I_dim >= 1500: BM, BN, BK, NW, RE = 128, 128, 32, 8, 8 elif H_dim >= 2000 and I_dim >= 1000: BM, BN, BK, NW, RE = 64, 128, 32, 8, 16 else: BM, BN, BK, NW, RE = 64, 64, 64, 4, 16 t2e, trs, tnr, num_m_tiles = _build_tile_map( expert_offsets, self.E, BM, device, ) if num_m_tiles == 0: return out n_ntiles = (I_dim + BN - 1) // BN grid = (num_m_tiles * n_ntiles,) _kernel[grid]( hidden_states, self.W_gate, self.W_up, out, t2e, trs, tnr, H_dim, I_dim, H_dim, 1, H_dim * I_dim, I_dim, 1, H_dim * I_dim, I_dim, 1, I_dim, 1, n_ntiles, BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=BK, RESET_EVERY=RE, num_warps=NW, ) return out def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor: T_perm = T_total * K base = T_perm // E rem = T_perm - base * E counts = torch.full((E,), base, dtype=torch.int32, device=device) counts[:rem] += 1 offsets = torch.zeros(E + 1, dtype=torch.int32, device=device) offsets[1:] = torch.cumsum(counts, dim=0) return offsets def get_inputs(): T_perm = T_total * K hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1 expert_offsets = _build_routing(T_total, E, K) return [hidden_states, expert_offsets] def get_init_inputs(): return [T_total, H, I, E, K]