From 2a49924ba1aba490c226753bbb89e6181ee6a207 Mon Sep 17 00:00:00 2001 From: jackopenn Date: Thu, 25 Dec 2025 20:58:44 +0000 Subject: [PATCH] Add out_sharding docstrings to linear layer call methods --- flax/nnx/nn/linear.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index f25433c96..67bf92855 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -251,6 +251,11 @@ def __call__(self, inputs: Array, out_sharding = None) -> Array: Args: inputs: The nd-array to be transformed. + out_sharding: Optional sharding specification (e.g., + ``jax.sharding.PartitionSpec``) for the output array. When using JAX's + explicit sharding mode with a mesh context with ``AxisType.Explicit``. + If ``None`` (default), the compiler automatically determines output + sharding. Returns: The transformed input. @@ -398,6 +403,11 @@ def __call__(self, inputs: Array, out_sharding = None) -> Array: Args: inputs: The nd-array to be transformed. + out_sharding: Optional sharding specification (e.g., + ``jax.sharding.PartitionSpec``) for the output array. When using JAX's + explicit sharding mode with a mesh context with ``AxisType.Explicit``. + If ``None`` (default), the compiler automatically determines output + sharding. Returns: The transformed input. @@ -532,6 +542,11 @@ def __call__( the rhs being the learnable kernel. Exactly one of ``einsum_str`` in the constructor argument and call argument must be not None, while the other must be None. + out_sharding: Optional sharding specification (e.g., + ``jax.sharding.PartitionSpec``) for the output array. When using JAX's + explicit sharding mode with a mesh context with ``AxisType.Explicit``. + If ``None`` (default), the compiler automatically determines output + sharding. Returns: The transformed input.