Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,33 @@ def install_requirements() -> List[str]:


def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions."""
return ["numpy", "triton"]
"""Test dependencies for TE/JAX extensions.

Triton Package Selection:
The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable:

Default (NVTE_USE_PYTORCH_TRITON unset or "0"):
Returns 'triton' - OpenAI's standard package from PyPI.
Install with: pip install triton

NVTE_USE_PYTORCH_TRITON=1:
Returns 'pytorch-triton' - for mixed JAX+PyTorch environments.
Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121

Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder.
"""
use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in (
"1",
"true",
"yes",
)

triton_package = "pytorch-triton" if use_pytorch_triton else "triton"

return [
"numpy",
triton_package,
]


def xla_path() -> str:
Expand Down
14 changes: 12 additions & 2 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@


def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
"""Install dependencies for TE/PyTorch extensions.

IMPORTANT - PyTorch Index Required for pytorch-triton:
These dependencies MUST be installed using PyTorch's package index:

pip install pytorch-triton --index-url https://download.pytorch.org/whl/<version??>

- pytorch-triton is only available from PyTorch's index (not PyPI)
- The 'pytorch-triton' package on PyPI is a placeholder that will fail
- torch.compile() requires pytorch-triton, not OpenAI's 'triton' package
"""
return [
"torch>=2.1",
"einops",
Expand All @@ -22,7 +32,7 @@ def install_requirements() -> List[str]:
"packaging",
"pydantic",
"nvdlfw-inspect",
"triton",
"pytorch-triton",
]


Expand Down
33 changes: 32 additions & 1 deletion transformer_engine/jax/triton_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,33 @@
IMPORTANT: This module requires Triton to be installed. If you don't have Triton,
use transformer_engine.jax.cpp_extensions instead (CUDA/FFI based primitives).

Install Triton: pip install triton

Triton Package Options:
-----------------------
There are two compatible Triton packages:

1. Standard 'triton' from OpenAI (recommended for JAX-only environments):
pip install triton

2. 'pytorch-triton' from PyTorch's index (for mixed JAX+PyTorch environments):
pip install torch --index-url https://download.pytorch.org/whl/cu121
# pytorch-triton is automatically installed as a dependency

Both packages work with JAX Triton kernels. The pytorch-triton package
has version format "X.Y.Z+<commit_sha>" (e.g., "3.0.0+45fff310c8").

WARNING: Do NOT run 'pip install pytorch-triton' directly! The package on PyPI
is a placeholder that will fail with "RuntimeError: Should never be installed".
The real pytorch-triton only comes bundled with PyTorch from PyTorch's index.


Environment Variables:
NVTE_USE_PYTORCH_TRITON: If set to "1", acknowledge using pytorch-triton
for JAX Triton kernels (suppresses compatibility warnings). Set this
when both JAX and PyTorch are installed in the same environment.

Example:
export NVTE_USE_PYTORCH_TRITON=1


Usage:
Expand All @@ -23,6 +49,11 @@ def lowering(ctx, x, **kwargs):

# Use permutation functions
from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map

# Check Triton package info
from transformer_engine.jax.triton_extensions import get_triton_info
info = get_triton_info()
print(f"Using Triton {info['version']} from {info['source']}")
"""

from .utils import *
Expand Down
161 changes: 149 additions & 12 deletions transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

This module provides utility functions for integrating Triton kernels into
JAX primitives. Triton is only imported when this module is used.

Triton Package Compatibility --> see __init__.py
"""

import hashlib
import os
import warnings
from typing import Any, Callable, Mapping
import zlib

Expand All @@ -17,6 +21,102 @@
import jax.numpy as jnp


# Placeholder package version on PyPI that should never be used
_PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1"


def _detect_triton_package():
"""Detect which Triton package is installed and validate compatibility.

Returns:
tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool)

