Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 71 additions & 15 deletions src/zarr/abc/codec.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from abc import abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar
from collections.abc import Iterator, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar

from typing_extensions import ReadOnly, TypedDict

from zarr.abc.metadata import Metadata
from zarr.abc.store import ByteGetter, ByteSetter
from zarr.core.buffer import Buffer, NDBuffer
from zarr.core.common import NamedConfig, concurrent_map
from zarr.core.config import config
Expand All @@ -15,7 +17,7 @@
from collections.abc import Awaitable, Callable, Iterable
from typing import Self

from zarr.abc.store import ByteGetter, ByteSetter, Store
from zarr.abc.store import Store
from zarr.core.array_spec import ArraySpec
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
Expand All @@ -28,10 +30,13 @@
"ArrayBytesCodecPartialDecodeMixin",
"ArrayBytesCodecPartialEncodeMixin",
"BaseCodec",
"BatchInfo",
"BytesBytesCodec",
"CodecInput",
"CodecOutput",
"CodecPipeline",
"ReadBatchInfo",
"WriteBatchInfo",
]

CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
Expand Down Expand Up @@ -59,6 +64,58 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]:
"""The widest type of JSON-like input that could specify a codec."""


TByteOperator = TypeVar("TByteOperator", bound="ByteGetter")


@dataclass(frozen=True)
class BatchInfo(Generic[TByteOperator]):
"""Information about a chunk to be read/written from/to the store.

This class is generic over the byte operator type:
- BatchInfo[ByteGetter] (aliased as ReadBatchInfo) for read operations
- BatchInfo[ByteSetter] (aliased as WriteBatchInfo) for write operations

Attributes
----------
byte_operator : TByteOperator
Used to fetch/write the chunk bytes from/to the store.
For reads, this is a ByteGetter. For writes, this is a ByteSetter.
array_spec : ArraySpec
Specification of the chunk array (shape, dtype, fill value, etc.).
chunk_selection : SelectorTuple
Slice selection determining which parts of the chunk to read/encode.
out_selection : SelectorTuple
Slice selection determining where in the output/value array the chunk data will be written/is located.
is_complete_chunk : bool
Whether this represents a complete chunk (vs. a partial chunk at array boundaries).
"""

byte_operator: TByteOperator
array_spec: ArraySpec
chunk_selection: SelectorTuple
out_selection: SelectorTuple
is_complete_chunk: bool

def __iter__(self) -> Iterator[Any]:
"""Iterate over fields for backwards compatibility with tuple unpacking."""
yield self.byte_operator
yield self.array_spec
yield self.chunk_selection
yield self.out_selection
yield self.is_complete_chunk

def __getitem__(self, index: int) -> Any:
"""Index access for backwards compatibility with tuple indexing."""
return list(self)[index]


ReadBatchInfo: TypeAlias = BatchInfo[ByteGetter]
"""Information about a chunk to be read from the store and decoded."""

WriteBatchInfo: TypeAlias = BatchInfo[ByteSetter]
"""Information about a chunk to be encoded and written to the store."""


class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]):
"""Generic base class for codecs.

Expand Down Expand Up @@ -412,7 +469,7 @@ async def encode(
@abstractmethod
async def read(
self,
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
batch_info: Iterable[ReadBatchInfo],
out: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand All @@ -421,25 +478,24 @@ async def read(

Parameters
----------
batch_info : Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]]
batch_info : Iterable[ReadBatchInfo]
Ordered set of information about the chunks.
The first slice selection determines which parts of the chunk will be fetched.
The second slice selection determines where in the output array the chunk data will be written.
The ByteGetter is used to fetch the necessary bytes.
The chunk spec contains information about the construction of an array from the bytes.
See ReadBatchInfo for details on the fields.

If the Store returns ``None`` for a chunk, then the chunk was not
written and the implementation must set the values of that chunk (or
``out``) to the fill value for the array.

out : NDBuffer
drop_axes : tuple[int, ...]
Axes to drop from the chunk data when reading from the value array.
"""
...

