Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 63 additions & 29 deletions flax/nnx/pytreelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):

def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P:
node = cls.__new__(cls, *args, **kwargs)
vars_obj = vars(node)
object.__setattr__(node, '_pytree__state', PytreeState())
object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes)
cls._pytree_meta_construct(node, *args, **kwargs)
Expand Down Expand Up @@ -498,6 +497,11 @@ def __init_subclass__(
**kwargs,
) -> None:
super().__init_subclass__(**kwargs)
if slots := getattr(cls, '__slots__', ()):
raise TypeError(
'Pytree currently does not support __slots__, '
f"found __slots__={slots} in '{cls.__name__}'."
)
cls._pytree__is_pytree = pytree

graph.register_graph_node_type(
Expand Down Expand Up @@ -874,22 +878,30 @@ def _pytree__flatten_with_paths(self):
else:
key_fn = None
node_attributes = self._pytree__nodes
node_names: list[str] = []
node_keys: list[str | int] = []
node_attrs: list[tuple[tp.Any, tp.Any]] = []
static_attrs: list[tuple[str, tp.Any]] = []
for name, value in sorted(obj_items, key=key_fn):
if name in node_attributes and node_attributes[name]:
node_names.append(name)
static_keys: list[str | int] = []
static_attrs: list[tp.Any] = []
for key, value in sorted(obj_items, key=key_fn):
# get string representation of the key because
# node_attributes keys are strings
key_str = _get_str(key)
if key_str in node_attributes and node_attributes[key_str]:
node_keys.append(key)
node_attrs.append((
jax.tree_util.GetAttrKey(name)
if isinstance(name, str)
else jax.tree_util.SequenceKey(name),
jax.tree_util.GetAttrKey(key)
if isinstance(key, str)
else jax.tree_util.SequenceKey(key),
value,
))
else:
static_attrs.append((name, value))
static_keys.append(key)
static_attrs.append(value)

return node_attrs, (tuple(node_names), tuple(static_attrs))
return (
node_attrs,
(tuple(node_keys), tuple(static_keys), tuple(static_attrs)),
)

def _pytree__flatten(self):
obj_items = vars(self).items()
Expand All @@ -899,34 +911,43 @@ def _pytree__flatten(self):
else:
key_fn = None
node_attributes = self._pytree__nodes
node_names: list[str] = []
node_keys: list[str | int] = []
node_attrs: list[tp.Any] = []
static_attrs: list[tuple[str, tp.Any]] = []
for name, value in sorted(obj_items, key=key_fn):
if name in node_attributes and node_attributes[name]:
node_names.append(name)
static_keys: list[str | int] = []
static_attrs: list[tp.Any] = []
for key, value in sorted(obj_items, key=key_fn):
# get string representation of the key because
# node_attributes keys are strings
key_str = _get_str(key)
if key_str in node_attributes and node_attributes[key_str]:
node_keys.append(key)
node_attrs.append(value)
else:
static_attrs.append((name, value))
static_keys.append(key)
static_attrs.append(value)

return node_attrs, (tuple(node_names), tuple(static_attrs))
return (
node_attrs,
(tuple(node_keys), tuple(static_keys), tuple(static_attrs)),
)

@classmethod
def _pytree__unflatten(
cls,
static: tuple[tuple[str, ...], tuple[tuple[str, tp.Any], ...]],
static: tuple[tuple[str | int, ...], tuple[str | int, ...], tuple[tp.Any, ...]],
node_attrs: tp.Iterable[tp.Any],
):
node_names, static_attrs = static
node_keys, static_keys, static_attrs = static
obj = object.__new__(cls)
vars_obj = vars(obj)
if cls._pytree__has_int_keys:
node_names = tuple(
str(name) if isinstance(name, int) else name for name in node_names
)
for name, value in zip(node_names, node_attrs, strict=True):
node_keys_iter = map(_get_str, node_keys)
static_keys_iter = map(_get_str, static_keys)
else:
node_keys_iter = node_keys
static_keys_iter = static_keys
for name, value in zip(node_keys_iter, node_attrs, strict=True):
object.__setattr__(obj, name, value)
for name, value in static_attrs:
for name, value in zip(static_keys_iter, static_attrs, strict=True):
object.__setattr__(obj, name, value)
return obj

Expand All @@ -946,7 +967,16 @@ def _graph_node_flatten(self):
def _graph_node_set_key(self, key, value: tp.Any):
if self._pytree__has_int_keys and isinstance(key, int):
key = str(key)
setattr(self, key, value)
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
elif (
hasattr(self, key)
and isinstance(variable := getattr(self, key), Variable)
and isinstance(value, Variable)
):
variable.update_from_state(value)
else:
setattr(self, key, value)

def _graph_node_pop_key(self, key):
if self._pytree__has_int_keys and isinstance(key, int):
Expand Down Expand Up @@ -978,7 +1008,8 @@ def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
(str(name) if isinstance(name, int) else name, value)
for name, value in attributes
)
vars(self).update(attributes)
for name, value in attributes:
object.__setattr__(self, name, value)

