Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ To cite this repository:
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.12.3},
version = {0.12.4},
year = {2024},
}
```
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/flip/4844-var-eager-sharding.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ with jax.set_mesh(mesh):
...
```

For JAX explicit mode, remove the `sharding_names=` annotation on the `nnx.Variable`.
For JAX explicit mode, remove the `out_sharding=` annotation on the `nnx.Variable`.


# Implementation
[implementation]: #implementation

When an `nnx.Variable` is created, check for the metadata `sharding_names`, and if present, check if under a valid global mesh context of was supplied with a valid mesh. If no, throw error; if yes, call `jax.lax.with_sharding_constraint` to apply sharding constraint on the value.
When an `nnx.Variable` is created, check for the metadata `out_sharding`, and if present, check if under a valid global mesh context of was supplied with a valid mesh. If no, throw error; if yes, call `jax.lax.with_sharding_constraint` to apply sharding constraint on the value.

Note that this only works in auto sharding mode. User should use JAX-level APIs to annotate shardings for explicit mode.
6 changes: 3 additions & 3 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
"metadata": {},
"outputs": [],
"source": [
"nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh)"
"nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh)"
]
},
{
Expand All @@ -134,7 +134,7 @@
"\n",
"Let's begin by sharding the simplest component possible - a Flax variable.\n",
"\n",
"When you define a Flax variable, you can pass in a metadata field called `sharding_names`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.\n",
"When you define a Flax variable, you can pass in a metadata field called `out_sharding`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.\n",
"\n",
"**You must have an existing device mesh** and create a sharding-annotated `nnx.Variable` within its scope. This allows the result variable to be sharded accordingly on those devices. The device mesh can be your actual accelerator mesh, or a dummy fake CPU mesh like in this notebook."
]
Expand Down Expand Up @@ -191,7 +191,7 @@
"with jax.set_mesh(auto_mesh):\n",
" w = nnx.Param(\n",
" rngs.lecun_normal()((4, 8)),\n",
" sharding_names=(None, 'model')\n",
" out_sharding=(None, 'model')\n",
" )\n",
" print(w.sharding.spec)\n",
" jax.debug.visualize_array_sharding(w) # already sharded!"
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ with nnx.use_eager_sharding(False):
You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way.

```{code-cell} ipython3
nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh)
nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh)
```

## Shard a single-array model

Let's begin by sharding the simplest component possible - a Flax variable.

When you define a Flax variable, you can pass in a metadata field called `sharding_names`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.
When you define a Flax variable, you can pass in a metadata field called `out_sharding`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.

**You must have an existing device mesh** and create a sharding-annotated `nnx.Variable` within its scope. This allows the result variable to be sharded accordingly on those devices. The device mesh can be your actual accelerator mesh, or a dummy fake CPU mesh like in this notebook.

Expand All @@ -81,7 +81,7 @@ rngs = nnx.Rngs(0)
with jax.set_mesh(auto_mesh):
w = nnx.Param(
rngs.lecun_normal()((4, 8)),
sharding_names=(None, 'model')
out_sharding=(None, 'model')
)
print(w.sharding.spec)
jax.debug.visualize_array_sharding(w) # already sharded!
Expand Down
20 changes: 10 additions & 10 deletions docs_nnx/guides/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -815,30 +815,30 @@
"output_type": "stream",
"text": [
"Inner m.param.shape = (3, 5)\n",
"Inner m.param.sharding_names = ('a', None)\n",
"Inner m.param.out_sharding = ('a', None)\n",
"Outter m.param.shape = (3, 4, 5)\n",
"Outter m.param.sharding_names = ('a', 'b', None)\n"
"Outter m.param.out_sharding = ('a', 'b', None)\n"
]
}
],
"source": [
"mesh = jax.make_mesh((1, 1), ('a', 'b'))\n",
"\n",
"class Weights(nnx.Module):\n",
" def __init__(self, array: jax.Array, sharding_names: tuple[str | None, ...]):\n",
" self.param = nnx.Param(array, sharding_names=sharding_names)\n",
" def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]):\n",
" self.param = nnx.Param(array, out_sharding=out_sharding)\n",
"\n",
"@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})\n",
"def f(m: Weights):\n",
" print(f'Inner {m.param.shape = }')\n",
" print(f'Inner {m.param.sharding_names = }')\n",
" print(f'Inner {m.param.out_sharding = }')\n",
"\n",
"with jax.set_mesh(mesh):\n",
" m = Weights(jnp.ones((3, 4, 5)), sharding_names=('a', 'b', None))\n",
" m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None))\n",
" f(m)\n",
"\n",
"print(f'Outter {m.param.shape = }')\n",
"print(f'Outter {m.param.sharding_names = }')"
"print(f'Outter {m.param.out_sharding = }')"
]
},
{
Expand All @@ -862,19 +862,19 @@
"output_type": "stream",
"text": [
"Outter m.param.shape = (3, 4, 5)\n",
"Outter m.param.sharding_names = ('a', 'b', None)\n"
"Outter m.param.out_sharding = ('a', 'b', None)\n"
]
}
],
"source": [
"@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})\n",
"def init_vmap():\n",
" return Weights(jnp.ones((3, 5)), sharding_names=('a', None))\n",
" return Weights(jnp.ones((3, 5)), out_sharding=('a', None))\n",
"\n",
"with jax.set_mesh(mesh):\n",
" m = init_vmap()\n",
"print(f'Outter {m.param.shape = }')\n",
"print(f'Outter {m.param.sharding_names = }')"
"print(f'Outter {m.param.out_sharding = }')"
]
}
],
Expand Down
14 changes: 7 additions & 7 deletions docs_nnx/guides/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,20 +391,20 @@ Let's see an example of this in action:
mesh = jax.make_mesh((1, 1), ('a', 'b'))

