from __future__ import annotations import torch import torch.nn as nn import triton import triton.language as tl OP_TYPE = "grouped_gemm_swiglu" SUPPORTED_PRECISIONS = ["bf16"] HARDWARE_REQUIRED = ["RTX_PRO_6000", "H100", "B200"] @triton.jit def _swiglu_grouped_kernel( x, expert_offsets, w_gate, w_up, out, H: tl.constexpr, I: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) expert = tl.program_id(2) start = tl.load(expert_offsets + expert) end = tl.load(expert_offsets + expert + 1) count = end - start offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) m_mask = offs_m < count acc_gate = tl.zeros((BLOCK_M, BLOCK_N), tl.float32) acc_up = tl.zeros((BLOCK_M, BLOCK_N), tl.float32) for k0 in range(0, H, BLOCK_K): k = k0 + offs_k x_tile = tl.load( x + (start + offs_m[:, None]) * H + k[None, :], mask=m_mask[:, None], other=0.0, ) wg_tile = tl.load( w_gate + expert * H * I + k[:, None] * I + offs_n[None, :], ) wu_tile = tl.load( w_up + expert * H * I + k[:, None] * I + offs_n[None, :], ) acc_gate += tl.dot(x_tile, wg_tile, out_dtype=tl.float32) acc_up += tl.dot(x_tile, wu_tile, out_dtype=tl.float32) gate = acc_gate y = (gate / (1.0 + tl.exp(-gate))) * acc_up tl.store( out + (start + offs_m[:, None]) * I + offs_n[None, :], y, mask=m_mask[:, None], ) class Model(nn.Module): def __init__(self, T_total: int, H: int, I: int, E: int, K: int): # noqa: E741 super().__init__() self.T_total = T_total self.H = H self.I = I self.E = E self.K = K self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) nn.init.normal_(self.W_gate, std=0.02) nn.init.normal_(self.W_up, std=0.02) self._out = None def forward( self, hidden_states: torch.Tensor, expert_offsets: torch.Tensor, ) -> torch.Tensor: t_perm = hidden_states.shape[0] if ( self._out is None or self._out.shape != (t_perm, self.I) or self._out.device != hidden_states.device ): self._out = torch.empty( (t_perm, self.I), dtype=torch.bfloat16, device=hidden_states.device, ) block_m: int block_n: int block_k: int num_warps: int num_stages: int if self.H == 2048 and self.I == 1024: block_m, block_n, block_k, num_warps, num_stages = 128, 64, 32, 4, 4 elif self.H == 2048 and self.I == 4096: block_m, block_n, block_k, num_warps, num_stages = 128, 64, 32, 4, 3 else: block_m, block_n, block_k, num_warps, num_stages = 128, 64, 32, 4, 4 max_rows_per_expert = triton.cdiv(t_perm, self.E) grid = ( triton.cdiv(max_rows_per_expert, block_m), triton.cdiv(self.I, block_n), self.E, ) _swiglu_grouped_kernel[grid]( hidden_states, expert_offsets, self.W_gate, self.W_up, self._out, self.H, self.I, BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, num_warps=num_warps, num_stages=num_stages, ) return self._out T_total = 32768 H = 4096 I = 1536 # noqa: E741 E = 128 K = 8 def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor: t_perm = T_total * K base = t_perm // E rem = t_perm - base * E counts = torch.full((E,), base, dtype=torch.int32, device=device) counts[:rem] += 1 offsets = torch.zeros(E + 1, dtype=torch.int32, device=device) offsets[1:] = torch.cumsum(counts, dim=0) return offsets def get_inputs(): t_perm = T_total * K hidden_states = torch.randn(t_perm, H, dtype=torch.bfloat16) * 0.1 expert_offsets = _build_routing(T_total, E, K) return [hidden_states, expert_offsets] def get_init_inputs(): return [T_total, H, I, E, K]