Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,591 changes: 2,591 additions & 0 deletions tests/jax/ffi_hlo/transformer_stablehlo.txt

Large diffs are not rendered by default.

209 changes: 209 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
#
# See LICENSE for license information.

from io import StringIO
import jax
import jax.numpy as jnp
import pytest
from jax import jit, value_and_grad
from functools import reduce
from typing import Union
import operator
import os
import re

from utils import (
assert_allclose,
Expand Down Expand Up @@ -1921,3 +1924,209 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)


@pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason)
class TestFFICompatibility:

HLO_DIR = os.path.join(os.path.dirname(__file__), "ffi_hlo")

@pytest.fixture(name="ffi_hlo_name")
def hlo_fixture(shape):
for file in os.listdir(TestFFICompatibility.HLO_DIR):
file_path = os.path.join(TestFFICompatibility.HLO_DIR, file)
if os.path.isfile(file_path):
yield file.split(".")[0]

@pytest.mark.skipif(
os.getenv("NVTE_JAX_FFI_HLO_GENERATE", "0") != "1", reason="HLO generation not enabled"
)
def test_generate_hlo(self):
"""Run this test with NVTE_JAX_FFI_HLO_GENERATE=1 to generate StableHLO text files for FFI compatibility tests. Use this when intentionally changing FFI bindings and breaking compatibility changes are required.

Instructions:
1. `CUDA_VISIBLE_DEVICES=0 XLA_FLAGS="$XLA_FLAGS --xla_dump_to=./tests/jax/ffi_hlo_dump" NVTE_JAX_FFI_HLO_GENERATE=1 pytest tests/jax/test_custom_call_compute.py::TestFFICompatibility::test_generate_hlo -s`
2. Find `tests/jax/ffi_hlo_dump/jit_train_step_<some numbers>/module.mlir` and copy it to the `tests/jax/ffi_hlo/` directory named transformer_stablehlo.txt
"""
import math
from transformer_engine.common.recipe import NVFP4BlockScaling, Float8CurrentScaling
from transformer_engine.jax import autocast, MeshResource, softmax
from transformer_engine.jax.flax import TransformerLayer
import flax.linen as nn

with autocast(enabled=True, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()):

class Model(nn.Module):
"""This module does not represent any meaningful model, it is just to cover all FFI calls."""

@nn.compact
def __call__(self, x):
# Covers most of the FFI calls
x = TransformerLayer(
hidden_dropout=0.0,
attention_dropout=0.0,
intermediate_dropout=0.0,
dtype=jnp.bfloat16,
)(x)

# Arbitrarily call softmax multiple times to cover all softmax FFI calls
x = x.reshape((1, *x.shape))
x = softmax.softmax(x, softmax_fusion_type=softmax.SoftmaxFusionType.SCALED)
mask1 = self.variable(
"collection",
"mask1",
lambda: jax.random.bernoulli(jax.random.PRNGKey(0), shape=x.shape).astype(
jnp.bfloat16
),
).value.astype(jnp.uint8)
x = softmax.softmax(
x, mask=mask1, softmax_fusion_type=softmax.SoftmaxFusionType.SCALED_MASKED
)
mask2 = self.variable(
"collection",
"mask2",
lambda: (1.0 - jnp.tril(jnp.ones_like(x))).astype(jnp.bfloat16),
).value.astype(jnp.uint8)
x = x.reshape((-1, 1, 32, 32))
x = softmax.softmax(
x,
mask=mask2,
softmax_fusion_type=softmax.SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
)
return x

model = Model()
input_shape = (1, 128, 512)
x = jnp.ones(input_shape, dtype=jnp.bfloat16)

var_collect = model.init(jax.random.PRNGKey(0), x)

def f(var_collect, x):
x = model.apply(var_collect, x, rngs={"sr_rng": jax.random.PRNGKey(0)})
x = jnp.mean(x) # fake loss function for value_and_grad
return x

@jax.jit
def train_step(var_collect, x, grouped_kernel):
loss, grads = jax.value_and_grad(f)(var_collect, x)

# Arbitrarily call grouped quantize and GEMM to cover remaining FFI calls
x = x.reshape((-1, x.shape[-1]))
x = grouped_dense(
x,
grouped_kernel,
contracting_dims=((1,), (1,)),
group_sizes=jnp.array([x.shape[0]], dtype=jnp.int32),
quantizer_set=QuantizerFactory.create_set(
n_groups=1,
fp8_recipe=Float8CurrentScaling(),
quantize_meta_set=QuantizeMetaSet(
QuantizeMeta(), QuantizeMeta(), QuantizeMeta()
),
),
)
loss += jnp.mean(x)