class Weights(nnx.Module):
def __init__(self, array: jax.Array, sharding_names: tuple[str | None, ...]):
self.param = nnx.Param(array, sharding_names=sharding_names)
def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]):
self.param = nnx.Param(array, out_sharding=out_sharding)

@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
def f(m: Weights):
print(f'Inner {m.param.shape = }')
print(f'Inner {m.param.sharding_names = }')
print(f'Inner {m.param.out_sharding = }')

with jax.set_mesh(mesh):
m = Weights(jnp.ones((3, 4, 5)), sharding_names=('a', 'b', None))
m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None))
f(m)

print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding_names = }')
print(f'Outter {m.param.out_sharding = }')
```

Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`.
Expand All @@ -414,10 +414,10 @@ You can verify that this also works when `nnx.Module`s are created inside the tr
```{code-cell} ipython3
@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
def init_vmap():
return Weights(jnp.ones((3, 5)), sharding_names=('a', None))
return Weights(jnp.ones((3, 5)), out_sharding=('a', None))

with jax.set_mesh(mesh):
m = init_vmap()
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding_names = }')
print(f'Outter {m.param.out_sharding = }')
```
6 changes: 3 additions & 3 deletions examples/nnx_toy_examples/10_fsdp_and_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ class MLP(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.w1 = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
sharding_names=mesh_rules('embed', 'mlp'),
out_sharding=mesh_rules('embed', 'mlp'),
)
self.b1 = nnx.Param(
jnp.zeros((dmid,)),
sharding_names=mesh_rules('mlp'),
out_sharding=mesh_rules('mlp'),
)
self.w2 = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
sharding_names=mesh_rules('embed', 'mlp'),
out_sharding=mesh_rules('embed', 'mlp'),
)

def __call__(self, x: jax.Array):
Expand Down
4 changes: 2 additions & 2 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,13 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = dict(vars(self))
metadata['sharding_names'] = metadata.pop('names')
metadata['out_sharding'] = metadata.pop('names')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
metadata['names'] = metadata.pop('sharding_names')
metadata['names'] = metadata.pop('out_sharding')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})

Expand Down
15 changes: 7 additions & 8 deletions flax/core/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
Sharding,
)

def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec:
def get_pspec(sharding, sharding_rules = None) -> PartitionSpec:
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
if get_logical_axis_rules() or sharding_rules:
context_rules = get_logical_axis_rules()
rules = composite_rules(context_rules, sharding_rules)
return PartitionSpec(*from_sharding_rules(sharding_names, rules))
return PartitionSpec(*sharding_names)
return PartitionSpec(*from_sharding_rules(sharding, rules))
return PartitionSpec(*sharding)

def _apply_sharding(value, sharding, mesh):
if mesh.are_all_axes_explicit:
Expand All @@ -44,10 +44,9 @@ def _apply_sharding(value, sharding, mesh):


def shard_value(
value, sharding_names, sharding_rules,
mesh: jax.sharding.AbstractMesh | jax.sharding.Mesh | None
value, sharding, sharding_rules, mesh: jax.sharding.AbstractMesh | jax.sharding.Mesh | None
):
if not sharding_names:
if not sharding:
return value