The function detects:
- None: Triton not installed
- Standard triton from OpenAI (versions like "3.1.0")
- Real pytorch-triton from PyTorch's index (versions like "3.0.0+45fff310c8")
- Placeholder pytorch-triton from PyPI (version "0.0.1" - broken, raises RuntimeError)
"""
try:
import triton

triton_version = getattr(triton, "__version__", "unknown")
except ImportError:
return None, False, False
except RuntimeError as e:
# The placeholder pytorch-triton package from PyPI raises:
# RuntimeError: "Should never be installed"
if "Should never be installed" in str(e):
return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True
raise

# Check for placeholder package (version 0.0.1 from PyPI)
is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION

# Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8"
is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8

return triton_version, is_pytorch_triton, is_placeholder


def _check_triton_compatibility():
"""Check Triton package compatibility and emit warnings if necessary."""
triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package()

# Handle placeholder package from PyPI
if is_placeholder:
raise ImportError(
"Detected the placeholder 'pytorch-triton' package (version 0.0.1) from PyPI.\n"
"This is NOT a functional Triton installation.\n\n"
"The placeholder package exists to prevent namespace conflicts. To fix this:\n\n"
"Option 1 - Use standard Triton (recommended for JAX-only environments):\n"
" pip uninstall pytorch-triton triton\n"
" pip install triton\n\n"
"Option 2 - Use real pytorch-triton (for mixed JAX+PyTorch environments):\n"
" pip uninstall pytorch-triton triton\n"
" pip install torch --index-url https://download.pytorch.org/whl/cu121\n"
" # pytorch-triton is automatically installed as a torch dependency\n\n"
"Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n"
"the broken placeholder. The real pytorch-triton only comes from PyTorch's index."
)

if triton_version is None:
raise ImportError(
"Triton is required for transformer_engine.jax.triton_extensions.\n\n"
"Option 1 - Install standard Triton (recommended for JAX-only):\n"
" pip install triton\n\n"
"Option 2 - Install PyTorch with pytorch-triton (for mixed environments):\n"
" pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n"
"If you don't need Triton, use transformer_engine.jax.cpp_extensions instead."
)

use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower()
use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes")

if is_pytorch_triton:
if use_pytorch_triton_explicit:
# User explicitly opted in - just log info (no warning)
pass # Silent acknowledgment, no warning needed
else:
# pytorch-triton detected but user didn't explicitly opt in
warnings.warn(
f"Detected pytorch-triton package (version {triton_version}) instead of the"
" standard 'triton' package from OpenAI. This typically happens when PyTorch is"
" installed alongside JAX.\n\npytorch-triton is compatible with JAX Triton"
" kernels. To suppress this warning, set:\n export"
" NVTE_USE_PYTORCH_TRITON=1\n\nAlternatively, for a JAX-only environment:\n - Use"
" separate virtual environments for JAX and PyTorch, or\n - Use"
" transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)",
category=UserWarning,
stacklevel=3,
)

return triton_version, is_pytorch_triton


# Perform compatibility check and get triton info
_TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility()

try:
from jax._src.lib import gpu_triton
from triton.compiler import compiler as tc
Expand All @@ -30,12 +130,35 @@
) from e


__all__ = ["triton_call_lowering"]
__all__ = ["triton_call_lowering", "get_triton_info"]

# Triton kernel cache (module-level, shared across all kernels)
_TRITON_KERNEL_CACHE = {}


def get_triton_info():
"""Get information about the installed Triton package.

Returns:
dict: Dictionary containing:
- version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8")
- is_pytorch_triton (bool): True if using real pytorch-triton from PyTorch's index
- is_openai_triton (bool): True if using standard triton from OpenAI/PyPI
- env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set
- source (str): "pytorch" or "openai" indicating the package source
"""
use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower()
env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes")

return {
"version": _TRITON_VERSION,
"is_pytorch_triton": _IS_PYTORCH_TRITON,
"is_openai_triton": not _IS_PYTORCH_TRITON,
"env_acknowledged": env_acknowledged and _IS_PYTORCH_TRITON,
"source": "pytorch" if _IS_PYTORCH_TRITON else "openai",
}


def get_triton_dtype(aval):
"""Convert JAX dtype to Triton type string.

Expand Down Expand Up @@ -142,17 +265,31 @@ def compile_triton(
)

# Create kernel object for JAX
kernel = gpu_triton.TritonKernel(
compiled.name,
num_warps,
compiled.metadata.shared,
compiled.asm["ptx"],
"", # ttir
compute_capability,
1,
1,
1, # cluster_dims
)
# From jax/jaxlib/gpu/triton_kernels.cc:
from packaging import version

if version.parse(jax.__version__) >= version.parse("0.8.2"):
kernel = gpu_triton.TritonKernel(
compiled.name, # arg0: kernel_name (str)
num_warps, # arg1: num_warps (int)
num_ctas, # arg2: num_ctas (int)
compiled.metadata.shared, # arg3: shared_mem_bytes (int)
compiled.asm["ptx"], # arg4: ptx (str)
"", # arg5: ttir (str) - empty
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will soon be the same as main. as this change is made here in: #1921, to be merged. it is just in this PR so I can test triton calls locally with the nitghtly jax container without running into errors because of jax 0.8.2+

compute_capability, # arg6: compute_capability (int)
)
else:
kernel = gpu_triton.TritonKernel(
compiled.name,
num_warps,
compiled.metadata.shared,
compiled.asm["ptx"],
"", # ttir
compute_capability,
1,
1,
1,
)

_TRITON_KERNEL_CACHE[cache_key] = kernel
return kernel
Expand Down
Loading