From a5c76831927b8439118d79f82a052edaea637811 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 12 Feb 2026 10:08:39 -0800 Subject: [PATCH] Rename sharding_names to out_sharding in NNX Variable metadata This CL renames the sharding_names attribute to out_sharding for better consistency with the sharding API. The new name more clearly indicates the purpose of this metadata field. ## Changes - Bump Flax version to 0.12.4 - Core changes in variablelib.py: - Add sharding_names to out_sharding metadata remapping for backward compatibility - Add deprecated sharding_names property that returns out_sharding with a warning - Update nnx/spmd.py, core/spmd.py, core/meta.py, linen/spmd.py to use out_sharding - Update all NNX tests to use the new attribute name - Update qwix flax_util.py to check for out_sharding first, with fallback to sharding_names - Update maxtext initializers.py to check for out_sharding first - Update documentation and examples to use out_sharding ## Backward Compatibility Existing code using sharding_names will continue to work via: - Metadata remapping during Variable creation - Deprecated Variable.sharding_names property PiperOrigin-RevId: 869269899 --- README.md | 2 +- docs_nnx/flip/4844-var-eager-sharding.md | 4 ++-- docs_nnx/guides/flax_gspmd.ipynb | 6 ++--- docs_nnx/guides/flax_gspmd.md | 6 ++--- docs_nnx/guides/transforms.ipynb | 20 ++++++++--------- docs_nnx/guides/transforms.md | 14 ++++++------ .../nnx_toy_examples/10_fsdp_and_optimizer.py | 6 ++--- flax/core/meta.py | 4 ++-- flax/core/spmd.py | 15 ++++++------- flax/linen/spmd.py | 4 ++-- flax/nnx/spmd.py | 16 +++++++------- flax/nnx/variablelib.py | 17 +++++++++++--- flax/version.py | 2 +- tests/nnx/bridge/wrappers_test.py | 10 ++++----- tests/nnx/nn/linear_test.py | 12 +++++----- tests/nnx/optimizer_test.py | 2 +- tests/nnx/spmd_test.py | 12 +++++----- tests/nnx/transforms_test.py | 22 +++++++++---------- 18 files changed, 92 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index 8752cc3a1..c4004bef6 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,7 @@ To cite this repository: author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee}, title = {{F}lax: A neural network library and ecosystem for {JAX}}, url = {http://github.com/google/flax}, - version = {0.12.3}, + version = {0.12.4}, year = {2024}, } ``` diff --git a/docs_nnx/flip/4844-var-eager-sharding.md b/docs_nnx/flip/4844-var-eager-sharding.md index 3f3316f81..edd3a268f 100644 --- a/docs_nnx/flip/4844-var-eager-sharding.md +++ b/docs_nnx/flip/4844-var-eager-sharding.md @@ -56,12 +56,12 @@ with jax.set_mesh(mesh): ... ``` -For JAX explicit mode, remove the `sharding_names=` annotation on the `nnx.Variable`. +For JAX explicit mode, remove the `out_sharding=` annotation on the `nnx.Variable`. # Implementation [implementation]: #implementation -When an `nnx.Variable` is created, check for the metadata `sharding_names`, and if present, check if under a valid global mesh context of was supplied with a valid mesh. If no, throw error; if yes, call `jax.lax.with_sharding_constraint` to apply sharding constraint on the value. +When an `nnx.Variable` is created, check for the metadata `out_sharding`, and if present, check if under a valid global mesh context of was supplied with a valid mesh. If no, throw error; if yes, call `jax.lax.with_sharding_constraint` to apply sharding constraint on the value. Note that this only works in auto sharding mode. User should use JAX-level APIs to annotate shardings for explicit mode. \ No newline at end of file diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index e833df534..b1e4c1404 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -123,7 +123,7 @@ "metadata": {}, "outputs": [], "source": [ - "nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh)" + "nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh)" ] }, { @@ -134,7 +134,7 @@ "\n", "Let's begin by sharding the simplest component possible - a Flax variable.\n", "\n", - "When you define a Flax variable, you can pass in a metadata field called `sharding_names`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.\n", + "When you define a Flax variable, you can pass in a metadata field called `out_sharding`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.\n", "\n", "**You must have an existing device mesh** and create a sharding-annotated `nnx.Variable` within its scope. This allows the result variable to be sharded accordingly on those devices. The device mesh can be your actual accelerator mesh, or a dummy fake CPU mesh like in this notebook." ] @@ -191,7 +191,7 @@ "with jax.set_mesh(auto_mesh):\n", " w = nnx.Param(\n", " rngs.lecun_normal()((4, 8)),\n", - " sharding_names=(None, 'model')\n", + " out_sharding=(None, 'model')\n", " )\n", " print(w.sharding.spec)\n", " jax.debug.visualize_array_sharding(w) # already sharded!" diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index f5b8a0661..90a7fade5 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -64,14 +64,14 @@ with nnx.use_eager_sharding(False): You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way. ```{code-cell} ipython3 -nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh) +nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh) ``` ## Shard a single-array model Let's begin by sharding the simplest component possible - a Flax variable. -When you define a Flax variable, you can pass in a metadata field called `sharding_names`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded. +When you define a Flax variable, you can pass in a metadata field called `out_sharding`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded. **You must have an existing device mesh** and create a sharding-annotated `nnx.Variable` within its scope. This allows the result variable to be sharded accordingly on those devices. The device mesh can be your actual accelerator mesh, or a dummy fake CPU mesh like in this notebook. @@ -81,7 +81,7 @@ rngs = nnx.Rngs(0) with jax.set_mesh(auto_mesh): w = nnx.Param( rngs.lecun_normal()((4, 8)), - sharding_names=(None, 'model') + out_sharding=(None, 'model') ) print(w.sharding.spec) jax.debug.visualize_array_sharding(w) # already sharded! diff --git a/docs_nnx/guides/transforms.ipynb b/docs_nnx/guides/transforms.ipynb index 27930ba39..e3ca95090 100644 --- a/docs_nnx/guides/transforms.ipynb +++ b/docs_nnx/guides/transforms.ipynb @@ -815,9 +815,9 @@ "output_type": "stream", "text": [ "Inner m.param.shape = (3, 5)\n", - "Inner m.param.sharding_names = ('a', None)\n", + "Inner m.param.out_sharding = ('a', None)\n", "Outter m.param.shape = (3, 4, 5)\n", - "Outter m.param.sharding_names = ('a', 'b', None)\n" + "Outter m.param.out_sharding = ('a', 'b', None)\n" ] } ], @@ -825,20 +825,20 @@ "mesh = jax.make_mesh((1, 1), ('a', 'b'))\n", "\n", "class Weights(nnx.Module):\n", - " def __init__(self, array: jax.Array, sharding_names: tuple[str | None, ...]):\n", - " self.param = nnx.Param(array, sharding_names=sharding_names)\n", + " def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]):\n", + " self.param = nnx.Param(array, out_sharding=out_sharding)\n", "\n", "@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})\n", "def f(m: Weights):\n", " print(f'Inner {m.param.shape = }')\n", - " print(f'Inner {m.param.sharding_names = }')\n", + " print(f'Inner {m.param.out_sharding = }')\n", "\n", "with jax.set_mesh(mesh):\n", - " m = Weights(jnp.ones((3, 4, 5)), sharding_names=('a', 'b', None))\n", + " m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None))\n", " f(m)\n", "\n", "print(f'Outter {m.param.shape = }')\n", - "print(f'Outter {m.param.sharding_names = }')" + "print(f'Outter {m.param.out_sharding = }')" ] }, { @@ -862,19 +862,19 @@ "output_type": "stream", "text": [ "Outter m.param.shape = (3, 4, 5)\n", - "Outter m.param.sharding_names = ('a', 'b', None)\n" + "Outter m.param.out_sharding = ('a', 'b', None)\n" ] } ], "source": [ "@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})\n", "def init_vmap():\n", - " return Weights(jnp.ones((3, 5)), sharding_names=('a', None))\n", + " return Weights(jnp.ones((3, 5)), out_sharding=('a', None))\n", "\n", "with jax.set_mesh(mesh):\n", " m = init_vmap()\n", "print(f'Outter {m.param.shape = }')\n", - "print(f'Outter {m.param.sharding_names = }')" + "print(f'Outter {m.param.out_sharding = }')" ] } ], diff --git a/docs_nnx/guides/transforms.md b/docs_nnx/guides/transforms.md index 957e438e6..d5036133e 100644 --- a/docs_nnx/guides/transforms.md +++ b/docs_nnx/guides/transforms.md @@ -391,20 +391,20 @@ Let's see an example of this in action: mesh = jax.make_mesh((1, 1), ('a', 'b')) class Weights(nnx.Module): - def __init__(self, array: jax.Array, sharding_names: tuple[str | None, ...]): - self.param = nnx.Param(array, sharding_names=sharding_names) + def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]): + self.param = nnx.Param(array, out_sharding=out_sharding) @nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'}) def f(m: Weights): print(f'Inner {m.param.shape = }') - print(f'Inner {m.param.sharding_names = }') + print(f'Inner {m.param.out_sharding = }') with jax.set_mesh(mesh): - m = Weights(jnp.ones((3, 4, 5)), sharding_names=('a', 'b', None)) + m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None)) f(m) print(f'Outter {m.param.shape = }') -print(f'Outter {m.param.sharding_names = }') +print(f'Outter {m.param.out_sharding = }') ``` Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`. @@ -414,10 +414,10 @@ You can verify that this also works when `nnx.Module`s are created inside the tr ```{code-cell} ipython3 @nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'}) def init_vmap(): - return Weights(jnp.ones((3, 5)), sharding_names=('a', None)) + return Weights(jnp.ones((3, 5)), out_sharding=('a', None)) with jax.set_mesh(mesh): m = init_vmap() print(f'Outter {m.param.shape = }') -print(f'Outter {m.param.sharding_names = }') +print(f'Outter {m.param.out_sharding = }') ``` diff --git a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py index b9695e01a..84fb27961 100644 --- a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py +++ b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py @@ -56,15 +56,15 @@ class MLP(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.w1 = nnx.Param( nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)), - sharding_names=mesh_rules('embed', 'mlp'), + out_sharding=mesh_rules('embed', 'mlp'), ) self.b1 = nnx.Param( jnp.zeros((dmid,)), - sharding_names=mesh_rules('mlp'), + out_sharding=mesh_rules('mlp'), ) self.w2 = nnx.Param( nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)), - sharding_names=mesh_rules('embed', 'mlp'), + out_sharding=mesh_rules('embed', 'mlp'), ) def __call__(self, x: jax.Array): diff --git a/flax/core/meta.py b/flax/core/meta.py index e37fb1c4b..9ee175c12 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -297,13 +297,13 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding: def to_nnx_metadata(self) -> dict[str, Any]: """Return a dict of metadata that can translate into an `nnx.Variable`.""" metadata = dict(vars(self)) - metadata['sharding_names'] = metadata.pop('names') + metadata['out_sharding'] = metadata.pop('names') return metadata @classmethod def from_nnx_metadata(cls, metadata: dict[str, Any]): """Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`.""" - metadata['names'] = metadata.pop('sharding_names') + metadata['names'] = metadata.pop('out_sharding') fields = {x.name for x in dataclasses.fields(cls)} return cls(**{k: v for k, v in metadata.items() if k in fields}) diff --git a/flax/core/spmd.py b/flax/core/spmd.py index 0b809e2b7..1945349d7 100644 --- a/flax/core/spmd.py +++ b/flax/core/spmd.py @@ -24,13 +24,13 @@ Sharding, ) -def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec: +def get_pspec(sharding, sharding_rules = None) -> PartitionSpec: """Given an `nnx.Variable`, return its `PartitionSpec`.""" if get_logical_axis_rules() or sharding_rules: context_rules = get_logical_axis_rules() rules = composite_rules(context_rules, sharding_rules) - return PartitionSpec(*from_sharding_rules(sharding_names, rules)) - return PartitionSpec(*sharding_names) + return PartitionSpec(*from_sharding_rules(sharding, rules)) + return PartitionSpec(*sharding) def _apply_sharding(value, sharding, mesh): if mesh.are_all_axes_explicit: @@ -44,10 +44,9 @@ def _apply_sharding(value, sharding, mesh): def shard_value( - value, sharding_names, sharding_rules, - mesh: jax.sharding.AbstractMesh | jax.sharding.Mesh | None + value, sharding, sharding_rules, mesh: jax.sharding.AbstractMesh | jax.sharding.Mesh | None ): - if not sharding_names: + if not sharding: return value if mesh is None: @@ -56,9 +55,9 @@ def shard_value( if mesh is None: raise ValueError( 'An auto mesh context or metadata is required if creating a variable' - f' with annotation {sharding_names=}. ' + f' with annotation {sharding=}. ' 'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.') - pspec = get_pspec(sharding_names, sharding_rules) + pspec = get_pspec(sharding, sharding_rules) return _apply_sharding(value, NamedSharding(mesh, pspec), mesh) diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 25b104f58..890c65b8b 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -303,7 +303,7 @@ def to_nnx_metadata(self) -> dict[str, Any]: """Return a dict of metadata that can translate into an `nnx.Variable`.""" metadata = vars(self) if 'names' in metadata: - metadata['sharding_names'] = metadata.pop('names') + metadata['out_sharding'] = metadata.pop('names') if 'rules' in metadata: metadata['sharding_rules'] = metadata.pop('rules') return metadata @@ -311,7 +311,7 @@ def to_nnx_metadata(self) -> dict[str, Any]: @classmethod def from_nnx_metadata(cls, metadata: dict[str, Any]): """Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`.""" - metadata['names'] = metadata.pop('sharding_names') + metadata['names'] = metadata.pop('out_sharding') metadata['rules'] = metadata.pop('sharding_rules') fields = {x.name for x in dataclasses.fields(cls)} return cls(**{k: v for k, v in metadata.items() if k in fields}) diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index 756165af9..cd4a215b6 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -45,9 +45,9 @@ def insert_field(fields, index, value): def _add_axis(x: tp.Any): if isinstance(x, variablelib.Variable): metadata = x.get_metadata() - if 'sharding_names' in metadata and metadata['sharding_names']: - sharding = metadata['sharding_names'] - x.set_metadata(sharding_names=insert_field(sharding, index, axis_name)) + if 'out_sharding' in metadata and metadata['out_sharding']: + sharding = metadata['out_sharding'] + x.set_metadata(out_sharding=insert_field(sharding, index, axis_name)) for k, v in other_meta.items(): if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple): @@ -74,9 +74,9 @@ def remove_field(fields, index, value): def _remove_axis(x: tp.Any): if isinstance(x, variablelib.Variable): - if hasattr(x, 'sharding_names') and x.sharding_names is not None: + if hasattr(x, 'out_sharding') and x.out_sharding is not None: x.set_metadata( - sharding_names=remove_field(x.sharding_names, index, axis_name) + out_sharding=remove_field(x.out_sharding, index, axis_name) ) for k, v in other_meta.items(): @@ -119,7 +119,7 @@ def with_partitioning( """A wrapper over any initializer to add sharding annotation data to a `Variable`.""" return variablelib.with_metadata( initializer, - sharding_names=sharding, + out_sharding=sharding, mesh=mesh, **metadata, ) @@ -128,8 +128,8 @@ def with_partitioning( def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None: """Given an `nnx.Variable`, return its `PartitionSpec`.""" metadata = v.get_metadata() - if 'sharding_names' in metadata and metadata['sharding_names']: - sharding = metadata['sharding_names'] + if 'out_sharding' in metadata and metadata['out_sharding']: + sharding = metadata['out_sharding'] if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata: context_rules = core_spmd.get_logical_axis_rules() local_rules = metadata.get('sharding_rules', ()) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 108cb896a..9a3f2b536 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -1298,6 +1298,15 @@ def mutable(self) -> bool: def shape(self: Variable[jax.Array]) -> tuple[int, ...]: return self.get_value().shape + @property + def sharding_names(self): + warnings.warn( + "'sharding_names' is deprecated, use 'out_sharding' instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.get_metadata('out_sharding', None) + def __init__( self, value: A | VariableMetadata[A], @@ -1387,7 +1396,9 @@ def __init__( metadata['on_remove_axis'] = var_t.on_remove_axis if 'sharding' in metadata: - metadata['sharding_names'] = metadata.pop('sharding') + metadata['out_sharding'] = metadata.pop('sharding') + if 'sharding_names' in metadata: + metadata['out_sharding'] = metadata.pop('sharding_names') # run create_value hooks if 'on_create_value' in metadata: @@ -1397,10 +1408,10 @@ def __init__( # run create_value hook value = self.create_value(value) # type: ignore # shard the _value if applicable - if eager_sharding and 'sharding_names' in metadata: + if eager_sharding and 'out_sharding' in metadata: value = core_spmd.shard_value( value, - metadata['sharding_names'], + metadata['out_sharding'], metadata.get('sharding_rules', None), metadata.get('mesh', None), ) diff --git a/flax/version.py b/flax/version.py index 4012bb722..e9cea7448 100644 --- a/flax/version.py +++ b/flax/version.py @@ -13,4 +13,4 @@ # limitations under the License. """Current Flax version at head on Github.""" -__version__ = '0.12.3' +__version__ = '0.12.4' diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 8e827bd24..b94dc2be9 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -174,7 +174,7 @@ def create_sharded_nnx_module(x): self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned) self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned) self.assertIsInstance(nnx_model.kernel, nnx.Variable) - assert nnx_model.kernel.sharding_names == ('in', 'out') + assert nnx_model.kernel.out_sharding == ('in', 'out') assert nnx_model.kernel[...].sharding.is_equivalent_to( jax.sharding.NamedSharding( self.mesh, jax.sharding.PartitionSpec('in', 'out') @@ -182,7 +182,7 @@ def create_sharded_nnx_module(x): ndim=2, ), f'{nnx_model.kernel[...].sharding = }' - assert nnx_model.bias.sharding_names == ('out-alias',) + assert nnx_model.bias.out_sharding == ('out-alias',) assert nnx_model.bias.sharding_rules == (('out-alias', 'out'),) assert nnx_model.bias[...].sharding.is_equivalent_to( jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out')), @@ -410,7 +410,7 @@ def test_nnx_to_linen_metadata(self): pspec_tree = nn.get_partition_spec(variables) assert y.shape == (1, 64) self.assertIsInstance(variables['params']['kernel'], nnx.bridge.NNXMeta) - assert variables['params']['kernel'].metadata['sharding_names'] == ('in', 'out') + assert variables['params']['kernel'].metadata['out_sharding'] == ('in', 'out') self.assertEqual(pspec_tree['params']['kernel'], jax.sharding.PartitionSpec('in', 'out')) np.testing.assert_allclose(y, x @ variables['params']['kernel'].value) @@ -519,8 +519,8 @@ def __call__(self, x): w, b = model.inner.dot['w'], model.inner.b np.testing.assert_allclose(model(x), x @ w + b) self.assertIsInstance(w, nnx.Param) - assert hasattr(w, 'sharding_names') and w.sharding_names == ('in', 'out') - assert hasattr(b, 'sharding_names') and b.sharding_names == ('out-alias', ) + assert hasattr(w, 'out_sharding') and w.out_sharding == ('in', 'out') + assert hasattr(b, 'out_sharding') and b.out_sharding == ('out-alias', ) def test_linen_nnx_linen(self): # TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without diff --git a/tests/nnx/nn/linear_test.py b/tests/nnx/nn/linear_test.py index da7e9faa8..4558b1516 100644 --- a/tests/nnx/nn/linear_test.py +++ b/tests/nnx/nn/linear_test.py @@ -391,16 +391,16 @@ class TestLayersParamsMetadata(parameterized.TestCase): def test(self, module_args_kwargs_initargs): module_cls, args, metadata_argnames = module_args_kwargs_initargs kwargs = {"rngs": nnx.Rngs(0)} - sharding_names = ("din", "dout") + out_sharding = ("din", "dout") metadata_kwargs = { - f"{key}_metadata": {"sharding_names": sharding_names[:le]} + f"{key}_metadata": {"out_sharding": out_sharding[:le]} for key, le, _ in metadata_argnames } mesh = jax.make_mesh( (1, 1), - sharding_names, - axis_types=(jax.sharding.AxisType.Auto,) * len(sharding_names), + out_sharding, + axis_types=(jax.sharding.AxisType.Auto,) * len(out_sharding), ) with jax.set_mesh(mesh): module = module_cls(*args, **metadata_kwargs, **kwargs) @@ -410,8 +410,8 @@ def test(self, module_args_kwargs_initargs): for attr_name, param_name in attrs: attr = getattr(module, attr_name) if attr_name is not None else module param = getattr(attr, param_name) - self.assertIsNotNone(param.sharding_names) - self.assertEqual(param.sharding_names, sharding_names[:le]) + self.assertIsNotNone(param.out_sharding) + self.assertEqual(param.out_sharding, out_sharding[:le]) if __name__ == '__main__': diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index 4ada0ce95..4ab7f22f3 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -91,7 +91,7 @@ def test_sharding_propagation(self): state = nnx.state(optimizer) partition_spec = nnx.get_partition_spec(state) - self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_names, ('a', 'b')) + self.assertEqual(state['opt_state'][0]['mu']['kernel'].out_sharding, ('a', 'b')) self.assertEqual( partition_spec['opt_state'][0]['mu']['kernel'].get_value(), jax.sharding.PartitionSpec('a', 'b'), diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 71e71bc7f..0a86505fe 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -139,7 +139,7 @@ def __init__(self, rngs: nnx.Rngs): 4, kernel_init=nnx.with_metadata( nnx.initializers.lecun_normal(), - sharding_names=('din', 'dout'), + out_sharding=('din', 'dout'), nickname=('in', 'out'), on_add_axis=lambda _, idx, name: kadds.append((idx, name)), on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)), @@ -160,7 +160,7 @@ def __call__(self, x: jax.Array): x = self.linear(x) # test sharding layer axes is not present inside scan test.assertEqual(self.linear.kernel.shape, (4, 4)) - test.assertEqual(self.linear.kernel.sharding_names, ('din', 'dout')) + test.assertEqual(self.linear.kernel.out_sharding, ('din', 'dout')) # at least a remove_axis was already called to remove the layer axis test.assertEqual(kremoves[-1], (0, 'layers')) test.assertEqual(bremoves[-1], (0, 'layers')) @@ -175,7 +175,7 @@ def __call__(self, x: jax.Array): with jax.set_mesh(mesh): m = MLP(rngs=nnx.Rngs(0)) self.assertEqual(m.linear.kernel.shape, (5, 4, 4)) - self.assertEqual(m.linear.kernel.sharding_names, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.kernel.out_sharding, ('layers', 'din', 'dout')) self.assertEqual(m.linear.kernel.nickname, ('nick', 'in', 'out')) self.assertEqual(m.linear.bias.shape, (5, 4)) # One add_axis called to add the `nnx.vmap` dimension @@ -205,7 +205,7 @@ def test_eager_sharding_context(self, use_eager_sharding): with jax.set_mesh(mesh): w = nnx.Param( rngs.lecun_normal()((4, 8)), - sharding_names=(None, 'model')) + out_sharding=(None, 'model')) if use_eager_sharding: assert has_sharding_spec(w) else: @@ -347,13 +347,13 @@ def test_sharding_axis_types(self, mode): with self.assertRaises(ValueError): nnx.Variable( jnp.ones((4, 4)), - sharding_names=('row', 'col'), + out_sharding=('row', 'col'), mesh=mesh, ) else: v = nnx.Variable( jnp.ones((4, 4)), - sharding_names=('row', 'col'), + out_sharding=('row', 'col'), mesh=mesh, ) self.assertEqual(v.sharding.mesh, mesh) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 5e48bb494..82c545d0e 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -866,8 +866,8 @@ def test_shardmap_with_sharding_names(self, graph): b = nnx.Param(jnp.ones((4,)), sharding_names=(None,)) self.assertIsInstance(w.get_raw_value().sharding, jax.sharding.NamedSharding) - self.assertEqual(w.sharding_names, ('data', None)) - self.assertEqual(b.sharding_names, (None,)) + self.assertEqual(w.out_sharding, ('data', None)) + self.assertEqual(b.out_sharding, (None,)) @nnx.shard_map( mesh=mesh, in_specs=(P('data', None), P(None)), out_specs=P('data', None), @@ -2105,10 +2105,10 @@ def __init__(self, rngs: nnx.Rngs): 3, 3, kernel_init=nnx.with_metadata( - nnx.initializers.lecun_normal(), sharding_names=('din', 'dout') + nnx.initializers.lecun_normal(), out_sharding=('din', 'dout') ), bias_init=nnx.with_metadata( - nnx.initializers.zeros_init(), sharding_names=('dout',) + nnx.initializers.zeros_init(), out_sharding=('dout',) ), rngs=rngs, ) @@ -2120,9 +2120,9 @@ def __call__(self, x: jax.Array): x = self.linear(x) # test sharding layer axes is not present inside scan test.assertEqual(self.linear.kernel.shape, (3, 3)) - test.assertEqual(self.linear.kernel.sharding_names, ('din', 'dout')) + test.assertEqual(self.linear.kernel.out_sharding, ('din', 'dout')) test.assertEqual(self.linear.bias.shape, (3,)) - test.assertEqual(self.linear.bias.sharding_names, ('dout',)) + test.assertEqual(self.linear.bias.out_sharding, ('dout',)) return x, None mesh = jax.make_mesh((1, 1, 1), ('layers', 'din', 'dout'), axis_types=(jax.sharding.AxisType.Auto,) * len(('layers', 'din', 'dout'))) @@ -2131,9 +2131,9 @@ def __call__(self, x: jax.Array): # test sharding layers axes is set self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) - self.assertEqual(m.linear.kernel.sharding_names, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.kernel.out_sharding, ('layers', 'din', 'dout')) self.assertEqual(m.linear.bias.shape, (5, 3)) - self.assertEqual(m.linear.bias.sharding_names, ('layers', 'dout')) + self.assertEqual(m.linear.bias.out_sharding, ('layers', 'dout')) x = jnp.ones((1, 3)) with jax.set_mesh(mesh): @@ -2141,9 +2141,9 @@ def __call__(self, x: jax.Array): # test sharding axes is preserved self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) - self.assertEqual(m.linear.kernel.sharding_names, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.kernel.out_sharding, ('layers', 'din', 'dout')) self.assertEqual(m.linear.bias.shape, (5, 3)) - self.assertEqual(m.linear.bias.sharding_names, ('layers', 'dout')) + self.assertEqual(m.linear.bias.out_sharding, ('layers', 'dout')) def test_cache_tracing_simple(self): n = 0 @@ -2972,7 +2972,7 @@ def create_block(rngs: nnx.Rngs): with jax.set_mesh(mesh): m = create_block(nnx.Rngs(0)) self.assertEqual(m.kernel.shape, (5, 16, 32)) - self.assertEqual(m.kernel.sharding_names, ('c', 'a', 'b')) + self.assertEqual(m.kernel.out_sharding, ('c', 'a', 'b')) def test_state_axes_from_state(self): class Model(nnx.Module):