From 5fbad196372aa642479549c664d02a9aa4bc41e9 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 17 Dec 2025 11:01:04 -0600 Subject: [PATCH] Refactor Codec interface This refactors the interface of `Codec.read` and `Codec.write` to move from an `Iterable[tuple[...]` to an `Iterable[BatchInfo]`. Two things motivate this change 1. Readability: I struggle to remember what the 4th member of these complex tuples. Having the name `info.out_selection` to remind me is helpful. 2. Possible future-proofing: right now, any change to the interface is a hard break since the number of elements in the tuple will change. There may be a class of changes to the interface where we can add additional information to `BatchInfo` without breaking backwards compatibility. I don't want to oversell motivaiton 2 though. If something is important enough to add to the interface, then presumably we expectd implementations to, you know, use it. --- src/zarr/abc/codec.py | 86 +++++++++++++++++++++----- src/zarr/codecs/sharding.py | 50 ++++++++-------- src/zarr/core/array.py | 36 +++++++---- src/zarr/core/codec_pipeline.py | 103 ++++++++++++++------------------ tests/test_abc/test_codec.py | 81 ++++++++++++++++++++++++- tests/test_config.py | 7 +-- 6 files changed, 249 insertions(+), 114 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..a041edfd79 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,12 +1,14 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Mapping -from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar +from collections.abc import Iterator, Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar from typing_extensions import ReadOnly, TypedDict from zarr.abc.metadata import Metadata +from zarr.abc.store import ByteGetter, ByteSetter from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import NamedConfig, concurrent_map from zarr.core.config import config @@ -15,7 +17,7 @@ from collections.abc import Awaitable, Callable, Iterable from typing import Self - from zarr.abc.store import ByteGetter, ByteSetter, Store + from zarr.abc.store import Store from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @@ -28,10 +30,13 @@ "ArrayBytesCodecPartialDecodeMixin", "ArrayBytesCodecPartialEncodeMixin", "BaseCodec", + "BatchInfo", "BytesBytesCodec", "CodecInput", "CodecOutput", "CodecPipeline", + "ReadBatchInfo", + "WriteBatchInfo", ] CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer) @@ -59,6 +64,58 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]: """The widest type of JSON-like input that could specify a codec.""" +TByteOperator = TypeVar("TByteOperator", bound="ByteGetter") + + +@dataclass(frozen=True) +class BatchInfo(Generic[TByteOperator]): + """Information about a chunk to be read/written from/to the store. + + This class is generic over the byte operator type: + - BatchInfo[ByteGetter] (aliased as ReadBatchInfo) for read operations + - BatchInfo[ByteSetter] (aliased as WriteBatchInfo) for write operations + + Attributes + ---------- + byte_operator : TByteOperator + Used to fetch/write the chunk bytes from/to the store. + For reads, this is a ByteGetter. For writes, this is a ByteSetter. + array_spec : ArraySpec + Specification of the chunk array (shape, dtype, fill value, etc.). + chunk_selection : SelectorTuple + Slice selection determining which parts of the chunk to read/encode. + out_selection : SelectorTuple + Slice selection determining where in the output/value array the chunk data will be written/is located. + is_complete_chunk : bool + Whether this represents a complete chunk (vs. a partial chunk at array boundaries). + """ + + byte_operator: TByteOperator + array_spec: ArraySpec + chunk_selection: SelectorTuple + out_selection: SelectorTuple + is_complete_chunk: bool + + def __iter__(self) -> Iterator[Any]: + """Iterate over fields for backwards compatibility with tuple unpacking.""" + yield self.byte_operator + yield self.array_spec + yield self.chunk_selection + yield self.out_selection + yield self.is_complete_chunk + + def __getitem__(self, index: int) -> Any: + """Index access for backwards compatibility with tuple indexing.""" + return list(self)[index] + + +ReadBatchInfo: TypeAlias = BatchInfo[ByteGetter] +"""Information about a chunk to be read from the store and decoded.""" + +WriteBatchInfo: TypeAlias = BatchInfo[ByteSetter] +"""Information about a chunk to be encoded and written to the store.""" + + class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]): """Generic base class for codecs. @@ -412,7 +469,7 @@ async def encode( @abstractmethod async def read( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ReadBatchInfo], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -421,25 +478,24 @@ async def read( Parameters ---------- - batch_info : Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]] + batch_info : Iterable[ReadBatchInfo] Ordered set of information about the chunks. - The first slice selection determines which parts of the chunk will be fetched. - The second slice selection determines where in the output array the chunk data will be written. - The ByteGetter is used to fetch the necessary bytes. - The chunk spec contains information about the construction of an array from the bytes. + See ReadBatchInfo for details on the fields. If the Store returns ``None`` for a chunk, then the chunk was not written and the implementation must set the values of that chunk (or ``out``) to the fill value for the array. out : NDBuffer + drop_axes : tuple[int, ...] + Axes to drop from the chunk data when reading from the value array. """ ... @abstractmethod async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[WriteBatchInfo], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -449,13 +505,13 @@ async def write( Parameters ---------- - batch_info : Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]] + batch_info : Iterable[WriteBatchInfo] Ordered set of information about the chunks. - The first slice selection determines which parts of the chunk will be encoded. - The second slice selection determines where in the value array the chunk data is located. - The ByteSetter is used to fetch and write the necessary bytes. - The chunk spec contains information about the chunk. + See WriteBatchInfo for details on the fields. value : NDBuffer + The data to write. + drop_axes : tuple[int, ...] + Axes to drop from the chunk data when writing to the output array. """ ... diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 8124ea44ea..803d9ebc44 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -16,6 +16,8 @@ ArrayBytesCodecPartialEncodeMixin, Codec, CodecPipeline, + ReadBatchInfo, + WriteBatchInfo, ) from zarr.abc.store import ( ByteGetter, @@ -358,12 +360,12 @@ async def _decode_single( # decoding chunks and writing them into the output buffer await self.codec_pipeline.read( [ - ( - _ShardingByteGetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, + ReadBatchInfo( + byte_operator=_ShardingByteGetter(shard_dict, chunk_coords), + array_spec=chunk_spec, + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_shard, ) for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer ], @@ -430,12 +432,12 @@ async def _decode_partial_single( # decoding chunks and writing them into the output buffer await self.codec_pipeline.read( [ - ( - _ShardingByteGetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, + ReadBatchInfo( + byte_operator=_ShardingByteGetter(shard_dict, chunk_coords), + array_spec=chunk_spec, + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_shard, ) for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer ], @@ -469,12 +471,12 @@ async def _encode_single( await self.codec_pipeline.write( [ - ( - _ShardingByteSetter(shard_builder, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, + WriteBatchInfo( + byte_operator=_ShardingByteSetter(shard_builder, chunk_coords), + array_spec=chunk_spec, + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_shard, ) for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer ], @@ -515,12 +517,12 @@ async def _encode_partial_single( await self.codec_pipeline.write( [ - ( - _ShardingByteSetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, + WriteBatchInfo( + byte_operator=_ShardingByteSetter(shard_dict, chunk_coords), + array_spec=chunk_spec, + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_shard, ) for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer ], diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 6b20ee950d..a9c72066bc 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -23,7 +23,14 @@ from typing_extensions import deprecated import zarr -from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec +from zarr.abc.codec import ( + ArrayArrayCodec, + ArrayBytesCodec, + BytesBytesCodec, + Codec, + ReadBatchInfo, + WriteBatchInfo, +) from zarr.abc.numcodec import Numcodec, _is_numcodec from zarr.codecs._v2 import V2Codec from zarr.codecs.bytes import BytesCodec @@ -1564,12 +1571,15 @@ async def _get_selection( # reading chunks and decoding them await self.codec_pipeline.read( [ - ( - self.store_path / self.metadata.encode_chunk_key(chunk_coords), - self.metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype), - chunk_selection, - out_selection, - is_complete_chunk, + ReadBatchInfo( + byte_operator=self.store_path + / self.metadata.encode_chunk_key(chunk_coords), + array_spec=self.metadata.get_chunk_spec( + chunk_coords, _config, prototype=prototype + ), + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_chunk, ) for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer ], @@ -1735,12 +1745,12 @@ async def _set_selection( # merging with existing data and encoding chunks await self.codec_pipeline.write( [ - ( - self.store_path / self.metadata.encode_chunk_key(chunk_coords), - self.metadata.get_chunk_spec(chunk_coords, _config, prototype), - chunk_selection, - out_selection, - is_complete_chunk, + WriteBatchInfo( + byte_operator=self.store_path / self.metadata.encode_chunk_key(chunk_coords), + array_spec=self.metadata.get_chunk_spec(chunk_coords, _config, prototype), + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_chunk, ) for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer ], diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..5970546480 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -13,6 +13,8 @@ BytesBytesCodec, Codec, CodecPipeline, + ReadBatchInfo, + WriteBatchInfo, ) from zarr.core.common import concurrent_map from zarr.core.config import config @@ -248,48 +250,43 @@ async def encode_partial_batch( async def read_batch( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ReadBatchInfo], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: + batch_info_list = list(batch_info) if self.supports_partial_decode: chunk_array_batch = await self.decode_partial_batch( [ - (byte_getter, chunk_selection, chunk_spec) - for byte_getter, chunk_spec, chunk_selection, *_ in batch_info + (info.byte_operator, info.chunk_selection, info.array_spec) + for info in batch_info_list ] ) - for chunk_array, (_, chunk_spec, _, out_selection, _) in zip( - chunk_array_batch, batch_info, strict=False - ): + for chunk_array, info in zip(chunk_array_batch, batch_info_list, strict=False): if chunk_array is not None: - out[out_selection] = chunk_array + out[info.out_selection] = chunk_array else: - out[out_selection] = fill_value_or_default(chunk_spec) + out[info.out_selection] = fill_value_or_default(info.array_spec) else: chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], + [(info.byte_operator, info.array_spec.prototype) for info in batch_info_list], lambda byte_getter, prototype: byte_getter.get(prototype), config.get("async.concurrency"), ) chunk_array_batch = await self.decode_batch( [ - (chunk_bytes, chunk_spec) - for chunk_bytes, (_, chunk_spec, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) + (chunk_bytes, info.array_spec) + for chunk_bytes, info in zip(chunk_bytes_batch, batch_info_list, strict=False) ], ) - for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip( - chunk_array_batch, batch_info, strict=False - ): + for chunk_array, info in zip(chunk_array_batch, batch_info_list, strict=False): if chunk_array is not None: - tmp = chunk_array[chunk_selection] + tmp = chunk_array[info.chunk_selection] if drop_axes != (): tmp = tmp.squeeze(axis=drop_axes) - out[out_selection] = tmp + out[info.out_selection] = tmp else: - out[out_selection] = fill_value_or_default(chunk_spec) + out[info.out_selection] = fill_value_or_default(info.array_spec) def _merge_chunk_array( self, @@ -337,24 +334,30 @@ def _merge_chunk_array( async def write_batch( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[WriteBatchInfo], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: + batch_info_list = list(batch_info) if self.supports_partial_encode: # Pass scalar values as is if len(value.shape) == 0: await self.encode_partial_batch( [ - (byte_setter, value, chunk_selection, chunk_spec) - for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info + (info.byte_operator, value, info.chunk_selection, info.array_spec) + for info in batch_info_list ], ) else: await self.encode_partial_batch( [ - (byte_setter, value[out_selection], chunk_selection, chunk_spec) - for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info + ( + info.byte_operator, + value[info.out_selection], + info.chunk_selection, + info.array_spec, + ) + for info in batch_info_list ], ) @@ -371,20 +374,18 @@ async def _read_key( chunk_bytes_batch = await concurrent_map( [ ( - None if is_complete_chunk else byte_setter, - chunk_spec.prototype, + None if info.is_complete_chunk else info.byte_operator, + info.array_spec.prototype, ) - for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info + for info in batch_info_list ], _read_key, config.get("async.concurrency"), ) chunk_array_decoded = await self.decode_batch( [ - (chunk_bytes, chunk_spec) - for chunk_bytes, (_, chunk_spec, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) + (chunk_bytes, info.array_spec) + for chunk_bytes, info in zip(chunk_bytes_batch, batch_info_list, strict=False) ], ) @@ -392,29 +393,21 @@ async def _read_key( self._merge_chunk_array( chunk_array, value, - out_selection, - chunk_spec, - chunk_selection, - is_complete_chunk, + info.out_selection, + info.array_spec, + info.chunk_selection, + info.is_complete_chunk, drop_axes, ) - for chunk_array, ( - _, - chunk_spec, - chunk_selection, - out_selection, - is_complete_chunk, - ) in zip(chunk_array_decoded, batch_info, strict=False) + for chunk_array, info in zip(chunk_array_decoded, batch_info_list, strict=False) ] chunk_array_batch: list[NDBuffer | None] = [] - for chunk_array, (_, chunk_spec, *_) in zip( - chunk_array_merged, batch_info, strict=False - ): + for chunk_array, info in zip(chunk_array_merged, batch_info_list, strict=False): if chunk_array is None: chunk_array_batch.append(None) # type: ignore[unreachable] else: - if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal( - fill_value_or_default(chunk_spec) + if not info.array_spec.config.write_empty_chunks and chunk_array.all_equal( + fill_value_or_default(info.array_spec) ): chunk_array_batch.append(None) else: @@ -422,10 +415,8 @@ async def _read_key( chunk_bytes_batch = await self.encode_batch( [ - (chunk_array, chunk_spec) - for chunk_array, (_, chunk_spec, *_) in zip( - chunk_array_batch, batch_info, strict=False - ) + (chunk_array, info.array_spec) + for chunk_array, info in zip(chunk_array_batch, batch_info_list, strict=False) ], ) @@ -437,10 +428,8 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non await concurrent_map( [ - (byte_setter, chunk_bytes) - for chunk_bytes, (byte_setter, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) + (info.byte_operator, chunk_bytes) + for chunk_bytes, info in zip(chunk_bytes_batch, batch_info_list, strict=False) ], _write_key, config.get("async.concurrency"), @@ -466,7 +455,7 @@ async def encode( async def read( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ReadBatchInfo], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -481,7 +470,7 @@ async def read( async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[WriteBatchInfo], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: diff --git a/tests/test_abc/test_codec.py b/tests/test_abc/test_codec.py index e0f9ddb7bb..aa1cba93b8 100644 --- a/tests/test_abc/test_codec.py +++ b/tests/test_abc/test_codec.py @@ -1,6 +1,14 @@ from __future__ import annotations -from zarr.abc.codec import _check_codecjson_v2 +from typing import TYPE_CHECKING + +from zarr.abc.codec import ReadBatchInfo, WriteBatchInfo, _check_codecjson_v2 +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import Buffer, BufferPrototype, default_buffer_prototype +from zarr.dtype import UInt8 + +if TYPE_CHECKING: + from zarr.abc.store import ByteRequest def test_check_codecjson_v2_valid() -> None: @@ -10,3 +18,74 @@ def test_check_codecjson_v2_valid() -> None: assert _check_codecjson_v2({"id": "gzip"}) assert not _check_codecjson_v2({"id": 10}) assert not _check_codecjson_v2([10, 11]) + + +def test_read_batch_info_iterable() -> None: + class ByteGetter: + async def get( + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Buffer | None: + pass + + byte_getter = ByteGetter() + info = ReadBatchInfo( + byte_operator=byte_getter, + array_spec=ArraySpec( + shape=(16, 16), + dtype=UInt8(), + fill_value=0, + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ), + chunk_selection=(0, 0), + out_selection=(0, 0), + is_complete_chunk=True, + ) + + assert tuple(info) == ( + info.byte_operator, + info.array_spec, + info.chunk_selection, + info.out_selection, + info.is_complete_chunk, + ) + + +def test_write_batch_info_iterable() -> None: + class ByteSetter: + async def get( + self, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Buffer | None: + pass + + async def set(self, value: Buffer) -> None: + pass + + async def delete(self) -> None: + pass + + async def set_if_not_exists(self, default: Buffer) -> None: + pass + + byte_setter = ByteSetter() + + info = WriteBatchInfo( + byte_operator=byte_setter, + array_spec=ArraySpec( + shape=(16, 16), + dtype=UInt8(), + fill_value=0, + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ), + chunk_selection=(0, 0), + out_selection=(0, 0), + is_complete_chunk=True, + ) + assert tuple(info) == ( + info.byte_operator, + info.array_spec, + info.chunk_selection, + info.out_selection, + info.is_complete_chunk, + ) diff --git a/tests/test_config.py b/tests/test_config.py index c3102e8efe..97aa8919f2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,8 +9,8 @@ import zarr from zarr import zeros -from zarr.abc.codec import CodecPipeline -from zarr.abc.store import ByteSetter, Store +from zarr.abc.codec import CodecPipeline, WriteBatchInfo +from zarr.abc.store import Store from zarr.codecs import ( BloscCodec, BytesCodec, @@ -22,7 +22,6 @@ from zarr.core.buffer.core import Buffer from zarr.core.codec_pipeline import BatchedCodecPipeline from zarr.core.config import BadConfigError, config -from zarr.core.indexing import SelectorTuple from zarr.errors import ZarrUserWarning from zarr.registry import ( fully_qualified_name, @@ -140,7 +139,7 @@ def test_config_codec_pipeline_class(store: Store) -> None: class MockCodecPipeline(BatchedCodecPipeline): async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[WriteBatchInfo], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: