diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index f25433c96..67bf92855 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -251,6 +251,11 @@ def __call__(self, inputs: Array, out_sharding = None) -> Array: Args: inputs: The nd-array to be transformed. + out_sharding: Optional sharding specification (e.g., + ``jax.sharding.PartitionSpec``) for the output array. When using JAX's + explicit sharding mode with a mesh context with ``AxisType.Explicit``. + If ``None`` (default), the compiler automatically determines output + sharding. Returns: The transformed input. @@ -398,6 +403,11 @@ def __call__(self, inputs: Array, out_sharding = None) -> Array: Args: inputs: The nd-array to be transformed. + out_sharding: Optional sharding specification (e.g., + ``jax.sharding.PartitionSpec``) for the output array. When using JAX's + explicit sharding mode with a mesh context with ``AxisType.Explicit``. + If ``None`` (default), the compiler automatically determines output + sharding. Returns: The transformed input. @@ -532,6 +542,11 @@ def __call__( the rhs being the learnable kernel. Exactly one of ``einsum_str`` in the constructor argument and call argument must be not None, while the other must be None. + out_sharding: Optional sharding specification (e.g., + ``jax.sharding.PartitionSpec``) for the output array. When using JAX's + explicit sharding mode with a mesh context with ``AxisType.Explicit``. + If ``None`` (default), the compiler automatically determines output + sharding. Returns: The transformed input.