diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e69de29bb..fa6e39aa2 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -0,0 +1 @@ +typing-extensions; python_version < "3.11" diff --git a/src/confluent_kafka/aio/_AIOConsumer.py b/src/confluent_kafka/aio/_AIOConsumer.py index 0ff986a62..81d6d0f50 100644 --- a/src/confluent_kafka/aio/_AIOConsumer.py +++ b/src/confluent_kafka/aio/_AIOConsumer.py @@ -16,6 +16,12 @@ import concurrent.futures from typing import Any, Callable, Dict, Optional, Tuple +try: + from typing import Self +except ImportError: + # FIXME: remove once we depend on Python >= 3.11 + from typing_extensions import Self + import confluent_kafka from . import _common as _common @@ -46,6 +52,12 @@ def __init__( self._consumer: confluent_kafka.Consumer = confluent_kafka.Consumer(consumer_conf) + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *_) -> None: + await self.close() + async def _call(self, blocking_task: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: return await _common.async_call(self.executor, blocking_task, *args, **kwargs) diff --git a/src/confluent_kafka/aio/producer/_AIOProducer.py b/src/confluent_kafka/aio/producer/_AIOProducer.py index a5fc28727..8556cd124 100644 --- a/src/confluent_kafka/aio/producer/_AIOProducer.py +++ b/src/confluent_kafka/aio/producer/_AIOProducer.py @@ -17,6 +17,12 @@ import logging from typing import Any, Callable, Dict, Optional +try: + from typing import Self +except ImportError: + # FIXME: remove once we depend on Python >= 3.11 + from typing_extensions import Self + import confluent_kafka from .. import _common as _common @@ -70,6 +76,12 @@ def __init__( if buffer_timeout > 0: self._buffer_timeout_manager.start_timeout_monitoring() + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *_) -> None: + await self.close() + async def close(self) -> None: """Close the producer and cleanup resources diff --git a/tests/test_AIOProducer.py b/tests/test_AIOProducer.py index 6cfa55833..75e117237 100644 --- a/tests/test_AIOProducer.py +++ b/tests/test_AIOProducer.py @@ -80,6 +80,18 @@ async def test_close_method(self, mock_producer, mock_common, basic_config): await producer2.close() assert producer2._is_closed is True + @pytest.mark.asyncio + async def test_async_context_manager(self, mock_producer, mock_common, basic_config): + with AIOProducer(basic_config) as producer: + assert producer._is_closed is False + assert producer._is_closed is True + + with AIOProducer(basic_config) as producer2: + assert producer2._is_closed is False + await producer2.close() + await producer2.close() + assert producer2._is_closed is True + @pytest.mark.asyncio async def test_call_method_executor_usage(self, mock_producer, mock_common, basic_config): producer = AIOProducer(basic_config)