[MagpieTTS] Mixture-of-Experts #15370
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds Mixture-of-Experts (MoE) support to MagpieTTS by introducing a router + expert FFN module, wiring it into the existing transformer_2501 stack, and adding auxiliary losses + inference-time logging/metrics to support training and evaluation.
Changes:
- Introduces
MoERouterandPositionwiseConvFFMoE, and integrates MoE intoTransformerLayer/Transformer. - Adds MoE auxiliary losses (
MoELoadBalancingLoss,MoERouterZLoss,MoEAuxiliaryLoss) and wires loss/statistics collection intoMagpieTTSModel. - Updates inference utilities and example scripts to detect MoE and report FLOPs/architecture summary; adds config + extensive unit/integration tests.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/collections/tts/modules/test_transformer_2501.py | Adds unit tests covering MoE router, MoE FFN, and transformer behavior with/without MoE. |
| tests/collections/tts/modules/test_moe_integration.py | Adds end-to-end integration tests across router → transformer → loss/backward. |
| nemo/collections/tts/modules/transformer_2501.py | Adds MoE params, swaps dense FFN for MoE FFN when enabled, returns routing info. |
| nemo/collections/tts/modules/moe_modules.py | New MoE router + MoE FFN implementation (experts + routing strategies). |
| nemo/collections/tts/modules/magpietts_inference/utils.py | Adds MoE detection, FLOPs computation, and richer return signature for model loading. |
| nemo/collections/tts/modules/magpietts_inference/init.py | Exposes new MoE/FLOPs utilities and updates example usage. |
| nemo/collections/tts/modules/init.py | Ensures MoE module is imported in the package init. |
| nemo/collections/tts/models/magpietts.py | Filters MoE-only config keys, collects routing info, computes/logs MoE losses and usage stats. |
| nemo/collections/tts/losses/moe_loss.py | New MoE auxiliary losses and expert-usage utility. |
| nemo/collections/tts/losses/init.py | Imports the new MoE loss module. |
| examples/tts/magpietts_inference.py | Updates example pipeline to include MoE naming + FLOPs metrics. |
| examples/tts/conf/magpietts/magpietts_lhotse_moe.yaml | Adds a reference MoE training config for MagpieTTS. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
3149156 to
4721710
Compare
5cbf2f6 to
981d4f6
Compare
| decoder: | ||
| n_layers: 12 | ||
| d_model: 768 | ||
| d_ffn: 1536 # 3072 / top_k_experts = 3072 / 2, match inference computation as dense baseline |
There was a problem hiding this comment.
Does this parameter still make sense? Why not keep it as 3072, and then change the per-export d_ffn depending on the number of experts at inference?
There was a problem hiding this comment.
yes, this param still makes sense.
d_ffn follows the standard MoE convention: it's the per-expert FFN hidden dimension, not a "total" to be divided.
The value 1536 is a deliberate FLOPs-matching choice for this particular config (3072 / top_k = 1536), but it's just one valid setting.
Users can freely set d_ffn: 3072 to give each expert the full dense FFN dimension (as in Mixtral/Switch Transformer), or any other value.
| Model | Per-Expert d_ffn |
Dense baseline d_ffn |
Top-K | FFN FLOPs vs dense |
|---|---|---|---|---|
| Switch Tranformer | 3072 | 3072 | 1 | 1x (same) |
| Mixtral 8x7B | 14336 | 14336 | 2 | 2x |
let me improve the YAML comment something like,
| d_ffn: 1536 # 3072 / top_k_experts = 3072 / 2, match inference computation as dense baseline | |
| d_ffn: 1536 # Per-expert FFN hidden dimension. Set to 1536 (=3072/top_k) for FLOPs-matched MoE, or 3072 for full-capacity experts (2x inference FLOPs vs dense). |
* update sinkhorn routing algorithm by evaluating convergence in scaling factor. * update core implementation and unit tests to compute loss and metrics only relevant to valid tokens. * only apply sinkhorn during training. * added FLOPs estimation function. * add patch that can load older models that has old name router_aux_loss_coeff. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
… checkpoints Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…FLOPs utilities Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
ed537ac to
209f65a
Compare
Signed-off-by: XuesongYang <XuesongYang@users.noreply.github.com>
|
@blisc I addressed all your comments. Pls have a look. |
tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh
Outdated
Show resolved
Hide resolved
tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh
Outdated
Show resolved
Hide resolved
tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh
Outdated
Show resolved
Hide resolved
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
| Tuple of: | ||
| all_code_logits (torch.Tensor): Logits of shape (B, T', num_codebooks * num_tokens_per_codebook). | ||
| attn_probabilities (list): Attention probabilities from each decoder layer. | ||
| dec_output (torch.Tensor): Raw decoder output of shape (B, T', d_model). | ||
| moe_routing_info (list or None): None if MoE is disabled. If MoE is enabled, | ||
| a list of dicts (one per layer) each containing: | ||
| - 'router_logits' (torch.Tensor): Raw router logits (B, T, num_experts). | ||
| - 'router_probs' (torch.Tensor): Router probabilities (B, T, num_experts). | ||
| - 'expert_indices' (torch.Tensor): Selected expert indices (B, T, top_k). |
There was a problem hiding this comment.
| Tuple of: | |
| all_code_logits (torch.Tensor): Logits of shape (B, T', num_codebooks * num_tokens_per_codebook). | |
| attn_probabilities (list): Attention probabilities from each decoder layer. | |
| dec_output (torch.Tensor): Raw decoder output of shape (B, T', d_model). | |
| moe_routing_info (list or None): None if MoE is disabled. If MoE is enabled, | |
| a list of dicts (one per layer) each containing: | |
| - 'router_logits' (torch.Tensor): Raw router logits (B, T, num_experts). | |
| - 'router_probs' (torch.Tensor): Router probabilities (B, T, num_experts). | |
| - 'expert_indices' (torch.Tensor): Selected expert indices (B, T, top_k). | |
| all_code_logits (torch.Tensor): Logits of shape (B, T', num_codebooks * num_tokens_per_codebook). | |
| attn_probabilities (list): Attention probabilities from each decoder layer. | |
| dec_output (torch.Tensor): Raw decoder output of shape (B, T', d_model). | |
| moe_routing_info (list or None): None if MoE is disabled. If MoE is enabled, | |
| a list of dicts (one per layer) each containing: | |
| - 'router_logits' (torch.Tensor): Raw router logits (B, T, num_experts). | |
| - 'router_probs' (torch.Tensor): Router probabilities (B, T, num_experts). | |
| - 'expert_indices' (torch.Tensor): Selected expert indices (B, T, top_k). |
This isn't a tuple. It's just returns 4 things not a tuple of 4 things. Please change it and we can merge.
There was a problem hiding this comment.
Actually, in Python, it is the comma that creates the tuple, not the parentheses. The Python documentation explicitly states: "It is actually the comma which makes a tuple, not the parentheses" (see docs).
So return a, b, c, d is functionally identical to return (a, b, c, d). both return a single tuple object containing 4 items. The current docstring "Returns: Tuple of..." is technically correct.
I verified this behavior to be sure as well,
In [1]: def foo():
...: return 1, 2, 3, 4
...:
In [2]: foo()
Out[2]: (1, 2, 3, 4)
In [3]: x = foo()
In [4]: type(x)
Out[4]: tuple
In [5]: y = 1, 2, "a", "b"
In [6]: type(y)
Out[6]: tuple
In [7]: type(foo())
Out[7]: tuple
Summary
Adds Mixture-of-Experts (MoE) support to MagpieTTS decoder. When enabled, each
TranformerLayerconditionally replaces its densePositionwiseConvFFwith aPositionwiseConvFFMoEthat uses a learned router to dispatch tokens to top-k experts (pointwise Conv1d FFNs), increasing model capacity without a proportional increase in per-token inference cost.Key components
moe_modules.py:MoERouter(top-k / Sinkhorn routing with optinal jitter noise) andPositionwiseConvFFMoE(sort-based vectorized expert dispatch withindex_add_scatter). Sinkhorn routing is used only during training; inference falls back to softmax top-k.moe_loss.py: MSE-based load balancing loss (penalizes deviation from uniform expert usage) and router z-loss (ST-MoE), both mask-aware to exclude padding tokens.ffn_modules.py:ConvolutionLayerandPositionwiseConvFFextracted fromtransformer_2501.pyto break a circular import withmoe_modules.transformer_2501.py:TransformerLayer/Transformerconditionally instantiate the MoE FFN viause_moe; per-layer routing info (logits, probs, expert indices) is collected and returned for model-level loss computation.magpietts.py: MoE auxiliary loss aggregated across decoder layers in training/validation; per-expert usage statistics and selection-frequency logged to WandB.magpietts_inference/utils.py: Per-component FLOPs estimation (including router cost), architecture summary logging with MoE parameter/FLOPs breakdown.magpietts_lhotse_moe.yaml: Example config withnum_experts=8,top_k=2,d_ffn=1536(FLOPs-matched to densed_ffn=3072),routing_strategy=top_k,router_jitter_noise-0.01.Design decisions
kernel_size=1is enforced;Conv1d(kernel_size=1)is equivalent tonn.Linear, matching standard MoE practice and ensuring token-independent dispatch.ValueErrorif both are active.router_load_balancing_loss_coeffandrouter_z_loss_coeffare filtered out before passing config to Transformer.init, keeping the module interface clean of training-only concerns.Tests
25 unit-test functions across 3 files (
test_moe_loss.py,test_moe_integration.py,test_transformer_2501.py) covering router behavior, top-k/Sinkhorn selection, FFN dispatch, padding/masking, cross-attention, loss computation, and cross-module integration contracts. 3 L2 CI scripts for MoE fast-dev-run training, zero-shot inference evaluation, and longform inference evaluation.