Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 103 additions & 9 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import contextlib
import json
import logging
import re
import uuid
import warnings
from typing import Any
from urllib.parse import quote

import google.oauth2.credentials
Expand All @@ -14,6 +16,7 @@
from google.api_core.client_info import ClientInfo
from google.cloud import bigquery
from packaging.version import parse as parse_version
from pydantic import BaseModel, ValidationError
from sqlalchemy.engine import URL, create_engine, make_url
from sqlalchemy.exc import ResourceClosedError

Expand All @@ -33,6 +36,18 @@
from deepnote_toolkit.sql.sql_utils import is_single_select_query
from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url

logger = logging.getLogger(__name__)


class IntegrationFederatedAuthParams(BaseModel):
integrationId: str
authContextToken: str


class FederatedAuthResponseData(BaseModel):
integrationType: str
accessToken: str


def compile_sql_query(
skip_jinja_template_render,
Expand Down Expand Up @@ -242,11 +257,97 @@ def _generate_temporary_credentials(integration_id):

response = requests.post(url, timeout=10, headers=headers)

response.raise_for_status()

data = response.json()

return quote(data["username"]), quote(data["password"])


def _get_federated_auth_credentials(
integration_id: str, user_pod_auth_context_token: str
) -> FederatedAuthResponseData:
"""Get federated auth credentials for the given integration ID and user pod auth context token."""

url = get_absolute_userpod_api_url(
f"integrations/federated-auth-token/{integration_id}"
)

# Add project credentials in detached mode
headers = get_project_auth_headers()
headers["UserPodAuthContextToken"] = user_pod_auth_context_token

response = requests.post(url, timeout=10, headers=headers)

response.raise_for_status()

data = FederatedAuthResponseData.model_validate(response.json())

return data


def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None:
"""Apply IAM credentials to the connection URL in-place."""

if "iamParams" not in sql_alchemy_dict:
return

integration_id = sql_alchemy_dict["iamParams"]["integrationId"]

temporary_username, temporary_password = _generate_temporary_credentials(
integration_id
)

sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
sql_alchemy_dict["url"], temporary_username, temporary_password
)


def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None:
"""Fetch and apply federated auth credentials to connection params in-place."""

if "federatedAuthParams" not in sql_alchemy_dict:
return

try:
federated_auth_params = IntegrationFederatedAuthParams.model_validate(
sql_alchemy_dict["federatedAuthParams"]
)
except ValidationError:
logger.exception("Invalid federated auth params, try updating toolkit version")
return

federated_auth = _get_federated_auth_credentials(
federated_auth_params.integrationId, federated_auth_params.authContextToken
)

if federated_auth.integrationType == "trino":
try:
sql_alchemy_dict["params"]["connect_args"]["http_headers"][
"Authorization"
] = f"Bearer {federated_auth.accessToken}"
except KeyError:
logger.exception(
"Invalid federated auth params, try updating toolkit version"
)
elif federated_auth.integrationType == "big-query":
try:
sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken
except KeyError:
logger.exception(
"Invalid federated auth params, try updating toolkit version"
)
elif federated_auth.integrationType == "snowflake":
logger.warning(
"Snowflake federated auth is not supported yet, using the original connection URL"
)
else:
logger.error(
"Unsupported integration type: %s, try updating toolkit version",
federated_auth.integrationType,
)


@contextlib.contextmanager
def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict):
server = None
Expand Down Expand Up @@ -346,16 +447,9 @@ def _query_data_source(
):
sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False)

if "iamParams" in sql_alchemy_dict:
integration_id = sql_alchemy_dict["iamParams"]["integrationId"]
_handle_iam_params(sql_alchemy_dict)

temporaryUsername, temporaryPassword = _generate_temporary_credentials(
integration_id
)

sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
sql_alchemy_dict["url"], temporaryUsername, temporaryPassword
)
_handle_federated_auth_params(sql_alchemy_dict)

with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url:
if url is None:
Expand Down
Loading