diff --git a/checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py b/checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py index ccc853aa6..fccb1697d 100644 --- a/checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py @@ -23,13 +23,14 @@ from orbax.checkpoint._src.path import atomicity from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint._src.path import gcs_utils +from orbax.checkpoint._src.path import s3_utils def get_item_default_temporary_path_class( path: epath.Path, ) -> type[atomicity_types.TemporaryPath]: """Returns the default temporary path class for a given sub-item path.""" - if gcs_utils.is_gcs_path(path): + if gcs_utils.is_gcs_path(path) or s3_utils.is_s3_path(path): return atomicity.CommitFileTemporaryPath else: return atomicity.AtomicRenameTemporaryPath @@ -39,7 +40,7 @@ def get_default_temporary_path_class( path: epath.Path, ) -> type[atomicity_types.TemporaryPath]: """Returns the default temporary path class for a given checkpoint path.""" - if gcs_utils.is_gcs_path(path): + if gcs_utils.is_gcs_path(path) or s3_utils.is_s3_path(path): return atomicity.CommitFileTemporaryPath else: return atomicity.AtomicRenameTemporaryPath diff --git a/checkpoint/orbax/checkpoint/_src/path/deleter.py b/checkpoint/orbax/checkpoint/_src/path/deleter.py index d24483653..5a7d20c98 100644 --- a/checkpoint/orbax/checkpoint/_src/path/deleter.py +++ b/checkpoint/orbax/checkpoint/_src/path/deleter.py @@ -29,6 +29,7 @@ from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import gcs_utils +from orbax.checkpoint._src.path import s3_utils from orbax.checkpoint._src.path import step as step_lib @@ -188,6 +189,10 @@ def delete(self, step: int) -> None: # This is recommended for GCS buckets with HNS enabled and requires # `_todelete_full_path` to be specified. self._gcs_rename_step(step, delete_target) + elif s3_utils.is_s3_path(self._directory): + # This is recommended for S3 buckets and requires + # `_todelete_full_path` to be specified. + self._s3_rename_step(step, delete_target) else: raise NotImplementedError() # Attempt to rename to local subdirectory using `todelete_subdir` @@ -259,6 +264,51 @@ def _gcs_rename_step( logging.error(message) raise RuntimeError(message) from e + def _s3_rename_step( + self, step: int, delete_target: epath.Path + ): + """Renames a S3 directory to a temporary location for deletion. + + Args: + step: The checkpoint step number. + delete_target: The path to the directory to be renamed. + """ + try: + # Get the bucket name from the source path + bucket_name = urlparse(str(delete_target)).netloc + if not bucket_name: + raise ValueError( + f'Could not parse bucket name from path: {delete_target}' + ) + + # Construct the destination path inside the `_todelete_full_path` dir. + destination_parent_path = epath.Path( + f's3://{bucket_name}/{self._todelete_full_path}' + ) + destination_parent_path.mkdir(parents=True, exist_ok=True) + + # Create a unique name for the destination to avoid collisions. + now = datetime.datetime.now() + timestamp_str = now.strftime('%Y%m%d-%H%M%S-%f') + new_name_with_timestamp = f'{delete_target.name}-{timestamp_str}' + dest_path = destination_parent_path / new_name_with_timestamp + + logging.info( + 'Executing filesystem-aware rename: Source=`%s`, Destination=`%s`', + delete_target, + dest_path, + ) + + # Call the high-level rename method. + # This will be fast on HNS and slow (but functional) on non-HNS. + delete_target.rename(dest_path) + logging.info('Successfully renamed step %d to %s', step, dest_path) + + except Exception as e: + message = f'Rename failed for step {step}. Error: {e}' + logging.error(message) + raise RuntimeError(message) from e + def _rename_step_to_subdir(self, step: int, delete_target: epath.Path): """Renames a step directory to its corresponding todelete_subdir.""" rename_dir = self._directory / self._todelete_subdir diff --git a/checkpoint/orbax/checkpoint/_src/path/s3_utils.py b/checkpoint/orbax/checkpoint/_src/path/s3_utils.py new file mode 100644 index 000000000..2ef11bf1b --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/path/s3_utils.py @@ -0,0 +1,39 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for interacting with S3 paths.""" + +import functools +from urllib import parse +from etils import epath + + +_S3_PATH_PREFIX = ('s3://',) + + +def is_s3_path(path: epath.Path) -> bool: + return path.as_posix().startswith(_S3_PATH_PREFIX) + + +def parse_s3_path(path: epath.PathLike) -> tuple[str, str]: + parsed = parse.urlparse(str(path)) + assert parsed.scheme == 's3', f'Unsupported scheme for S3: {parsed.scheme}' + # Strip the leading slash from the path. + standardized_path = parsed.path + if standardized_path.startswith('/'): + standardized_path = standardized_path[1:] + # Add a trailing slash if it's missing. + if not standardized_path.endswith('/'): + standardized_path = standardized_path + '/' + return parsed.netloc, standardized_path diff --git a/checkpoint/orbax/checkpoint/_src/path/s3_utils_test.py b/checkpoint/orbax/checkpoint/_src/path/s3_utils_test.py new file mode 100644 index 000000000..d5cbc6caf --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/path/s3_utils_test.py @@ -0,0 +1,74 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for s3_utils.""" + +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +from orbax.checkpoint._src.path import s3_utils + + +class S3UtilsTest(parameterized.TestCase): + + @parameterized.parameters( + ('s3://bucket/path', True), + ('s3://bucket/path/', True), + ('s3://bucket', True), + ('/local/path', False), + ('gs://bucket/path', False), + ('/tmp/foo/bar', False), + ('file:///path', False), + ) + def test_is_s3_path(self, path_str, expected): + path = epath.Path(path_str) + self.assertEqual(s3_utils.is_s3_path(path), expected) + + @parameterized.parameters( + ('s3://bucket/path/to/file', ('bucket', 'path/to/file/')), + ('s3://bucket/path/', ('bucket', 'path/')), + ('s3://bucket/path', ('bucket', 'path/')), + ('s3://bucket/', ('bucket', '/')), + ('s3://bucket', ('bucket', '/')), + ('s3://my-bucket/deep/nested/path/', ('my-bucket', 'deep/nested/path/')), + ) + def test_parse_s3_path(self, path_str, expected): + bucket, key = s3_utils.parse_s3_path(path_str) + self.assertEqual((bucket, key), expected) + + def test_parse_s3_path_strips_leading_slash(self): + """Test that leading slashes are properly stripped from the path.""" + bucket, key = s3_utils.parse_s3_path('s3://bucket//path') + self.assertEqual(bucket, 'bucket') + self.assertEqual(key, '/path/') + + def test_parse_s3_path_adds_trailing_slash(self): + """Test that trailing slashes are added when missing.""" + bucket, key = s3_utils.parse_s3_path('s3://bucket/path') + self.assertTrue(key.endswith('/')) + self.assertEqual(key, 'path/') + + def test_parse_s3_path_invalid_scheme(self): + """Test that non-s3 schemes raise an assertion error.""" + with self.assertRaises(AssertionError): + s3_utils.parse_s3_path('gs://bucket/path') + + def test_parse_s3_path_no_scheme(self): + """Test that paths without schemes raise an assertion error.""" + with self.assertRaises(AssertionError): + s3_utils.parse_s3_path('/local/path') + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/path/utils.py b/checkpoint/orbax/checkpoint/_src/path/utils.py index 3da79bd2f..c3504e340 100644 --- a/checkpoint/orbax/checkpoint/_src/path/utils.py +++ b/checkpoint/orbax/checkpoint/_src/path/utils.py @@ -24,6 +24,7 @@ from etils import epath from orbax.checkpoint._src.path import gcs_utils +from orbax.checkpoint._src.path import s3_utils @@ -36,6 +37,8 @@ def get_storage_type(path: epath.Path | str) -> str: if gcs_utils.is_gcs_path(path): return 'gcs' + elif s3_utils.is_s3_path(path): + return 's3' else: return 'other' diff --git a/checkpoint/orbax/checkpoint/_src/path/utils_test.py b/checkpoint/orbax/checkpoint/_src/path/utils_test.py index e3adab4e0..d2bc3542d 100644 --- a/checkpoint/orbax/checkpoint/_src/path/utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/utils_test.py @@ -70,6 +70,7 @@ def test_recursively_copy_files_with_skip_paths(self): @parameterized.parameters( ('gs://bucket/path', 'gcs'), + ('s3://bucket/path', 's3'), ('/tmp/foo/bar', 'other'), ) def test_get_storage_type(self, path, expected_type): diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py index ac6682eba..f65b27427 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py @@ -92,6 +92,30 @@ def _get_kvstore_for_gcs(ckpt_path: str): 'path': path_without_bucket, } +def _get_kvstore_for_s3(ckpt_path: str): + """Constructs a TensorStore kvstore spec for a S3 path. + + Args: + ckpt_path: A S3 path of the form s3:///. + Returns: + A dictionary containing the TensorStore kvstore spec. + Raises: + ValueError: if ckpt_path is not a valid S3 path. + """ + m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path, re.DOTALL) + if m is None: + raise ValueError( + 'The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}' + ) + s3_bucket = m.group(1) + path_without_bucket = m.group(2) + return { + 'driver': 's3', + 'bucket': s3_bucket, + 'path': path_without_bucket, + } + def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): """Constructs a TensorStore spec for the given checkpoint path.""" @@ -99,9 +123,10 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): # fix the path prefix to add back the stripped '/'. ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://') is_gcs_path = ckpt_path.startswith('gs://') + is_s3_path = ckpt_path.startswith('s3://') spec = {'driver': 'zarr', 'kvstore': {}} if ocdbt: - if not is_gcs_path and not os.path.isabs(ckpt_path): + if not is_gcs_path and not is_s3_path and not os.path.isabs(ckpt_path): raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') base_path = os.path.dirname(ckpt_path) base_driver_spec = ( @@ -117,6 +142,8 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): else: if is_gcs_path: spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path) + elif is_s3_path: + spec['kvstore'] = _get_kvstore_for_s3(ckpt_path) else: spec['kvstore'] = {'driver': ts_utils.DEFAULT_DRIVER, 'path': ckpt_path} diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py index bd971b1fd..c8e215424 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py @@ -56,6 +56,7 @@ ZARR_VER3 = 'zarr3' _GCS_PATH_RE = r'^gs://([^/]*)/(.*)$' +_S3_PATH_RE = r'^s3://([^/]*)/(.*)$' # Even if the data is equal to the fill value, we still want to write it # to the checkpoint. This results in unnecessary writes in some edge @@ -135,6 +136,24 @@ def _get_kvstore_for_gcs(ckpt_path: str) -> JsonSpec: return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket} +def _get_kvstore_for_s3(ckpt_path: str) -> JsonSpec: + """Constructs a TensorStore kvstore spec for an S3 path.""" + # Parse s3://bucket-name/path/to/checkpoint + m = re.fullmatch(_S3_PATH_RE, ckpt_path, re.DOTALL) + if m is None: + raise ValueError( + 'The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}' + ) + s3_bucket = m.group(1) + path_without_bucket = m.group(2) + return { + 'driver': 's3', + 'bucket': s3_bucket, + 'path': path_without_bucket, + } + + def build_kvstore_tspec( directory: str, name: str | None = None, @@ -161,12 +180,13 @@ def build_kvstore_tspec( default_driver = DEFAULT_DRIVER # Normalize path to exclude trailing '/'. In GCS path case, we will need to # fix the path prefix to add back the stripped '/'. - directory = os.path.normpath(directory).replace('gs:/', 'gs://') + directory = os.path.normpath(directory).replace('gs:/', 'gs://').replace('s3:/', 's3://') is_gcs_path = directory.startswith('gs://') + is_s3_path = directory.startswith('s3://') kv_spec = {} if use_ocdbt: - if not is_gcs_path and not os.path.isabs(directory): + if not is_gcs_path and not is_s3_path and not os.path.isabs(directory): raise ValueError(f'Checkpoint path should be absolute. Got {directory}') if process_id is not None: process_id = str(process_id) @@ -186,7 +206,7 @@ def build_kvstore_tspec( directory = os.path.join(*join_paths) base_driver_spec = ( directory - if is_gcs_path + if is_gcs_path or is_s3_path else {'driver': default_driver, 'path': str(directory)} ) kv_spec.update({ @@ -217,6 +237,8 @@ def build_kvstore_tspec( path = os.path.join(directory, name) if is_gcs_path: kv_spec = _get_kvstore_for_gcs(path) + elif is_s3_path: + kv_spec = _get_kvstore_for_s3(path) else: kv_spec = {'driver': default_driver, 'path': path} diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py index aa3420bfa..1464f3707 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py @@ -169,6 +169,36 @@ def test_file_kvstore_with_gcs_path( }, ) + @parameterized.named_parameters( + dict( + testcase_name='regular_path', + directory='s3://s3_bucket/object_path', + ), + dict( + testcase_name='path_with_single_slash', + directory='s3:/s3_bucket/object_path', + ), + ) + def test_file_kvstore_with_s3_path( + self, + directory: str, + ): + tspec = self.array_write_spec_constructor( + directory=directory, + relative_array_filename=self.param_name, + use_zarr3=False, + use_ocdbt=False, + ) + self.assertFalse(tspec.metadata.use_ocdbt) + self.assertEqual( + tspec.json['kvstore'], + { + 'driver': 's3', + 'bucket': 's3_bucket', + 'path': f'object_path/{self.param_name}', + }, + ) + @parameterized.named_parameters( dict( testcase_name='local_fs_path', @@ -232,6 +262,39 @@ def test_ocdbt_kvstore_with_gcs_path( ) self.assertEqual(kvstore_tspec['path'], self.param_name) + @parameterized.named_parameters( + dict( + testcase_name='regular_path', + directory='s3://s3_bucket/object_path', + expected_directory=None, + ), + dict( + testcase_name='path_with_single_slash', + directory='s3:/s3_bucket/object_path', + expected_directory='s3://s3_bucket/object_path', + ), + ) + def test_ocdbt_kvstore_with_s3_path( + self, + directory: str, + expected_directory: str | None, + ): + tspec = self.array_write_spec_constructor( + directory=directory, + relative_array_filename=self.param_name, + use_zarr3=False, + use_ocdbt=True, + process_id=0, + ) + self.assertTrue(tspec.metadata.use_ocdbt) + kvstore_tspec = tspec.json['kvstore'] + self.assertEqual(kvstore_tspec['driver'], 'ocdbt') + self.assertEqual( + kvstore_tspec['base'], + os.path.join(expected_directory or directory, 'ocdbt.process_0'), + ) + self.assertEqual(kvstore_tspec['path'], self.param_name) + @parameterized.product(use_zarr3=(True, False)) def test_ocdbt_kvstore_default_target_data_file_size(self, use_zarr3: bool): tspec = self.array_write_spec_constructor( @@ -598,6 +661,10 @@ def test_maybe_cloud_storage(self): gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True) self.assertTrue(ts_utils.is_remote_storage(gs_spec)) + s3_path = 's3://some-bucket/path' + s3_spec = serialization.get_tensorstore_spec(s3_path, ocdbt=True) + self.assertTrue(ts_utils.is_remote_storage(s3_spec)) + local_path = '/tmp/checkpoint' local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True) self.assertFalse(ts_utils.is_remote_storage(local_spec)) diff --git a/checkpoint/orbax/checkpoint/utils.py b/checkpoint/orbax/checkpoint/utils.py index 1634c3033..a1fb08965 100644 --- a/checkpoint/orbax/checkpoint/utils.py +++ b/checkpoint/orbax/checkpoint/utils.py @@ -27,6 +27,7 @@ from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import atomicity from orbax.checkpoint._src.path import gcs_utils +from orbax.checkpoint._src.path import s3_utils from orbax.checkpoint._src.path import step as step_lib from orbax.checkpoint._src.tree import utils as tree_utils @@ -49,6 +50,7 @@ is_primary_host = multihost.is_primary_host is_gcs_path = gcs_utils.is_gcs_path +is_s3_path = s3_utils.is_s3_path checkpoint_steps = step_lib.checkpoint_steps any_checkpoint_step = step_lib.any_checkpoint_step is_checkpoint_finalized = step_lib.is_path_finalized