diff --git a/sagemaker-core/src/sagemaker/core/__init__.py b/sagemaker-core/src/sagemaker/core/__init__.py index 27dd2e0d72..97192083a7 100644 --- a/sagemaker-core/src/sagemaker/core/__init__.py +++ b/sagemaker-core/src/sagemaker/core/__init__.py @@ -12,5 +12,8 @@ ) from sagemaker.core.transformer import Transformer # noqa: F401 +# Partner App +from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401 + # Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner # They are not re-exported from core to avoid circular dependencies diff --git a/sagemaker-core/src/sagemaker/core/partner_app/__init__.py b/sagemaker-core/src/sagemaker/core/partner_app/__init__.py new file mode 100644 index 0000000000..87ab21acec --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/partner_app/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""__init__ file for sagemaker.core.partner_app""" +from __future__ import absolute_import + +from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401 diff --git a/sagemaker-core/src/sagemaker/core/partner_app/auth_provider.py b/sagemaker-core/src/sagemaker/core/partner_app/auth_provider.py new file mode 100644 index 0000000000..7abdb71e0b --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/partner_app/auth_provider.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""The SageMaker partner application SDK auth module""" +from __future__ import absolute_import + +import os +import re +from typing import Dict, Tuple + +import boto3 +from botocore.auth import SigV4Auth +from botocore.credentials import Credentials +from requests.auth import AuthBase +from requests.models import PreparedRequest +from sagemaker.core.partner_app.auth_utils import PartnerAppAuthUtils + +SERVICE_NAME = "sagemaker" +AWS_PARTNER_APP_ARN_REGEX = r"arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:partner-app\/.*" + + +class RequestsAuth(AuthBase): + """Requests authentication class for SigV4 header generation. + + This class is used to generate the SigV4 header and add it to the request headers. + """ + + def __init__(self, sigv4: SigV4Auth, app_arn: str): + """Initialize the RequestsAuth class. + + Args: + sigv4 (SigV4Auth): SigV4Auth object + app_arn (str): Application ARN + """ + self.sigv4 = sigv4 + self.app_arn = app_arn + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + """Callback function to generate the SigV4 header and add it to the request headers. + + Args: + request (PreparedRequest): PreparedRequest object + + Returns: + PreparedRequest: PreparedRequest object with the SigV4 header added + """ + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + sigv4=self.sigv4, + app_arn=self.app_arn, + url=request.url, + method=request.method, + headers=request.headers, + body=request.body, + ) + request.url = url + request.headers.update(signed_headers) + + return request + + +class PartnerAppAuthProvider: + """The SageMaker partner application SDK auth provider class""" + + def __init__(self, credentials: Credentials = None): + """Initialize the PartnerAppAuthProvider class. + + Args: + credentials (Credentials, optional): AWS credentials. Defaults to None. + Raises: + ValueError: If the AWS_PARTNER_APP_ARN environment variable is not set or is invalid. + """ + self.app_arn = os.getenv("AWS_PARTNER_APP_ARN") + if self.app_arn is None: + raise ValueError("Must specify the AWS_PARTNER_APP_ARN environment variable") + + app_arn_regex_match = re.search(AWS_PARTNER_APP_ARN_REGEX, self.app_arn) + if app_arn_regex_match is None: + raise ValueError("Must specify a valid AWS_PARTNER_APP_ARN environment variable") + + split_arn = self.app_arn.split(":") + self.region = split_arn[3] + + self.credentials = ( + credentials if credentials is not None else boto3.Session().get_credentials() + ) + self.sigv4 = SigV4Auth(self.credentials, SERVICE_NAME, self.region) + + def get_signed_request( + self, url: str, method: str, headers: dict, body: object + ) -> Tuple[str, Dict[str, str]]: + """Generate the SigV4 header and add it to the request headers. + + Args: + url (str): Request URL + method (str): HTTP method + headers (dict): Request headers + body (object): Request body + + Returns: + tuple: (url, headers) + """ + return PartnerAppAuthUtils.get_signed_request( + sigv4=self.sigv4, + app_arn=self.app_arn, + url=url, + method=method, + headers=headers, + body=body, + ) + + def get_auth(self) -> RequestsAuth: + """Returns the callback class (RequestsAuth) used for generating the SigV4 header. + + Returns: + RequestsAuth: Callback Object which will calculate the header just before + request submission. + """ + + return RequestsAuth(self.sigv4, os.environ["AWS_PARTNER_APP_ARN"]) diff --git a/sagemaker-core/src/sagemaker/core/partner_app/auth_utils.py b/sagemaker-core/src/sagemaker/core/partner_app/auth_utils.py new file mode 100644 index 0000000000..eb1dcacaa9 --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/partner_app/auth_utils.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""Partner App Auth Utils Module""" + +from __future__ import absolute_import + +from hashlib import sha256 +import functools +from typing import Tuple, Dict + +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest + +HEADER_CONNECTION = "Connection" +HEADER_X_AMZ_TARGET = "X-Amz-Target" +HEADER_AUTHORIZATION = "Authorization" +HEADER_PARTNER_APP_SERVER_ARN = "X-SageMaker-Partner-App-Server-Arn" +HEADER_PARTNER_APP_AUTHORIZATION = "X-Amz-Partner-App-Authorization" +HEADER_X_AMZ_CONTENT_SHA_256 = "X-Amz-Content-SHA256" +CALL_PARTNER_APP_API_ACTION = "SageMaker.CallPartnerAppApi" + +PAYLOAD_BUFFER = 1024 * 1024 +EMPTY_SHA256_HASH = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" +UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" + + +class PartnerAppAuthUtils: + """Partner App Auth Utils Class""" + + @staticmethod + def get_signed_request( + sigv4: SigV4Auth, app_arn: str, url: str, method: str, headers: dict, body: object + ) -> Tuple[str, Dict[str, str]]: + """Generate the SigV4 header and add it to the request headers. + + Args: + sigv4 (SigV4Auth): SigV4Auth object + app_arn (str): Application ARN + url (str): Request URL + method (str): HTTP method + headers (dict): Request headers + body (object): Request body + Returns: + tuple: (url, headers) + """ + # Move API key to X-Amz-Partner-App-Authorization + if HEADER_AUTHORIZATION in headers: + headers[HEADER_PARTNER_APP_AUTHORIZATION] = headers[HEADER_AUTHORIZATION] + + # App Arn + headers[HEADER_PARTNER_APP_SERVER_ARN] = app_arn + + # IAM Action + headers[HEADER_X_AMZ_TARGET] = CALL_PARTNER_APP_API_ACTION + + # Body + headers[HEADER_X_AMZ_CONTENT_SHA_256] = PartnerAppAuthUtils.get_body_header(body) + + # Connection header is excluded from server-side signature calculation + connection_header = headers[HEADER_CONNECTION] if HEADER_CONNECTION in headers else None + + if HEADER_CONNECTION in headers: + del headers[HEADER_CONNECTION] + + # Spaces are encoded as %20 + url = url.replace("+", "%20") + + # Calculate SigV4 header + aws_request = AWSRequest( + method=method, + url=url, + headers=headers, + data=body, + ) + sigv4.add_auth(aws_request) + + # Reassemble headers + final_headers = dict(aws_request.headers.items()) + if connection_header is not None: + final_headers[HEADER_CONNECTION] = connection_header + + return (url, final_headers) + + @staticmethod + def get_body_header(body: object): + """Calculate the body header for the SigV4 header. + + Args: + body (object): Request body + """ + if body and hasattr(body, "seek"): + position = body.tell() + read_chunksize = functools.partial(body.read, PAYLOAD_BUFFER) + checksum = sha256() + for chunk in iter(read_chunksize, b""): + checksum.update(chunk) + hex_checksum = checksum.hexdigest() + body.seek(position) + return hex_checksum + + if body and not isinstance(body, bytes): + # Body is of a class we don't recognize, so don't sign the payload + return UNSIGNED_PAYLOAD + + if body: + # The request serialization has ensured that + # request.body is a bytes() type. + return sha256(body).hexdigest() + + # Body is None + return EMPTY_SHA256_HASH diff --git a/sagemaker-core/tests/unit/sagemaker/core/partner_app/__init__.py b/sagemaker-core/tests/unit/sagemaker/core/partner_app/__init__.py new file mode 100644 index 0000000000..a6987bc6a6 --- /dev/null +++ b/sagemaker-core/tests/unit/sagemaker/core/partner_app/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/sagemaker-core/tests/unit/sagemaker/core/partner_app/test_auth_provider.py b/sagemaker-core/tests/unit/sagemaker/core/partner_app/test_auth_provider.py new file mode 100644 index 0000000000..b7b512948c --- /dev/null +++ b/sagemaker-core/tests/unit/sagemaker/core/partner_app/test_auth_provider.py @@ -0,0 +1,152 @@ +from __future__ import absolute_import + +import os +import unittest +from unittest.mock import patch, MagicMock +from requests import PreparedRequest +from sagemaker.core.partner_app.auth_provider import RequestsAuth, PartnerAppAuthProvider + + +class TestRequestsAuth(unittest.TestCase): + + @patch("sagemaker.core.partner_app.auth_provider.PartnerAppAuthUtils.get_signed_request") + @patch("sagemaker.core.partner_app.auth_provider.SigV4Auth") + def test_requests_auth_call(self, mock_sigv4_auth, mock_get_signed_request): + # Prepare mock data + mock_signed_url = "https://returned-url.test.com/" + mock_signed_headers = {"Authorization": "SigV4", "x-amz-date": "20241016T120000Z"} + mock_get_signed_request.return_value = (mock_signed_url, mock_signed_headers) + + # Create the objects needed for testing + app_arn = "arn:aws:lambda:us-west-2:123456789012:sagemaker:test" + under_test = RequestsAuth(sigv4=mock_sigv4_auth, app_arn=app_arn) + + # Create a prepared request object to simulate an actual request + request = PreparedRequest() + request.method = "GET" + request_url = "https://test.com" + request.url = request_url + request_headers = {} + request.headers = request_headers + request.body = "{}" + + # Call the method under test + updated_request = under_test(request) + + # Assertions to verify the behavior + mock_get_signed_request.assert_called_once_with( + sigv4=mock_sigv4_auth, + app_arn=app_arn, + url=request_url, + method="GET", + headers=request_headers, + body=request.body, + ) + + self.assertEqual(updated_request.url, mock_signed_url) + self.assertIn("Authorization", updated_request.headers) + self.assertIn("x-amz-date", updated_request.headers) + self.assertEqual(updated_request.headers["Authorization"], "SigV4") + self.assertEqual(updated_request.headers["x-amz-date"], "20241016T120000Z") + + +class TestPartnerAppAuthProvider(unittest.TestCase): + + @patch("sagemaker.core.partner_app.auth_provider.boto3.Session") + @patch("sagemaker.core.partner_app.auth_provider.SigV4Auth") + @patch("sagemaker.core.partner_app.auth_provider.PartnerAppAuthUtils.get_signed_request") + def test_get_signed_request( + self, mock_get_signed_request, mock_sigv4auth_class, mock_boto3_session + ): + # Set up environment variable + test_app_arn = "arn:aws-us-gov:sagemaker:us-west-2:123456789012:partner-app/my-app" + os.environ["AWS_PARTNER_APP_ARN"] = test_app_arn + + # Mock the return value of boto3.Session().get_credentials() + mock_credentials = MagicMock() + mock_boto3_session.return_value.get_credentials.return_value = mock_credentials + + # Mock the SigV4Auth instance + mock_sigv4auth_instance = MagicMock() + mock_sigv4auth_class.return_value = mock_sigv4auth_instance + + # Initialize the PartnerAppAuthProvider class + provider = PartnerAppAuthProvider() + + # Mock return value for get_signed_request + mock_get_signed_request.return_value = { + "Authorization": "SigV4", + "x-amz-date": "20241016T120000Z", + } + + # Call get_signed_request method + signed_request = provider.get_signed_request( + url="https://example.com", + method="GET", + headers={"Content-Type": "application/json"}, + body=None, + ) + + # Assert that the get_signed_request method was called with correct parameters + mock_get_signed_request.assert_called_once_with( + sigv4=mock_sigv4auth_instance, + app_arn=test_app_arn, + url="https://example.com", + method="GET", + headers={"Content-Type": "application/json"}, + body=None, + ) + + # Assert the response matches the mocked return value + self.assertEqual(signed_request["Authorization"], "SigV4") + self.assertEqual(signed_request["x-amz-date"], "20241016T120000Z") + + @patch("sagemaker.core.partner_app.auth_provider.SigV4Auth") + def test_get_auth(self, mock_sigv4auth_class): + # Set up environment variable + os.environ["AWS_PARTNER_APP_ARN"] = ( + "arn:aws:sagemaker:us-west-2:123456789012:partner-app/app-abc" + ) + + # Mock the SigV4Auth instance + mock_sigv4auth_instance = MagicMock() + mock_sigv4auth_class.return_value = mock_sigv4auth_instance + + # Initialize the PartnerAppAuthProvider class + provider = PartnerAppAuthProvider() + + # Call get_auth method + auth_instance = provider.get_auth() + + # Assert that the returned object is a RequestsAuth instance + self.assertIsInstance(auth_instance, RequestsAuth) + + # Assert that RequestsAuth was initialized with correct arguments + self.assertEqual(auth_instance.sigv4, mock_sigv4auth_instance) + self.assertEqual(auth_instance.app_arn, os.environ["AWS_PARTNER_APP_ARN"]) + + def test_init_raises_value_error_with_missing_app_arn(self): + # Remove the environment variable + if "AWS_PARTNER_APP_ARN" in os.environ: + del os.environ["AWS_PARTNER_APP_ARN"] + + # Ensure ValueError is raised when AWS_PARTNER_APP_ARN is not set + with self.assertRaises(ValueError) as context: + PartnerAppAuthProvider() + + self.assertIn( + "Must specify the AWS_PARTNER_APP_ARN environment variable", str(context.exception) + ) + + def test_init_raises_value_error_with_invalid_app_arn(self): + os.environ["AWS_PARTNER_APP_ARN"] = ( + "arn:aws:lambda:us-west-2:123456789012:function:my-function" + ) + + # Ensure ValueError is raised when AWS_PARTNER_APP_ARN is not set + with self.assertRaises(ValueError) as context: + PartnerAppAuthProvider() + + self.assertIn( + "Must specify a valid AWS_PARTNER_APP_ARN environment variable", str(context.exception) + ) diff --git a/sagemaker-core/tests/unit/sagemaker/core/partner_app/test_auth_utils.py b/sagemaker-core/tests/unit/sagemaker/core/partner_app/test_auth_utils.py new file mode 100644 index 0000000000..75bf7cf64c --- /dev/null +++ b/sagemaker-core/tests/unit/sagemaker/core/partner_app/test_auth_utils.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import + +import unittest +from unittest.mock import Mock, patch +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from hashlib import sha256 + +from sagemaker.core.partner_app.auth_utils import ( + PartnerAppAuthUtils, + EMPTY_SHA256_HASH, + UNSIGNED_PAYLOAD, +) + + +class TestPartnerAppAuthUtils(unittest.TestCase): + def setUp(self): + self.sigv4_mock = Mock(spec=SigV4Auth) + self.app_arn = "arn:aws:sagemaker:us-west-2:123456789012:partner-app/abc123" + self.url = "https://partner-app-abc123.us-west-2.amazonaws.com?fileName=Jupyter+interactive" + self.method = "POST" + self.headers = {"Authorization": "API_KEY", "Connection": "conn"} + self.body = b'{"key": "value"}' # Byte type body for hashing + + @patch("sagemaker.core.partner_app.auth_utils.AWSRequest") + def test_get_signed_request_with_body(self, AWSRequestMock): + aws_request_mock = Mock(spec=AWSRequest) + AWSRequestMock.return_value = aws_request_mock + + expected_hash = sha256(self.body).hexdigest() + # Authorization still has the original value as the sigv4 mock does not add this header + expected_sign_headers = { + "Authorization": "API_KEY", + "X-Amz-Partner-App-Authorization": "API_KEY", + "X-SageMaker-Partner-App-Server-Arn": self.app_arn, + "X-Amz-Target": "SageMaker.CallPartnerAppApi", + "X-Amz-Content-SHA256": expected_hash, + } + aws_request_mock.headers = expected_sign_headers + + # Mock the add_auth method on the SigV4Auth + self.sigv4_mock.add_auth = Mock() + + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, self.body + ) + + # Assert X-SageMaker-Partner-App-Server-Arn header is correct + self.assertEqual(signed_headers["X-SageMaker-Partner-App-Server-Arn"], self.app_arn) + + # Assert the Authorization header was moved to X-Amz-Partner-App-Authorization + self.assertIn("X-Amz-Partner-App-Authorization", signed_headers) + + # Assert X-Amz-Content-SHA256 is set + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], expected_hash) + + # Assert the Connection header is reserved + self.assertEqual(signed_headers["Connection"], "conn") + + expected_canonical_url = self.url.replace("+", "%20") + # Assert AWSRequestMock was called + AWSRequestMock.assert_called_once_with( + method=self.method, + url=expected_canonical_url, + headers=expected_sign_headers, + data=self.body, + ) + + def test_get_signed_request_with_no_body(self): + body = None + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, body + ) + + # Assert X-Amz-Content-SHA256 is EMPTY_SHA256_HASH + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], EMPTY_SHA256_HASH) + + def test_get_signed_request_with_bytes_body(self): + body = Mock() + body.seek = Mock() + body.tell = Mock(return_value=0) + body.read = Mock(side_effect=[b"test", b""]) + + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, body + ) + + # Verify the seek method was called + body.seek.assert_called() + + # Calculate the expected checksum for the body + checksum = sha256(b"test").hexdigest() + + # Assert X-Amz-Content-SHA256 is the calculated checksum + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], checksum) + + def test_get_body_header_unsigned_payload(self): + body = {"key": "value"} + + result = PartnerAppAuthUtils.get_body_header(body) + + # Assert the result is UNSIGNED_PAYLOAD for unrecognized body type + self.assertEqual(result, UNSIGNED_PAYLOAD) + + def test_get_body_header_empty_body(self): + body = None + + result = PartnerAppAuthUtils.get_body_header(body) + + # Assert the result is EMPTY_SHA256_HASH for empty body + self.assertEqual(result, EMPTY_SHA256_HASH)