Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def get_schema(ctx: ReadonlyContext) -> dict[str, Any]:
import inspect
import json
import logging
import re
from typing import Any, Awaitable, Callable, Optional, TypeAlias, Union

import jsonschema
Expand Down Expand Up @@ -252,17 +253,27 @@ async def run_async(
f" arg {self.A2UI_JSON_ARG_NAME} "
)

a2ui_json_payload = json.loads(a2ui_json)
a2ui_schema = await self.get_a2ui_schema(tool_context)
try:
# Attempt to parse and validate
a2ui_json_payload = json.loads(a2ui_json)
jsonschema.validate(instance=a2ui_json_payload, schema=a2ui_schema)

# Auto-wrap single object in list
if not isinstance(a2ui_json_payload, list):
logger.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets keep the list wrapping logic

"Received a single JSON object, wrapping in a list for validation."
)
a2ui_json_payload = [a2ui_json_payload]
except (jsonschema.exceptions.ValidationError, json.JSONDecodeError) as e:
logger.warning(f"Initial A2UI JSON validation failed: {e}")

a2ui_schema = await self.get_a2ui_schema(tool_context)
jsonschema.validate(instance=a2ui_json_payload, schema=a2ui_schema)
# Run Fixer
fixed_a2ui_json = re.sub(r",(?=\s*[\]}])", "", a2ui_json)

if fixed_a2ui_json != a2ui_json:
# Emit Warning
logger.warning("Detected trailing commas in LLM output; applied autofix.")

# Re-parse and Re-validate
a2ui_json_payload = json.loads(fixed_a2ui_json)
jsonschema.validate(instance=a2ui_json_payload, schema=a2ui_schema)
else:
raise e

logger.info(
f"Validated call to tool {self.TOOL_NAME} with {self.A2UI_JSON_ARG_NAME}"
Expand Down Expand Up @@ -328,4 +339,4 @@ def convert_send_a2ui_to_client_genai_part_to_a2a_part(
converted_part = part_converter.convert_genai_part_to_a2a_part(part)

logger.info(f"Returning converted part: {converted_part}")
return [converted_part] if converted_part else []
return [converted_part] if converted_part else []
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,33 @@ async def test_send_tool_run_async_schema_validation_fail():
assert "'text' is a required property" in result["error"]


@pytest.mark.asyncio
async def test_send_tool_run_async_handles_trailing_comma(caplog):
"""Tests that the tool's run_async can handle and fix a trailing comma in the JSON."""
tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA)
tool_context_mock = MagicMock(spec=ToolContext)
tool_context_mock.state = {}
tool_context_mock.actions = MagicMock(skip_summarization=False)

# Malformed JSON with a trailing comma in the list
malformed_a2ui_str = '[{"type": "Text", "text": "Hello"},]'

args = {
SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME: malformed_a2ui_str
}

result = await tool.run_async(args=args, tool_context=tool_context_mock)

# Assert that the fix was successful and the result is correct
expected_a2ui = [{"type": "Text", "text": "Hello"}]
assert result == {
SendA2uiToClientToolset._SendA2uiJsonToClientTool.VALIDATED_A2UI_JSON_KEY: expected_a2ui
}

# Assert that the warning was logged
assert "Detected trailing commas in LLM output; applied autofix." in caplog.text


# endregion

# region send_a2ui_to_client_part_converter Tests
Expand Down