Skip to content

[MagpieTTS] Mixture-of-Experts #15370

Merged
XuesongYang merged 19 commits intoNVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-moe
Feb 13, 2026
Merged

[MagpieTTS] Mixture-of-Experts #15370
XuesongYang merged 19 commits intoNVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-moe

Conversation

@XuesongYang
Copy link
Collaborator

@XuesongYang XuesongYang commented Feb 8, 2026

Summary

Adds Mixture-of-Experts (MoE) support to MagpieTTS decoder. When enabled, each TranformerLayer conditionally replaces its dense PositionwiseConvFF with a PositionwiseConvFFMoE that uses a learned router to dispatch tokens to top-k experts (pointwise Conv1d FFNs), increasing model capacity without a proportional increase in per-token inference cost.

Key components

  • moe_modules.py: MoERouter (top-k / Sinkhorn routing with optinal jitter noise) and PositionwiseConvFFMoE (sort-based vectorized expert dispatch with index_add_ scatter). Sinkhorn routing is used only during training; inference falls back to softmax top-k.
  • moe_loss.py: MSE-based load balancing loss (penalizes deviation from uniform expert usage) and router z-loss (ST-MoE), both mask-aware to exclude padding tokens.
  • ffn_modules.py: ConvolutionLayer and PositionwiseConvFF extracted from transformer_2501.py to break a circular import with moe_modules.
  • transformer_2501.py: TransformerLayer / Transformer conditionally instantiate the MoE FFN via use_moe; per-layer routing info (logits, probs, expert indices) is collected and returned for model-level loss computation.
  • magpietts.py: MoE auxiliary loss aggregated across decoder layers in training/validation; per-expert usage statistics and selection-frequency logged to WandB.
  • magpietts_inference/utils.py: Per-component FLOPs estimation (including router cost), architecture summary logging with MoE parameter/FLOPs breakdown.
  • magpietts_lhotse_moe.yaml: Example config with num_experts=8, top_k=2, d_ffn=1536 (FLOPs-matched to dense d_ffn=3072), routing_strategy=top_k, router_jitter_noise-0.01.

Design decisions

  • Experts are pointwise linear: kernel_size=1 is enforced; Conv1d(kernel_size=1) is equivalent to nn.Linear, matching standard MoE practice and ensuring token-independent dispatch.
  • Context audio included in MoE loss: load-balancing and z-loss are computed over the full decoder input (context + target, excluding padding) so the router sees representative token distributions and avoids misleading gradients.
  • Sinkhorn + load balancing is rejected: Sinkhorn already produces a doubly stochastic routing matrix that enforces balanced routing; enabling load-balancing on top is redundant, so the config validator raises ValueError if both are active.
  • Loss coefficients live at model level: router_load_balancing_loss_coeff and router_z_loss_coeff are filtered out before passing config to Transformer.init, keeping the module interface clean of training-only concerns.

Tests

25 unit-test functions across 3 files (test_moe_loss.py, test_moe_integration.py, test_transformer_2501.py) covering router behavior, top-k/Sinkhorn selection, FFN dispatch, padding/masking, cross-attention, loss computation, and cross-module integration contracts. 3 L2 CI scripts for MoE fast-dev-run training, zero-shot inference evaluation, and longform inference evaluation.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Mixture-of-Experts (MoE) support to MagpieTTS by introducing a router + expert FFN module, wiring it into the existing transformer_2501 stack, and adding auxiliary losses + inference-time logging/metrics to support training and evaluation.

Changes:

  • Introduces MoERouter and PositionwiseConvFFMoE, and integrates MoE into TransformerLayer / Transformer.
  • Adds MoE auxiliary losses (MoELoadBalancingLoss, MoERouterZLoss, MoEAuxiliaryLoss) and wires loss/statistics collection into MagpieTTSModel.
  • Updates inference utilities and example scripts to detect MoE and report FLOPs/architecture summary; adds config + extensive unit/integration tests.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
tests/collections/tts/modules/test_transformer_2501.py Adds unit tests covering MoE router, MoE FFN, and transformer behavior with/without MoE.
tests/collections/tts/modules/test_moe_integration.py Adds end-to-end integration tests across router → transformer → loss/backward.
nemo/collections/tts/modules/transformer_2501.py Adds MoE params, swaps dense FFN for MoE FFN when enabled, returns routing info.
nemo/collections/tts/modules/moe_modules.py New MoE router + MoE FFN implementation (experts + routing strategies).
nemo/collections/tts/modules/magpietts_inference/utils.py Adds MoE detection, FLOPs computation, and richer return signature for model loading.
nemo/collections/tts/modules/magpietts_inference/init.py Exposes new MoE/FLOPs utilities and updates example usage.
nemo/collections/tts/modules/init.py Ensures MoE module is imported in the package init.
nemo/collections/tts/models/magpietts.py Filters MoE-only config keys, collects routing info, computes/logs MoE losses and usage stats.
nemo/collections/tts/losses/moe_loss.py New MoE auxiliary losses and expert-usage utility.
nemo/collections/tts/losses/init.py Imports the new MoE loss module.
examples/tts/magpietts_inference.py Updates example pipeline to include MoE naming + FLOPs metrics.
examples/tts/conf/magpietts/magpietts_lhotse_moe.yaml Adds a reference MoE training config for MagpieTTS.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

