From de40a714fc680bbc344a2009c3afabe187476120 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 16:19:36 -0800 Subject: [PATCH 1/6] Add triton version detection logic, and NVTE_USE_PYTORCH_TRITON knob for jax Signed-off-by: tdophung --- .../jax/triton_extensions/__init__.py | 33 ++- .../jax/triton_extensions/utils.py | 201 ++++++++++++++++-- 2 files changed, 221 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index 13a36421bf1..e0aa956fd30 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -9,7 +9,33 @@ IMPORTANT: This module requires Triton to be installed. If you don't have Triton, use transformer_engine.jax.cpp_extensions instead (CUDA/FFI based primitives). -Install Triton: pip install triton + +Triton Package Options: +----------------------- +There are two compatible Triton packages: + +1. Standard 'triton' from OpenAI (recommended for JAX-only environments): + pip install triton + +2. 'pytorch-triton' from PyTorch's index (for mixed JAX+PyTorch environments): + pip install torch --index-url https://download.pytorch.org/whl/cu121 + # pytorch-triton is automatically installed as a dependency + + Both packages work with JAX Triton kernels. The pytorch-triton package + has version format "X.Y.Z+" (e.g., "3.0.0+45fff310c8"). + +WARNING: Do NOT run 'pip install pytorch-triton' directly! The package on PyPI +is a placeholder that will fail with "RuntimeError: Should never be installed". +The real pytorch-triton only comes bundled with PyTorch from PyTorch's index. + + +Environment Variables: + NVTE_USE_PYTORCH_TRITON: If set to "1", acknowledge using pytorch-triton + for JAX Triton kernels (suppresses compatibility warnings). Set this + when both JAX and PyTorch are installed in the same environment. + + Example: + export NVTE_USE_PYTORCH_TRITON=1 Usage: @@ -23,6 +49,11 @@ def lowering(ctx, x, **kwargs): # Use permutation functions from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map + + # Check Triton package info + from transformer_engine.jax.triton_extensions import get_triton_info + info = get_triton_info() + print(f"Using Triton {info['version']} from {info['source']}") """ from .utils import * diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 12d6a9e3de4..97ffc7ba7ff 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -6,9 +6,33 @@ This module provides utility functions for integrating Triton kernels into JAX primitives. Triton is only imported when this module is used. + +Triton Package Compatibility: + There are two Triton packages that can be used: + + 1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box. + Install with: pip install triton + + 2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes + PyTorch-specific patches. Version format: "3.0.0+" + + IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a + placeholder that will NOT work. The real pytorch-triton is only available + from PyTorch's package index and is auto-installed with PyTorch: + pip install torch --index-url https://download.pytorch.org/whl/cu121 + + pytorch-triton has been tested to work with JAX Triton kernels. + +Environment Variables: + NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using + pytorch-triton for JAX Triton kernels (suppresses warnings). This is + useful when both JAX and PyTorch are installed in the same environment. + Default is "0". """ import hashlib +import os +import warnings from typing import Any, Callable, Mapping import zlib @@ -17,6 +41,115 @@ import jax.numpy as jnp +# Placeholder package version on PyPI that should never be used +_PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1" + + +def _detect_triton_package(): + """Detect which Triton package is installed and validate compatibility. + + Returns: + tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool) + + The function detects: + - None: Triton not installed + - Standard triton from OpenAI (versions like "3.1.0") + - Real pytorch-triton from PyTorch's index (versions like "3.0.0+45fff310c8") + - Placeholder pytorch-triton from PyPI (version "0.0.1" - broken, raises RuntimeError) + """ + try: + import triton + triton_version = getattr(triton, "__version__", "unknown") + except ImportError: + return None, False, False + except RuntimeError as e: + # The placeholder pytorch-triton package from PyPI raises: + # RuntimeError: "Should never be installed" + if "Should never be installed" in str(e): + return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True + raise + + # Check for placeholder package (version 0.0.1 from PyPI) + is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION + + # Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8" + is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8 + + return triton_version, is_pytorch_triton, is_placeholder + + +def _check_triton_compatibility(): + """Check Triton package compatibility and emit warnings if necessary. + + This function handles the case where both JAX and PyTorch may be installed, + each expecting different Triton packages: + - JAX typically uses the standard 'triton' package from OpenAI + - PyTorch uses 'pytorch-triton' which is versioned with commit SHAs + + The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly + acknowledge using pytorch-triton with JAX (suppresses warnings). + + Raises: + ImportError: If triton is not installed or the placeholder package is detected. + """ + triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() + + # Handle placeholder package from PyPI + if is_placeholder: + raise ImportError( + "Detected the placeholder 'pytorch-triton' package (version 0.0.1) from PyPI.\n" + "This is NOT a functional Triton installation.\n\n" + "The placeholder package exists to prevent namespace conflicts. To fix this:\n\n" + "Option 1 - Use standard Triton (recommended for JAX-only environments):\n" + " pip uninstall pytorch-triton triton\n" + " pip install triton\n\n" + "Option 2 - Use real pytorch-triton (for mixed JAX+PyTorch environments):\n" + " pip uninstall pytorch-triton triton\n" + " pip install torch --index-url https://download.pytorch.org/whl/cu121\n" + " # pytorch-triton is automatically installed as a torch dependency\n\n" + "Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n" + "the broken placeholder. The real pytorch-triton only comes from PyTorch's index." + ) + + if triton_version is None: + raise ImportError( + "Triton is required for transformer_engine.jax.triton_extensions.\n\n" + "Option 1 - Install standard Triton (recommended for JAX-only):\n" + " pip install triton\n\n" + "Option 2 - Install PyTorch with pytorch-triton (for mixed environments):\n" + " pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n" + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) + + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() + use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes") + + if is_pytorch_triton: + if use_pytorch_triton_explicit: + # User explicitly opted in - just log info (no warning) + pass # Silent acknowledgment, no warning needed + else: + # pytorch-triton detected but user didn't explicitly opt in + warnings.warn( + f"Detected pytorch-triton package (version {triton_version}) instead of " + f"the standard 'triton' package from OpenAI. This typically happens when " + f"PyTorch is installed alongside JAX.\n\n" + f"pytorch-triton is compatible with JAX Triton kernels. To suppress this " + f"warning, set:\n" + f" export NVTE_USE_PYTORCH_TRITON=1\n\n" + f"Alternatively, for a JAX-only environment:\n" + f" - Use separate virtual environments for JAX and PyTorch, or\n" + f" - Use transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", + category=UserWarning, + stacklevel=3, + ) + + return triton_version, is_pytorch_triton + + +# Perform compatibility check and get triton info +_TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility() + try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc @@ -30,12 +163,42 @@ ) from e -__all__ = ["triton_call_lowering"] +__all__ = ["triton_call_lowering", "get_triton_info"] # Triton kernel cache (module-level, shared across all kernels) _TRITON_KERNEL_CACHE = {} +def get_triton_info(): + """Get information about the installed Triton package. + + Returns: + dict: Dictionary containing: + - version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8") + - is_pytorch_triton (bool): True if using real pytorch-triton from PyTorch's index + - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI + - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set + - source (str): "pytorch" or "openai" indicating the package source + + Example: + >>> from transformer_engine.jax.triton_extensions import get_triton_info + >>> info = get_triton_info() + >>> print(f"Triton version: {info['version']} (from {info['source']})") + >>> if info['is_pytorch_triton']: + ... print("Using pytorch-triton - compatible with both PyTorch and JAX") + """ + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() + env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") + + return { + "version": _TRITON_VERSION, + "is_pytorch_triton": _IS_PYTORCH_TRITON, + "is_openai_triton": not _IS_PYTORCH_TRITON, + "env_acknowledged": env_acknowledged and _IS_PYTORCH_TRITON, + "source": "pytorch" if _IS_PYTORCH_TRITON else "openai", + } + + def get_triton_dtype(aval): """Convert JAX dtype to Triton type string. @@ -142,17 +305,31 @@ def compile_triton( ) # Create kernel object for JAX - kernel = gpu_triton.TritonKernel( - compiled.name, - num_warps, - compiled.metadata.shared, - compiled.asm["ptx"], - "", # ttir - compute_capability, - 1, - 1, - 1, # cluster_dims - ) + # From jax/jaxlib/gpu/triton_kernels.cc: + from packaging import version + + if version.parse(jax.__version__) >= version.parse("0.8.2"): + kernel = gpu_triton.TritonKernel( + compiled.name, # arg0: kernel_name (str) + num_warps, # arg1: num_warps (int) + num_ctas, # arg2: num_ctas (int) + compiled.metadata.shared, # arg3: shared_mem_bytes (int) + compiled.asm["ptx"], # arg4: ptx (str) + "", # arg5: ttir (str) - empty + compute_capability, # arg6: compute_capability (int) + ) + else: + kernel = gpu_triton.TritonKernel( + compile.name, + num_warps, + compiled.metadata.shared, + compiled.asm["ptx"], + "", # ttir + compute_capability, + 1, + 1, + 1, + ) _TRITON_KERNEL_CACHE[cache_key] = kernel return kernel From d832ff8bc66e11922460fffa8884989d6aee8bac Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 16:48:31 -0800 Subject: [PATCH 2/6] change build requirements and installation to reflect new option Signed-off-by: tdophung --- build_tools/jax.py | 25 +++++++++++++++++++++++-- build_tools/pytorch.py | 14 ++++++++++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index df78bf3e2f3..14612f57f78 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -19,8 +19,29 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: - """Test dependencies for TE/JAX extensions.""" - return ["numpy", "triton"] + """Test dependencies for TE/JAX extensions. + + Triton Package Selection: + The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable: + + Default (NVTE_USE_PYTORCH_TRITON unset or "0"): + Returns 'triton' - OpenAI's standard package from PyPI. + Install with: pip install triton + + NVTE_USE_PYTORCH_TRITON=1: + Returns 'pytorch-triton' - for mixed JAX+PyTorch environments. + Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121 + + Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder. + """ + use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ("1", "true", "yes") + + triton_package = "pytorch-triton" if use_pytorch_triton else "triton" + + return [ + "numpy", + triton_package, + ] def xla_path() -> str: diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index b03ef04fa42..ebc32de551f 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -13,7 +13,17 @@ def install_requirements() -> List[str]: - """Install dependencies for TE/PyTorch extensions.""" + """Install dependencies for TE/PyTorch extensions. + + IMPORTANT - PyTorch Index Required for pytorch-triton: + These dependencies MUST be installed using PyTorch's package index: + + pip install pytorch-triton --index-url https://download.pytorch.org/whl/ + + - pytorch-triton is only available from PyTorch's index (not PyPI) + - The 'pytorch-triton' package on PyPI is a placeholder that will fail + - torch.compile() requires pytorch-triton, not OpenAI's 'triton' package + """ return [ "torch>=2.1", "einops", @@ -22,7 +32,7 @@ def install_requirements() -> List[str]: "packaging", "pydantic", "nvdlfw-inspect", - "triton", + "pytorch-triton", ] From 49f097604502bc54eaad057736089bc304b33730 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 16:54:53 -0800 Subject: [PATCH 3/6] reduce boilerplate comments Signed-off-by: tdophung --- .../jax/triton_extensions/utils.py | 43 +------------------ 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 97ffc7ba7ff..018ed2e0743 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -7,27 +7,7 @@ This module provides utility functions for integrating Triton kernels into JAX primitives. Triton is only imported when this module is used. -Triton Package Compatibility: - There are two Triton packages that can be used: - - 1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box. - Install with: pip install triton - - 2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes - PyTorch-specific patches. Version format: "3.0.0+" - - IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a - placeholder that will NOT work. The real pytorch-triton is only available - from PyTorch's package index and is auto-installed with PyTorch: - pip install torch --index-url https://download.pytorch.org/whl/cu121 - - pytorch-triton has been tested to work with JAX Triton kernels. - -Environment Variables: - NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using - pytorch-triton for JAX Triton kernels (suppresses warnings). This is - useful when both JAX and PyTorch are installed in the same environment. - Default is "0". +Triton Package Compatibility --> see __init__.py """ import hashlib @@ -79,19 +59,7 @@ def _detect_triton_package(): def _check_triton_compatibility(): - """Check Triton package compatibility and emit warnings if necessary. - - This function handles the case where both JAX and PyTorch may be installed, - each expecting different Triton packages: - - JAX typically uses the standard 'triton' package from OpenAI - - PyTorch uses 'pytorch-triton' which is versioned with commit SHAs - - The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly - acknowledge using pytorch-triton with JAX (suppresses warnings). - - Raises: - ImportError: If triton is not installed or the placeholder package is detected. - """ + """Check Triton package compatibility and emit warnings if necessary.""" triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() # Handle placeholder package from PyPI @@ -179,13 +147,6 @@ def get_triton_info(): - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set - source (str): "pytorch" or "openai" indicating the package source - - Example: - >>> from transformer_engine.jax.triton_extensions import get_triton_info - >>> info = get_triton_info() - >>> print(f"Triton version: {info['version']} (from {info['source']})") - >>> if info['is_pytorch_triton']: - ... print("Using pytorch-triton - compatible with both PyTorch and JAX") """ use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") From d2371e791791f59124e4e663b7299650e46c84a5 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 16:56:35 -0800 Subject: [PATCH 4/6] format code Signed-off-by: tdophung --- build_tools/pytorch.py | 6 +++--- transformer_engine/jax/triton_extensions/__init__.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index ebc32de551f..19abd7d8293 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -14,12 +14,12 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions. - + IMPORTANT - PyTorch Index Required for pytorch-triton: These dependencies MUST be installed using PyTorch's package index: - + pip install pytorch-triton --index-url https://download.pytorch.org/whl/ - + - pytorch-triton is only available from PyTorch's index (not PyPI) - The 'pytorch-triton' package on PyPI is a placeholder that will fail - torch.compile() requires pytorch-triton, not OpenAI's 'triton' package diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index e0aa956fd30..6e635d3a2a2 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -33,7 +33,7 @@ NVTE_USE_PYTORCH_TRITON: If set to "1", acknowledge using pytorch-triton for JAX Triton kernels (suppresses compatibility warnings). Set this when both JAX and PyTorch are installed in the same environment. - + Example: export NVTE_USE_PYTORCH_TRITON=1 @@ -49,7 +49,7 @@ def lowering(ctx, x, **kwargs): # Use permutation functions from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map - + # Check Triton package info from transformer_engine.jax.triton_extensions import get_triton_info info = get_triton_info() From 43c34506fc9743c4de10904bc0915dd66e87e6d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 01:18:39 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- build_tools/jax.py | 18 +++-- build_tools/pytorch.py | 6 +- .../jax/triton_extensions/__init__.py | 4 +- .../jax/triton_extensions/utils.py | 65 +++++++++---------- 4 files changed, 48 insertions(+), 45 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 14612f57f78..ec0b4aaef45 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -20,24 +20,28 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: """Test dependencies for TE/JAX extensions. - + Triton Package Selection: The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable: - + Default (NVTE_USE_PYTORCH_TRITON unset or "0"): Returns 'triton' - OpenAI's standard package from PyPI. Install with: pip install triton - + NVTE_USE_PYTORCH_TRITON=1: Returns 'pytorch-triton' - for mixed JAX+PyTorch environments. Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121 - + Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder. """ - use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ("1", "true", "yes") - + use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( + "1", + "true", + "yes", + ) + triton_package = "pytorch-triton" if use_pytorch_triton else "triton" - + return [ "numpy", triton_package, diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index ebc32de551f..19abd7d8293 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -14,12 +14,12 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions. - + IMPORTANT - PyTorch Index Required for pytorch-triton: These dependencies MUST be installed using PyTorch's package index: - + pip install pytorch-triton --index-url https://download.pytorch.org/whl/ - + - pytorch-triton is only available from PyTorch's index (not PyPI) - The 'pytorch-triton' package on PyPI is a placeholder that will fail - torch.compile() requires pytorch-triton, not OpenAI's 'triton' package diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index e0aa956fd30..6e635d3a2a2 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -33,7 +33,7 @@ NVTE_USE_PYTORCH_TRITON: If set to "1", acknowledge using pytorch-triton for JAX Triton kernels (suppresses compatibility warnings). Set this when both JAX and PyTorch are installed in the same environment. - + Example: export NVTE_USE_PYTORCH_TRITON=1 @@ -49,7 +49,7 @@ def lowering(ctx, x, **kwargs): # Use permutation functions from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map - + # Check Triton package info from transformer_engine.jax.triton_extensions import get_triton_info info = get_triton_info() diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 97ffc7ba7ff..b8b8f5b5bc8 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -9,23 +9,23 @@ Triton Package Compatibility: There are two Triton packages that can be used: - + 1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box. Install with: pip install triton - + 2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes PyTorch-specific patches. Version format: "3.0.0+" - - IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a + + IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a placeholder that will NOT work. The real pytorch-triton is only available from PyTorch's package index and is auto-installed with PyTorch: pip install torch --index-url https://download.pytorch.org/whl/cu121 - + pytorch-triton has been tested to work with JAX Triton kernels. Environment Variables: - NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using - pytorch-triton for JAX Triton kernels (suppresses warnings). This is + NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using + pytorch-triton for JAX Triton kernels (suppresses warnings). This is useful when both JAX and PyTorch are installed in the same environment. Default is "0". """ @@ -47,10 +47,10 @@ def _detect_triton_package(): """Detect which Triton package is installed and validate compatibility. - + Returns: tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool) - + The function detects: - None: Triton not installed - Standard triton from OpenAI (versions like "3.1.0") @@ -59,6 +59,7 @@ def _detect_triton_package(): """ try: import triton + triton_version = getattr(triton, "__version__", "unknown") except ImportError: return None, False, False @@ -68,32 +69,32 @@ def _detect_triton_package(): if "Should never be installed" in str(e): return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True raise - + # Check for placeholder package (version 0.0.1 from PyPI) is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION - + # Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8" is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8 - + return triton_version, is_pytorch_triton, is_placeholder def _check_triton_compatibility(): """Check Triton package compatibility and emit warnings if necessary. - + This function handles the case where both JAX and PyTorch may be installed, each expecting different Triton packages: - JAX typically uses the standard 'triton' package from OpenAI - PyTorch uses 'pytorch-triton' which is versioned with commit SHAs - + The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly acknowledge using pytorch-triton with JAX (suppresses warnings). - + Raises: ImportError: If triton is not installed or the placeholder package is detected. """ triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() - + # Handle placeholder package from PyPI if is_placeholder: raise ImportError( @@ -110,7 +111,7 @@ def _check_triton_compatibility(): "Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n" "the broken placeholder. The real pytorch-triton only comes from PyTorch's index." ) - + if triton_version is None: raise ImportError( "Triton is required for transformer_engine.jax.triton_extensions.\n\n" @@ -120,10 +121,10 @@ def _check_triton_compatibility(): " pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n" "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) - + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes") - + if is_pytorch_triton: if use_pytorch_triton_explicit: # User explicitly opted in - just log info (no warning) @@ -131,19 +132,17 @@ def _check_triton_compatibility(): else: # pytorch-triton detected but user didn't explicitly opt in warnings.warn( - f"Detected pytorch-triton package (version {triton_version}) instead of " - f"the standard 'triton' package from OpenAI. This typically happens when " - f"PyTorch is installed alongside JAX.\n\n" - f"pytorch-triton is compatible with JAX Triton kernels. To suppress this " - f"warning, set:\n" - f" export NVTE_USE_PYTORCH_TRITON=1\n\n" - f"Alternatively, for a JAX-only environment:\n" - f" - Use separate virtual environments for JAX and PyTorch, or\n" - f" - Use transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", + f"Detected pytorch-triton package (version {triton_version}) instead of the" + " standard 'triton' package from OpenAI. This typically happens when PyTorch is" + " installed alongside JAX.\n\npytorch-triton is compatible with JAX Triton" + " kernels. To suppress this warning, set:\n export" + " NVTE_USE_PYTORCH_TRITON=1\n\nAlternatively, for a JAX-only environment:\n - Use" + " separate virtual environments for JAX and PyTorch, or\n - Use" + " transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", category=UserWarning, stacklevel=3, ) - + return triton_version, is_pytorch_triton @@ -171,7 +170,7 @@ def _check_triton_compatibility(): def get_triton_info(): """Get information about the installed Triton package. - + Returns: dict: Dictionary containing: - version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8") @@ -179,7 +178,7 @@ def get_triton_info(): - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set - source (str): "pytorch" or "openai" indicating the package source - + Example: >>> from transformer_engine.jax.triton_extensions import get_triton_info >>> info = get_triton_info() @@ -189,7 +188,7 @@ def get_triton_info(): """ use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") - + return { "version": _TRITON_VERSION, "is_pytorch_triton": _IS_PYTORCH_TRITON, @@ -305,7 +304,7 @@ def compile_triton( ) # Create kernel object for JAX - # From jax/jaxlib/gpu/triton_kernels.cc: + # From jax/jaxlib/gpu/triton_kernels.cc: from packaging import version if version.parse(jax.__version__) >= version.parse("0.8.2"): From ba3e1f60148fb92078730a9e31e24cc9c9dcaefb Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 18:37:46 -0800 Subject: [PATCH 6/6] fix typo Signed-off-by: tdophung --- build_tools/jax.py | 18 +++++--- .../jax/triton_extensions/utils.py | 45 +++++++++---------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 14612f57f78..ec0b4aaef45 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -20,24 +20,28 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: """Test dependencies for TE/JAX extensions. - + Triton Package Selection: The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable: - + Default (NVTE_USE_PYTORCH_TRITON unset or "0"): Returns 'triton' - OpenAI's standard package from PyPI. Install with: pip install triton - + NVTE_USE_PYTORCH_TRITON=1: Returns 'pytorch-triton' - for mixed JAX+PyTorch environments. Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121 - + Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder. """ - use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ("1", "true", "yes") - + use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( + "1", + "true", + "yes", + ) + triton_package = "pytorch-triton" if use_pytorch_triton else "triton" - + return [ "numpy", triton_package, diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 018ed2e0743..361e17ef3ee 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -27,10 +27,10 @@ def _detect_triton_package(): """Detect which Triton package is installed and validate compatibility. - + Returns: tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool) - + The function detects: - None: Triton not installed - Standard triton from OpenAI (versions like "3.1.0") @@ -39,6 +39,7 @@ def _detect_triton_package(): """ try: import triton + triton_version = getattr(triton, "__version__", "unknown") except ImportError: return None, False, False @@ -48,20 +49,20 @@ def _detect_triton_package(): if "Should never be installed" in str(e): return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True raise - + # Check for placeholder package (version 0.0.1 from PyPI) is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION - + # Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8" is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8 - + return triton_version, is_pytorch_triton, is_placeholder def _check_triton_compatibility(): """Check Triton package compatibility and emit warnings if necessary.""" triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() - + # Handle placeholder package from PyPI if is_placeholder: raise ImportError( @@ -78,7 +79,7 @@ def _check_triton_compatibility(): "Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n" "the broken placeholder. The real pytorch-triton only comes from PyTorch's index." ) - + if triton_version is None: raise ImportError( "Triton is required for transformer_engine.jax.triton_extensions.\n\n" @@ -88,10 +89,10 @@ def _check_triton_compatibility(): " pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n" "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) - + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes") - + if is_pytorch_triton: if use_pytorch_triton_explicit: # User explicitly opted in - just log info (no warning) @@ -99,19 +100,17 @@ def _check_triton_compatibility(): else: # pytorch-triton detected but user didn't explicitly opt in warnings.warn( - f"Detected pytorch-triton package (version {triton_version}) instead of " - f"the standard 'triton' package from OpenAI. This typically happens when " - f"PyTorch is installed alongside JAX.\n\n" - f"pytorch-triton is compatible with JAX Triton kernels. To suppress this " - f"warning, set:\n" - f" export NVTE_USE_PYTORCH_TRITON=1\n\n" - f"Alternatively, for a JAX-only environment:\n" - f" - Use separate virtual environments for JAX and PyTorch, or\n" - f" - Use transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", + f"Detected pytorch-triton package (version {triton_version}) instead of the" + " standard 'triton' package from OpenAI. This typically happens when PyTorch is" + " installed alongside JAX.\n\npytorch-triton is compatible with JAX Triton" + " kernels. To suppress this warning, set:\n export" + " NVTE_USE_PYTORCH_TRITON=1\n\nAlternatively, for a JAX-only environment:\n - Use" + " separate virtual environments for JAX and PyTorch, or\n - Use" + " transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", category=UserWarning, stacklevel=3, ) - + return triton_version, is_pytorch_triton @@ -139,7 +138,7 @@ def _check_triton_compatibility(): def get_triton_info(): """Get information about the installed Triton package. - + Returns: dict: Dictionary containing: - version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8") @@ -150,7 +149,7 @@ def get_triton_info(): """ use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") - + return { "version": _TRITON_VERSION, "is_pytorch_triton": _IS_PYTORCH_TRITON, @@ -266,7 +265,7 @@ def compile_triton( ) # Create kernel object for JAX - # From jax/jaxlib/gpu/triton_kernels.cc: + # From jax/jaxlib/gpu/triton_kernels.cc: from packaging import version if version.parse(jax.__version__) >= version.parse("0.8.2"): @@ -281,7 +280,7 @@ def compile_triton( ) else: kernel = gpu_triton.TritonKernel( - compile.name, + compiled.name, num_warps, compiled.metadata.shared, compiled.asm["ptx"],