Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,59 @@ def test_quantize_dact_dbias_mxfp8_scaling(
)


class TestQuantizeWithVmap:
"""Test vmap support for quantization primitives."""

@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("q_layout", [QuantizeLayout.ROWWISE])
def test_vmap_quantize(self, in_dtype, scaling_mode, q_layout):
"""Test that vmap works with tex.quantize using the general batcher."""
# Determine q_dtype based on scaling mode
if scaling_mode.is_nvfp4_scaling:
q_dtype = jnp.float4_e2m1fn
else:
q_dtype = jnp.float8_e4m3fn

# Create batched input (E, M, K) - E experts
E, M, K = 4, 64, 128
key = jax.random.PRNGKey(0)
batched_input = jax.random.uniform(key, (E, M, K), in_dtype)

# Create per-expert quantizers
quantizers = [
QuantizerFactory.create(
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
)
for _ in range(E)
]

# Stack quantizers for vmap
stacked_quantizers = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *quantizers)

# Vmap over expert dimension
def quantize_single(x, quantizer):
return tex.quantize(x, quantizer=quantizer, flatten_axis=-1)

vmapped_quantize = jax.vmap(quantize_single, in_axes=(0, 0))
result = vmapped_quantize(batched_input, stacked_quantizers)

# Verify shapes
assert result.data.shape == (E, M, K)
assert result.scale_inv.shape[0] == E # Per-expert scales

# Compare with calling quantize for each expert individually
individual_results = []
for i in range(E):
res_i = tex.quantize(batched_input[i], quantizer=quantizers[i], flatten_axis=-1)
individual_results.append(res_i.data)

expected = jnp.stack(individual_results, axis=0)
assert_allclose(result.data, expected, dtype=quantizers[0].q_dtype)