return loss, grads

grouped_kernel = jnp.zeros((1, x.shape[-1], x.shape[-1]), dtype=jnp.bfloat16)
train_step(var_collect, x, grouped_kernel)

def _get_hlo_text_from_file(self, hlo_name: str) -> str:
"""Reads the StableHLO text from a file given its name."""
hlo_file_path = os.path.join(self.HLO_DIR, f"{hlo_name}.txt")
with open(hlo_file_path, "r") as f:
hlo_text = f.read()
return hlo_text

def _make_args_based_on_input_tensor_shape_and_dtype(self, stablehlo_text: str):
"""Parses the StableHLO text to extract input tensor shapes and dtypes, and creates dummy JAX arrays accordingly."""
# Parse function signature to extract argument information
# Pattern matches: @main(%arg0: tensor<32x32xbf16>, %arg1: tensor<64xf32>, ...)
pattern = r"@main\((.*?)\{"
match = re.search(pattern, stablehlo_text)

if not match:
raise ValueError("Could not find @main function signature in StableHLO text")

args_str = match.group(1)

# Parse individual arguments
# Pattern matches: %arg0: tensor<32x32xbf16>
arg_pattern = r"%arg(\d+):\s*tensor<([^>]+)>"
arg_matches = re.findall(arg_pattern, args_str)

parsed_args = []
for arg_num, shape_and_dtype_str in arg_matches:
print(f"Parsing argument {arg_num} with shape and dtype: {shape_and_dtype_str}")
# Parse shape: "32x32xbf16" -> [32, 32]
dtype_str = shape_and_dtype_str.split("x")[-1]
shape = [int(dim) for dim in shape_and_dtype_str.split("x")[:-1]]

# Map StableHLO dtype to JAX dtype
dtype_map = {
"bf16": jnp.bfloat16,
"f32": jnp.float32,
"f16": jnp.float16,
"f8E4M3FN": jnp.float8_e4m3fn,
"f8E5M2": jnp.float8_e5m2,
"i32": jnp.int32,
"ui32": jnp.uint32,
}
dtype = dtype_map.get(dtype_str)

parsed_args.append(jnp.ones(shape, dtype=dtype))
return parsed_args

def test_ffi_compatibility(self, ffi_hlo_name):
"""Tests that the current FFI bindings are compatible with the provided HLO and there are no API mismatches."""
from jax.extend.backend import get_backend

stablehlo_text = self._get_hlo_text_from_file(ffi_hlo_name)
args = self._make_args_based_on_input_tensor_shape_and_dtype(stablehlo_text)

client = get_backend("cuda")
executable = client.compile_and_load(
stablehlo_text.encode("utf-8"), executable_devices=jax.devices()[:1]
)
results = executable.execute(args)
print(results) # No need to assert anything here, just ensure it runs without error

def test_all_primitive_ffi_tested(self):
"""Ensures that all our TE primitives with FFI bindings are included in the FFI HLO compatibility tests."""
# Open all HLO files and extract primitive FFI names
tested_hlos = set()
for file in os.listdir(self.HLO_DIR):
file_path = os.path.join(self.HLO_DIR, file)
if os.path.isfile(file_path) and file.endswith(".txt"):
with open(file_path, "r") as f:
hlo_text = f.read()
# Extract primitive name from HLO text
pattern = r"stablehlo.custom_call @(.+?)\("
matches = re.findall(pattern, hlo_text)
if matches:
for match in matches:
primitive_name = match
tested_hlos.add(primitive_name)

# Assert that all registered primitives have corresponding FFI tests
import transformer_engine_jax

KNOWN_MISSING_FFI_TESTS = {
# dequantize does not have a JAX primitive currently
"te_dequantize_ffi",
# needs testing
"te_grouped_gemm_d2h_group_sizes_ffi",
}

unmatched_primitives = set()
for primitive_ffi_name, _ in transformer_engine_jax.registrations().items():
if (
primitive_ffi_name not in tested_hlos
and primitive_ffi_name not in KNOWN_MISSING_FFI_TESTS
):
unmatched_primitives.add(primitive_ffi_name)

assert (
len(unmatched_primitives) == 0
), f"The following primitives do not have FFI tests: {unmatched_primitives}"