Skip to content

Conversation

@xiaoxi-wangfj
Copy link
Contributor

@xiaoxi-wangfj xiaoxi-wangfj commented Dec 26, 2025

Description

This PR introduces blockwise, scaling-aware FP8 transpose optimizations for FP8 MoE that enable a casting-free, FP8-centric MoE dataflow in TransformerEngine by eliminating unnecessary cast and re-quantization steps, while maintaining numerical stability in existing FP8 training workflows.

This PR is designed to be used in conjunction with PR NVIDIA/Megatron-LM#2764

Further optimizations are introduced via two additional PRs:

Background / Motivation

The design and theoretical background of this PR are described in the paper:
FP8-Flow-MoE: A Casting-Free FP8 Recipe without Double Quantization Error

The follow figure illustrates the optimized MoE dataflow and highlights the key optimization points (marked as ①–⑤).

FP8FLOW-MoE

1. FP8 Quantization Before Dispatch (DeepEP → GroupedLinear)

Quantization is performed before DeepEP dispatch, and row-wise FP8 tensors are directly fed into GroupedLinear.

  • Keeps dispatch → permute → expert computation entirely in FP8
  • Float8BlockwiseQTensor is propagated with a COMPACT layout (for _rowwise_scale_inv) along the dispatch → permute → GroupedLinear path, avoiding layout-induced .T.contiguous() calls and reducing unnecessary memory copies.

(Shown as marker ① in the figure)

2. Scaling-Aware FP8 Transpose for Wgrad

GroupedLinear requires:

  • row-wise FP8 for Fprop/Dgrad
  • column-wise FP8 for Wgrad

To avoid dequantize → transpose → requantize , this PR introduces scaling_aware_fp8_transpose, which:

  • Converts row-wise FP8 to column-wise FP8 via exponent manipulation only
  • Preserves scale consistency across layouts
  • reduce cpu overhead

(Shown as marker ④ in the figure)

3. Fused Permute + Padding / Unpermute + Unpadding

We fuse two memory movement operators along the MoE path:

  • permute + pad in the forward pass
  • unpermute + unpad in the backward pass

For details of this optimization, please refer to PR #1921

(Shown as marker ② in the figure)

4. Fused Activation + Quantization

Activation and FP8 quantization are fused into a single kernel, Produces FP8 outputs directly, while enabling FP8 persistence

(Shown as marker ③ in the figure)

5. Add fine-grained recompute moe_expert

Because the entire dispatch → permute → GroupedLinear path stays in FP8, we enable fine-grained recomputation at the moe_expert level:

  • Saves ~50% peak activation memory and avoids recomputation of the router compared to recomputing the full module moe level

(Shown as marker ⑤ in the figure)

Performance Results

We evaluate FP8-Flow-MoE on DeepSeek-V3 (671B) to validate scalability and robustness under realistic large-scale training conditions.

Throughput

Measured throughput (TGS, tokens/GPU/s) under different expert parallelism (EP) on DeepSeek-V3 (671B) :

  • vs. BF16
    +6% (EP8), +8% (EP16), +16% (EP32)

  • vs. TransformerEngine blockwise FP8 recipe
    +3% (EP8), +8% (EP16), up to +21% (EP32)

Memory Efficiency

With AC = selective checkpointing and recompute-modules = moe_expert:

  • At EP8:
    • ~8 GB lower peak memory vs. BF16
    • ~16.5 GB lower peak memory vs. blockwise FP8

Numerical Accuracy

We trained for >200B tokens. The loss deviation of FP8-Flow-MoE stays within 0.19% compared to both BF16 baselines and TransformerEngine blockwise FP8 recipe, with no observed instability or divergence.

Limitations

  • Currently validated on NVIDIA Hopper architecture with blockwise FP8 recipe

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Megatron-LM: Added fused FP8 kernels for activation + quantization in fused_bias_swiglu.py and fused_weighted_swiglu_quant.py
  • Megatron-LM: Integrated FP8 dispatch and expert recomputation support in Megatron-LM fused_a2a.py
  • TransformerEngine: Added support for Float8BlockwiseQTensor inputs in grouped_linear.py
  • TransformerEngine: Added scaling_aware_fp8_transpose operator in triton/blockwise_scaling_aware_fp8_transpose.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@rich-junwang
Copy link

Quick question, does this work with mxfp8 or it only applies to fp8? Thanks.

@xiaoxi-wangfj
Copy link
Contributor Author

xiaoxi-wangfj commented Dec 29, 2025

Quick question, does this work with mxfp8 or it only applies to fp8? Thanks.

@rich-junwang , Thanks for asking. This PR primarily targets the standard FP8 blockwise recipe, and the current scaling-aware FP8 transpose implementation is specialized for blockwise=128 scaling.

Among the five optimizations described in this PR:
Items 1, 2, and 4 are specific to the FP8 blockwise recipe.
Items 3 and 5 are not tied to a specific FP8 recipe and can be applied more generally, independent of whether blockwise scaling is used.

1. add fp8 rowwise scaling-aware transpose op for wgrad columwise.
2. support Float8BlockwiseQTensor input in grouped_linear.
3. _rowwise_scale_inv is propagated with a COMPACT layout along the `dispatch → permute → GroupedLinear` path.

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Co-authored-by: dantesuu@gmail.com
Co-authored-by: xzhu@zhejianglab.org
Co-authored-by: 123sssmmm@gmail.com
@xiaoxi-wangfj xiaoxi-wangfj marked this pull request as ready for review December 29, 2025 08:18
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 29, 2025

Greptile Summary

This PR implements a casting-free FP8 MoE dataflow optimization by introducing scaling-aware FP8 transpose operations that eliminate unnecessary dequantization/requantization steps.

