diff --git a/sagemaker-core/src/sagemaker/core/jumpstart/search.py b/sagemaker-core/src/sagemaker/core/jumpstart/search.py index 9581897983..3d7609a4cc 100644 --- a/sagemaker-core/src/sagemaker/core/jumpstart/search.py +++ b/sagemaker-core/src/sagemaker/core/jumpstart/search.py @@ -7,6 +7,78 @@ logger = logging.getLogger(__name__) +class _ExpressionNode: + """Base class for expression AST nodes.""" + + def evaluate(self, keywords: List[str]) -> bool: + """Evaluate this node against the given keywords.""" + raise NotImplementedError + + +class _AndNode(_ExpressionNode): + """AND logical operator node.""" + + def __init__(self, left: _ExpressionNode, right: _ExpressionNode): + self.left = left + self.right = right + + def evaluate(self, keywords: List[str]) -> bool: + return self.left.evaluate(keywords) and self.right.evaluate(keywords) + + +class _OrNode(_ExpressionNode): + """OR logical operator node.""" + + def __init__(self, left: _ExpressionNode, right: _ExpressionNode): + self.left = left + self.right = right + + def evaluate(self, keywords: List[str]) -> bool: + return self.left.evaluate(keywords) or self.right.evaluate(keywords) + + +class _NotNode(_ExpressionNode): + """NOT logical operator node.""" + + def __init__(self, operand: _ExpressionNode): + self.operand = operand + + def evaluate(self, keywords: List[str]) -> bool: + return not self.operand.evaluate(keywords) + + +class _PatternNode(_ExpressionNode): + """Pattern matching node for keywords with wildcard support.""" + + def __init__(self, pattern: str): + self.pattern = pattern.strip('"').strip("'") + + def evaluate(self, keywords: List[str]) -> bool: + """Check if any keyword matches this pattern.""" + for keyword in keywords: + if self._matches_pattern(keyword, self.pattern): + return True + return False + + def _matches_pattern(self, keyword: str, pattern: str) -> bool: + """Check if a keyword matches a pattern with wildcard support.""" + if pattern.startswith("*") and pattern.endswith("*"): + # Contains pattern: *text* + stripped = pattern.strip("*") + return stripped in keyword + elif pattern.startswith("*"): + # Ends with pattern: *text + stripped = pattern[1:] + return keyword.endswith(stripped) + elif pattern.endswith("*"): + # Starts with pattern: text* + stripped = pattern[:-1] + return keyword.startswith(stripped) + else: + # Exact match + return keyword == pattern + + class _Filter: """ A filter that evaluates logical expressions against a list of keyword strings. @@ -28,6 +100,7 @@ def __init__(self, expression: str) -> None: Supports AND, OR, NOT, parentheses, and wildcard patterns (*). """ self.expression: str = expression + self._ast: Optional[_ExpressionNode] = None def match(self, keywords: List[str]) -> bool: """ @@ -39,54 +112,90 @@ def match(self, keywords: List[str]) -> bool: Returns: bool: True if the expression evaluates to True for the given keywords, else False. """ - expr: str = self._convert_expression(self.expression) try: - return eval(expr, {"__builtins__": {}}, {"keywords": keywords, "any": any}) + if self._ast is None: + self._ast = self._parse_expression(self.expression) + return self._ast.evaluate(keywords) except Exception: return False - def _convert_expression(self, expr: str) -> str: + def _parse_expression(self, expr: str) -> _ExpressionNode: """ - Convert the logical filter expression into a Python-evaluable string. + Parse the logical filter expression into an AST. Args: - expr (str): The raw expression to convert. + expr (str): The raw expression to parse. Returns: - str: A Python expression string using 'any' and logical operators. + _ExpressionNode: Root node of the parsed expression AST. """ - tokens: List[str] = re.findall( - r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE - ) - - def wildcard_condition(pattern: str) -> str: - pattern = pattern.strip('"').strip("'") - stripped = pattern.strip("*") + tokens = self._tokenize(expr) + result, _ = self._parse_or_expression(tokens, 0) + return result + + def _tokenize(self, expr: str) -> List[str]: + """Tokenize the expression into logical operators, keywords, and parentheses.""" + return re.findall(r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE) + + def _parse_or_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: + """Parse OR expression (lowest precedence).""" + left, pos = self._parse_and_expression(tokens, pos) + + while pos < len(tokens) and tokens[pos].upper() == "OR": + pos += 1 # Skip OR token + right, pos = self._parse_and_expression(tokens, pos) + left = _OrNode(left, right) + + return left, pos + + def _parse_and_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: + """Parse AND expression (medium precedence).""" + left, pos = self._parse_not_expression(tokens, pos) + + while pos < len(tokens) and tokens[pos].upper() == "AND": + pos += 1 # Skip AND token + right, pos = self._parse_not_expression(tokens, pos) + left = _AndNode(left, right) + + return left, pos + + def _parse_not_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: + """Parse NOT expression (highest precedence).""" + if pos < len(tokens) and tokens[pos].upper() == "NOT": + pos += 1 # Skip NOT token + operand, pos = self._parse_primary_expression(tokens, pos) + return _NotNode(operand), pos + else: + return self._parse_primary_expression(tokens, pos) + + def _parse_primary_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]: + """Parse primary expression (parentheses or pattern).""" + if pos >= len(tokens): + raise ValueError("Unexpected end of expression") + + token = tokens[pos] + + if token == "(": + pos += 1 # Skip opening parenthesis + expr, pos = self._parse_or_expression(tokens, pos) + if pos >= len(tokens) or tokens[pos] != ")": + raise ValueError("Missing closing parenthesis") + pos += 1 # Skip closing parenthesis + return expr, pos + elif token == ")": + raise ValueError("Unexpected closing parenthesis") + else: + # Pattern token + return _PatternNode(token), pos + 1 - if pattern.startswith("*") and pattern.endswith("*"): - return f"{repr(stripped)} in k" - elif pattern.startswith("*"): - return f"k.endswith({repr(stripped)})" - elif pattern.endswith("*"): - return f"k.startswith({repr(stripped)})" - else: - return f"k == {repr(pattern)}" - - def convert_token(token: str) -> str: - upper = token.upper() - if upper == "AND": - return "and" - elif upper == "OR": - return "or" - elif upper == "NOT": - return "not" - elif token in ("(", ")"): - return token - else: - return f"any({wildcard_condition(token)} for k in keywords)" - - converted_tokens = [convert_token(tok) for tok in tokens] - return " ".join(converted_tokens) + def _convert_expression(self, expr: str) -> str: + """ + Legacy method for backward compatibility. + This method is no longer used but kept to avoid breaking changes. + """ + # This method is deprecated and should not be used + # It's kept only for backward compatibility + return expr def _list_all_hub_models(hub_name: str, sm_client: Session) -> Iterator[HubContent]: