-
Notifications
You must be signed in to change notification settings - Fork 789
Add Optimization Cookbook #5117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
c495dc1 to
b929529
Compare
34d7c20 to
444c6b6
Compare
fa523a9 to
5c12190
Compare
3172fdb to
b939636
Compare
f894a0d to
b37c527
Compare
|
I feel we could simplify the intro by doing the following:
model = nnx.Sequential(
nnx.Linear(2,8, rngs=rngs),
nnx.relu,
nnx.Linear(8,8, rngs=rngs),
)
optimizer = nnx.Optimizer(
model,
tx=optax.adam(1e-3),
wrt=nnx.Param)
...
@nnx.jit
def train_step(model, optimizer, ema, x, y):
loss_fn = lambda m, x, y: jnp.sum((m(x) - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
ema.update(model)
return loss |
docs_nnx/guides/opt_cookbook.rst
Outdated
| model = nnx_model(rngs) | ||
| state = nnx.state(model, nnx.Param) | ||
| rates = {'kernel': optax.adam(1e-3), 'bias': optax.adam(1e-2)} | ||
| param_tys = nnx.map_state(lambda p, v: list(p)[-1], state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thin this is enough:
| param_tys = nnx.map_state(lambda p, v: list(p)[-1], state) | |
| param_tys = nnx.map_state(lambda p, v: p[-1], state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also use jax.tree. map_with_path as in the JAX example.
docs_nnx/guides/opt_cookbook.rst
Outdated
| axis_types=(AxisType.Explicit, AxisType.Explicit)) | ||
| jax.set_mesh(mesh) | ||
|
|
||
| ghost_model = jax.eval_shape(lambda: nnx_model(nnx.Rngs(0), out_sharding=P('x', 'y'))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of creating this fake model it would be a good opportunity to create the optimizer_sharding API on Variable before finishing this guide.
|
After fully reading the guide I'm getting the sense that having the JAX versions makes explanations a bit longer and slightly harder to understand (cause you have to mentally filter for the version you are interested in) and having the JAX version doesn't necessarily make understanding the NNX version easier. |
Fair enough! I'll convert it to nnx-only. |
694bd84 to
76f8752
Compare
76f8752 to
f73edbd
Compare
What does this PR do?
This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:
This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.
Warnings: