Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 26 additions & 33 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,29 +552,19 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
content_index=current_item_content_index,
)
)
if not self._ongoing_response:
logger.debug(
"Skipping truncate because no response is in progress. "
f"Item id: {current_item_id}, "
f"elapsed ms: {elapsed_ms}, "
f"content index: {current_item_content_index}"
)
else:
max_audio_ms: int | None = None
audio_limits = self._get_audio_limits(
current_item_id, current_item_content_index
)
if audio_limits is not None:
_, max_audio_ms = audio_limits
truncated_ms = max(int(elapsed_ms), 0)
if max_audio_ms is not None:
truncated_ms = min(truncated_ms, max_audio_ms)
converted = _ConversionHelper.convert_interrupt(
current_item_id,
current_item_content_index,
truncated_ms,
)
await self._send_raw_message(converted)
max_audio_ms: int | None = None
audio_limits = self._get_audio_limits(current_item_id, current_item_content_index)
if audio_limits is not None:
_, max_audio_ms = audio_limits
truncated_ms = max(int(elapsed_ms), 0)
if max_audio_ms is not None:
truncated_ms = min(truncated_ms, max_audio_ms)
converted = _ConversionHelper.convert_interrupt(
current_item_id,
current_item_content_index,
truncated_ms,
)
await self._send_raw_message(converted)
else:
logger.debug(
"Didn't interrupt bc elapsed ms is < 0. "
Expand Down Expand Up @@ -779,21 +769,24 @@ async def _handle_ws_event(self, event: dict[str, Any]):
effective_elapsed_ms = float(elapsed_override)

if playback_item_id and effective_elapsed_ms is not None:
if not self._ongoing_response:
max_audio_ms: int | None = None
audio_limits = self._get_audio_limits(playback_item_id, playback_content_index)
if audio_limits is not None:
_, max_audio_ms = audio_limits
truncated_ms = max(int(round(effective_elapsed_ms)), 0)
if (
max_audio_ms is not None
and truncated_ms >= max_audio_ms
and not self._ongoing_response
):
logger.debug(
"Skipping truncate because no response is in progress. "
"Skipping truncate because playback appears complete. "
f"Item id: {playback_item_id}, "
f"elapsed ms: {effective_elapsed_ms}, "
f"content index: {playback_content_index}"
f"content index: {playback_content_index}, "
f"audio length ms: {max_audio_ms}"
)
else:
max_audio_ms: int | None = None
audio_limits = self._get_audio_limits(
playback_item_id, playback_content_index
)
if audio_limits is not None:
_, max_audio_ms = audio_limits
truncated_ms = max(int(round(effective_elapsed_ms)), 0)
if max_audio_ms is not None:
truncated_ms = min(truncated_ms, max_audio_ms)
await self._send_raw_message(
Expand Down
66 changes: 65 additions & 1 deletion tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import AsyncMock, Mock, patch
Expand Down Expand Up @@ -548,7 +549,6 @@ async def test_transcription_related_and_timeouts_and_speech_started(self, model
# Prepare tracker state to simulate ongoing audio
model._audio_state_tracker.set_audio_format("pcm16")
model._audio_state_tracker.on_audio_delta("i1", 0, b"a" * 96)
model._ongoing_response = True

# Patch sending to avoid websocket dependency
monkeypatch.setattr(
Expand Down Expand Up @@ -610,6 +610,70 @@ async def test_transcription_related_and_timeouts_and_speech_started(self, model
assert "transcript_delta" in types
assert "input_audio_timeout_triggered" in types

@pytest.mark.asyncio
async def test_speech_started_skips_truncate_when_audio_complete(self, model, monkeypatch):
model._audio_state_tracker.set_audio_format("pcm16")
model._audio_state_tracker.on_audio_delta("i1", 0, b"a" * 48_000)
state = model._audio_state_tracker.get_state("i1", 0)
assert state is not None
state.initial_received_time = datetime.now() - timedelta(seconds=5)

monkeypatch.setattr(
model,
"_send_raw_message",
AsyncMock(),
)

await model._handle_ws_event(
{
"type": "input_audio_buffer.speech_started",
"event_id": "es2",
"item_id": "i1",
"audio_start_ms": 0,
"audio_end_ms": 0,
}
)

truncate_events = [
call.args[0]
for call in model._send_raw_message.await_args_list
if getattr(call.args[0], "type", None) == "conversation.item.truncate"
]
assert not truncate_events

@pytest.mark.asyncio
async def test_speech_started_truncates_when_response_ongoing(self, model, monkeypatch):
model._audio_state_tracker.set_audio_format("pcm16")
model._audio_state_tracker.on_audio_delta("i1", 0, b"a" * 48_000)
state = model._audio_state_tracker.get_state("i1", 0)
assert state is not None
state.initial_received_time = datetime.now() - timedelta(seconds=5)
model._ongoing_response = True

monkeypatch.setattr(
model,
"_send_raw_message",
AsyncMock(),
)

await model._handle_ws_event(
{
"type": "input_audio_buffer.speech_started",
"event_id": "es3",
"item_id": "i1",
"audio_start_ms": 0,
"audio_end_ms": 0,
}
)

truncate_events = [
call.args[0]
for call in model._send_raw_message.await_args_list
if getattr(call.args[0], "type", None) == "conversation.item.truncate"
]
assert truncate_events
assert truncate_events[0].audio_end_ms == 1000


class TestSendEventAndConfig(TestOpenAIRealtimeWebSocketModel):
@pytest.mark.asyncio
Expand Down
2 changes: 0 additions & 2 deletions tests/realtime/test_playback_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ async def test_interrupt_timing_with_custom_playback_tracker(self, model):

# Set up model with custom tracker directly
model._playback_tracker = custom_tracker
model._ongoing_response = True

# Mock send_raw_message to capture interrupt
model._send_raw_message = AsyncMock()
Expand Down Expand Up @@ -63,7 +62,6 @@ async def test_interrupt_clamps_elapsed_to_audio_length(self, model):
"""Test interrupt clamps elapsed time to the received audio length."""
model._send_raw_message = AsyncMock()
model._audio_state_tracker.set_audio_format("pcm16")
model._ongoing_response = True

# 48_000 bytes of PCM16 at 24kHz equals ~1000ms of audio.
model._audio_state_tracker.on_audio_delta("item_1", 0, b"a" * 48_000)
Expand Down