Skip to content

Conversation

@divye-joshi
Copy link

@divye-joshi divye-joshi commented Jan 26, 2026

Description

This PR adds native object-oriented pooling layers to flax.nnx.

I am currently migrating from Keras to Flax NNX and noticed that while functional pooling exists in linen, NNX was missing the modular, object-oriented equivalents (e.g., nnx.AvgPool, nnx.MaxPool). Currently, users have to import functional pooling from linen or wrap them manually, which breaks the object-oriented flow of NNX.

Additionally, I noticed that GlobalAveragePool was missing entirely, so I have implemented it as a standard NNX module to simplify workflows for those coming from other frameworks like Keras or PyTorch.

Changes Made

  • Added flax/nnx/nn/pooling.py: Implemented MaxPool, MinPool, AvgPool, and GlobalAveragePool as nnx.Module subclasses.
  • Updated flax/nnx/__init__.py: Exposed the new layers in the top-level nnx namespace.
  • Documentation: Added pooling.rst to the API reference and updated the index.

Usage Example

Before (Linen Functional Pooling)

from flax import linen as nn

# Max Pool
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

# Average Pool
x = nn.avg_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")

# Global Average Pool (manual)
x = x.mean(axis=(1, 2))

After (NNX Object-Oriented Pooling)

from flax import nnx

# Max Pool
x = nnx.MaxPool(window_shape=(2, 2), strides=(2, 2))(x)

# Average Pool
x = nnx.AvgPool(window_shape=(3, 3), strides=(2, 2), padding="SAME")(x)

# Global Average Pool
x = nnx.GlobalAveragePool()(x)

**Closes #5202 **

Copy link
Collaborator

@samanklesaria samanklesaria left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good to me. I do wonder about the utility of a GlobalAveragePool layer. The goal here, as I see it, is to make things more familiar to PyTorch users. But PyTorch doesn't seem to have a GlobalAvergePool layer either. Given that it's so easy for the user just to use a jnp.mean themselves, I'd argue for removing GlobalAveragePool.

@divye-joshi
Copy link
Author

@samanklesaria
Thanks for the review!

You make a great point regarding GlobalAveragePool. Since it's not a standard layer in PyTorch and is trivial to implement with jnp.mean, I agree that removing it helps keep the API surface minimal and focused.

I've updated the PR to remove GlobalAveragePool from the code and documentation. MaxPool, MinPool, and AvgPool remain as discussed. Ready for another look!

@samanklesaria
Copy link
Collaborator

@divye-joshi Thinking over this again, I don't think object oriented versions of the pooling layers make sense for Flax. Functional versions of max_pool and avg_pool already exist, and the marginal benefit of copying PyTorch's API isn't worth the additional maintenance burden of adding new classes to Flax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(nnx): Missing object-oriented pooling layers in NNX

2 participants