diff --git a/src/agents/tool.py b/src/agents/tool.py index fc3194613..692d8b790 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import inspect import json import weakref @@ -900,9 +901,9 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: result = await the_func(*args, **kwargs_dict) else: if schema.takes_context: - result = the_func(ctx, *args, **kwargs_dict) + result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict) else: - result = the_func(*args, **kwargs_dict) + result = await asyncio.to_thread(the_func, *args, **kwargs_dict) if _debug.DONT_LOG_TOOL_DATA: logger.debug(f"Tool {schema.name} completed.") diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 3597f48c3..01d1e3c6b 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -1,5 +1,8 @@ +import asyncio +import contextlib import json -from typing import Any +import time +from typing import Any, Callable import pytest from pydantic import BaseModel @@ -87,6 +90,70 @@ async def test_simple_function(): ) +@pytest.mark.asyncio +async def test_sync_function_runs_via_to_thread(monkeypatch: pytest.MonkeyPatch) -> None: + calls = {"to_thread": 0, "func": 0} + + def sync_func() -> str: + calls["func"] += 1 + return "ok" + + async def fake_to_thread( + func: Callable[..., Any], + /, + *args: Any, + **kwargs: Any, + ) -> Any: + calls["to_thread"] += 1 + return func(*args, **kwargs) + + monkeypatch.setattr(asyncio, "to_thread", fake_to_thread) + + tool = function_tool(sync_func) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ) + assert result == "ok" + assert calls["to_thread"] == 1 + assert calls["func"] == 1 + + +@pytest.mark.asyncio +async def test_sync_function_does_not_block_event_loop() -> None: + def sync_func() -> str: + time.sleep(0.2) + return "ok" + + tool = function_tool(sync_func) + + async def run_tool() -> Any: + return await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ) + + tool_task: asyncio.Task[Any] = asyncio.create_task(run_tool()) + background_task: asyncio.Task[None] = asyncio.create_task(asyncio.sleep(0.01)) + + done, pending = await asyncio.wait( + {tool_task, background_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + try: + assert background_task in done + assert tool_task in pending + assert await tool_task == "ok" + finally: + if not background_task.done(): + background_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await background_task + if not tool_task.done(): + tool_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await tool_task + + class Foo(BaseModel): a: int b: int = 5