From 0aed8fa2e0771720e22812ec8fbc85e56be42180 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Mon, 15 Dec 2025 19:08:29 +0000 Subject: [PATCH 1/6] add evaluator tagging for jumpstart models --- .../train/evaluate/base_evaluator.py | 19 ++++++++++++++++-- .../src/sagemaker/train/evaluate/execution.py | 20 +++++++++++++------ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index 620b7ffe34..63e1aaf76a 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -9,10 +9,11 @@ import logging import re -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from pydantic import BaseModel, validator +from sagemaker.core.common_utils import TagsDict from sagemaker.core.resources import ModelPackageGroup from sagemaker.core.shapes import VpcConfig @@ -411,6 +412,13 @@ def _source_model_package_arn(self) -> Optional[str]: """Get the resolved source model package ARN (None for JumpStart models).""" info = self._get_resolved_model_info() return info.source_model_package_arn if info else None + + @property + def _is_jumpstart_model(self) -> bool: + """Determine if model is a JumpStart model""" + from sagemaker.train.common_utils.model_resolution import _ModelType + info = self._get_resolved_model_info() + return info.model_type == _ModelType.JUMPSTART def _infer_model_package_group_arn(self) -> Optional[str]: """Infer model package group ARN from source model package ARN. @@ -795,6 +803,12 @@ def _start_execution( EvaluationPipelineExecution: Started execution object """ from .execution import EvaluationPipelineExecution + + tags: List[TagsDict] = [] + + if self._is_jumpstart_model: + from sagemaker.core.jumpstart.utils import add_jumpstart_model_info_tags + tags = add_jumpstart_model_info_tags(tags, self.model, "*") execution = EvaluationPipelineExecution.start( eval_type=eval_type, @@ -803,7 +817,8 @@ def _start_execution( role_arn=role_arn, s3_output_path=self.s3_output_path, session=self.sagemaker_session.boto_session if hasattr(self.sagemaker_session, 'boto_session') else None, - region=region + region=region, + tags=tags ) return execution diff --git a/sagemaker-train/src/sagemaker/train/evaluate/execution.py b/sagemaker-train/src/sagemaker/train/evaluate/execution.py index a3fd88397d..6a1322526b 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/execution.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/execution.py @@ -16,6 +16,7 @@ # Third-party imports from botocore.exceptions import ClientError from pydantic import BaseModel, Field +from sagemaker.core.common_utils import TagsDict from sagemaker.core.helper.session_helper import Session from sagemaker.core.resources import Pipeline, PipelineExecution, Tag from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter @@ -38,6 +39,7 @@ def _create_evaluation_pipeline( pipeline_definition: str, session: Optional[Any] = None, region: Optional[str] = None, + tags: Optional[List[TagsDict]] = [], ) -> Any: """Helper method to create a SageMaker pipeline for evaluation. @@ -49,6 +51,7 @@ def _create_evaluation_pipeline( pipeline_definition (str): JSON pipeline definition (Jinja2 template). session (Optional[Any]): SageMaker session object. region (Optional[str]): AWS region. + tags (Optional[List[TagsDict]]): List of tags to include in pipeline Returns: Any: Created Pipeline instance (ready for execution). @@ -65,9 +68,9 @@ def _create_evaluation_pipeline( resolved_pipeline_definition = template.render(pipeline_name=pipeline_name) # Create tags for the pipeline - tags = [ + tags = tags.extend([ {"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"} - ] + ]) pipeline = Pipeline.create( pipeline_name=pipeline_name, @@ -163,7 +166,8 @@ def _get_or_create_pipeline( pipeline_definition: str, role_arn: str, session: Optional[Session] = None, - region: Optional[str] = None + region: Optional[str] = None, + tags: Optional[List[TagsDict]] = [], ) -> Pipeline: """Get existing pipeline or create/update it. @@ -177,6 +181,7 @@ def _get_or_create_pipeline( role_arn: IAM role ARN for pipeline execution session: Boto3 session (optional) region: AWS region (optional) + tags (Optional[List[TagsDict]]): List of tags to include in pipeline Returns: Pipeline instance (existing updated or newly created) @@ -202,7 +207,7 @@ def _get_or_create_pipeline( # Get tags using Tag.get_all tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region) - tags = {tag.key: tag.value for tag in tags_list} + tags = tags.extend({tag.key: tag.value for tag in tags_list}) # Validate tag if tags.get(_TAG_SAGEMAKER_MODEL_EVALUATION) == "true": @@ -505,7 +510,8 @@ def start( role_arn: str, s3_output_path: Optional[str] = None, session: Optional[Session] = None, - region: Optional[str] = None + region: Optional[str] = None, + tags: Optional[List[TagsDict]] = [], ) -> 'EvaluationPipelineExecution': """Create sagemaker pipeline execution. Optionally creates pipeline. @@ -517,6 +523,7 @@ def start( s3_output_path (Optional[str]): S3 location where evaluation results are stored. session (Optional[Session]): Boto3 session for API calls. region (Optional[str]): AWS region for the pipeline. + tags (Optional[List[TagsDict]]): List of tags to include in pipeline Returns: EvaluationPipelineExecution: Started pipeline execution instance. @@ -547,7 +554,8 @@ def start( pipeline_definition=pipeline_definition, role_arn=role_arn, session=session, - region=region + region=region, + tags=tags, ) # Start pipeline execution via boto3 From 973095b249f0102e7965b4981bc18d1abb006296 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Mon, 15 Dec 2025 20:16:11 +0000 Subject: [PATCH 2/6] fix bug for extending tags --- sagemaker-train/src/sagemaker/train/evaluate/execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/execution.py b/sagemaker-train/src/sagemaker/train/evaluate/execution.py index 6a1322526b..cea7c52466 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/execution.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/execution.py @@ -68,7 +68,7 @@ def _create_evaluation_pipeline( resolved_pipeline_definition = template.render(pipeline_name=pipeline_name) # Create tags for the pipeline - tags = tags.extend([ + tags.extend([ {"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"} ]) @@ -207,7 +207,7 @@ def _get_or_create_pipeline( # Get tags using Tag.get_all tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region) - tags = tags.extend({tag.key: tag.value for tag in tags_list}) + tags.extend({tag.key: tag.value for tag in tags_list}) # Validate tag if tags.get(_TAG_SAGEMAKER_MODEL_EVALUATION) == "true": From fd1cc3f353d56b2fb198c060b81af6bc3be4cc01 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Mon, 15 Dec 2025 20:57:20 +0000 Subject: [PATCH 3/6] bug fix for js tags --- .../src/sagemaker/train/evaluate/execution.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/execution.py b/sagemaker-train/src/sagemaker/train/evaluate/execution.py index cea7c52466..3d217b08cf 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/execution.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/execution.py @@ -167,7 +167,7 @@ def _get_or_create_pipeline( role_arn: str, session: Optional[Session] = None, region: Optional[str] = None, - tags: Optional[List[TagsDict]] = [], + create_tags: Optional[List[TagsDict]] = [], ) -> Pipeline: """Get existing pipeline or create/update it. @@ -181,7 +181,7 @@ def _get_or_create_pipeline( role_arn: IAM role ARN for pipeline execution session: Boto3 session (optional) region: AWS region (optional) - tags (Optional[List[TagsDict]]): List of tags to include in pipeline + create_tags (Optional[List[TagsDict]]): List of tags to include in pipeline Returns: Pipeline instance (existing updated or newly created) @@ -207,7 +207,7 @@ def _get_or_create_pipeline( # Get tags using Tag.get_all tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region) - tags.extend({tag.key: tag.value for tag in tags_list}) + tags = {tag.key: tag.value for tag in tags_list} # Validate tag if tags.get(_TAG_SAGEMAKER_MODEL_EVALUATION) == "true": @@ -230,19 +230,19 @@ def _get_or_create_pipeline( # No matching pipeline found, create new one logger.info(f"No existing pipeline found with prefix {pipeline_name_prefix}, creating new one") - return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region) + return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags) except ClientError as e: error_code = e.response['Error']['Code'] if "ResourceNotFound" in error_code: - return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region) + return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags) else: raise except Exception as e: # If search fails for other reasons, try to create logger.info(f"Error searching for pipeline ({str(e)}), attempting to create new pipeline") - return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region) + return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags) def _start_pipeline_execution( @@ -555,7 +555,7 @@ def start( role_arn=role_arn, session=session, region=region, - tags=tags, + create_tags=tags, ) # Start pipeline execution via boto3 From bfcfae0b9c270aaf966b04193f6f7916070e7d5b Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Mon, 15 Dec 2025 21:25:30 +0000 Subject: [PATCH 4/6] add unit test for js evaluator tagging --- .../unit/train/evaluate/test_execution.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/sagemaker-train/tests/unit/train/evaluate/test_execution.py b/sagemaker-train/tests/unit/train/evaluate/test_execution.py index 8b3402c1ee..a01d500ff5 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_execution.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_execution.py @@ -299,7 +299,37 @@ def test_create_pipeline_when_not_found(self, mock_pipeline_class, mock_create, DEFAULT_ROLE, DEFAULT_PIPELINE_DEFINITION, mock_session, - DEFAULT_REGION + DEFAULT_REGION, + [] + ) + assert result == mock_pipeline + + @patch("sagemaker.train.evaluate.execution._create_evaluation_pipeline") + @patch("sagemaker.train.evaluate.execution.Pipeline") + def test_create_pipeline_when_not_found_with_jumpstart_tags(self, mock_pipeline_class, mock_create, mock_session): + """Test creating pipeline when it doesn't exist.""" + error_response = {"Error": {"Code": "ResourceNotFound"}} + mock_pipeline_class.get.side_effect = ClientError(error_response, "DescribePipeline") + mock_pipeline = MagicMock() + mock_create.return_value = mock_pipeline + create_tags = [{"key": "sagemaker-sdk:jumpstart-model-id", "value": "dummy-js-model-id"}] + + result = _get_or_create_pipeline( + eval_type=EvalType.BENCHMARK, + pipeline_definition=DEFAULT_PIPELINE_DEFINITION, + role_arn=DEFAULT_ROLE, + session=mock_session, + region=DEFAULT_REGION, + create_tags=create_tags + ) + + mock_create.assert_called_once_with( + EvalType.BENCHMARK, + DEFAULT_ROLE, + DEFAULT_PIPELINE_DEFINITION, + mock_session, + DEFAULT_REGION, + create_tags ) assert result == mock_pipeline From 01c3c2c497be439815fde7ea6e5ca46840506d89 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Mon, 19 Jan 2026 13:58:32 -0500 Subject: [PATCH 5/6] chore: upgrade search in jumpstart code --- .../src/sagemaker/core/jumpstart/search.py | 185 ++++++++++++++---- 1 file changed, 147 insertions(+), 38 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/jumpstart/search.py b/sagemaker-core/src/sagemaker/core/jumpstart/search.py index 9581897983..9da8adb4eb 100644 --- a/sagemaker-core/src/sagemaker/core/jumpstart/search.py +++ b/sagemaker-core/src/sagemaker/core/jumpstart/search.py @@ -1,12 +1,84 @@ import re import logging -from typing import List, Iterator, Optional +from typing import List, Iterator, Optional, Union, Any from sagemaker.core.helper.session_helper import Session from sagemaker.core.resources import HubContent logger = logging.getLogger(__name__) +class _ExpressionNode: + """Base class for expression AST nodes.""" + + def evaluate(self, keywords: List[str]) -> bool: + """Evaluate this node against the given keywords.""" + raise NotImplementedError + + +class _AndNode(_ExpressionNode): + """AND logical operator node.""" + + def __init__(self, left: _ExpressionNode, right: _ExpressionNode): + self.left = left + self.right = right + + def evaluate(self, keywords: List[str]) -> bool: + return self.left.evaluate(keywords) and self.right.evaluate(keywords) + + +class _OrNode(_ExpressionNode): + """OR logical operator node.""" + + def __init__(self, left: _ExpressionNode, right: _ExpressionNode): + self.left = left + self.right = right + + def evaluate(self, keywords: List[str]) -> bool: + return self.left.evaluate(keywords) or self.right.evaluate(keywords) + + +class _NotNode(_ExpressionNode): + """NOT logical operator node.""" + + def __init__(self, operand: _ExpressionNode): + self.operand = operand + + def evaluate(self, keywords: List[str]) -> bool: + return not self.operand.evaluate(keywords) + + +class _PatternNode(_ExpressionNode): + """Pattern matching node for keywords with wildcard support.""" + + def __init__(self, pattern: str): + self.pattern = pattern.strip('"').strip("'") + + def evaluate(self, keywords: List[str]) -> bool: + """Check if any keyword matches this pattern.""" + for keyword in keywords: + if self._matches_pattern(keyword, self.pattern): + return True + return False + + def _matches_pattern(self, keyword: str, pattern: str) -> bool: + """Check if a keyword matches a pattern with wildcard support.""" + if pattern.startswith("*") and pattern.endswith("*"): + # Contains pattern: *text* + stripped = pattern.strip("*") + return stripped in keyword + elif pattern.startswith("*"): + # Ends with pattern: *text + stripped = pattern[1:] + return keyword.endswith(stripped) + elif pattern.endswith("*"): + # Starts with pattern: text* + stripped = pattern[:-1] + return keyword.startswith(stripped) + else: + # Exact match + return keyword == pattern + + class _Filter: """ A filter that evaluates logical expressions against a list of keyword strings. @@ -28,6 +100,7 @@ def __init__(self, expression: str) -> None: Supports AND, OR, NOT, parentheses, and wildcard patterns (*). """ self.expression: str = expression + self._ast: Optional[_ExpressionNode] = None def match(self, keywords: List[str]) -> bool: """ @@ -39,54 +112,90 @@ def match(self, keywords: List[str]) -> bool: Returns: bool: True if the expression evaluates to True for the given keywords, else False. """ - expr: str = self._convert_expression(self.expression) try: - return eval(expr, {"__builtins__": {}}, {"keywords": keywords, "any": any}) + if self._ast is None: + self._ast = self._parse_expression(self.expression) + return self._ast.evaluate(keywords) except Exception: return False - def _convert_expression(self, expr: str) -> str: + def _parse_expression(self, expr: str) -> _ExpressionNode: """ - Convert the logical filter expression into a Python-evaluable string. + Parse the logical filter expression into an AST. Args: - expr (str): The raw expression to convert. + expr (str): The raw expression to parse. Returns: - str: A Python expression string using 'any' and logical operators. + _ExpressionNode: Root node of the parsed expression AST. """ - tokens: List[str] = re.findall( - r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE - ) + tokens = self._tokenize(expr) + result, _ = self._parse_or_expression(tokens, 0) + return result + + def _tokenize(self, expr: str) -> List[str]: + """Tokenize the expression into logical operators, keywords, and parentheses.""" + return re.findall(r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE) + + def _parse_or_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: + """Parse OR expression (lowest precedence).""" + left, pos = self._parse_and_expression(tokens, pos) + + while pos < len(tokens) and tokens[pos].upper() == "OR": + pos += 1 # Skip OR token + right, pos = self._parse_and_expression(tokens, pos) + left = _OrNode(left, right) + + return left, pos + + def _parse_and_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: + """Parse AND expression (medium precedence).""" + left, pos = self._parse_not_expression(tokens, pos) + + while pos < len(tokens) and tokens[pos].upper() == "AND": + pos += 1 # Skip AND token + right, pos = self._parse_not_expression(tokens, pos) + left = _AndNode(left, right) + + return left, pos + + def _parse_not_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: + """Parse NOT expression (highest precedence).""" + if pos < len(tokens) and tokens[pos].upper() == "NOT": + pos += 1 # Skip NOT token + operand, pos = self._parse_primary_expression(tokens, pos) + return _NotNode(operand), pos + else: + return self._parse_primary_expression(tokens, pos) + + def _parse_primary_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: + """Parse primary expression (parentheses or pattern).""" + if pos >= len(tokens): + raise ValueError("Unexpected end of expression") + + token = tokens[pos] + + if token == "(": + pos += 1 # Skip opening parenthesis + expr, pos = self._parse_or_expression(tokens, pos) + if pos >= len(tokens) or tokens[pos] != ")": + raise ValueError("Missing closing parenthesis") + pos += 1 # Skip closing parenthesis + return expr, pos + elif token == ")": + raise ValueError("Unexpected closing parenthesis") + else: + # Pattern token + return _PatternNode(token), pos + 1 - def wildcard_condition(pattern: str) -> str: - pattern = pattern.strip('"').strip("'") - stripped = pattern.strip("*") - - if pattern.startswith("*") and pattern.endswith("*"): - return f"{repr(stripped)} in k" - elif pattern.startswith("*"): - return f"k.endswith({repr(stripped)})" - elif pattern.endswith("*"): - return f"k.startswith({repr(stripped)})" - else: - return f"k == {repr(pattern)}" - - def convert_token(token: str) -> str: - upper = token.upper() - if upper == "AND": - return "and" - elif upper == "OR": - return "or" - elif upper == "NOT": - return "not" - elif token in ("(", ")"): - return token - else: - return f"any({wildcard_condition(token)} for k in keywords)" - - converted_tokens = [convert_token(tok) for tok in tokens] - return " ".join(converted_tokens) + def _convert_expression(self, expr: str) -> str: + """ + Legacy method for backward compatibility. + This method is no longer used but kept to avoid breaking changes. + """ + # This method is deprecated and should not be used + # It's kept only for backward compatibility + return expr def _list_all_hub_models(hub_name: str, sm_client: Session) -> Iterator[HubContent]: From b4abee8274574239fe69ab1e53ad5aa029d80cca Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Tue, 20 Jan 2026 10:50:50 -0500 Subject: [PATCH 6/6] fix linting --- .../src/sagemaker/core/jumpstart/search.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/jumpstart/search.py b/sagemaker-core/src/sagemaker/core/jumpstart/search.py index 9da8adb4eb..3d7609a4cc 100644 --- a/sagemaker-core/src/sagemaker/core/jumpstart/search.py +++ b/sagemaker-core/src/sagemaker/core/jumpstart/search.py @@ -1,6 +1,6 @@ import re import logging -from typing import List, Iterator, Optional, Union, Any +from typing import List, Iterator, Optional from sagemaker.core.helper.session_helper import Session from sagemaker.core.resources import HubContent @@ -9,7 +9,7 @@ class _ExpressionNode: """Base class for expression AST nodes.""" - + def evaluate(self, keywords: List[str]) -> bool: """Evaluate this node against the given keywords.""" raise NotImplementedError @@ -17,49 +17,49 @@ def evaluate(self, keywords: List[str]) -> bool: class _AndNode(_ExpressionNode): """AND logical operator node.""" - + def __init__(self, left: _ExpressionNode, right: _ExpressionNode): self.left = left self.right = right - + def evaluate(self, keywords: List[str]) -> bool: return self.left.evaluate(keywords) and self.right.evaluate(keywords) class _OrNode(_ExpressionNode): """OR logical operator node.""" - + def __init__(self, left: _ExpressionNode, right: _ExpressionNode): self.left = left self.right = right - + def evaluate(self, keywords: List[str]) -> bool: return self.left.evaluate(keywords) or self.right.evaluate(keywords) class _NotNode(_ExpressionNode): """NOT logical operator node.""" - + def __init__(self, operand: _ExpressionNode): self.operand = operand - + def evaluate(self, keywords: List[str]) -> bool: return not self.operand.evaluate(keywords) class _PatternNode(_ExpressionNode): """Pattern matching node for keywords with wildcard support.""" - + def __init__(self, pattern: str): self.pattern = pattern.strip('"').strip("'") - + def evaluate(self, keywords: List[str]) -> bool: """Check if any keyword matches this pattern.""" for keyword in keywords: if self._matches_pattern(keyword, self.pattern): return True return False - + def _matches_pattern(self, keyword: str, pattern: str) -> bool: """Check if a keyword matches a pattern with wildcard support.""" if pattern.startswith("*") and pattern.endswith("*"): @@ -140,23 +140,23 @@ def _tokenize(self, expr: str) -> List[str]: def _parse_or_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: """Parse OR expression (lowest precedence).""" left, pos = self._parse_and_expression(tokens, pos) - + while pos < len(tokens) and tokens[pos].upper() == "OR": pos += 1 # Skip OR token right, pos = self._parse_and_expression(tokens, pos) left = _OrNode(left, right) - + return left, pos def _parse_and_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: """Parse AND expression (medium precedence).""" left, pos = self._parse_not_expression(tokens, pos) - + while pos < len(tokens) and tokens[pos].upper() == "AND": pos += 1 # Skip AND token right, pos = self._parse_not_expression(tokens, pos) left = _AndNode(left, right) - + return left, pos def _parse_not_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: @@ -172,9 +172,9 @@ def _parse_primary_expression(self, tokens: List[str], pos: int) -> tuple[_Expre """Parse primary expression (parentheses or pattern).""" if pos >= len(tokens): raise ValueError("Unexpected end of expression") - + token = tokens[pos] - + if token == "(": pos += 1 # Skip opening parenthesis expr, pos = self._parse_or_expression(tokens, pos)