Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import inspect
import json
import weakref
Expand Down Expand Up @@ -900,9 +901,9 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
result = await the_func(*args, **kwargs_dict)
else:
if schema.takes_context:
result = the_func(ctx, *args, **kwargs_dict)
result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict)
else:
result = the_func(*args, **kwargs_dict)
result = await asyncio.to_thread(the_func, *args, **kwargs_dict)

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Tool {schema.name} completed.")
Expand Down
69 changes: 68 additions & 1 deletion tests/test_function_tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import contextlib
import json
from typing import Any
import time
from typing import Any, Callable

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -87,6 +90,70 @@ async def test_simple_function():
)


@pytest.mark.asyncio
async def test_sync_function_runs_via_to_thread(monkeypatch: pytest.MonkeyPatch) -> None:
calls = {"to_thread": 0, "func": 0}

def sync_func() -> str:
calls["func"] += 1
return "ok"

async def fake_to_thread(
func: Callable[..., Any],
/,
*args: Any,
**kwargs: Any,
) -> Any:
calls["to_thread"] += 1
return func(*args, **kwargs)

monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)

tool = function_tool(sync_func)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
)
assert result == "ok"
assert calls["to_thread"] == 1
assert calls["func"] == 1


@pytest.mark.asyncio
async def test_sync_function_does_not_block_event_loop() -> None:
def sync_func() -> str:
time.sleep(0.2)
return "ok"

tool = function_tool(sync_func)

async def run_tool() -> Any:
return await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
)

tool_task: asyncio.Task[Any] = asyncio.create_task(run_tool())
background_task: asyncio.Task[None] = asyncio.create_task(asyncio.sleep(0.01))

done, pending = await asyncio.wait(
{tool_task, background_task},
return_when=asyncio.FIRST_COMPLETED,
)

try:
assert background_task in done
assert tool_task in pending
assert await tool_task == "ok"
finally:
if not background_task.done():
background_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await background_task
if not tool_task.done():
tool_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await tool_task


class Foo(BaseModel):
a: int
b: int = 5
Expand Down