The information contained in this blog represents the opinion of Evergrid.AI as of the date presented. AMD and/or Evergrid.AI have no obligation to update any forward-looking content in the blog. AMD is not responsible for the content of this third-party blog and does not necessarily endorse the comments/claims/information presented therein. GD-84.
MoE model mimics low power consumption pattern in human brain: functions are divided into divisions, partial activated via adaptive routing when thinking.
Human-brain cortex from Oxford university research paper,
archived from internet
The first truly workable version in CUDA is SwitchTransformer[1], then improved by Mistral[2] by upcycling dense models:
SwitchTransformer-MoE
Later DeepSeek V2/V3/R1 [3][4][5] improved MoE by introducing shared experts [3] and gating bias [4][5], which finally leads to auxiliar loss free MoE models [4][5]. This is essentially attributed to the fact that when shared experts (chosen as 1 by DeepSeek team) are used, imbalance of experts routing problem can be mitigated by forcing a punishment of a bias score over a large pool of experts (256)[11].
The MoE layer is implemented as multi experts FFN layers, which consists gating functions to route activations according to topk gating scores (with bias in DeepSeek V3/R1) and producing logits by Group GEMM upon selected FFN layers.
The function relies heavily on radix sorting logics underlying. With MoE Align & Sort, ML researchers and practitioners can sort tokens in the order of expert IDs.
In some application, such as TransformerEngine [6][7], the operation was implemented by deprecated cub::DeviceRadixSort, and permute was implemented to record the src(left) to dest(right) mapping, the gradient of which is unpermute.
MoE permute illustration
Despite the fact that cub::DeviceRadixSort uses intensively shared memory, which is slightly slower than the implementation based on shfl_xor_sync where only thread local memory is used, it does not allow alignment sorting.
Alignment sorting is important for Group Gemm efficiency where experts can process tokens in blocks.
The MoE Align & Sort algorithm in SGLang employed alignment sorting yet was not efficient when serving large scale prefill operations for MoE models up to 256 experts. The issue was identified in the issue#2732. The current implementation split MoE Align & Sort into two kernel launches:
We propose and write AMD-friendly CUDA kernels using our proposed MoE Align & Sort algorithm. So profiling and analysis on AMD platform will be fully considered.
By using RocProfiler-Compute for different workloads, we can clearly see that the first kernel takes 33W cycles and second kernel takes 8W cycles even without counting multiple kernels launch overhead in a trace profile:
MoE align kernel 1
MoE align kernel 2
In ROCm SDK 6.3.0, omniperf has been rebranded as rocprof-compute. Dispite the active support of AMD Instinct MI300X/MI300A GPUs, it is not by default shipped with ROCm SDK 6.3.0. But setting up the ROCm compute profiler is nothing more than three simple steps as demonstrated in Tools-dockerhub.
Now, on chip overhead will be immedately reduced to 20W cycles from previous 41W cycles after applying the optimization we proposed in PR#3613:
Enable efficient multi-blocks MoE-align execution in SGLang
By fully enabling concurrent multiple blocks execution with arbitrary expert numbers (MAX_EXPERT_NUMBER==256), and with aggressive usage of shared memory (5kB LDS) and registers (52 VGPRs, 48 SGPRs), the MoE Align & Sort logics was crafted to achieve
Opt bench (all cases)
A-100
Opt bench (all cases)
MI100
(gfx908)
With Rocprof-Compute, we can easily collect some key indicators for a captured kernel and visualize them in a remote GUI server:
Start Rocprof-Compute in server side
To summarize, in AMD Instinct MI300X/MI300A GPUs, the proposed efficient multi-blocks MoE Align & Sort algorithm uses aggressively both vector registers (52) per wave with no registers spills (I adjust the initial threads block size to its best), and LDS (5kB) per CU with only 6.8% bank conflicts rates.
We also analyzed the roofline model of MoE Sort & Align. The roofline model shows the kernel performance drops in memory bound region.
In section AMD Compute Profile, we give details of the profiling data and analysis of our algorithm design in ROCm platform.
Essentially, MI300X/MI300A is the world's first high-performance AI accelerator architecture based on a multi die design. As a result, finetuning of operations on this chip will be slightly different from those on NVIDIA's platform.
The fundamental rule is, that synchronization among XCDs (Accelerated Computing Dies) is costly, better to make full use of XCDs and L2 cache locality affinity to increase the performance.
And we should avoid expensive synchronization by either using the lowest speed computing die (XCD7 for MI300X, XCD5 for MI300A) when grid size is smaller than the number of XCDs per chip (8 for MI300X, 6 for MI300A), or adapting grid size to a multiple of the number of XCDs per chip when it exceeds that threshold.
Launching cooperative kernels by hipCooperativeLaunch may increase L2 cache pressure (relate to texture addresser stall rate and busy rate) when data exchange (espeically Die-Die Exchange) increases among blocks.
In this example, the implementation from previous main branch uses 39 active CUs which is almost good since essentially two dies were used.
Our implementation uses 66 active CUs in multi-blocks execution that across two dies and Die-Die exchange is inevitable in block-wise reduction. We will submit further V4 optimization to SGLang later in this quarter.
Details will be further discussed in profiling section.
SGLang team used triton first approach to implement the logics and gained great successes in day 0 support of DeepSeek V3 in Dec 2024.
The SGLang MoE launches fused MoE kernel implemented in Triton.
Before the kernel launch, the MoE Align & Sort algorithm is applied. the MoE Align & Sort triton kernel is split into 4 phases where direct accesses to DRAM without shared memory are employed contrast to the vectorize triton version.
Multiple launches and inefficient use of LDS, local caches, and registers (VGPR for example) contributed to inefficient single test execution for small workloads, compared to single block CUDA implementation counterpart.
Then CUDA implementation is finally split into two phases and only the second phase execution is accelerated in multiple blocks.
FasterTransfomer
Before Mistral[2] and DeepSeek V2[3], open dense models are more popular in inference scenarios. This was when FasterTransfomer[8] was born.
In FasterTransformer[8] project, initiated by NVIDIA, MoE models are supported essentailly via cub::DeviceRadixSort and kernels like moe_softmax (which is essentially softmax in cub::BlockReduce), moe_top_k and its fused version topk_gating_softmax, permute to order latent vector logits, and finally group gemm.
Hence fusion is largely (by cost) limited to topk gating softmax, biased topk gating softmax, which are later incoroperated in SGLang.
Megatron
Megatron, before the publication of this article, for FP16/BF16, largely uses FasterTransformer approach but
added gradient operation of permute : unpermute, to facilitate training workload.
That means MoE is also not efficiently fused.
vLLM
SGLang uses many vLLM kernels, but vLLM 's Fused Moe was initially contributed by SGLang team. Hence they
deploy the same approach.
CK
The first version of AMD friendly fused MoE was proposed in CK#1634 on NOV 26, 2024. Later, MoE Align &
Sort was added in CK#1771 and CK#1840.
The high-level idea is to fuse MoE sorting with Group GEMM. And MoE & Sorting in CK largely employes
SGLang's team approach execept for CK pipliner and partitioner.
CK fused MoE High Level Idea[9]
Fusion of per_group_token_quant (for online fp8 quantization), MoE sorting and Group GEMM can be immediately resolved by incorporating Radix Sort computing logics into Group GEMM pipeliner: count occurencies to compute offsets followed by parallel placement.
One of the most critical problems is that how the two kinds of workloads (Radix Sorting & Group GEMM) is balanced.
In AMD data center chips, Group GEMM fragment is more likely to be evenly distributed to all the available blocks in an XCD. While, the data exchange among blocks in different CUs are through low speed of L2 Cache and L2 Cache fabric if multiple XCDs involved.
Writing CK kernels requires writing host side CK solution launcher:
Device entry of the kernel, tile partitioner, and stages pipliner.
The AMD CK partitioner and stages pipliner for fused MoE is also very interesting to be attributed to the final
assembly, yet out of scope of this article.
But just remember its MoE Align & Sort is part of producer codes:
So MoE Align & Sort in the AMD CK solution almost aligns with SGLang main implementation except for partitioner and pipliner.
Note the implementation does not always promise the best performance in AMD platform (see asm MoE in AITER).
Since AMD CDNA3 arch does not support Graphcore alike on-chip shuffling (we abstracted and generalized on-chip shuffling as Remapping Op of PopART[12] & PopRT in 2023) magics, -- which was now supported in NVIDIA H100/H200/B200 throughout high efficient on chip SM<->SM communication.
As a result, adapting the data layout cheaply among blocks to its best will be a very interesting section in the AMD open-source solution.
Hence, in philosophy, tiling based fusion code of these two different workloads may not always exceed the non-fused version. Details of the research will be conducted in our V4 release.
AITER
AI Tensor Engine For ROCm[10]
AITER was introduced at an early time of this year to incorporate LLM kernels used in different projects. It supports Fused MoE via ck moe, asm version of MoE via hipModule and Triton-fused MoE.
Hence it is partially open source, since the opaque assembly and development schedule to MI300X developers.
The alleged 3x acceleration [10] of fused MoE in AITER is verified by Bruce Xu [13] and is essentail from the
acceleration observed in a group GEMM with different shapes : a gemm where each expert's FFN weights
mutliply a block of hidden states of tokens.
The proof is that asm gemm generates almost 3x improvements in PR#199:
ASM Flat Matrix Multiply
Notably, there are still cases where triton kernels adapted from SGLang community are selected. To run triton kernel efficiently on MI300X/MI300A, they map thread blocks onto dies using multi-die architecture specific logics:
Besides, various of AMD chip intrinsics have been used in CK fused MoE, such as
and so on so forth. These are suspected to be attributed to the final assembly version of fused MoE.
For example, with usage of __builtin_nontemporal_load, we can skip L2 cache and leave more spaces in L2 cacheline for the data predicted to be reused.
Cutlass v3.8
Fused MoE is not currently publicly supported in NVIDIA Cutlass 3.8.0 at the time I am writing this article.
Hence no MoE Align & Sort available this repo.
TRT-LLM
Before v0.16.0, the TRT-LLM basic follows FasterTransformer approach. After v0.17.0, the MoE part is disclosed.
The algorithm employes multiple blocks execution schemes and consists of 3 different sections (D-C-P):
Our proposed efficent multi-blocks MoE Align & Sort algorithm
Parallel unaligned local cumsum
Our proposed parallel local unaligned cumsum
The algorithm was first proposed and implemented by us in PR#2970.
We load balanced the cumsum execution in each block to kElementsPerThr(16) threads, were kElementsPerThr + kElementsPerThr + threadIdx.x Add Operations needed to be processed in each thread.
Hence wavefront is faster to reach compared to the single thread version in current repo and we hereby observed 30% improvement in this version of implementation.
Reduce unaligned cumsum
Once we get local unaligned cumsum in each block, we proceed to block-wise reduction among the cumsum stored in the pre-allocated HBM buffer.
We chose FRAG_SIZE_M(16) x FRAG_SIZE_N(16) x FRAGS_PER_BLOCK(4) SRAM fragments for blockwise reduction, and FRAGS_PER_BLOCK is tuneable :
Block-Wise Reduction
In AMD platform, calculation is performend on a 1 warp to load / 1 warp to compute basis, while 2 warps to load and 1 warp to compute in NVIDIA platform.
The design makes use of full advantages of AMD 64 SIMD lanes in CDNA3 architecture. And the number blocks is always multiple of the number of XCDs in this multi-die arch chip.
FRAGS_PER_BLOCK was set to 4 to facilitate re-use of SMEM in multiple rounds.
Align global cumsum & store global cumsum
We improved the vectorization codes and take care of loop tails if input data size is not aligned with kElementsPerAccess constant.
The benchmarks show coalescing rate is improved but still limited to 30%. We will work on it in V4 release.
Writing AMD friendly CUDA
Writing a Pytorch extension enables automatic translating CUDA kernel to HIP kernel with ROCm SDK.
However, there are cases where the HIP kernel works differently from CUDA kernel:
We conducted extensive tests without under CUDA graph capture for large workloads of deepseek v3 models. Hence the number of experts was set to 256. The algorithm currently does not support to be under cuda graph capture and we will resolve this issue later in V4 release.
Due to the virtualization of GPU machines and the number of CPU allocated for the test node, the performance may vary from time to time compared to bare metal tests.
Hence, we use triton implementation as baseline to demonstrate the acceleration multiple and efficiency of our proposed algorithm for MoE Align & Sort.
Each test was verified first before benchmark. During the benchmark, we observed that triton in AMD platform runs significantly longer than that in NV at the time we tested. We hence recommend further optimization of triton MLIR for more efficient lowering process compared to NVDIA triton.
For AMD triton, we observed MI300X is 1.5x faster, hence improvement multiple in MI300X is not significant as MI100. And moreover, even MI300X is generally believe more faster than MI100, but in our test, the algorithm in MI100 performs better than in MI300X.
It is partially attributed to the fact that for a memory bounded op, the communication among multiple dies chip lowering the speed of execution.
In both platforms we observed significant improvements after applying our proposed algorithm, where the existing CUDA implementation almost costed the same time as Triton.
AMD system preparation
In order to make best usage of AMD heterogenous system, it is recommended to do some checking.
Benchmark on MI100
git clone https://github.com/yiakwy-xpu-ml-framework-team/AMD-sglang-benchmark-fork.git -b
optimize_moe_align_v3 && cd sgl-kernel && python setup_rocm.py install
Feasibility across different combination of numbers input token and experts can be verified:
cd ../benchmark/kernels/fused_moe_trition && python benchmark_deepseekv3_moe_align_blocks.py --
verify
num_tokens |
experts |
SGLang |
Triton (AMD) |
GPU |
8192 |
256 |
79.36 |
426.71 |
MI100 |
16384 |
256 |
86.4 |
681.12 |
MI100 |
16384 x 128 |
256 |
3047.68 |
62442.85 |
MI100 |
32768 x 128 |
256 |
7211.37 |
129388.43 |
MI100 |
Benchmark on A100
num_tokens |
experts |
SGLang |
Triton (NV) |
GPU |
8192 |
256 |
77.44 |
124.92 |
A100 |
16384 |
256 |
\ |
\ |
A100 |
16384 x 128 |
256 |
5966.81 |
17396.51 |
A100 |
32768 x 128 |
256 |
12450.05 |
34711.14 |
A100 |
Benchmark on H200
num_tokens |
experts |
SGLang |
Triton (NV) |
GPU |
8192 |
256 |
\ |
\ |
H200 |
16384 |
256 |
\ |
\ |
H200 |
16384 x 128 |
256 |
4508.42 |
12361.15 |
H200 |
32768 x 128 |
256 |
9023.48 |
24683.70 |
H200 |
Benchmark on MI300X
num_tokens |
experts |
SGLang |
Triton (AMD) |
GPU |
8192 |
256 |
88.16 |
281.64 |
MI300X |
16384 |
256 |
134.02 |
448.88 |
MI300X |
16384 x 128 |
256 |
6865.64 |
43266.09 |
MI300X |
32768 x 128 |
256 |
13431.80 |
89788.58 |
MI300X |
Setup
In ROCm 6.3.3, setup a rocprof-compute can be easily as three steps setup, details can be found here: https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/tree/main
Profiling Results of Vector L1 Cache
The workload 16384 tokens x (top 8 out of 256 experts) unless otherwise specified.
kernel VGPRs |
VGPRs |
SGPRs |
active CUs |
Vector L1 cache hit rate |
coalescing rate / utils |
old main moe_align_block_size_kernel (k1) |
20 |
48 |
3 |
0% |
25% / 7% |
old main count_and_sort_expert_tokens_kernel (k2) |
8 |
32 |
39 |
27% |
NaN |
our moe_align_block_size_kernel |
52 |
48 |
66 |
61% |
36% / 18% |
We maximize the usage of VGPRs but reduce total usage of SGPRs in our algorithm. The data also indicates Zero VGPRs/SGPRs spills usage that healthy usage of registers and no performance panelty for this kernel.
Vector L1 cache (vL1D) is unit local to each CU, the hit rate records cache line hit rates when data requestd from L2 Cache to CU. 30% L2 cache requests was coalesced by vL1D's texture addresser and 61% hit rates achieved, which can also be improved later if necessary.
At the time data requested from CU to vL1D's addressing processing unit (texture addresser), there are four states for the complex to decide whether to accept or roll back the data request to CU via the data processor unit in vL1D.
Detials of this micro arch behavior can be found in AMD CDNA3 ISA and rocProfiler-compute docs.
Our vL1D addresser stall
We witnessed 18.61% Data Waiting Stall rate from vector L1 cache in this aglorithm design.
The load balance of data R/W is greatly reduced from 8 kB Reading Op, 27 B Writing Op to combination of 109 B Reading Op, 468 B Writing Op and 202 B Atomic Op.
Profiling Results of L2 Cache
In CDNA3 architecture, L2 Cache is shared by all CUs and is the main entry to share data among thread blocks distruted to different CUs.
With multiple channels and addresses interleaving design, requests to L2 cache can be largely handled concurrently.
Moreover with AMD specific intrinsics such as builtin_nontemporal_load, we can pass through L2 cache for data we don't need to visit again.
The details of L2 cache study will be revealed in V4 release.
The new algorithm accelerates MoE Align & Sort in both CUDA and ROCm platform significantly up to 3x ~ 7x by maximizing the usage of LDS and vector registers.
We also observed memory bounded op may perform worse in a multiple die chip compared to a single die chip, this indicates a new finetuning direction when programming device codes in a multiple-die chip such as MI300X/MI300A and B200/B300.
However, details of the algorithm can be still polished to improve cache hit rate and main memory coalecsing rate.
Special thanks to Prof Zhang Han (hanzhangqin8@gmail.com), Doctor Wang YunHong (yunhongwang2000@gmail.com) from NUS team for the collaboration in MI100/MI250 performance verification, Zev Rekhter (connect@evergrid.ai) for the collaboration in MI300X performance verification, Shuyi Fan (fsygd1996@163.com) for the collaboration in H200 verification and BBuf(1182563586@qq.com) for discussion and review of the solution in the SGLang.
Note this is an independent work from SGLang community.
I also express my deep thanks to Bingqing, Peng Sun and ShawHai who spare time individually in reviewing the article and giving suggestions in revision.
Also see evergrid.ai and huggingface sites.