Skip to content

Conversation

@copybara-service
Copy link

Add tree-mode jit (graph=False) to nnx.jit

nnx.jit now accepts a graph parameter (default True). When graph=False,
Modules are treated as regular JAX pytrees instead of going through the NNX
graph protocol. This simpler mode behaves more like JAX as it assumes referential transparency (no sharing), it only propagates updates for Variables, roughly matching the expected behavior of mutable Hijax .

Tree-mode removes the need for graph.update_context (the most complex
part of NNX) and the NNX prefix/Lift APIs such as StateAxes,
StateSharding, and DiffState. The implementation is thus much simpler, easier to maintain and optimize.

Limitations of tree-mode (graph=False):

  • Shared Variable references are not supported (e.g. two sub-modules
    pointing to the same Linear layer will raise an error).
  • Mutable array references cannot be returned from the jitted function.
  • Capturing updates from the backward pass and forwarding state updates
    to captured objects is not available.

The existing graph-mode (graph=True) remains the default and fully
backward compatible. Users are not forced to migrate.

nnx.jit now accepts a `graph` parameter (default `True`). When `graph=False`,
Modules are treated as regular JAX pytrees instead of going through the NNX
graph protocol. This simpler mode behaves more like JAX as it assumes referential transparency (no sharing), it only propagates updates for Variables, roughly matching the expected behavior of mutable Hijax .

Tree-mode removes the need for `graph.update_context` (the most complex
part of NNX) and the NNX prefix/Lift APIs such as `StateAxes`,
`StateSharding`, and `DiffState`. The implementation is thus much simpler, easier to maintain and optimize.

Limitations of tree-mode (graph=False):
- Shared Variable references are not supported (e.g. two sub-modules
  pointing to the same Linear layer will raise an error).
- Mutable array references cannot be returned from the jitted function.
- Capturing updates from the backward pass and forwarding state updates
  to captured objects is not available.

The existing graph-mode (graph=True) remains the default and fully
backward compatible. Users are not forced to migrate.

PiperOrigin-RevId: 867790352
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants