diff --git a/google/genai/_interactions/_files.py b/google/genai/_interactions/_files.py index f2ddd9cc5..2e207aef5 100644 --- a/google/genai/_interactions/_files.py +++ b/google/genai/_interactions/_files.py @@ -42,13 +42,20 @@ def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: def is_file_content(obj: object) -> TypeGuard[FileContent]: return ( - isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + isinstance(obj, bytes) + or isinstance(obj, tuple) + or isinstance(obj, io.IOBase) + or isinstance(obj, os.PathLike) ) def assert_is_file_content(obj: object, *, key: str | None = None) -> None: if not is_file_content(obj): - prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" + prefix = ( + f"Expected entry at `{key}`" + if key is not None + else f"Expected file input `{obj!r}`" + ) raise RuntimeError( f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead." ) from None @@ -71,7 +78,9 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: elif is_sequence_t(files): files = [(key, _transform_file(file)) for key, file in files] else: - raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") + raise TypeError( + f"Unexpected file type input {type(files)}, expected mapping or sequence" + ) return files @@ -80,19 +89,23 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes: if is_file_content(file): if isinstance(file, os.PathLike): path = pathlib.Path(file) - return (path.name, path.read_bytes()) + # Return an open file handle instead of loading entire file into memory. + # This prevents OOM errors for large files. httpx supports IO[bytes] directly. + return (path.name, open(path, "rb")) return file if is_tuple_t(file): return (file[0], read_file_content(file[1]), *file[2:]) - raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + raise TypeError( + f"Expected file types input to be a FileContent type or to be a tuple" + ) def read_file_content(file: FileContent) -> HttpxFileContent: if isinstance(file, os.PathLike): - return pathlib.Path(file).read_bytes() + return open(pathlib.Path(file), "rb") return file @@ -113,7 +126,9 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles elif is_sequence_t(files): files = [(key, await _async_transform_file(file)) for key, file in files] else: - raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence") + raise TypeError( + "Unexpected file type input {type(files)}, expected mapping or sequence" + ) return files @@ -121,19 +136,21 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles async def _async_transform_file(file: FileTypes) -> HttpxFileTypes: if is_file_content(file): if isinstance(file, os.PathLike): - path = anyio.Path(file) - return (path.name, await path.read_bytes()) + path = pathlib.Path(file) + return (path.name, open(path, "rb")) return file if is_tuple_t(file): return (file[0], await async_read_file_content(file[1]), *file[2:]) - raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + raise TypeError( + f"Expected file types input to be a FileContent type or to be a tuple" + ) async def async_read_file_content(file: FileContent) -> HttpxFileContent: if isinstance(file, os.PathLike): - return await anyio.Path(file).read_bytes() + return open(pathlib.Path(file), "rb") return file diff --git a/google/genai/_interactions/_utils/_transform.py b/google/genai/_interactions/_utils/_transform.py index bb7db1865..0400edb54 100644 --- a/google/genai/_interactions/_utils/_transform.py +++ b/google/genai/_interactions/_utils/_transform.py @@ -20,7 +20,12 @@ import pathlib from typing import Any, Mapping, TypeVar, cast from datetime import date, datetime -from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints +from typing_extensions import ( + Literal, + get_args, + override, + get_type_hints as _get_type_hints, +) import anyio import pydantic @@ -196,15 +201,26 @@ def _transform_recursive( if origin == dict and is_mapping(data): items_type = get_args(stripped_type)[1] - return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + return { + key: _transform_recursive(value, annotation=items_type) + for key, value in data.items() + } if ( # List[T] (is_list_type(stripped_type) and is_list(data)) # Iterable[T] - or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + or ( + is_iterable_type(stripped_type) + and is_iterable(data) + and not isinstance(data, str) + ) # Sequence[T] - or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) + or ( + is_sequence_type(stripped_type) + and is_sequence(data) + and not isinstance(data, str) + ) ): # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually # intended as an iterable, so we don't transform it. @@ -221,7 +237,10 @@ def _transform_recursive( return data return list(data) - return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + return [ + _transform_recursive(d, annotation=annotation, inner_type=inner_type) + for d in data + ] if is_union_type(stripped_type): # For union types we run the transformation against all subtypes to ensure that everything is transformed. @@ -248,7 +267,9 @@ def _transform_recursive( return data -def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: +def _format_data( + data: object, format_: PropertyFormat, format_template: str | None +) -> object: if isinstance(data, (date, datetime)): if format_ == "iso8601": return data.isoformat() @@ -257,22 +278,35 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N return data.strftime(format_template) if format_ == "base64" and is_base64_file_input(data): - binary: str | bytes | None = None - - if isinstance(data, pathlib.Path): - binary = data.read_bytes() - elif isinstance(data, io.IOBase): - binary = data.read() + return _encode_file_to_base64(data) - if isinstance(binary, str): # type: ignore[unreachable] - binary = binary.encode() - - if not isinstance(binary, bytes): - raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + return data - return base64.b64encode(binary).decode("ascii") - return data +def _encode_file_to_base64(data: object) -> str: + """Encode file content to base64 using chunked reading to reduce peak memory usage.""" + CHUNK_SIZE = 3 * 1024 * 1024 # 3MB (must be multiple of 3 for base64) + chunks: list[str] = [] + + if isinstance(data, pathlib.Path): + with open(data, "rb") as f: + while True: + chunk = f.read(CHUNK_SIZE) + if not chunk: + break + chunks.append(base64.b64encode(chunk).decode("ascii")) + elif isinstance(data, io.IOBase): + while True: + chunk = data.read(CHUNK_SIZE) + if not chunk: + break + if isinstance(chunk, str): + chunk = chunk.encode() + chunks.append(base64.b64encode(chunk).decode("ascii")) + else: + raise RuntimeError(f"Could not read bytes from {data}; Received {type(data)}") + + return "".join(chunks) def _transform_typeddict( @@ -292,7 +326,9 @@ def _transform_typeddict( # we do not have a type annotation for this field, leave it as is result[key] = value else: - result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) + result[_maybe_transform_key(key, type_)] = _transform_recursive( + value, annotation=type_ + ) return result @@ -328,7 +364,9 @@ class Params(TypedDict, total=False): It should be noted that the transformations that this function does are not represented in the type system. """ - transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) + transformed = await _async_transform_recursive( + data, annotation=cast(type, expected_type) + ) return cast(_T, transformed) @@ -362,15 +400,26 @@ async def _async_transform_recursive( if origin == dict and is_mapping(data): items_type = get_args(stripped_type)[1] - return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + return { + key: _transform_recursive(value, annotation=items_type) + for key, value in data.items() + } if ( # List[T] (is_list_type(stripped_type) and is_list(data)) # Iterable[T] - or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + or ( + is_iterable_type(stripped_type) + and is_iterable(data) + and not isinstance(data, str) + ) # Sequence[T] - or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) + or ( + is_sequence_type(stripped_type) + and is_sequence(data) + and not isinstance(data, str) + ) ): # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually # intended as an iterable, so we don't transform it. @@ -387,7 +436,12 @@ async def _async_transform_recursive( return data return list(data) - return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + return [ + await _async_transform_recursive( + d, annotation=annotation, inner_type=inner_type + ) + for d in data + ] if is_union_type(stripped_type): # For union types we run the transformation against all subtypes to ensure that everything is transformed. @@ -395,7 +449,9 @@ async def _async_transform_recursive( # TODO: there may be edge cases where the same normalized field name will transform to two different names # in different subtypes. for subtype in get_args(stripped_type): - data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) + data = await _async_transform_recursive( + data, annotation=annotation, inner_type=subtype + ) return data if isinstance(data, pydantic.BaseModel): @@ -409,12 +465,16 @@ async def _async_transform_recursive( annotations = get_args(annotated_type)[1:] for annotation in annotations: if isinstance(annotation, PropertyInfo) and annotation.format is not None: - return await _async_format_data(data, annotation.format, annotation.format_template) + return await _async_format_data( + data, annotation.format, annotation.format_template + ) return data -async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: +async def _async_format_data( + data: object, format_: PropertyFormat, format_template: str | None +) -> object: if isinstance(data, (date, datetime)): if format_ == "iso8601": return data.isoformat() @@ -423,22 +483,35 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ return data.strftime(format_template) if format_ == "base64" and is_base64_file_input(data): - binary: str | bytes | None = None - - if isinstance(data, pathlib.Path): - binary = await anyio.Path(data).read_bytes() - elif isinstance(data, io.IOBase): - binary = data.read() - - if isinstance(binary, str): # type: ignore[unreachable] - binary = binary.encode() + return await _async_encode_file_to_base64(data) - if not isinstance(binary, bytes): - raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + return data - return base64.b64encode(binary).decode("ascii") - return data +async def _async_encode_file_to_base64(data: object) -> str: + """Encode file content to base64 using chunked reading to reduce peak memory usage.""" + CHUNK_SIZE = 3 * 1024 * 1024 # 3MB (must be multiple of 3 for base64) + chunks: list[str] = [] + + if isinstance(data, pathlib.Path): + async with await anyio.Path(data).open("rb") as f: + while True: + chunk = await f.read(CHUNK_SIZE) + if not chunk: + break + chunks.append(base64.b64encode(chunk).decode("ascii")) + elif isinstance(data, io.IOBase): + while True: + chunk = data.read(CHUNK_SIZE) + if not chunk: + break + if isinstance(chunk, str): + chunk = chunk.encode() + chunks.append(base64.b64encode(chunk).decode("ascii")) + else: + raise RuntimeError(f"Could not read bytes from {data}; Received {type(data)}") + + return "".join(chunks) async def _async_transform_typeddict( @@ -458,7 +531,9 @@ async def _async_transform_typeddict( # we do not have a type annotation for this field, leave it as is result[key] = value else: - result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) + result[_maybe_transform_key(key, type_)] = await _async_transform_recursive( + value, annotation=type_ + ) return result @@ -469,4 +544,6 @@ def get_type_hints( localns: Mapping[str, Any] | None = None, include_extras: bool = False, ) -> dict[str, Any]: - return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) + return _get_type_hints( + obj, globalns=globalns, localns=localns, include_extras=include_extras + ) diff --git a/google/genai/_interactions/_utils/_utils.py b/google/genai/_interactions/_utils/_utils.py index 31daaec5d..764055ba0 100644 --- a/google/genai/_interactions/_utils/_utils.py +++ b/google/genai/_interactions/_utils/_utils.py @@ -92,7 +92,9 @@ def _extract_items( if is_list(obj): files: list[tuple[str, FileTypes]] = [] for entry in obj: - assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "") + assert_is_file_content( + entry, key=flattened_key + "[]" if flattened_key else "" + ) files.append((flattened_key + "[]", cast(FileTypes, entry))) return files @@ -132,7 +134,9 @@ def _extract_items( item, path, index=index, - flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", + flattened_key=flattened_key + "[]" + if flattened_key is not None + else "[]", ) for item in obj ] @@ -282,7 +286,12 @@ def wrapper(*args: object, **kwargs: object) -> object: else: # no break if len(variants) > 1: variations = human_join( - ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] + [ + "(" + + human_join([quote(arg) for arg in variant], final="and") + + ")" + for variant in variants + ] ) msg = f"Missing required arguments; Expected either {variations} arguments to be given" else: @@ -380,9 +389,8 @@ def removesuffix(string: str, suffix: str) -> str: def file_from_path(path: str) -> FileTypes: - contents = Path(path).read_bytes() file_name = os.path.basename(path) - return (file_name, contents) + return (file_name, open(Path(path), "rb")) def get_required_header(headers: HeadersLike, header: str) -> str: @@ -394,7 +402,11 @@ def get_required_header(headers: HeadersLike, header: str) -> str: return v # to deal with the case where the header looks like Stainless-Event-Id - intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) + intercaps_header = re.sub( + r"([^\w])(\w)", + lambda pat: pat.group(1) + pat.group(2).upper(), + header.capitalize(), + ) for normalized_header in [header, lower_header, header.upper(), intercaps_header]: value = headers.get(normalized_header)