diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 71e9a17..784405f 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -56,8 +56,8 @@ jobs: - name: Run linting and formatting checks run: just check - - name: Run tests - run: just test + - name: Run tests with coverage + run: just test-coverage release-preview: runs-on: ubuntu-latest diff --git a/justfile b/justfile index 88ae269..f3c3978 100644 --- a/justfile +++ b/justfile @@ -21,6 +21,13 @@ test-package PACKAGE: test-file PACKAGE FILE: cd packages/{{PACKAGE}} && uv run --extra test pytest tests/{{FILE}} -v +# Run tests with coverage enforcement +# Note: mcp package has lower threshold due to optional client integrations (CrewAI, LangChain, etc.) +test-coverage: build + cd packages/oauth && uv run --extra test pytest tests/ -v --cov=src --cov-report=term-missing --cov-fail-under=70 + cd packages/mcp && uv run --extra test pytest tests/ -v --cov=src --cov-report=term-missing --cov-fail-under=65 + cd packages/mcp-fastmcp && uv run --extra test pytest tests/ -v --cov=src --cov-report=term-missing --cov-fail-under=70 + check: uv run ruff check diff --git a/packages/mcp/tests/integration/e2e/__init__.py b/packages/mcp/tests/integration/e2e/__init__.py new file mode 100644 index 0000000..5cc6ad5 --- /dev/null +++ b/packages/mcp/tests/integration/e2e/__init__.py @@ -0,0 +1 @@ +"""End-to-end tests for critical OAuth and MCP flows.""" diff --git a/packages/mcp/tests/integration/e2e/conftest.py b/packages/mcp/tests/integration/e2e/conftest.py new file mode 100644 index 0000000..55ba422 --- /dev/null +++ b/packages/mcp/tests/integration/e2e/conftest.py @@ -0,0 +1,70 @@ +"""Shared fixtures for E2E tests. + +These fixtures extend the base auth_provider fixtures with E2E-specific +configurations for testing complete flows. +""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from keycardai.mcp.server.auth.client_factory import ClientFactory +from keycardai.oauth.types.models import AuthorizationServerMetadata, TokenResponse + +# E2E test constants +E2E_ZONE_ID = "e2e-test" +E2E_ZONE_URL = "https://e2e-test.keycard.cloud" + + +@pytest.fixture +def e2e_oauth_metadata(): + """Standard OAuth metadata for E2E tests.""" + return AuthorizationServerMetadata( + issuer=E2E_ZONE_URL, + authorization_endpoint=f"{E2E_ZONE_URL}/auth", + token_endpoint=f"{E2E_ZONE_URL}/token", + registration_endpoint=f"{E2E_ZONE_URL}/register", + jwks_uri=f"{E2E_ZONE_URL}/.well-known/jwks.json", + ) + + +@pytest.fixture +def e2e_client_factory(e2e_oauth_metadata): + """Create a reusable mock client factory for E2E tests.""" + factory = Mock(spec=ClientFactory) + + # Mock sync client for metadata discovery + mock_sync_client = Mock() + mock_sync_client.discover_server_metadata.return_value = e2e_oauth_metadata + factory.create_client.return_value = mock_sync_client + + # Mock async client for token exchange + mock_async_client = AsyncMock() + mock_async_client.config = Mock() + mock_async_client.config.client_id = "e2e_test_client" + + def default_exchange(request): + """Generate resource-specific tokens for E2E testing.""" + resource = request.resource if hasattr(request, "resource") else str(request) + # Create deterministic token based on resource + token_suffix = resource.replace("https://", "").replace("/", "_").replace(".", "_") + return TokenResponse( + access_token=f"e2e_token_for_{token_suffix}", + token_type="Bearer", + expires_in=3600, + ) + + mock_async_client.exchange_token.side_effect = default_exchange + factory.create_async_client.return_value = mock_async_client + + return factory, mock_async_client + + +@pytest.fixture +def e2e_auth_provider_config(): + """Standard AuthProvider configuration for E2E tests.""" + return { + "zone_id": E2E_ZONE_ID, + "mcp_server_name": "E2E Test Server", + "mcp_server_url": "http://localhost:8000/", + } diff --git a/packages/mcp/tests/integration/e2e/test_authprovider_e2e.py b/packages/mcp/tests/integration/e2e/test_authprovider_e2e.py new file mode 100644 index 0000000..fce5d48 --- /dev/null +++ b/packages/mcp/tests/integration/e2e/test_authprovider_e2e.py @@ -0,0 +1,263 @@ +"""End-to-end tests for MCP server AuthProvider. + +Tests: AuthProvider init -> JWT verifier creation -> tool invocation. +""" + +from unittest.mock import AsyncMock, Mock + +import pytest +from mcp.server.fastmcp import Context + +from keycardai.mcp.server.auth import ( + AccessContext, + AuthProvider, + AuthProviderConfigurationError, +) +from keycardai.oauth.types.models import TokenResponse + + +def create_mock_context_with_auth_info( + access_token: str = "user_jwt_token", + zone_id: str = "e2e-test", + resource_client_id: str = "", + resource_server_url: str = "http://localhost:8000/", +): + """Helper function to create a mock Context with authentication info.""" + mock_context = Mock(spec=Context) + mock_context.request_context = Mock() + mock_context.request_context.request = Mock() + mock_context.request_context.request.state.keycardai_auth_info = { + "access_token": access_token, + "zone_id": zone_id, + "resource_client_id": resource_client_id, + "resource_server_url": resource_server_url, + } + return mock_context + + +class TestAuthProviderE2EFlow: + """End-to-end tests for AuthProvider initialization to tool execution.""" + + @pytest.mark.asyncio + async def test_authprovider_init_to_tool_invocation( + self, e2e_client_factory, e2e_auth_provider_config + ): + """Test complete flow from AuthProvider init to successful tool invocation.""" + factory, mock_async_client = e2e_client_factory + + # Step 1: Initialize AuthProvider + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + # Step 2: Verify JWT verifier can be created + verifier = auth_provider.get_token_verifier() + assert verifier is not None + # Note: AuthProvider.required_scopes defaults to None, verifier converts to empty list + assert verifier.required_scopes == (auth_provider.required_scopes or []) + + # Step 3: Create a tool with grant decorator + @auth_provider.grant("https://api.e2e-test.com") + def e2e_tool(access_ctx: AccessContext, ctx: Context, query: str) -> str: + if access_ctx.has_errors(): + return f"Error: {access_ctx.get_errors()}" + token = access_ctx.access("https://api.e2e-test.com").access_token + return f"Success with token: {token}" + + # Step 4: Create mock context with auth info + mock_context = create_mock_context_with_auth_info() + + # Step 5: Execute the tool + result = await e2e_tool(ctx=mock_context, query="test query") + + # Step 6: Verify complete flow succeeded + assert "Success with token" in result + assert "e2e_token_for_api_e2e-test_com" in result + + @pytest.mark.asyncio + async def test_authprovider_multi_resource_grant( + self, e2e_client_factory, e2e_auth_provider_config + ): + """Test AuthProvider with multiple resource grants.""" + factory, mock_async_client = e2e_client_factory + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant(["https://api1.e2e-test.com", "https://api2.e2e-test.com"]) + def multi_resource_tool(access_ctx: AccessContext, ctx: Context) -> dict: + results = {} + for resource in ["https://api1.e2e-test.com", "https://api2.e2e-test.com"]: + if access_ctx.has_resource_error(resource): + results[resource] = "error" + else: + results[resource] = access_ctx.access(resource).access_token + return results + + mock_context = create_mock_context_with_auth_info() + + result = await multi_resource_tool(ctx=mock_context) + + assert "e2e_token_for_api1_e2e-test_com" in result["https://api1.e2e-test.com"] + assert "e2e_token_for_api2_e2e-test_com" in result["https://api2.e2e-test.com"] + + @pytest.mark.asyncio + async def test_authprovider_async_tool(self, e2e_client_factory, e2e_auth_provider_config): + """Test AuthProvider with async tool function.""" + factory, mock_async_client = e2e_client_factory + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant("https://api.e2e-test.com") + async def async_e2e_tool(access_ctx: AccessContext, ctx: Context, data: str) -> str: + if access_ctx.has_errors(): + return f"Error: {access_ctx.get_errors()}" + token = access_ctx.access("https://api.e2e-test.com").access_token + return f"Async success: {data}, token: {token}" + + mock_context = create_mock_context_with_auth_info() + + result = await async_e2e_tool(ctx=mock_context, data="test_data") + + assert "Async success: test_data" in result + assert "e2e_token_for_api_e2e-test_com" in result + + def test_authprovider_missing_zone_configuration(self): + """Test AuthProvider raises appropriate errors for missing zone config.""" + with pytest.raises(AuthProviderConfigurationError): + AuthProvider( + mcp_server_name="Test Server", mcp_server_url="http://localhost:8000/" + ) + + def test_authprovider_with_explicit_zone_url(self): + """Test AuthProvider accepts explicit zone_url instead of zone_id.""" + # Should not raise - using zone_url directly + auth_provider = AuthProvider( + zone_url="https://custom.keycard.cloud", + mcp_server_name="Test Server", + mcp_server_url="http://localhost:8000/", + ) + + assert auth_provider.zone_url == "https://custom.keycard.cloud" + + @pytest.mark.asyncio + async def test_authprovider_token_exchange_failure(self, e2e_auth_provider_config): + """Test AuthProvider handles token exchange failures gracefully.""" + # Create factory with failing async client + factory = Mock() + mock_sync_client = Mock() + mock_sync_client.discover_server_metadata.return_value = Mock( + issuer="https://e2e-test.keycard.cloud", + authorization_endpoint="https://e2e-test.keycard.cloud/auth", + token_endpoint="https://e2e-test.keycard.cloud/token", + jwks_uri="https://e2e-test.keycard.cloud/.well-known/jwks.json", + ) + factory.create_client.return_value = mock_sync_client + + mock_async_client = AsyncMock() + mock_async_client.exchange_token.side_effect = Exception("Token exchange failed") + factory.create_async_client.return_value = mock_async_client + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant("https://api.e2e-test.com") + def failing_tool(access_ctx: AccessContext, ctx: Context) -> str: + if access_ctx.has_errors(): + errors = access_ctx.get_errors() + return f"Error occurred: {errors}" + return "Should not reach here" + + mock_context = create_mock_context_with_auth_info() + + result = await failing_tool(ctx=mock_context) + + assert "Error occurred" in result + assert "Token exchange failed" in str(result) + + @pytest.mark.asyncio + async def test_authprovider_partial_token_exchange_failure( + self, e2e_auth_provider_config + ): + """Test AuthProvider handles partial failures with multiple resources.""" + # Create factory where one resource fails + factory = Mock() + mock_sync_client = Mock() + mock_sync_client.discover_server_metadata.return_value = Mock( + issuer="https://e2e-test.keycard.cloud", + authorization_endpoint="https://e2e-test.keycard.cloud/auth", + token_endpoint="https://e2e-test.keycard.cloud/token", + jwks_uri="https://e2e-test.keycard.cloud/.well-known/jwks.json", + ) + factory.create_client.return_value = mock_sync_client + + mock_async_client = AsyncMock() + call_count = [0] + + async def mock_exchange(request): + call_count[0] += 1 + resource = request.resource if hasattr(request, "resource") else str(request) + if "api1" in resource: + return TokenResponse( + access_token="token_for_api1", token_type="Bearer", expires_in=3600 + ) + else: + raise Exception("API2 exchange failed") + + mock_async_client.exchange_token.side_effect = mock_exchange + factory.create_async_client.return_value = mock_async_client + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant(["https://api1.e2e-test.com", "https://api2.e2e-test.com"]) + def partial_tool(access_ctx: AccessContext, ctx: Context) -> dict: + return { + "status": access_ctx.get_status(), + "has_errors": access_ctx.has_errors(), + "successful": access_ctx.get_successful_resources(), + "failed": access_ctx.get_failed_resources(), + } + + mock_context = create_mock_context_with_auth_info() + + result = await partial_tool(ctx=mock_context) + + # Verify partial success/failure handling + assert result["has_errors"] is True + assert "https://api1.e2e-test.com" in result["successful"] + assert "https://api2.e2e-test.com" in result["failed"] + + +class TestAuthProviderVerifier: + """Tests for AuthProvider JWT verifier creation.""" + + def test_verifier_creation_with_default_scopes(self, e2e_auth_provider_config): + """Test verifier is created with default empty scopes.""" + factory = Mock() + mock_sync_client = Mock() + mock_sync_client.discover_server_metadata.return_value = Mock( + issuer="https://e2e-test.keycard.cloud", + jwks_uri="https://e2e-test.keycard.cloud/.well-known/jwks.json", + ) + factory.create_client.return_value = mock_sync_client + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + verifier = auth_provider.get_token_verifier() + + assert verifier is not None + assert verifier.required_scopes == [] + + def test_verifier_creation_with_custom_scopes(self, e2e_auth_provider_config): + """Test verifier is created with custom required scopes.""" + factory = Mock() + mock_sync_client = Mock() + mock_sync_client.discover_server_metadata.return_value = Mock( + issuer="https://e2e-test.keycard.cloud", + jwks_uri="https://e2e-test.keycard.cloud/.well-known/jwks.json", + ) + factory.create_client.return_value = mock_sync_client + + config = {**e2e_auth_provider_config, "required_scopes": ["read", "write"]} + auth_provider = AuthProvider(**config, client_factory=factory) + + verifier = auth_provider.get_token_verifier() + + assert verifier is not None + assert verifier.required_scopes == ["read", "write"] diff --git a/packages/mcp/tests/integration/e2e/test_grant_decorator_e2e.py b/packages/mcp/tests/integration/e2e/test_grant_decorator_e2e.py new file mode 100644 index 0000000..abb33a1 --- /dev/null +++ b/packages/mcp/tests/integration/e2e/test_grant_decorator_e2e.py @@ -0,0 +1,299 @@ +"""End-to-end tests for @grant decorator. + +Tests: token exchange -> AccessContext population -> error handling. +""" + +from unittest.mock import AsyncMock, Mock + +import pytest +from mcp.server.fastmcp import Context + +from keycardai.mcp.server.auth import ( + AccessContext, + AuthProvider, + MissingAccessContextError, + MissingContextError, + ResourceAccessError, +) +from keycardai.oauth.types.models import TokenResponse + + +def create_mock_context_with_auth(): + """Create mock context with authentication info for E2E tests.""" + mock_context = Mock(spec=Context) + mock_context.request_context = Mock() + mock_context.request_context.request = Mock() + mock_context.request_context.request.state.keycardai_auth_info = { + "access_token": "user_token", + "zone_id": "e2e-test", + "resource_client_id": "", + "resource_server_url": "http://localhost:8000/", + } + return mock_context + + +class TestGrantDecoratorE2E: + """End-to-end tests for grant decorator functionality.""" + + @pytest.mark.asyncio + async def test_grant_decorator_token_exchange_success( + self, e2e_client_factory, e2e_auth_provider_config + ): + """Test successful token exchange through grant decorator.""" + factory, mock_async_client = e2e_client_factory + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant("https://api.example.com") + def test_tool(access_ctx: AccessContext, ctx: Context, data: str) -> str: + token = access_ctx.access("https://api.example.com").access_token + return f"Got token: {token}, data: {data}" + + mock_context = create_mock_context_with_auth() + + result = await test_tool(ctx=mock_context, data="test_input") + + assert "Got token:" in result + assert "data: test_input" in result + + @pytest.mark.asyncio + async def test_grant_decorator_token_exchange_failure(self, e2e_auth_provider_config): + """Test error handling when token exchange fails.""" + # Create factory with failing exchange + factory = Mock() + mock_sync_client = Mock() + mock_sync_client.discover_server_metadata.return_value = Mock( + issuer="https://e2e-test.keycard.cloud", + jwks_uri="https://e2e-test.keycard.cloud/.well-known/jwks.json", + ) + factory.create_client.return_value = mock_sync_client + + mock_async_client = AsyncMock() + mock_async_client.exchange_token.side_effect = Exception("Token exchange failed") + factory.create_async_client.return_value = mock_async_client + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant("https://api.example.com") + def test_tool(access_ctx: AccessContext, ctx: Context) -> str: + if access_ctx.has_errors(): + return f"Error: {access_ctx.get_errors()}" + return "Should not reach here" + + mock_context = create_mock_context_with_auth() + + result = await test_tool(ctx=mock_context) + + assert "Error" in result + assert "Token exchange failed" in str(result) + + @pytest.mark.asyncio + async def test_grant_decorator_partial_success(self, e2e_auth_provider_config): + """Test partial success scenario with multiple resources.""" + factory = Mock() + mock_sync_client = Mock() + mock_sync_client.discover_server_metadata.return_value = Mock( + issuer="https://e2e-test.keycard.cloud", + jwks_uri="https://e2e-test.keycard.cloud/.well-known/jwks.json", + ) + factory.create_client.return_value = mock_sync_client + + mock_async_client = AsyncMock() + + async def mock_exchange(request): + resource = request.resource if hasattr(request, "resource") else str(request) + if "api1" in resource: + return TokenResponse( + access_token="token_api1", token_type="Bearer", expires_in=3600 + ) + else: + raise Exception("API2 token exchange failed") + + mock_async_client.exchange_token.side_effect = mock_exchange + factory.create_async_client.return_value = mock_async_client + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant(["https://api1.example.com", "https://api2.example.com"]) + def test_tool(access_ctx: AccessContext, ctx: Context) -> dict: + status = access_ctx.get_status() + successful = access_ctx.get_successful_resources() + failed = access_ctx.get_failed_resources() + return {"status": status, "successful": successful, "failed": failed} + + mock_context = create_mock_context_with_auth() + + result = await test_tool(ctx=mock_context) + + # Verify partial success behavior + assert result["status"] == "partial_error" + assert "https://api1.example.com" in result["successful"] + assert "https://api2.example.com" in result["failed"] + + def test_grant_decorator_missing_access_context( + self, e2e_client_factory, e2e_auth_provider_config + ): + """Test that missing AccessContext parameter raises error.""" + factory, _ = e2e_client_factory + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + with pytest.raises(MissingAccessContextError): + + @auth_provider.grant("https://api.example.com") + def bad_tool(ctx: Context, data: str) -> str: # Missing AccessContext + return data + + def test_grant_decorator_missing_context( + self, e2e_client_factory, e2e_auth_provider_config + ): + """Test that missing Context parameter raises error.""" + factory, _ = e2e_client_factory + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + with pytest.raises(MissingContextError): + + @auth_provider.grant("https://api.example.com") + def bad_tool(access_ctx: AccessContext, data: str) -> str: # Missing Context + return data + + @pytest.mark.asyncio + async def test_grant_decorator_preserves_function_metadata( + self, e2e_client_factory, e2e_auth_provider_config + ): + """Test that grant decorator preserves function name and docstring.""" + factory, _ = e2e_client_factory + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant("https://api.example.com") + def my_documented_tool(access_ctx: AccessContext, ctx: Context, data: str) -> str: + """This is my tool documentation.""" + return data + + # The wrapper should preserve the function name and docstring + assert my_documented_tool.__name__ == "my_documented_tool" + assert "documentation" in my_documented_tool.__doc__ + + @pytest.mark.asyncio + async def test_grant_decorator_with_no_auth_info( + self, e2e_client_factory, e2e_auth_provider_config + ): + """Test grant decorator handles missing authentication info gracefully.""" + factory, _ = e2e_client_factory + + auth_provider = AuthProvider(**e2e_auth_provider_config, client_factory=factory) + + @auth_provider.grant("https://api.example.com") + def test_tool(access_ctx: AccessContext, ctx: Context) -> str: + if access_ctx.has_error(): + return f"Auth error: {access_ctx.get_error()}" + return "Success" + + # Create context without auth info + mock_context = Mock(spec=Context) + mock_context.request_context = Mock() + mock_context.request_context.request = Mock() + mock_context.request_context.request.state = {} + + result = await test_tool(ctx=mock_context) + + assert "Auth error" in result + assert "No request authentication information" in str(result) + + +class TestAccessContextE2E: + """End-to-end tests for AccessContext behavior.""" + + def test_access_context_token_retrieval(self): + """Test token retrieval from AccessContext.""" + token_response = TokenResponse( + access_token="test_token_abc", token_type="Bearer", expires_in=3600 + ) + + ctx = AccessContext({"https://api.example.com": token_response}) + + retrieved = ctx.access("https://api.example.com") + assert retrieved.access_token == "test_token_abc" + + def test_access_context_missing_resource_error(self): + """Test ResourceAccessError for missing resource.""" + ctx = AccessContext( + { + "https://api.example.com": TokenResponse( + access_token="token", token_type="Bearer" + ) + } + ) + + with pytest.raises(ResourceAccessError): + ctx.access("https://other.api.com") + + def test_access_context_error_states(self): + """Test error state management in AccessContext.""" + ctx = AccessContext() + + # Initially no errors + assert not ctx.has_errors() + assert ctx.get_status() == "success" + + # Set resource error + ctx.set_resource_error("https://api1.com", {"error": "Failed"}) + assert ctx.has_errors() + assert ctx.has_resource_error("https://api1.com") + assert ctx.get_status() == "partial_error" + + # Set global error + ctx.set_error({"error": "Global failure"}) + assert ctx.has_error() + assert ctx.get_status() == "error" + + def test_access_context_set_and_retrieve_token(self): + """Test setting and retrieving tokens dynamically.""" + ctx = AccessContext() + + # Initially empty + assert ctx.get_successful_resources() == [] + + # Set token + token = TokenResponse(access_token="dynamic_token", token_type="Bearer") + ctx.set_token("https://api.test.com", token) + + # Verify retrieval + assert "https://api.test.com" in ctx.get_successful_resources() + retrieved = ctx.access("https://api.test.com") + assert retrieved.access_token == "dynamic_token" + + def test_access_context_error_overrides_token(self): + """Test that setting error clears previous token for resource.""" + ctx = AccessContext() + + # Set token first + token = TokenResponse(access_token="original_token", token_type="Bearer") + ctx.set_token("https://api.test.com", token) + + # Set error for same resource + ctx.set_resource_error("https://api.test.com", {"error": "Now failed"}) + + # Token should no longer be accessible + with pytest.raises(ResourceAccessError): + ctx.access("https://api.test.com") + + # Resource should be in failed list + assert "https://api.test.com" in ctx.get_failed_resources() + assert "https://api.test.com" not in ctx.get_successful_resources() + + def test_access_context_bulk_tokens(self): + """Test setting multiple tokens at once.""" + token1 = TokenResponse(access_token="token1", token_type="Bearer") + token2 = TokenResponse(access_token="token2", token_type="Bearer") + + ctx = AccessContext() + ctx.set_bulk_tokens( + {"https://api1.com": token1, "https://api2.com": token2} + ) + + assert ctx.access("https://api1.com").access_token == "token1" + assert ctx.access("https://api2.com").access_token == "token2" + assert len(ctx.get_successful_resources()) == 2 diff --git a/packages/mcp/tests/integration/e2e/test_oauth_e2e.py b/packages/mcp/tests/integration/e2e/test_oauth_e2e.py new file mode 100644 index 0000000..2a39170 --- /dev/null +++ b/packages/mcp/tests/integration/e2e/test_oauth_e2e.py @@ -0,0 +1,279 @@ +"""End-to-end tests for OAuth flows. + +Tests the complete OAuth flow: discovery -> registration -> token exchange. +""" + +from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from keycardai.mcp.client.auth.oauth.discovery import OAuthDiscoveryService +from keycardai.mcp.client.auth.oauth.exchange import OAuthTokenExchangeService +from keycardai.mcp.client.auth.oauth.registration import OAuthClientRegistrationService +from keycardai.mcp.client.auth.storage_facades import OAuthStorage +from keycardai.mcp.client.storage import InMemoryBackend, NamespacedStorage + + +class TestOAuthE2EFlow: + """End-to-end tests for complete OAuth flow.""" + + @pytest.fixture + def oauth_storage(self): + """Create OAuth storage for testing.""" + backend = InMemoryBackend() + base_storage = NamespacedStorage(backend, "e2e:test:oauth") + return OAuthStorage(base_storage) + + @pytest.fixture + def mock_http_client(self): + """Create configurable mock HTTP client.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + return mock_client + + @pytest.mark.asyncio + async def test_complete_oauth_flow_discovery_to_token( + self, oauth_storage, mock_http_client + ): + """Test complete flow: discovery -> registration -> token exchange.""" + # Setup: Configure mock responses for each step + auth_server_url = "https://auth.keycard.test" + + # Step 2: Auth server metadata response + auth_server_response = MagicMock() + auth_server_response.status_code = 200 + auth_server_response.raise_for_status = MagicMock() + auth_server_response.json.return_value = { + "issuer": auth_server_url, + "authorization_endpoint": f"{auth_server_url}/authorize", + "token_endpoint": f"{auth_server_url}/token", + "registration_endpoint": f"{auth_server_url}/register", + "jwks_uri": f"{auth_server_url}/.well-known/jwks.json", + } + + # Step 3: Client registration response + registration_response = MagicMock() + registration_response.status_code = 200 + registration_response.raise_for_status = MagicMock() + registration_response.json.return_value = { + "client_id": "e2e_test_client", + "client_secret": None, + "redirect_uris": ["http://localhost:8080/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "token_endpoint_auth_method": "none", + } + + # Step 4: Token exchange response + token_response = MagicMock() + token_response.status_code = 200 + token_response.json.return_value = { + "access_token": "e2e_access_token_xyz", + "refresh_token": "e2e_refresh_token_abc", + "expires_in": 3600, + "token_type": "Bearer", + } + + # Configure mock to return different responses based on URL + call_log = {"get": [], "post": []} + + async def mock_get(url): + call_log["get"].append(url) + if ".well-known/oauth-authorization-server" in url: + return auth_server_response + raise ValueError(f"Unexpected GET URL: {url}") + + async def mock_post(url, **kwargs): + call_log["post"].append(url) + if "/register" in url: + return registration_response + elif "/token" in url: + return token_response + raise ValueError(f"Unexpected POST URL: {url}") + + mock_http_client.get = mock_get + mock_http_client.post = mock_post + + def client_factory(): + return mock_http_client + + # Execute: Run the complete flow + + # 1. Auth server discovery (simulating resource metadata already available) + resource_metadata = {"authorization_servers": [auth_server_url]} + + discovery_service = OAuthDiscoveryService( + storage=oauth_storage, client_factory=client_factory + ) + + auth_metadata = await discovery_service.discover_auth_server(resource_metadata) + assert auth_metadata["token_endpoint"] == f"{auth_server_url}/token" + assert auth_metadata["registration_endpoint"] == f"{auth_server_url}/register" + + # 2. Client registration + registration_service = OAuthClientRegistrationService( + storage=oauth_storage, client_name="E2E Test Client", client_factory=client_factory + ) + + client_info = await registration_service.get_or_register_client( + auth_metadata, ["http://localhost:8080/callback"] + ) + assert client_info["client_id"] == "e2e_test_client" + + # 3. Token exchange (requires PKCE state) + state = "e2e_test_state" + await oauth_storage.save_pkce_state( + state=state, + pkce_data={ + "code_verifier": "e2e_verifier", + "code_challenge": "e2e_challenge", + "redirect_uri": "http://localhost:8080/callback", + "resource_url": "https://api.test.com", + }, + ttl=timedelta(minutes=10), + ) + + exchange_service = OAuthTokenExchangeService( + storage=oauth_storage, client_factory=client_factory + ) + + tokens = await exchange_service.exchange_code_for_tokens( + code="e2e_auth_code", + state=state, + auth_server_metadata=auth_metadata, + client_info=client_info, + ) + + # Verify: Complete flow succeeded + assert tokens["access_token"] == "e2e_access_token_xyz" + assert tokens["refresh_token"] == "e2e_refresh_token_abc" + + # Verify all steps were called + assert len(call_log["get"]) >= 1 # Auth server discovery + assert len(call_log["post"]) >= 2 # Registration + token exchange + + @pytest.mark.asyncio + async def test_oauth_flow_with_cached_metadata(self, oauth_storage, mock_http_client): + """Test that cached metadata is reused in subsequent flows.""" + # Pre-cache auth server metadata + cached_metadata = { + "issuer": "https://cached.keycard.test", + "authorization_endpoint": "https://cached.keycard.test/authorize", + "token_endpoint": "https://cached.keycard.test/token", + "registration_endpoint": "https://cached.keycard.test/register", + } + await oauth_storage.save_auth_server_metadata(cached_metadata) + + # Pre-cache client registration + cached_client = { + "client_id": "cached_client_123", + "redirect_uris": ["http://localhost:8080/callback"], + } + await oauth_storage.save_client_registration(cached_client) + + # HTTP client should NOT be called for discovery or registration + mock_http_client.get = AsyncMock(side_effect=Exception("Should not be called")) + mock_http_client.post = AsyncMock(side_effect=Exception("Should not be called")) + + def client_factory(): + return mock_http_client + + # Discovery should return cached metadata + discovery_service = OAuthDiscoveryService( + storage=oauth_storage, client_factory=client_factory + ) + + metadata = await discovery_service.discover_auth_server( + {"authorization_servers": ["https://any.server.com"]} + ) + + assert metadata == cached_metadata + + # Registration should return cached client + registration_service = OAuthClientRegistrationService( + storage=oauth_storage, client_name="Test", client_factory=client_factory + ) + + client = await registration_service.get_or_register_client( + cached_metadata, ["http://localhost:8080/callback"] + ) + + assert client == cached_client + + @pytest.mark.asyncio + async def test_oauth_flow_handles_discovery_failure(self, oauth_storage, mock_http_client): + """Test graceful handling when discovery fails.""" + mock_http_client.get = AsyncMock(side_effect=Exception("Network error")) + + def client_factory(): + return mock_http_client + + discovery_service = OAuthDiscoveryService( + storage=oauth_storage, client_factory=client_factory + ) + + with pytest.raises(ValueError, match="Failed to discover"): + await discovery_service.discover_auth_server( + {"authorization_servers": ["https://unreachable.server.com"]} + ) + + @pytest.mark.asyncio + async def test_oauth_flow_handles_registration_failure(self, oauth_storage, mock_http_client): + """Test graceful handling when client registration fails.""" + # Mock successful metadata response + metadata_response = MagicMock() + metadata_response.status_code = 200 + metadata_response.raise_for_status = MagicMock() + + # Mock failed registration response + registration_response = MagicMock() + registration_response.status_code = 400 + registration_response.raise_for_status.side_effect = Exception("Registration failed") + + async def mock_get(url): + return metadata_response + + async def mock_post(url, **kwargs): + if "/register" in url: + return registration_response + raise ValueError(f"Unexpected POST: {url}") + + mock_http_client.get = mock_get + mock_http_client.post = mock_post + + def client_factory(): + return mock_http_client + + registration_service = OAuthClientRegistrationService( + storage=oauth_storage, client_name="Test Client", client_factory=client_factory + ) + + auth_metadata = {"registration_endpoint": "https://auth.test/register"} + + with pytest.raises(Exception, match="Registration failed"): + await registration_service.get_or_register_client( + auth_metadata, ["http://localhost:8080/callback"] + ) + + @pytest.mark.asyncio + async def test_oauth_flow_handles_missing_pkce_state(self, oauth_storage, mock_http_client): + """Test error handling when PKCE state is missing during token exchange.""" + + def client_factory(): + return mock_http_client + + exchange_service = OAuthTokenExchangeService( + storage=oauth_storage, client_factory=client_factory + ) + + auth_metadata = {"token_endpoint": "https://auth.test/token"} + client_info = {"client_id": "test_client"} + + with pytest.raises(ValueError, match="No PKCE state found"): + await exchange_service.exchange_code_for_tokens( + code="some_code", + state="nonexistent_state", + auth_server_metadata=auth_metadata, + client_info=client_info, + )