Key Changes

  • New Triton Kernel (blockwise_scaling_aware_fp8_transpose.py): Implements in-domain FP8 transpose by manipulating exponents directly rather than dequantizing to higher precision. The kernel extracts FP8 exponent/mantissa, computes scale adjustments (k = exp_target - exp_source), and applies adjustments to produce column-wise FP8 data with appropriate scaling factors.

  • Float8BlockwiseQTensor Integration (float8_blockwise_tensor.py): Added split_scaling_aware_fp8_transpose() method that wraps the Triton kernel, handles splits with zero token counts, and manages format conversions between COMPACT and GEMM_READY layouts.

  • GroupedLinear Support (grouped_linear.py): Modified forward/backward passes to detect Float8BlockwiseQTensor inputs and route them through the new scaling-aware transpose path instead of the standard tex.split_quantize().

  • Format Handling (permutation.py): Fixed conditional logic to avoid unnecessary transposes when rowwise_scale_inv shapes match rowwise_data (COMPACT format).

Architecture Integration

The changes fit well into the existing FP8 infrastructure. The new path activates when blockwise-quantized tensors flow through GroupedLinear, enabling the "casting-free" dataflow described in the paper. The implementation maintains compatibility with existing delayed scaling and per-tensor recipes.

Issues Found

  1. Potential exponent underflow in the Triton kernel when exp_new < 0 could cause wraparound
  2. In-place format mutation in split_scaling_aware_fp8_transpose() modifies self._data_format which may have side effects
  3. Scale difference validation needed for extreme cases where target_si and si differ by orders of magnitude

Confidence Score: 4/5

  • This PR is safe to merge with minor risk from potential edge cases in the FP8 exponent manipulation logic
  • The implementation is well-structured and integrates cleanly into existing FP8 infrastructure. The core Triton kernel logic is mathematically sound for FP8 scaling-aware transpose. However, there's a minor concern about exponent underflow handling that should be validated. The changes are additive (new code path for Float8BlockwiseQTensor) and don't break existing functionality. The PR claims validation with >200B tokens of training showing <0.19% loss deviation, which is strong evidence of correctness.
  • Pay close attention to transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py lines 100-101 for the exponent underflow edge case

Important Files Changed

Filename Overview
transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py New Triton kernel implementing FP8 transpose with scaling awareness. Converts row-wise FP8 to column-wise via exponent manipulation. Core logic appears sound but has potential edge case with exponent overflow.
transformer_engine/pytorch/module/grouped_linear.py Added Float8BlockwiseQTensor support to GroupedLinear forward/backward passes. Changes integrate the new scaling-aware transpose path when blockwise quantized tensors are used.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Added split_scaling_aware_fp8_transpose method that wraps the Triton kernel. Handles m_splits with zero values and format conversions. Logic appears correct but modifies data_format in-place.

Sequence Diagram

sequenceDiagram
    participant User as User Input
    participant GL as GroupedLinear
    participant FB as Float8BlockwiseQTensor
    participant TK as Triton Kernel
    participant GEMM as GroupedGEMM
    
    Note over User,GEMM: Forward Pass - FP8 Quantization Before Dispatch
    User->>GL: forward(inp, m_splits)
    GL->>GL: Check if inp is Float8BlockwiseQTensor
    alt Input is Float8BlockwiseQTensor
        GL->>FB: split_scaling_aware_fp8_transpose(m_splits)
        FB->>FB: Convert GEMM_READY → COMPACT format
        FB->>TK: blockwise_scaling_aware_fp8_transpose()
        Note over TK: Row-wise FP8 → Column-wise FP8<br/>via exponent manipulation
        TK->>TK: Extract sign, exp, mantissa
        TK->>TK: Compute k = exp_target - exp_source
        TK->>TK: Adjust exponent: exp_new = exp - k
        TK->>TK: Transpose and write columnwise data
        TK-->>FB: [rowwise_data, rowwise_scale_inv_t,<br/>columnwise_data, columnwise_scale_inv]
        FB-->>GL: List of Float8BlockwiseQTensorStorage
    else Input is standard tensor
        GL->>GL: tex.split_quantize(inp_view, m_splits)
    end
    GL->>GEMM: general_grouped_gemm(weights, inputmats)
    GEMM-->>GL: output
    GL-->>User: result
    
    Note over User,GEMM: Backward Pass - Scaling-Aware Transpose for Gradients
    User->>GL: backward(grad_output)
    GL->>GL: Check if grad_output is Float8BlockwiseQTensor
    alt grad_output is Float8BlockwiseQTensor
        GL->>FB: split_scaling_aware_fp8_transpose(m_splits)
        FB->>TK: blockwise_scaling_aware_fp8_transpose()
        TK-->>FB: transposed gradients
        FB-->>GL: quantized grad tensors
    else Standard grad_output
        GL->>GL: tex.split_quantize()
    end
    GL->>GEMM: dgrad computation
    GL->>GEMM: wgrad computation
    GEMM-->>GL: gradients
    GL-->>User: dgrad, wgrad
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py, line 100-101 (link)

    logic: potential exponent underflow not handled - when exp_new < 0 (not just < 1), the result could wrap around since exp is unsigned. consider clamping to 0 for negative values

  2. transformer_engine/pytorch/tensor/float8_blockwise_tensor.py, line 453-456 (link)

    style: modifying self._data_format in-place during a method call can cause unexpected side effects if the tensor is reused. consider creating a new tensor or documenting this mutation clearly

  3. transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py, line 94-99 (link)

    style: verify that the scaling adjustment works correctly when target_si and si differ by more than 127 in exponent (i.e., orders of magnitude apart). have you validated this scaling adjustment with extreme scale differences where target_si and si magnitudes differ significantly?

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants