Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions export/orbax/export/data_processors/data_processor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=g-importing-member

import abc
from collections.abc import Set


class DataProcessor(abc.ABC):
Expand Down
2 changes: 1 addition & 1 deletion export/orbax/export/data_processors/tf_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""The data processor module for tensroflow-based pre/post-processors."""

from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence, Set
import functools
from typing import Any, Tuple, cast

Expand Down
80 changes: 63 additions & 17 deletions export/orbax/export/jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@
ApplyFn = orbax_export_typing.ApplyFn


def _is_apply_fn_info(
apply_fn: (
orbax_export_typing.ApplyFn
| orbax_export_typing.ApplyFnInfo
| Mapping[str, orbax_export_typing.ApplyFn]
| Mapping[str, orbax_export_typing.ApplyFnInfo]
),
) -> bool:
if isinstance(apply_fn, orbax_export_typing.ApplyFnInfo):
return True
if isinstance(apply_fn, Mapping):
for v in apply_fn.values():
if isinstance(v, orbax_export_typing.ApplyFnInfo):
return True
return False


class JaxModule(orbax_module_base.OrbaxModuleBase):
"""An exportable module for JAX functions and parameters.

Expand All @@ -41,7 +58,12 @@ class JaxModule(orbax_module_base.OrbaxModuleBase):
def __init__(
self,
params: PyTree,
apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]],
apply_fn: (
orbax_export_typing.ApplyFn
| orbax_export_typing.ApplyFnInfo
| Mapping[str, orbax_export_typing.ApplyFn]
| Mapping[str, orbax_export_typing.ApplyFnInfo]
),
trainable: Optional[Union[bool, PyTree]] = None,
input_polymorphic_shape: Union[PyTree, Mapping[str, PyTree], None] = None,
input_polymorphic_shape_symbol_values: Union[
Expand All @@ -60,10 +82,12 @@ def __init__(
Args:
params: a pytree of JAX parameters or parameter specs (e.g.
`jax.ShapeDtypeStruct`s).
apply_fn: A JAX ``ApplyFn`` (i.e. of signature ``apply_fn(params, x)``),
or a mapping of method key to ``ApplyFn``. If it is an ``ApplyFn``, it
will be assigned a key ``constants.DEFAULT_METHOD_KEY`` automatically,
which can be used to look up the TF function converted from it.
apply_fn: A single `ApplyFn` (taking `model_params` and `model_inputs`), a
single `ApplyFnInfo` object (containing `ApplyFn` and input/output
keys), or a mapping of method keys to `ApplyFn`s or `ApplyFnInfo`
objects. If it is a single ``ApplyFn`` or ``ApplyFnInfo``, it will be
assigned a key ``constants.DEFAULT_METHOD_KEY`` automatically, which can
be used to look up the TF function converted from it.
trainable: a pytree in the same structure as ``params`` and boolean leaves
to tell if a parameter is trainable. Alternatively, it can be a single
boolean value to tell if all the parameters are trainable or not. By
Expand Down Expand Up @@ -115,8 +139,12 @@ def __init__(
OrbaxNativeSerializationType.
jax2obm_options: Options for jax2obm conversion.

raises:
ValueError: If the export version is not supported.
Raises:
ValueError: If `jax2obm_kwargs` and `jax2obm_options` are both provided,
or if `input_polymorphic_shape_symbol_values` or `ApplyFnInfo` are
provided but `export_version` is not
`constants.ExportModelType.ORBAX_MODEL`, or if `export_version` is not
supported.
"""
if jax2obm_kwargs is not None:
if jax2obm_options is not None:
Expand All @@ -129,14 +157,18 @@ def __init__(
DeprecationWarning,
)
self._export_version = export_version
if (
input_polymorphic_shape_symbol_values is not None
and export_version != constants.ExportModelType.ORBAX_MODEL
):
raise ValueError(
'`input_polymorphic_shape_symbol_values` is only supported for'
' constants.ExportModelType.ORBAX_MODEL.'
)

if export_version != constants.ExportModelType.ORBAX_MODEL:
if input_polymorphic_shape_symbol_values is not None:
raise ValueError(
'`input_polymorphic_shape_symbol_values` is only supported for'
' constants.ExportModelType.ORBAX_MODEL.'
)
if _is_apply_fn_info(apply_fn):
raise ValueError(
'`ApplyFnInfo` is only supported for'
' constants.ExportModelType.ORBAX_MODEL.'
)

match export_version:
case constants.ExportModelType.TF_SAVEDMODEL:
Expand Down Expand Up @@ -168,8 +200,22 @@ def __init__(
)

@property
def apply_fn_map(self) -> Mapping[str, ApplyFn]:
"""Returns the apply_fn_map."""
def apply_fn_map(
self,
) -> Mapping[
str, orbax_export_typing.ApplyFn | orbax_export_typing.ApplyFnInfo
]:
"""Returns a mapping from method keys to ApplyFn or ApplyFnInfo objects.

Each value in the mapping is either an `ApplyFn` (a callable that takes
model parameters and inputs) or an `ApplyFnInfo` object. `ApplyFnInfo`
wraps an `ApplyFn` along with its input and output keys, and is used for
specifying preprocessing/postprocessing dependencies when exporting to
`constants.ExportModelType.ORBAX_MODEL` format.

If a single `ApplyFn` or `ApplyFnInfo` was provided during initialization,
it is keyed by `constants.DEFAULT_METHOD_KEY`.
"""
return self._export_module.apply_fn_map

@property
Expand Down
35 changes: 29 additions & 6 deletions export/orbax/export/modules/obm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ class ObmModule(orbax_module_base.OrbaxModuleBase):
def __init__(
self,
params: PyTree,
apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]],
apply_fn: (
orbax_export_typing.ApplyFn
| orbax_export_typing.ApplyFnInfo
| Mapping[
str, orbax_export_typing.ApplyFn | orbax_export_typing.ApplyFnInfo
]
),
*,
input_polymorphic_shape: Any = None,
input_polymorphic_shape_symbol_values: Union[
Expand All @@ -51,7 +57,11 @@ def __init__(

Args:
params: The model parameter specs (e.g. `jax.ShapeDtypeStruct`s).
apply_fn: The apply_fn for the model.
apply_fn: A single `ApplyFn` (taking `model_params` and `model_inputs`), a
single `ApplyFnInfo` object (containing `ApplyFn` and input/output
keys), or a mapping of method keys to `ApplyFn`s or `ApplyFnInfo`
objects. If it is a single ``ApplyFn`` or ``ApplyFnInfo``, it will be
assigned a key ``constants.DEFAULT_METHOD_KEY`` automatically.
input_polymorphic_shape: polymorphic shape for the inputs of `apply_fn`.
input_polymorphic_shape_symbol_values: optional mapping of symbol names
presented in `input_polymorphic_shape` to discrete values (e.g. {'b':
Expand Down Expand Up @@ -150,20 +160,29 @@ def _jax2obm_kwargs_to_options(

def _normalize_apply_fn_map(
self,
apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]],
apply_fn: (
orbax_export_typing.ApplyFn
| orbax_export_typing.ApplyFnInfo
| Mapping[
str, orbax_export_typing.ApplyFn | orbax_export_typing.ApplyFnInfo
]
),
input_polymorphic_shape: Union[PyTree, Mapping[str, PyTree], None],
input_polymorphic_shape_symbol_values: Union[
PyTree, Mapping[str, PyTree], None
],
) -> Tuple[
Mapping[str, ApplyFn],
Mapping[
str, orbax_export_typing.ApplyFn | orbax_export_typing.ApplyFnInfo
],
Mapping[str, Union[PyTree, None]],
Mapping[str, Union[PyTree, None]],
]:
"""Converts all the inputs to maps that share the same keys."""

# Single apply_fn case. Will use the default method key.
if callable(apply_fn):
if not isinstance(apply_fn, Mapping):
apply_fn: orbax_export_typing.ApplyFnInfo | orbax_export_typing.ApplyFn
apply_fn_map = {constants.DEFAULT_METHOD_KEY: apply_fn}
input_polymorphic_shape_map = {
constants.DEFAULT_METHOD_KEY: input_polymorphic_shape
Expand Down Expand Up @@ -302,7 +321,11 @@ def export_module(
return self

@property
def apply_fn_map(self) -> Mapping[str, ApplyFn]:
def apply_fn_map(
self,
) -> Mapping[
str, orbax_export_typing.ApplyFn | orbax_export_typing.ApplyFnInfo
]:
"""Returns the apply_fn_map from function name to jit'd apply function."""
return self._apply_fn_map

Expand Down
5 changes: 3 additions & 2 deletions export/orbax/export/oex_orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

"""Pipeline: pre-processor + model-function + post-processor."""

from collections.abc import Sequence, Set
import dataclasses
from typing import Any, Dict, List, Sequence, Tuple, TypeVar
from typing import Any

from absl import logging
import jax
import jaxtyping
# TODO: b/448900820 - remove this unused import.
from orbax.export.data_processors import data_processor_base
from orbax.export.modules import obm_module
63 changes: 53 additions & 10 deletions export/orbax/export/serving_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ class ServingConfig:
# Optional sequence of `DataProcessor`s to be applied after the main model
# function.
postprocessors: Sequence[data_processor_base.DataProcessor] = ()
# Optional sequence of `DataProcessor`s to be applied. `DataProcessor` is a
# new abstraction for constructing pipelines for Orbax Model export. This
# field is mutually exclusive with `tf_preprocessor`, `preprocessors`,
# `tf_postprocessor`, and `postprocessors`. If this field is used, the
# `DataProcessor`s and the model function will be ordered based on their
# input and output keys using topological sorting.
data_processors: Sequence[data_processor_base.DataProcessor] = ()
# A nested structure of tf.saved_model.experimental.TrackableResource that are
# used in `tf_preprocessor` and/or `tf_postprocessor`. If a TrackableResource
# an attritute of the `tf_preprocessor` (or `tf_postprocessor`), and the
Expand Down Expand Up @@ -108,6 +115,18 @@ class ServingConfig:
preprocess_output_passthrough_enabled: bool = False

def __post_init__(self):
"""Post-initialization checks for ServingConfig.

Raises:
ValueError: If `obm_kwargs` and `obm_export_options` are both set.
ValueError: If `signature_key` is not set.
ValueError: If `data_processors` is set along with `tf_preprocessor`,
`preprocessors`, `tf_postprocessor`, or `postprocessors`.
ValueError: If a processor in `data_processors` does not have
`input_keys` or `output_keys`.
ValueError: If `tf_preprocessor` and `preprocessors` are both set.
ValueError: If `tf_postprocessor` and `postprocessors` are both set.
"""
if self.obm_kwargs:
if self.obm_export_options is not None:
raise ValueError(
Expand All @@ -123,16 +142,40 @@ def __post_init__(self):
)
if not self.signature_key:
raise ValueError('`signature_key` must be set.')
if self.tf_preprocessor and self.preprocessors:
raise ValueError(
'`tf_preprocessor` and `preprocessors` cannot be set at the same'
' time.'
)
if self.tf_postprocessor and self.postprocessors:
raise ValueError(
'`tf_postprocessor` and `postprocessors` cannot be set at the same'
' time.'
)
if self.data_processors:
if (
self.tf_preprocessor
or self.preprocessors
or self.tf_postprocessor
or self.postprocessors
):
raise ValueError(
'`data_processors` cannot be set at the same time as'
' `tf_preprocessor`, `preprocessors`, `tf_postprocessor` or'
' `postprocessors`.'
)
for processor in self.data_processors:
if not processor.input_keys:
raise ValueError(
f'Processor {processor.name} in `data_processors` must have'
' `input_keys`.'
)
if not processor.output_keys:
raise ValueError(
f'Processor {processor.name} in `data_processors` must have'
' `output_keys`.'
)
else:
if self.tf_preprocessor and self.preprocessors:
raise ValueError(
'`tf_preprocessor` and `preprocessors` cannot be set at the same'
' time.'
)
if self.tf_postprocessor and self.postprocessors:
raise ValueError(
'`tf_postprocessor` and `postprocessors` cannot be set at the same'
' time.'
)

def get_signature_keys(self) -> Sequence[str]:
if isinstance(self.signature_key, str):
Expand Down
11 changes: 9 additions & 2 deletions export/orbax/export/serving_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized
import jax.numpy as jnp
import numpy as np
from orbax.export import obm_configs
from orbax.export import serving_config
from orbax.export.data_processors import data_processor_base
import tensorflow as tf


ServingConfig = serving_config.ServingConfig


class ServingConfigTest(tf.test.TestCase):
class _TestDataProcessor(data_processor_base.DataProcessor):

def prepare(self):
pass


class ServingConfigTest(tf.test.TestCase, parameterized.TestCase):

def test_obm_kwargs_deprecation(self):
batch_opts = obm_configs.BatchOptions(
Expand Down
25 changes: 24 additions & 1 deletion export/orbax/export/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

"""Common typing for export."""

from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence, Set
import dataclasses
from typing import Any, TypeVar, Union
import jaxtyping
import tensorflow as tf
Expand All @@ -31,3 +32,25 @@
# ApplyFn take two arguments, the first one is the model_params, the second one
# is the model_inputs.
ApplyFn = Callable[[PyTree, PyTree], PyTree]


@dataclasses.dataclass
class ApplyFnInfo:
"""Information about an apply function.

Attributes:
apply_fn: The apply function, which takes `model_params` and `model_inputs`
as arguments. `model_inputs` must be a dictionary with keys matching
`input_keys`. The function must return a dictionary with keys matching
`output_keys`.
input_keys: The keys of the input dict that the `apply_fn` expects. These
keys are also used to determine the topological ordering of the `apply_fn`
and other `DataProcessor`s in the pipeline.
output_keys: The keys of the output dict that the `apply_fn` produces. These
keys are also used to determine the topological ordering of the `apply_fn`
and other `DataProcessor`s in the pipeline.
"""

apply_fn: ApplyFn
input_keys: Set[str]
output_keys: Set[str]
Loading