diff --git a/sagemaker-core/src/sagemaker/core/resources.py b/sagemaker-core/src/sagemaker/core/resources.py index 66b13e112a..b8d14c0a00 100644 --- a/sagemaker-core/src/sagemaker/core/resources.py +++ b/sagemaker-core/src/sagemaker/core/resources.py @@ -11818,6 +11818,7 @@ def get_all( list_method_kwargs=operation_input_args, ) + ''' class EndpointConfigInternal(Base): """ @@ -12510,6 +12511,7 @@ def get_all( ) ''' + class Experiment(Base): """ Class representing resource Experiment @@ -13036,6 +13038,353 @@ def get_name(self) -> str: logger.error("Name attribute not found for object feature_group") return None + @staticmethod + def _s3_uri_to_arn(s3_uri: str) -> str: + """ + Convert S3 URI to S3 ARN format. + + Args: + s3_uri: S3 URI in format s3://bucket/path or already an ARN + + Returns: + S3 ARN in format arn:aws:s3:::bucket/path + """ + if s3_uri.startswith("arn:aws:s3:::"): + return s3_uri + path = s3_uri.replace("s3://", "") + return f"arn:aws:s3:::{path}" + + def _get_lake_formation_client( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ): + """ + Get a Lake Formation client. + + Args: + session: Boto3 session. If not provided, a new session will be created. + region: AWS region name. + + Returns: + A boto3 Lake Formation client. + """ + boto_session = session or Session() + return boto_session.client("lakeformation", region_name=region) + + def _register_s3_with_lake_formation( + self, + s3_location: str, + session: Optional[Session] = None, + region: Optional[str] = None, + use_service_linked_role: bool = True, + role_arn: Optional[str] = None, + ) -> bool: + """ + Register an S3 location with Lake Formation. + + Args: + s3_location: S3 URI or ARN to register. + session: Boto3 session. + region: AWS region. + use_service_linked_role: Whether to use the Lake Formation service-linked role. + If True, Lake Formation uses its service-linked role for registration. + If False, role_arn must be provided. + role_arn: IAM role ARN to use for registration. Required when + use_service_linked_role is False. + + Returns: + True if registration succeeded or location already registered. + + Raises: + ValueError: If use_service_linked_role is False but role_arn is not provided. + ClientError: If registration fails for unexpected reasons. + """ + if not use_service_linked_role and not role_arn: + raise ValueError("role_arn must be provided when use_service_linked_role is False") + + client = self._get_lake_formation_client(session, region) + resource_arn = self._s3_uri_to_arn(s3_location) + + try: + register_params = {"ResourceArn": resource_arn} + + if use_service_linked_role: + register_params["UseServiceLinkedRole"] = True + else: + register_params["RoleArn"] = role_arn + + client.register_resource(**register_params) + logger.info(f"Successfully registered S3 location: {resource_arn}") + return True + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "AlreadyExistsException": + logger.info(f"S3 location already registered: {resource_arn}") + return True + raise + + def _revoke_iam_allowed_principal( + self, + database_name: str, + table_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> bool: + """ + Revoke IAMAllowedPrincipal permissions from a Glue table. + + Args: + database_name: Glue database name. + table_name: Glue table name. + session: Boto3 session. + region: AWS region. + + Returns: + True if revocation succeeded or permissions didn't exist. + + Raises: + ClientError: If revocation fails for unexpected reasons. + """ + client = self._get_lake_formation_client(session, region) + + try: + client.revoke_permissions( + Principal={"DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS"}, + Resource={ + "Table": { + "DatabaseName": database_name, + "Name": table_name, + } + }, + Permissions=["ALL"], + ) + logger.info(f"Revoked IAMAllowedPrincipal from table: {database_name}.{table_name}") + return True + except botocore.exceptions.ClientError as e: + # if the Table doesn't have that permission because the user already revoked it + # then just return True + if e.response["Error"]["Code"] == "InvalidInputException": + logger.info( + f"IAMAllowedPrincipal permissions may not exist on: {database_name}.{table_name}" + ) + return True + raise + + def _grant_lake_formation_permissions( + self, + role_arn: str, + database_name: str, + table_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> bool: + """ + Grant permissions to a role on a Glue table via Lake Formation. + + Args: + role_arn: IAM role ARN to grant permissions to. + database_name: Glue database name. + table_name: Glue table name. + session: Boto3 session. + region: AWS region. + + Returns: + True if grant succeeded or permissions already exist. + + Raises: + ClientError: If grant fails for unexpected reasons. + """ + client = self._get_lake_formation_client(session, region) + permissions = ["SELECT", "INSERT", "DELETE", "DESCRIBE", "ALTER"] + + try: + client.grant_permissions( + Principal={"DataLakePrincipalIdentifier": role_arn}, + Resource={ + "Table": { + "DatabaseName": database_name, + "Name": table_name, + } + }, + Permissions=permissions, + PermissionsWithGrantOption=[], + ) + logger.info(f"Granted permissions to {role_arn} on table: {database_name}.{table_name}") + return True + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "InvalidInputException": + logger.info( + f"Permissions may already exist for {role_arn} on: {database_name}.{table_name}" + ) + return True + raise + + @Base.add_validate_call + def enable_lake_formation( + self, + session: Optional[Session] = None, + region: Optional[StrPipeVar] = None, + use_service_linked_role: bool = True, + registration_role_arn: Optional[str] = None, + wait_for_active: bool = False, + ) -> dict: + """ + Enable Lake Formation governance for this Feature Group's offline store. + + This method: + 1. Optionally waits for Feature Group to reach 'Created' status + 2. Validates Feature Group status is 'Created' + 3. Registers the offline store S3 location as data lake location + 4. Grants the execution role permissions on the Glue table + 5. Revokes IAMAllowedPrincipal permissions from the Glue table + + The role ARN is automatically extracted from the Feature Group's configuration. + Each phase depends on the success of the previous phase - if any phase fails, + subsequent phases are not executed. + + Parameters: + session: Boto3 session. + region: Region name. + use_service_linked_role: Whether to use the Lake Formation service-linked role + for S3 registration. If True, Lake Formation uses its service-linked role. + If False, registration_role_arn must be provided. Default is True. + registration_role_arn: IAM role ARN to use for S3 registration with Lake Formation. + Required when use_service_linked_role is False. This can be different from the + Feature Group's execution role (role_arn) + wait_for_active: If True, waits for the Feature Group to reach 'Created' status + before enabling Lake Formation. Default is False. + + Returns: + Dict with status of each Lake Formation operation: + - s3_registration: bool + - iam_principal_revoked: bool + - permissions_granted: bool + + Raises: + ValueError: If the Feature Group has no offline store configured, + if role_arn is not set on the Feature Group, if use_service_linked_role + is False but registration_role_arn is not provided, or if the Feature Group + is not in 'Created' status. + ClientError: If Lake Formation operations fail. + RuntimeError: If a phase fails and subsequent phases cannot proceed. + """ + # Wait for Created status if requested + if wait_for_active: + self.wait_for_status(target_status="Created") + + # Refresh to get latest state + self.refresh() + + # Validate Feature Group status + if self.feature_group_status not in ["Created"]: + raise ValueError( + f"Feature Group '{self.feature_group_name}' must be in 'Created' status " + f"to enable Lake Formation. Current status: '{self.feature_group_status}'. " + f"Use wait_for_active=True to automatically wait for the Feature Group to be ready." + ) + + # Validate offline store exists + if self.offline_store_config is None or self.offline_store_config == Unassigned(): + raise ValueError( + f"Feature Group '{self.feature_group_name}' does not have an offline store configured. " + "Lake Formation can only be enabled for Feature Groups with offline stores." + ) + + # Get role ARN from Feature Group config + if self.role_arn is None or self.role_arn == Unassigned(): + raise ValueError( + f"Feature Group '{self.feature_group_name}' does not have a role_arn configured. " + "Lake Formation requires a role ARN to grant permissions." + ) + if not use_service_linked_role and registration_role_arn is None: + raise ValueError( + "Either 'use_service_linked_role' must be True or 'registration_role_arn' must be provided." + ) + + # Extract required configuration + s3_config = self.offline_store_config.s3_storage_config + if s3_config is None: + raise ValueError("Offline store S3 configuration is missing") + + resolved_s3_uri = s3_config.resolved_output_s3_uri + if resolved_s3_uri is None or resolved_s3_uri == Unassigned(): + raise ValueError( + "Resolved S3 URI not available. Ensure the Feature Group is in 'Created' status." + ) + + data_catalog_config = self.offline_store_config.data_catalog_config + if data_catalog_config is None: + raise ValueError("Data catalog configuration is missing from offline store config") + + database_name = data_catalog_config.database + table_name = data_catalog_config.table_name + + if not database_name or not table_name: + raise ValueError("Database name and table name are required from data catalog config") + + # Execute Lake Formation setup with fail-fast behavior + results = { + "s3_registration": False, + "iam_principal_revoked": False, + "permissions_granted": False, + } + + # Phase 1: Register S3 with Lake Formation + try: + results["s3_registration"] = self._register_s3_with_lake_formation( + resolved_s3_uri, + session, + region, + use_service_linked_role=use_service_linked_role, + role_arn=registration_role_arn, + ) + except Exception as e: + raise RuntimeError( + f"Failed to register S3 location with Lake Formation. " + f"Subsequent phases skipped. Results: {results}. Error: {e}" + ) from e + + if not results["s3_registration"]: + raise RuntimeError( + f"Failed to register S3 location with Lake Formation. " + f"Subsequent phases skipped. Results: {results}" + ) + + # Phase 2: Grant Lake Formation permissions to the role + try: + results["permissions_granted"] = self._grant_lake_formation_permissions( + self.role_arn, database_name, table_name, session, region + ) + except Exception as e: + raise RuntimeError( + f"Failed to grant Lake Formation permissions. " + f"Subsequent phases skipped. Results: {results}. Error: {e}" + ) from e + + if not results["permissions_granted"]: + raise RuntimeError( + f"Failed to grant Lake Formation permissions. " + f"Subsequent phases skipped. Results: {results}" + ) + + # Phase 3: Revoke IAMAllowedPrincipal permissions + try: + results["iam_principal_revoked"] = self._revoke_iam_allowed_principal( + database_name, table_name, session, region + ) + except Exception as e: + raise RuntimeError( + f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}. Error: {e}" + ) from e + + if not results["iam_principal_revoked"]: + raise RuntimeError( + f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}" + ) + + logger.info(f"Lake Formation setup complete for {self.feature_group_name}: {results}") + return results + def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): @@ -13075,6 +13424,7 @@ def create( description: Optional[StrPipeVar] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), use_pre_prod_offline_store_replicator_lambda: Optional[bool] = Unassigned(), + enable_lake_formation: bool = False, session: Optional[Session] = None, region: Optional[StrPipeVar] = None, ) -> Optional["FeatureGroup"]: @@ -13093,6 +13443,7 @@ def create( description: A free-form description of a FeatureGroup. tags: Tags used to identify Features in each FeatureGroup. use_pre_prod_offline_store_replicator_lambda: + enable_lake_formation: If True, enables Lake Formation governance for the offline store. Requires offline_store_config and role_arn to be set. The role_arn will be extracted from the Feature Group configuration. Defaults to False. session: Boto3 session. region: Region name. @@ -13116,6 +13467,15 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ + # Validation for Lake Formation + if enable_lake_formation: + if offline_store_config == Unassigned() or offline_store_config is None: + raise ValueError( + "enable_lake_formation=True requires offline_store_config to be configured" + ) + if role_arn == Unassigned() or role_arn is None: + raise ValueError("enable_lake_formation=True requires role_arn to be specified") + logger.info("Creating feature_group resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" @@ -13148,7 +13508,16 @@ def create( response = client.create_feature_group(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(feature_group_name=feature_group_name, session=session, region=region) + feature_group = cls.get( + feature_group_name=feature_group_name, session=session, region=region + ) + + # Enable Lake Formation if requested + if enable_lake_formation: + feature_group.wait_for_status(target_status="Created") + feature_group.enable_lake_formation(session=session, region=region) + + return feature_group @classmethod @Base.add_validate_call @@ -24630,6 +24999,7 @@ def get_all( list_method_kwargs=operation_input_args, ) + ''' class ModelInternal(Base): """ @@ -24757,6 +25127,8 @@ def delete( logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") ''' + + class ModelPackage(Base): """ Class representing resource ModelPackage @@ -32242,6 +32614,7 @@ def get_all( list_method_kwargs=operation_input_args, ) + ''' class ProcessingJobInternal(Base): """ @@ -32544,6 +32917,8 @@ def stop(self) -> None: logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") ''' + + class Project(Base): """ Class representing resource Project @@ -36028,6 +36403,7 @@ def get_all( list_method_kwargs=operation_input_args, ) + ''' class TrainingJobInternal(Base): """ @@ -36359,6 +36735,8 @@ def stop(self) -> None: logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") ''' + + class TrainingPlan(Base): """ Class representing resource TrainingPlan @@ -37380,6 +37758,7 @@ def get_all( list_method_kwargs=operation_input_args, ) + ''' class TransformJobInternal(Base): """ @@ -37632,6 +38011,8 @@ def stop(self) -> None: logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") ''' + + class Trial(Base): """ Class representing resource Trial diff --git a/sagemaker-core/tests/integ/feature_store/__init__.py b/sagemaker-core/tests/integ/feature_store/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-core/tests/integ/feature_store/test_lake_formation_integ.py b/sagemaker-core/tests/integ/feature_store/test_lake_formation_integ.py new file mode 100644 index 0000000000..1334c85028 --- /dev/null +++ b/sagemaker-core/tests/integ/feature_store/test_lake_formation_integ.py @@ -0,0 +1,454 @@ +""" +Integration tests for Lake Formation with FeatureGroup. + +These tests require: +- AWS credentials with Lake Formation and SageMaker permissions +- An S3 bucket for offline store (uses default SageMaker bucket) +- An IAM role for Feature Store (uses execution role) + +Run with: pytest tests/integ/feature_store/test_lake_formation_integ.py -v -m integ +""" + +import uuid + +import boto3 +import pytest +from botocore.exceptions import ClientError + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.core.resources import FeatureGroup +from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + OnlineStoreConfig, + S3StorageConfig, +) + +feature_definitions = [ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + FeatureDefinition(feature_name="feature_value", feature_type="Fractional"), +] + + +@pytest.fixture(scope="module") +def sagemaker_session(): + return Session() + + +@pytest.fixture(scope="module") +def role(sagemaker_session): + return get_execution_role(sagemaker_session) + + +@pytest.fixture(scope="module") +def s3_uri(sagemaker_session): + bucket = sagemaker_session.default_bucket() + return f"s3://{bucket}/feature-store-test" + + +@pytest.fixture(scope="module") +def region(): + return "us-west-2" + + +@pytest.fixture(scope="module") +def shared_feature_group_for_negative_tests(s3_uri, role, region): + """ + Create a single FeatureGroup for negative tests that only need to verify + error conditions without modifying the resource. + + This fixture is module-scoped to be created once and shared across tests, + reducing test execution time. + """ + fg_name = f"test-lf-negative-{uuid.uuid4().hex[:8]}" + fg = None + + try: + fg = create_test_feature_group(fg_name, s3_uri, role, region) + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + yield fg + finally: + if fg: + cleanup_feature_group(fg) + + +def generate_feature_group_name(): + """Generate a unique feature group name for testing.""" + return f"test-lf-fg-{uuid.uuid4().hex[:8]}" + + +def create_test_feature_group(name: str, s3_uri: str, role_arn: str, region: str) -> FeatureGroup: + """Create a FeatureGroup with offline store for testing.""" + + offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) + + fg = FeatureGroup.create( + feature_group_name=name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=role_arn, + region=region, + ) + + return fg + + +def cleanup_feature_group(fg: FeatureGroup): + """ + Delete a FeatureGroup and its associated Glue table. + + Args: + fg: The FeatureGroup to delete. + """ + try: + # Delete the Glue table if it exists + if fg.offline_store_config is not None: + try: + fg.refresh() # Ensure we have latest config + data_catalog_config = fg.offline_store_config.data_catalog_config + if data_catalog_config is not None: + database_name = data_catalog_config.database + table_name = data_catalog_config.table_name + + if database_name and table_name: + glue_client = boto3.client("glue") + try: + glue_client.delete_table(DatabaseName=database_name, Name=table_name) + except ClientError as e: + # Ignore if table doesn't exist + if e.response["Error"]["Code"] != "EntityNotFoundException": + raise + except Exception: + # Don't fail cleanup if Glue table deletion fails + pass + + # Delete the FeatureGroup + fg.delete() + except ClientError: + # Don't fail cleanup if Glue table deletion fails + pass + + +@pytest.mark.serial +@pytest.mark.slow_test +def test_create_feature_group_and_enable_lake_formation(s3_uri, role, region): + """ + Test creating a FeatureGroup and enabling Lake Formation governance. + + This test: + 1. Creates a new FeatureGroup with offline store + 2. Waits for it to reach Created status + 3. Enables Lake Formation governance (registers S3, grants permissions, revokes IAM principals) + 4. Cleans up the FeatureGroup + """ + + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup + fg = create_test_feature_group(fg_name, s3_uri, role, region) + assert fg is not None + + # Wait for Created status + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + assert fg.feature_group_status == "Created" + + # Enable Lake Formation governance + result = fg.enable_lake_formation() + + # Verify all phases completed successfully + assert result["s3_registration"] is True + assert result["permissions_granted"] is True + assert result["iam_principal_revoked"] is True + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +@pytest.mark.serial +@pytest.mark.slow_test +def test_create_feature_group_with_lake_formation_enabled(s3_uri, role, region): + """ + Test creating a FeatureGroup with enable_lake_formation=True. + + This test verifies the integrated workflow where Lake Formation is enabled + automatically during FeatureGroup creation: + 1. Creates a new FeatureGroup with enable_lake_formation=True + 2. Verifies the FeatureGroup is created and Lake Formation is configured + 3. Cleans up the FeatureGroup + """ + + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup with Lake Formation enabled + + offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) + + fg = FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=role, + enable_lake_formation=True, + ) + + # Verify the FeatureGroup was created + assert fg is not None + assert fg.feature_group_name == fg_name + assert fg.feature_group_status == "Created" + + # Verify Lake Formation is configured by checking we can refresh without errors + fg.refresh() + assert fg.offline_store_config is not None + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +@pytest.mark.serial +def test_create_feature_group_without_lake_formation(s3_uri, role, region): + """ + Test creating a FeatureGroup without Lake Formation enabled. + + This test verifies that when enable_lake_formation is False or not specified, + the FeatureGroup is created successfully without any Lake Formation operations: + 1. Creates a new FeatureGroup with enable_lake_formation=False (default) + 2. Verifies the FeatureGroup is created successfully + 3. Verifies no Lake Formation operations were performed + 4. Cleans up the FeatureGroup + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup without Lake Formation + offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) + + # Create with enable_lake_formation=False (explicit) + fg = FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=role, + enable_lake_formation=False, + ) + + # Verify the FeatureGroup was created + assert fg is not None + assert fg.feature_group_name == fg_name + + # Wait for Created status to ensure it's fully provisioned + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + assert fg.feature_group_status == "Created" + + # Verify offline store is configured + fg.refresh() + assert fg.offline_store_config is not None + assert fg.offline_store_config.s3_storage_config is not None + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +# ============================================================================ +# Negative Integration Tests +# ============================================================================ + + +def test_create_feature_group_with_lake_formation_fails_without_offline_store(role, region): + """ + Test that creating a FeatureGroup with enable_lake_formation=True fails + when no offline store is configured. + + Expected behavior: ValueError should be raised indicating offline store is required. + """ + fg_name = generate_feature_group_name() + + # Attempt to create without offline store but with Lake Formation enabled + with pytest.raises(ValueError) as exc_info: + FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + enable_lake_formation=True, + ) + + # Verify error message mentions offline_store_config requirement + assert "enable_lake_formation=True requires offline_store_config to be configured" in str( + exc_info.value + ) + + +def test_create_feature_group_with_lake_formation_fails_without_role(s3_uri, region): + """ + Test that creating a FeatureGroup with enable_lake_formation=True fails + when no role_arn is provided. + + Expected behavior: ValueError should be raised indicating role_arn is required. + """ + fg_name = generate_feature_group_name() + + offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) + + # Attempt to create without role_arn but with Lake Formation enabled + with pytest.raises(ValueError) as exc_info: + FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + enable_lake_formation=True, + ) + + # Verify error message mentions role_arn requirement + assert "enable_lake_formation=True requires role_arn to be specified" in str(exc_info.value) + + +def test_enable_lake_formation_fails_for_non_created_status(s3_uri, role, region): + """ + Test that enable_lake_formation() fails when called on a FeatureGroup + that is not in 'Created' status. + + Expected behavior: ValueError should be raised indicating the Feature Group + must be in 'Created' status. + + Note: This test creates its own FeatureGroup because it needs to test + behavior during the 'Creating' status, which requires a fresh resource. + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup + fg = create_test_feature_group(fg_name, s3_uri, role, region) + assert fg is not None + + # Immediately try to enable Lake Formation without waiting for Created status + # The Feature Group will be in 'Creating' status + with pytest.raises(ValueError) as exc_info: + fg.enable_lake_formation(wait_for_active=False) + + # Verify error message mentions status requirement + error_msg = str(exc_info.value) + assert "must be in 'Created' status to enable Lake Formation" in error_msg + + finally: + # Cleanup + if fg: + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + cleanup_feature_group(fg) + + +def test_enable_lake_formation_without_offline_store(role, region): + """ + Test that enable_lake_formation() fails when called on a FeatureGroup + without an offline store configured. + + Expected behavior: ValueError should be raised indicating offline store is required. + + Note: This test creates a FeatureGroup with only online store, which is a valid + configuration, but Lake Formation cannot be enabled for it. + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create a FeatureGroup with only online store (no offline store) + online_store_config = OnlineStoreConfig(enable_online_store=True) + + fg = FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + online_store_config=online_store_config, + role_arn=role, + ) + + # Wait for Created status + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + + # Attempt to enable Lake Formation + with pytest.raises(ValueError) as exc_info: + fg.enable_lake_formation() + # Verify error message mentions offline store requirement + assert "does not have an offline store configured" in str(exc_info.value) + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +def test_enable_lake_formation_fails_with_invalid_registration_role( + shared_feature_group_for_negative_tests, +): + """ + Test that enable_lake_formation() fails when use_service_linked_role=False + but no registration_role_arn is provided. + + Expected behavior: ValueError should be raised indicating registration_role_arn + is required when not using service-linked role. + """ + fg = shared_feature_group_for_negative_tests + + # Attempt to enable Lake Formation without service-linked role and without registration_role_arn + with pytest.raises(ValueError) as exc_info: + fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=None, + ) + + # Verify error message mentions role requirement + error_msg = str(exc_info.value) + assert "registration_role_arn" in error_msg + + +def test_enable_lake_formation_fails_with_nonexistent_role( + shared_feature_group_for_negative_tests, role +): + """ + Test that enable_lake_formation() properly bubbles errors when using + a nonexistent role ARN for Lake Formation registration. + + Expected behavior: RuntimeError or ClientError should be raised with details + about the registration failure. + + Note: This test uses a nonexistent role ARN (current role with random suffix) + to trigger an error during S3 registration with Lake Formation. + """ + fg = shared_feature_group_for_negative_tests + + # Generate a nonexistent role ARN by appending a random string to the current role + nonexistent_role = f"{role}-nonexistent-{uuid.uuid4().hex[:8]}" + + with pytest.raises(RuntimeError) as exc_info: + fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=nonexistent_role, + ) + + # Verify we got an appropriate error + error_msg = str(exc_info.value) + print(exc_info) + # Should mention role-related issues (not found, invalid, access denied, etc.) + assert "EntityNotFoundException" in error_msg diff --git a/sagemaker-core/tests/unit/test_feature_store_lakeformation.py b/sagemaker-core/tests/unit/test_feature_store_lakeformation.py new file mode 100644 index 0000000000..974ae2b820 --- /dev/null +++ b/sagemaker-core/tests/unit/test_feature_store_lakeformation.py @@ -0,0 +1,1018 @@ +"""Unit tests for Lake Formation integration with FeatureGroup.""" +from unittest.mock import MagicMock, patch + +import botocore.exceptions +import pytest + +from sagemaker.core.resources import FeatureGroup + + +class TestS3UriToArn: + """Tests for _s3_uri_to_arn static method.""" + + def test_converts_s3_uri_to_arn(self): + """Test S3 URI is converted to ARN format.""" + uri = "s3://my-bucket/my-prefix/data" + result = FeatureGroup._s3_uri_to_arn(uri) + assert result == "arn:aws:s3:::my-bucket/my-prefix/data" + + def test_handles_bucket_only_uri(self): + """Test S3 URI with bucket only.""" + uri = "s3://my-bucket" + result = FeatureGroup._s3_uri_to_arn(uri) + assert result == "arn:aws:s3:::my-bucket" + + def test_returns_arn_unchanged(self): + """Test ARN input is returned unchanged (idempotent).""" + arn = "arn:aws:s3:::my-bucket/path" + result = FeatureGroup._s3_uri_to_arn(arn) + assert result == arn + + +class TestGetLakeFormationClient: + """Tests for _get_lake_formation_client method.""" + + @patch("sagemaker.core.resources.Session") + def test_creates_client_with_default_session(self, mock_session_class): + """Test client creation with default session.""" + mock_session = MagicMock() + mock_client = MagicMock() + mock_session.client.return_value = mock_client + mock_session_class.return_value = mock_session + + fg = MagicMock(spec=FeatureGroup) + fg._get_lake_formation_client = FeatureGroup._get_lake_formation_client.__get__(fg) + + client = fg._get_lake_formation_client(region="us-west-2") + + mock_session.client.assert_called_with("lakeformation", region_name="us-west-2") + assert client == mock_client + + def test_creates_client_with_provided_session(self): + """Test client creation with provided session.""" + mock_session = MagicMock() + mock_client = MagicMock() + mock_session.client.return_value = mock_client + + fg = MagicMock(spec=FeatureGroup) + fg._get_lake_formation_client = FeatureGroup._get_lake_formation_client.__get__(fg) + + client = fg._get_lake_formation_client(session=mock_session, region="us-west-2") + + mock_session.client.assert_called_with("lakeformation", region_name="us-west-2") + assert client == mock_client + + +class TestRegisterS3WithLakeFormation: + """Tests for _register_s3_with_lake_formation method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fg = MagicMock(spec=FeatureGroup) + self.fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn + self.fg._register_s3_with_lake_formation = ( + FeatureGroup._register_s3_with_lake_formation.__get__(self.fg) + ) + self.mock_client = MagicMock() + self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) + + def test_successful_registration_returns_true(self): + """Test successful S3 registration returns True.""" + self.mock_client.register_resource.return_value = {} + + result = self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") + + assert result is True + self.mock_client.register_resource.assert_called_with( + ResourceArn="arn:aws:s3:::test-bucket/prefix", + UseServiceLinkedRole=True, + ) + + def test_already_exists_exception_returns_true(self): + """Test AlreadyExistsException is handled gracefully.""" + self.mock_client.register_resource.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AlreadyExistsException", "Message": "Already exists"}}, + "RegisterResource", + ) + + result = self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") + + assert result is True + + def test_other_exceptions_are_propagated(self): + """Test non-AlreadyExistsException errors are propagated.""" + self.mock_client.register_resource.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, + "RegisterResource", + ) + + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") + + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + + def test_uses_service_linked_role(self): + """Test UseServiceLinkedRole is set to True.""" + self.mock_client.register_resource.return_value = {} + + self.fg._register_s3_with_lake_formation("s3://bucket/path") + + call_args = self.mock_client.register_resource.call_args + assert call_args[1]["UseServiceLinkedRole"] is True + + def test_uses_custom_role_arn_when_service_linked_role_disabled(self): + """Test custom role ARN is used when use_service_linked_role is False.""" + self.mock_client.register_resource.return_value = {} + custom_role = "arn:aws:iam::123456789012:role/CustomLakeFormationRole" + + self.fg._register_s3_with_lake_formation( + "s3://bucket/path", + use_service_linked_role=False, + role_arn=custom_role, + ) + + call_args = self.mock_client.register_resource.call_args + assert call_args[1]["RoleArn"] == custom_role + assert "UseServiceLinkedRole" not in call_args[1] + + def test_raises_error_when_role_arn_missing_and_service_linked_role_disabled(self): + """Test ValueError when use_service_linked_role is False but role_arn not provided.""" + with pytest.raises(ValueError) as exc_info: + self.fg._register_s3_with_lake_formation( + "s3://bucket/path", use_service_linked_role=False + ) + + assert "role_arn must be provided when use_service_linked_role is False" in str( + exc_info.value + ) + + +class TestRevokeIamAllowedPrincipal: + """Tests for _revoke_iam_allowed_principal method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fg = MagicMock(spec=FeatureGroup) + self.fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__( + self.fg + ) + self.mock_client = MagicMock() + self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) + + def test_successful_revocation_returns_true(self): + """Test successful revocation returns True.""" + self.mock_client.revoke_permissions.return_value = {} + + result = self.fg._revoke_iam_allowed_principal("test_database", "test_table") + + assert result is True + self.mock_client.revoke_permissions.assert_called_once() + + def test_revoke_permissions_call_structure(self): + """Test that revoke_permissions is called with correct parameters.""" + self.mock_client.revoke_permissions.return_value = {} + database_name = "my_database" + table_name = "my_table" + + self.fg._revoke_iam_allowed_principal(database_name, table_name) + + call_args = self.mock_client.revoke_permissions.call_args + assert call_args[1]["Principal"] == { + "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" + } + assert call_args[1]["Permissions"] == ["ALL"] + assert call_args[1]["Resource"] == { + "Table": { + "DatabaseName": database_name, + "Name": table_name, + } + } + + def test_invalid_input_exception_returns_true(self): + """Test InvalidInputException is handled gracefully (permissions may not exist).""" + self.mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "InvalidInputException", "Message": "Permissions not found"}}, + "RevokePermissions", + ) + + result = self.fg._revoke_iam_allowed_principal("test_database", "test_table") + + assert result is True + + def test_other_exceptions_are_propagated(self): + """Test non-InvalidInputException errors are propagated.""" + self.mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, + "RevokePermissions", + ) + + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + self.fg._revoke_iam_allowed_principal("test_database", "test_table") + + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + + def test_passes_session_and_region_to_client(self): + """Test session and region are passed to get_lake_formation_client.""" + self.mock_client.revoke_permissions.return_value = {} + mock_session = MagicMock() + + self.fg._revoke_iam_allowed_principal( + "test_database", "test_table", session=mock_session, region="us-west-2" + ) + + self.fg._get_lake_formation_client.assert_called_with(mock_session, "us-west-2") + + +class TestGrantLakeFormationPermissions: + """Tests for _grant_lake_formation_permissions method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fg = MagicMock(spec=FeatureGroup) + self.fg._grant_lake_formation_permissions = ( + FeatureGroup._grant_lake_formation_permissions.__get__(self.fg) + ) + self.mock_client = MagicMock() + self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) + + def test_successful_grant_returns_true(self): + """Test successful permission grant returns True.""" + self.mock_client.grant_permissions.return_value = {} + + result = self.fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" + ) + + assert result is True + self.mock_client.grant_permissions.assert_called_once() + + def test_grant_permissions_call_structure(self): + """Test that grant_permissions is called with correct parameters.""" + self.mock_client.grant_permissions.return_value = {} + role_arn = "arn:aws:iam::123456789012:role/MyExecutionRole" + + self.fg._grant_lake_formation_permissions(role_arn, "my_database", "my_table") + + call_args = self.mock_client.grant_permissions.call_args + assert call_args[1]["Principal"] == {"DataLakePrincipalIdentifier": role_arn} + assert call_args[1]["Resource"] == { + "Table": { + "DatabaseName": "my_database", + "Name": "my_table", + } + } + assert call_args[1]["Permissions"] == ["SELECT", "INSERT", "DELETE", "DESCRIBE", "ALTER"] + assert call_args[1]["PermissionsWithGrantOption"] == [] + + def test_invalid_input_exception_returns_true(self): + """Test InvalidInputException is handled gracefully (permissions may exist).""" + self.mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "InvalidInputException", "Message": "Permissions already exist"}}, + "GrantPermissions", + ) + + result = self.fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" + ) + + assert result is True + + def test_other_exceptions_are_propagated(self): + """Test non-InvalidInputException errors are propagated.""" + self.mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, + "GrantPermissions", + ) + + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + self.fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" + ) + + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + + def test_passes_session_and_region_to_client(self): + """Test session and region are passed to get_lake_formation_client.""" + self.mock_client.grant_permissions.return_value = {} + mock_session = MagicMock() + + self.fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", + "test_database", + "test_table", + session=mock_session, + region="us-west-2", + ) + + self.fg._get_lake_formation_client.assert_called_with(mock_session, "us-west-2") + + +class TestEnableLakeFormationValidation: + """Tests for enable_lake_formation validation logic.""" + + @patch.object(FeatureGroup, "refresh") + def test_raises_error_when_no_offline_store(self, mock_refresh): + """Test that enable_lake_formation raises ValueError when no offline store is configured.""" + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = None + fg.feature_group_status = "Created" + + with pytest.raises(ValueError, match="does not have an offline store configured"): + fg.enable_lake_formation() + + # Verify refresh was called + mock_refresh.assert_called_once() + + @patch.object(FeatureGroup, "refresh") + def test_raises_error_when_no_role_arn(self, mock_refresh): + """Test that enable_lake_formation raises ValueError when no role_arn is configured.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = None + fg.feature_group_status = "Created" + + with pytest.raises(ValueError, match="does not have a role_arn configured"): + fg.enable_lake_formation() + + # Verify refresh was called + mock_refresh.assert_called_once() + + @patch.object(FeatureGroup, "refresh") + def test_raises_error_when_invalid_status(self, mock_refresh): + """Test enable_lake_formation raises ValueError when Feature Group not in Created status.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" + fg.feature_group_status = "Creating" + + with pytest.raises(ValueError, match="must be in 'Created' status"): + fg.enable_lake_formation() + + # Verify refresh was called + mock_refresh.assert_called_once() + + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + def test_wait_for_active_calls_wait_for_status( + self, mock_revoke, mock_grant, mock_register, mock_refresh, mock_wait + ): + """Test that wait_for_active=True calls wait_for_status with 'Created' target.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call with wait_for_active=True + fg.enable_lake_formation(wait_for_active=True) + + # Verify wait_for_status was called with "Created" + mock_wait.assert_called_once_with(target_status="Created") + # Verify refresh was called after wait + mock_refresh.assert_called_once() + + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + def test_wait_for_active_false_does_not_call_wait( + self, mock_revoke, mock_grant, mock_register, mock_refresh, mock_wait + ): + """Test that wait_for_active=False does not call wait_for_status.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call with wait_for_active=False (default) + fg.enable_lake_formation(wait_for_active=False) + + # Verify wait_for_status was NOT called + mock_wait.assert_not_called() + # Verify refresh was still called + mock_refresh.assert_called_once() + + @pytest.mark.parametrize( + "feature_group_name,role_arn,s3_uri,database_name,table_name", + [ + ("test-fg", "TestRole", "path1", "db1", "table1"), + ("my_feature_group", "ExecutionRole", "data/features", "feature_db", "feature_table"), + ("fg123", "MyRole123", "ml/features/v1", "analytics", "features_v1"), + ("simple", "SimpleRole", "simple-path", "simple_db", "simple_table"), + ( + "complex-name", + "ComplexExecutionRole", + "complex/path/structure", + "complex_database", + "complex_table_name", + ), + ( + "underscore_name", + "Underscore_Role", + "underscore_path", + "underscore_db", + "underscore_table", + ), + ("mixed-123", "Mixed123Role", "mixed/path/123", "mixed_db_123", "mixed_table_123"), + ("x", "XRole", "x", "x", "x"), + ( + "very-long-name", + "VeryLongRoleName", + "very/long/path/structure", + "very_long_database_name", + "very_long_table_name", + ), + ], + ) + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + def test_fail_fast_phase_execution( + self, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + feature_group_name, + role_arn, + s3_uri, + database_name, + table_name, + ): + """ + Test fail-fast behavior for Lake Formation phases. + + If Phase 1 (S3 registration) fails, Phase 2 and 3 should not execute. + If Phase 2 fails, Phase 3 should not execute. + RuntimeError should indicate which phase failed. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name=feature_group_name) + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri=f"s3://test-bucket/{s3_uri}", + resolved_output_s3_uri=f"s3://test-bucket/resolved-{s3_uri}", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database=database_name, table_name=table_name + ), + ) + fg.role_arn = f"arn:aws:iam::123456789012:role/{role_arn}" + fg.feature_group_status = "Created" + + # Test Phase 1 failure - subsequent phases should not be called + mock_register.side_effect = Exception("Phase 1 failed") + mock_grant.return_value = True + mock_revoke.return_value = True + + with pytest.raises( + RuntimeError, match="Failed to register S3 location with Lake Formation" + ): + fg.enable_lake_formation() + + # Verify Phase 1 was called but Phase 2 and 3 were not + mock_register.assert_called_once() + mock_grant.assert_not_called() + mock_revoke.assert_not_called() + + # Reset mocks for Phase 2 failure test + mock_register.reset_mock() + mock_grant.reset_mock() + mock_revoke.reset_mock() + + # Test Phase 2 failure - Phase 3 should not be called + mock_register.side_effect = None + mock_register.return_value = True + mock_grant.side_effect = Exception("Phase 2 failed") + mock_revoke.return_value = True + + with pytest.raises(RuntimeError, match="Failed to grant Lake Formation permissions"): + fg.enable_lake_formation() + + # Verify Phase 1 and 2 were called but Phase 3 was not + mock_register.assert_called_once() + mock_grant.assert_called_once() + mock_revoke.assert_not_called() + + # Reset mocks for Phase 3 failure test + mock_register.reset_mock() + mock_grant.reset_mock() + mock_revoke.reset_mock() + + # Test Phase 3 failure - all phases should be called + mock_register.side_effect = None + mock_register.return_value = True + mock_grant.side_effect = None + mock_grant.return_value = True + mock_revoke.side_effect = Exception("Phase 3 failed") + + with pytest.raises(RuntimeError, match="Failed to revoke IAMAllowedPrincipal permissions"): + fg.enable_lake_formation() + + # Verify all phases were called + mock_register.assert_called_once() + mock_grant.assert_called_once() + mock_revoke.assert_called_once() + + +class TestUnhandledExceptionPropagation: + """Tests for proper propagation of unhandled boto3 exceptions.""" + + def test_register_s3_propagates_unhandled_exceptions(self): + """ + Non-AlreadyExists Errors Propagate from S3 Registration + + For any error from Lake Formation's register_resource API that is not + AlreadyExistsException, the error should be propagated to the caller unchanged. + + """ + fg = MagicMock(spec=FeatureGroup) + fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn + fg._register_s3_with_lake_formation = FeatureGroup._register_s3_with_lake_formation.__get__( + fg + ) + mock_client = MagicMock() + fg._get_lake_formation_client = MagicMock(return_value=mock_client) + + # Configure mock to raise an unhandled error + mock_client.register_resource.side_effect = botocore.exceptions.ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User does not have permission", + } + }, + "RegisterResource", + ) + + # Verify the exception is propagated unchanged + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + fg._register_s3_with_lake_formation("s3://test-bucket/path") + + # Verify error details are preserved + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + assert exc_info.value.response["Error"]["Message"] == "User does not have permission" + assert exc_info.value.operation_name == "RegisterResource" + + def test_revoke_iam_principal_propagates_unhandled_exceptions(self): + """ + Non-InvalidInput Errors Propagate from IAM Principal Revocation + + For any error from Lake Formation's revoke_permissions API that is not + InvalidInputException, the error should be propagated to the caller unchanged. + + """ + fg = MagicMock(spec=FeatureGroup) + fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__(fg) + mock_client = MagicMock() + fg._get_lake_formation_client = MagicMock(return_value=mock_client) + + # Configure mock to raise an unhandled error + mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User does not have permission", + } + }, + "RevokePermissions", + ) + + # Verify the exception is propagated unchanged + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + fg._revoke_iam_allowed_principal("test_database", "test_table") + + # Verify error details are preserved + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + assert exc_info.value.response["Error"]["Message"] == "User does not have permission" + assert exc_info.value.operation_name == "RevokePermissions" + + def test_grant_permissions_propagates_unhandled_exceptions(self): + """ + Non-InvalidInput Errors Propagate from Permission Grant + + For any error from Lake Formation's grant_permissions API that is not + InvalidInputException, the error should be propagated to the caller unchanged. + + """ + fg = MagicMock(spec=FeatureGroup) + fg._grant_lake_formation_permissions = ( + FeatureGroup._grant_lake_formation_permissions.__get__(fg) + ) + mock_client = MagicMock() + fg._get_lake_formation_client = MagicMock(return_value=mock_client) + + # Configure mock to raise an unhandled error + mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User does not have permission", + } + }, + "GrantPermissions", + ) + + # Verify the exception is propagated unchanged + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" + ) + + # Verify error details are preserved + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + assert exc_info.value.response["Error"]["Message"] == "User does not have permission" + assert exc_info.value.operation_name == "GrantPermissions" + + def test_handled_exceptions_do_not_propagate(self): + """ + Verify that specifically handled exceptions (AlreadyExistsException, InvalidInputException) + do NOT propagate but return True instead, while all other exceptions are propagated. + """ + fg = MagicMock(spec=FeatureGroup) + fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn + fg._register_s3_with_lake_formation = FeatureGroup._register_s3_with_lake_formation.__get__( + fg + ) + fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__(fg) + fg._grant_lake_formation_permissions = ( + FeatureGroup._grant_lake_formation_permissions.__get__(fg) + ) + mock_client = MagicMock() + fg._get_lake_formation_client = MagicMock(return_value=mock_client) + + # Test AlreadyExistsException is handled (not propagated) + mock_client.register_resource.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AlreadyExistsException", "Message": "Already exists"}}, + "RegisterResource", + ) + result = fg._register_s3_with_lake_formation("s3://test-bucket/path") + assert result is True # Should return True, not raise + + # Test InvalidInputException is handled for revoke (not propagated) + mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "InvalidInputException", "Message": "Invalid input"}}, + "RevokePermissions", + ) + result = fg._revoke_iam_allowed_principal("db", "table") + assert result is True # Should return True, not raise + + # Test InvalidInputException is handled for grant (not propagated) + mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "InvalidInputException", "Message": "Invalid input"}}, + "GrantPermissions", + ) + result = fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "db", "table" + ) + assert result is True # Should return True, not raise + + +class TestCreateWithLakeFormation: + """Tests for create() method with Lake Formation integration.""" + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature", + [ + ("test-fg", "record_id", "event_time"), + ("my_feature_group", "id", "timestamp"), + ("fg123", "identifier", "time"), + ("simple", "rec_id", "evt_time"), + ("complex-name", "record_identifier", "event_timestamp"), + ("underscore_name", "record_id_field", "event_time_field"), + ("mixed-123", "id_123", "time_123"), + ("x", "x_id", "x_time"), + ("very-long-name", "very_long_record_id", "very_long_event_time"), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroup, "get") + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "enable_lake_formation") + def test_no_lake_formation_operations_when_disabled( + self, + mock_enable_lf, + mock_wait, + mock_get, + mock_get_client, + feature_group_name, + record_id_feature, + event_time_feature, + ): + """ + No Lake Formation Operations When Disabled + + For any call to FeatureGroup.create() where enable_lake_formation is False or not specified, + no Lake Formation client methods should be invoked. + + """ + from sagemaker.core.shapes import FeatureDefinition + + # Mock the SageMaker client + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + # Mock the get method to return a feature group + mock_fg = MagicMock(spec=FeatureGroup) + mock_fg.feature_group_name = feature_group_name + mock_get.return_value = mock_fg + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Test 1: enable_lake_formation=False (explicit) + result = FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + enable_lake_formation=False, + ) + + # Verify enable_lake_formation was NOT called + mock_enable_lf.assert_not_called() + # Verify wait_for_status was NOT called + mock_wait.assert_not_called() + # Verify the feature group was returned + assert result == mock_fg + + # Reset mocks for next test + mock_enable_lf.reset_mock() + mock_wait.reset_mock() + mock_get.reset_mock() + mock_get.return_value = mock_fg + + # Test 2: enable_lake_formation not specified (defaults to False) + result = FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + # enable_lake_formation not specified, should default to False + ) + + # Verify enable_lake_formation was NOT called + mock_enable_lf.assert_not_called() + # Verify wait_for_status was NOT called + mock_wait.assert_not_called() + # Verify the feature group was returned + assert result == mock_fg + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature,role_arn,s3_uri,database,table", + [ + ("test-fg", "record_id", "event_time", "TestRole", "path1", "db1", "table1"), + ( + "my_feature_group", + "id", + "timestamp", + "ExecutionRole", + "data/features", + "feature_db", + "feature_table", + ), + ( + "fg123", + "identifier", + "time", + "MyRole123", + "ml/features/v1", + "analytics", + "features_v1", + ), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroup, "get") + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "enable_lake_formation") + def test_enable_lake_formation_called_when_enabled( + self, + mock_enable_lf, + mock_wait, + mock_get, + mock_get_client, + feature_group_name, + record_id_feature, + event_time_feature, + role_arn, + s3_uri, + database, + table, + ): + """ + Test that enable_lake_formation is called when enable_lake_formation=True. + + This verifies the integration between create() and enable_lake_formation(). + """ + from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + ) + + # Mock the SageMaker client + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + # Mock the get method to return a feature group + mock_fg = MagicMock(spec=FeatureGroup) + mock_fg.feature_group_name = feature_group_name + mock_fg.wait_for_status = mock_wait + mock_fg.enable_lake_formation = mock_enable_lf + mock_get.return_value = mock_fg + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Create offline store config + offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://test-bucket/{s3_uri}"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database=database, table_name=table + ), + ) + + # Create with enable_lake_formation=True + result = FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=f"arn:aws:iam::123456789012:role/{role_arn}", + enable_lake_formation=True, + ) + + # Verify wait_for_status was called with "Created" + mock_wait.assert_called_once_with(target_status="Created") + # Verify enable_lake_formation was called + mock_enable_lf.assert_called_once() + # Verify the feature group was returned + assert result == mock_fg + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature", + [ + ("test-fg", "record_id", "event_time"), + ("my_feature_group", "id", "timestamp"), + ("fg123", "identifier", "time"), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_enable_lake_formation_without_offline_store( + self, mock_get_client, feature_group_name, record_id_feature, event_time_feature + ): + """Test create() raises ValueError when enable_lake_formation=True without offline_store.""" + from sagemaker.core.shapes import FeatureDefinition + + # Mock the SageMaker client + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Test with enable_lake_formation=True but no offline_store_config + with pytest.raises( + ValueError, + match="enable_lake_formation=True requires offline_store_config to be configured", + ): + FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + enable_lake_formation=True, + # offline_store_config not provided + ) + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature,s3_uri,database,table", + [ + ("test-fg", "record_id", "event_time", "path1", "db1", "table1"), + ("my_feature_group", "id", "timestamp", "data/features", "feature_db", "feature_table"), + ("fg123", "identifier", "time", "ml/features/v1", "analytics", "features_v1"), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_enable_lake_formation_without_role_arn( + self, + mock_get_client, + feature_group_name, + record_id_feature, + event_time_feature, + s3_uri, + database, + table, + ): + """Test create() raises ValueError when enable_lake_formation=True without role_arn.""" + from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + ) + + # Mock the SageMaker client + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Create offline store config + offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://test-bucket/{s3_uri}"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database=database, table_name=table + ), + ) + + # Test with enable_lake_formation=True but no role_arn + with pytest.raises( + ValueError, match="enable_lake_formation=True requires role_arn to be specified" + ): + FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + enable_lake_formation=True, + # role_arn not provided + )