Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
if you want to use these Rapid Storage APIs.

"""
from io import BufferedReader
from typing import Optional, Union
from io import BufferedReader, BytesIO
import asyncio
from typing import List, Optional, Tuple, Union

from google_crc32c import Checksum
from google.api_core import exceptions
from google.api_core.retry_async import AsyncRetry
from google.rpc import status_pb2
from google.cloud._storage_v2.types import BidiWriteObjectRedirectedError


from ._utils import raise_if_no_fast_crc32c
from google.cloud import _storage_v2
Expand All @@ -35,10 +40,58 @@
from google.cloud.storage._experimental.asyncio.async_write_object_stream import (
_AsyncWriteObjectStream,
)
from google.cloud.storage._experimental.asyncio.retry.bidi_stream_retry_manager import (
_BidiStreamRetryManager,
)
from google.cloud.storage._experimental.asyncio.retry.writes_resumption_strategy import (
_WriteResumptionStrategy,
_WriteState,
)


_MAX_CHUNK_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB
_DEFAULT_FLUSH_INTERVAL_BYTES = 16 * 1024 * 1024 # 16 MiB
_BIDI_WRITE_REDIRECTED_TYPE_URL = (
"type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError"
)


def _is_write_retryable(exc):
"""Predicate to determine if a write operation should be retried."""
if isinstance(
exc,
(
exceptions.InternalServerError,
exceptions.ServiceUnavailable,
exceptions.DeadlineExceeded,
exceptions.TooManyRequests,
),
):
return True

grpc_error = None
if isinstance(exc, exceptions.Aborted):
grpc_error = exc.errors[0]
trailers = grpc_error.trailing_metadata()
if not trailers:
return False

status_details_bin = None
for key, value in trailers:
if key == "grpc-status-details-bin":
status_details_bin = value
break

if status_details_bin:
status_proto = status_pb2.Status()
try:
status_proto.ParseFromString(status_details_bin)
for detail in status_proto.details:
if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL:
return True
except Exception:
return False
Comment on lines +92 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The broad except Exception: can hide unexpected errors during the parsing of the status details. It would be better to catch a more specific exception, such as google.protobuf.message.DecodeError, to ensure other issues aren't silently ignored.

return False


class AsyncAppendableObjectWriter:
Expand Down Expand Up @@ -114,13 +167,7 @@ def __init__(
self.write_handle = write_handle
self.generation = generation

self.write_obj_stream = _AsyncWriteObjectStream(
client=self.client,
bucket_name=self.bucket_name,
object_name=self.object_name,
generation_number=self.generation,
write_handle=self.write_handle,
)
self.write_obj_stream: Optional[_AsyncWriteObjectStream] = None
self._is_stream_open: bool = False
# `offset` is the latest size of the object without staleless.
self.offset: Optional[int] = None
Expand All @@ -143,6 +190,8 @@ def __init__(
f"flush_interval must be a multiple of {_MAX_CHUNK_SIZE_BYTES}, but provided {self.flush_interval}"
)
self.bytes_appended_since_last_flush = 0
self._lock = asyncio.Lock()
self._routing_token: Optional[str] = None

async def state_lookup(self) -> int:
"""Returns the persisted_size
Expand All @@ -165,7 +214,55 @@ async def state_lookup(self) -> int:
self.persisted_size = response.persisted_size
return self.persisted_size

async def open(self) -> None:
def _on_open_error(self, exc):
"""Extracts routing token and write handle on redirect error during open."""
grpc_error = None
if isinstance(exc, exceptions.Aborted) and exc.errors:
grpc_error = exc.errors[0]

if grpc_error:
if isinstance(grpc_error, BidiWriteObjectRedirectedError):
self._routing_token = grpc_error.routing_token
if grpc_error.write_handle:
self.write_handle = grpc_error.write_handle
return

if hasattr(grpc_error, "trailing_metadata"):
trailers = grpc_error.trailing_metadata()
if not trailers:
return

