"""Kimi-Linear W4A16 hybrid decode -- single-launch megakernel solution. One cooperative CUDA kernel (torch load_inline, launched exactly once per step()) runs the entire per-token forward: all four layers (KDA x3 + MLA), every int4 dequant-GEMV fused (int4 weights streamed once, never materialized to bf16), the short causal conv + silu, the KDA gated-delta state update, the MLA absorbed-latent attention over the compressed KV cache (two streaming passes with an online softmax), the MoE router + top-8 + 9 expert FFNs, both RMSNorms and all residual adds. Stages inside the kernel are separated by a grid-wide sense-reversing barrier (~1us each, 14 per token; the launch is cooperative so all blocks are guaranteed co-resident). Within a stage, finer dependencies use producer-consumer spin counters instead of barriers: the KDA state update spins on per-head readiness of the q/k/v/g epilogues, the MoE router result gates the routed-expert GEMVs, and the expert down-projections spin on per-expert completion of silu(gate)*up. The int4 dequant is exact vs the reference (which materializes bf16((wu - z) * s) and matmuls in fp32): nibbles are splatted into bf16 lanes as (128+n) with one LOP3 (0x4300 | n), HSUB2 against (128+z) is exact, and HMUL2 by the bf16 scale applies the reference's single round-to-nearest. The GEMV accumulates those weights against the bf16-rounded rmsnorm activations in fp32, so router logits come out bit-identical to the reference's cublas path and MoE top-8 boundary flips are avoided. Split-K across blocks writes fp32 partial slices; the last-arriving block per column tile (atomic done counter) sums the slices and runs the per-column epilogue (conv window shift, RoPE, cache append, silu*up, residual add, ...). MLA never materializes k/v: q_nope is absorbed through the k-half of kv_b (q_abs = W_k^T q per head) so attention runs directly against the latent cache; the output latent is pushed back through the v-half of kv_b. The MLA latent cache grows one row per step. To keep step() a single kernel launch, the Model owns a capacity buffer: adopting an externally fed state (the first step of a run) sets an ingest flag and the kernel itself copies the old cache rows before the MLA stage reads them; the returned state holds views into the capacity buffer so subsequent steps append in place from inside the kernel. """ from __future__ import annotations from dataclasses import dataclass, field import torch import torch.nn as nn from torch.utils.cpp_extension import load_inline OP_TYPE = "kimi_linear_w4a16_decode" HARDWARE_REQUIRED = ["RTX_PRO_6000"] EPS = 1.0e-6 GROUP_SIZE = 128 NBLK = 188 NTHR = 512 DSMEM = 76800 CACHE_MARGIN = 512 # --------------------------------------------------------------------------- # # CUDA source: one __global__ megakernel + host launcher # --------------------------------------------------------------------------- # _CUDA_SRC = r""" #include #include #include #include #include #include #define NTHR 512 #define NWARP 16 #define NBLK 188 #define DSMEM 76800 #define HID 2304 #define K2_HID 1152 #define G_HID 18 #define NKD 4096 #define K2_NKD 2048 #define G_NKD 32 #define NQ 6144 #define NKVA 576 #define KVL 512 #define KVBN 8192 #define MI 1024 #define K2_MI 512 #define G_MI 8 #define KDA_SCALE 0.08838834764831845f #define MLA_SCALE 0.07216878364870323f #define RSCALE 2.446f #define ROPE_THETA 10000.0f using bf16 = __nv_bfloat16; static __device__ __forceinline__ float b2f(bf16 v) { return __bfloat162float(v); } static __device__ __forceinline__ bf16 f2b(float v) { return __float2bfloat16(v); } static __device__ __forceinline__ float rt(float v) { return b2f(f2b(v)); } // bf16 round-trip // full-rate bf16 round-to-nearest-even on normals (integer trick; avoids CVT pipe). // Matches reference dequant rounding: bf16((wu - z) * s). static __device__ __forceinline__ float rtb(float v) { unsigned u = __float_as_uint(v); u = (u + 0x7fffu + ((u >> 16) & 1u)) & 0xffff0000u; return __uint_as_float(u); } struct LayerP { // KDA: q,k,v,g,o | MLA: q_proj, kv_a, kv_b, o_proj, (dup) const uint8_t* wq[5]; const float* ws[5]; // fp32 scales (used by the kv_b absorb stages) const float* wb[5]; // fp32 -(z*s) const bf16* sc[5]; // bf16 scales (SIMD dequant path) const bf16* zp[5]; // bf16 (128 + z) const bf16 *conv, *betaw, *anorm, *mnorm, *router; // gate, up, down, s_gate, s_up, s_down (expert-major bases) const uint8_t* ewq[6]; const bf16* esc[6]; const bf16* ezp[6]; }; struct P { LayerP L[4]; bf16 *qkv; // [3][4096] float *gexp; // [4096] float *beta; // [32] bf16 *oatt; // [4096] bf16 *q6144; // [6144] bf16 *qabs; // [32][512] float *olat; // [32][512] float *logits; // [64] float *xm; // [9][1024] float *partial; // split-K slices unsigned *done; // tile counters (self-resetting) float *smax; // [NBLK][32][2] unsigned *bar; unsigned long long *stamps; // [64] per-stage clock64 marks (block 0, for profiling) bf16 *hid; // [2304] residual stream const bf16 *hin; bf16 *ckv; // [cap][512] bf16 *kro; // [cap][64] const bf16 *ckv_src, *kro_src; float *scores; // [cap][32] float *S[3]; bf16 *cw[3][3]; int pos; int ingest; }; // ------------------------------ grid barrier ------------------------------- // static __device__ __forceinline__ void gridbar(unsigned* bar, unsigned* lsense) { __syncthreads(); if (threadIdx.x == 0) { unsigned s = *lsense ^ 1u; *lsense = s; __threadfence(); volatile unsigned* cnt = bar; volatile unsigned* sense = bar + 32; unsigned a = atomicAdd((unsigned*)cnt, 1u); if (a == NBLK - 1u) { *cnt = 0u; __threadfence(); *sense = s; } else { while (*sense != s) { } } __threadfence(); } __syncthreads(); } // ------------------------------ smem layout -------------------------------- // extern __shared__ __align__(16) unsigned char SMEM[]; #define SM_XN ((float*)(SMEM)) // [4096] #define SM_XSUM ((float*)(SMEM + 16384)) // [32] #define SM_RED ((float*)(SMEM + 16512)) // [16*128] #define SM_TOPK_I ((int*)(SMEM + 24704)) // [8] #define SM_TOPK_C ((float*)(SMEM + 24704 + 32)) // [9] #define SM_PROBS ((float*)(SMEM + 24704 + 96)) // [64] #define SM_MISC ((float*)(SMEM + 25088)) // [>=512] // MLA attention stages: #define SM_QABS_T ((bf16*)(SMEM)) // [576][36] (pad 4 kills store bank conflicts) #define SM_ROWS ((bf16*)(SMEM + 41472)) // [64][128] #define SM_STILE ((float*)(SMEM + 57856)) // [64][32] #define SM_RMS ((float*)(SMEM + 66048)) // [2][32] running m/s // ------------------------------ x staging ---------------------------------- // // rmsnorm(src) * w -> xn (values are f32 of the bf16-rounded result) static __device__ void stage_x_norm(const bf16* __restrict__ src, const bf16* __restrict__ w, int K, int G) { float* xn = SM_XN; float ss = 0.f; for (int i = threadIdx.x; i < K; i += NTHR) { float v = b2f(src[i]); xn[i] = v; ss += v * v; } float* red = SM_RED; #pragma unroll for (int o = 16; o > 0; o >>= 1) ss += __shfl_down_sync(0xffffffffu, ss, o); if ((threadIdx.x & 31) == 0) red[threadIdx.x >> 5] = ss; __syncthreads(); if (threadIdx.x == 0) { float t = 0.f; #pragma unroll for (int i = 0; i < NWARP; i++) t += red[i]; red[256] = rsqrtf(t / (float)K + 1e-6f); } __syncthreads(); const float scale = red[256]; for (int i = threadIdx.x; i < K; i += NTHR) xn[i] = rt(xn[i] * scale * b2f(w[i])); __syncthreads(); } static __device__ void stage_x_raw(const bf16* __restrict__ src, int K, int G) { float* xn = SM_XN; for (int i = threadIdx.x; i < K; i += NTHR) xn[i] = b2f(src[i]); __syncthreads(); } static __device__ void stage_x_f32(const float* __restrict__ src, int K, int G) { float* xn = SM_XN; for (int i = threadIdx.x; i < K; i += NTHR) xn[i] = src[i]; __syncthreads(); } // --------------------------- fused int4 GEMV tile --------------------------- // // One 128-col tile, packed-row range [r0,r1). Row = k-pair; byte lo nibble = even k. // Returns reduced col sum on threads 0..127. // Dequant is exact vs the reference (bf16((n - z) * s)) but SIMD: nibbles are // splatted into bf16 lanes as (128+n) via one LOP3 (0x4300 | n), then // HSUB2((128+n),(128+z)) is exact and HMUL2(.,s) applies the single bf16-RN // rounding the reference has. No CVT-pipe traffic in the hot loop. static __device__ float gemv_tile(const uint8_t* __restrict__ wq, const bf16* __restrict__ sc, const bf16* __restrict__ zp, int N, int r0, int r1, int c0, int nvalid) { const float* xn = SM_XN; float* red = SM_RED; const int warp = threadIdx.x >> 5, lane = threadIdx.x & 31; const int cc = 4 * lane; const int c = c0 + cc; float a0 = 0.f, a1 = 0.f, a2 = 0.f, a3 = 0.f; if (cc < nvalid) { const int per = (r1 - r0 + NWARP - 1) >> 4; int r = r0 + warp * per; const int we = min(r1, r + per); const uint8_t* p = wq + (size_t)r * N + c; // uint32 W = cols c..c+3 for the k-pair (2r, 2r+1); the halfword-lane masks // pair columns (c,c+2) and (c+1,c+3). #define GPROC(W, XV) do { \ const unsigned lo02 = ((W) & 0x000F000Fu) | 0x43004300u; \ const unsigned hi02 = (((W) >> 4) & 0x000F000Fu) | 0x43004300u; \ const unsigned lo13 = (((W) >> 8) & 0x000F000Fu) | 0x43004300u; \ const unsigned hi13 = (((W) >> 12) & 0x000F000Fu) | 0x43004300u; \ const __nv_bfloat162 w02l = __hmul2(__hsub2(*(const __nv_bfloat162*)&lo02, z02), s02); \ const __nv_bfloat162 w02h = __hmul2(__hsub2(*(const __nv_bfloat162*)&hi02, z02), s02); \ const __nv_bfloat162 w13l = __hmul2(__hsub2(*(const __nv_bfloat162*)&lo13, z13), s13); \ const __nv_bfloat162 w13h = __hmul2(__hsub2(*(const __nv_bfloat162*)&hi13, z13), s13); \ a0 = fmaf(b2f(w02l.x), (XV).x, a0); \ a2 = fmaf(b2f(w02l.y), (XV).x, a2); \ a0 = fmaf(b2f(w02h.x), (XV).y, a0); \ a2 = fmaf(b2f(w02h.y), (XV).y, a2); \ a1 = fmaf(b2f(w13l.x), (XV).x, a1); \ a3 = fmaf(b2f(w13l.y), (XV).x, a3); \ a1 = fmaf(b2f(w13h.x), (XV).y, a1); \ a3 = fmaf(b2f(w13h.y), (XV).y, a3); \ } while (0) while (r < we) { const int g = r >> 6; const int seg = min(we, (g + 1) << 6); const uint2 sv = *(const uint2*)(sc + (size_t)g * N + c); const uint2 zv = *(const uint2*)(zp + (size_t)g * N + c); const unsigned s02u = __byte_perm(sv.x, sv.y, 0x5410); const unsigned s13u = __byte_perm(sv.x, sv.y, 0x7632); const unsigned z02u = __byte_perm(zv.x, zv.y, 0x5410); const unsigned z13u = __byte_perm(zv.x, zv.y, 0x7632); const __nv_bfloat162 s02 = *(const __nv_bfloat162*)&s02u; const __nv_bfloat162 s13 = *(const __nv_bfloat162*)&s13u; const __nv_bfloat162 z02 = *(const __nv_bfloat162*)&z02u; const __nv_bfloat162 z13 = *(const __nv_bfloat162*)&z13u; for (; r + 8 <= seg; r += 8, p += 8 * (size_t)N) { unsigned wv[8]; float2 xv[8]; #pragma unroll for (int u = 0; u < 8; ++u) wv[u] = *(const unsigned*)(p + (size_t)u * N); #pragma unroll for (int u = 0; u < 8; ++u) xv[u] = *(const float2*)(xn + 2 * (r + u)); #pragma unroll for (int u = 0; u < 8; ++u) { GPROC(wv[u], xv[u]); } } for (; r < seg; ++r, p += N) { const unsigned w = *(const unsigned*)p; const float2 xv = *(const float2*)(xn + 2 * r); GPROC(w, xv); } } #undef GPROC } __syncthreads(); float4* rw = (float4*)(red + warp * 128 + cc); *rw = make_float4(a0, a1, a2, a3); __syncthreads(); float v = 0.f; if (threadIdx.x < 128) { #pragma unroll for (int w = 0; w < NWARP; ++w) v += red[w * 128 + threadIdx.x]; } return v; } // write this task's fp32 partial slice; returns true on the block that arrived // last for the tile (that block then runs the epilogue). static __device__ bool arrive_tile(float v, float* __restrict__ slice, unsigned* done_ctr, unsigned target) { if (threadIdx.x < 128) slice[threadIdx.x] = v; __threadfence(); __shared__ unsigned lastflag; __syncthreads(); if (threadIdx.x == 0) { unsigned old = atomicAdd(done_ctr, 1u); lastflag = (old == target - 1u) ? 1u : 0u; if (lastflag) *done_ctr = 0u; } __syncthreads(); if (!lastflag) return false; __threadfence(); return true; } // ------------------------- small bf16 row-dot (4 rows) ---------------------- // static __device__ void rowdot4(const bf16* __restrict__ W, int row0, int K, float* out4) { const float* xn = SM_XN; const int warp = threadIdx.x >> 5, lane = threadIdx.x & 31; const int row = row0 + (warp >> 2); const int q = warp & 3; const int kq = K >> 2; const bf16* w = W + (size_t)row * K + q * kq; const float* x = xn + q * kq; float acc = 0.f; for (int i = lane * 2; i < kq; i += 64) { __nv_bfloat162 t = *(const __nv_bfloat162*)(w + i); acc = fmaf(b2f(t.x), x[i], acc); acc = fmaf(b2f(t.y), x[i + 1], acc); } #pragma unroll for (int o = 16; o > 0; o >>= 1) acc += __shfl_down_sync(0xffffffffu, acc, o); float* red = SM_RED; __syncthreads(); if (lane == 0) red[warp] = acc; __syncthreads(); if (threadIdx.x < 4) out4[threadIdx.x] = red[threadIdx.x * 4] + red[threadIdx.x * 4 + 1] + red[threadIdx.x * 4 + 2] + red[threadIdx.x * 4 + 3]; __syncthreads(); } // ------------------------------- math helpers ------------------------------ // static __device__ __forceinline__ float sigmoidf_(float x) { return 1.f / (1.f + expf(-x)); } static __device__ __forceinline__ float siluf_(float x) { return x * sigmoidf_(x); } static __device__ __forceinline__ float softplusf_(float x) { return (x > 20.f) ? x : log1pf(expf(x)); } static __device__ void rope_pair(int pos, int j, float e, float o, float* re, float* ro) { const float inv = powf(ROPE_THETA, -(float)j * (1.0f / 32.0f)); const float ang = (float)pos * inv; float s, c; sincosf(ang, &s, &c); *re = e * c - o * s; *ro = o * c + e * s; } // conv epilogue for stream js (0=q,1=k,2=v), column c, raw GEMV value v static __device__ void conv_epi(const P& p, int l, int js, int c, float v) { const bf16 val = f2b(v); const bf16* cwv = p.L[l].conv + ((size_t)js * NKD + c) * 4; const float w0 = b2f(cwv[0]), w1 = b2f(cwv[1]), w2 = b2f(cwv[2]), w3 = b2f(cwv[3]); bf16* win = p.cw[l][js]; const bf16 p0 = win[c], p1 = win[NKD + c], p2 = win[2 * NKD + c]; const float s = w0 * b2f(p0) + w1 * b2f(p1) + w2 * b2f(p2) + w3 * b2f(val); p.qkv[(size_t)js * NKD + c] = f2b(siluf_(s)); win[c] = p1; win[NKD + c] = p2; win[2 * NKD + c] = val; } // per-block redundant softmax + top-8 from logits scratch static __device__ void topk_local(const P& p) { float* probs = SM_PROBS; if (threadIdx.x < 32) { const int lane = threadIdx.x; float l0 = p.logits[lane], l1 = p.logits[32 + lane]; float m = fmaxf(l0, l1); #pragma unroll for (int o = 16; o > 0; o >>= 1) m = fmaxf(m, __shfl_xor_sync(0xffffffffu, m, o)); float e0 = expf(l0 - m), e1 = expf(l1 - m); float s = e0 + e1; #pragma unroll for (int o = 16; o > 0; o >>= 1) s += __shfl_xor_sync(0xffffffffu, s, o); probs[lane] = e0 / s; probs[32 + lane] = e1 / s; __syncwarp(); float wsum = 0.f; for (int r = 0; r < 8; ++r) { float v0 = probs[lane], v1 = probs[32 + lane]; float bv = v0; int bi = lane; if (v1 > bv) { bv = v1; bi = 32 + lane; } #pragma unroll for (int o = 16; o > 0; o >>= 1) { float ov = __shfl_xor_sync(0xffffffffu, bv, o); int oi = __shfl_xor_sync(0xffffffffu, bi, o); if (ov > bv || (ov == bv && oi < bi)) { bv = ov; bi = oi; } } if (lane == 0) { SM_TOPK_I[r] = bi; SM_TOPK_C[r] = bv; probs[bi] = -1.f; wsum += bv; } __syncwarp(); } if (lane == 0) { const float inv = RSCALE / (wsum + 1e-9f); #pragma unroll for (int r = 0; r < 8; ++r) SM_TOPK_C[r] *= inv; SM_TOPK_C[8] = 1.0f; } } __syncthreads(); } // ================================== stages ================================= // // KDA A1: q,k,v,g GEMVs (conv / gate epilogues) + beta row-dots #define RDONE 255 static __device__ void stage_a1(const P& p, int l, const bf16* hsrc) { if (blockIdx.x == 0 && threadIdx.x == 0) p.done[RDONE] = 0u; stage_x_norm(hsrc, p.L[l].anorm, HID, G_HID); for (int t = blockIdx.x; t < 520; t += NBLK) { if (t < 512) { const int job = t >> 7, r = t & 127, tile = r >> 2, kc = r & 3; const int nf0 = job * NKD + tile * 128; float v = gemv_tile(p.L[l].wq[job], p.L[l].sc[job], p.L[l].zp[job], NKD, kc * 288, kc * 288 + 288, tile * 128, 128); if (arrive_tile(v, p.partial + (size_t)kc * 16384 + nf0, p.done + job * 32 + tile, 4)) { if (threadIdx.x < 128) { float out = 0.f; #pragma unroll for (int k = 0; k < 4; ++k) out += p.partial[(size_t)k * 16384 + nf0 + threadIdx.x]; const int c = tile * 128 + threadIdx.x; if (job < 3) { conv_epi(p, l, job, c, out); } else { p.gexp[c] = expf(-softplusf_(rt(out))); } } __threadfence(); __syncthreads(); // head `tile` readiness (q,k,v,g each arrive once; beta adds 1 -> 5) if (threadIdx.x == 0) atomicAdd(p.done + 192 + tile, 1u); } } else { const int bt = t - 512; rowdot4(p.L[l].betaw, bt * 4, HID, SM_MISC); if (threadIdx.x < 4) p.beta[bt * 4 + threadIdx.x] = sigmoidf_(rt(SM_MISC[threadIdx.x])); __threadfence(); __syncthreads(); if (threadIdx.x < 4) atomicAdd(p.done + 192 + bt * 4 + threadIdx.x, 1u); __syncthreads(); } } // fused A2: gated delta state update; spins on per-head readiness instead of // a grid barrier. The 4th (h,jt) task to finish a head resets its counter. { float* kq = (float*)SMEM; float* kk = kq + 128; float* kv = kq + 256; float* kg = kq + 384; float* pr = kq + 512; float* dl = kq + 1024; for (int t = blockIdx.x; t < 128; t += NBLK) { const int h = t >> 2, jt = t & 3, c0 = h * 128; if (threadIdx.x == 0) { while (*(volatile unsigned*)(p.done + 192 + h) < 5u) { } } __syncthreads(); __threadfence(); if (threadIdx.x < 128) { kq[threadIdx.x] = b2f(p.qkv[c0 + threadIdx.x]) * KDA_SCALE; kk[threadIdx.x] = b2f(p.qkv[NKD + c0 + threadIdx.x]); kv[threadIdx.x] = b2f(p.qkv[2 * NKD + c0 + threadIdx.x]); kg[threadIdx.x] = p.gexp[c0 + threadIdx.x]; } __syncthreads(); const int ic = threadIdx.x >> 5, j = threadIdx.x & 31; const int col = jt * 32 + j; float* Sp = p.S[l] + ((size_t)h * 128 + ic * 8) * 128 + col; float sg[8]; float pp = 0.f; #pragma unroll for (int r = 0; r < 8; ++r) { const int i = ic * 8 + r; const float sv = Sp[(size_t)r * 128] * kg[i]; sg[r] = sv; pp = fmaf(sv, kk[i], pp); } pr[ic * 32 + j] = pp; __syncthreads(); if (ic == 0) { float sum = 0.f; #pragma unroll for (int q = 0; q < 16; ++q) sum += pr[q * 32 + j]; dl[j] = p.beta[h] * (kv[col] - sum); } __syncthreads(); const float d = dl[j]; float op = 0.f; #pragma unroll for (int r = 0; r < 8; ++r) { const int i = ic * 8 + r; const float sn = fmaf(kk[i], d, sg[r]); Sp[(size_t)r * 128] = sn; op = fmaf(sn, kq[i], op); } pr[ic * 32 + j] = op; __syncthreads(); if (ic == 0) { float sum = 0.f; #pragma unroll for (int q = 0; q < 16; ++q) sum += pr[q * 32 + j]; p.oatt[c0 + col] = f2b(sum); } __syncthreads(); if (threadIdx.x == 0) { const unsigned v2 = atomicAdd(p.done + 192 + h, 1u); if (v2 == 8u) p.done[192 + h] = 0u; } } } } // o_proj [4096 -> 2304] + residual (widx: 4 for KDA, 3 for MLA) static __device__ void stage_oproj(const P& p, int l, int widx, const bf16* hsrc) { stage_x_raw(p.oatt, NKD, G_NKD); for (int t = blockIdx.x; t < 144; t += NBLK) { const int tile = t >> 3, kc = t & 7; const int nf0 = tile * 128; float v = gemv_tile(p.L[l].wq[widx], p.L[l].sc[widx], p.L[l].zp[widx], HID, kc * 256, kc * 256 + 256, tile * 128, 128); if (arrive_tile(v, p.partial + (size_t)kc * HID + nf0, p.done + tile, 8)) { if (threadIdx.x < 128) { float out = 0.f; #pragma unroll for (int k = 0; k < 8; ++k) out += p.partial[(size_t)k * HID + nf0 + threadIdx.x]; const int c = tile * 128 + threadIdx.x; p.hid[c] = f2b(b2f(hsrc[c]) + rt(out)); } __syncthreads(); } } } // MoE B2 (router fused): blocks 0..15 compute the router rows and arrive on a // ready counter; the first 32 task slots are the shared expert (index-free), so // most blocks overlap the router latency; routed tasks spin briefly on the // counter before running the local top-8. static __device__ void stage_b2(const P& p, int l) { stage_x_norm(p.hid, p.L[l].mnorm, HID, G_HID); if (blockIdx.x < 16) { rowdot4(p.L[l].router, blockIdx.x * 4, HID, SM_MISC); if (threadIdx.x < 4) p.logits[blockIdx.x * 4 + threadIdx.x] = rt(SM_MISC[threadIdx.x]); __threadfence(); __syncthreads(); if (threadIdx.x == 0) atomicAdd(p.done + RDONE, 1u); } bool have_topk = false; for (int t = (blockIdx.x + 172) % NBLK; t < 612; t += NBLK) { if (t < 288) { // gate/up task; slot remap: first 32 tasks = shared expert (router-free) const int to = (t < 32) ? (256 + t) : (t - 32); const int e = to >> 5, r = to & 31, gu = r >> 4, r2 = r & 15, tile = r2 >> 1, kc = r2 & 1; if (e < 8 && !have_topk) { if (threadIdx.x == 0) { while (*(volatile unsigned*)(p.done + RDONE) < 16u) { } } __syncthreads(); __threadfence(); topk_local(p); have_topk = true; } const uint8_t* wq; const bf16 *sc, *zp; if (e < 8) { const size_t ei = (size_t)SM_TOPK_I[e]; wq = p.L[l].ewq[gu] + ei * (K2_HID * MI); sc = p.L[l].esc[gu] + ei * (G_HID * MI); zp = p.L[l].ezp[gu] + ei * (G_HID * MI); } else { wq = p.L[l].ewq[3 + gu]; sc = p.L[l].esc[3 + gu]; zp = p.L[l].ezp[3 + gu]; } const int nf0 = e * MI + tile * 128; float v = gemv_tile(wq, sc, zp, MI, kc * 576, kc * 576 + 576, tile * 128, 128); if (arrive_tile(v, p.partial + (size_t)(gu * 2 + kc) * 9216 + nf0, p.done + e * 8 + tile, 4)) { if (threadIdx.x < 128) { const size_t o = (size_t)nf0 + threadIdx.x; const float gs = p.partial[o] + p.partial[9216 + o]; const float us = p.partial[2 * 9216 + o] + p.partial[3 * 9216 + o]; p.xm[o] = siluf_(gs) * us; } __threadfence(); __syncthreads(); // xm[e] tile readiness for the fused down-GEMV tasks if (threadIdx.x == 0) atomicAdd(p.done + 176 + e, 1u); } } else { // down task (fused former B3); shared expert slot ordered first const int tb = t - 288; const int er = tb / 36, r3 = tb % 36; const int e = (er + 8) % 9; const int tile = r3 >> 1, kc = r3 & 1; if (e < 8 && !have_topk) { if (threadIdx.x == 0) { while (*(volatile unsigned*)(p.done + RDONE) < 16u) { } } __syncthreads(); __threadfence(); topk_local(p); have_topk = true; } // wait until all 8 xm tiles of this expert slot are written if (threadIdx.x == 0) { while (*(volatile unsigned*)(p.done + 176 + e) < 8u) { } } __syncthreads(); __threadfence(); const uint8_t* wq; const bf16 *sc, *zp; if (e < 8) { const size_t ei = (size_t)SM_TOPK_I[e]; wq = p.L[l].ewq[2] + ei * (K2_MI * HID); sc = p.L[l].esc[2] + ei * (G_MI * HID); zp = p.L[l].ezp[2] + ei * (G_MI * HID); } else { wq = p.L[l].ewq[5]; sc = p.L[l].esc[5]; zp = p.L[l].ezp[5]; } stage_x_f32(p.xm + (size_t)e * MI, MI, G_MI); const float coef = (e < 8) ? SM_TOPK_C[e] : 1.0f; float v = gemv_tile(wq, sc, zp, HID, kc * 256, kc * 256 + 256, tile * 128, 128) * coef; const int nf0 = tile * 128; if (arrive_tile(v, p.partial + 73728 + (size_t)(e * 2 + kc) * HID + nf0, p.done + 128 + tile, 18)) { if (threadIdx.x < 128) { float y = 0.f; #pragma unroll for (int k = 0; k < 18; ++k) y += p.partial[73728 + (size_t)k * HID + nf0 + threadIdx.x]; const int c = tile * 128 + threadIdx.x; p.hid[c] = f2b(b2f(p.hid[c]) + rt(y)); } __syncthreads(); } // release the edone counter (8 ready + 36 down tasks = 44 -> reset at 43) if (threadIdx.x == 0) { const unsigned v2 = atomicAdd(p.done + 176 + e, 1u); if (v2 == 43u) p.done[176 + e] = 0u; } } } } // q_abs for one head (absorb q_nope through the kv_b k-half) static __device__ void qabs_head(const P& p, int h) { float* qn = (float*)SMEM; // [128] float* s_sm = qn + 128; // [4][128] float* b_sm = qn + 640; // [4][128] float* qh = qn + 1152; // [2][256][2] { if (threadIdx.x < 128) qn[threadIdx.x] = b2f(p.q6144[h * 192 + threadIdx.x]); { const int g = threadIdx.x >> 7, d = threadIdx.x & 127; const int col = h * 256 + d; s_sm[g * 128 + d] = p.L[3].ws[2][(size_t)g * KVBN + col]; b_sm[g * 128 + d] = p.L[3].wb[2][(size_t)g * KVBN + col]; } __syncthreads(); { const int cp = threadIdx.x & 255, half = threadIdx.x >> 8; const int gcp = cp >> 6; const uint8_t* wp = p.L[3].wq[2] + (size_t)cp * KVBN + h * 256 + half * 64; const float* sg = s_sm + gcp * 128 + half * 64; const float* bg = b_sm + gcp * 128 + half * 64; const float* qg = qn + half * 64; float alo = 0.f, ahi = 0.f; #pragma unroll for (int dd = 0; dd < 64; dd += 16) { const uint4 w16 = *(const uint4*)(wp + dd); const unsigned ww[4] = {w16.x, w16.y, w16.z, w16.w}; #pragma unroll for (int u = 0; u < 4; ++u) { const unsigned w = ww[u]; const int d0 = dd + u * 4; #pragma unroll for (int i = 0; i < 4; ++i) { const unsigned nlo = (w >> (8 * i)) & 15u; const unsigned nhi = (w >> (8 * i + 4)) & 15u; const float s = sg[d0 + i], b = bg[d0 + i], x = qg[d0 + i]; alo = fmaf(rtb(fmaf((float)nlo, s, b)), x, alo); ahi = fmaf(rtb(fmaf((float)nhi, s, b)), x, ahi); } } } qh[(half * 256 + cp) * 2] = alo; qh[(half * 256 + cp) * 2 + 1] = ahi; } __syncthreads(); if (threadIdx.x < 256) { const float lo = qh[threadIdx.x * 2] + qh[(256 + threadIdx.x) * 2]; const float hi = qh[threadIdx.x * 2 + 1] + qh[(256 + threadIdx.x) * 2 + 1]; p.qabs[(size_t)h * KVL + 2 * threadIdx.x] = f2b(lo); p.qabs[(size_t)h * KVL + 2 * threadIdx.x + 1] = f2b(hi); } __syncthreads(); } } // MLA M1: q_proj + kv_a GEMVs (rope + cache-append epilogues) static __device__ void stage_m1(const P& p, const bf16* hsrc) { if (blockIdx.x == 0 && threadIdx.x == 0) p.done[RDONE] = 0u; stage_x_norm(hsrc, p.L[3].anorm, HID, G_HID); for (int t = blockIdx.x; t < 318; t += NBLK) { const int tj = t / 6, kc = t % 6; const bool isq = (tj < 48); const int c0 = isq ? tj * 128 : (tj - 48) * 128; const int N = isq ? NQ : NKVA; const int nvalid = (!isq && tj == 52) ? 64 : 128; const int wj = isq ? 0 : 1; const int nf0 = tj * 128; float v = gemv_tile(p.L[3].wq[wj], p.L[3].sc[wj], p.L[3].zp[wj], N, kc * 192, kc * 192 + 192, c0, nvalid); if (arrive_tile(v, p.partial + (size_t)kc * 6784 + nf0, p.done + tj, 6)) { if (threadIdx.x < nvalid) { float out = 0.f; #pragma unroll for (int k = 0; k < 6; ++k) out += p.partial[(size_t)k * 6784 + nf0 + threadIdx.x]; if (isq) { const int col = c0 + threadIdx.x; const int rr = col % 192; if (rr < 128) { p.q6144[col] = f2b(out); } else if (!(rr & 1)) { float odd = 0.f; #pragma unroll for (int k = 0; k < 6; ++k) odd += p.partial[(size_t)k * 6784 + nf0 + threadIdx.x + 1]; float re, ro; rope_pair(p.pos, (rr - 128) >> 1, rt(out), rt(odd), &re, &ro); p.q6144[col] = f2b(re); p.q6144[col + 1] = f2b(ro); } } else { const int c2 = c0 + threadIdx.x; if (c2 < KVL) { p.ckv[(size_t)p.pos * KVL + c2] = f2b(out); } else if (!(c2 & 1)) { float odd = 0.f; #pragma unroll for (int k = 0; k < 6; ++k) odd += p.partial[(size_t)k * 6784 + nf0 + threadIdx.x + 1]; float re, ro; rope_pair(p.pos, (c2 - KVL) >> 1, rt(out), rt(odd), &re, &ro); p.kro[(size_t)p.pos * 64 + (c2 - KVL)] = f2b(re); p.kro[(size_t)p.pos * 64 + (c2 - KVL) + 1] = f2b(ro); } } } __threadfence(); __syncthreads(); if (isq && threadIdx.x == 0) { // arrive on the head whose q_nope window this tile covers (at most one) const int h0 = (2 * tj) / 3; #pragma unroll for (int hh = h0 - 1; hh <= h0 + 1; ++hh) { if (hh >= 0 && hh < 32) { const int ta = (3 * hh) >> 1, tb2 = (192 * hh + 127) >> 7; if (tj >= ta && tj <= tb2) atomicAdd(p.done + 72 + hh, 1u); } } } __syncthreads(); } } // olat zeroing + one-time cache ingest (consumed by M3/M4, barriers away) { const int gt = blockIdx.x * NTHR + threadIdx.x; for (int i = gt; i < 32 * KVL; i += NBLK * NTHR) p.olat[i] = 0.f; if (p.ingest) { const uint4* cs = (const uint4*)p.ckv_src; uint4* cd = (uint4*)p.ckv; const int nc = p.pos * (KVL / 8); for (int i = gt; i < nc; i += NBLK * NTHR) cd[i] = cs[i]; const uint4* ks = (const uint4*)p.kro_src; uint4* kd = (uint4*)p.kro; const int nk = p.pos * 8; for (int i = gt; i < nk; i += NBLK * NTHR) kd[i] = ks[i]; } } // fused q_abs (former M2): per-head spin on the q_nope tile readiness for (int t = blockIdx.x; t < 32; t += NBLK) { const int h = t; if (threadIdx.x == 0) { const unsigned tgt = 1u + (unsigned)(h & 1); while (*(volatile unsigned*)(p.done + 72 + h) < tgt) { } p.done[72 + h] = 0u; // single consumer per head } __syncthreads(); __threadfence(); qabs_head(p, h); } } // MLA M3: scores pass (streams cache once), online softmax partials per block static __device__ void stage_m3(const P& p) { bf16* qat = SM_QABS_T; // coalesced reads, smem-side scatter for the [k][h] transpose for (int i = threadIdx.x; i < 32 * KVL; i += NTHR) { const int h = i >> 9, k = i & 511; qat[k * 36 + h] = p.qabs[i]; } for (int i = threadIdx.x; i < 32 * 64; i += NTHR) { const int h = i >> 6, r = i & 63; qat[(KVL + r) * 36 + h] = p.q6144[h * 192 + 128 + r]; } float* rm = SM_RMS; float* rs = SM_RMS + 32; if (threadIdx.x < 32) { rm[threadIdx.x] = -INFINITY; rs[threadIdx.x] = 0.f; } __syncthreads(); const int chunk = (p.pos + NBLK) / NBLK; const int l0 = blockIdx.x * chunk; const int lend = min(p.pos + 1, l0 + chunk); const int tl = threadIdx.x >> 3, hg = threadIdx.x & 7; for (int lt = l0; lt < lend; lt += 64) { const int nrows = min(64, lend - lt); float acc0 = 0.f, acc1 = 0.f, acc2 = 0.f, acc3 = 0.f; for (int kc = 0; kc < 5; ++kc) { const int kw = (kc < 4) ? 128 : 64; __syncthreads(); if (kc < 4) { for (int i = threadIdx.x; i < nrows * 16; i += NTHR) { const int rr = i >> 4, seg = i & 15; ((uint4*)(SM_ROWS + rr * 128))[seg] = ((const uint4*)(p.ckv + (size_t)(lt + rr) * KVL + kc * 128))[seg]; } } else { for (int i = threadIdx.x; i < nrows * 8; i += NTHR) { const int rr = i >> 3, seg = i & 7; ((uint4*)(SM_ROWS + rr * 128))[seg] = ((const uint4*)(p.kro + (size_t)(lt + rr) * 64))[seg]; } } __syncthreads(); if (tl < nrows) { const bf16* arow = SM_ROWS + tl * 128; const bf16* bt = qat + (size_t)kc * 128 * 36 + hg * 4; #pragma unroll 4 for (int kk = 0; kk < kw; kk += 2) { const __nv_bfloat162 av = *(const __nv_bfloat162*)(arow + kk); const float a0 = b2f(av.x), a1 = b2f(av.y); const uint2 b0 = *(const uint2*)(bt + kk * 36); const uint2 b1 = *(const uint2*)(bt + kk * 36 + 36); const __nv_bfloat162 b0lo = *(const __nv_bfloat162*)&b0.x; const __nv_bfloat162 b0hi = *(const __nv_bfloat162*)&b0.y; const __nv_bfloat162 b1lo = *(const __nv_bfloat162*)&b1.x; const __nv_bfloat162 b1hi = *(const __nv_bfloat162*)&b1.y; acc0 = fmaf(a0, b2f(b0lo.x), fmaf(a1, b2f(b1lo.x), acc0)); acc1 = fmaf(a0, b2f(b0lo.y), fmaf(a1, b2f(b1lo.y), acc1)); acc2 = fmaf(a0, b2f(b0hi.x), fmaf(a1, b2f(b1hi.x), acc2)); acc3 = fmaf(a0, b2f(b0hi.y), fmaf(a1, b2f(b1hi.y), acc3)); } } } __syncthreads(); if (tl < nrows) { const float4 sc = make_float4(acc0 * MLA_SCALE, acc1 * MLA_SCALE, acc2 * MLA_SCALE, acc3 * MLA_SCALE); *(float4*)(SM_STILE + tl * 32 + hg * 4) = sc; *(float4*)(p.scores + (size_t)(lt + tl) * 32 + hg * 4) = sc; } __syncthreads(); if (threadIdx.x < 32) { float m = rm[threadIdx.x], s = rs[threadIdx.x]; for (int rr = 0; rr < nrows; ++rr) { const float v = SM_STILE[rr * 32 + threadIdx.x]; const float mn = fmaxf(m, v); s = s * expf(m - mn) + expf(v - mn); m = mn; } rm[threadIdx.x] = m; rs[threadIdx.x] = s; } __syncthreads(); } if (threadIdx.x < 32) { p.smax[(size_t)blockIdx.x * 64 + threadIdx.x * 2] = rm[threadIdx.x]; p.smax[(size_t)blockIdx.x * 64 + threadIdx.x * 2 + 1] = rs[threadIdx.x]; } } // o = olat[h] through the kv_b v-half, one head static __device__ void m5_head(const P& p, int h) { float* ol = (float*)SMEM; // [512] float* red2 = ol + 528; // [4][128] { if (threadIdx.x < 512) ol[threadIdx.x] = p.olat[(size_t)h * KVL + threadIdx.x]; __syncthreads(); { const int cc = threadIdx.x >> 7, d = threadIdx.x & 127; const int col = h * 256 + 128 + d; const uint8_t* wp = p.L[3].wq[2] + (size_t)(cc * 64) * KVBN + col; const float s = p.L[3].ws[2][(size_t)cc * KVBN + col]; const float bz = p.L[3].wb[2][(size_t)cc * KVBN + col]; float qd = 0.f; for (int r0 = 0; r0 < 64; r0 += 16) { unsigned bb[16]; #pragma unroll for (int u = 0; u < 16; ++u) bb[u] = wp[(size_t)(r0 + u) * KVBN]; #pragma unroll for (int u = 0; u < 16; ++u) { const int r = r0 + u; qd = fmaf(rtb(fmaf((float)(bb[u] & 15u), s, bz)), ol[cc * 128 + 2 * r], qd); qd = fmaf(rtb(fmaf((float)((bb[u] >> 4) & 15u), s, bz)), ol[cc * 128 + 2 * r + 1], qd); } } red2[cc * 128 + d] = qd; } __syncthreads(); if (threadIdx.x < 128) { const float o = red2[threadIdx.x] + red2[128 + threadIdx.x] + red2[256 + threadIdx.x] + red2[384 + threadIdx.x]; p.oatt[h * 128 + threadIdx.x] = f2b(o); } __syncthreads(); } } // MLA M4: finalize softmax; weighted-sum pass (streams cache again) -> olat. // Thread owns (8 heads stride 4) x (one c-quad): row values are read once per // 4 FMAs and p reads are warp-uniform broadcasts, so smem traffic stays near // the FMA floor. static __device__ void stage_m4(const P& p) { bf16* rows = (bf16*)SMEM; // [64][512] float* stile = (float*)(SMEM + 65536); // [64][32] float* mrz = (float*)(SMEM + 73728); // [2][32] float* red = mrz + 64; // [16][32] { const int h = threadIdx.x & 31, st = threadIdx.x >> 5; float m = -INFINITY; for (int b = st; b < NBLK; b += 16) m = fmaxf(m, p.smax[(size_t)b * 64 + h * 2]); red[st * 32 + h] = m; __syncthreads(); if (st == 0) { float mm = red[h]; #pragma unroll for (int q = 1; q < 16; ++q) mm = fmaxf(mm, red[q * 32 + h]); mrz[h] = mm; } __syncthreads(); const float M = mrz[h]; float s = 0.f; for (int b = st; b < NBLK; b += 16) { const float mb = p.smax[(size_t)b * 64 + h * 2]; const float sb = p.smax[(size_t)b * 64 + h * 2 + 1]; s += sb * expf(mb - M); } red[st * 32 + h] = s; __syncthreads(); if (st == 0) { float ss = 0.f; #pragma unroll for (int q = 0; q < 16; ++q) ss += red[q * 32 + h]; mrz[32 + h] = 1.0f / ss; } __syncthreads(); } // 188 blocks = 47 l-chunks x 4 head-groups: each block accumulates 8 heads // over its chunk (rows re-read by the 4 groups come from L2), cutting the // global fp32 atomic count 4x vs an all-heads-per-block layout. const int chunk = (p.pos + 47) / 47; const int lc = blockIdx.x >> 2, hg4 = blockIdx.x & 3; const int l0 = lc * chunk; const int lend = min(p.pos + 1, l0 + chunk); const int cq = threadIdx.x & 127; // c = 4*cq .. 4*cq+3 const int hq2 = threadIdx.x >> 7; // h = hg4*8 + hq2 + 4m, m in {0,1} const int h0 = hg4 * 8 + hq2; float acc[8]; #pragma unroll for (int i = 0; i < 8; ++i) acc[i] = 0.f; for (int lt = l0; lt < lend; lt += 64) { const int nrows = min(64, lend - lt); __syncthreads(); for (int i = threadIdx.x; i < nrows * 32; i += NTHR) { const int rr = i >> 5, h = i & 31; stile[i] = expf(p.scores[(size_t)(lt + rr) * 32 + h] - mrz[h]) * mrz[32 + h]; } for (int i = threadIdx.x; i < nrows * 64; i += NTHR) { const int rr = i >> 6, seg = i & 63; ((uint4*)(rows + rr * 512))[seg] = ((const uint4*)(p.ckv + (size_t)(lt + rr) * KVL))[seg]; } __syncthreads(); for (int rr = 0; rr < nrows; ++rr) { const uint2 rv = *(const uint2*)(rows + rr * 512 + cq * 4); const __nv_bfloat162 c01 = *(const __nv_bfloat162*)&rv.x; const __nv_bfloat162 c23 = *(const __nv_bfloat162*)&rv.y; const float x0 = b2f(c01.x), x1 = b2f(c01.y), x2 = b2f(c23.x), x3 = b2f(c23.y); const float* pr = stile + rr * 32 + h0; #pragma unroll for (int m = 0; m < 2; ++m) { const float pv = pr[4 * m]; acc[m * 4 + 0] = fmaf(pv, x0, acc[m * 4 + 0]); acc[m * 4 + 1] = fmaf(pv, x1, acc[m * 4 + 1]); acc[m * 4 + 2] = fmaf(pv, x2, acc[m * 4 + 2]); acc[m * 4 + 3] = fmaf(pv, x3, acc[m * 4 + 3]); } } } #pragma unroll for (int m = 0; m < 2; ++m) #pragma unroll for (int j = 0; j < 4; ++j) atomicAdd(&p.olat[(size_t)(h0 + 4 * m) * KVL + cq * 4 + j], acc[m * 4 + j]); __threadfence(); __syncthreads(); if (threadIdx.x == 0) atomicAdd(p.done + 104 + hg4, 1u); // fused former M5: o = olat through the kv_b v-half, spinning per head-group // (47 l-chunk blocks per group; +8 consumer arrivals -> reset at 54) for (int t = blockIdx.x; t < 32; t += NBLK) { const int h = t; if (threadIdx.x == 0) { while (*(volatile unsigned*)(p.done + 104 + (h >> 3)) < 47u) { } } __syncthreads(); __threadfence(); m5_head(p, h); if (threadIdx.x == 0) { const unsigned v2 = atomicAdd(p.done + 104 + (h >> 3), 1u); if (v2 == 54u) p.done[104 + (h >> 3)] = 0u; } } } // ================================ megakernel =============================== // #define STAMP() do { if (blockIdx.x == 0 && threadIdx.x == 0) p.stamps[sct] = clock64(); sct++; } while (0) __global__ void __launch_bounds__(NTHR, 1) megakernel(P p) { // the global sense cell persists across launches (27 barriers per step flips // its parity); seed the block-local sense from it or the first barrier of the // next step releases early on the stale value. __shared__ unsigned lsense; if (threadIdx.x == 0) lsense = *(volatile unsigned*)(p.bar + 32); __syncthreads(); int sct = 0; STAMP(); for (int l = 0; l < 4; ++l) { const bf16* hsrc = (l == 0) ? p.hin : p.hid; if (l < 3) { stage_a1(p, l, hsrc); gridbar(p.bar, &lsense); STAMP(); stage_oproj(p, l, 4, hsrc); gridbar(p.bar, &lsense); STAMP(); } else { stage_m1(p, hsrc); gridbar(p.bar, &lsense); STAMP(); stage_m3(p); gridbar(p.bar, &lsense); STAMP(); stage_m4(p); gridbar(p.bar, &lsense); STAMP(); stage_oproj(p, 3, 3, hsrc); gridbar(p.bar, &lsense); STAMP(); } stage_b2(p, l); gridbar(p.bar, &lsense); STAMP(); } } // ================================ host glue ================================ // static std::unordered_map g_models; static int64_t g_next = 1; int64_t mk_setup(std::vector ts) { TORCH_CHECK((int)ts.size() == 4 * 48 + 15, "bad tensor count ", ts.size()); P p{}; size_t i = 0; for (int l = 0; l < 4; ++l) { LayerP& L = p.L[l]; for (int j = 0; j < 5; ++j) L.wq[j] = (const uint8_t*)ts[i++].data_ptr(); for (int j = 0; j < 5; ++j) L.ws[j] = (const float*)ts[i++].data_ptr(); for (int j = 0; j < 5; ++j) L.wb[j] = (const float*)ts[i++].data_ptr(); for (int j = 0; j < 5; ++j) L.sc[j] = (const bf16*)ts[i++].data_ptr(); for (int j = 0; j < 5; ++j) L.zp[j] = (const bf16*)ts[i++].data_ptr(); L.conv = (const bf16*)ts[i++].data_ptr(); L.betaw = (const bf16*)ts[i++].data_ptr(); L.anorm = (const bf16*)ts[i++].data_ptr(); L.mnorm = (const bf16*)ts[i++].data_ptr(); L.router = (const bf16*)ts[i++].data_ptr(); for (int j = 0; j < 6; ++j) L.ewq[j] = (const uint8_t*)ts[i++].data_ptr(); for (int j = 0; j < 6; ++j) L.esc[j] = (const bf16*)ts[i++].data_ptr(); for (int j = 0; j < 6; ++j) L.ezp[j] = (const bf16*)ts[i++].data_ptr(); } p.qkv = (bf16*)ts[i++].data_ptr(); p.gexp = (float*)ts[i++].data_ptr(); p.beta = (float*)ts[i++].data_ptr(); p.oatt = (bf16*)ts[i++].data_ptr(); p.q6144 = (bf16*)ts[i++].data_ptr(); p.qabs = (bf16*)ts[i++].data_ptr(); p.olat = (float*)ts[i++].data_ptr(); p.logits = (float*)ts[i++].data_ptr(); p.xm = (float*)ts[i++].data_ptr(); p.partial = (float*)ts[i++].data_ptr(); p.done = (unsigned*)ts[i++].data_ptr(); p.smax = (float*)ts[i++].data_ptr(); p.bar = (unsigned*)ts[i++].data_ptr(); p.hid = (bf16*)ts[i++].data_ptr(); p.stamps = (unsigned long long*)ts[i++].data_ptr(); static bool s_ready = false; if (!s_ready) { cudaError_t e = cudaFuncSetAttribute((const void*)megakernel, cudaFuncAttributeMaxDynamicSharedMemorySize, DSMEM); TORCH_CHECK(e == cudaSuccess, "smem attr: ", cudaGetErrorString(e)); int nb = 0; e = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&nb, (const void*)megakernel, NTHR, DSMEM); TORCH_CHECK(e == cudaSuccess && nb >= 1, "occupancy too low for cooperative launch"); int dev = 0, sms = 0, coop = 0; cudaGetDevice(&dev); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); cudaDeviceGetAttribute(&coop, cudaDevAttrCooperativeLaunch, dev); TORCH_CHECK(coop == 1, "cooperative launch unsupported"); TORCH_CHECK(sms * nb >= NBLK, "grid too large: ", sms, "x", nb, " < ", NBLK); s_ready = true; } const int64_t id = g_next++; g_models[id] = p; return id; } void mk_step(int64_t id, int64_t hin, int64_t ckv, int64_t kro, int64_t csrc, int64_t ksrc, int64_t scores, int64_t pos, int64_t ingest, int64_t S0, int64_t q0, int64_t k0, int64_t v0, int64_t S1, int64_t q1, int64_t k1, int64_t v1, int64_t S2, int64_t q2, int64_t k2, int64_t v2) { P p = g_models[id]; p.hin = (const bf16*)hin; p.ckv = (bf16*)ckv; p.kro = (bf16*)kro; p.ckv_src = (const bf16*)csrc; p.kro_src = (const bf16*)ksrc; p.scores = (float*)scores; p.pos = (int)pos; p.ingest = (int)ingest; p.S[0] = (float*)S0; p.cw[0][0] = (bf16*)q0; p.cw[0][1] = (bf16*)k0; p.cw[0][2] = (bf16*)v0; p.S[1] = (float*)S1; p.cw[1][0] = (bf16*)q1; p.cw[1][1] = (bf16*)k1; p.cw[1][2] = (bf16*)v1; p.S[2] = (float*)S2; p.cw[2][0] = (bf16*)q2; p.cw[2][1] = (bf16*)k2; p.cw[2][2] = (bf16*)v2; void* args[] = { &p }; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaError_t e = cudaLaunchCooperativeKernel((const void*)megakernel, dim3(NBLK), dim3(NTHR), args, DSMEM, stream); TORCH_CHECK(e == cudaSuccess, "megakernel launch: ", cudaGetErrorString(e)); } """ _CPP_SRC = r""" #include #include int64_t mk_setup(std::vector ts); void mk_step(int64_t id, int64_t hin, int64_t ckv, int64_t kro, int64_t csrc, int64_t ksrc, int64_t scores, int64_t pos, int64_t ingest, int64_t S0, int64_t q0, int64_t k0, int64_t v0, int64_t S1, int64_t q1, int64_t k1, int64_t v1, int64_t S2, int64_t q2, int64_t k2, int64_t v2); """ _EXT = None def _ext(): global _EXT if _EXT is None: import os _EXT = load_inline( name="kimi_mega_v1", cpp_sources=[_CPP_SRC], cuda_sources=[_CUDA_SRC], functions=["mk_setup", "mk_step"], verbose=os.environ.get("KIMI_MEGA_VERBOSE", "0") == "1", extra_cuda_cflags=["-O3", "-arch=sm_120", "--ptxas-options=-v", "-lineinfo"], ) return _EXT # --------------------------------------------------------------------------- # # module tree (state_dict-compatible with reference.py) # --------------------------------------------------------------------------- # @dataclass(frozen=True) class Config: hidden: int = 2304 kda_heads: int = 32 kda_head_dim: int = 128 short_conv: int = 4 mla_heads: int = 32 kv_lora: int = 512 qk_nope: int = 128 qk_rope: int = 64 v_head: int = 128 rope_theta: float = 10000.0 n_experts: int = 64 n_active: int = 8 n_shared: int = 1 moe_inter: int = 1024 routed_scaling: float = 2.446 group: int = 128 pattern: tuple = ("K", "K", "K", "M") dtype: torch.dtype = field(default=torch.bfloat16) def build_config(shape: dict) -> Config: return Config(n_experts=int(shape.get("n_experts", 64))) class QuantLinear(nn.Module): def __init__(self, in_f: int, out_f: int, group: int = GROUP_SIZE): super().__init__() self.in_f, self.out_f, self.group = in_f, out_f, group ng = in_f // group self.register_buffer("w_q", torch.zeros(in_f // 2, out_f, dtype=torch.uint8)) self.register_buffer("scales", torch.zeros(ng, out_f, dtype=torch.bfloat16)) self.register_buffer("zeros", torch.zeros(ng, out_f, dtype=torch.bfloat16)) class QuantExperts(nn.Module): def __init__(self, n: int, in_f: int, out_f: int, group: int = GROUP_SIZE): super().__init__() self.n, self.in_f, self.out_f, self.group = n, in_f, out_f, group ng = in_f // group self.register_buffer("w_q", torch.zeros(n, in_f // 2, out_f, dtype=torch.uint8)) self.register_buffer("scales", torch.zeros(n, ng, out_f, dtype=torch.bfloat16)) self.register_buffer("zeros", torch.zeros(n, ng, out_f, dtype=torch.bfloat16)) class KDA(nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg H, Dk, d = cfg.kda_heads, cfg.kda_head_dim, cfg.hidden self.q_proj = QuantLinear(d, H * Dk, cfg.group) self.k_proj = QuantLinear(d, H * Dk, cfg.group) self.v_proj = QuantLinear(d, H * Dk, cfg.group) self.g_proj = QuantLinear(d, H * Dk, cfg.group) self.beta_proj = nn.Linear(d, H, bias=False, dtype=cfg.dtype) self.conv_w = nn.Parameter(torch.zeros(3, H * Dk, cfg.short_conv, dtype=cfg.dtype)) self.o_proj = QuantLinear(H * Dk, d, cfg.group) self.scale = Dk ** -0.5 class MLA(nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg H, d = cfg.mla_heads, cfg.hidden self.q_proj = QuantLinear(d, H * (cfg.qk_nope + cfg.qk_rope), cfg.group) self.kv_a = QuantLinear(d, cfg.kv_lora + cfg.qk_rope, cfg.group) self.kv_b = QuantLinear(cfg.kv_lora, H * (cfg.qk_nope + cfg.v_head), cfg.group) self.o_proj = QuantLinear(H * cfg.v_head, d, cfg.group) self.scale = (cfg.qk_nope + cfg.qk_rope) ** -0.5 class MoE(nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg d, m, E = cfg.hidden, cfg.moe_inter, cfg.n_experts self.router = nn.Linear(d, E, bias=False, dtype=cfg.dtype) self.gate = QuantExperts(E, d, m, cfg.group) self.up = QuantExperts(E, d, m, cfg.group) self.down = QuantExperts(E, m, d, cfg.group) self.s_gate = QuantExperts(cfg.n_shared, d, m, cfg.group) self.s_up = QuantExperts(cfg.n_shared, d, m, cfg.group) self.s_down = QuantExperts(cfg.n_shared, m, d, cfg.group) class Block(nn.Module): def __init__(self, cfg: Config, kind: str): super().__init__() self.kind = kind self.attn_norm = nn.Parameter(torch.ones(cfg.hidden, dtype=cfg.dtype)) self.moe_norm = nn.Parameter(torch.ones(cfg.hidden, dtype=cfg.dtype)) self.attn = KDA(cfg) if kind == "K" else MLA(cfg) self.moe = MoE(cfg) class Model(nn.Module): def __init__(self, cfg: Config): super().__init__() assert cfg.hidden == 2304 and cfg.kda_heads == 32 and cfg.kda_head_dim == 128 assert cfg.n_experts == 64 and cfg.n_active == 8 and cfg.n_shared == 1 assert cfg.moe_inter == 1024 and cfg.kv_lora == 512 and cfg.qk_rope == 64 assert cfg.qk_nope == 128 and cfg.v_head == 128 and cfg.short_conv == 4 assert tuple(cfg.pattern) == ("K", "K", "K", "M") self.cfg = cfg self.blocks = nn.ModuleList(Block(cfg, k) for k in cfg.pattern) self._ready = False self._id = None self._ckv = None self._kro = None self._scores = None self._explen = -1 self._ckv_ptr = -1 self._kro_ptr = -1 self.register_load_state_dict_post_hook(Model._invalidate) @staticmethod def _invalidate(module, incompatible_keys): module._ready = False # ------------------------------------------------------------------ # def _prepare(self) -> None: dev = self.blocks[0].attn_norm.device assert dev.type == "cuda", "solution requires CUDA" hold = [] flat = [] for blk in self.blocks: a = blk.attn if blk.kind == "K": mats = [a.q_proj, a.k_proj, a.v_proj, a.g_proj, a.o_proj] conv, betaw = a.conv_w, a.beta_proj.weight else: mats = [a.q_proj, a.kv_a, a.kv_b, a.o_proj, a.q_proj] conv, betaw = blk.attn_norm, blk.attn_norm # unused fillers wq = [m.w_q for m in mats] ws = [m.scales.float().contiguous() for m in mats] wb = [(-(m.scales.float() * m.zeros.float())).contiguous() for m in mats] sc = [m.scales for m in mats] zp = [(m.zeros.float() + 128.0).to(torch.bfloat16).contiguous() for m in mats] ex = [blk.moe.gate, blk.moe.up, blk.moe.down, blk.moe.s_gate, blk.moe.s_up, blk.moe.s_down] ewq = [e.w_q for e in ex] esc = [e.scales for e in ex] ezp = [(e.zeros.float() + 128.0).to(torch.bfloat16).contiguous() for e in ex] flat += wq + ws + wb + sc + zp flat += [conv, betaw, blk.attn_norm, blk.moe_norm, blk.moe.router.weight] flat += ewq + esc + ezp hold += ws + wb + zp + ezp bf, f32 = torch.bfloat16, torch.float32 sc = { "qkv": torch.empty(3 * 4096, dtype=bf, device=dev), "gexp": torch.empty(4096, dtype=f32, device=dev), "beta": torch.empty(32, dtype=f32, device=dev), "oatt": torch.empty(4096, dtype=bf, device=dev), "q6144": torch.empty(6144, dtype=bf, device=dev), "qabs": torch.empty(32 * 512, dtype=bf, device=dev), "olat": torch.empty(32 * 512, dtype=f32, device=dev), "logits": torch.empty(64, dtype=f32, device=dev), "xm": torch.empty(9 * 1024, dtype=f32, device=dev), "partial": torch.empty(131072, dtype=f32, device=dev), "done": torch.zeros(256, dtype=torch.int32, device=dev), "smax": torch.empty(NBLK * 64, dtype=f32, device=dev), "bar": torch.zeros(64, dtype=torch.int32, device=dev), "hid": torch.empty(2304, dtype=bf, device=dev), "stamps": torch.zeros(64, dtype=torch.int64, device=dev), } flat += [sc[k] for k in ("qkv", "gexp", "beta", "oatt", "q6144", "qabs", "olat", "logits", "xm", "partial", "done", "smax", "bar", "hid", "stamps")] for t in flat: assert t.is_cuda and t.is_contiguous() self._hold = hold self._scratch = sc self._id = _ext().mk_setup(flat) self._dev = dev # cache invalidated on (re)prepare self._explen = -1 self._ckv_ptr = -1 self._kro_ptr = -1 self._ready = True # ------------------------------------------------------------------ # def step(self, hidden, state): if not self._ready: self._prepare() mla = state[3] ckv_in = mla["c_kv"] kro_in = mla["k_rope"] L = ckv_in.shape[0] if (L == self._explen and ckv_in.data_ptr() == self._ckv_ptr and kro_in.data_ptr() == self._kro_ptr): ingest = 0 src_c = 0 src_k = 0 else: # adopt external state (first step of a run) assert ckv_in.is_contiguous() and kro_in.is_contiguous() assert ckv_in.dtype == torch.bfloat16 and ckv_in.shape[1] == 512 cap = L + 1 + CACHE_MARGIN if self._ckv is None or self._ckv.shape[0] < cap: self._ckv = torch.empty(cap, 512, dtype=torch.bfloat16, device=self._dev) self._kro = torch.empty(cap, 64, dtype=torch.bfloat16, device=self._dev) self._scores = torch.empty(cap * 32, dtype=torch.float32, device=self._dev) ingest = 1 src_c = ckv_in.data_ptr() src_k = kro_in.data_ptr() self._ckv_ptr = self._ckv.data_ptr() self._kro_ptr = self._kro.data_ptr() if hidden.dtype != torch.bfloat16 or not hidden.is_contiguous(): hidden = hidden.to(torch.bfloat16).contiguous() s0, s1, s2 = state[0], state[1], state[2] _ext().mk_step( self._id, hidden.data_ptr(), self._ckv_ptr, self._kro_ptr, src_c, src_k, self._scores.data_ptr(), L, ingest, s0["S"].data_ptr(), s0["cq"].data_ptr(), s0["ck"].data_ptr(), s0["cv"].data_ptr(), s1["S"].data_ptr(), s1["cq"].data_ptr(), s1["ck"].data_ptr(), s1["cv"].data_ptr(), s2["S"].data_ptr(), s2["cq"].data_ptr(), s2["ck"].data_ptr(), s2["cv"].data_ptr(), ) self._explen = L + 1 mla["c_kv"] = self._ckv[: L + 1] mla["k_rope"] = self._kro[: L + 1] return self._scratch["hid"], state # --------------------------------------------------------------------------- # # state / input helpers (same contract as reference.py) # --------------------------------------------------------------------------- # def init_state(cfg: Config, context_len: int, seed: int) -> list: dev = torch.device("cuda:0") g = torch.Generator(device=dev).manual_seed(seed) H, Dk = cfg.kda_heads, cfg.kda_head_dim C = H * Dk state = [] for kind in cfg.pattern: if kind == "K": state.append({ "S": torch.randn(H, Dk, Dk, device=dev, generator=g) * 0.05, "cq": torch.randn(cfg.short_conv - 1, C, device=dev, generator=g, dtype=cfg.dtype) * 0.1, "ck": torch.randn(cfg.short_conv - 1, C, device=dev, generator=g, dtype=cfg.dtype) * 0.1, "cv": torch.randn(cfg.short_conv - 1, C, device=dev, generator=g, dtype=cfg.dtype) * 0.1, }) else: state.append({ "c_kv": torch.randn(context_len, cfg.kv_lora, device=dev, generator=g, dtype=cfg.dtype) * 0.1, "k_rope": torch.randn(context_len, cfg.qk_rope, device=dev, generator=g, dtype=cfg.dtype) * 0.1, }) return state def init_token(cfg: Config, seed: int) -> torch.Tensor: dev = torch.device("cuda:0") g = torch.Generator(device=dev).manual_seed(seed + 1) return torch.randn(cfg.hidden, device=dev, generator=g, dtype=cfg.dtype) * 0.25