@abstractmethod
async def write(
self,
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
batch_info: Iterable[WriteBatchInfo],
value: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand All @@ -449,13 +505,13 @@ async def write(

Parameters
----------
batch_info : Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]]
batch_info : Iterable[WriteBatchInfo]
Ordered set of information about the chunks.
The first slice selection determines which parts of the chunk will be encoded.
The second slice selection determines where in the value array the chunk data is located.
The ByteSetter is used to fetch and write the necessary bytes.
The chunk spec contains information about the chunk.
See WriteBatchInfo for details on the fields.
value : NDBuffer
The data to write.
drop_axes : tuple[int, ...]
Axes to drop from the chunk data when writing to the output array.
"""
...

Expand Down
50 changes: 26 additions & 24 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
ArrayBytesCodecPartialEncodeMixin,
Codec,
CodecPipeline,
ReadBatchInfo,
WriteBatchInfo,
)
from zarr.abc.store import (
ByteGetter,
Expand Down Expand Up @@ -358,12 +360,12 @@ async def _decode_single(
# decoding chunks and writing them into the output buffer
await self.codec_pipeline.read(
[
(
_ShardingByteGetter(shard_dict, chunk_coords),
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
ReadBatchInfo(
byte_operator=_ShardingByteGetter(shard_dict, chunk_coords),
array_spec=chunk_spec,
chunk_selection=chunk_selection,
out_selection=out_selection,
is_complete_chunk=is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
Expand Down Expand Up @@ -430,12 +432,12 @@ async def _decode_partial_single(
# decoding chunks and writing them into the output buffer
await self.codec_pipeline.read(
[
(
_ShardingByteGetter(shard_dict, chunk_coords),
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
ReadBatchInfo(
byte_operator=_ShardingByteGetter(shard_dict, chunk_coords),
array_spec=chunk_spec,
chunk_selection=chunk_selection,
out_selection=out_selection,
is_complete_chunk=is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
Expand Down Expand Up @@ -469,12 +471,12 @@ async def _encode_single(

await self.codec_pipeline.write(
[
(
_ShardingByteSetter(shard_builder, chunk_coords),
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
WriteBatchInfo(
byte_operator=_ShardingByteSetter(shard_builder, chunk_coords),
array_spec=chunk_spec,
chunk_selection=chunk_selection,
out_selection=out_selection,
is_complete_chunk=is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
Expand Down Expand Up @@ -515,12 +517,12 @@ async def _encode_partial_single(

await self.codec_pipeline.write(
[
(
_ShardingByteSetter(shard_dict, chunk_coords),
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
WriteBatchInfo(
byte_operator=_ShardingByteSetter(shard_dict, chunk_coords),
array_spec=chunk_spec,
chunk_selection=chunk_selection,
out_selection=out_selection,
is_complete_chunk=is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
Expand Down
36 changes: 23 additions & 13 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from typing_extensions import deprecated

import zarr
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
from zarr.abc.codec import (
ArrayArrayCodec,
ArrayBytesCodec,
BytesBytesCodec,
Codec,
ReadBatchInfo,
WriteBatchInfo,
)
from zarr.abc.numcodec import Numcodec, _is_numcodec
from zarr.codecs._v2 import V2Codec
from zarr.codecs.bytes import BytesCodec
Expand Down Expand Up @@ -1564,12 +1571,15 @@ async def _get_selection(
# reading chunks and decoding them
await self.codec_pipeline.read(
[
(
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
self.metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype),
chunk_selection,
out_selection,
is_complete_chunk,
ReadBatchInfo(
byte_operator=self.store_path
/ self.metadata.encode_chunk_key(chunk_coords),
array_spec=self.metadata.get_chunk_spec(
chunk_coords, _config, prototype=prototype
),
chunk_selection=chunk_selection,
out_selection=out_selection,
is_complete_chunk=is_complete_chunk,
)
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
],
Expand Down Expand Up @@ -1735,12 +1745,12 @@ async def _set_selection(
# merging with existing data and encoding chunks
await self.codec_pipeline.write(
[
(
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
self.metadata.get_chunk_spec(chunk_coords, _config, prototype),
chunk_selection,
out_selection,
is_complete_chunk,
WriteBatchInfo(
byte_operator=self.store_path / self.metadata.encode_chunk_key(chunk_coords),
array_spec=self.metadata.get_chunk_spec(chunk_coords, _config, prototype),
chunk_selection=chunk_selection,
out_selection=out_selection,
is_complete_chunk=is_complete_chunk,
)
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
],
Expand Down
Loading