decoder:
n_layers: 12
d_model: 768
d_ffn: 1536 # 3072 / top_k_experts = 3072 / 2, match inference computation as dense baseline
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this parameter still make sense? Why not keep it as 3072, and then change the per-export d_ffn depending on the number of experts at inference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this param still makes sense.

d_ffn follows the standard MoE convention: it's the per-expert FFN hidden dimension, not a "total" to be divided.

The value 1536 is a deliberate FLOPs-matching choice for this particular config (3072 / top_k = 1536), but it's just one valid setting.

Users can freely set d_ffn: 3072 to give each expert the full dense FFN dimension (as in Mixtral/Switch Transformer), or any other value.

Model Per-Expert d_ffn Dense baseline d_ffn Top-K FFN FLOPs vs dense
Switch Tranformer 3072 3072 1 1x (same)
Mixtral 8x7B 14336 14336 2 2x

let me improve the YAML comment something like,

Suggested change
d_ffn: 1536 # 3072 / top_k_experts = 3072 / 2, match inference computation as dense baseline
d_ffn: 1536 # Per-expert FFN hidden dimension. Set to 1536 (=3072/top_k) for FLOPs-matched MoE, or 3072 for full-capacity experts (2x inference FLOPs vs dense).

Xuesong Yang and others added 3 commits February 12, 2026 12:08
* update sinkhorn routing algorithm by evaluating convergence in scaling factor.
* update core implementation and unit tests to compute loss and metrics only relevant to valid tokens.
* only apply sinkhorn during training.
* added FLOPs estimation function.
* add patch that can load older models that has old name router_aux_loss_coeff.

Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
… checkpoints

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…FLOPs utilities

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: XuesongYang <XuesongYang@users.noreply.github.com>
@XuesongYang XuesongYang changed the title [magpietts] Add Mixture-of-Experts to MagpieTTS [MagpieTTS] Mixture-of-Experts Feb 13, 2026
@XuesongYang XuesongYang requested a review from blisc February 13, 2026 09:40
@XuesongYang
Copy link
Collaborator Author

@blisc I addressed all your comments. Pls have a look.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Comment on lines +1264 to +1272
Tuple of:
all_code_logits (torch.Tensor): Logits of shape (B, T', num_codebooks * num_tokens_per_codebook).
attn_probabilities (list): Attention probabilities from each decoder layer.
dec_output (torch.Tensor): Raw decoder output of shape (B, T', d_model).
moe_routing_info (list or None): None if MoE is disabled. If MoE is enabled,
a list of dicts (one per layer) each containing:
- 'router_logits' (torch.Tensor): Raw router logits (B, T, num_experts).
- 'router_probs' (torch.Tensor): Router probabilities (B, T, num_experts).
- 'expert_indices' (torch.Tensor): Selected expert indices (B, T, top_k).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Tuple of:
all_code_logits (torch.Tensor): Logits of shape (B, T', num_codebooks * num_tokens_per_codebook).
attn_probabilities (list): Attention probabilities from each decoder layer.
dec_output (torch.Tensor): Raw decoder output of shape (B, T', d_model).
moe_routing_info (list or None): None if MoE is disabled. If MoE is enabled,
a list of dicts (one per layer) each containing:
- 'router_logits' (torch.Tensor): Raw router logits (B, T, num_experts).
- 'router_probs' (torch.Tensor): Router probabilities (B, T, num_experts).
- 'expert_indices' (torch.Tensor): Selected expert indices (B, T, top_k).
all_code_logits (torch.Tensor): Logits of shape (B, T', num_codebooks * num_tokens_per_codebook).
attn_probabilities (list): Attention probabilities from each decoder layer.
dec_output (torch.Tensor): Raw decoder output of shape (B, T', d_model).
moe_routing_info (list or None): None if MoE is disabled. If MoE is enabled,
a list of dicts (one per layer) each containing:
- 'router_logits' (torch.Tensor): Raw router logits (B, T, num_experts).
- 'router_probs' (torch.Tensor): Router probabilities (B, T, num_experts).
- 'expert_indices' (torch.Tensor): Selected expert indices (B, T, top_k).

This isn't a tuple. It's just returns 4 things not a tuple of 4 things. Please change it and we can merge.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, in Python, it is the comma that creates the tuple, not the parentheses. The Python documentation explicitly states: "It is actually the comma which makes a tuple, not the parentheses" (see docs).

So return a, b, c, d is functionally identical to return (a, b, c, d). both return a single tuple object containing 4 items. The current docstring "Returns: Tuple of..." is technically correct.

I verified this behavior to be sure as well,

In [1]: def foo():
   ...:     return 1, 2, 3, 4
   ...:

In [2]: foo()
Out[2]: (1, 2, 3, 4)

In [3]: x = foo()

In [4]: type(x)
Out[4]: tuple

In [5]: y = 1, 2, "a", "b"

In [6]: type(y)
Out[6]: tuple

In [7]: type(foo())
Out[7]: tuple

@XuesongYang XuesongYang merged commit 68eaee6 into NVIDIA-NeMo:main Feb 13, 2026
414 of 420 checks passed
@XuesongYang XuesongYang deleted the xueyang/pr-moe branch February 13, 2026 23:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants