Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/agent_patterns/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ For example, you could model the translation task above as tool calls instead: r

See the [`agents_as_tools.py`](./agents_as_tools.py) file for an example of this.
See the [`agents_as_tools_streaming.py`](./agents_as_tools_streaming.py) file for a streaming variant that taps into nested agent events via `on_stream`.
See the [`agents_as_tools_structured.py`](./agents_as_tools_structured.py) file for a structured-input variant using `Agent.as_tool()` parameters.

## LLM-as-a-judge

Expand Down
64 changes: 64 additions & 0 deletions examples/agent_patterns/agents_as_tools_structured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import asyncio

from pydantic import BaseModel, Field

from agents import Agent, Runner

"""
This example shows structured input for agent-as-tool calls.
"""


class TranslationInput(BaseModel):
text: str = Field(description="Text to translate.")
source: str = Field(description="Source language code or name.")
target: str = Field(description="Target language code or name.")


translator = Agent(
name="translator",
instructions=(
"Translate the input text into the target language. "
"If the target is not clear, ask the user for clarification."
),
)

orchestrator = Agent(
name="orchestrator",
instructions=(
"You are a task dispatcher. Always call the tool with sufficient input. "
"Do not handle the translation yourself."
),
tools=[
translator.as_tool(
tool_name="translate_text",
tool_description=(
"Translate text between languages. Provide text, source language, "
"and target language."
),
parameters=TranslationInput,
# By default, the input schema will be included in a simpler format.
# Set include_input_schema to true to include the full JSON Schema:
# include_input_schema=True,
# Build a custom prompt from structured input data:
# input_builder=lambda options: (
# f'Translate the text "{options["params"]["text"]}" '
# f'from {options["params"]["source"]} to {options["params"]["target"]}.'
# ),
)
],
)


async def main() -> None:
query = 'Translate "Hola" from Spanish to French.'

response1 = await Runner.run(translator, query)
print(f"Translator agent direct run: {response1.final_output}")

response2 = await Runner.run(orchestrator, query)
print(f"Translator agent as tool: {response2.final_output}")


if __name__ == "__main__":
asyncio.run(main())
165 changes: 150 additions & 15 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,29 @@
import asyncio
import dataclasses
import inspect
import json
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast

from openai.types.responses.response_prompt_param import ResponsePromptParam
from pydantic import BaseModel, TypeAdapter, ValidationError
from typing_extensions import NotRequired, TypeAlias, TypedDict

from . import _debug
from .agent_output import AgentOutputSchemaBase
from .agent_tool_input import (
AgentAsToolInput,
StructuredToolInputBuilder,
build_structured_input_schema_info,
resolve_agent_tool_input,
)
from .agent_tool_state import (
consume_agent_tool_run_result,
peek_agent_tool_run_result,
record_agent_tool_run_result,
)
from .exceptions import ModelBehaviorError
from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
from .logger import logger
Expand All @@ -29,16 +39,18 @@
from .models.interface import Model
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
from .run_context import RunContextWrapper, TContext
from .strict_schema import ensure_strict_json_schema
from .tool import (
FunctionTool,
FunctionToolResult,
Tool,
ToolErrorFunction,
_extract_tool_argument_json_error,
default_tool_error_function,
function_tool,
)
from .tool_context import ToolContext
from .util import _transforms
from .tracing import SpanError
from .util import _error_tracing, _transforms
from .util._types import MaybeAwaitable

if TYPE_CHECKING:
Expand Down Expand Up @@ -442,6 +454,9 @@ def as_tool(
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
needs_approval: bool
| Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] = False,
parameters: type[Any] | None = None,
input_builder: StructuredToolInputBuilder | None = None,
include_input_schema: bool = False,
) -> Tool:
"""Transform this agent into a tool, callable by other agents.

Expand All @@ -467,19 +482,85 @@ def as_tool(
failure_error_function: If provided, generate an error message when the tool (agent) run
fails. The message is sent to the LLM. If None, the exception is raised instead.
needs_approval: Bool or callable to decide if this agent tool should pause for approval.
parameters: Structured input type for the tool arguments (dataclass or Pydantic model).
input_builder: Optional function to build the nested agent input from structured data.
include_input_schema: Whether to include the full JSON schema in structured input.
"""

@function_tool(
name_override=tool_name or _transforms.transform_string_function_style(self.name),
description_override=tool_description or "",
is_enabled=is_enabled,
needs_approval=needs_approval,
failure_error_function=failure_error_function,
def _is_supported_parameters(value: Any) -> bool:
if not isinstance(value, type):
return False
if dataclasses.is_dataclass(value):
return True
return issubclass(value, BaseModel)

tool_name_resolved = tool_name or _transforms.transform_string_function_style(self.name)
tool_description_resolved = tool_description or ""
has_custom_parameters = parameters is not None
include_schema = bool(include_input_schema and has_custom_parameters)
should_capture_tool_input = bool(
has_custom_parameters or include_schema or input_builder is not None
)

if parameters is None:
params_adapter = TypeAdapter(AgentAsToolInput)
params_schema = ensure_strict_json_schema(params_adapter.json_schema())
else:
if not _is_supported_parameters(parameters):
raise TypeError("Agent tool parameters must be a dataclass or Pydantic model type.")
params_adapter = TypeAdapter(parameters)
params_schema = ensure_strict_json_schema(params_adapter.json_schema())

schema_info = build_structured_input_schema_info(
params_schema,
include_json_schema=include_schema,
)
async def run_agent(context: ToolContext, input: str) -> Any:

def _normalize_tool_input(parsed: Any) -> Any:
# Prefer JSON mode so structured params (datetime/UUID/Decimal, etc.) serialize cleanly.
try:
return params_adapter.dump_python(parsed, mode="json")
except Exception as exc:
raise ModelBehaviorError(
f"Failed to serialize structured tool input for {tool_name_resolved}: {exc}"
) from exc

async def _run_agent_impl(context: ToolContext, input_json: str) -> Any:
from .run import DEFAULT_MAX_TURNS, Runner
from .tool_context import ToolContext

try:
json_data = json.loads(input_json) if input_json else {}
except Exception as exc:
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invalid JSON input for tool {tool_name_resolved}")
else:
logger.debug(f"Invalid JSON input for tool {tool_name_resolved}: {input_json}")
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool_name_resolved}: {input_json}"
) from exc

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invoking tool {tool_name_resolved}")
else:
logger.debug(f"Invoking tool {tool_name_resolved} with input {input_json}")

try:
parsed_params = params_adapter.validate_python(json_data)
except ValidationError as exc:
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool_name_resolved}: {exc}"
) from exc

params_data = _normalize_tool_input(parsed_params)
resolved_input = await resolve_agent_tool_input(
params=params_data,
schema_info=schema_info if should_capture_tool_input else None,
input_builder=input_builder,
)
if not isinstance(resolved_input, str) and not isinstance(resolved_input, list):
raise ModelBehaviorError("Agent tool called with invalid input")

resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
if isinstance(context, ToolContext):
# Use a fresh ToolContext to avoid sharing approval state with parent runs.
Expand All @@ -491,10 +572,20 @@ async def run_agent(context: ToolContext, input: str) -> Any:
tool_arguments=context.tool_arguments,
tool_call=context.tool_call,
)
if should_capture_tool_input:
nested_context.tool_input = params_data
elif isinstance(context, RunContextWrapper):
nested_context = context.context
if should_capture_tool_input:
nested_context = RunContextWrapper(context=context.context)
nested_context.tool_input = params_data
else:
nested_context = context.context
else:
nested_context = context
if should_capture_tool_input:
nested_context = RunContextWrapper(context=context)
nested_context.tool_input = params_data
else:
nested_context = context
run_result: RunResult | RunResultStreaming | None = None
resume_state: RunState | None = None
should_record_run_result = True
Expand Down Expand Up @@ -575,7 +666,7 @@ def _apply_nested_approvals(
if on_stream is not None:
run_result_streaming = Runner.run_streamed(
starting_agent=cast(Agent[Any], self),
input=resume_state or input,
input=resume_state or resolved_input,
context=None if resume_state is not None else cast(Any, nested_context),
run_config=run_config,
max_turns=resolved_max_turns,
Expand Down Expand Up @@ -639,7 +730,7 @@ async def dispatch_stream_events() -> None:
else:
run_result = await Runner.run(
starting_agent=cast(Agent[Any], self),
input=resume_state or input,
input=resume_state or resolved_input,
context=None if resume_state is not None else cast(Any, nested_context),
run_config=run_config,
max_turns=resolved_max_turns,
Expand All @@ -663,8 +754,52 @@ async def dispatch_stream_events() -> None:

return run_result.final_output

# Mark the function tool as an agent tool.
run_agent_tool = run_agent
async def _run_agent_tool(context: ToolContext, input_json: str) -> Any:
try:
return await _run_agent_impl(context, input_json)
except Exception as exc:
if failure_error_function is None:
raise

result = failure_error_function(context, exc)
if inspect.isawaitable(result):
result = await result

json_decode_error = _extract_tool_argument_json_error(exc)
if json_decode_error is not None:
span_error_message = "Error running tool"
span_error_detail = str(json_decode_error)
else:
span_error_message = "Error running tool (non-fatal)"
span_error_detail = str(exc)

_error_tracing.attach_error_to_current_span(
SpanError(
message=span_error_message,
data={
"tool_name": tool_name_resolved,
"error": span_error_detail,
},
)
)
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Tool {tool_name_resolved} failed")
else:
logger.error(
f"Tool {tool_name_resolved} failed: {input_json} {exc}",
exc_info=exc,
)
return result

run_agent_tool = FunctionTool(
name=tool_name_resolved,
description=tool_description_resolved,
params_json_schema=params_schema,
on_invoke_tool=_run_agent_tool,
strict_json_schema=True,
is_enabled=is_enabled,
needs_approval=needs_approval,
)
run_agent_tool._is_agent_tool = True
run_agent_tool._agent_instance = self

Expand Down
Loading