diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index 82edc5ee8..41fd800a7 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -406,7 +406,6 @@ def _pytree_meta_construct(cls, self, *args, **kwargs): def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P: node = cls.__new__(cls, *args, **kwargs) - vars_obj = vars(node) object.__setattr__(node, '_pytree__state', PytreeState()) object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes) cls._pytree_meta_construct(node, *args, **kwargs) @@ -498,6 +497,11 @@ def __init_subclass__( **kwargs, ) -> None: super().__init_subclass__(**kwargs) + if slots := getattr(cls, '__slots__', ()): + raise TypeError( + 'Pytree currently does not support __slots__, ' + f"found __slots__={slots} in '{cls.__name__}'." + ) cls._pytree__is_pytree = pytree graph.register_graph_node_type( @@ -874,22 +878,30 @@ def _pytree__flatten_with_paths(self): else: key_fn = None node_attributes = self._pytree__nodes - node_names: list[str] = [] + node_keys: list[str | int] = [] node_attrs: list[tuple[tp.Any, tp.Any]] = [] - static_attrs: list[tuple[str, tp.Any]] = [] - for name, value in sorted(obj_items, key=key_fn): - if name in node_attributes and node_attributes[name]: - node_names.append(name) + static_keys: list[str | int] = [] + static_attrs: list[tp.Any] = [] + for key, value in sorted(obj_items, key=key_fn): + # get string representation of the key because + # node_attributes keys are strings + key_str = _get_str(key) + if key_str in node_attributes and node_attributes[key_str]: + node_keys.append(key) node_attrs.append(( - jax.tree_util.GetAttrKey(name) - if isinstance(name, str) - else jax.tree_util.SequenceKey(name), + jax.tree_util.GetAttrKey(key) + if isinstance(key, str) + else jax.tree_util.SequenceKey(key), value, )) else: - static_attrs.append((name, value)) + static_keys.append(key) + static_attrs.append(value) - return node_attrs, (tuple(node_names), tuple(static_attrs)) + return ( + node_attrs, + (tuple(node_keys), tuple(static_keys), tuple(static_attrs)), + ) def _pytree__flatten(self): obj_items = vars(self).items() @@ -899,34 +911,43 @@ def _pytree__flatten(self): else: key_fn = None node_attributes = self._pytree__nodes - node_names: list[str] = [] + node_keys: list[str | int] = [] node_attrs: list[tp.Any] = [] - static_attrs: list[tuple[str, tp.Any]] = [] - for name, value in sorted(obj_items, key=key_fn): - if name in node_attributes and node_attributes[name]: - node_names.append(name) + static_keys: list[str | int] = [] + static_attrs: list[tp.Any] = [] + for key, value in sorted(obj_items, key=key_fn): + # get string representation of the key because + # node_attributes keys are strings + key_str = _get_str(key) + if key_str in node_attributes and node_attributes[key_str]: + node_keys.append(key) node_attrs.append(value) else: - static_attrs.append((name, value)) + static_keys.append(key) + static_attrs.append(value) - return node_attrs, (tuple(node_names), tuple(static_attrs)) + return ( + node_attrs, + (tuple(node_keys), tuple(static_keys), tuple(static_attrs)), + ) @classmethod def _pytree__unflatten( cls, - static: tuple[tuple[str, ...], tuple[tuple[str, tp.Any], ...]], + static: tuple[tuple[str | int, ...], tuple[str | int, ...], tuple[tp.Any, ...]], node_attrs: tp.Iterable[tp.Any], ): - node_names, static_attrs = static + node_keys, static_keys, static_attrs = static obj = object.__new__(cls) - vars_obj = vars(obj) if cls._pytree__has_int_keys: - node_names = tuple( - str(name) if isinstance(name, int) else name for name in node_names - ) - for name, value in zip(node_names, node_attrs, strict=True): + node_keys_iter = map(_get_str, node_keys) + static_keys_iter = map(_get_str, static_keys) + else: + node_keys_iter = node_keys + static_keys_iter = static_keys + for name, value in zip(node_keys_iter, node_attrs, strict=True): object.__setattr__(obj, name, value) - for name, value in static_attrs: + for name, value in zip(static_keys_iter, static_attrs, strict=True): object.__setattr__(obj, name, value) return obj @@ -946,7 +967,16 @@ def _graph_node_flatten(self): def _graph_node_set_key(self, key, value: tp.Any): if self._pytree__has_int_keys and isinstance(key, int): key = str(key) - setattr(self, key, value) + if not isinstance(key, str): + raise KeyError(f'Invalid key: {key!r}') + elif ( + hasattr(self, key) + and isinstance(variable := getattr(self, key), Variable) + and isinstance(value, Variable) + ): + variable.update_from_state(value) + else: + setattr(self, key, value) def _graph_node_pop_key(self, key): if self._pytree__has_int_keys and isinstance(key, int): @@ -978,7 +1008,8 @@ def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]): (str(name) if isinstance(name, int) else name, value) for name, value in attributes ) - vars(self).update(attributes) + for name, value in attributes: + object.__setattr__(self, name, value) if tp.TYPE_CHECKING: def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ... @@ -1000,4 +1031,7 @@ def _maybe_int(x): try: return int(x) except (ValueError, TypeError): - return x \ No newline at end of file + return x + +def _get_str(x): + return x if isinstance(x, str) else str(x) \ No newline at end of file diff --git a/tests/nnx/helpers_test.py b/tests/nnx/helpers_test.py index bf00d4ac4..b56f26d5a 100644 --- a/tests/nnx/helpers_test.py +++ b/tests/nnx/helpers_test.py @@ -159,6 +159,77 @@ def test_list_mutable_sequence(self): self.assertEqual(l[1:3], [6, 7]) + def test_list_fori_loop(self): + class Foo(nnx.Module): + def __init__(self): + self.layers = nnx.List([ + nnx.Linear(1, 1, rngs=nnx.Rngs(0)), + nnx.Linear(1, 1, rngs=nnx.Rngs(0)), + ]) + + def batch_loop_body(i, carry): + return carry + + net = Foo() + jax.lax.fori_loop(0, 2, batch_loop_body, net) + + def test_list_pytree_default_behavior(self): + ls = nnx.List([jnp.array(1), jnp.array(2), jnp.array(3)]) + leaves = jax.tree_util.tree_leaves(ls) + self.assertLen(leaves, 3) + np.testing.assert_array_equal(leaves[0], jnp.array(1)) + np.testing.assert_array_equal(leaves[1], jnp.array(2)) + np.testing.assert_array_equal(leaves[2], jnp.array(3)) + + def test_list_pytree_static_elements(self): + ls = nnx.List([nnx.static(10), nnx.static(20), nnx.static(30)]) + leaves = jax.tree_util.tree_leaves(ls) + self.assertEmpty(leaves) + + def test_list_pytree_data_elements(self): + ls = nnx.List([nnx.data(1), nnx.data(2), nnx.data(3)]) + leaves = jax.tree_util.tree_leaves(ls) + self.assertLen(leaves, 3) + self.assertEqual(leaves[0], 1) + self.assertEqual(leaves[1], 2) + self.assertEqual(leaves[2], 3) + + def test_list_pytree_mixed_static_data(self): + ls = nnx.List([ + nnx.data(jnp.array(1)), + nnx.static(100), + nnx.data(jnp.array(2)), + nnx.static(200), + ]) + leaves = jax.tree_util.tree_leaves(ls) + self.assertLen(leaves, 2) + np.testing.assert_array_equal(leaves[0], jnp.array(1)) + np.testing.assert_array_equal(leaves[1], jnp.array(2)) + + def test_list_pytree_flatten_unflatten(self): + ls = nnx.List([nnx.data(10), nnx.static('hello'), nnx.data(20)]) + leaves, treedef = jax.tree_util.tree_flatten(ls) + self.assertLen(leaves, 2) + self.assertEqual(leaves[0], 10) + self.assertEqual(leaves[1], 20) + + new_leaves = [x * 2 for x in leaves] + new_ls = jax.tree_util.tree_unflatten(treedef, new_leaves) + self.assertEqual(new_ls[0], 20) + self.assertEqual(new_ls[1], 'hello') + self.assertEqual(new_ls[2], 40) + + def test_list_pytree_jit(self): + ls = nnx.List([nnx.data(jnp.array(1.0)), nnx.static(999)]) + + @jax.jit + def double(ls): + return jax.tree.map(lambda x: x * 2, ls) + + result = double(ls) + np.testing.assert_array_equal(result[0], jnp.array(2.0)) + self.assertEqual(result[1], 999) + if __name__ == '__main__': absltest.main()