From c50220d32dfe5283c3524254282eff67c51cdb65 Mon Sep 17 00:00:00 2001 From: tomas Date: Wed, 17 Dec 2025 12:50:09 +0000 Subject: [PATCH 01/10] feat: Add support to refresh federated auth access token --- deepnote_toolkit/sql/sql_execution.py | 94 ++++++++++++++++++++++++--- 1 file changed, 85 insertions(+), 9 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 02892dc..77cb8ea 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -1,13 +1,22 @@ import base64 import contextlib import json +import logging import re +import sys +from typing import Any, Literal import uuid + +if sys.version_info >= (3, 11): + from typing import Never +else: + from typing_extensions import Never import warnings from urllib.parse import quote import google.oauth2.credentials import numpy as np +from pydantic import BaseModel, ValidationError import requests from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -33,6 +42,18 @@ from deepnote_toolkit.sql.sql_utils import is_single_select_query from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url +logger = logging.getLogger(__name__) + + +class IntegrationFederatedAuthParams(BaseModel): + integrationType: Literal["trino", "big-query"] + integrationId: str + userId: str + + +class FederatedAuthResponseData(BaseModel): + accessToken: str + def compile_sql_query( skip_jinja_template_render, @@ -247,6 +268,68 @@ def _generate_temporary_credentials(integration_id): return quote(data["username"]), quote(data["password"]) +def _get_federated_auth_credentials(integration_id: str, user_id: str) -> str: + url = get_absolute_userpod_api_url( + f"integrations/federated-auth-token/{integration_id}" + ) + + # Add project credentials in detached mode + headers = get_project_auth_headers() + + response = requests.post(url, json={"userId": user_id}, timeout=10, headers=headers) + + data = FederatedAuthResponseData.model_validate_json(response.json()) + + return data.accessToken + + +def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None: + if "iamParams" not in sql_alchemy_dict: + return + + integration_id = sql_alchemy_dict["iamParams"]["integrationId"] + + temporaryUsername, temporaryPassword = _generate_temporary_credentials( + integration_id + ) + + sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( + sql_alchemy_dict["url"], temporaryUsername, temporaryPassword + ) + + +def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: + if "federatedAuthParams" not in sql_alchemy_dict: + return + + try: + federated_auth_params = IntegrationFederatedAuthParams.model_validate( + sql_alchemy_dict["federatedAuthParams"] + ) + except ValidationError as e: + logger.error( + f"Invalid federated auth params, try updating toolkit version: {e}" + ) + return + + access_token = _get_federated_auth_credentials( + federated_auth_params.integrationId, federated_auth_params.userId + ) + + match federated_auth_params.integrationType: + case "trino": + sql_alchemy_dict["params"]["connect_args"]["http_headers"][ + "Authorization" + ] = f"Bearer {access_token}" + case "big-query": + sql_alchemy_dict["params"]["access_token"] = access_token + case _: + _check_never: Never = federated_auth_params.integrationType + raise ValueError( + f"Unsupported integration type: {federated_auth_params.integrationType}" + ) + + @contextlib.contextmanager def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict): server = None @@ -346,16 +429,9 @@ def _query_data_source( ): sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False) - if "iamParams" in sql_alchemy_dict: - integration_id = sql_alchemy_dict["iamParams"]["integrationId"] - - temporaryUsername, temporaryPassword = _generate_temporary_credentials( - integration_id - ) + _handle_iam_params(sql_alchemy_dict) - sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( - sql_alchemy_dict["url"], temporaryUsername, temporaryPassword - ) + _handle_federated_auth_params(sql_alchemy_dict) with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url: if url is None: From 895e221bdb5bfed70b0ce1215a988c3905a88f52 Mon Sep 17 00:00:00 2001 From: tomas Date: Fri, 19 Dec 2025 15:45:59 +0000 Subject: [PATCH 02/10] refactor(sql_execution): Update federated auth handling and improve type hints - Refactored `_get_federated_auth_credentials` to return `FederatedAuthResponseData` instead of just the access token. - Updated `IntegrationFederatedAuthParams` to include `userPodAuthContextToken` and changed the handling of integration types. - Improved error logging for unsupported integration types. - Cleaned up imports and ensured consistent use of type hints. --- deepnote_toolkit/sql/sql_execution.py | 59 +++++++++++++-------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 77cb8ea..9aba91d 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -3,26 +3,20 @@ import json import logging import re -import sys -from typing import Any, Literal import uuid - -if sys.version_info >= (3, 11): - from typing import Never -else: - from typing_extensions import Never import warnings +from typing import Any from urllib.parse import quote import google.oauth2.credentials import numpy as np -from pydantic import BaseModel, ValidationError import requests from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from google.api_core.client_info import ClientInfo from google.cloud import bigquery from packaging.version import parse as parse_version +from pydantic import BaseModel, ValidationError from sqlalchemy.engine import URL, create_engine, make_url from sqlalchemy.exc import ResourceClosedError @@ -46,12 +40,12 @@ class IntegrationFederatedAuthParams(BaseModel): - integrationType: Literal["trino", "big-query"] integrationId: str - userId: str + userPodAuthContextToken: str class FederatedAuthResponseData(BaseModel): + integrationType: str accessToken: str @@ -268,37 +262,42 @@ def _generate_temporary_credentials(integration_id): return quote(data["username"]), quote(data["password"]) -def _get_federated_auth_credentials(integration_id: str, user_id: str) -> str: +def _get_federated_auth_credentials(integration_id: str, user_pod_auth_context_token: str) -> FederatedAuthResponseData: url = get_absolute_userpod_api_url( f"integrations/federated-auth-token/{integration_id}" ) # Add project credentials in detached mode headers = get_project_auth_headers() + headers["UserPodAuthContextToken"] = user_pod_auth_context_token - response = requests.post(url, json={"userId": user_id}, timeout=10, headers=headers) + response = requests.post(url, timeout=10, headers=headers) - data = FederatedAuthResponseData.model_validate_json(response.json()) + data = FederatedAuthResponseData.model_validate(response.json()) - return data.accessToken + return data def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None: + """Apply IAM credentials to the connection URL in-place.""" + if "iamParams" not in sql_alchemy_dict: return integration_id = sql_alchemy_dict["iamParams"]["integrationId"] - temporaryUsername, temporaryPassword = _generate_temporary_credentials( + temporary_username, temporary_password = _generate_temporary_credentials( integration_id ) sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( - sql_alchemy_dict["url"], temporaryUsername, temporaryPassword + sql_alchemy_dict["url"], temporary_username, temporary_password ) def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: + """Fetch and apply federated auth credentials to connection params in-place.""" + if "federatedAuthParams" not in sql_alchemy_dict: return @@ -308,26 +307,24 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: ) except ValidationError as e: logger.error( - f"Invalid federated auth params, try updating toolkit version: {e}" + "Invalid federated auth params, try updating toolkit version:", exc_info=e ) return - access_token = _get_federated_auth_credentials( - federated_auth_params.integrationId, federated_auth_params.userId + federated_auth = _get_federated_auth_credentials( + federated_auth_params.integrationId, federated_auth_params.userPodAuthContextToken ) - match federated_auth_params.integrationType: - case "trino": - sql_alchemy_dict["params"]["connect_args"]["http_headers"][ - "Authorization" - ] = f"Bearer {access_token}" - case "big-query": - sql_alchemy_dict["params"]["access_token"] = access_token - case _: - _check_never: Never = federated_auth_params.integrationType - raise ValueError( - f"Unsupported integration type: {federated_auth_params.integrationType}" - ) + if federated_auth.integrationType == "trino": + sql_alchemy_dict["params"]["connect_args"]["http_headers"][ + "Authorization" + ] = f"Bearer {federated_auth.access_token}" + elif federated_auth.integrationType == "big-query": + sql_alchemy_dict["params"]["access_token"] = federated_auth.access_token + else: + logger.error( + "Unsupported integration type: %s, try updating toolkit version", federated_auth.integrationType + ) @contextlib.contextmanager From 66b99793adc13c8bf517e464589aaec381993e7a Mon Sep 17 00:00:00 2001 From: tomas Date: Fri, 19 Dec 2025 16:18:46 +0000 Subject: [PATCH 03/10] Fix typo --- deepnote_toolkit/sql/sql_execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 9aba91d..0d25651 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -318,9 +318,9 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: if federated_auth.integrationType == "trino": sql_alchemy_dict["params"]["connect_args"]["http_headers"][ "Authorization" - ] = f"Bearer {federated_auth.access_token}" + ] = f"Bearer {federated_auth.accessToken}" elif federated_auth.integrationType == "big-query": - sql_alchemy_dict["params"]["access_token"] = federated_auth.access_token + sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken else: logger.error( "Unsupported integration type: %s, try updating toolkit version", federated_auth.integrationType From 8d312361103aeedb4186d4a3a5b2a0526e4b4db1 Mon Sep 17 00:00:00 2001 From: tomas Date: Fri, 19 Dec 2025 17:55:57 +0000 Subject: [PATCH 04/10] Fix federated auth key name, raise on http error --- deepnote_toolkit/sql/sql_execution.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 0d25651..9bc7d2f 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -41,7 +41,7 @@ class IntegrationFederatedAuthParams(BaseModel): integrationId: str - userPodAuthContextToken: str + authContextToken: str class FederatedAuthResponseData(BaseModel): @@ -257,12 +257,16 @@ def _generate_temporary_credentials(integration_id): response = requests.post(url, timeout=10, headers=headers) + response.raise_for_status() + data = response.json() return quote(data["username"]), quote(data["password"]) -def _get_federated_auth_credentials(integration_id: str, user_pod_auth_context_token: str) -> FederatedAuthResponseData: +def _get_federated_auth_credentials( + integration_id: str, user_pod_auth_context_token: str +) -> FederatedAuthResponseData: url = get_absolute_userpod_api_url( f"integrations/federated-auth-token/{integration_id}" ) @@ -273,6 +277,8 @@ def _get_federated_auth_credentials(integration_id: str, user_pod_auth_context_t response = requests.post(url, timeout=10, headers=headers) + response.raise_for_status() + data = FederatedAuthResponseData.model_validate(response.json()) return data @@ -312,7 +318,7 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: return federated_auth = _get_federated_auth_credentials( - federated_auth_params.integrationId, federated_auth_params.userPodAuthContextToken + federated_auth_params.integrationId, federated_auth_params.authContextToken ) if federated_auth.integrationType == "trino": @@ -323,7 +329,8 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken else: logger.error( - "Unsupported integration type: %s, try updating toolkit version", federated_auth.integrationType + "Unsupported integration type: %s, try updating toolkit version", + federated_auth.integrationType, ) From 42c90f9068832b1a7a0bd6b260d2e49b88fe7cf8 Mon Sep 17 00:00:00 2001 From: tomas Date: Fri, 19 Dec 2025 21:25:50 +0000 Subject: [PATCH 05/10] Add tests for federated auth token fetch from webapp --- deepnote_toolkit/sql/sql_execution.py | 4 + tests/unit/test_sql_execution.py | 219 ++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 9bc7d2f..91a1d33 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -267,6 +267,8 @@ def _generate_temporary_credentials(integration_id): def _get_federated_auth_credentials( integration_id: str, user_pod_auth_context_token: str ) -> FederatedAuthResponseData: + """Get federated auth credentials for the given integration ID and user pod auth context token.""" + url = get_absolute_userpod_api_url( f"integrations/federated-auth-token/{integration_id}" ) @@ -327,6 +329,8 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: ] = f"Bearer {federated_auth.accessToken}" elif federated_auth.integrationType == "big-query": sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken + elif federated_auth.integrationType == "snowflake": + logger.warning("Snowflake federated auth is not supported yet, using the original connection URL") else: logger.error( "Unsupported integration type: %s, try updating toolkit version", diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index 5a2c41f..e516a50 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -585,3 +585,222 @@ def test_all_dataframes_serialize_to_parquet(self, key, df): df_cleaned.to_parquet(in_memory_file) except: # noqa: E722 self.fail(f"serializing to parquet failed for {key}") + + +class TestFederatedAuth(unittest.TestCase): + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_trino(self, mock_get_credentials): + """Test that Trino federated auth updates the Authorization header with Bearer token.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return Trino credentials + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="trino", + accessToken="test-trino-access-token", + ) + + # Create a sql_alchemy_dict with federatedAuthParams and the expected structure + sql_alchemy_dict = { + "url": "trino://user@localhost:8080/catalog", + "params": { + "connect_args": { + "http_headers": { + "Authorization": "Bearer old-token", + } + } + }, + "federatedAuthParams": { + "integrationId": "test-integration-id", + "authContextToken": "test-auth-context-token", + }, + } + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the API was called with correct params + mock_get_credentials.assert_called_once_with( + "test-integration-id", "test-auth-context-token" + ) + + # Verify the Authorization header was updated with the new token + self.assertEqual( + sql_alchemy_dict["params"]["connect_args"]["http_headers"]["Authorization"], + "Bearer test-trino-access-token", + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_bigquery(self, mock_get_credentials): + """Test that BigQuery federated auth updates the access_token in params.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return BigQuery credentials + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="big-query", + accessToken="test-bigquery-access-token", + ) + + # Create a sql_alchemy_dict with federatedAuthParams + sql_alchemy_dict = { + "url": "bigquery://?user_supplied_client=true", + "params": { + "access_token": "old-access-token", + "project": "test-project", + }, + "federatedAuthParams": { + "integrationId": "test-bigquery-integration-id", + "authContextToken": "test-bigquery-auth-context-token", + }, + } + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the API was called with correct params + mock_get_credentials.assert_called_once_with( + "test-bigquery-integration-id", "test-bigquery-auth-context-token" + ) + + # Verify the access_token was updated with the new token + self.assertEqual( + sql_alchemy_dict["params"]["access_token"], + "test-bigquery-access-token", + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_snowflake(self, mock_get_credentials, mock_logger): + """Test that Snowflake federated auth logs a warning since it's not supported yet.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return Snowflake credentials + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="snowflake", + accessToken="test-snowflake-access-token", + ) + + # Create a sql_alchemy_dict with federatedAuthParams + sql_alchemy_dict = { + "url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces", + "params": {}, + "federatedAuthParams": { + "integrationId": "test-snowflake-integration-id", + "authContextToken": "test-snowflake-auth-context-token", + }, + } + + # Store original params to verify they remain unchanged + original_params = sql_alchemy_dict["params"].copy() + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the API was called with correct params + mock_get_credentials.assert_called_once_with( + "test-snowflake-integration-id", "test-snowflake-auth-context-token" + ) + + # Verify a warning was logged + mock_logger.warning.assert_called_once_with( + "Snowflake federated auth is not supported yet, using the original connection URL" + ) + + # Verify params were NOT modified (snowflake is not supported yet) + self.assertEqual(sql_alchemy_dict["params"], original_params) + + def test_federated_auth_params_not_present(self): + """Test that no action is taken when federatedAuthParams is not present.""" + from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params + + # Create a sql_alchemy_dict without federatedAuthParams + sql_alchemy_dict = { + "url": "trino://user@localhost:8080/catalog", + "params": { + "connect_args": { + "http_headers": {"Authorization": "Bearer original-token"} + } + }, + } + + original_dict = json.loads(json.dumps(sql_alchemy_dict)) + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the dict was not modified + self.assertEqual(sql_alchemy_dict, original_dict) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + def test_federated_auth_params_invalid_params(self, mock_logger): + """Test that invalid federated auth params logs an error and returns early.""" + from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params + + # Create a sql_alchemy_dict with invalid federatedAuthParams (missing required fields) + sql_alchemy_dict = { + "url": "trino://user@localhost:8080/catalog", + "params": {}, + "federatedAuthParams": { + "invalidField": "value", + }, + } + + original_dict = json.loads(json.dumps(sql_alchemy_dict)) + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify an error was logged + mock_logger.error.assert_called_once() + call_args = mock_logger.error.call_args + self.assertIn("Invalid federated auth params", call_args[0][0]) + + self.assertEqual(sql_alchemy_dict, original_dict) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_unsupported_integration_type( + self, mock_get_credentials, mock_logger + ): + """Test that unsupported integration type logs an error.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return unknown integration type + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="unknown-database", + accessToken="test-token", + ) + + # Create a sql_alchemy_dict with federatedAuthParams + sql_alchemy_dict = { + "url": "unknown://host/db", + "params": {}, + "federatedAuthParams": { + "integrationId": "test-integration-id", + "authContextToken": "test-auth-context-token", + }, + } + + original_dict = json.loads(json.dumps(sql_alchemy_dict)) + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify an error was logged for unsupported integration type + mock_logger.error.assert_called_once_with( + "Unsupported integration type: %s, try updating toolkit version", + "unknown-database", + ) + + self.assertEqual(sql_alchemy_dict, original_dict) From 7fad890cb56533a0e271c43c7f67a14ba5031b29 Mon Sep 17 00:00:00 2001 From: tomas Date: Fri, 19 Dec 2025 21:31:58 +0000 Subject: [PATCH 06/10] Reformat code --- deepnote_toolkit/sql/sql_execution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 91a1d33..6e65589 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -330,7 +330,9 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: elif federated_auth.integrationType == "big-query": sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken elif federated_auth.integrationType == "snowflake": - logger.warning("Snowflake federated auth is not supported yet, using the original connection URL") + logger.warning( + "Snowflake federated auth is not supported yet, using the original connection URL" + ) else: logger.error( "Unsupported integration type: %s, try updating toolkit version", From e7c998b9ef66fecd0cfc51707bd134bac6bf852e Mon Sep 17 00:00:00 2001 From: tomas Date: Mon, 22 Dec 2025 08:13:29 +0000 Subject: [PATCH 07/10] Catch KeyError when assigning refreshed access token, use logger.exception instead of logger.error --- deepnote_toolkit/sql/sql_execution.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 6e65589..86ff5ca 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -313,10 +313,8 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: federated_auth_params = IntegrationFederatedAuthParams.model_validate( sql_alchemy_dict["federatedAuthParams"] ) - except ValidationError as e: - logger.error( - "Invalid federated auth params, try updating toolkit version:", exc_info=e - ) + except ValidationError: + logger.exception("Invalid federated auth params, try updating toolkit version") return federated_auth = _get_federated_auth_credentials( @@ -324,11 +322,21 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: ) if federated_auth.integrationType == "trino": - sql_alchemy_dict["params"]["connect_args"]["http_headers"][ - "Authorization" - ] = f"Bearer {federated_auth.accessToken}" + try: + sql_alchemy_dict["params"]["connect_args"]["http_headers"][ + "Authorization" + ] = f"Bearer {federated_auth.accessToken}" + except KeyError: + logger.exception( + "Invalid federated auth params, try updating toolkit version" + ) elif federated_auth.integrationType == "big-query": - sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken + try: + sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken + except KeyError: + logger.exception( + "Invalid federated auth params, try updating toolkit version" + ) elif federated_auth.integrationType == "snowflake": logger.warning( "Snowflake federated auth is not supported yet, using the original connection URL" From 674f289835c6e134ed7b33e8d70c113fa3153f84 Mon Sep 17 00:00:00 2001 From: tomas Date: Mon, 22 Dec 2025 08:45:43 +0000 Subject: [PATCH 08/10] fix(tests): Update error logging in federated auth tests to use logger.exception - Changed the assertion in the TestFederatedAuth class to verify that an exception was logged instead of an error. - Updated the test to check for the correct logging of invalid federated auth parameters. --- tests/unit/test_sql_execution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index e516a50..54845cf 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -758,9 +758,9 @@ def test_federated_auth_params_invalid_params(self, mock_logger): # Call the function _handle_federated_auth_params(sql_alchemy_dict) - # Verify an error was logged - mock_logger.error.assert_called_once() - call_args = mock_logger.error.call_args + # Verify an exception was logged + mock_logger.exception.assert_called_once() + call_args = mock_logger.exception.call_args self.assertIn("Invalid federated auth params", call_args[0][0]) self.assertEqual(sql_alchemy_dict, original_dict) From 532a45ff50e5175b0832a8c7c86c952491c847f6 Mon Sep 17 00:00:00 2001 From: tomas Date: Mon, 22 Dec 2025 08:55:26 +0000 Subject: [PATCH 09/10] Add tests for federated auth access token KeyError --- tests/unit/test_sql_execution.py | 101 ++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index 54845cf..b3a636f 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -1,4 +1,5 @@ import base64 +import copy import datetime import io import json @@ -731,7 +732,7 @@ def test_federated_auth_params_not_present(self): }, } - original_dict = json.loads(json.dumps(sql_alchemy_dict)) + original_dict = copy.deepcopy(sql_alchemy_dict) # Call the function _handle_federated_auth_params(sql_alchemy_dict) @@ -753,7 +754,7 @@ def test_federated_auth_params_invalid_params(self, mock_logger): }, } - original_dict = json.loads(json.dumps(sql_alchemy_dict)) + original_dict = copy.deepcopy(sql_alchemy_dict) # Call the function _handle_federated_auth_params(sql_alchemy_dict) @@ -792,7 +793,7 @@ def test_federated_auth_params_unsupported_integration_type( }, } - original_dict = json.loads(json.dumps(sql_alchemy_dict)) + original_dict = copy.deepcopy(sql_alchemy_dict) # Call the function _handle_federated_auth_params(sql_alchemy_dict) @@ -804,3 +805,97 @@ def test_federated_auth_params_unsupported_integration_type( ) self.assertEqual(sql_alchemy_dict, original_dict) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_trino_missing_http_headers( + self, mock_get_credentials, mock_logger + ): + """Test that Trino federated auth logs exception when connect_args is missing http_headers.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return Trino credentials + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="trino", + accessToken="test-trino-access-token", + ) + + # Create a sql_alchemy_dict with connect_args but missing http_headers + sql_alchemy_dict = { + "url": "trino://user@localhost:8080/catalog", + "params": { + "connect_args": { + # http_headers is missing + } + }, + "federatedAuthParams": { + "integrationId": "test-integration-id", + "authContextToken": "test-auth-context-token", + }, + } + + original_dict = copy.deepcopy(sql_alchemy_dict) + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the API was called with correct params + mock_get_credentials.assert_called_once_with( + "test-integration-id", "test-auth-context-token" + ) + + # Verify an exception was logged for missing http_headers + mock_logger.exception.assert_called_once() + call_args = mock_logger.exception.call_args + self.assertIn("Invalid federated auth params", call_args[0][0]) + + # Verify the dict was not modified + self.assertEqual(sql_alchemy_dict, original_dict) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_bigquery_missing_params( + self, mock_get_credentials, mock_logger + ): + """Test that BigQuery federated auth logs exception when params key is missing.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return BigQuery credentials + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="big-query", + accessToken="test-bigquery-access-token", + ) + + # Create a sql_alchemy_dict without params key (will cause KeyError) + sql_alchemy_dict = { + "url": "bigquery://?user_supplied_client=true", + # params key is missing entirely + "federatedAuthParams": { + "integrationId": "test-bigquery-integration-id", + "authContextToken": "test-bigquery-auth-context-token", + }, + } + + original_dict = copy.deepcopy(sql_alchemy_dict) + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the API was called with correct params + mock_get_credentials.assert_called_once_with( + "test-bigquery-integration-id", "test-bigquery-auth-context-token" + ) + + # Verify an exception was logged for missing params + mock_logger.exception.assert_called_once() + call_args = mock_logger.exception.call_args + self.assertIn("Invalid federated auth params", call_args[0][0]) + + # Verify the dict was not modified + self.assertEqual(sql_alchemy_dict, original_dict) From d277ff61640da9fd8ee32a73419a814714af1962 Mon Sep 17 00:00:00 2001 From: tomas Date: Mon, 22 Dec 2025 09:02:10 +0000 Subject: [PATCH 10/10] Use deepcopy instead of dict copy --- tests/unit/test_sql_execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index b3a636f..8ce4966 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -700,7 +700,7 @@ def test_federated_auth_params_snowflake(self, mock_get_credentials, mock_logger } # Store original params to verify they remain unchanged - original_params = sql_alchemy_dict["params"].copy() + original_params = copy.deepcopy(sql_alchemy_dict["params"]) # Call the function _handle_federated_auth_params(sql_alchemy_dict)