valid_fp8_gemm_operand_types = [
(jnp.float8_e4m3fn, jnp.float8_e4m3fn),
(jnp.float8_e5m2, jnp.float8_e4m3fn),
Expand Down
221 changes: 221 additions & 0 deletions tests/jax/test_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for TE einsum operation with FP8 quantization."""

import jax
import jax.numpy as jnp
import pytest
from jax import value_and_grad

from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.jax.einsum import einsum
from transformer_engine.jax.quantize import (
QuantizerFactory,
QuantizeMeta,
QuantizeMetaSet,
)
from transformer_engine.jax.quantize import helper


# Test parameters
DTYPES = [jnp.bfloat16]
# (B, S, M, E, C, H)
# B: Batch size
# S: Sequence length (number of tokens)
# M: Model dimension (hidden size)
# E: Number of experts
# C: Capacity (max tokens per expert)
# H: Hidden dimension (MLP intermediate size)
MOE_CASES = [
(2, 32, 128, 4, 32, 64),
]

# Get supported recipes
supported_recipes = helper.get_supported_quantization_recipes()
supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes]


@pytest.fixture(autouse=True, scope="module")
def init():
"""WAR for CUDA uninitialize error"""
# Calling customcalls before jax may cause CUDA uninitialize error
_ = jnp.zeros(0)
yield


class TestMoEMLPWithRecipes:
"""Test MoE MLP operations with different FP8 recipes and gradients."""

def _get_quantizer_sets(self, recipe, num_experts):
return QuantizerFactory.create_set(
n_quantizer_sets=num_experts,
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)

def _einsum(self, equation, *operands, quantizer_sets=None, quantizer_dim=None, fallback=False):
out = einsum(
equation,
*operands,
quantizer_sets=quantizer_sets,
quantizer_dim=quantizer_dim,
fallback=fallback,
)
return jnp.mean(out)

def _ref_einsum(self, equation, *operands):
out = jnp.einsum(equation, *operands)
return jnp.mean(out)

@pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES)
@pytest_parametrize_wrapper("recipe", supported_recipes)
def test_mlp_up_grad(self, B, S, M, E, C, H, recipe):
"""Test MLP up: EBCM,EMH->EBCH with gradients and different recipes."""
# Create per-expert quantizers
quantizer_sets = self._get_quantizer_sets(recipe, E)
dispatched = jax.random.normal(
jax.random.PRNGKey(0), (E, B, C, M), dtype=jnp.bfloat16
) / jnp.sqrt(M)
weights = jax.random.normal(jax.random.PRNGKey(1), (E, M, H), dtype=jnp.bfloat16)

# Compute with TE einsum with quantization
loss_te, grads_te = value_and_grad(self._einsum, argnums=(1, 2))(
"EBCM,EMH->EBCH", dispatched, weights, quantizer_sets=quantizer_sets, quantizer_dim="E"
)

# Compute reference (BF16)
loss_ref, grads_ref = value_and_grad(self._ref_einsum, argnums=(1, 2))(
"EBCM,EMH->EBCH", dispatched, weights
)

# Verify shapes and no NaNs
assert grads_te[0].shape == dispatched.shape
assert grads_te[1].shape == weights.shape
assert not jnp.isnan(loss_te)
assert jnp.all(jnp.isfinite(grads_te[0]))
assert jnp.all(jnp.isfinite(grads_te[1]))

# Compare with reference (with FP8 tolerance)
assert_allclose(loss_te, loss_ref, dtype=quantizer_sets[0].x.q_dtype)
assert_allclose(grads_te[0], grads_ref[0], dtype=quantizer_sets[0].dgrad.q_dtype)
assert_allclose(grads_te[1], grads_ref[1], dtype=quantizer_sets[0].dgrad.q_dtype)

@pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES)
@pytest_parametrize_wrapper("recipe", supported_recipes)
def test_mlp_down_grad(self, B, S, M, E, C, H, recipe):
"""Test MLP down: EBCH,EHM->EBCM with gradients and different recipes."""
# Create per-expert quantizers
quantizer_sets = self._get_quantizer_sets(recipe, E)

hidden = jax.random.normal(
jax.random.PRNGKey(0), (E, B, C, H), dtype=jnp.bfloat16
) / jnp.sqrt(H)
weights = jax.random.normal(jax.random.PRNGKey(1), (E, H, M), dtype=jnp.bfloat16)

# Compute with TE einsum with quantization
loss_te, grads_te = value_and_grad(self._einsum, argnums=(1, 2))(
"EBCH,EHM->EBCM", hidden, weights, quantizer_sets=quantizer_sets, quantizer_dim="E"
)

# Compute reference (BF16)
loss_ref, grads_ref = value_and_grad(self._ref_einsum, argnums=(1, 2))(
"EBCH,EHM->EBCM", hidden, weights
)

# Verify shapes and no NaNs
assert grads_te[0].shape == hidden.shape
assert grads_te[1].shape == weights.shape
assert not jnp.isnan(loss_te)
assert jnp.all(jnp.isfinite(grads_te[0]))
assert jnp.all(jnp.isfinite(grads_te[1]))

# Compare with reference (with FP8 tolerance)
assert_allclose(loss_te, loss_ref, dtype=quantizer_sets[0].x.q_dtype)
assert_allclose(grads_te[0], grads_ref[0], dtype=quantizer_sets[0].dgrad.q_dtype)
assert_allclose(grads_te[1], grads_ref[1], dtype=quantizer_sets[0].dgrad.q_dtype)

@pytest_parametrize_wrapper("B,S,M,E,C,H", MOE_CASES)
@pytest_parametrize_wrapper("recipe", supported_recipes)
def test_full_moe_grad(self, B, S, M, E, C, H, recipe):
"""Test full MoE pipeline (all 4 einsums) with gradients and different recipes."""
# Create per-expert quantizers for each einsum
mlp_up_quantizer_sets = self._get_quantizer_sets(recipe, E)
mlp_down_quantizer_sets = self._get_quantizer_sets(recipe, E)

tokens = jax.random.normal(jax.random.PRNGKey(0), (B, S, M), dtype=jnp.bfloat16) / jnp.sqrt(
M
)
routing = jax.random.normal(jax.random.PRNGKey(1), (B, S, E, C), dtype=jnp.bfloat16)
routing = jax.nn.softmax(routing, axis=-1) # Normalize routing weights
up_weights = jax.random.normal(
jax.random.PRNGKey(2), (E, M, H), dtype=jnp.bfloat16
) / jnp.sqrt(H)
down_weights = jax.random.normal(
jax.random.PRNGKey(3), (E, H, M), dtype=jnp.bfloat16
) / jnp.sqrt(M)

# TE implementation with quantization
def full_moe_te(tokens, routing, up_w, down_w):
"""Complete MoE pipeline with TE einsum."""
dispatched = einsum("BSM,BSEC->EBCM", tokens, routing, fallback=True)
hidden = einsum(
"EBCM,EMH->EBCH",
dispatched,
up_w,
quantizer_sets=mlp_up_quantizer_sets,
quantizer_dim="E",
)
expert_out = einsum(
"EBCH,EHM->EBCM",
hidden,
down_w,
quantizer_sets=mlp_down_quantizer_sets,
quantizer_dim="E",
)
output = einsum("EBCM,BSEC->BSM", expert_out, routing, fallback=True)
return jnp.sum(output)

# Reference implementation with jnp.einsum
def full_moe_ref(tokens, routing, up_w, down_w):
"""Complete MoE pipeline with jnp.einsum."""
dispatched = jnp.einsum("BSM,BSEC->EBCM", tokens, routing)
hidden = jnp.einsum("EBCM,EMH->EBCH", dispatched, up_w)
expert_out = jnp.einsum("EBCH,EHM->EBCM", hidden, down_w)
output = jnp.einsum("EBCM,BSEC->BSM", expert_out, routing)
return jnp.sum(output)

loss_te, grads_te = value_and_grad(full_moe_te, argnums=(0, 1, 2, 3))(
tokens, routing, up_weights, down_weights
)

loss_ref, grads_ref = value_and_grad(full_moe_ref, argnums=(0, 1, 2, 3))(
tokens, routing, up_weights, down_weights
)

# Verify all gradient shapes
assert grads_te[0].shape == tokens.shape, f"tokens grad shape mismatch"
assert grads_te[1].shape == routing.shape, f"routing grad shape mismatch"
assert grads_te[2].shape == up_weights.shape, f"up_weights grad shape mismatch"
assert grads_te[3].shape == down_weights.shape, f"down_weights grad shape mismatch"

# Verify no NaNs or Infs
assert not jnp.isnan(loss_te), "Loss is NaN"
assert jnp.isfinite(loss_te), "Loss is Inf"
assert jnp.all(jnp.isfinite(grads_te[0])), "tokens grad has NaN/Inf"
assert jnp.all(jnp.isfinite(grads_te[1])), "routing grad has NaN/Inf"
assert jnp.all(jnp.isfinite(grads_te[2])), "up_weights grad has NaN/Inf"
assert jnp.all(jnp.isfinite(grads_te[3])), "down_weights grad has NaN/Inf"

# Compare with reference (with FP8 tolerance)
assert_allclose(loss_te, loss_ref, dtype=mlp_up_quantizer_sets[0].x.q_dtype)
assert_allclose(grads_te[0], grads_ref[0], dtype=mlp_up_quantizer_sets[0].dgrad.q_dtype)
assert_allclose(grads_te[1], grads_ref[1], dtype=mlp_up_quantizer_sets[0].dgrad.q_dtype)
assert_allclose(grads_te[2], grads_ref[2], dtype=mlp_down_quantizer_sets[0].x.q_dtype)
assert_allclose(grads_te[3], grads_ref[3], dtype=mlp_down_quantizer_sets[0].dgrad.q_dtype)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
36 changes: 36 additions & 0 deletions transformer_engine/jax/cpp_extensions/amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types
output_spec = (f"{prefix}_amax",)
return SdyShardingRule((input_spec,), (output_spec,))

@staticmethod
def batcher(batched_args, batch_dims, *, amax_scope, transpose_batch_sequence):
"""Batcher for amax calculation - returns single amax value."""
return AmaxCalculationPrimitive.batcher_impl(
batched_args,
batch_dims,
static_kwargs={
"amax_scope": amax_scope,
"transpose_batch_sequence": transpose_batch_sequence,
},
)


register_primitive(AmaxCalculationPrimitive, outer_only=True)

Expand Down Expand Up @@ -370,6 +382,30 @@ def shardy_sharding_rule(
output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",)
return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec))

@staticmethod
def batcher(
batched_args,
batch_dims,
*,
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
):
"""Batcher for RHT amax calculation - returns 2 amax values."""
return RHTAmaxCalculationPrimitive.batcher_impl(
batched_args,
batch_dims,
static_kwargs={
"amax_scope": amax_scope,
"transpose_batch_sequence": transpose_batch_sequence,
"rht_matrix_random_sign_mask_t": rht_matrix_random_sign_mask_t,
"produce_regular_amax": produce_regular_amax,
"flatten_axis": flatten_axis,
},
)


register_primitive(RHTAmaxCalculationPrimitive)

Expand Down
Loading