"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120). Per expert e: h_e = silu(x_e @ W_gate[e]) * (x_e @ W_up[e]) """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl def _num_sms() -> int: return torch.cuda.get_device_properties(0).multi_processor_count @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 256}, num_stages=2, num_warps=4), ], key=["H", "I", "E"], ) @triton.jit def _grouped_swiglu_kernel( a_ptr, b_gate_ptr, b_up_ptr, c_ptr, offsets_ptr, E, H, I, stride_am, stride_ak, stride_bg, stride_bh, stride_bi, stride_cm, stride_cn, NUM_SMS: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): tidx = tl.program_id(0) iterated_tiles = 0 for g in tl.range(E): m_start = tl.load(offsets_ptr + g) m_end = tl.load(offsets_ptr + g + 1) m_size = m_end - m_start num_m_tiles = tl.cdiv(m_size, BLOCK_M) num_n_tiles = tl.cdiv(I, BLOCK_N) num_tiles = num_m_tiles * num_n_tiles if m_size > 0: while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: gidx = tidx - iterated_tiles tile_m_idx = gidx % num_m_tiles tile_n_idx = gidx // num_m_tiles acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) for k_block in range(0, H, BLOCK_K): offs_k = k_block + tl.arange(0, BLOCK_K) a_ptrs = ( a_ptr + (m_start + offs_am[:, None]) * stride_am + offs_k[None, :] * stride_ak ) bg_ptrs = ( b_gate_ptr + g * stride_bg + offs_k[:, None] * stride_bh + offs_bn[None, :] * stride_bi ) bu_ptrs = ( b_up_ptr + g * stride_bg + offs_k[:, None] * stride_bh + offs_bn[None, :] * stride_bi ) a_mask = (offs_am[:, None] < m_size) & (offs_k[None, :] < H) b_mask = (offs_k[:, None] < H) & (offs_bn[None, :] < I) a = tl.load(a_ptrs, mask=a_mask, other=0.0) bg = tl.load(bg_ptrs, mask=b_mask, other=0.0) bu = tl.load(bu_ptrs, mask=b_mask, other=0.0) acc_gate = tl.dot(a, bg, acc_gate) acc_up = tl.dot(a, bu, acc_up) gate = acc_gate silu_gate = gate * tl.sigmoid(gate) c = (silu_gate * acc_up).to(tl.bfloat16) c_ptrs = ( c_ptr + (m_start + offs_am[:, None]) * stride_cm + offs_bn[None, :] * stride_cn ) c_mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < I) tl.store(c_ptrs, c, mask=c_mask) tidx += NUM_SMS iterated_tiles += num_tiles def grouped_swiglu( hidden_states: torch.Tensor, W_gate: torch.Tensor, W_up: torch.Tensor, expert_offsets: torch.Tensor, ) -> torch.Tensor: T_perm, H = hidden_states.shape E, H_w, I = W_gate.shape assert H == H_w and W_up.shape == W_gate.shape out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device) num_sms = _num_sms() _grouped_swiglu_kernel[(num_sms,)]( hidden_states, W_gate, W_up, out, expert_offsets, E, H, I, hidden_states.stride(0), hidden_states.stride(1), W_gate.stride(0), W_gate.stride(1), W_gate.stride(2), out.stride(0), out.stride(1), NUM_SMS=num_sms, ) return out class Model(nn.Module): def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741 super().__init__() self.T_total = T_total self.H = H self.I = I self.E = E self.K = K self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) nn.init.normal_(self.W_gate, std=0.02) nn.init.normal_(self.W_up, std=0.02) def forward( self, hidden_states: torch.Tensor, expert_offsets: torch.Tensor, ) -> torch.Tensor: return grouped_swiglu( hidden_states.contiguous(), self.W_gate.contiguous(), self.W_up.contiguous(), expert_offsets.contiguous(), ) T_total = 32768 H = 4096 I = 1536 # noqa: E741 E = 128 K = 8 def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor: T_perm = T_total * K base = T_perm // E rem = T_perm - base * E counts = torch.full((E,), base, dtype=torch.int32, device=device) counts[:rem] += 1 offsets = torch.zeros(E + 1, dtype=torch.int32, device=device) offsets[1:] = torch.cumsum(counts, dim=0) return offsets def get_inputs(): T_perm = T_total * K hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1 expert_offsets = _build_routing(T_total, E, K) return [hidden_states, expert_offsets] def get_init_inputs(): return [T_total, H, I, E, K]