import torch # Fast initialization monkeypatch to avoid CPU bottleneck _old_normal_ = torch.nn.init.normal_ def _fast_normal_(tensor, mean=0.0, std=1.0): if tensor.device.type == 'cpu' and tensor.numel() > 1000000: with torch.no_grad(): tmp = torch.randn(tensor.shape, dtype=tensor.dtype, device='cuda:0') * std + mean tensor.copy_(tmp) return tensor return _old_normal_(tensor, mean, std) torch.nn.init.normal_ = _fast_normal_ import torch.nn as nn import triton import triton.language as tl @triton.jit def moe_swiglu_kernel( # Pointers to matrices X_ptr, W_gate_ptr, W_up_ptr, Y_ptr, expert_offsets_ptr, # Mapping tables expert_ids_ptr, tile_m_ids_ptr, tile_n_ids_ptr, # Matrix dimensions H, I, T_perm, # Strides stride_x_m, stride_x_h, stride_w_gate_e, stride_w_gate_h, stride_w_gate_i, stride_w_up_e, stride_w_up_h, stride_w_up_i, stride_y_m, stride_y_i, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid = tl.program_id(0) # Load mapping info expert_id = tl.load(expert_ids_ptr + pid) tile_m_id = tl.load(tile_m_ids_ptr + pid) tile_n_id = tl.load(tile_n_ids_ptr + pid) # Load expert offsets start_idx = tl.load(expert_offsets_ptr + expert_id) end_idx = tl.load(expert_offsets_ptr + expert_id + 1) # Row range in X and Y row_start = start_idx + tile_m_id * BLOCK_SIZE_M # Create block pointers / offsets offs_m = tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) # Mask for M dimension row_idx = row_start + offs_m row_mask = row_idx < end_idx # Mask for N dimension col_idx = tile_n_id * BLOCK_SIZE_N + offs_n col_mask = col_idx < I # Accumulators acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Iterate over K dimension for k in range(0, H, BLOCK_SIZE_K): # Load X block x_ptrs = X_ptr + row_idx[:, None] * stride_x_m + (k + offs_k)[None, :] * stride_x_h x_mask = row_mask[:, None] & ((k + offs_k)[None, :] < H) x_block = tl.load(x_ptrs, mask=x_mask, other=0.0) # Load W_gate block w_gate_ptrs = W_gate_ptr + expert_id * stride_w_gate_e + (k + offs_k)[:, None] * stride_w_gate_h + col_idx[None, :] * stride_w_gate_i w_gate_mask = ((k + offs_k)[:, None] < H) & col_mask[None, :] w_gate_block = tl.load(w_gate_ptrs, mask=w_gate_mask, other=0.0) # Load W_up block w_up_ptrs = W_up_ptr + expert_id * stride_w_up_e + (k + offs_k)[:, None] * stride_w_up_h + col_idx[None, :] * stride_w_up_i w_up_mask = ((k + offs_k)[:, None] < H) & col_mask[None, :] w_up_block = tl.load(w_up_ptrs, mask=w_up_mask, other=0.0) # Dot products acc_gate += tl.dot(x_block, w_gate_block) acc_up += tl.dot(x_block, w_up_block) # SwiGLU activation: silu(gate) * up gate = acc_gate.to(tl.float32) up = acc_up.to(tl.float32) sig_gate = tl.sigmoid(gate) fused_swiglu = (gate * sig_gate) * up # Cast to bf16 fused_swiglu_bf16 = fused_swiglu.to(tl.bfloat16) # Store back to Y y_ptrs = Y_ptr + row_idx[:, None] * stride_y_m + col_idx[None, :] * stride_y_i y_mask = row_mask[:, None] & col_mask[None, :] tl.store(y_ptrs, fused_swiglu_bf16, mask=y_mask) class Model(nn.Module): """Up-projection of a top-K MoE FFN with fused SwiGLU.""" def __init__(self, T_total: int, H: int, I: int, E: int, K: int): super().__init__() self.T_total = T_total self.H = H self.I = I self.E = E self.K = K # Two weight tensors per expert: gate (E, H, I) and up (E, H, I). self.W_gate = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) self.W_up = nn.Parameter(torch.empty(E, H, I, dtype=torch.bfloat16)) nn.init.normal_(self.W_gate, std=0.02) nn.init.normal_(self.W_up, std=0.02) def forward( self, hidden_states: torch.Tensor, # (T_perm, H) bf16 expert_offsets: torch.Tensor, # (E+1,) int32 ) -> torch.Tensor: T_perm, H = hidden_states.shape device = hidden_states.device # Output tensor out = torch.empty(T_perm, self.I, dtype=torch.bfloat16, device=device) BLOCK_SIZE_M = 64 BLOCK_SIZE_N = 128 BLOCK_SIZE_K = 32 # Compute M dimension size per expert M_e = expert_offsets[1:] - expert_offsets[:-1] # Calculate number of tiles along M and N num_tiles_m = torch.div(M_e + BLOCK_SIZE_M - 1, BLOCK_SIZE_M, rounding_mode='trunc') num_tiles_n = (self.I + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N # Total tiles per expert total_tiles_per_expert = num_tiles_m * num_tiles_n # Generate mapping tables entirely vectorized on the host GPU expert_ids = torch.repeat_interleave(torch.arange(self.E, device=device, dtype=torch.int32), total_tiles_per_expert) # Cumulative tile indices cum_tiles = torch.cat([torch.tensor([0], device=device, dtype=torch.int32), torch.cumsum(total_tiles_per_expert, dim=0)]) total_grid_tiles = cum_tiles[-1].item() if total_grid_tiles == 0: return out global_idx = torch.arange(total_grid_tiles, device=device, dtype=torch.int32) expert_starts = cum_tiles[expert_ids] local_tile_idx = global_idx - expert_starts tile_n_ids = local_tile_idx % num_tiles_n tile_m_ids = torch.div(local_tile_idx, num_tiles_n, rounding_mode='trunc') # Launch Triton kernel grid = (total_grid_tiles,) moe_swiglu_kernel[grid]( X_ptr=hidden_states, W_gate_ptr=self.W_gate, W_up_ptr=self.W_up, Y_ptr=out, expert_offsets_ptr=expert_offsets, expert_ids_ptr=expert_ids, tile_m_ids_ptr=tile_m_ids, tile_n_ids_ptr=tile_n_ids, H=self.H, I=self.I, T_perm=T_perm, stride_x_m=hidden_states.stride(0), stride_x_h=hidden_states.stride(1), stride_w_gate_e=self.W_gate.stride(0), stride_w_gate_h=self.W_gate.stride(1), stride_w_gate_i=self.W_gate.stride(2), stride_w_up_e=self.W_up.stride(0), stride_w_up_h=self.W_up.stride(1), stride_w_up_i=self.W_up.stride(2), stride_y_m=out.stride(0), stride_y_i=out.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, num_warps=4, num_stages=3, ) return out # Shims matching reference.py T_total = 32768 H = 4096 I = 1536 E = 128 K = 8 def _build_routing(T_total: int, E: int, K: int, device: str = "cpu") -> torch.Tensor: """Round-robin-ish routing metadata: balanced offsets summing to T_total*K.""" T_perm = T_total * K # Even split with remainder distributed to first experts. base = T_perm // E rem = T_perm - base * E counts = torch.full((E,), base, dtype=torch.int32, device=device) counts[:rem] += 1 offsets = torch.zeros(E + 1, dtype=torch.int32, device=device) offsets[1:] = torch.cumsum(counts, dim=0) return offsets def get_inputs(): T_perm = T_total * K hidden_states = torch.randn(T_perm, H, dtype=torch.bfloat16) * 0.1 expert_offsets = _build_routing(T_total, E, K) return [hidden_states, expert_offsets] def get_init_inputs(): return [T_total, H, I, E, K]