Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import s3_utils


def get_item_default_temporary_path_class(
path: epath.Path,
) -> type[atomicity_types.TemporaryPath]:
"""Returns the default temporary path class for a given sub-item path."""
if gcs_utils.is_gcs_path(path):
if gcs_utils.is_gcs_path(path) or s3_utils.is_s3_path(path):
return atomicity.CommitFileTemporaryPath
else:
return atomicity.AtomicRenameTemporaryPath
Expand All @@ -39,7 +40,7 @@ def get_default_temporary_path_class(
path: epath.Path,
) -> type[atomicity_types.TemporaryPath]:
"""Returns the default temporary path class for a given checkpoint path."""
if gcs_utils.is_gcs_path(path):
if gcs_utils.is_gcs_path(path) or s3_utils.is_s3_path(path):
return atomicity.CommitFileTemporaryPath
else:
return atomicity.AtomicRenameTemporaryPath
50 changes: 50 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/deleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from orbax.checkpoint._src.logging import event_tracking
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import s3_utils
from orbax.checkpoint._src.path import step as step_lib


Expand Down Expand Up @@ -188,6 +189,10 @@ def delete(self, step: int) -> None:
# This is recommended for GCS buckets with HNS enabled and requires
# `_todelete_full_path` to be specified.
self._gcs_rename_step(step, delete_target)
elif s3_utils.is_s3_path(self._directory):
# This is recommended for S3 buckets and requires
# `_todelete_full_path` to be specified.
self._s3_rename_step(step, delete_target)
else:
raise NotImplementedError()
# Attempt to rename to local subdirectory using `todelete_subdir`
Expand Down Expand Up @@ -259,6 +264,51 @@ def _gcs_rename_step(
logging.error(message)
raise RuntimeError(message) from e

def _s3_rename_step(
self, step: int, delete_target: epath.Path
):
"""Renames a S3 directory to a temporary location for deletion.

Args:
step: The checkpoint step number.
delete_target: The path to the directory to be renamed.
"""
try:
# Get the bucket name from the source path
bucket_name = urlparse(str(delete_target)).netloc
if not bucket_name:
raise ValueError(
f'Could not parse bucket name from path: {delete_target}'
)

# Construct the destination path inside the `_todelete_full_path` dir.
destination_parent_path = epath.Path(
f's3://{bucket_name}/{self._todelete_full_path}'
)
destination_parent_path.mkdir(parents=True, exist_ok=True)

# Create a unique name for the destination to avoid collisions.
now = datetime.datetime.now()
timestamp_str = now.strftime('%Y%m%d-%H%M%S-%f')
new_name_with_timestamp = f'{delete_target.name}-{timestamp_str}'
dest_path = destination_parent_path / new_name_with_timestamp

logging.info(
'Executing filesystem-aware rename: Source=`%s`, Destination=`%s`',
delete_target,
dest_path,
)

# Call the high-level rename method.
# This will be fast on HNS and slow (but functional) on non-HNS.
delete_target.rename(dest_path)
logging.info('Successfully renamed step %d to %s', step, dest_path)

except Exception as e:
message = f'Rename failed for step {step}. Error: {e}'
logging.error(message)
raise RuntimeError(message) from e

def _rename_step_to_subdir(self, step: int, delete_target: epath.Path):
"""Renames a step directory to its corresponding todelete_subdir."""
rename_dir = self._directory / self._todelete_subdir
Expand Down
39 changes: 39 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/s3_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utils for interacting with S3 paths."""

import functools
from urllib import parse
from etils import epath


_S3_PATH_PREFIX = ('s3://',)


def is_s3_path(path: epath.Path) -> bool:
return path.as_posix().startswith(_S3_PATH_PREFIX)


def parse_s3_path(path: epath.PathLike) -> tuple[str, str]:
parsed = parse.urlparse(str(path))
assert parsed.scheme == 's3', f'Unsupported scheme for S3: {parsed.scheme}'
# Strip the leading slash from the path.
standardized_path = parsed.path
if standardized_path.startswith('/'):
standardized_path = standardized_path[1:]
# Add a trailing slash if it's missing.
if not standardized_path.endswith('/'):
standardized_path = standardized_path + '/'
return parsed.netloc, standardized_path
74 changes: 74 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/s3_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for s3_utils."""

from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
from orbax.checkpoint._src.path import s3_utils


class S3UtilsTest(parameterized.TestCase):

@parameterized.parameters(
('s3://bucket/path', True),
('s3://bucket/path/', True),
('s3://bucket', True),
('/local/path', False),
('gs://bucket/path', False),
('/tmp/foo/bar', False),
('file:///path', False),
)
def test_is_s3_path(self, path_str, expected):
path = epath.Path(path_str)
self.assertEqual(s3_utils.is_s3_path(path), expected)

@parameterized.parameters(
('s3://bucket/path/to/file', ('bucket', 'path/to/file/')),
('s3://bucket/path/', ('bucket', 'path/')),
('s3://bucket/path', ('bucket', 'path/')),
('s3://bucket/', ('bucket', '/')),
('s3://bucket', ('bucket', '/')),
('s3://my-bucket/deep/nested/path/', ('my-bucket', 'deep/nested/path/')),
)
def test_parse_s3_path(self, path_str, expected):
bucket, key = s3_utils.parse_s3_path(path_str)
self.assertEqual((bucket, key), expected)

def test_parse_s3_path_strips_leading_slash(self):
"""Test that leading slashes are properly stripped from the path."""
bucket, key = s3_utils.parse_s3_path('s3://bucket//path')
self.assertEqual(bucket, 'bucket')
self.assertEqual(key, '/path/')

def test_parse_s3_path_adds_trailing_slash(self):
"""Test that trailing slashes are added when missing."""
bucket, key = s3_utils.parse_s3_path('s3://bucket/path')
self.assertTrue(key.endswith('/'))
self.assertEqual(key, 'path/')

def test_parse_s3_path_invalid_scheme(self):
"""Test that non-s3 schemes raise an assertion error."""
with self.assertRaises(AssertionError):
s3_utils.parse_s3_path('gs://bucket/path')

def test_parse_s3_path_no_scheme(self):
"""Test that paths without schemes raise an assertion error."""
with self.assertRaises(AssertionError):
s3_utils.parse_s3_path('/local/path')


if __name__ == '__main__':
absltest.main()
3 changes: 3 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from etils import epath

from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import s3_utils



Expand All @@ -36,6 +37,8 @@ def get_storage_type(path: epath.Path | str) -> str:

if gcs_utils.is_gcs_path(path):
return 'gcs'
elif s3_utils.is_s3_path(path):
return 's3'
else:
return 'other'

Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/_src/path/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_recursively_copy_files_with_skip_paths(self):

@parameterized.parameters(
('gs://bucket/path', 'gcs'),
('s3://bucket/path', 's3'),
('/tmp/foo/bar', 'other'),
)
def test_get_storage_type(self, path, expected_type):
Expand Down
29 changes: 28 additions & 1 deletion checkpoint/orbax/checkpoint/_src/serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,41 @@ def _get_kvstore_for_gcs(ckpt_path: str):
'path': path_without_bucket,
}