if tp.TYPE_CHECKING:
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ...
Expand All @@ -1000,4 +1031,7 @@ def _maybe_int(x):
try:
return int(x)
except (ValueError, TypeError):
return x
return x

def _get_str(x):
return x if isinstance(x, str) else str(x)
71 changes: 71 additions & 0 deletions tests/nnx/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,77 @@ def test_list_mutable_sequence(self):

self.assertEqual(l[1:3], [6, 7])

def test_list_fori_loop(self):
class Foo(nnx.Module):
def __init__(self):
self.layers = nnx.List([
nnx.Linear(1, 1, rngs=nnx.Rngs(0)),
nnx.Linear(1, 1, rngs=nnx.Rngs(0)),
])

def batch_loop_body(i, carry):
return carry

net = Foo()
jax.lax.fori_loop(0, 2, batch_loop_body, net)

def test_list_pytree_default_behavior(self):
ls = nnx.List([jnp.array(1), jnp.array(2), jnp.array(3)])
leaves = jax.tree_util.tree_leaves(ls)
self.assertLen(leaves, 3)
np.testing.assert_array_equal(leaves[0], jnp.array(1))
np.testing.assert_array_equal(leaves[1], jnp.array(2))
np.testing.assert_array_equal(leaves[2], jnp.array(3))

def test_list_pytree_static_elements(self):
ls = nnx.List([nnx.static(10), nnx.static(20), nnx.static(30)])
leaves = jax.tree_util.tree_leaves(ls)
self.assertEmpty(leaves)

def test_list_pytree_data_elements(self):
ls = nnx.List([nnx.data(1), nnx.data(2), nnx.data(3)])
leaves = jax.tree_util.tree_leaves(ls)
self.assertLen(leaves, 3)
self.assertEqual(leaves[0], 1)
self.assertEqual(leaves[1], 2)
self.assertEqual(leaves[2], 3)

def test_list_pytree_mixed_static_data(self):
ls = nnx.List([
nnx.data(jnp.array(1)),
nnx.static(100),
nnx.data(jnp.array(2)),
nnx.static(200),
])
leaves = jax.tree_util.tree_leaves(ls)
self.assertLen(leaves, 2)
np.testing.assert_array_equal(leaves[0], jnp.array(1))
np.testing.assert_array_equal(leaves[1], jnp.array(2))

def test_list_pytree_flatten_unflatten(self):
ls = nnx.List([nnx.data(10), nnx.static('hello'), nnx.data(20)])
leaves, treedef = jax.tree_util.tree_flatten(ls)
self.assertLen(leaves, 2)
self.assertEqual(leaves[0], 10)
self.assertEqual(leaves[1], 20)

new_leaves = [x * 2 for x in leaves]
new_ls = jax.tree_util.tree_unflatten(treedef, new_leaves)
self.assertEqual(new_ls[0], 20)
self.assertEqual(new_ls[1], 'hello')
self.assertEqual(new_ls[2], 40)

def test_list_pytree_jit(self):
ls = nnx.List([nnx.data(jnp.array(1.0)), nnx.static(999)])

@jax.jit
def double(ls):
return jax.tree.map(lambda x: x * 2, ls)

result = double(ls)
np.testing.assert_array_equal(result[0], jnp.array(2.0))
self.assertEqual(result[1], 999)


if __name__ == '__main__':
absltest.main()
Expand Down
Loading