if mesh is None:
Expand All @@ -56,9 +55,9 @@ def shard_value(
if mesh is None:
raise ValueError(
'An auto mesh context or metadata is required if creating a variable'
f' with annotation {sharding_names=}. '
f' with annotation {sharding=}. '
'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.')
pspec = get_pspec(sharding_names, sharding_rules)
pspec = get_pspec(sharding, sharding_rules)
return _apply_sharding(value, NamedSharding(mesh, pspec), mesh)


Expand Down
4 changes: 2 additions & 2 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,15 @@ def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
if 'names' in metadata:
metadata['sharding_names'] = metadata.pop('names')
metadata['out_sharding'] = metadata.pop('names')
if 'rules' in metadata:
metadata['sharding_rules'] = metadata.pop('rules')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
metadata['names'] = metadata.pop('sharding_names')
metadata['names'] = metadata.pop('out_sharding')
metadata['rules'] = metadata.pop('sharding_rules')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})
Expand Down
16 changes: 8 additions & 8 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def insert_field(fields, index, value):
def _add_axis(x: tp.Any):
if isinstance(x, variablelib.Variable):
metadata = x.get_metadata()
if 'sharding_names' in metadata and metadata['sharding_names']:
sharding = metadata['sharding_names']
x.set_metadata(sharding_names=insert_field(sharding, index, axis_name))
if 'out_sharding' in metadata and metadata['out_sharding']:
sharding = metadata['out_sharding']
x.set_metadata(out_sharding=insert_field(sharding, index, axis_name))

for k, v in other_meta.items():
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
Expand All @@ -74,9 +74,9 @@ def remove_field(fields, index, value):

def _remove_axis(x: tp.Any):
if isinstance(x, variablelib.Variable):
if hasattr(x, 'sharding_names') and x.sharding_names is not None:
if hasattr(x, 'out_sharding') and x.out_sharding is not None:
x.set_metadata(
sharding_names=remove_field(x.sharding_names, index, axis_name)
out_sharding=remove_field(x.out_sharding, index, axis_name)
)

for k, v in other_meta.items():
Expand Down Expand Up @@ -119,7 +119,7 @@ def with_partitioning(
"""A wrapper over any initializer to add sharding annotation data to a `Variable`."""
return variablelib.with_metadata(
initializer,
sharding_names=sharding,
out_sharding=sharding,
mesh=mesh,
**metadata,
)
Expand All @@ -128,8 +128,8 @@ def with_partitioning(
def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None:
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
metadata = v.get_metadata()
if 'sharding_names' in metadata and metadata['sharding_names']:
sharding = metadata['sharding_names']
if 'out_sharding' in metadata and metadata['out_sharding']:
sharding = metadata['out_sharding']
if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata:
context_rules = core_spmd.get_logical_axis_rules()
local_rules = metadata.get('sharding_rules', ())
Expand Down
17 changes: 14 additions & 3 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,15 @@ def mutable(self) -> bool:
def shape(self: Variable[jax.Array]) -> tuple[int, ...]:
return self.get_value().shape

@property
def sharding_names(self):
warnings.warn(
"'sharding_names' is deprecated, use 'out_sharding' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.get_metadata('out_sharding', None)

def __init__(
self,
value: A | VariableMetadata[A],
Expand Down Expand Up @@ -1387,7 +1396,9 @@ def __init__(
metadata['on_remove_axis'] = var_t.on_remove_axis

if 'sharding' in metadata:
metadata['sharding_names'] = metadata.pop('sharding')
metadata['out_sharding'] = metadata.pop('sharding')
if 'sharding_names' in metadata:
metadata['out_sharding'] = metadata.pop('sharding_names')

# run create_value hooks
if 'on_create_value' in metadata:
Expand All @@ -1397,10 +1408,10 @@ def __init__(
# run create_value hook
value = self.create_value(value) # type: ignore
# shard the _value if applicable
if eager_sharding and 'sharding_names' in metadata:
if eager_sharding and 'out_sharding' in metadata:
value = core_spmd.shard_value(
value,
metadata['sharding_names'],
metadata['out_sharding'],
metadata.get('sharding_rules', None),
metadata.get('mesh', None),
)
Expand Down
2 changes: 1 addition & 1 deletion flax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

"""Current Flax version at head on Github."""
__version__ = '0.12.3'
__version__ = '0.12.4'
Loading
Loading