status_details_bin = None
for key, value in trailers:
if key == "grpc-status-details-bin":
status_details_bin = value
break

if status_details_bin:
status_proto = status_pb2.Status()
try:
status_proto.ParseFromString(status_details_bin)
for detail in status_proto.details:
if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL:
redirect_proto = (
BidiWriteObjectRedirectedError.deserialize(
detail.value
)
)
if redirect_proto.routing_token:
self._routing_token = redirect_proto.routing_token
if redirect_proto.write_handle:
self.write_handle = redirect_proto.write_handle
break
except Exception:
# Could not parse the error, ignore
pass
Comment on lines +257 to +259
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: is risky as it can mask bugs. Please consider catching a more specific exception, like google.protobuf.message.DecodeError, to make the error handling more robust. Silently passing on any exception could hide important issues.


async def open(
self,
retry_policy: Optional[AsyncRetry] = None,
metadata: Optional[List[Tuple[str, str]]] = None,
) -> None:
"""Opens the underlying bidi-gRPC stream.

:raises ValueError: If the stream is already open.
Expand All @@ -174,62 +271,190 @@ async def open(self) -> None:
if self._is_stream_open:
raise ValueError("Underlying bidi-gRPC stream is already open")

await self.write_obj_stream.open()
self._is_stream_open = True
if self.generation is None:
self.generation = self.write_obj_stream.generation_number
self.write_handle = self.write_obj_stream.write_handle
self.persisted_size = self.write_obj_stream.persisted_size
if retry_policy is None:
retry_policy = AsyncRetry(
predicate=_is_write_retryable, on_error=self._on_open_error
)
else:
original_on_error = retry_policy._on_error

def combined_on_error(exc):
self._on_open_error(exc)
if original_on_error:
original_on_error(exc)

retry_policy = retry_policy.with_predicate(
_is_write_retryable
).with_on_error(combined_on_error)

async def _do_open():
current_metadata = list(metadata) if metadata else []

# Cleanup stream from previous failed attempt, if any.
if self.write_obj_stream:
if self._is_stream_open:
try:
await self.write_obj_stream.close()
except Exception: # ignore cleanup errors
pass
self.write_obj_stream = None
self._is_stream_open = False

self.write_obj_stream = _AsyncWriteObjectStream(
client=self.client,
bucket_name=self.bucket_name,
object_name=self.object_name,
generation_number=self.generation,
write_handle=self.write_handle,
)

async def append(self, data: bytes) -> None:
"""Appends data to the Appendable object.
if self._routing_token:
current_metadata.append(
("x-goog-request-params", f"routing_token={self._routing_token}")
)
self._routing_token = None

await self.write_obj_stream.open(
metadata=current_metadata if metadata else None
)

calling `self.append` will append bytes at the end of the current size
ie. `self.offset` bytes relative to the begining of the object.
if self.write_obj_stream.generation_number:
self.generation = self.write_obj_stream.generation_number
if self.write_obj_stream.write_handle:
self.write_handle = self.write_obj_stream.write_handle
if self.write_obj_stream.persisted_size is not None:
self.persisted_size = self.write_obj_stream.persisted_size

This method sends the provided `data` to the GCS server in chunks.
and persists data in GCS at every `_MAX_BUFFER_SIZE_BYTES` bytes by
calling `self.simple_flush`.
self._is_stream_open = True

await retry_policy(_do_open)()

async def _upload_with_retry(
self,
data: bytes,
retry_policy: Optional[AsyncRetry] = None,
metadata: Optional[List[Tuple[str, str]]] = None,
) -> None:
if not self._is_stream_open:
raise ValueError("Underlying bidi-gRPC stream is not open")

if retry_policy is None:
retry_policy = AsyncRetry(predicate=_is_write_retryable)

