diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index cc52a394d..5bebf159b 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -552,29 +552,19 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None: content_index=current_item_content_index, ) ) - if not self._ongoing_response: - logger.debug( - "Skipping truncate because no response is in progress. " - f"Item id: {current_item_id}, " - f"elapsed ms: {elapsed_ms}, " - f"content index: {current_item_content_index}" - ) - else: - max_audio_ms: int | None = None - audio_limits = self._get_audio_limits( - current_item_id, current_item_content_index - ) - if audio_limits is not None: - _, max_audio_ms = audio_limits - truncated_ms = max(int(elapsed_ms), 0) - if max_audio_ms is not None: - truncated_ms = min(truncated_ms, max_audio_ms) - converted = _ConversionHelper.convert_interrupt( - current_item_id, - current_item_content_index, - truncated_ms, - ) - await self._send_raw_message(converted) + max_audio_ms: int | None = None + audio_limits = self._get_audio_limits(current_item_id, current_item_content_index) + if audio_limits is not None: + _, max_audio_ms = audio_limits + truncated_ms = max(int(elapsed_ms), 0) + if max_audio_ms is not None: + truncated_ms = min(truncated_ms, max_audio_ms) + converted = _ConversionHelper.convert_interrupt( + current_item_id, + current_item_content_index, + truncated_ms, + ) + await self._send_raw_message(converted) else: logger.debug( "Didn't interrupt bc elapsed ms is < 0. " @@ -779,21 +769,24 @@ async def _handle_ws_event(self, event: dict[str, Any]): effective_elapsed_ms = float(elapsed_override) if playback_item_id and effective_elapsed_ms is not None: - if not self._ongoing_response: + max_audio_ms: int | None = None + audio_limits = self._get_audio_limits(playback_item_id, playback_content_index) + if audio_limits is not None: + _, max_audio_ms = audio_limits + truncated_ms = max(int(round(effective_elapsed_ms)), 0) + if ( + max_audio_ms is not None + and truncated_ms >= max_audio_ms + and not self._ongoing_response + ): logger.debug( - "Skipping truncate because no response is in progress. " + "Skipping truncate because playback appears complete. " f"Item id: {playback_item_id}, " f"elapsed ms: {effective_elapsed_ms}, " - f"content index: {playback_content_index}" + f"content index: {playback_content_index}, " + f"audio length ms: {max_audio_ms}" ) else: - max_audio_ms: int | None = None - audio_limits = self._get_audio_limits( - playback_item_id, playback_content_index - ) - if audio_limits is not None: - _, max_audio_ms = audio_limits - truncated_ms = max(int(round(effective_elapsed_ms)), 0) if max_audio_ms is not None: truncated_ms = min(truncated_ms, max_audio_ms) await self._send_raw_message( diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 2327be653..780f3fc52 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -1,5 +1,6 @@ import asyncio import json +from datetime import datetime, timedelta from types import SimpleNamespace from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch @@ -548,7 +549,6 @@ async def test_transcription_related_and_timeouts_and_speech_started(self, model # Prepare tracker state to simulate ongoing audio model._audio_state_tracker.set_audio_format("pcm16") model._audio_state_tracker.on_audio_delta("i1", 0, b"a" * 96) - model._ongoing_response = True # Patch sending to avoid websocket dependency monkeypatch.setattr( @@ -610,6 +610,70 @@ async def test_transcription_related_and_timeouts_and_speech_started(self, model assert "transcript_delta" in types assert "input_audio_timeout_triggered" in types + @pytest.mark.asyncio + async def test_speech_started_skips_truncate_when_audio_complete(self, model, monkeypatch): + model._audio_state_tracker.set_audio_format("pcm16") + model._audio_state_tracker.on_audio_delta("i1", 0, b"a" * 48_000) + state = model._audio_state_tracker.get_state("i1", 0) + assert state is not None + state.initial_received_time = datetime.now() - timedelta(seconds=5) + + monkeypatch.setattr( + model, + "_send_raw_message", + AsyncMock(), + ) + + await model._handle_ws_event( + { + "type": "input_audio_buffer.speech_started", + "event_id": "es2", + "item_id": "i1", + "audio_start_ms": 0, + "audio_end_ms": 0, + } + ) + + truncate_events = [ + call.args[0] + for call in model._send_raw_message.await_args_list + if getattr(call.args[0], "type", None) == "conversation.item.truncate" + ] + assert not truncate_events + + @pytest.mark.asyncio + async def test_speech_started_truncates_when_response_ongoing(self, model, monkeypatch): + model._audio_state_tracker.set_audio_format("pcm16") + model._audio_state_tracker.on_audio_delta("i1", 0, b"a" * 48_000) + state = model._audio_state_tracker.get_state("i1", 0) + assert state is not None + state.initial_received_time = datetime.now() - timedelta(seconds=5) + model._ongoing_response = True + + monkeypatch.setattr( + model, + "_send_raw_message", + AsyncMock(), + ) + + await model._handle_ws_event( + { + "type": "input_audio_buffer.speech_started", + "event_id": "es3", + "item_id": "i1", + "audio_start_ms": 0, + "audio_end_ms": 0, + } + ) + + truncate_events = [ + call.args[0] + for call in model._send_raw_message.await_args_list + if getattr(call.args[0], "type", None) == "conversation.item.truncate" + ] + assert truncate_events + assert truncate_events[0].audio_end_ms == 1000 + class TestSendEventAndConfig(TestOpenAIRealtimeWebSocketModel): @pytest.mark.asyncio diff --git a/tests/realtime/test_playback_tracker.py b/tests/realtime/test_playback_tracker.py index 48b83a8a9..16de96287 100644 --- a/tests/realtime/test_playback_tracker.py +++ b/tests/realtime/test_playback_tracker.py @@ -28,7 +28,6 @@ async def test_interrupt_timing_with_custom_playback_tracker(self, model): # Set up model with custom tracker directly model._playback_tracker = custom_tracker - model._ongoing_response = True # Mock send_raw_message to capture interrupt model._send_raw_message = AsyncMock() @@ -63,7 +62,6 @@ async def test_interrupt_clamps_elapsed_to_audio_length(self, model): """Test interrupt clamps elapsed time to the received audio length.""" model._send_raw_message = AsyncMock() model._audio_state_tracker.set_audio_format("pcm16") - model._ongoing_response = True # 48_000 bytes of PCM16 at 24kHz equals ~1000ms of audio. model._audio_state_tracker.on_audio_delta("item_1", 0, b"a" * 48_000)