"""Grouped GEMM + fused SwiGLU for MoE up-projection (Triton, SM120). Design: - Gate/up weights are packed once into a single (E, H, 2I) tensor whose columns interleave gate/up pairs (g0,u0,g1,u1,...). Each tile then needs a single B stream and a single tl.dot; the SwiGLU pairing is register-local in the mma accumulator layout (adjacent column pairs live in the same thread), so the epilogue split costs no shuffles. - Each program finds its expert by an in-register scan of expert_offsets (no host sync, no extra kernel). Grid is sized for the worst case; the few surplus programs exit immediately. - A-row indices are clamped instead of masked: out-of-slice rows load arbitrary in-bounds data and are discarded by the masked store. The k-loop therefore has no load masks at all. """ from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl @triton.jit def _grouped_swiglu_kernel( x_ptr, # (T_perm, H) bf16 w_ptr, # (E, H, 2I) bf16, gate/up column-interleaved out_ptr, # (T_perm, I) bf16 offs_ptr, # (E+1,) int32 T_perm, H: tl.constexpr, I: tl.constexpr, E: tl.constexpr, E_POW2: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # packed width (2x output cols per tile) BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, ): pid = tl.program_id(0) num_pid_m = tl.cdiv(T_perm, BLOCK_M) + E num_pid_n = tl.cdiv(2 * I, BLOCK_N) num_pid_in_group = GROUP_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_M group_size_m = min(num_pid_m - first_pid_m, GROUP_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # map flat m-tile -> (expert, tile within expert), all in registers eidx = tl.arange(0, E_POW2) offs_vec = tl.load(offs_ptr + eidx, mask=eidx <= E, other=2147483647) next_vec = tl.load(offs_ptr + eidx + 1, mask=eidx < E, other=2147483647) counts = tl.where(eidx < E, next_vec - offs_vec, 0) tiles = tl.cdiv(counts, BLOCK_M) incl = tl.cumsum(tiles, axis=0) total_m_tiles = tl.sum(tiles, axis=0) if pid_m >= total_m_tiles: return e = tl.sum((incl <= pid_m).to(tl.int32), axis=0) tile_start_e = tl.sum(tl.where(eidx == e, incl - tiles, 0), axis=0) row_start = tl.load(offs_ptr + e) row_end = tl.load(offs_ptr + e + 1) row0 = row_start + (pid_m - tile_start_e) * BLOCK_M rm = row0 + tl.arange(0, BLOCK_M) rm_ld = tl.minimum(rm, T_perm - 1) # clamp; garbage rows masked at store rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rk = tl.arange(0, BLOCK_K) x_ptrs = x_ptr + rm_ld[:, None] * H + rk[None, :] w_off = e.to(tl.int64) * H * (2 * I) w_ptrs = w_ptr + w_off + rk[:, None] * (2 * I) + rn[None, :] acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for _k in range(0, tl.cdiv(H, BLOCK_K)): a = tl.load(x_ptrs) w = tl.load(w_ptrs) acc = tl.dot(a, w, acc) x_ptrs += BLOCK_K w_ptrs += BLOCK_K * (2 * I) g, u = tl.split(tl.reshape(acc, (BLOCK_M, BLOCK_N // 2, 2))) out = g * tl.sigmoid(g) * u on = pid_n * (BLOCK_N // 2) + tl.arange(0, BLOCK_N // 2) out_ptrs = out_ptr + rm[:, None] * I + on[None, :] tl.store(out_ptrs, out.to(tl.bfloat16), mask=(rm < row_end)[:, None]) def _pick_config(T_perm: int, H: int, I: int, E: int): # noqa: E741 """(BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, num_warps, num_stages)""" if T_perm * I >= 64 * 1024 * 1024: return (128, 256, 64, 8, 8, 3) return (128, 128, 64, 8, 4, 3) _launch_cache: dict = {} def grouped_swiglu( hidden_states: torch.Tensor, w_packed: torch.Tensor, # (E, H, 2I) interleaved expert_offsets: torch.Tensor, I: int, # noqa: E741 ) -> torch.Tensor: T_perm, H = hidden_states.shape E = w_packed.shape[0] out = torch.empty(T_perm, I, dtype=torch.bfloat16, device=hidden_states.device) if T_perm == 0: return out cfg = _pick_config(T_perm, H, I, E) BM, BN, BK, GM, warps, stages = cfg E_POW2 = triton.next_power_of_2(E + 1) grid0 = (triton.cdiv(T_perm, BM) + E) * triton.cdiv(2 * I, BN) args = (hidden_states, w_packed, out, expert_offsets, T_perm, H, I, E, E_POW2, BM, BN, BK, GM) # Fast path: re-launch the cached compiled kernel directly, skipping the # Triton JIT dispatch layer (~6us/call). Specialization safety: the key # pins every value the binder specializes on; fresh torch allocations are # always >=16B aligned, so pointer-alignment specialization is stable. key = (T_perm, H, I, E, cfg, hidden_states.device.index) compiled = _launch_cache.get(key) if compiled is not None and ( hidden_states.data_ptr() | w_packed.data_ptr() | out.data_ptr() | expert_offsets.data_ptr() ) % 16 == 0: stream = torch.cuda.current_stream(hidden_states.device).cuda_stream compiled.run(grid0, 1, 1, stream, compiled.function, compiled.packed_metadata, None, None, None, *args) return out compiled = _grouped_swiglu_kernel[(grid0,)]( *args, num_warps=warps, num_stages=stages, ) _launch_cache[key] = compiled return out class Model(nn.Module): def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741 super().__init__() self.T_total = T_total self.H = H self.I = I self.E = E self.K = K self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) nn.init.normal_(self.W_gate, std=0.02) nn.init.normal_(self.W_up, std=0.02) self._w_packed: torch.Tensor | None = None self.register_load_state_dict_pre_hook(self._invalidate_cache) def _invalidate_cache(self, *args, **kwargs): self._w_packed = None def _packed(self) -> torch.Tensor: wp = self._w_packed if ( wp is None or wp.device != self.W_gate.device or wp.shape[1] != self.H ): E, H, I = self.W_gate.shape # noqa: E741 wp = torch.empty(E, H, 2 * I, dtype=torch.bfloat16, device=self.W_gate.device) wp[:, :, 0::2] = self.W_gate.detach() wp[:, :, 1::2] = self.W_up.detach() self._w_packed = wp return wp def forward( self, hidden_states: torch.Tensor, expert_offsets: torch.Tensor, ) -> torch.Tensor: return grouped_swiglu(hidden_states, self._packed(), expert_offsets, self.I) T_total = 32768 H = 4096 I = 1536 # noqa: E741 E = 128 K = 8 def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor: T_perm = T_total * K base = T_perm // E rem = T_perm - base * E counts = torch.full((E,), base, dtype=torch.int32, device=device) counts[:rem] += 1 offsets = torch.zeros(E + 1, dtype=torch.int32, device=device) offsets[1:] = torch.cumsum(counts, dim=0) return offsets def get_inputs(): T_perm = T_total * K hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1 expert_offsets = _build_routing(T_total, E, K) return [hidden_states, expert_offsets] def get_init_inputs(): return [T_total, H, I, E, K]