def _get_kvstore_for_s3(ckpt_path: str):
"""Constructs a TensorStore kvstore spec for a S3 path.

Args:
ckpt_path: A S3 path of the form s3://<bucket>/<path>.
Returns:
A dictionary containing the TensorStore kvstore spec.
Raises:
ValueError: if ckpt_path is not a valid S3 path.
"""
m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
if m is None:
raise ValueError(
'The ckpt_path should contain the bucket name and the '
f'file path inside the bucket. Got: {ckpt_path}'
)
s3_bucket = m.group(1)
path_without_bucket = m.group(2)
return {
'driver': 's3',
'bucket': s3_bucket,
'path': path_without_bucket,
}


def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
"""Constructs a TensorStore spec for the given checkpoint path."""
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
# fix the path prefix to add back the stripped '/'.
ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://')
is_gcs_path = ckpt_path.startswith('gs://')
is_s3_path = ckpt_path.startswith('s3://')
spec = {'driver': 'zarr', 'kvstore': {}}
if ocdbt:
if not is_gcs_path and not os.path.isabs(ckpt_path):
if not is_gcs_path and not is_s3_path and not os.path.isabs(ckpt_path):
raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
base_path = os.path.dirname(ckpt_path)
base_driver_spec = (
Expand All @@ -117,6 +142,8 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
else:
if is_gcs_path:
spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path)
elif is_s3_path:
spec['kvstore'] = _get_kvstore_for_s3(ckpt_path)
else:
spec['kvstore'] = {'driver': ts_utils.DEFAULT_DRIVER, 'path': ckpt_path}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ZARR_VER3 = 'zarr3'

_GCS_PATH_RE = r'^gs://([^/]*)/(.*)$'
_S3_PATH_RE = r'^s3://([^/]*)/(.*)$'

# Even if the data is equal to the fill value, we still want to write it
# to the checkpoint. This results in unnecessary writes in some edge
Expand Down Expand Up @@ -135,6 +136,24 @@ def _get_kvstore_for_gcs(ckpt_path: str) -> JsonSpec:
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}


def _get_kvstore_for_s3(ckpt_path: str) -> JsonSpec:
"""Constructs a TensorStore kvstore spec for an S3 path."""
# Parse s3://bucket-name/path/to/checkpoint
m = re.fullmatch(_S3_PATH_RE, ckpt_path, re.DOTALL)
if m is None:
raise ValueError(
'The ckpt_path should contain the bucket name and the '
f'file path inside the bucket. Got: {ckpt_path}'
)
s3_bucket = m.group(1)
path_without_bucket = m.group(2)
return {
'driver': 's3',
'bucket': s3_bucket,
'path': path_without_bucket,
}


def build_kvstore_tspec(
directory: str,
name: str | None = None,
Expand All @@ -161,12 +180,13 @@ def build_kvstore_tspec(
default_driver = DEFAULT_DRIVER
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
# fix the path prefix to add back the stripped '/'.
directory = os.path.normpath(directory).replace('gs:/', 'gs://')
directory = os.path.normpath(directory).replace('gs:/', 'gs://').replace('s3:/', 's3://')
is_gcs_path = directory.startswith('gs://')
is_s3_path = directory.startswith('s3://')
kv_spec = {}

if use_ocdbt:
if not is_gcs_path and not os.path.isabs(directory):
if not is_gcs_path and not is_s3_path and not os.path.isabs(directory):
raise ValueError(f'Checkpoint path should be absolute. Got {directory}')
if process_id is not None:
process_id = str(process_id)
Expand All @@ -186,7 +206,7 @@ def build_kvstore_tspec(
directory = os.path.join(*join_paths)
base_driver_spec = (
directory
if is_gcs_path
if is_gcs_path or is_s3_path
else {'driver': default_driver, 'path': str(directory)}
)
kv_spec.update({
Expand Down Expand Up @@ -217,6 +237,8 @@ def build_kvstore_tspec(
path = os.path.join(directory, name)
if is_gcs_path:
kv_spec = _get_kvstore_for_gcs(path)
elif is_s3_path:
kv_spec = _get_kvstore_for_s3(path)
else:
kv_spec = {'driver': default_driver, 'path': path}

Expand Down
Loading