# Initialize Global State for Retry Strategy
spec = _storage_v2.AppendObjectSpec(
bucket=self.bucket_name,
object=self.object_name,
generation=self.generation,
)
buffer = BytesIO(data)
write_state = _WriteState(
spec=spec,
chunk_size=_MAX_CHUNK_SIZE_BYTES,
user_buffer=buffer,
)
write_state.write_handle = self.write_handle

initial_state = {
"write_state": write_state,
"first_request": True,
}

# Track attempts to manage stream reuse
attempt_count = 0

def stream_opener(
requests,
state,
metadata: Optional[List[Tuple[str, str]]] = None,
):
async def generator():
nonlocal attempt_count
attempt_count += 1

async with self._lock:
current_handle = state["write_state"].write_handle
current_token = state["write_state"].routing_token

should_reopen = (attempt_count > 1) or (current_token is not None)

if should_reopen:
if self.write_obj_stream and self.write_obj_stream._is_stream_open:
await self.write_obj_stream.close()

self.write_obj_stream = _AsyncWriteObjectStream(
client=self.client,
bucket_name=self.bucket_name,
object_name=self.object_name,
generation_number=self.generation,
write_handle=current_handle,
)

current_metadata = list(metadata) if metadata else []
if current_token:
current_metadata.append(
(
"x-goog-request-params",
f"routing_token={current_token}",
)
)

await self.write_obj_stream.open(
metadata=current_metadata if current_metadata else None
)
self._is_stream_open = True

# Let the strategy generate the request sequence
async for request in requests:
await self.write_obj_stream.send(request)

# Signal that we are done sending requests.
await self.write_obj_stream.requests.put(None)

# Process responses
async for response in self.write_obj_stream:
yield response

return generator()

strategy = _WriteResumptionStrategy()
retry_manager = _BidiStreamRetryManager(
strategy, lambda r, s: stream_opener(r, s, metadata=metadata)
)

await retry_manager.execute(initial_state, retry_policy)

# Update the writer's state from the strategy's final state
final_write_state = initial_state["write_state"]
self.persisted_size = final_write_state.persisted_size
self.write_handle = final_write_state.write_handle
self.offset = self.persisted_size

async def append(
self,
data: bytes,
retry_policy: Optional[AsyncRetry] = None,
metadata: Optional[List[Tuple[str, str]]] = None,
) -> None:
"""Appends data to the Appendable object with automatic retries.

:type data: bytes
:param data: The bytes to append to the object.

:rtype: None
:type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
:param retry_policy: (Optional) The retry policy to use for the operation.

:raises ValueError: If the stream is not open (i.e., `open()` has not
been called).
"""
:type metadata: List[Tuple[str, str]]
:param metadata: (Optional) The metadata to be sent with the request.

:raises ValueError: If the stream is not open.
"""
if not self._is_stream_open:
raise ValueError("Stream is not open. Call open() before append().")
total_bytes = len(data)
if total_bytes == 0:
# TODO: add warning.
return
if self.offset is None:
assert self.persisted_size is not None
self.offset = self.persisted_size

start_idx = 0
while start_idx < total_bytes:
end_idx = min(start_idx + _MAX_CHUNK_SIZE_BYTES, total_bytes)
data_chunk = data[start_idx:end_idx]
await self.write_obj_stream.send(
_storage_v2.BidiWriteObjectRequest(
write_offset=self.offset,
checksummed_data=_storage_v2.ChecksummedData(
content=data_chunk,
crc32c=int.from_bytes(Checksum(data_chunk).digest(), "big"),
),
)
)
chunk_size = end_idx - start_idx
self.offset += chunk_size
self.bytes_appended_since_last_flush += chunk_size
if self.bytes_appended_since_last_flush >= self.flush_interval:
await self.simple_flush()
self.bytes_appended_since_last_flush = 0
start_idx = end_idx
if not data:
return # Do nothing for empty data

await self._upload_with_retry(data, retry_policy, metadata)

async def simple_flush(self) -> None:
"""Flushes the data to the server.
Expand Down
Loading